Skip to content
Snippets Groups Projects

Release 1.0.5

Merged Mateusz Klimaszewski requested to merge develop into master
Compare and
6 files
+ 66
15
Compare changes
  • Side-by-side
  • Inline
Files
6
from typing import Optional, Dict, Any, List, Tuple
from allennlp import data
from allennlp.data import token_indexers, tokenizers
from allennlp.data import token_indexers, tokenizers, IndexedTokenList, vocabulary
from overrides import overrides
@data.TokenIndexer.register("pretrained_transformer_mismatched_fixed")
class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
"""TODO(mklimasz) Remove during next allennlp update, fixed on allennlp master."""
def __init__(self, model_name: str, namespace: str = "tags", max_length: int = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> None:
@@ -24,6 +24,37 @@ class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransforme
self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens
self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens
@overrides
def tokens_to_indices(self,
tokens,
vocabulary: vocabulary ) -> IndexedTokenList:
"""
Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the
maximal input of a model.
"""
self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
[t.ensure_text() for t in tokens])
if len(wordpieces) > self._tokenizer.max_len_single_sentence:
raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
" ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
f"Current input: {len(wordpieces)}")
offsets = [x if x is not None else (-1, -1) for x in offsets]
output: IndexedTokenList = {
"token_ids": [t.text_id for t in wordpieces],
"mask": [True] * len(tokens), # for original tokens (i.e. word-level)
"type_ids": [t.type_id for t in wordpieces],
"offsets": offsets,
"wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level)
}
return self._matched_indexer._postprocess_output(output)
class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer):