From dfec6d56ca1f642c45150497d0fdf986b2c62c9d Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Mon, 14 Sep 2020 14:28:37 +0200 Subject: [PATCH] Fix herberta training. --- README.md | 6 +----- combo/data/api.py | 11 +++++++++-- combo/data/dataset.py | 19 +++++++++++++++++-- combo/main.py | 12 +++++++----- combo/predict.py | 2 +- combo/training/trainer.py | 1 + setup.py | 2 +- 7 files changed, 37 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 680673c..c5d0fd0 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,5 @@ ## Installation -### HERBERTA notes: - -Install herberta transformers package **before** running command below - Clone this repository and run: ```bash python setup.py develop @@ -86,7 +82,7 @@ Input: one sentence per line. Output: List of token jsons. ```bash -combo --mode predict --model_path your_model_tar_gz --input_file your_text_file --output_file your_output_file --silent +combo --mode predict --model_path your_model_tar_gz --input_file your_text_file --output_file your_output_file --silent --noconllu_format ``` #### Advanced diff --git a/combo/data/api.py b/combo/data/api.py index b0763b6..10a3a72 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -20,6 +20,7 @@ class Token: deprel: Optional[str] = None deps: Optional[str] = None misc: Optional[str] = None + semrel: Optional[str] = None @dataclass_json @@ -37,8 +38,14 @@ class _TokenList(conllu.TokenList): return 'TokenList<' + ', '.join(token['token'] for token in self) + '>' -def sentence2conllu(sentence: Sentence) -> conllu.TokenList: - tokens = [collections.OrderedDict(t.to_dict()) for t in sentence.tokens] +def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.TokenList: + tokens = [] + for token in sentence.tokens: + token_dict = collections.OrderedDict(token.to_dict()) + # Remove semrel to have default conllu format. + if not keep_semrel: + del token_dict["semrel"] + tokens.append(token_dict) # Range tokens must be tuple not list, this is conllu library requirement for t in tokens: if type(t["id"]) == list: diff --git a/combo/data/dataset.py b/combo/data/dataset.py index b5f5c30..459a755 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -41,7 +41,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): "Features and targets cannot share elements! " "Remove {} from either features or targets.".format(intersection) ) - self._use_sem = use_sem + self.use_sem = use_sem # *.conllu readers configuration fields = list(parser.DEFAULT_FIELDS) @@ -49,7 +49,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): field_parsers = parser.DEFAULT_FIELD_PARSERS # Do not make it nullable field_parsers.pop("xpostag", None) - if self._use_sem: + if self.use_sem: fields = list(fields) fields.append("semrel") field_parsers["semrel"] = lambda line, i: line[i] @@ -113,8 +113,23 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field, label_namespace=target_name + "_labels") + # Restore feats fields to string representation + # parser.serialize_field doesn't handle key without value + for token in tree.tokens: + if "feats" in token: + feats = token["feats"] + if feats: + feats_values = [] + for k, v in feats.items(): + feats_values.append('='.join((k, v)) if v else k) + field = "|".join(feats_values) + else: + field = "_" + token["feats"] = field + # metadata fields_["metadata"] = allen_fields.MetadataField({"input": tree, "field_names": self.fields}) + return allen_data.Instance(fields_) @staticmethod diff --git a/combo/main.py b/combo/main.py index 4dc0056..c7aac87 100644 --- a/combo/main.py +++ b/combo/main.py @@ -13,7 +13,7 @@ from allennlp.common import checks as allen_checks, util from allennlp.models import archival from combo import predict -from combo.data import dataset +from combo.data import api, dataset from combo.utils import checks logger = logging.getLogger(__name__) @@ -76,6 +76,8 @@ flags.DEFINE_string(name="model_path", default=None, help="Pretrained model path.") flags.DEFINE_string(name="input_file", default=None, help="File to predict path") +flags.DEFINE_boolean(name="conllu_format", default=True, + help="Prediction based on conllu format (instead of raw text).") flags.DEFINE_integer(name="batch_size", default=1, help="Prediction batch size.") flags.DEFINE_boolean(name="silent", default=True, @@ -136,13 +138,13 @@ def run(_): model=model, dataset_reader=dataset_reader ) - test_path = FLAGS.test_path - test_trees = dataset_reader.read(test_path) + test_trees = dataset_reader.read(FLAGS.test_path) with open(FLAGS.output_file, "w") as file: for tree in test_trees: - file.writelines(predictor.predict_instance_as_tree(tree).serialize()) + file.writelines(api.sentence2conllu(predictor.predict_instance(tree), + keep_semrel=dataset_reader.use_sem).serialize()) else: - use_dataset_reader = ".conllu" in FLAGS.input_file.lower() + use_dataset_reader = FLAGS.conllu_format predictor = _get_predictor() if use_dataset_reader: predictor.line_to_conllu = True diff --git a/combo/predict.py b/combo/predict.py index 0ee80a9..ebbb372 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -128,7 +128,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): # Check whether serialized (str) tree or token's list # Serialized tree has already separators between lines if self.line_to_conllu: - return sentence2conllu(outputs).serialize() + return sentence2conllu(outputs, keep_semrel=self._dataset_reader.use_sem).serialize() else: return outputs.to_json() diff --git a/combo/training/trainer.py b/combo/training/trainer.py index 234bdd7..772b9b0 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -127,6 +127,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): val_reg_loss, num_batches=num_batches, batch_loss=None, + batch_reg_loss=None, reset=True, world_size=self._world_size, cuda_device=self.cuda_device, diff --git a/setup.py b/setup.py index 228e025..dd21555 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ REQUIREMENTS = [ 'torch==1.6.0', 'tqdm==4.43.0', 'transformers>=3.0.0,<3.1.0', - 'urllib3==1.24.2', + 'urllib3>=1.25.11', ] setup( -- GitLab