diff --git a/combo/data/__init__.py b/combo/data/__init__.py index 42bda68428ff8ef536900af16950910353e1f360..8a64dafc5fd338069dc66102a2347c9333800c47 100644 --- a/combo/data/__init__.py +++ b/combo/data/__init__.py @@ -1,2 +1,4 @@ from .api import Token from .vocabulary import Vocabulary +from .samplers import TokenCountBatchSampler +from .instance import Instance \ No newline at end of file diff --git a/combo/data/instance.py b/combo/data/instance.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c008a0bf1b359cb14af83938b7d2e0d8aa0c44 --- /dev/null +++ b/combo/data/instance.py @@ -0,0 +1,124 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/data/instance.py +""" + +from typing import Dict, MutableMapping, Mapping, Any +from combo.data import Vocabulary +from combo.data.fields import Field +from combo.data.fields.field import DataArray + + +JsonDict = Dict[str, Any] + + +class Instance(Mapping[str, Field]): + """ + An `Instance` is a collection of :class:`~allennlp.data.fields.field.Field` objects, + specifying the inputs and outputs to + some model. We don't make a distinction between inputs and outputs here, though - all + operations are done on all fields, and when we return arrays, we return them as dictionaries + keyed by field name. A model can then decide which fields it wants to use as inputs as which + as outputs. + The `Fields` in an `Instance` can start out either indexed or un-indexed. During the data + processing pipeline, all fields will be indexed, after which multiple instances can be combined + into a `Batch` and then converted into padded arrays. + # Parameters + fields : `Dict[str, Field]` + The `Field` objects that will be used to produce data arrays for this instance. + """ + + __slots__ = ["fields", "indexed"] + + def __init__(self, fields: MutableMapping[str, Field]) -> None: + self.fields = fields + self.indexed = False + + # Add methods for `Mapping`. Note, even though the fields are + # mutable, we don't implement `MutableMapping` because we want + # you to use `add_field` and supply a vocabulary. + def __getitem__(self, key: str) -> Field: + return self.fields[key] + + def __iter__(self): + return iter(self.fields) + + def __len__(self) -> int: + return len(self.fields) + + def add_field(self, field_name: str, field: Field, vocab: Vocabulary = None) -> None: + """ + Add the field to the existing fields mapping. + If we have already indexed the Instance, then we also index `field`, so + it is necessary to supply the vocab. + """ + self.fields[field_name] = field + if self.indexed and vocab is not None: + field.index(vocab) + + def count_vocab_items(self, counter: Dict[str, Dict[str, int]]): + """ + Increments counts in the given `counter` for all of the vocabulary items in all of the + `Fields` in this `Instance`. + """ + for field in self.fields.values(): + field.count_vocab_items(counter) + + def index_fields(self, vocab: Vocabulary) -> None: + """ + Indexes all fields in this `Instance` using the provided `Vocabulary`. + This `mutates` the current object, it does not return a new `Instance`. + A `DataLoader` will call this on each pass through a dataset; we use the `indexed` + flag to make sure that indexing only happens once. + This means that if for some reason you modify your vocabulary after you've + indexed your instances, you might get unexpected behavior. + """ + if not self.indexed: + for field in self.fields.values(): + field.index(vocab) + self.indexed = True + + def get_padding_lengths(self) -> Dict[str, Dict[str, int]]: + """ + Returns a dictionary of padding lengths, keyed by field name. Each `Field` returns a + mapping from padding keys to actual lengths, and we just key that dictionary by field name. + """ + lengths = {} + for field_name, field in self.fields.items(): + lengths[field_name] = field.get_padding_lengths() + return lengths + + def as_tensor_dict( + self, padding_lengths: Dict[str, Dict[str, int]] = None + ) -> Dict[str, DataArray]: + """ + Pads each `Field` in this instance to the lengths given in `padding_lengths` (which is + keyed by field name, then by padding key, the same as the return value in + :func:`get_padding_lengths`), returning a list of torch tensors for each field. + If `padding_lengths` is omitted, we will call `self.get_padding_lengths()` to get the + sizes of the tensors to create. + """ + padding_lengths = padding_lengths or self.get_padding_lengths() + tensors = {} + for field_name, field in self.fields.items(): + tensors[field_name] = field.as_tensor(padding_lengths[field_name]) + return tensors + + def __str__(self) -> str: + base_string = "Instance with fields:\n" + return " ".join( + [base_string] + [f"\t {name}: {field} \n" for name, field in self.fields.items()] + ) + + def duplicate(self) -> "Instance": + new = Instance({k: field.duplicate() for k, field in self.fields.items()}) + new.indexed = self.indexed + return new + + def human_readable_dict(self) -> JsonDict: + """ + This function help to output instances to json files or print for human readability. + Use case includes example-based explanation, where it's better to have a output file or + rather than printing or logging. + """ + return {key: field.human_readable_repr() for key, field in self.fields.items()} \ No newline at end of file