Skip to content
Snippets Groups Projects
Commit 26044b45 authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

Add COMBO Model and required dependencies

parent 6343230c
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
"""Main COMBO model."""
from typing import Optional, Dict, Any, List
import torch
from overrides import overrides
from combo import data
from combo.models import base
from combo.models.embeddings import TokenEmbedder
from combo.models.model import Model
from combo.modules.seq2seq_encoder import Seq2SeqEncoder
from combo.nn import RegularizerApplicator
from combo.nn.util import get_text_field_mask
from combo.utils import metrics
class ComboModel(Model):
"""Main COMBO model."""
def __init__(self,
vocab: data.Vocabulary,
loss_weights: Dict[str, float],
text_field_embedder: TokenEmbedder,
seq_encoder: Seq2SeqEncoder,
use_sample_weight: bool = True,
lemmatizer: Optional[base.Predictor] = None,
upos_tagger: Optional[base.Predictor] = None,
xpos_tagger: Optional[base.Predictor] = None,
semantic_relation: Optional[base.Predictor] = None,
morphological_feat: Optional[base.Predictor] = None,
dependency_relation: Optional[base.Predictor] = None,
enhanced_dependency_relation: Optional[base.Predictor] = None,
regularizer: RegularizerApplicator = None) -> None:
super().__init__(vocab, regularizer)
self.text_field_embedder = text_field_embedder
self.loss_weights = loss_weights
self.use_sample_weight = use_sample_weight
self.seq_encoder = seq_encoder
self.lemmatizer = lemmatizer
self.upos_tagger = upos_tagger
self.xpos_tagger = xpos_tagger
self.semantic_relation = semantic_relation
self.morphological_feat = morphological_feat
self.dependency_relation = dependency_relation
self.enhanced_dependency_relation = enhanced_dependency_relation
self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, self.seq_encoder.get_output_dim()]))
self.scores = metrics.SemanticMetrics()
self._partial_losses = None
@overrides
def forward(self,
sentence: Dict[str, Dict[str, torch.Tensor]],
metadata: List[Dict[str, Any]],
upostag: torch.Tensor = None,
xpostag: torch.Tensor = None,
lemma: Dict[str, Dict[str, torch.Tensor]] = None,
feats: torch.Tensor = None,
head: torch.Tensor = None,
deprel: torch.Tensor = None,
semrel: torch.Tensor = None,
enhanced_heads: torch.Tensor = None,
enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
# Prepare masks
char_mask = sentence["char"]["token_characters"].gt(0)
word_mask = get_text_field_mask(sentence)
device = word_mask.device
# If enabled weight samples loss by log(sentence_length)
sample_weights = word_mask.sum(-1).float().log() if self.use_sample_weight else None
encoder_input = self.text_field_embedder(sentence, char_mask=char_mask)
encoder_emb = self.seq_encoder(encoder_input, word_mask)
batch_size, _, encoding_dim = encoder_emb.size()
# Concatenate the head sentinel (ROOT) onto the sentence representation.
head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
encoder_emb_with_root = torch.cat([head_sentinel, encoder_emb], 1)
word_mask_with_root = torch.cat([torch.ones((batch_size, 1), device=device), word_mask], 1)
upos_output = self._optional(self.upos_tagger,
encoder_emb,
mask=word_mask,
labels=upostag,
sample_weights=sample_weights)
xpos_output = self._optional(self.xpos_tagger,
encoder_emb,
mask=word_mask,
labels=xpostag,
sample_weights=sample_weights)
semrel_output = self._optional(self.semantic_relation,
encoder_emb,
mask=word_mask,
labels=semrel,
sample_weights=sample_weights)
morpho_output = self._optional(self.morphological_feat,
encoder_emb,
mask=word_mask,
labels=feats,
sample_weights=sample_weights)
lemma_output = self._optional(self.lemmatizer,
(encoder_emb, sentence.get("char").get("token_characters")
if sentence.get("char") else None),
mask=word_mask,
labels=lemma.get("char").get("token_characters") if lemma else None,
sample_weights=sample_weights)
parser_output = self._optional(self.dependency_relation,
encoder_emb_with_root,
returns_tuple=True,
mask=word_mask_with_root,
labels=(deprel, head),
sample_weights=sample_weights)
enhanced_parser_output = self._optional(self.enhanced_dependency_relation,
encoder_emb_with_root,
returns_tuple=True,
mask=word_mask_with_root,
labels=(enhanced_deprels, head, enhanced_heads),
sample_weights=sample_weights)
relations_pred, head_pred = parser_output["prediction"]
enhanced_relations_pred, enhanced_head_pred = enhanced_parser_output["prediction"]
output = {
"upostag": upos_output["prediction"],
"xpostag": xpos_output["prediction"],
"semrel": semrel_output["prediction"],
"feats": morpho_output["prediction"],
"lemma": lemma_output["prediction"],
"head": head_pred,
"deprel": relations_pred,
"enhanced_head": enhanced_head_pred,
"enhanced_deprel": enhanced_relations_pred,
"sentence_embedding": torch.max(encoder_emb, dim=1)[0],
"upostag_token_embedding": upos_output["embedding"],
"xpostag_token_embedding": xpos_output["embedding"],
"semrel_token_embedding": semrel_output["embedding"],
"feats_token_embedding": morpho_output["embedding"],
"deprel_token_embedding": parser_output["embedding"],
}
if "rel_probability" in enhanced_parser_output:
output["enhanced_deprel_prob"] = enhanced_parser_output["rel_probability"]
if self._has_labels([upostag, xpostag, lemma, feats, head, deprel, semrel]):
# Feats mapping
if self.morphological_feat:
mapped_gold_labels = []
for _, cat_indices in self.morphological_feat.slices.items():
mapped_gold_labels.append(feats[:, :, cat_indices].argmax(dim=-1))
feats = torch.stack(mapped_gold_labels, dim=-1)
labels = {
"upostag": upostag,
"xpostag": xpostag,
"semrel": semrel,
"feats": feats,
"lemma": lemma.get("char").get("token_characters") if lemma else None,
"head": head,
"deprel": deprel,
"enhanced_head": enhanced_heads,
"enhanced_deprel": enhanced_deprels,
}
self.scores(output, labels, word_mask)
relations_loss, head_loss = parser_output["loss"]
enhanced_relations_loss, enhanced_head_loss = enhanced_parser_output["loss"]
losses = {
"upostag_loss": upos_output["loss"],
"xpostag_loss": xpos_output["loss"],
"semrel_loss": semrel_output["loss"],
"feats_loss": morpho_output["loss"],
"lemma_loss": lemma_output["loss"],
"head_loss": head_loss,
"deprel_loss": relations_loss,
"enhanced_head_loss": enhanced_head_loss,
"enhanced_deprel_loss": enhanced_relations_loss,
# Cycle loss is only for the metrics purposes.
"cycle_loss": parser_output.get("cycle_loss")
}
self._partial_losses = losses.copy()
losses["loss"] = self._calculate_loss(losses)
output.update(losses)
return self._clean(output)
@staticmethod
def _has_labels(labels):
return any(x is not None for x in labels)
def _calculate_loss(self, output):
losses = []
for name, value in self.loss_weights.items():
if output.get(f"{name}_loss"):
losses.append(output[f"{name}_loss"] * value)
return torch.stack(losses).sum()
@staticmethod
def _optional(callable_model: Optional[torch.nn.Module],
*args,
returns_tuple: bool = False,
**kwargs):
if callable_model:
return callable_model(*args, **kwargs)
if returns_tuple:
return {"prediction": (None, None), "loss": (None, None), "embedding": (None, None)}
return {"prediction": None, "loss": None, "embedding": None}
@staticmethod
def _clean(output):
for k, v in dict(output).items():
if v is None:
del output[k]
return output
@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
metrics = self.scores.get_metric(reset)
if self._partial_losses:
losses = self._clean(self._partial_losses)
losses = {f"partial_loss/{k}": v.detach().item() for k, v in losses.items()}
metrics.update(losses)
return metrics
class Encoder: class StackedBidirectionalLstm:
pass pass
class StackedBidirectionalLstm(Encoder): class ComboEncoder:
pass
class ComboEncoder(Encoder):
pass pass
"""
Adapted from AllenNLP
https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/modules/encoder_base.py
"""
from typing import Tuple, Union, Optional, Callable, Any
import torch
from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence
from combo.nn.util import get_lengths_from_binary_sequence_mask, sort_batch_by_length
# We have two types here for the state, because storing the state in something
# which is Iterable (like a tuple, below), is helpful for internal manipulation
# - however, the states are consumed as either Tensors or a Tuple of Tensors, so
# returning them in this format is unhelpful.
RnnState = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
RnnStateStorage = Tuple[torch.Tensor, ...]
class _EncoderBase(torch.nn.Module):
"""
This abstract class serves as a base for the 3 `Encoder` abstractions in AllenNLP.
- [`Seq2SeqEncoders`](./seq2seq_encoders/seq2seq_encoder.md)
- [`Seq2VecEncoders`](./seq2vec_encoders/seq2vec_encoder.md)
Additionally, this class provides functionality for sorting sequences by length
so they can be consumed by Pytorch RNN classes, which require their inputs to be
sorted by length. Finally, it also provides optional statefulness to all of it's
subclasses by allowing the caching and retrieving of the hidden states of RNNs.
"""
def __init__(self, stateful: bool = False) -> None:
super().__init__()
self.stateful = stateful
self._states: Optional[RnnStateStorage] = None
def sort_and_run_forward(
self,
module: Callable[
[PackedSequence, Optional[RnnState]],
Tuple[Union[PackedSequence, torch.Tensor], RnnState],
],
inputs: torch.Tensor,
mask: torch.BoolTensor,
hidden_state: Optional[RnnState] = None,
):
"""
This function exists because Pytorch RNNs require that their inputs be sorted
before being passed as input. As all of our Seq2xxxEncoders use this functionality,
it is provided in a base class. This method can be called on any module which
takes as input a `PackedSequence` and some `hidden_state`, which can either be a
tuple of tensors or a tensor.
As all of our Seq2xxxEncoders have different return types, we return `sorted`
outputs from the module, which is called directly. Additionally, we return the
indices into the batch dimension required to restore the tensor to it's correct,
unsorted order and the number of valid batch elements (i.e the number of elements
in the batch which are not completely masked). This un-sorting and re-padding
of the module outputs is left to the subclasses because their outputs have different
types and handling them smoothly here is difficult.
# Parameters
module : `Callable[RnnInputs, RnnOutputs]`
A function to run on the inputs, where
`RnnInputs: [PackedSequence, Optional[RnnState]]` and
`RnnOutputs: Tuple[Union[PackedSequence, torch.Tensor], RnnState]`.
In most cases, this is a `torch.nn.Module`.
inputs : `torch.Tensor`, required.
A tensor of shape `(batch_size, sequence_length, embedding_size)` representing
the inputs to the Encoder.
mask : `torch.BoolTensor`, required.
A tensor of shape `(batch_size, sequence_length)`, representing masked and
non-masked elements of the sequence for each element in the batch.
hidden_state : `Optional[RnnState]`, (default = `None`).
A single tensor of shape (num_layers, batch_size, hidden_size) representing the
state of an RNN with or a tuple of
tensors of shapes (num_layers, batch_size, hidden_size) and
(num_layers, batch_size, memory_size), representing the hidden state and memory
state of an LSTM-like RNN.
# Returns
module_output : `Union[torch.Tensor, PackedSequence]`.
A Tensor or PackedSequence representing the output of the Pytorch Module.
The batch size dimension will be equal to `num_valid`, as sequences of zero
length are clipped off before the module is called, as Pytorch cannot handle
zero length sequences.
final_states : `Optional[RnnState]`
A Tensor representing the hidden state of the Pytorch Module. This can either
be a single tensor of shape (num_layers, num_valid, hidden_size), for instance in
the case of a GRU, or a tuple of tensors, such as those required for an LSTM.
restoration_indices : `torch.LongTensor`
A tensor of shape `(batch_size,)`, describing the re-indexing required to transform
the outputs back to their original batch order.
"""
# In some circumstances you may have sequences of zero length. `pack_padded_sequence`
# requires all sequence lengths to be > 0, so remove sequences of zero length before
# calling self._module, then fill with zeros.
# First count how many sequences are empty.
batch_size = mask.size(0)
num_valid = torch.sum(mask[:, 0]).int().item()
sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
(
sorted_inputs,
sorted_sequence_lengths,
restoration_indices,
sorting_indices,
) = sort_batch_by_length(inputs, sequence_lengths)
# Now create a PackedSequence with only the non-empty, sorted sequences.
packed_sequence_input = pack_padded_sequence(
sorted_inputs[:num_valid, :, :],
sorted_sequence_lengths[:num_valid].data.tolist(),
batch_first=True,
)
# Prepare the initial states.
if not self.stateful:
if hidden_state is None:
initial_states: Any = hidden_state
elif isinstance(hidden_state, tuple):
initial_states = [
state.index_select(1, sorting_indices)[:, :num_valid, :].contiguous()
for state in hidden_state
]
else:
initial_states = hidden_state.index_select(1, sorting_indices)[
:, :num_valid, :
].contiguous()
else:
initial_states = self._get_initial_states(batch_size, num_valid, sorting_indices)
# Actually call the module on the sorted PackedSequence.
module_output, final_states = module(packed_sequence_input, initial_states)
return module_output, final_states, restoration_indices
def _get_initial_states(
self, batch_size: int, num_valid: int, sorting_indices: torch.LongTensor
) -> Optional[RnnState]:
"""
Returns an initial state for use in an RNN. Additionally, this method handles
the batch size changing across calls by mutating the state to append initial states
for new elements in the batch. Finally, it also handles sorting the states
with respect to the sequence lengths of elements in the batch and removing rows
which are completely padded. Importantly, this `mutates` the state if the
current batch size is larger than when it was previously called.
# Parameters
batch_size : `int`, required.
The batch size can change size across calls to stateful RNNs, so we need
to know if we need to expand or shrink the states before returning them.
Expanded states will be set to zero.
num_valid : `int`, required.
The batch may contain completely padded sequences which get removed before
the sequence is passed through the encoder. We also need to clip these off
of the state too.
sorting_indices `torch.LongTensor`, required.
Pytorch RNNs take sequences sorted by length. When we return the states to be
used for a given call to `module.forward`, we need the states to match up to
the sorted sequences, so before returning them, we sort the states using the
same indices used to sort the sequences.
# Returns
This method has a complex return type because it has to deal with the first time it
is called, when it has no state, and the fact that types of RNN have heterogeneous
states.
If it is the first time the module has been called, it returns `None`, regardless
of the type of the `Module`.
Otherwise, for LSTMs, it returns a tuple of `torch.Tensors` with shape
`(num_layers, num_valid, state_size)` and `(num_layers, num_valid, memory_size)`
respectively, or for GRUs, it returns a single `torch.Tensor` of shape
`(num_layers, num_valid, state_size)`.
"""
# We don't know the state sizes the first time calling forward,
# so we let the module define what it's initial hidden state looks like.
if self._states is None:
return None
# Otherwise, we have some previous states.
if batch_size > self._states[0].size(1):
# This batch is larger than the all previous states.
# If so, resize the states.
num_states_to_concat = batch_size - self._states[0].size(1)
resized_states = []
# state has shape (num_layers, batch_size, hidden_size)
for state in self._states:
# This _must_ be inside the loop because some
# RNNs have states with different last dimension sizes.
zeros = state.new_zeros(state.size(0), num_states_to_concat, state.size(2))
resized_states.append(torch.cat([state, zeros], 1))
self._states = tuple(resized_states)
correctly_shaped_states = self._states
elif batch_size < self._states[0].size(1):
# This batch is smaller than the previous one.
correctly_shaped_states = tuple(state[:, :batch_size, :] for state in self._states)
else:
correctly_shaped_states = self._states
# At this point, our states are of shape (num_layers, batch_size, hidden_size).
# However, the encoder uses sorted sequences and additionally removes elements
# of the batch which are fully padded. We need the states to match up to these
# sorted and filtered sequences, so we do that in the next two blocks before
# returning the state/s.
if len(self._states) == 1:
# GRUs only have a single state. This `unpacks` it from the
# tuple and returns the tensor directly.
correctly_shaped_state = correctly_shaped_states[0]
sorted_state = correctly_shaped_state.index_select(1, sorting_indices)
return sorted_state[:, :num_valid, :].contiguous()
else:
# LSTMs have a state tuple of (state, memory).
sorted_states = [
state.index_select(1, sorting_indices) for state in correctly_shaped_states
]
return tuple(state[:, :num_valid, :].contiguous() for state in sorted_states)
def _update_states(
self, final_states: RnnStateStorage, restoration_indices: torch.LongTensor
) -> None:
"""
After the RNN has run forward, the states need to be updated.
This method just sets the state to the updated new state, performing
several pieces of book-keeping along the way - namely, unsorting the
states and ensuring that the states of completely padded sequences are
not updated. Finally, it also detaches the state variable from the
computational graph, such that the graph can be garbage collected after
each batch iteration.
# Parameters
final_states : `RnnStateStorage`, required.
The hidden states returned as output from the RNN.
restoration_indices : `torch.LongTensor`, required.
The indices that invert the sorting used in `sort_and_run_forward`
to order the states with respect to the lengths of the sequences in
the batch.
"""
# TODO(Mark): seems weird to sort here, but append zeros in the subclasses.
# which way around is best?
new_unsorted_states = [state.index_select(1, restoration_indices) for state in final_states]
if self._states is None:
# We don't already have states, so just set the
# ones we receive to be the current state.
self._states = tuple(state.data for state in new_unsorted_states)
else:
# Now we've sorted the states back so that they correspond to the original
# indices, we need to figure out what states we need to update, because if we
# didn't use a state for a particular row, we want to preserve its state.
# Thankfully, the rows which are all zero in the state correspond exactly
# to those which aren't used, so we create masks of shape (new_batch_size,),
# denoting which states were used in the RNN computation.
current_state_batch_size = self._states[0].size(1)
new_state_batch_size = final_states[0].size(1)
# Masks for the unused states of shape (1, new_batch_size, 1)
used_new_rows_mask = [
(state[0, :, :].sum(-1) != 0.0).float().view(1, new_state_batch_size, 1)
for state in new_unsorted_states
]
new_states = []
if current_state_batch_size > new_state_batch_size:
# The new state is smaller than the old one,
# so just update the indices which we used.
for old_state, new_state, used_mask in zip(
self._states, new_unsorted_states, used_new_rows_mask
):
# zero out all rows in the previous state
# which _were_ used in the current state.
masked_old_state = old_state[:, :new_state_batch_size, :] * (1 - used_mask)
# The old state is larger, so update the relevant parts of it.
old_state[:, :new_state_batch_size, :] = new_state + masked_old_state
new_states.append(old_state.detach())
else:
# The states are the same size, so we just have to
# deal with the possibility that some rows weren't used.
new_states = []
for old_state, new_state, used_mask in zip(
self._states, new_unsorted_states, used_new_rows_mask
):
# zero out all rows which _were_ used in the current state.
masked_old_state = old_state * (1 - used_mask)
# The old state is larger, so update the relevant parts of it.
new_state += masked_old_state
new_states.append(new_state.detach())
# It looks like there should be another case handled here - when
# the current_state_batch_size < new_state_batch_size. However,
# this never happens, because the states themeselves are mutated
# by appending zeros when calling _get_inital_states, meaning that
# the new states are either of equal size, or smaller, in the case
# that there are some unused elements (zero-length) for the RNN computation.
self._states = tuple(new_states)
def reset_states(self, mask: torch.BoolTensor = None) -> None:
"""
Resets the internal states of a stateful encoder.
# Parameters
mask : `torch.BoolTensor`, optional.
A tensor of shape `(batch_size,)` indicating which states should
be reset. If not provided, all states will be reset.
"""
if mask is None:
self._states = None
else:
# state has shape (num_layers, batch_size, hidden_size). We reshape
# mask to have shape (1, batch_size, 1) so that operations
# broadcast properly.
mask_batch_size = mask.size(0)
mask = mask.view(1, mask_batch_size, 1)
new_states = []
assert self._states is not None
for old_state in self._states:
old_state_batch_size = old_state.size(1)
if old_state_batch_size != mask_batch_size:
raise ValueError(
f"Trying to reset states using mask with incorrect batch size. "
f"Expected batch size: {old_state_batch_size}. "
f"Provided batch size: {mask_batch_size}."
)
new_state = ~mask * old_state
new_states.append(new_state.detach())
self._states = tuple(new_states)
from combo.models.encoder import Encoder
class Seq2SeqEncoder(Encoder):
"""
A `Seq2SeqEncoder` is a `Module` that takes as input a sequence of vectors and returns a
modified sequence of vectors. Input shape : `(batch_size, sequence_length, input_dim)`; output
shape : `(batch_size, sequence_length, output_dim)`.
We add two methods to the basic `Module` API: `get_input_dim()` and `get_output_dim()`.
You might need this if you want to construct a `Linear` layer using the output of this encoder,
or to raise sensible errors for mis-matching input dimensions.
"""
def get_input_dim(self) -> int:
"""
Returns the dimension of the vector input for each element in the sequence input
to a `Seq2SeqEncoder`. This is `not` the shape of the input tensor, but the
last element of that shape.
"""
raise NotImplementedError
def get_output_dim(self) -> int:
"""
Returns the dimension of each vector in the sequence output by this `Seq2SeqEncoder`.
This is `not` the shape of the returned tensor, but the last element of that shape.
"""
raise NotImplementedError
def is_bidirectional(self) -> bool:
"""
Returns `True` if this encoder is bidirectional. If so, we assume the forward direction
of the encoder is the first half of the final dimension, and the backward direction is the
second half.
"""
raise NotImplementedError
...@@ -7,6 +7,7 @@ from typing import Union ...@@ -7,6 +7,7 @@ from typing import Union
import torch import torch
from combo.common.util import int_to_device from combo.common.util import int_to_device
from combo.utils import ConfigurationError
def move_to_device(obj, device: Union[torch.device, int]): def move_to_device(obj, device: Union[torch.device, int]):
...@@ -53,4 +54,108 @@ def device_mapping(cuda_device: int): ...@@ -53,4 +54,108 @@ def device_mapping(cuda_device: int):
else: else:
return storage return storage
return return inner_device_mapping
\ No newline at end of file
def get_lengths_from_binary_sequence_mask(mask: torch.BoolTensor) -> torch.LongTensor:
"""
Compute sequence lengths for each batch element in a tensor using a
binary mask.
# Parameters
mask : `torch.BoolTensor`, required.
A 2D binary mask of shape (batch_size, sequence_length) to
calculate the per-batch sequence lengths from.
# Returns
`torch.LongTensor`
A torch.LongTensor of shape (batch_size,) representing the lengths
of the sequences in the batch.
"""
return mask.sum(-1)
def sort_batch_by_length(tensor: torch.Tensor, sequence_lengths: torch.Tensor):
"""
Sort a batch first tensor by some specified lengths.
# Parameters
tensor : `torch.FloatTensor`, required.
A batch first Pytorch tensor.
sequence_lengths : `torch.LongTensor`, required.
A tensor representing the lengths of some dimension of the tensor which
we want to sort by.
# Returns
sorted_tensor : `torch.FloatTensor`
The original tensor sorted along the batch dimension with respect to sequence_lengths.
sorted_sequence_lengths : `torch.LongTensor`
The original sequence_lengths sorted by decreasing size.
restoration_indices : `torch.LongTensor`
Indices into the sorted_tensor such that
`sorted_tensor.index_select(0, restoration_indices) == original_tensor`
permutation_index : `torch.LongTensor`
The indices used to sort the tensor. This is useful if you want to sort many
tensors using the same ordering.
"""
if not isinstance(tensor, torch.Tensor) or not isinstance(sequence_lengths, torch.Tensor):
raise ConfigurationError("Both the tensor and sequence lengths must be torch.Tensors.")
sorted_sequence_lengths, permutation_index = sequence_lengths.sort(0, descending=True)
sorted_tensor = tensor.index_select(0, permutation_index)
index_range = torch.arange(0, len(sequence_lengths), device=sequence_lengths.device)
# This is the equivalent of zipping with index, sorting by the original
# sequence lengths and returning the now sorted indices.
_, reverse_mapping = permutation_index.sort(0, descending=False)
restoration_indices = index_range.index_select(0, reverse_mapping)
return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index
def get_text_field_mask(
text_field_tensors: Dict[str, Dict[str, torch.Tensor]],
num_wrapping_dims: int = 0,
padding_id: int = 0,
) -> torch.BoolTensor:
"""
Takes the dictionary of tensors produced by a `TextField` and returns a mask
with 0 where the tokens are padding, and 1 otherwise. `padding_id` specifies the id of padding tokens.
We also handle `TextFields` wrapped by an arbitrary number of `ListFields`, where the number of wrapping
`ListFields` is given by `num_wrapping_dims`.
If `num_wrapping_dims == 0`, the returned mask has shape `(batch_size, num_tokens)`.
If `num_wrapping_dims > 0` then the returned mask has `num_wrapping_dims` extra
dimensions, so the shape will be `(batch_size, ..., num_tokens)`.
There could be several entries in the tensor dictionary with different shapes (e.g., one for
word ids, one for character ids). In order to get a token mask, we use the tensor in
the dictionary with the lowest number of dimensions. After subtracting `num_wrapping_dims`,
if this tensor has two dimensions we assume it has shape `(batch_size, ..., num_tokens)`,
and use it for the mask. If instead it has three dimensions, we assume it has shape
`(batch_size, ..., num_tokens, num_features)`, and sum over the last dimension to produce
the mask. Most frequently this will be a character id tensor, but it could also be a
featurized representation of each token, etc.
If the input `text_field_tensors` contains the "mask" key, this is returned instead of inferring the mask.
"""
masks = []
for indexer_name, indexer_tensors in text_field_tensors.items():
if "mask" in indexer_tensors:
masks.append(indexer_tensors["mask"].bool())
if len(masks) == 1:
return masks[0]
elif len(masks) > 1:
# TODO(mattg): My guess is this will basically never happen, so I'm not writing logic to
# handle it. Should be straightforward to handle, though. If you see this error in
# practice, open an issue on github.
raise ValueError("found two mask outputs; not sure which to use!")
tensor_dims = [
(tensor.dim(), tensor)
for indexer_output in text_field_tensors.values()
for tensor in indexer_output.values()
]
tensor_dims.sort(key=lambda x: x[0])
smallest_dim = tensor_dims[0][0] - num_wrapping_dims
if smallest_dim == 2:
token_tensor = tensor_dims[0][1]
return token_tensor != padding_id
elif smallest_dim == 3:
character_tensor = tensor_dims[0][1]
return (character_tensor != padding_id).any(dim=-1)
else:
raise ValueError("Expected a tensor with dimension 2 or 3, found {}".format(smallest_dim))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment