From dfec6d56ca1f642c45150497d0fdf986b2c62c9d Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Mon, 14 Sep 2020 14:28:37 +0200
Subject: [PATCH] Fix herberta training.

---
 README.md                 |  6 +-----
 combo/data/api.py         | 11 +++++++++--
 combo/data/dataset.py     | 19 +++++++++++++++++--
 combo/main.py             | 12 +++++++-----
 combo/predict.py          |  2 +-
 combo/training/trainer.py |  1 +
 setup.py                  |  2 +-
 7 files changed, 37 insertions(+), 16 deletions(-)

diff --git a/README.md b/README.md
index 680673c..c5d0fd0 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,5 @@
 ## Installation
 
-### HERBERTA notes:
-
-Install herberta transformers package **before** running command below
-
 Clone this repository and run:
 ```bash
 python setup.py develop
@@ -86,7 +82,7 @@ Input: one sentence per line.
 Output: List of token jsons.
 
 ```bash
-combo --mode predict --model_path your_model_tar_gz --input_file your_text_file --output_file your_output_file --silent
+combo --mode predict --model_path your_model_tar_gz --input_file your_text_file --output_file your_output_file --silent --noconllu_format
 ```
 #### Advanced
 
diff --git a/combo/data/api.py b/combo/data/api.py
index b0763b6..10a3a72 100644
--- a/combo/data/api.py
+++ b/combo/data/api.py
@@ -20,6 +20,7 @@ class Token:
     deprel: Optional[str] = None
     deps: Optional[str] = None
     misc: Optional[str] = None
+    semrel: Optional[str] = None
 
 
 @dataclass_json
@@ -37,8 +38,14 @@ class _TokenList(conllu.TokenList):
         return 'TokenList<' + ', '.join(token['token'] for token in self) + '>'
 
 
-def sentence2conllu(sentence: Sentence) -> conllu.TokenList:
-    tokens = [collections.OrderedDict(t.to_dict()) for t in sentence.tokens]
+def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.TokenList:
+    tokens = []
+    for token in sentence.tokens:
+        token_dict = collections.OrderedDict(token.to_dict())
+        # Remove semrel to have default conllu format.
+        if not keep_semrel:
+            del token_dict["semrel"]
+        tokens.append(token_dict)
     # Range tokens must be tuple not list, this is conllu library requirement
     for t in tokens:
         if type(t["id"]) == list:
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index b5f5c30..459a755 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -41,7 +41,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                 "Features and targets cannot share elements! "
                 "Remove {} from either features or targets.".format(intersection)
             )
-        self._use_sem = use_sem
+        self.use_sem = use_sem
 
         # *.conllu readers configuration
         fields = list(parser.DEFAULT_FIELDS)
@@ -49,7 +49,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
         field_parsers = parser.DEFAULT_FIELD_PARSERS
         # Do not make it nullable
         field_parsers.pop("xpostag", None)
-        if self._use_sem:
+        if self.use_sem:
             fields = list(fields)
             fields.append("semrel")
             field_parsers["semrel"] = lambda line, i: line[i]
@@ -113,8 +113,23 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                         fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
                                                                                label_namespace=target_name + "_labels")
 
+        # Restore feats fields to string representation
+        # parser.serialize_field doesn't handle key without value
+        for token in tree.tokens:
+            if "feats" in token:
+                feats = token["feats"]
+                if feats:
+                    feats_values = []
+                    for k, v in feats.items():
+                        feats_values.append('='.join((k, v)) if v else k)
+                    field = "|".join(feats_values)
+                else:
+                    field = "_"
+                token["feats"] = field
+
         # metadata
         fields_["metadata"] = allen_fields.MetadataField({"input": tree, "field_names": self.fields})
+
         return allen_data.Instance(fields_)
 
     @staticmethod
diff --git a/combo/main.py b/combo/main.py
index 4dc0056..c7aac87 100644
--- a/combo/main.py
+++ b/combo/main.py
@@ -13,7 +13,7 @@ from allennlp.common import checks as allen_checks, util
 from allennlp.models import archival
 
 from combo import predict
-from combo.data import dataset
+from combo.data import api, dataset
 from combo.utils import checks
 
 logger = logging.getLogger(__name__)
@@ -76,6 +76,8 @@ flags.DEFINE_string(name="model_path", default=None,
                     help="Pretrained model path.")
 flags.DEFINE_string(name="input_file", default=None,
                     help="File to predict path")
+flags.DEFINE_boolean(name="conllu_format", default=True,
+                     help="Prediction based on conllu format (instead of raw text).")
 flags.DEFINE_integer(name="batch_size", default=1,
                      help="Prediction batch size.")
 flags.DEFINE_boolean(name="silent", default=True,
@@ -136,13 +138,13 @@ def run(_):
                 model=model,
                 dataset_reader=dataset_reader
             )
-            test_path = FLAGS.test_path
-            test_trees = dataset_reader.read(test_path)
+            test_trees = dataset_reader.read(FLAGS.test_path)
             with open(FLAGS.output_file, "w") as file:
                 for tree in test_trees:
-                    file.writelines(predictor.predict_instance_as_tree(tree).serialize())
+                    file.writelines(api.sentence2conllu(predictor.predict_instance(tree),
+                                                        keep_semrel=dataset_reader.use_sem).serialize())
     else:
-        use_dataset_reader = ".conllu" in FLAGS.input_file.lower()
+        use_dataset_reader = FLAGS.conllu_format
         predictor = _get_predictor()
         if use_dataset_reader:
             predictor.line_to_conllu = True
diff --git a/combo/predict.py b/combo/predict.py
index 0ee80a9..ebbb372 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -128,7 +128,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
         # Check whether serialized (str) tree or token's list
         # Serialized tree has already separators between lines
         if self.line_to_conllu:
-            return sentence2conllu(outputs).serialize()
+            return sentence2conllu(outputs, keep_semrel=self._dataset_reader.use_sem).serialize()
         else:
             return outputs.to_json()
 
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
index 234bdd7..772b9b0 100644
--- a/combo/training/trainer.py
+++ b/combo/training/trainer.py
@@ -127,6 +127,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
                             val_reg_loss,
                             num_batches=num_batches,
                             batch_loss=None,
+                            batch_reg_loss=None,
                             reset=True,
                             world_size=self._world_size,
                             cuda_device=self.cuda_device,
diff --git a/setup.py b/setup.py
index 228e025..dd21555 100644
--- a/setup.py
+++ b/setup.py
@@ -14,7 +14,7 @@ REQUIREMENTS = [
     'torch==1.6.0',
     'tqdm==4.43.0',
     'transformers>=3.0.0,<3.1.0',
-    'urllib3==1.24.2',
+    'urllib3>=1.25.11',
 ]
 
 setup(
-- 
GitLab