Skip to content
Snippets Groups Projects
Commit c698ce09 authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

Re-implement COMBO's SequenceMultiLabelField

parent 69d85921
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
from .base_field import Field
from .field import Field
from .sequence_multilabel_field import SequenceMultiLabelField
class Field:
pass
from copy import deepcopy
from typing import Dict, Generic, List, TypeVar, Any
import torch
from combo.data.vocabulary import Vocabulary
"""
Adapted from https://github.com/allenai/allennlp/blob/main/allennlp/data/fields/field.py
"""
DataArray = TypeVar(
"DataArray", torch.Tensor, Dict[str, torch.Tensor], Dict[str, Dict[str, torch.Tensor]]
)
class Field(Generic[DataArray]):
"""
A `Field` is some piece of a data instance that ends up as an tensor in a model (either as an
input or an output). Data instances are just collections of fields.
Fields go through up to two steps of processing: (1) tokenized fields are converted into token
ids, (2) fields containing token ids (or any other numeric data) are padded (if necessary) and
converted into tensors. The `Field` API has methods around both of these steps, though they
may not be needed for some concrete `Field` classes - if your field doesn't have any strings
that need indexing, you don't need to implement `count_vocab_items` or `index`. These
methods `pass` by default.
Once a vocabulary is computed and all fields are indexed, we will determine padding lengths,
then intelligently batch together instances and pad them into actual tensors.
"""
__slots__ = [] # type: ignore
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
"""
If there are strings in this field that need to be converted into integers through a
:class:`Vocabulary`, here is where we count them, to determine which tokens are in or out
of the vocabulary.
If your `Field` does not have any strings that need to be converted into indices, you do
not need to implement this method.
A note on this `counter`: because `Fields` can represent conceptually different things,
we separate the vocabulary items by `namespaces`. This way, we can use a single shared
mechanism to handle all mappings from strings to integers in all fields, while keeping
words in a `TextField` from sharing the same ids with labels in a `LabelField` (e.g.,
"entailment" or "contradiction" are labels in an entailment task)
Additionally, a single `Field` might want to use multiple namespaces - `TextFields` can
be represented as a combination of word ids and character ids, and you don't want words and
characters to share the same vocabulary - "a" as a word should get a different id from "a"
as a character, and the vocabulary sizes of words and characters are very different.
Because of this, the first key in the `counter` object is a `namespace`, like "tokens",
"token_characters", "tags", or "labels", and the second key is the actual vocabulary item.
"""
pass
def human_readable_repr(self) -> Any:
"""
This method should be implemented by subclasses to return a structured, yet human-readable
representation of the field.
!!! Note
`human_readable_repr()` is not meant to be used as a method to serialize a `Field` since the return
value does not necessarily contain all of the attributes of the `Field` instance. But the object
returned should be JSON-serializable.
"""
raise NotImplementedError
def index(self, vocab: Vocabulary):
"""
Given a :class:`Vocabulary`, converts all strings in this field into (typically) integers.
This `modifies` the `Field` object, it does not return anything.
If your `Field` does not have any strings that need to be converted into indices, you do
not need to implement this method.
"""
pass
def get_padding_lengths(self) -> Dict[str, int]:
"""
If there are things in this field that need padding, note them here. In order to pad a
batch of instance, we get all of the lengths from the batch, take the max, and pad
everything to that length (or use a pre-specified maximum length). The return value is a
dictionary mapping keys to lengths, like `{'num_tokens': 13}`.
This is always called after :func:`index`.
"""
raise NotImplementedError
def as_tensor(self, padding_lengths: Dict[str, int]) -> DataArray:
"""
Given a set of specified padding lengths, actually pad the data in this field and return a
torch Tensor (or a more complex data structure) of the correct shape. We also take a
couple of parameters that are important when constructing torch Tensors.
# Parameters
padding_lengths : `Dict[str, int]`
This dictionary will have the same keys that were produced in
:func:`get_padding_lengths`. The values specify the lengths to use when padding each
relevant dimension, aggregated across all instances in a batch.
"""
raise NotImplementedError
def empty_field(self) -> "Field":
"""
So that `ListField` can pad the number of fields in a list (e.g., the number of answer
option `TextFields`), we need a representation of an empty field of each type. This
returns that. This will only ever be called when we're to the point of calling
:func:`as_tensor`, so you don't need to worry about `get_padding_lengths`,
`count_vocab_items`, etc., being called on this empty field.
We make this an instance method instead of a static method so that if there is any state
in the Field, we can copy it over (e.g., the token indexers in `TextField`).
"""
raise NotImplementedError
def batch_tensors(self, tensor_list: List[DataArray]) -> DataArray: # type: ignore
"""
Takes the output of `Field.as_tensor()` from a list of `Instances` and merges it into
one batched tensor for this `Field`. The default implementation here in the base class
handles cases where `as_tensor` returns a single torch tensor per instance. If your
subclass returns something other than this, you need to override this method.
This operation does not modify `self`, but in some cases we need the information
contained in `self` in order to perform the batching, so this is an instance method, not
a class method.
"""
return torch.stack(tensor_list)
def __eq__(self, other) -> bool:
if isinstance(self, other.__class__):
# With the way "slots" classes work, self.__slots__ only gives the slots defined
# by the current class, but not any of its base classes. Therefore to truly
# check for equality we have to check through all of the slots in all of the
# base classes as well.
for class_ in self.__class__.mro():
for attr in getattr(class_, "__slots__", []):
if getattr(self, attr) != getattr(other, attr):
return False
# It's possible that a subclass was not defined as a slots class, in which
# case we'll need to check __dict__.
if hasattr(self, "__dict__"):
return self.__dict__ == other.__dict__
return True
return NotImplemented
def __len__(self):
raise NotImplementedError
def duplicate(self):
return deepcopy(self)
\ No newline at end of file
from combo.data.fields.field import DataArray, Field
"""
Adapted from https://github.com/allenai/allennlp/blob/main/allennlp/data/fields/sequence_field.py
"""
class SequenceField(Field[DataArray]):
"""
A `SequenceField` represents a sequence of things. This class just adds a method onto
`Field`: :func:`sequence_length`. It exists so that `SequenceLabelField`, `IndexField` and other
similar `Fields` can have a single type to require, with a consistent API, whether they are
pointing to words in a `TextField`, items in a `ListField`, or something else.
"""
__slots__ = [] # type: ignore
def sequence_length(self) -> int:
"""
How many elements are there in this sequence?
"""
raise NotImplementedError
def empty_field(self) -> "SequenceField":
raise NotImplementedError
"""Sequence multilabel field implementation."""
"""
Sequence multilabel field implementation
Adapted from original COMBO.
Author: Mateusz Klimaszewski
"""
import logging
import textwrap
from typing import Set, List, Callable, Iterator, Union, Dict
import torch
from overrides import overrides
from combo.data import Vocabulary
from combo.data.fields import Field
from combo.data.fields.sequence_field import SequenceField
from combo.utils import ConfigurationError
logger = logging.getLogger(__name__)
class SequenceMultiLabelField(Field):
"""
Adapted from original COMBO
Author: Mateusz Klimaszewski
"""
class SequenceMultiLabelField(Field[torch.Tensor]):
"""
A `SequenceMultiLabelField` is an extension of the :class:`MultiLabelField` that allows for multiple labels
while keeping sequence dimension.
......@@ -42,9 +57,9 @@ class SequenceMultiLabelField(Field):
def __init__(
self,
multi_labels: List[List[str]],
multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str], int], List[int]]],
multi_label_indexer: Callable[[Vocabulary], Callable[[List[str], int], List[int]]],
as_tensor: Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]],
sequence_field: fields.SequenceField,
sequence_field: SequenceField,
label_namespace: str = "labels",
) -> None:
self.multi_labels = multi_labels
......@@ -55,13 +70,13 @@ class SequenceMultiLabelField(Field):
self._maybe_warn_for_namespace(label_namespace)
self.as_tensor_wrapper = as_tensor(self)
if len(multi_labels) != sequence_field.sequence_length():
raise checks.ConfigurationError(
raise ConfigurationError(
"Label length and sequence length "
"don't match: %d and %d" % (len(multi_labels), sequence_field.sequence_length())
)
if not all([isinstance(x, str) for multi_label in multi_labels for x in multi_label]):
raise checks.ConfigurationError(
raise ConfigurationError(
"SequenceMultiLabelField must be passed either all "
"strings or all ints. Found labels {} with "
"types: {}.".format(multi_labels, [type(x) for multi_label in multi_labels for x in multi_label])
......@@ -91,23 +106,43 @@ class SequenceMultiLabelField(Field):
@overrides
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
pass
if self._indexed_multi_labels is None:
for multi_label in self.multi_labels:
for label in multi_label:
counter[self._label_namespace][label] += 1 # type: ignore
@overrides
def index(self, vocab: data.Vocabulary):
pass
def index(self, vocab: Vocabulary):
indexer = self.multi_label_indexer(vocab)
indexed = []
for multi_label in self.multi_labels:
indexed.append(indexer(multi_label, len(self.multi_labels)))
self._indexed_multi_labels = indexed
@overrides
def get_padding_lengths(self) -> Dict[str, int]:
pass
return {"num_tokens": self.sequence_field.sequence_length()}
@overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
pass
return self.as_tensor_wrapper(padding_lengths)
@overrides
def empty_field(self) -> "SequenceMultiLabelField":
pass
empty_list: List[List[str]] = [[]]
sequence_label_field = SequenceMultiLabelField(empty_list, lambda x: lambda y: y,
lambda x: lambda y: y,
self.sequence_field.empty_field())
sequence_label_field._indexed_labels = empty_list
return sequence_label_field
def __str__(self) -> str:
pass
length = self.sequence_field.sequence_length()
formatted_labels = "".join(
"\t\t" + labels + "\n" for labels in textwrap.wrap(repr(self.multi_labels), 100)
)
return (
f"SequenceMultiLabelField of length {length} with "
f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
)
from .checks import *
from .sequence import *
"""
Adapted from AllenNLP
https://github.com/allenai/allennlp/blob/main/allennlp/common/util.py
"""
from typing import Any, Callable, List, Sequence
def pad_sequence_to_length(
sequence: Sequence,
desired_length: int,
default_value: Callable[[], Any] = lambda: 0,
padding_on_right: bool = True,
) -> List:
"""
Take a list of objects and pads it to the desired length, returning the padded list. The
original list is not modified.
# Parameters
sequence : `List`
A list of objects to be padded.
desired_length : `int`
Maximum length of each sequence. Longer sequences are truncated to this length, and
shorter ones are padded to it.
default_value: `Callable`, optional (default=`lambda: 0`)
Callable that outputs a default value (of any type) to use as padding values. This is
a lambda to avoid using the same object when the default value is more complex, like a
list.
padding_on_right : `bool`, optional (default=`True`)
When we add padding tokens (or truncate the sequence), should we do it on the right or
the left?
# Returns
padded_sequence : `List`
"""
sequence = list(sequence)
# Truncates the sequence to the desired length.
if padding_on_right:
padded_sequence = sequence[:desired_length]
else:
padded_sequence = sequence[-desired_length:]
# Continues to pad with default_value() until we reach the desired length.
pad_length = desired_length - len(padded_sequence)
# This just creates the default value once, so if it's a list, and if it gets mutated
# later, it could cause subtle bugs. But the risk there is low, and this is much faster.
values_to_pad = [default_value()] * pad_length
if padding_on_right:
padded_sequence = padded_sequence + values_to_pad
else:
padded_sequence = values_to_pad + padded_sequence
return padded_sequence
"""
Sequence multilabel field tests.
Adapted from original COMBO
Author: Mateusz Klimaszewski
"""
import unittest
from typing import List
import torch
from combo.data.fields.sequence_field import SequenceField
from combo.utils import pad_sequence_to_length
from combo.data import fields, Vocabulary
class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
def setUp(self) -> None:
self.namespace = "test_labels"
self.vocab = Vocabulary()
self.vocab.add_tokens_to_namespace(
tokens=["t" + str(idx) for idx in range(3)],
namespace=self.namespace
)
def _indexer(vocab: Vocabulary):
vocab_size = vocab.get_vocab_size(self.namespace)
def _mapper(multi_label: List[str], _: int) -> List[int]:
one_hot = [0] * vocab_size
for label in multi_label:
index = vocab.get_token_index(label, self.namespace)
one_hot[index] = 1
return one_hot
return _mapper
def _as_tensor(field: fields.SequenceMultiLabelField):
def _wrapped(padding_lengths):
desired_num_tokens = padding_lengths["num_tokens"]
classes_count = len(field._indexed_multi_labels[0])
default_value = [0.0] * classes_count
padded_tags = pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens,
lambda: default_value)
tensor = torch.LongTensor(padded_tags)
return tensor
return _wrapped
self.indexer = _indexer
self.as_tensor = _as_tensor
self.sequence_field = _SequenceFieldTestWrapper(self.vocab.get_vocab_size(self.namespace))
def test_indexing(self):
# given
field = fields.SequenceMultiLabelField(
multi_labels=[["t1", "t2"], [], ["t0"]],
multi_label_indexer=self.indexer,
as_tensor=self.as_tensor,
sequence_field=self.sequence_field,
label_namespace=self.namespace
)
expected = [[0, 1, 1], [0, 0, 0], [1, 0, 0]]
# when
field.index(self.vocab)
# then
self.assertEqual(expected, field._indexed_multi_labels)
def test_mapping_to_tensor(self):
# given
field = fields.SequenceMultiLabelField(
multi_labels=[["t1", "t2"], [], ["t0"]],
multi_label_indexer=self.indexer,
as_tensor=self.as_tensor,
sequence_field=self.sequence_field,
label_namespace=self.namespace
)
field.index(self.vocab)
expected = torch.tensor([[0, 1, 1], [0, 0, 0], [1, 0, 0]])
# when
actual = field.as_tensor(field.get_padding_lengths())
# then
self.assertTrue(torch.all(expected.eq(actual)))
def test_sequence_method(self):
# given
field = fields.SequenceMultiLabelField(
multi_labels=[["t1", "t2"], [], ["t0"]],
multi_label_indexer=self.indexer,
as_tensor=self.as_tensor,
sequence_field=self.sequence_field,
label_namespace=self.namespace
)
# when
length = len(field)
iter_length = len(list(iter(field)))
middle_value = field[1]
# then
self.assertEqual(3, length)
self.assertEqual(3, iter_length)
self.assertEqual([], middle_value)
class _SequenceFieldTestWrapper(SequenceField):
def __init__(self, length: int):
self.length = length
def sequence_length(self) -> int:
return self.length
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment