diff --git a/combo/data/dataset.py b/combo/data/dataset.py index 459a755c7f71c40e449d0542bf9af21d05e1f2c9..0b53df3284cde5d11e95ecd496059b6f55921e35 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -1,9 +1,10 @@ import logging -from typing import Union, List, Dict, Iterable, Optional, Any +from typing import Union, List, Dict, Iterable, Optional, Any, Tuple import conllu +import torch from allennlp import data as allen_data -from allennlp.common import checks +from allennlp.common import checks, util from allennlp.data import fields as allen_fields, vocabulary from conllu import parser from dataclasses import dataclass @@ -35,6 +36,9 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): if "token" not in features and "char" not in features: raise checks.ConfigurationError("There must be at least one ('char' or 'token') text-based feature!") + if "deps" in targets and not ("head" in targets and "deprel" in targets): + raise checks.ConfigurationError("Add 'head' and 'deprel' to targets when using 'deps'!") + intersection = set(features).intersection(set(targets)) if len(intersection) != 0: raise checks.ConfigurationError( @@ -102,13 +106,40 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): elif target_name == "feats": target_values = self._feat_values(tree_tokens) fields_[target_name] = fields.SequenceMultiLabelField(target_values, - self._feats_to_index_multi_label, + self._feats_indexer, + self._feats_as_tensor_wrapper, text_field, label_namespace="feats_labels") elif target_name == "head": target_values = [0 if v == "_" else int(v) for v in target_values] fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field, label_namespace=target_name + "_labels") + elif target_name == "deps": + heads = [0 if t["head"] == "_" else int(t["head"]) for t in tree_tokens] + deprels = [t["deprel"] for t in tree_tokens] + enhanced_heads: List[Tuple[int, int]] = [] + enhanced_deprels: List[str] = [] + for idx, t in enumerate(tree_tokens): + enhanced_heads.append((idx, heads[idx])) + enhanced_deprels.append(deprels[idx]) + t_deps = t["deps"] + if t_deps and t_deps != "_": + t_heads, t_deprels = zip(*[tuple(d.split(":")) for d in t_deps.split("|")]) + enhanced_heads.extend([(idx, t) for t in t_heads]) + enhanced_deprels.extend(t_deprels) + fields_["enhanced_heads"] = allen_fields.AdjacencyField( + indices=enhanced_heads, + sequence_field=text_field, + label_namespace="enhanced_heads_labels", + padding_value=0, + ) + fields_["enhanced_deprels"] = allen_fields.AdjacencyField( + indices=enhanced_heads, + sequence_field=text_field, + labels=enhanced_deprels, + label_namespace="enhanced_deprels_labels", + padding_value=0, + ) else: fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field, label_namespace=target_name + "_labels") @@ -151,12 +182,26 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): return features @staticmethod - def _feats_to_index_multi_label(vocab: allen_data.Vocabulary): + def _feats_as_tensor_wrapper(field: fields.SequenceMultiLabelField): + def as_tensor(padding_lengths): + desired_num_tokens = padding_lengths["num_tokens"] + assert len(field._indexed_multi_labels) > 0 + classes_count = len(field._indexed_multi_labels[0]) + default_value = [0.0] * classes_count + padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens, + lambda: default_value) + tensor = torch.LongTensor(padded_tags) + return tensor + + return as_tensor + + @staticmethod + def _feats_indexer(vocab: allen_data.Vocabulary): label_namespace = "feats_labels" vocab_size = vocab.get_vocab_size(label_namespace) slices = get_slices_if_not_provided(vocab) - def _m_from_n_ones_encoding(multi_label: List[str]) -> List[int]: + def _m_from_n_ones_encoding(multi_label: List[str], sentence_length: int) -> List[int]: one_hot_encoding = [0] * vocab_size for cat, cat_indices in slices.items(): if cat not in ["__PAD__", "_"]: diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py index 4e98a148aee35e42af0b4828a031368fe0eafc12..b200580cf1edfba1e710b42bc19c4c4efdb0db4f 100644 --- a/combo/data/fields/sequence_multilabel_field.py +++ b/combo/data/fields/sequence_multilabel_field.py @@ -5,7 +5,7 @@ from typing import Set, List, Callable, Iterator, Union, Dict import torch from allennlp import data -from allennlp.common import checks, util +from allennlp.common import checks from allennlp.data import fields from overrides import overrides @@ -17,15 +17,16 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]): A `SequenceMultiLabelField` is an extension of the :class:`MultiLabelField` that allows for multiple labels while keeping sequence dimension. - This field will get converted into a sequence of vectors of length equal to the vocabulary size with - M from N encoding for the labels (all zeros, and ones for the labels). + To allow configuration to different circumstances, class takes few delegates functions. # Parameters multi_labels : `List[List[str]]` multi_label_indexer : `Callable[[data.Vocabulary], Callable[[List[str]], List[int]]]` - Nested callable which based on vocab creates mapper for multilabel field in the sequence from strings - to indexed, int values. + Nested callable which based on vocab and sequence length maps values of the fields in the sequence + from strings to indexed, int values. + as_tensor: `Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]]` + Nested callable which based on the field itself, maps indexed data to a tensor. sequence_field : `SequenceField` A field containing the sequence that this `SequenceMultiLabelField` is labeling. Most often, this is a `TextField`, for tagging individual tokens in a sentence. @@ -43,7 +44,8 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]): def __init__( self, multi_labels: List[List[str]], - multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str]], List[int]]], + multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str], int], List[int]]], + as_tensor: Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]], sequence_field: fields.SequenceField, label_namespace: str = "labels", ) -> None: @@ -53,6 +55,7 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]): self._label_namespace = label_namespace self._indexed_multi_labels = None self._maybe_warn_for_namespace(label_namespace) + self.as_tensor_wrapper = as_tensor(self) if len(multi_labels) != sequence_field.sequence_length(): raise checks.ConfigurationError( "Label length and sequence length " @@ -101,7 +104,7 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]): indexed = [] for multi_label in self.multi_labels: - indexed.append(indexer(multi_label)) + indexed.append(indexer(multi_label, len(self.multi_labels))) self._indexed_multi_labels = indexed @overrides @@ -110,19 +113,13 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]): @overrides def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: - desired_num_tokens = padding_lengths["num_tokens"] - assert len(self._indexed_multi_labels) > 0 - classes_count = len(self._indexed_multi_labels[0]) - default_value = [0.0] * classes_count - padded_tags = util.pad_sequence_to_length(self._indexed_multi_labels, desired_num_tokens, lambda: default_value) - tensor = torch.LongTensor(padded_tags) - return tensor + return self.as_tensor_wrapper(padding_lengths) @overrides def empty_field(self) -> "SequenceMultiLabelField": - # The empty_list here is needed for mypy empty_list: List[List[str]] = [[]] sequence_label_field = SequenceMultiLabelField(empty_list, lambda x: lambda y: y, + lambda x: lambda y: y, self.sequence_field.empty_field()) sequence_label_field._indexed_labels = empty_list return sequence_label_field diff --git a/combo/main.py b/combo/main.py index 44ad091f8c5004bf9f4e70323ec828bf44288447..3f7fad44a829b8d514c5c64acd486b99b644a739 100644 --- a/combo/main.py +++ b/combo/main.py @@ -18,7 +18,7 @@ from combo.utils import checks logger = logging.getLogger(__name__) _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"] -_TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent"] +_TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent", "deps"] FLAGS = flags.FLAGS flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict"], diff --git a/combo/models/model.py b/combo/models/model.py index 77b43e3c1a95e09b15c310409af0090f097d47fa..ec0f113be9a3c00868f5c9cce3d8f75f789bcb81 100644 --- a/combo/models/model.py +++ b/combo/models/model.py @@ -53,7 +53,9 @@ class SemanticMultitaskModel(allen_models.Model): feats: torch.Tensor = None, head: torch.Tensor = None, deprel: torch.Tensor = None, - semrel: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: + semrel: torch.Tensor = None, + enhanced_heads: torch.Tensor = None, + enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]: # Prepare masks char_mask: torch.BoolTensor = sentence["char"]["token_characters"] > 0 diff --git a/tests/data/fields/test_sequence_multilabel_field.py b/tests/data/fields/test_sequence_multilabel_field.py index d2a1f8bc6b4d853e427d6a7214592f1b881db52c..fff8ff4ee285215aa4015d1c55609f3c2d28a3a6 100644 --- a/tests/data/fields/test_sequence_multilabel_field.py +++ b/tests/data/fields/test_sequence_multilabel_field.py @@ -4,6 +4,7 @@ from typing import List import torch from allennlp import data as allen_data +from allennlp.common import util from allennlp.data import fields as allen_fields from combo.data import fields @@ -22,7 +23,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase): def _indexer(vocab: allen_data.Vocabulary): vocab_size = vocab.get_vocab_size(self.namespace) - def _mapper(multi_label: List[str]) -> List[int]: + def _mapper(multi_label: List[str], _: int) -> List[int]: one_hot = [0] * vocab_size for label in multi_label: index = vocab.get_token_index(label, self.namespace) @@ -31,7 +32,21 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase): return _mapper + def _as_tensor(field: fields.SequenceMultiLabelField): + + def _wrapped(padding_lengths): + desired_num_tokens = padding_lengths["num_tokens"] + classes_count = len(field._indexed_multi_labels[0]) + default_value = [0.0] * classes_count + padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens, + lambda: default_value) + tensor = torch.LongTensor(padded_tags) + return tensor + + return _wrapped + self.indexer = _indexer + self.as_tensor = _as_tensor self.sequence_field = _SequenceFieldTestWrapper(self.vocab.get_vocab_size(self.namespace)) def test_indexing(self): @@ -39,6 +54,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase): field = fields.SequenceMultiLabelField( multi_labels=[["t1", "t2"], [], ["t0"]], multi_label_indexer=self.indexer, + as_tensor=self.as_tensor, sequence_field=self.sequence_field, label_namespace=self.namespace ) @@ -55,6 +71,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase): field = fields.SequenceMultiLabelField( multi_labels=[["t1", "t2"], [], ["t0"]], multi_label_indexer=self.indexer, + as_tensor=self.as_tensor, sequence_field=self.sequence_field, label_namespace=self.namespace ) @@ -72,6 +89,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase): field = fields.SequenceMultiLabelField( multi_labels=[["t1", "t2"], [], ["t0"]], multi_label_indexer=self.indexer, + as_tensor=self.as_tensor, sequence_field=self.sequence_field, label_namespace=self.namespace )