Skip to content
Snippets Groups Projects
Commit f4eea1c8 authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

Add AllenNLP fields: AdjacencyField and SequenceLabelField

parent b74a0ebd
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
"""
Adapted from AllenNLP
https://github.com/allenai/allennlp/blob/main/allennlp/data/fields/adjacency_field.py
"""
from typing import Dict, List, Set, Tuple, Optional
import logging
import textwrap
import torch
from combo.data import Vocabulary
from combo.data.fields import Field
from combo.data.fields.sequence_field import SequenceField
from combo.utils import ConfigurationError
logger = logging.getLogger(__name__)
class AdjacencyField(Field[torch.Tensor]):
"""
A `AdjacencyField` defines directed adjacency relations between elements
in a :class:`~allennlp.data.fields.sequence_field.SequenceField`.
Because it's a labeling of some other field, we take that field as input here
and use it to determine our padding and other things.
This field will get converted into an array of shape (sequence_field_length, sequence_field_length),
where the (i, j)th array element is either a binary flag indicating there is an edge from i to j,
or an integer label k, indicating there is a label from i to j of type k.
# Parameters
indices : `List[Tuple[int, int]]`
sequence_field : `SequenceField`
A field containing the sequence that this `AdjacencyField` is labeling. Most often,
this is a `TextField`, for tagging edge relations between tokens in a sentence.
labels : `List[str]`, optional, (default = `None`)
Optional labels for the edges of the adjacency matrix.
label_namespace : `str`, optional (default=`'labels'`)
The namespace to use for converting tag strings into integers. We convert tag strings to
integers for you, and this parameter tells the `Vocabulary` object which mapping from
strings to integers to use (so that "O" as a tag doesn't get the same id as "O" as a word).
padding_value : `int`, optional (default = `-1`)
The value to use as padding.
"""
__slots__ = [
"indices",
"labels",
"sequence_field",
"_label_namespace",
"_padding_value",
"_indexed_labels",
]
# It is possible that users want to use this field with a namespace which uses OOV/PAD tokens.
# This warning will be repeated for every instantiation of this class (i.e for every data
# instance), spewing a lot of warnings so this class variable is used to only log a single
# warning per namespace.
_already_warned_namespaces: Set[str] = set()
def __init__(
self,
indices: List[Tuple[int, int]],
sequence_field: SequenceField,
labels: List[str] = None,
label_namespace: str = "labels",
padding_value: int = -1,
) -> None:
self.indices = indices
self.labels = labels
self.sequence_field = sequence_field
self._label_namespace = label_namespace
self._padding_value = padding_value
self._indexed_labels: Optional[List[int]] = None
self._maybe_warn_for_namespace(label_namespace)
field_length = sequence_field.sequence_length()
if len(set(indices)) != len(indices):
raise ConfigurationError(f"Indices must be unique, but found {indices}")
if not all(
0 <= index[1] < field_length and 0 <= index[0] < field_length for index in indices
):
raise ConfigurationError(
f"Label indices and sequence length "
f"are incompatible: {indices} and {field_length}"
)
if labels is not None and len(indices) != len(labels):
raise ConfigurationError(
f"Labelled indices were passed, but their lengths do not match: "
f" {labels}, {indices}"
)
def _maybe_warn_for_namespace(self, label_namespace: str) -> None:
if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
if label_namespace not in self._already_warned_namespaces:
logger.warning(
"Your label namespace was '%s'. We recommend you use a namespace "
"ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
"default to your vocabulary. See documentation for "
"`non_padded_namespaces` parameter in Vocabulary.",
self._label_namespace,
)
self._already_warned_namespaces.add(label_namespace)
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
if self._indexed_labels is None and self.labels is not None:
for label in self.labels:
counter[self._label_namespace][label] += 1 # type: ignore
def index(self, vocab: Vocabulary):
if self.labels is not None:
self._indexed_labels = [
vocab.get_token_index(label, self._label_namespace) for label in self.labels
]
def get_padding_lengths(self) -> Dict[str, int]:
return {"num_tokens": self.sequence_field.sequence_length()}
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
desired_num_tokens = padding_lengths["num_tokens"]
tensor = torch.ones(desired_num_tokens, desired_num_tokens) * self._padding_value
labels = self._indexed_labels or [1 for _ in range(len(self.indices))]
for index, label in zip(self.indices, labels):
tensor[index] = label
return tensor
def empty_field(self) -> "AdjacencyField":
# The empty_list here is needed for mypy
empty_list: List[Tuple[int, int]] = []
adjacency_field = AdjacencyField(
empty_list, self.sequence_field.empty_field(), padding_value=self._padding_value
)
return adjacency_field
def __str__(self) -> str:
length = self.sequence_field.sequence_length()
formatted_labels = "".join(
"\t\t" + labels + "\n" for labels in textwrap.wrap(repr(self.labels), 100)
)
formatted_indices = "".join(
"\t\t" + index + "\n" for index in textwrap.wrap(repr(self.indices), 100)
)
return (
f"AdjacencyField of length {length}\n"
f"\t\twith indices:\n {formatted_indices}\n"
f"\t\tand labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
)
def __len__(self):
return len(self.sequence_field)
def human_readable_repr(self):
ret = {"indices": self.indices}
if self.labels is not None:
ret["labels"] = self.labels
return ret
"""
Adapted from AllenNLP
https://github.com/allenai/allennlp/blob/main/allennlp/data/fields/sequence_label_field.py
"""
from typing import Dict, List, Union, Set, Iterator
import logging
import textwrap
import torch
from combo.data import Vocabulary
from combo.data.fields import Field
from combo.data.fields.sequence_field import SequenceField
from combo.utils import ConfigurationError, pad_sequence_to_length
logger = logging.getLogger(__name__)
class SequenceLabelField(Field[torch.Tensor]):
"""
A `SequenceLabelField` assigns a categorical label to each element in a
:class:`~allennlp.data.fields.sequence_field.SequenceField`.
Because it's a labeling of some other field, we take that field as input here, and we use it to
determine our padding and other things.
This field will get converted into a list of integer class ids, representing the correct class
for each element in the sequence.
# Parameters
labels : `Union[List[str], List[int]]`
A sequence of categorical labels, encoded as strings or integers. These could be POS tags like [NN,
JJ, ...], BIO tags like [B-PERS, I-PERS, O, O, ...], or any other categorical tag sequence. If the
labels are encoded as integers, they will not be indexed using a vocab.
sequence_field : `SequenceField`
A field containing the sequence that this `SequenceLabelField` is labeling. Most often, this is a
`TextField`, for tagging individual tokens in a sentence.
label_namespace : `str`, optional (default=`'labels'`)
The namespace to use for converting tag strings into integers. We convert tag strings to
integers for you, and this parameter tells the `Vocabulary` object which mapping from
strings to integers to use (so that "O" as a tag doesn't get the same id as "O" as a word).
"""
__slots__ = [
"labels",
"sequence_field",
"_label_namespace",
"_indexed_labels",
"_skip_indexing",
]
# It is possible that users want to use this field with a namespace which uses OOV/PAD tokens.
# This warning will be repeated for every instantiation of this class (i.e for every data
# instance), spewing a lot of warnings so this class variable is used to only log a single
# warning per namespace.
_already_warned_namespaces: Set[str] = set()
def __init__(
self,
labels: Union[List[str], List[int]],
sequence_field: SequenceField,
label_namespace: str = "labels",
) -> None:
self.labels = labels
self.sequence_field = sequence_field
self._label_namespace = label_namespace
self._indexed_labels = None
self._maybe_warn_for_namespace(label_namespace)
if len(labels) != sequence_field.sequence_length():
raise ConfigurationError(
"Label length and sequence length "
"don't match: %d and %d" % (len(labels), sequence_field.sequence_length())
)
self._skip_indexing = False
if all(isinstance(x, int) for x in labels):
self._indexed_labels = labels
self._skip_indexing = True
elif not all(isinstance(x, str) for x in labels):
raise ConfigurationError(
"SequenceLabelFields must be passed either all "
"strings or all ints. Found labels {} with "
"types: {}.".format(labels, [type(x) for x in labels])
)
def _maybe_warn_for_namespace(self, label_namespace: str) -> None:
if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
if label_namespace not in self._already_warned_namespaces:
logger.warning(
"Your label namespace was '%s'. We recommend you use a namespace "
"ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
"default to your vocabulary. See documentation for "
"`non_padded_namespaces` parameter in Vocabulary.",
self._label_namespace,
)
self._already_warned_namespaces.add(label_namespace)
# Sequence methods
def __iter__(self) -> Iterator[Union[str, int]]:
return iter(self.labels)
def __getitem__(self, idx: int) -> Union[str, int]:
return self.labels[idx]
def __len__(self) -> int:
return len(self.labels)
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
if self._indexed_labels is None:
for label in self.labels:
counter[self._label_namespace][label] += 1 # type: ignore
def index(self, vocab: Vocabulary):
if not self._skip_indexing:
self._indexed_labels = [
vocab.get_token_index(label, self._label_namespace) # type: ignore
for label in self.labels
]
def get_padding_lengths(self) -> Dict[str, int]:
return {"num_tokens": self.sequence_field.sequence_length()}
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
if self._indexed_labels is None:
raise ConfigurationError(
"You must call .index(vocabulary) on a field before calling .as_tensor()"
)
desired_num_tokens = padding_lengths["num_tokens"]
padded_tags = pad_sequence_to_length(self._indexed_labels, desired_num_tokens)
tensor = torch.LongTensor(padded_tags)
return tensor
def empty_field(self) -> "SequenceLabelField":
# The empty_list here is needed for mypy
empty_list: List[str] = []
sequence_label_field = SequenceLabelField(empty_list, self.sequence_field.empty_field())
sequence_label_field._indexed_labels = empty_list
return sequence_label_field
def __str__(self) -> str:
length = self.sequence_field.sequence_length()
formatted_labels = "".join(
"\t\t" + labels + "\n" for labels in textwrap.wrap(repr(self.labels), 100)
)
return (
f"SequenceLabelField of length {length} with "
f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
)
def human_readable_repr(self) -> Union[List[str], List[int]]:
return self.labels
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment