Skip to content
Snippets Groups Projects
Commit 497db05f authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Add ListField

parent 7a2f7973
1 merge request!46Merge COMBO 3.0 into master
from .adjacency_field import AdjacencyField
from .field import Field
from .label_field import LabelField
from .list_field import ListField
from .metadata_field import MetadataField
from .sequence_field import SequenceField
from .sequence_label_field import SequenceLabelField
from .sequence_multilabel_field import SequenceMultiLabelField
from .text_field import TextField
"""
Adapted from AllenNLP
https://github.com/allenai/allennlp/blob/main/allennlp/data/fields/list_field.py
"""
from typing import Dict, List, Iterator, Sequence, Any
from combo.data import Vocabulary
from combo.data.fields.field import DataArray, Field
from combo.data.fields.sequence_field import SequenceField
from combo.utils import pad_sequence_to_length
class ListField(SequenceField[DataArray]):
"""
A `ListField` is a list of other fields. You would use this to represent, e.g., a list of
answer options that are themselves `TextFields`.
This field will get converted into a tensor that has one more mode than the items in the list.
If this is a list of `TextFields` that have shape (num_words, num_characters), this
`ListField` will output a tensor of shape (num_sentences, num_words, num_characters).
# Parameters
field_list : `List[Field]`
A list of `Field` objects to be concatenated into a single input tensor. All of the
contained `Field` objects must be of the same type.
"""
__slots__ = ["field_list"]
def __init__(self, field_list: Sequence[Field]) -> None:
field_class_set = {field.__class__ for field in field_list}
assert (
len(field_class_set) == 1
), "ListFields must contain a single field type, found " + str(field_class_set)
# Not sure why mypy has a hard time with this type...
self.field_list = field_list
# Sequence[Field] methods
def __iter__(self) -> Iterator[Field]:
return iter(self.field_list)
def __getitem__(self, idx: int) -> Field:
return self.field_list[idx]
def __len__(self) -> int:
return len(self.field_list)
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
for field in self.field_list:
field.count_vocab_items(counter)
def index(self, vocab: Vocabulary):
for field in self.field_list:
field.index(vocab)
def get_padding_lengths(self) -> Dict[str, int]:
field_lengths = [field.get_padding_lengths() for field in self.field_list]
padding_lengths = {"num_fields": len(self.field_list)}
# We take the set of all possible padding keys for all fields, rather
# than just a random key, because it is possible for fields to be empty
# when we pad ListFields.
possible_padding_keys = [
key for field_length in field_lengths for key in list(field_length.keys())
]
for key in set(possible_padding_keys):
# In order to be able to nest ListFields, we need to scope the padding length keys
# appropriately, so that nested ListFields don't all use the same "num_fields" key. So
# when we construct the dictionary from the list of fields, we add something to the
# name, and we remove it when padding the list of fields.
padding_lengths["list_" + key] = max(x[key] if key in x else 0 for x in field_lengths)
# Set minimum padding length to handle empty list fields.
for padding_key in padding_lengths:
padding_lengths[padding_key] = max(padding_lengths[padding_key], 1)
return padding_lengths
def sequence_length(self) -> int:
return len(self.field_list)
def as_tensor(self, padding_lengths: Dict[str, int]) -> DataArray:
padded_field_list = pad_sequence_to_length(
self.field_list, padding_lengths["num_fields"], self.field_list[0].empty_field
)
# Here we're removing the scoping on the padding length keys that we added in
# `get_padding_lengths`; see the note there for more detail.
child_padding_lengths = {
key.replace("list_", "", 1): value
for key, value in padding_lengths.items()
if key.startswith("list_")
}
padded_fields = [field.as_tensor(child_padding_lengths) for field in padded_field_list]
return self.field_list[0].batch_tensors(padded_fields)
def empty_field(self):
# Our "empty" list field will actually have a single field in the list, so that we can
# correctly construct nested lists. For example, if we have a type that is
# `ListField[ListField[LabelField]]`, we need the top-level `ListField` to know to
# construct a `ListField[LabelField]` when it's padding, and the nested `ListField` needs
# to know that it's empty objects are `LabelFields`. Having an "empty" list actually have
# length one makes this all work out, and we'll always be padding to at least length 1,
# anyway.
return ListField([self.field_list[0].empty_field()])
def batch_tensors(self, tensor_list: List[DataArray]) -> DataArray:
# We defer to the class we're wrapping in a list to handle the batching.
return self.field_list[0].batch_tensors(tensor_list)
def __str__(self) -> str:
field_class = self.field_list[0].__class__.__name__
base_string = f"ListField of {len(self.field_list)} {field_class}s : \n"
return " ".join([base_string] + [f"\t {field} \n" for field in self.field_list])
def human_readable_repr(self) -> List[Any]:
return [f.human_readable_repr() for f in self.field_list]
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment