Skip to content
Snippets Groups Projects
Commit 03bc35b2 authored by Łukasz Pszenny's avatar Łukasz Pszenny
Browse files

Truncation while text_to_instance is called.

parent 35c211f7
1 merge request!39Develop
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from allennlp import data as allen_data from allennlp import data as allen_data
from allennlp.common import checks, util from allennlp.common import checks, util
from allennlp.data import fields as allen_fields, vocabulary from allennlp.data import fields as allen_fields, vocabulary
from conllu import parser from conllu import parser, TokenList
from dataclasses import dataclass from dataclasses import dataclass
from overrides import overrides from overrides import overrides
...@@ -27,6 +27,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -27,6 +27,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
features: List[str] = None, features: List[str] = None,
targets: List[str] = None, targets: List[str] = None,
use_sem: bool = False, use_sem: bool = False,
max_input_embedder: int = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -48,6 +49,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -48,6 +49,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
"Remove {} from either features or targets.".format(intersection) "Remove {} from either features or targets.".format(intersection)
) )
self.use_sem = use_sem self.use_sem = use_sem
self.max_input_embedder = max_input_embedder
# *.conllu readers configuration # *.conllu readers configuration
fields = list(parser.DEFAULT_FIELDS) fields = list(parser.DEFAULT_FIELDS)
...@@ -88,13 +90,16 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -88,13 +90,16 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
@overrides @overrides
def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance: def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance:
if self.max_input_embedder:
tree = TokenList(tokens = tree.tokens[: self.max_input_embedder],
metadata = tree.metadata)
fields_: Dict[str, allen_data.Field] = {} fields_: Dict[str, allen_data.Field] = {}
tree_tokens = [t for t in tree if isinstance(t["id"], int)] tree_tokens = [t for t in tree if isinstance(t["id"], int)]
tokens = [_Token(t["token"], tokens = [_Token(t["token"],
pos_=t.get("upostag"), pos_=t.get("upostag"),
tag_=t.get("xpostag"), tag_=t.get("xpostag"),
lemma_=t.get("lemma"), lemma_=t.get("lemma"),
feats_=t.get("feats")) feats_=t.get("feats"))
for t in tree_tokens] for t in tree_tokens]
# features # features
...@@ -117,7 +122,11 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -117,7 +122,11 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
text_field, text_field,
label_namespace="feats_labels") label_namespace="feats_labels")
elif target_name == "head": elif target_name == "head":
target_values = [0 if v == "_" else int(v) for v in target_values] if self.max_input_embedder:
target_values = [0 if v == "_" else int(v) for v in target_values]
target_values = [v for v in target_values if v < self.max_input_embedder]
else:
target_values = [0 if v == "_" else int(v) for v in target_values]
fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field, fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
label_namespace=target_name + "_labels") label_namespace=target_name + "_labels")
elif target_name == "deps": elif target_name == "deps":
...@@ -130,6 +139,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -130,6 +139,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
t_deps = t["deps"] t_deps = t["deps"]
if t_deps and t_deps != "_": if t_deps and t_deps != "_":
for rel, head in t_deps: for rel, head in t_deps:
if int(head) >= self.max_input_embedder:
continue
# EmoryNLP skips the first edge, if there are two edges between the same # EmoryNLP skips the first edge, if there are two edges between the same
# nodes. Thanks to that one is in a tree and another in a graph. # nodes. Thanks to that one is in a tree and another in a graph.
# This snippet follows that approach. # This snippet follows that approach.
......
...@@ -228,7 +228,8 @@ class COMBO(predictor.Predictor): ...@@ -228,7 +228,8 @@ class COMBO(predictor.Predictor):
@classmethod @classmethod
def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
batch_size: int = 1024, batch_size: int = 1024,
cuda_device: int = -1): cuda_device: int = -1,
max_input_embedder: int = None):
util.import_module_and_submodules("combo.commands") util.import_module_and_submodules("combo.commands")
util.import_module_and_submodules("combo.models") util.import_module_and_submodules("combo.models")
util.import_module_and_submodules("combo.training") util.import_module_and_submodules("combo.training")
...@@ -245,6 +246,13 @@ class COMBO(predictor.Predictor): ...@@ -245,6 +246,13 @@ class COMBO(predictor.Predictor):
archive = models.load_archive(model_path, cuda_device=cuda_device) archive = models.load_archive(model_path, cuda_device=cuda_device)
model = archive.model model = archive.model
dataset_reader = allen_data.DatasetReader.from_params( dataset_reader = allen_data.DatasetReader.from_params(archive.config["dataset_reader"],
archive.config["dataset_reader"]) max_input_embedder = max_input_embedder)
logger.info("Using pretrained transformer embedder may require truncating tokenized sentences.")
if max_input_embedder:
logger.info(f"Currently they are truncated to {max_input_embedder} first tokens")
else:
logger.info("Currently they are not truncated")
return cls(model, dataset_reader, tokenizer, batch_size) return cls(model, dataset_reader, tokenizer, batch_size)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment