Skip to content
Snippets Groups Projects
Commit fd0f2ea1 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Update allennlp to 1.0.0.

parent fa0c1555
Branches
Tags
No related merge requests found
...@@ -230,6 +230,14 @@ def get_slices_if_not_provided(vocab: allen_data.Vocabulary): ...@@ -230,6 +230,14 @@ def get_slices_if_not_provided(vocab: allen_data.Vocabulary):
return vocab.slices return vocab.slices
@dataclass @dataclass(init=False, repr=False)
class _Token(allen_data.Token): class _Token(allen_data.Token):
feats_: Optional[str] = None __slots__ = allen_data.Token.__slots__ + ['feats_']
feats_: Optional[str]
def __init__(self, text: str = None, idx: int = None, idx_end: int = None, lemma_: str = None, pos_: str = None,
tag_: str = None, dep_: str = None, ent_type_: str = None, text_id: int = None, type_id: int = None,
feats_: str = None) -> None:
super().__init__(text, idx, idx_end, lemma_, pos_, tag_, dep_, ent_type_, text_id, type_id)
self.feats_ = feats_
from .token_characters_indexer import TokenCharactersIndexer from .token_characters_indexer import TokenCharactersIndexer
from .token_features_indexer import TokenFeatsIndexer from .token_features_indexer import TokenFeatsIndexer
from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
from typing import List
from allennlp import data
from allennlp.data import token_indexers
from overrides import overrides
@token_indexers.TokenIndexer.register("pretrained_transformer_mismatched_tmp_fix")
class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
"""TODO remove when fixed in AllenNLP (fc47bf6ae5c0df6d473103d459b75fa7edbdd979)"""
@overrides
def tokens_to_indices(self, tokens: List[data.Token], vocabulary: data.Vocabulary) -> data.IndexedTokenList:
self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize([t.text for t in tokens])
# For tokens that don't correspond to any word pieces, we put (-1, -1) into the offsets.
# That results in the embedding for the token to be all zeros.
offsets = [x if x is not None else (-1, -1) for x in offsets]
output: data.IndexedTokenList = {
"token_ids": [t.text_id for t in wordpieces],
"mask": [True] * len(tokens), # for original tokens (i.e. word-level)
"type_ids": [t.type_id for t in wordpieces],
"offsets": offsets,
"wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level)
}
return self._matched_indexer._postprocess_output(output)
...@@ -137,31 +137,10 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm ...@@ -137,31 +137,10 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
type_ids: Optional[torch.LongTensor] = None, type_ids: Optional[torch.LongTensor] = None,
segment_concat_mask: Optional[torch.BoolTensor] = None, segment_concat_mask: Optional[torch.BoolTensor] = None,
) -> torch.Tensor: # type: ignore ) -> torch.Tensor: # type: ignore
"""TODO remove (and call super) when fixed in AllenNLP (fc47bf6ae5c0df6d473103d459b75fa7edbdd979)""" x = super().forward(token_ids, mask, offsets, wordpiece_mask, type_ids, segment_concat_mask)
# Shape: [batch_size, num_wordpieces, embedding_size].
embeddings = self._matched_embedder(
token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask
)
# span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
# span_mask: (batch_size, num_orig_tokens, max_span_length)
span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets)
span_mask = span_mask.unsqueeze(-1)
span_embeddings *= span_mask # zero out paddings
span_embeddings_sum = span_embeddings.sum(2)
span_embeddings_len = span_mask.sum(2)
# Shape: (batch_size, num_orig_tokens, embedding_size)
orig_embeddings = span_embeddings_sum / span_embeddings_len
# All the places where the span length is zero, write in zeros.
orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0
# TODO end remove
if self.projection_layer: if self.projection_layer:
orig_embeddings = self.projection_layer(orig_embeddings) x = self.projection_layer(x)
return orig_embeddings return x
@overrides @overrides
def get_output_dim(self): def get_output_dim(self):
......
...@@ -67,11 +67,15 @@ class HeadPredictionModel(base.Predictor): ...@@ -67,11 +67,15 @@ class HeadPredictionModel(base.Predictor):
def _cycle_loss(self, pred: torch.Tensor): def _cycle_loss(self, pred: torch.Tensor):
BATCH_SIZE, _, _ = pred.size() BATCH_SIZE, _, _ = pred.size()
loss = pred.new_zeros(BATCH_SIZE) loss = pred.new_zeros(BATCH_SIZE)
# 1: as using non __ROOT__ tokens # Index from 1: as using non __ROOT__ tokens
x = pred[:, 1:, 1:] pred = pred.softmax(-1)[:, 1:, 1:]
for _ in range(self.cycle_loss_n): x = pred
loss += self._batch_trace(x) / BATCH_SIZE for i in range(self.cycle_loss_n):
x = x.bmm(pred[:, 1:, 1:]) loss += self._batch_trace(x)
# Don't multiple on last iteration
if i < self.cycle_loss_n - 1:
x = x.bmm(pred)
return loss return loss
...@@ -83,7 +87,7 @@ class HeadPredictionModel(base.Predictor): ...@@ -83,7 +87,7 @@ class HeadPredictionModel(base.Predictor):
identity = x.new_tensor(torch.eye(N)) identity = x.new_tensor(torch.eye(N))
identity = identity.reshape((1, N, N)) identity = identity.reshape((1, N, N))
batch_identity = identity.repeat(BATCH_SIZE, 1, 1) batch_identity = identity.repeat(BATCH_SIZE, 1, 1)
return (x * batch_identity).sum() return (x * batch_identity).sum((-1, -2))
def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor, def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
......
...@@ -29,6 +29,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): ...@@ -29,6 +29,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
self.batch_size = batch_size self.batch_size = batch_size
self.vocab = model.vocab self.vocab = model.vocab
self._dataset_reader.generate_labels = False self._dataset_reader.generate_labels = False
self._dataset_reader.lazy = True
self._tokenizer = tokenizer self._tokenizer = tokenizer
def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
...@@ -71,8 +72,8 @@ class SemanticMultitaskPredictor(predictor.Predictor): ...@@ -71,8 +72,8 @@ class SemanticMultitaskPredictor(predictor.Predictor):
raise ValueError("Input must be either string or list of strings.") raise ValueError("Input must be either string or list of strings.")
@overrides @overrides
def predict_batch_instance(self, instances: List[allen_data.Instance], serialize: bool = True) -> List[ def predict_batch_instance(self, instances: List[allen_data.Instance], serialize: bool = True
common.JsonDict]: ) -> List[common.JsonDict]:
trees = [] trees = []
predictions = super().predict_batch_instance(instances) predictions = super().predict_batch_instance(instances)
for prediction, instance in zip(predictions, instances): for prediction, instance in zip(predictions, instances):
......
...@@ -128,7 +128,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -128,7 +128,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
num_batches, num_batches,
reset=True, reset=True,
world_size=self._world_size, world_size=self._world_size,
cuda_device=[self.cuda_device], cuda_device=self.cuda_device,
) )
# Check validation metric for early stopping # Check validation metric for early stopping
......
...@@ -112,7 +112,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ...@@ -112,7 +112,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
use_sem: if in_targets("semrel") then true else false, use_sem: if in_targets("semrel") then true else false,
token_indexers: { token_indexers: {
token: if use_transformer then { token: if use_transformer then {
type: "pretrained_transformer_mismatched_tmp_fix", type: "pretrained_transformer_mismatched",
model_name: pretrained_transformer_name, model_name: pretrained_transformer_name,
} else { } else {
# SingleIdTokenIndexer, token as single int # SingleIdTokenIndexer, token as single int
......
...@@ -3,7 +3,7 @@ from setuptools import find_packages, setup ...@@ -3,7 +3,7 @@ from setuptools import find_packages, setup
REQUIREMENTS = [ REQUIREMENTS = [
'absl-py==0.9.0', 'absl-py==0.9.0',
'allennlp==1.0.0rc5', 'allennlp==1.0.0',
'conllu==2.3.2', 'conllu==2.3.2',
'joblib==0.14.1', 'joblib==0.14.1',
'jsonnet==0.15.0', 'jsonnet==0.15.0',
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment