Skip to content
Snippets Groups Projects

Add try/catch clause for sentences with large number of wordpieces.

Closed Martyna Wiącek requested to merge fix/try_catch_clause_for_long_wordpiece_list into develop
Compare and
1 file
+ 25
19
Compare changes
  • Side-by-side
  • Inline
import logging
import sys
from typing import Optional, Dict, Any, List, Tuple
from allennlp import data
from allennlp.data import token_indexers, tokenizers, IndexedTokenList, vocabulary
from overrides import overrides
from typing import List
logger = logging.getLogger(__name__)
@data.TokenIndexer.register("pretrained_transformer_mismatched_fixed")
class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
@@ -34,28 +35,33 @@ class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransforme
Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the
maximal input of a model.
"""
self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
try:
self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
[t.ensure_text() for t in tokens])
wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
[t.ensure_text() for t in tokens])
if len(wordpieces) > self._tokenizer.max_len_single_sentence:
raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
" ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
f"Current input: {len(wordpieces)}")
if len(wordpieces) > self._tokenizer.max_len_single_sentence:
raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
" ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
f"Current input: {len(wordpieces)}")
offsets = [x if x is not None else (-1, -1) for x in offsets]
offsets = [x if x is not None else (-1, -1) for x in offsets]
output: IndexedTokenList = {
"token_ids": [t.text_id for t in wordpieces],
"mask": [True] * len(tokens), # for original tokens (i.e. word-level)
"type_ids": [t.type_id for t in wordpieces],
"offsets": offsets,
"wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level)
}
output: IndexedTokenList = {
"token_ids": [t.text_id for t in wordpieces],
"mask": [True] * len(tokens), # for original tokens (i.e. word-level)
"type_ids": [t.type_id for t in wordpieces],
"offsets": offsets,
"wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level)
}
return self._matched_indexer._postprocess_output(output)
return self._matched_indexer._postprocess_output(output)
except ValueError as value_error:
logger.error(value_error)
sys.exit(1)
class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer):