From 03bc35b22f488d643b0746e6fcf4bd4c2a0dd3a4 Mon Sep 17 00:00:00 2001 From: pszenny <pszenny@e-science.pl> Date: Thu, 7 Oct 2021 09:44:07 +0200 Subject: [PATCH] Truncation while text_to_instance is called. --- combo/data/dataset.py | 23 +++++++++++++++++------ combo/predict.py | 14 +++++++++++--- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/combo/data/dataset.py b/combo/data/dataset.py index bdc8b20..4b0352a 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -8,7 +8,7 @@ import torch from allennlp import data as allen_data from allennlp.common import checks, util from allennlp.data import fields as allen_fields, vocabulary -from conllu import parser +from conllu import parser, TokenList from dataclasses import dataclass from overrides import overrides @@ -27,6 +27,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): features: List[str] = None, targets: List[str] = None, use_sem: bool = False, + max_input_embedder: int = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -48,6 +49,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): "Remove {} from either features or targets.".format(intersection) ) self.use_sem = use_sem + self.max_input_embedder = max_input_embedder # *.conllu readers configuration fields = list(parser.DEFAULT_FIELDS) @@ -88,13 +90,16 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): @overrides def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance: + if self.max_input_embedder: + tree = TokenList(tokens = tree.tokens[: self.max_input_embedder], + metadata = tree.metadata) fields_: Dict[str, allen_data.Field] = {} tree_tokens = [t for t in tree if isinstance(t["id"], int)] tokens = [_Token(t["token"], - pos_=t.get("upostag"), - tag_=t.get("xpostag"), - lemma_=t.get("lemma"), - feats_=t.get("feats")) + pos_=t.get("upostag"), + tag_=t.get("xpostag"), + lemma_=t.get("lemma"), + feats_=t.get("feats")) for t in tree_tokens] # features @@ -117,7 +122,11 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): text_field, label_namespace="feats_labels") elif target_name == "head": - target_values = [0 if v == "_" else int(v) for v in target_values] + if self.max_input_embedder: + target_values = [0 if v == "_" else int(v) for v in target_values] + target_values = [v for v in target_values if v < self.max_input_embedder] + else: + target_values = [0 if v == "_" else int(v) for v in target_values] fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field, label_namespace=target_name + "_labels") elif target_name == "deps": @@ -130,6 +139,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): t_deps = t["deps"] if t_deps and t_deps != "_": for rel, head in t_deps: + if int(head) >= self.max_input_embedder: + continue # EmoryNLP skips the first edge, if there are two edges between the same # nodes. Thanks to that one is in a tree and another in a graph. # This snippet follows that approach. diff --git a/combo/predict.py b/combo/predict.py index 01a0837..b235389 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -228,7 +228,8 @@ class COMBO(predictor.Predictor): @classmethod def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), batch_size: int = 1024, - cuda_device: int = -1): + cuda_device: int = -1, + max_input_embedder: int = None): util.import_module_and_submodules("combo.commands") util.import_module_and_submodules("combo.models") util.import_module_and_submodules("combo.training") @@ -245,6 +246,13 @@ class COMBO(predictor.Predictor): archive = models.load_archive(model_path, cuda_device=cuda_device) model = archive.model - dataset_reader = allen_data.DatasetReader.from_params( - archive.config["dataset_reader"]) + dataset_reader = allen_data.DatasetReader.from_params(archive.config["dataset_reader"], + max_input_embedder = max_input_embedder) + + logger.info("Using pretrained transformer embedder may require truncating tokenized sentences.") + if max_input_embedder: + logger.info(f"Currently they are truncated to {max_input_embedder} first tokens") + else: + logger.info("Currently they are not truncated") + return cls(model, dataset_reader, tokenizer, batch_size) -- GitLab