From 03bc35b22f488d643b0746e6fcf4bd4c2a0dd3a4 Mon Sep 17 00:00:00 2001
From: pszenny <pszenny@e-science.pl>
Date: Thu, 7 Oct 2021 09:44:07 +0200
Subject: [PATCH] Truncation while  text_to_instance is called.

---
 combo/data/dataset.py | 23 +++++++++++++++++------
 combo/predict.py      | 14 +++++++++++---
 2 files changed, 28 insertions(+), 9 deletions(-)

diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index bdc8b20..4b0352a 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -8,7 +8,7 @@ import torch
 from allennlp import data as allen_data
 from allennlp.common import checks, util
 from allennlp.data import fields as allen_fields, vocabulary
-from conllu import parser
+from conllu import parser, TokenList
 from dataclasses import dataclass
 from overrides import overrides
 
@@ -27,6 +27,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
             features: List[str] = None,
             targets: List[str] = None,
             use_sem: bool = False,
+            max_input_embedder: int = None,
             **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -48,6 +49,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                 "Remove {} from either features or targets.".format(intersection)
             )
         self.use_sem = use_sem
+        self.max_input_embedder = max_input_embedder
 
         # *.conllu readers configuration
         fields = list(parser.DEFAULT_FIELDS)
@@ -88,13 +90,16 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
 
     @overrides
     def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance:
+        if self.max_input_embedder:
+            tree = TokenList(tokens = tree.tokens[: self.max_input_embedder],
+                             metadata = tree.metadata)
         fields_: Dict[str, allen_data.Field] = {}
         tree_tokens = [t for t in tree if isinstance(t["id"], int)]
         tokens = [_Token(t["token"],
-                         pos_=t.get("upostag"),
-                         tag_=t.get("xpostag"),
-                         lemma_=t.get("lemma"),
-                         feats_=t.get("feats"))
+                  pos_=t.get("upostag"),
+                  tag_=t.get("xpostag"),
+                  lemma_=t.get("lemma"),
+                  feats_=t.get("feats"))
                   for t in tree_tokens]
 
         # features
@@ -117,7 +122,11 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                                                                               text_field,
                                                                               label_namespace="feats_labels")
                     elif target_name == "head":
-                        target_values = [0 if v == "_" else int(v) for v in target_values]
+                        if self.max_input_embedder:
+                            target_values = [0 if v == "_" else int(v) for v in target_values]
+                            target_values = [v for v in target_values if v < self.max_input_embedder]
+                        else:
+                            target_values = [0 if v == "_" else int(v) for v in target_values]
                         fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
                                                                                label_namespace=target_name + "_labels")
                     elif target_name == "deps":
@@ -130,6 +139,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                             t_deps = t["deps"]
                             if t_deps and t_deps != "_":
                                 for rel, head in t_deps:
+                                    if int(head) >= self.max_input_embedder:
+                                        continue
                                     # EmoryNLP skips the first edge, if there are two edges between the same
                                     # nodes. Thanks to that one is in a tree and another in a graph.
                                     # This snippet follows that approach.
diff --git a/combo/predict.py b/combo/predict.py
index 01a0837..b235389 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -228,7 +228,8 @@ class COMBO(predictor.Predictor):
     @classmethod
     def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
                         batch_size: int = 1024,
-                        cuda_device: int = -1):
+                        cuda_device: int = -1,
+                        max_input_embedder: int = None):
         util.import_module_and_submodules("combo.commands")
         util.import_module_and_submodules("combo.models")
         util.import_module_and_submodules("combo.training")
@@ -245,6 +246,13 @@ class COMBO(predictor.Predictor):
 
         archive = models.load_archive(model_path, cuda_device=cuda_device)
         model = archive.model
-        dataset_reader = allen_data.DatasetReader.from_params(
-            archive.config["dataset_reader"])
+        dataset_reader = allen_data.DatasetReader.from_params(archive.config["dataset_reader"],
+                                                              max_input_embedder = max_input_embedder)
+
+        logger.info("Using pretrained transformer embedder may require truncating tokenized sentences.")
+        if max_input_embedder:
+            logger.info(f"Currently they are truncated to {max_input_embedder} first tokens")
+        else:
+            logger.info("Currently they are not truncated")
+
         return cls(model, dataset_reader, tokenizer, batch_size)
-- 
GitLab