Skip to content
Snippets Groups Projects
Commit 31c595f8 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Handle multi word tokens during dataset reading.

parent 0b7636f8
No related branches found
No related tags found
No related merge requests found
...@@ -79,12 +79,13 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -79,12 +79,13 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
@overrides @overrides
def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance: def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance:
fields_: Dict[str, allen_data.Field] = {} fields_: Dict[str, allen_data.Field] = {}
tree_tokens = [t for t in tree if isinstance(t["id"], int)]
tokens = [_Token(t["token"], tokens = [_Token(t["token"],
pos_=t.get("upostag"), pos_=t.get("upostag"),
tag_=t.get("xpostag"), tag_=t.get("xpostag"),
lemma_=t.get("lemma"), lemma_=t.get("lemma"),
feats_=t.get("feats")) feats_=t.get("feats"))
for t in tree if isinstance(t["id"], int)] for t in tree_tokens]
# features # features
text_field = allen_fields.TextField(tokens, self._token_indexers) text_field = allen_fields.TextField(tokens, self._token_indexers)
...@@ -94,12 +95,12 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -94,12 +95,12 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
if self.generate_labels: if self.generate_labels:
for target_name in self._targets: for target_name in self._targets:
if target_name != "sent": if target_name != "sent":
target_values = [t[target_name] for t in tree.tokens] target_values = [t[target_name] for t in tree_tokens]
if target_name == "lemma": if target_name == "lemma":
target_values = [allen_data.Token(v) for v in target_values] target_values = [allen_data.Token(v) for v in target_values]
fields_[target_name] = allen_fields.TextField(target_values, self._lemma_indexers) fields_[target_name] = allen_fields.TextField(target_values, self._lemma_indexers)
elif target_name == "feats": elif target_name == "feats":
target_values = self._feat_values(tree) target_values = self._feat_values(tree_tokens)
fields_[target_name] = fields.SequenceMultiLabelField(target_values, fields_[target_name] = fields.SequenceMultiLabelField(target_values,
self._feats_to_index_multi_label, self._feats_to_index_multi_label,
text_field, text_field,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment