From d0dc576f2a4d002d58cacc4e4016f236bd053cae Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Tue, 22 Dec 2020 18:52:29 +0100
Subject: [PATCH] Refactor predictor name, speed-up dataset reader and graph
 config.

---
 README.md                     |  4 ++--
 combo/data/dataset.py         |  4 ++--
 combo/main.py                 |  2 +-
 combo/models/model.py         |  4 +++-
 combo/predict.py              |  2 +-
 config.graph.template.jsonnet |  2 ++
 docs/models.md                |  4 ++--
 docs/prediction.md            | 14 ++++++++++++--
 tests/test_predict.py         |  2 +-
 9 files changed, 26 insertions(+), 12 deletions(-)

diff --git a/README.md b/README.md
index 19847b8..c339bda 100644
--- a/README.md
+++ b/README.md
@@ -18,9 +18,9 @@ python setup.py develop
 ```
 Run the following lines in your Python console to make predictions with a pre-trained model:
 ```python
-import combo.predict as predict
+from combo.predict import COMBO
 
-nlp = predict.SemanticMultitaskPredictor.from_pretrained("polish-herbert-base")
+nlp = COMBO.from_pretrained("polish-herbert-base")
 sentence = nlp("Moje zdanie.")
 print(sentence.tokens)
 ```
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index bb56ac3..48b68b1 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -119,8 +119,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                                                                                label_namespace=target_name + "_labels")
                     elif target_name == "deps":
                         # Graphs require adding ROOT (AdjacencyField uses sequence length from TextField).
-                        text_field_deps = copy.deepcopy(text_field)
-                        text_field_deps.tokens.insert(0, _Token("ROOT"))
+                        text_field_deps = allen_fields.TextField([_Token("ROOT")] + copy.deepcopy(tokens),
+                                                                 self._token_indexers)
                         enhanced_heads: List[Tuple[int, int]] = []
                         enhanced_deprels: List[str] = []
                         for idx, t in enumerate(tree_tokens):
diff --git a/combo/main.py b/combo/main.py
index 374af69..17c960a 100644
--- a/combo/main.py
+++ b/combo/main.py
@@ -136,7 +136,7 @@ def run(_):
             params = common.Params.from_file(FLAGS.config_path, ext_vars=_get_ext_vars())["dataset_reader"]
             params.pop("type")
             dataset_reader = dataset.UniversalDependenciesDatasetReader.from_params(params)
-            predictor = predict.SemanticMultitaskPredictor(
+            predictor = predict.COMBO(
                 model=model,
                 dataset_reader=dataset_reader
             )
diff --git a/combo/models/model.py b/combo/models/model.py
index 710f72c..ad0df0e 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -126,10 +126,12 @@ class SemanticMultitaskModel(allen_models.Model):
             "deprel": relations_pred,
             "enhanced_head": enhanced_head_pred,
             "enhanced_deprel": enhanced_relations_pred,
-            "enhanced_deprel_prob": enhanced_parser_output["rel_probability"],
             "sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0],
         }
 
+        if "rel_probability" in enhanced_parser_output:
+            output["enhanced_deprel_prob"] = enhanced_parser_output["rel_probability"]
+
         if self._has_labels([upostag, xpostag, lemma, feats, head, deprel, semrel]):
 
             # Feats mapping
diff --git a/combo/predict.py b/combo/predict.py
index c58db25..8d3e2f9 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
 
 @predictor.Predictor.register("semantic-multitask-predictor")
 @predictor.Predictor.register("semantic-multitask-predictor-spacy", constructor="with_spacy_tokenizer")
-class SemanticMultitaskPredictor(predictor.Predictor):
+class COMBO(predictor.Predictor):
 
     def __init__(self,
                  model: models.Model,
diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet
index 6975aba..bc8c465 100644
--- a/config.graph.template.jsonnet
+++ b/config.graph.template.jsonnet
@@ -204,6 +204,8 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
                     type: "transformers_word_embeddings",
                     model_name: pretrained_transformer_name,
                     projection_dim: projected_embedding_dim,
+                    tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
+                                      then {use_fast: false} else {},
                 } else {
                     type: "embeddings_projected",
                     embedding_dim: embedding_dim,
diff --git a/docs/models.md b/docs/models.md
index 485f761..d4346ff 100644
--- a/docs/models.md
+++ b/docs/models.md
@@ -5,9 +5,9 @@ Pre-trained models are available [here](http://mozart.ipipan.waw.pl/~mklimaszews
 ## Automatic download
 Python `from_pretrained` method will download the pre-trained model if the provided name (without the extension .tar.gz) matches one of the names in [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/).
 ```python
-import combo.predict as predict
+from combo.predict import COMBO
 
-nlp = predict.SemanticMultitaskPredictor.from_pretrained("polish-herbert-base")
+nlp = COMBO.from_pretrained("polish-herbert-base")
 ```
 Otherwise it looks for a model in local env.
 
diff --git a/docs/prediction.md b/docs/prediction.md
index 89cc74c..6de5d0e 100644
--- a/docs/prediction.md
+++ b/docs/prediction.md
@@ -32,9 +32,19 @@ Use either `--predictor_name semantic-multitask-predictor` or `--predictor_name
 
 ## Python
 ```python
-import combo.predict as predict
+from combo.predict import COMBO
 
 model_path = "your_model.tar.gz"
-nlp = predict.SemanticMultitaskPredictor.from_pretrained(model_path)
+nlp = COMBO.from_pretrained(model_path)
 sentence = nlp("Sentence to parse.")
 ```
+
+Using your own tokenization:
+```python
+from combo.predict import COMBO
+
+model_path = "your_model.tar.gz"
+nlp = COMBO.from_pretrained(model_path)
+tokenized_sentence = ["Sentence", "to", "parse", "."]
+sentence = nlp([tokenized_sentence])
+```
diff --git a/tests/test_predict.py b/tests/test_predict.py
index 2a56bd9..332ced3 100644
--- a/tests/test_predict.py
+++ b/tests/test_predict.py
@@ -22,7 +22,7 @@ class PredictionTest(unittest.TestCase):
             data.Token(id=2, token=".")
         ])]
         api_wrapped_tokenized_sentence = [data.conllu2sentence(data.tokens2conllu(["Test", "."]), [])]
-        nlp = predict.SemanticMultitaskPredictor.from_pretrained(os.path.join(self.FIXTURES_ROOT, "model.tar.gz"))
+        nlp = predict.COMBO.from_pretrained(os.path.join(self.FIXTURES_ROOT, "model.tar.gz"))
 
         # when
         results = [
-- 
GitLab