Skip to content
Snippets Groups Projects

Develop

9 files
+ 44
237
Compare changes
  • Side-by-side
  • Inline

Files

+ 6
17
@@ -8,7 +8,7 @@ import torch
from allennlp import data as allen_data
from allennlp.common import checks, util
from allennlp.data import fields as allen_fields, vocabulary
from conllu import parser, TokenList
from conllu import parser
from dataclasses import dataclass
from overrides import overrides
@@ -27,7 +27,6 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
features: List[str] = None,
targets: List[str] = None,
use_sem: bool = False,
max_input_embedder: int = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -49,7 +48,6 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
"Remove {} from either features or targets.".format(intersection)
)
self.use_sem = use_sem
self.max_input_embedder = max_input_embedder
# *.conllu readers configuration
fields = list(parser.DEFAULT_FIELDS)
@@ -90,16 +88,13 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
@overrides
def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance:
if self.max_input_embedder:
tree = TokenList(tokens = tree.tokens[: self.max_input_embedder],
metadata = tree.metadata)
fields_: Dict[str, allen_data.Field] = {}
tree_tokens = [t for t in tree if isinstance(t["id"], int)]
tokens = [_Token(t["token"],
pos_=t.get("upostag"),
tag_=t.get("xpostag"),
lemma_=t.get("lemma"),
feats_=t.get("feats"))
pos_=t.get("upostag"),
tag_=t.get("xpostag"),
lemma_=t.get("lemma"),
feats_=t.get("feats"))
for t in tree_tokens]
# features
@@ -122,11 +117,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
text_field,
label_namespace="feats_labels")
elif target_name == "head":
if self.max_input_embedder:
target_values = [0 if v == "_" else int(v) for v in target_values]
target_values = [v for v in target_values if v < self.max_input_embedder]
else:
target_values = [0 if v == "_" else int(v) for v in target_values]
target_values = [0 if v == "_" else int(v) for v in target_values]
fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
label_namespace=target_name + "_labels")
elif target_name == "deps":
@@ -139,8 +130,6 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
t_deps = t["deps"]
if t_deps and t_deps != "_":
for rel, head in t_deps:
if int(head) >= self.max_input_embedder:
continue
# EmoryNLP skips the first edge, if there are two edges between the same
# nodes. Thanks to that one is in a tree and another in a graph.
# This snippet follows that approach.
Loading