Skip to content
Snippets Groups Projects
Commit 03bc35b2 authored by Łukasz Pszenny's avatar Łukasz Pszenny
Browse files

Truncation while text_to_instance is called.

parent 35c211f7
1 merge request!39Develop
......@@ -8,7 +8,7 @@ import torch
from allennlp import data as allen_data
from allennlp.common import checks, util
from allennlp.data import fields as allen_fields, vocabulary
from conllu import parser
from conllu import parser, TokenList
from dataclasses import dataclass
from overrides import overrides
......@@ -27,6 +27,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
features: List[str] = None,
targets: List[str] = None,
use_sem: bool = False,
max_input_embedder: int = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
......@@ -48,6 +49,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
"Remove {} from either features or targets.".format(intersection)
)
self.use_sem = use_sem
self.max_input_embedder = max_input_embedder
# *.conllu readers configuration
fields = list(parser.DEFAULT_FIELDS)
......@@ -88,13 +90,16 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
@overrides
def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance:
if self.max_input_embedder:
tree = TokenList(tokens = tree.tokens[: self.max_input_embedder],
metadata = tree.metadata)
fields_: Dict[str, allen_data.Field] = {}
tree_tokens = [t for t in tree if isinstance(t["id"], int)]
tokens = [_Token(t["token"],
pos_=t.get("upostag"),
tag_=t.get("xpostag"),
lemma_=t.get("lemma"),
feats_=t.get("feats"))
pos_=t.get("upostag"),
tag_=t.get("xpostag"),
lemma_=t.get("lemma"),
feats_=t.get("feats"))
for t in tree_tokens]
# features
......@@ -117,7 +122,11 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
text_field,
label_namespace="feats_labels")
elif target_name == "head":
target_values = [0 if v == "_" else int(v) for v in target_values]
if self.max_input_embedder:
target_values = [0 if v == "_" else int(v) for v in target_values]
target_values = [v for v in target_values if v < self.max_input_embedder]
else:
target_values = [0 if v == "_" else int(v) for v in target_values]
fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
label_namespace=target_name + "_labels")
elif target_name == "deps":
......@@ -130,6 +139,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
t_deps = t["deps"]
if t_deps and t_deps != "_":
for rel, head in t_deps:
if int(head) >= self.max_input_embedder:
continue
# EmoryNLP skips the first edge, if there are two edges between the same
# nodes. Thanks to that one is in a tree and another in a graph.
# This snippet follows that approach.
......
......@@ -228,7 +228,8 @@ class COMBO(predictor.Predictor):
@classmethod
def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
batch_size: int = 1024,
cuda_device: int = -1):
cuda_device: int = -1,
max_input_embedder: int = None):
util.import_module_and_submodules("combo.commands")
util.import_module_and_submodules("combo.models")
util.import_module_and_submodules("combo.training")
......@@ -245,6 +246,13 @@ class COMBO(predictor.Predictor):
archive = models.load_archive(model_path, cuda_device=cuda_device)
model = archive.model
dataset_reader = allen_data.DatasetReader.from_params(
archive.config["dataset_reader"])
dataset_reader = allen_data.DatasetReader.from_params(archive.config["dataset_reader"],
max_input_embedder = max_input_embedder)
logger.info("Using pretrained transformer embedder may require truncating tokenized sentences.")
if max_input_embedder:
logger.info(f"Currently they are truncated to {max_input_embedder} first tokens")
else:
logger.info("Currently they are not truncated")
return cls(model, dataset_reader, tokenizer, batch_size)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment