Skip to content
Snippets Groups Projects

Enhanced dependency parsing

Merged Mateusz Klimaszewski requested to merge enhanced_dependency_parsing into develop
Compare and
29 files
+ 1490
101
Compare changes
  • Side-by-side
  • Inline
Files
29
@@ -5,7 +5,7 @@ from typing import Set, List, Callable, Iterator, Union, Dict
import torch
from allennlp import data
from allennlp.common import checks, util
from allennlp.common import checks
from allennlp.data import fields
from overrides import overrides
@@ -17,15 +17,16 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
A `SequenceMultiLabelField` is an extension of the :class:`MultiLabelField` that allows for multiple labels
while keeping sequence dimension.
This field will get converted into a sequence of vectors of length equal to the vocabulary size with
M from N encoding for the labels (all zeros, and ones for the labels).
To allow configuration to different circumstances, class takes few delegates functions.
# Parameters
multi_labels : `List[List[str]]`
multi_label_indexer : `Callable[[data.Vocabulary], Callable[[List[str]], List[int]]]`
Nested callable which based on vocab creates mapper for multilabel field in the sequence from strings
to indexed, int values.
Nested callable which based on vocab and sequence length maps values of the fields in the sequence
from strings to indexed, int values.
as_tensor: `Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]]`
Nested callable which based on the field itself, maps indexed data to a tensor.
sequence_field : `SequenceField`
A field containing the sequence that this `SequenceMultiLabelField` is labeling. Most often, this is a
`TextField`, for tagging individual tokens in a sentence.
@@ -43,7 +44,8 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
def __init__(
self,
multi_labels: List[List[str]],
multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str]], List[int]]],
multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str], int], List[int]]],
as_tensor: Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]],
sequence_field: fields.SequenceField,
label_namespace: str = "labels",
) -> None:
@@ -53,6 +55,7 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
self._label_namespace = label_namespace
self._indexed_multi_labels = None
self._maybe_warn_for_namespace(label_namespace)
self.as_tensor_wrapper = as_tensor(self)
if len(multi_labels) != sequence_field.sequence_length():
raise checks.ConfigurationError(
"Label length and sequence length "
@@ -101,7 +104,7 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
indexed = []
for multi_label in self.multi_labels:
indexed.append(indexer(multi_label))
indexed.append(indexer(multi_label, len(self.multi_labels)))
self._indexed_multi_labels = indexed
@overrides
@@ -110,19 +113,13 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
@overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
desired_num_tokens = padding_lengths["num_tokens"]
assert len(self._indexed_multi_labels) > 0
classes_count = len(self._indexed_multi_labels[0])
default_value = [0.0] * classes_count
padded_tags = util.pad_sequence_to_length(self._indexed_multi_labels, desired_num_tokens, lambda: default_value)
tensor = torch.LongTensor(padded_tags)
return tensor
return self.as_tensor_wrapper(padding_lengths)
@overrides
def empty_field(self) -> "SequenceMultiLabelField":
# The empty_list here is needed for mypy
empty_list: List[List[str]] = [[]]
sequence_label_field = SequenceMultiLabelField(empty_list, lambda x: lambda y: y,
lambda x: lambda y: y,
self.sequence_field.empty_field())
sequence_label_field._indexed_labels = empty_list
return sequence_label_field