From 847f3a903a1709f573fc6604778b00b8d09373e8 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Sun, 3 Jan 2021 18:47:16 +0100
Subject: [PATCH] Refactor prediction loop.

---
 combo/predict.py | 100 ++++++++++++++++++++++++-----------------------
 1 file changed, 51 insertions(+), 49 deletions(-)

diff --git a/combo/predict.py b/combo/predict.py
index a5c99fd..bd9f5d4 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -24,7 +24,7 @@ class COMBO(predictor.Predictor):
                  model: models.Model,
                  dataset_reader: allen_data.DatasetReader,
                  tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(),
-                 batch_size: int = 32,
+                 batch_size: int = 1024,
                  line_to_conllu: bool = True) -> None:
         super().__init__(model, dataset_reader)
         self.batch_size = batch_size
@@ -140,54 +140,56 @@ class COMBO(predictor.Predictor):
         tree = instance.fields["metadata"]["input"]
         field_names = instance.fields["metadata"]["field_names"]
         tree_tokens = [t for t in tree if isinstance(t["id"], int)]
-        for idx, token in enumerate(tree_tokens):
-            for field_name in field_names:
-                if field_name in predictions:
-                    if field_name in ["xpostag", "upostag", "semrel", "deprel"]:
-                        value = self.vocab.get_token_from_index(predictions[field_name][idx], field_name + "_labels")
-                        token[field_name] = value
-                    elif field_name in ["head"]:
-                        token[field_name] = int(predictions[field_name][idx])
-                    elif field_name == "deps":
-                        # Handled after every other decoding
-                        continue
-
-                    elif field_name in ["feats"]:
-                        slices = self._model.morphological_feat.slices
-                        features = []
-                        prediction = predictions[field_name][idx]
-                        for (cat, cat_indices), pred_idx in zip(slices.items(), prediction):
-                            if cat not in ["__PAD__", "_"]:
-                                value = self.vocab.get_token_from_index(cat_indices[pred_idx],
-                                                                        field_name + "_labels")
-                                # Exclude auxiliary values
-                                if "=None" not in value:
-                                    features.append(value)
-                        if len(features) == 0:
-                            field_value = "_"
-                        else:
-                            lowercase_features = [f.lower() for f in features]
-                            arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__)
-                            field_value = "|".join(np.array(features)[arg_indices].tolist())
-
-                        token[field_name] = field_value
-                    elif field_name == "lemma":
-                        prediction = predictions[field_name][idx]
-                        word_chars = []
-                        for char_idx in prediction[1:-1]:
-                            pred_char = self.vocab.get_token_from_index(char_idx, "lemma_characters")
-
-                            if pred_char == "__END__":
-                                break
-                            elif pred_char == "__PAD__":
-                                continue
-                            elif "_" in pred_char:
-                                pred_char = "?"
-
-                            word_chars.append(pred_char)
-                        token[field_name] = "".join(word_chars)
+        for field_name in field_names:
+            if field_name not in predictions:
+                continue
+            field_predictions = predictions[field_name]
+            for idx, token in enumerate(tree_tokens):
+                if field_name in {"xpostag", "upostag", "semrel", "deprel"}:
+                    value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels")
+                    token[field_name] = value
+                elif field_name == "head":
+                    token[field_name] = int(field_predictions[idx])
+                elif field_name == "deps":
+                    # Handled after every other decoding
+                    continue
+
+                elif field_name == "feats":
+                    slices = self._model.morphological_feat.slices
+                    features = []
+                    prediction = field_predictions[idx]
+                    for (cat, cat_indices), pred_idx in zip(slices.items(), prediction):
+                        if cat not in ["__PAD__", "_"]:
+                            value = self.vocab.get_token_from_index(cat_indices[pred_idx],
+                                                                    field_name + "_labels")
+                            # Exclude auxiliary values
+                            if "=None" not in value:
+                                features.append(value)
+                    if len(features) == 0:
+                        field_value = "_"
                     else:
-                        raise NotImplementedError(f"Unknown field name {field_name}!")
+                        lowercase_features = [f.lower() for f in features]
+                        arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__)
+                        field_value = "|".join(np.array(features)[arg_indices].tolist())
+
+                    token[field_name] = field_value
+                elif field_name == "lemma":
+                    prediction = field_predictions[idx]
+                    word_chars = []
+                    for char_idx in prediction[1:-1]:
+                        pred_char = self.vocab.get_token_from_index(char_idx, "lemma_characters")
+
+                        if pred_char == "__END__":
+                            break
+                        elif pred_char == "__PAD__":
+                            continue
+                        elif "_" in pred_char:
+                            pred_char = "?"
+
+                        word_chars.append(pred_char)
+                    token[field_name] = "".join(word_chars)
+                else:
+                    raise NotImplementedError(f"Unknown field name {field_name}!")
 
         if "enhanced_head" in predictions and predictions["enhanced_head"]:
             # TODO off-by-one hotfix, refactor
@@ -212,7 +214,7 @@ class COMBO(predictor.Predictor):
 
     @classmethod
     def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
-                        batch_size: int = 32,
+                        batch_size: int = 1024,
                         cuda_device: int = -1):
         util.import_module_and_submodules("combo.commands")
         util.import_module_and_submodules("combo.models")
-- 
GitLab