Skip to content
Snippets Groups Projects
Commit 62ff9bbf authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

Add TextField from AllenNLP

parent ebd2bd32
1 merge request!46Merge COMBO 3.0 into master
"""
Adapted from AllenNLP.
https://github.com/allenai/allennlp/blob/main/allennlp/data/fields/text_field.py
A `TextField` represents a string of text, the kind that you might want to represent with
standard word vectors, or pass through an LSTM.
"""
from collections import defaultdict
from copy import deepcopy
from typing import Dict, List, Optional, Iterator
import textwrap
from spacy.tokens import Token as SpacyToken
import torch
# There are two levels of dictionaries here: the top level is for the *key*, which aligns
# TokenIndexers with their corresponding TokenEmbedders. The bottom level is for the *objects*
# produced by a given TokenIndexer, which will be input to a particular TokenEmbedder's forward()
# method. We label these as tensors, because that's what they typically are, though they could in
# reality have arbitrary type.
from combo.data import Vocabulary
from combo.data.fields.sequence_field import SequenceField
from combo.data.token_indexers import TokenIndexer, IndexedTokenList
from combo.data.tokenizers import TokenizerToken
from combo.utils import ConfigurationError
TextFieldTensors = Dict[str, Dict[str, torch.Tensor]]
def batch_tensor_dicts(
tensor_dicts: List[Dict[str, torch.Tensor]], remove_trailing_dimension: bool = False
) -> Dict[str, torch.Tensor]:
"""
Takes a list of tensor dictionaries, where each dictionary is assumed to have matching keys,
and returns a single dictionary with all tensors with the same key batched together.
# Parameters
tensor_dicts : `List[Dict[str, torch.Tensor]]`
The list of tensor dictionaries to batch.
remove_trailing_dimension : `bool`
If `True`, we will check for a trailing dimension of size 1 on the tensors that are being
batched, and remove it if we find it.
"""
key_to_tensors: Dict[str, List[torch.Tensor]] = defaultdict(list)
for tensor_dict in tensor_dicts:
for key, tensor in tensor_dict.items():
key_to_tensors[key].append(tensor)
batched_tensors = {}
for key, tensor_list in key_to_tensors.items():
batched_tensor = torch.stack(tensor_list)
if remove_trailing_dimension and all(tensor.size(-1) == 1 for tensor in tensor_list):
batched_tensor = batched_tensor.squeeze(-1)
batched_tensors[key] = batched_tensor
return batched_tensors
class TextField(SequenceField[TextFieldTensors]):
"""
This `Field` represents a list of string tokens. Before constructing this object, you need
to tokenize raw strings using a :class:`~allennlp.data.tokenizers.tokenizer.Tokenizer`.
Because string tokens can be represented as indexed arrays in a number of ways, we also take a
dictionary of :class:`~allennlp.data.token_indexers.token_indexer.TokenIndexer`
objects that will be used to convert the tokens into indices.
Each `TokenIndexer` could represent each token as a single ID, or a list of character IDs, or
something else.
This field will get converted into a dictionary of arrays, one for each `TokenIndexer`. A
`SingleIdTokenIndexer` produces an array of shape (num_tokens,), while a
`TokenCharactersIndexer` produces an array of shape (num_tokens, num_characters).
"""
__slots__ = ["tokens", "_token_indexers", "_indexed_tokens"]
def __init__(
self, tokens: List[TokenizerToken], token_indexers: Optional[Dict[str, TokenIndexer]] = None
) -> None:
self.tokens = tokens
self._token_indexers = token_indexers
self._indexed_tokens: Optional[Dict[str, IndexedTokenList]] = None
if not all(isinstance(x, (TokenizerToken, SpacyToken)) for x in tokens):
raise ConfigurationError(
"TextFields must be passed Tokens. "
"Found: {} with types {}.".format(tokens, [type(x) for x in tokens])
)
@property
def token_indexers(self) -> Dict[str, TokenIndexer]:
if self._token_indexers is None:
raise ValueError(
"TextField's token_indexers have not been set.\n"
"Did you forget to call DatasetReader.apply_token_indexers(instance) "
"on your instance?\n"
"If apply_token_indexers() is being called but "
"you're still seeing this error, it may not be implemented correctly."
)
return self._token_indexers
@token_indexers.setter
def token_indexers(self, token_indexers: Dict[str, TokenIndexer]) -> None:
self._token_indexers = token_indexers
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
for indexer in self.token_indexers.values():
for token in self.tokens:
indexer.count_vocab_items(token, counter)
def index(self, vocab: Vocabulary):
self._indexed_tokens = {}
for indexer_name, indexer in self.token_indexers.items():
self._indexed_tokens[indexer_name] = indexer.tokens_to_indices(self.tokens, vocab)
def get_padding_lengths(self) -> Dict[str, int]:
"""
The `TextField` has a list of `Tokens`, and each `Token` gets converted into arrays by
(potentially) several `TokenIndexers`. This method gets the max length (over tokens)
associated with each of these arrays.
"""
if self._indexed_tokens is None:
raise ConfigurationError(
"You must call .index(vocabulary) on a field before determining padding lengths."
)
padding_lengths = {}
for indexer_name, indexer in self.token_indexers.items():
indexer_lengths = indexer.get_padding_lengths(self._indexed_tokens[indexer_name])
for key, length in indexer_lengths.items():
padding_lengths[f"{indexer_name}___{key}"] = length
return padding_lengths
def sequence_length(self) -> int:
return len(self.tokens)
def as_tensor(self, padding_lengths: Dict[str, int]) -> TextFieldTensors:
if self._indexed_tokens is None:
raise ConfigurationError(
"You must call .index(vocabulary) on a field before calling .as_tensor()"
)
tensors = {}
indexer_lengths: Dict[str, Dict[str, int]] = defaultdict(dict)
for key, value in padding_lengths.items():
# We want this to crash if the split fails. Should never happen, so I'm not
# putting in a check, but if you fail on this line, open a github issue.
indexer_name, padding_key = key.split("___")
indexer_lengths[indexer_name][padding_key] = value
for indexer_name, indexer in self.token_indexers.items():
tensors[indexer_name] = indexer.as_padded_tensor_dict(
self._indexed_tokens[indexer_name], indexer_lengths[indexer_name]
)
return tensors
def empty_field(self):
text_field = TextField([], self._token_indexers)
text_field._indexed_tokens = {}
if self._token_indexers is not None:
for indexer_name, indexer in self.token_indexers.items():
text_field._indexed_tokens[indexer_name] = indexer.get_empty_token_list()
return text_field
def batch_tensors(self, tensor_list: List[TextFieldTensors]) -> TextFieldTensors:
# This is creating a dict of {token_indexer_name: {token_indexer_outputs: batched_tensor}}
# for each token indexer used to index this field.
indexer_lists: Dict[str, List[Dict[str, torch.Tensor]]] = defaultdict(list)
for tensor_dict in tensor_list:
for indexer_name, indexer_output in tensor_dict.items():
indexer_lists[indexer_name].append(indexer_output)
batched_tensors = {
# NOTE(mattg): if an indexer has its own nested structure, rather than one tensor per
# argument, then this will break. If that ever happens, we should move this to an
# `indexer.batch_tensors` method, with this logic as the default implementation in the
# base class.
indexer_name: batch_tensor_dicts(indexer_outputs)
for indexer_name, indexer_outputs in indexer_lists.items()
}
return batched_tensors
def __str__(self) -> str:
# Double tab to indent under the header.
formatted_text = "".join(
"\t\t" + text + "\n" for text in textwrap.wrap(repr(self.tokens), 100)
)
if self._token_indexers is not None:
indexers = {
name: indexer.__class__.__name__ for name, indexer in self._token_indexers.items()
}
return (
f"TextField of length {self.sequence_length()} with "
f"text: \n {formatted_text} \t\tand TokenIndexers : {indexers}"
)
else:
return f"TextField of length {self.sequence_length()} with text: \n {formatted_text}"
# Sequence[Token] methods
def __iter__(self) -> Iterator[TokenizerToken]:
return iter(self.tokens)
def __getitem__(self, idx: int) -> TokenizerToken:
return self.tokens[idx]
def __len__(self) -> int:
return len(self.tokens)
def duplicate(self):
"""
Overrides the behavior of `duplicate` so that `self._token_indexers` won't
actually be deep-copied.
Not only would it be extremely inefficient to deep-copy the token indexers,
but it also fails in many cases since some tokenizers (like those used in
the 'transformers' lib) cannot actually be deep-copied.
"""
if self._token_indexers is not None:
new = TextField(deepcopy(self.tokens), {k: v for k, v in self._token_indexers.items()})
else:
new = TextField(deepcopy(self.tokens))
new._indexed_tokens = deepcopy(self._indexed_tokens)
return new
def human_readable_repr(self) -> List[str]:
return [str(t) for t in self.tokens]
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment