Skip to content
Snippets Groups Projects
Commit 9cf28e9f authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Add enhanced UD data preprocessing.

parent 89776c08
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
import logging import logging
from typing import Union, List, Dict, Iterable, Optional, Any from typing import Union, List, Dict, Iterable, Optional, Any, Tuple
import conllu import conllu
import torch
from allennlp import data as allen_data from allennlp import data as allen_data
from allennlp.common import checks from allennlp.common import checks, util
from allennlp.data import fields as allen_fields, vocabulary from allennlp.data import fields as allen_fields, vocabulary
from conllu import parser from conllu import parser
from dataclasses import dataclass from dataclasses import dataclass
...@@ -35,6 +36,9 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -35,6 +36,9 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
if "token" not in features and "char" not in features: if "token" not in features and "char" not in features:
raise checks.ConfigurationError("There must be at least one ('char' or 'token') text-based feature!") raise checks.ConfigurationError("There must be at least one ('char' or 'token') text-based feature!")
if "deps" in targets and not ("head" in targets and "deprel" in targets):
raise checks.ConfigurationError("Add 'head' and 'deprel' to targets when using 'deps'!")
intersection = set(features).intersection(set(targets)) intersection = set(features).intersection(set(targets))
if len(intersection) != 0: if len(intersection) != 0:
raise checks.ConfigurationError( raise checks.ConfigurationError(
...@@ -102,13 +106,40 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -102,13 +106,40 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
elif target_name == "feats": elif target_name == "feats":
target_values = self._feat_values(tree_tokens) target_values = self._feat_values(tree_tokens)
fields_[target_name] = fields.SequenceMultiLabelField(target_values, fields_[target_name] = fields.SequenceMultiLabelField(target_values,
self._feats_to_index_multi_label, self._feats_indexer,
self._feats_as_tensor_wrapper,
text_field, text_field,
label_namespace="feats_labels") label_namespace="feats_labels")
elif target_name == "head": elif target_name == "head":
target_values = [0 if v == "_" else int(v) for v in target_values] target_values = [0 if v == "_" else int(v) for v in target_values]
fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field, fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
label_namespace=target_name + "_labels") label_namespace=target_name + "_labels")
elif target_name == "deps":
heads = [0 if t["head"] == "_" else int(t["head"]) for t in tree_tokens]
deprels = [t["deprel"] for t in tree_tokens]
enhanced_heads: List[Tuple[int, int]] = []
enhanced_deprels: List[str] = []
for idx, t in enumerate(tree_tokens):
enhanced_heads.append((idx, heads[idx]))
enhanced_deprels.append(deprels[idx])
t_deps = t["deps"]
if t_deps and t_deps != "_":
t_heads, t_deprels = zip(*[tuple(d.split(":")) for d in t_deps.split("|")])
enhanced_heads.extend([(idx, t) for t in t_heads])
enhanced_deprels.extend(t_deprels)
fields_["enhanced_heads"] = allen_fields.AdjacencyField(
indices=enhanced_heads,
sequence_field=text_field,
label_namespace="enhanced_heads_labels",
padding_value=0,
)
fields_["enhanced_deprels"] = allen_fields.AdjacencyField(
indices=enhanced_heads,
sequence_field=text_field,
labels=enhanced_deprels,
label_namespace="enhanced_deprels_labels",
padding_value=0,
)
else: else:
fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field, fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
label_namespace=target_name + "_labels") label_namespace=target_name + "_labels")
...@@ -151,12 +182,26 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -151,12 +182,26 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
return features return features
@staticmethod @staticmethod
def _feats_to_index_multi_label(vocab: allen_data.Vocabulary): def _feats_as_tensor_wrapper(field: fields.SequenceMultiLabelField):
def as_tensor(padding_lengths):
desired_num_tokens = padding_lengths["num_tokens"]
assert len(field._indexed_multi_labels) > 0
classes_count = len(field._indexed_multi_labels[0])
default_value = [0.0] * classes_count
padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens,
lambda: default_value)
tensor = torch.LongTensor(padded_tags)
return tensor
return as_tensor
@staticmethod
def _feats_indexer(vocab: allen_data.Vocabulary):
label_namespace = "feats_labels" label_namespace = "feats_labels"
vocab_size = vocab.get_vocab_size(label_namespace) vocab_size = vocab.get_vocab_size(label_namespace)
slices = get_slices_if_not_provided(vocab) slices = get_slices_if_not_provided(vocab)
def _m_from_n_ones_encoding(multi_label: List[str]) -> List[int]: def _m_from_n_ones_encoding(multi_label: List[str], sentence_length: int) -> List[int]:
one_hot_encoding = [0] * vocab_size one_hot_encoding = [0] * vocab_size
for cat, cat_indices in slices.items(): for cat, cat_indices in slices.items():
if cat not in ["__PAD__", "_"]: if cat not in ["__PAD__", "_"]:
......
...@@ -5,7 +5,7 @@ from typing import Set, List, Callable, Iterator, Union, Dict ...@@ -5,7 +5,7 @@ from typing import Set, List, Callable, Iterator, Union, Dict
import torch import torch
from allennlp import data from allennlp import data
from allennlp.common import checks, util from allennlp.common import checks
from allennlp.data import fields from allennlp.data import fields
from overrides import overrides from overrides import overrides
...@@ -17,15 +17,16 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]): ...@@ -17,15 +17,16 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
A `SequenceMultiLabelField` is an extension of the :class:`MultiLabelField` that allows for multiple labels A `SequenceMultiLabelField` is an extension of the :class:`MultiLabelField` that allows for multiple labels
while keeping sequence dimension. while keeping sequence dimension.
This field will get converted into a sequence of vectors of length equal to the vocabulary size with To allow configuration to different circumstances, class takes few delegates functions.
M from N encoding for the labels (all zeros, and ones for the labels).
# Parameters # Parameters
multi_labels : `List[List[str]]` multi_labels : `List[List[str]]`
multi_label_indexer : `Callable[[data.Vocabulary], Callable[[List[str]], List[int]]]` multi_label_indexer : `Callable[[data.Vocabulary], Callable[[List[str]], List[int]]]`
Nested callable which based on vocab creates mapper for multilabel field in the sequence from strings Nested callable which based on vocab and sequence length maps values of the fields in the sequence
to indexed, int values. from strings to indexed, int values.
as_tensor: `Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]]`
Nested callable which based on the field itself, maps indexed data to a tensor.
sequence_field : `SequenceField` sequence_field : `SequenceField`
A field containing the sequence that this `SequenceMultiLabelField` is labeling. Most often, this is a A field containing the sequence that this `SequenceMultiLabelField` is labeling. Most often, this is a
`TextField`, for tagging individual tokens in a sentence. `TextField`, for tagging individual tokens in a sentence.
...@@ -43,7 +44,8 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]): ...@@ -43,7 +44,8 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
def __init__( def __init__(
self, self,
multi_labels: List[List[str]], multi_labels: List[List[str]],
multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str]], List[int]]], multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str], int], List[int]]],
as_tensor: Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]],
sequence_field: fields.SequenceField, sequence_field: fields.SequenceField,
label_namespace: str = "labels", label_namespace: str = "labels",
) -> None: ) -> None:
...@@ -53,6 +55,7 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]): ...@@ -53,6 +55,7 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
self._label_namespace = label_namespace self._label_namespace = label_namespace
self._indexed_multi_labels = None self._indexed_multi_labels = None
self._maybe_warn_for_namespace(label_namespace) self._maybe_warn_for_namespace(label_namespace)
self.as_tensor_wrapper = as_tensor(self)
if len(multi_labels) != sequence_field.sequence_length(): if len(multi_labels) != sequence_field.sequence_length():
raise checks.ConfigurationError( raise checks.ConfigurationError(
"Label length and sequence length " "Label length and sequence length "
...@@ -101,7 +104,7 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]): ...@@ -101,7 +104,7 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
indexed = [] indexed = []
for multi_label in self.multi_labels: for multi_label in self.multi_labels:
indexed.append(indexer(multi_label)) indexed.append(indexer(multi_label, len(self.multi_labels)))
self._indexed_multi_labels = indexed self._indexed_multi_labels = indexed
@overrides @overrides
...@@ -110,19 +113,13 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]): ...@@ -110,19 +113,13 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
@overrides @overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
desired_num_tokens = padding_lengths["num_tokens"] return self.as_tensor_wrapper(padding_lengths)
assert len(self._indexed_multi_labels) > 0
classes_count = len(self._indexed_multi_labels[0])
default_value = [0.0] * classes_count
padded_tags = util.pad_sequence_to_length(self._indexed_multi_labels, desired_num_tokens, lambda: default_value)
tensor = torch.LongTensor(padded_tags)
return tensor
@overrides @overrides
def empty_field(self) -> "SequenceMultiLabelField": def empty_field(self) -> "SequenceMultiLabelField":
# The empty_list here is needed for mypy
empty_list: List[List[str]] = [[]] empty_list: List[List[str]] = [[]]
sequence_label_field = SequenceMultiLabelField(empty_list, lambda x: lambda y: y, sequence_label_field = SequenceMultiLabelField(empty_list, lambda x: lambda y: y,
lambda x: lambda y: y,
self.sequence_field.empty_field()) self.sequence_field.empty_field())
sequence_label_field._indexed_labels = empty_list sequence_label_field._indexed_labels = empty_list
return sequence_label_field return sequence_label_field
......
...@@ -18,7 +18,7 @@ from combo.utils import checks ...@@ -18,7 +18,7 @@ from combo.utils import checks
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"] _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"]
_TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent"] _TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent", "deps"]
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict"], flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict"],
......
...@@ -53,7 +53,9 @@ class SemanticMultitaskModel(allen_models.Model): ...@@ -53,7 +53,9 @@ class SemanticMultitaskModel(allen_models.Model):
feats: torch.Tensor = None, feats: torch.Tensor = None,
head: torch.Tensor = None, head: torch.Tensor = None,
deprel: torch.Tensor = None, deprel: torch.Tensor = None,
semrel: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: semrel: torch.Tensor = None,
enhanced_heads: torch.Tensor = None,
enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
# Prepare masks # Prepare masks
char_mask: torch.BoolTensor = sentence["char"]["token_characters"] > 0 char_mask: torch.BoolTensor = sentence["char"]["token_characters"] > 0
......
...@@ -4,6 +4,7 @@ from typing import List ...@@ -4,6 +4,7 @@ from typing import List
import torch import torch
from allennlp import data as allen_data from allennlp import data as allen_data
from allennlp.common import util
from allennlp.data import fields as allen_fields from allennlp.data import fields as allen_fields
from combo.data import fields from combo.data import fields
...@@ -22,7 +23,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase): ...@@ -22,7 +23,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
def _indexer(vocab: allen_data.Vocabulary): def _indexer(vocab: allen_data.Vocabulary):
vocab_size = vocab.get_vocab_size(self.namespace) vocab_size = vocab.get_vocab_size(self.namespace)
def _mapper(multi_label: List[str]) -> List[int]: def _mapper(multi_label: List[str], _: int) -> List[int]:
one_hot = [0] * vocab_size one_hot = [0] * vocab_size
for label in multi_label: for label in multi_label:
index = vocab.get_token_index(label, self.namespace) index = vocab.get_token_index(label, self.namespace)
...@@ -31,7 +32,21 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase): ...@@ -31,7 +32,21 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
return _mapper return _mapper
def _as_tensor(field: fields.SequenceMultiLabelField):
def _wrapped(padding_lengths):
desired_num_tokens = padding_lengths["num_tokens"]
classes_count = len(field._indexed_multi_labels[0])
default_value = [0.0] * classes_count
padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens,
lambda: default_value)
tensor = torch.LongTensor(padded_tags)
return tensor
return _wrapped
self.indexer = _indexer self.indexer = _indexer
self.as_tensor = _as_tensor
self.sequence_field = _SequenceFieldTestWrapper(self.vocab.get_vocab_size(self.namespace)) self.sequence_field = _SequenceFieldTestWrapper(self.vocab.get_vocab_size(self.namespace))
def test_indexing(self): def test_indexing(self):
...@@ -39,6 +54,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase): ...@@ -39,6 +54,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
field = fields.SequenceMultiLabelField( field = fields.SequenceMultiLabelField(
multi_labels=[["t1", "t2"], [], ["t0"]], multi_labels=[["t1", "t2"], [], ["t0"]],
multi_label_indexer=self.indexer, multi_label_indexer=self.indexer,
as_tensor=self.as_tensor,
sequence_field=self.sequence_field, sequence_field=self.sequence_field,
label_namespace=self.namespace label_namespace=self.namespace
) )
...@@ -55,6 +71,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase): ...@@ -55,6 +71,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
field = fields.SequenceMultiLabelField( field = fields.SequenceMultiLabelField(
multi_labels=[["t1", "t2"], [], ["t0"]], multi_labels=[["t1", "t2"], [], ["t0"]],
multi_label_indexer=self.indexer, multi_label_indexer=self.indexer,
as_tensor=self.as_tensor,
sequence_field=self.sequence_field, sequence_field=self.sequence_field,
label_namespace=self.namespace label_namespace=self.namespace
) )
...@@ -72,6 +89,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase): ...@@ -72,6 +89,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
field = fields.SequenceMultiLabelField( field = fields.SequenceMultiLabelField(
multi_labels=[["t1", "t2"], [], ["t0"]], multi_labels=[["t1", "t2"], [], ["t0"]],
multi_label_indexer=self.indexer, multi_label_indexer=self.indexer,
as_tensor=self.as_tensor,
sequence_field=self.sequence_field, sequence_field=self.sequence_field,
label_namespace=self.namespace label_namespace=self.namespace
) )
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment