From e060a26d6109783e8e15ab1b2482026a54f8c930 Mon Sep 17 00:00:00 2001 From: martynawiacek <Hkkm6072> Date: Thu, 20 Jan 2022 17:28:20 +0100 Subject: [PATCH] Add try/catch clause for sentences with large number of wordpieces. --- ...etrained_transformer_mismatched_indexer.py | 44 +++++++++++-------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py index b9a4e3c..1b86103 100644 --- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py @@ -1,11 +1,12 @@ +import logging +import sys from typing import Optional, Dict, Any, List, Tuple from allennlp import data from allennlp.data import token_indexers, tokenizers, IndexedTokenList, vocabulary from overrides import overrides -from typing import List - +logger = logging.getLogger(__name__) @data.TokenIndexer.register("pretrained_transformer_mismatched_fixed") class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer): @@ -34,28 +35,33 @@ class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransforme Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the maximal input of a model. """ - self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary) + try: + self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary) + + wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize( + [t.ensure_text() for t in tokens]) - wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize( - [t.ensure_text() for t in tokens]) + if len(wordpieces) > self._tokenizer.max_len_single_sentence: + raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\ + " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \ + f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \ + f"Current input: {len(wordpieces)}") - if len(wordpieces) > self._tokenizer.max_len_single_sentence: - raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\ - " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \ - f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \ - f"Current input: {len(wordpieces)}") + offsets = [x if x is not None else (-1, -1) for x in offsets] - offsets = [x if x is not None else (-1, -1) for x in offsets] + output: IndexedTokenList = { + "token_ids": [t.text_id for t in wordpieces], + "mask": [True] * len(tokens), # for original tokens (i.e. word-level) + "type_ids": [t.type_id for t in wordpieces], + "offsets": offsets, + "wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level) + } - output: IndexedTokenList = { - "token_ids": [t.text_id for t in wordpieces], - "mask": [True] * len(tokens), # for original tokens (i.e. word-level) - "type_ids": [t.type_id for t in wordpieces], - "offsets": offsets, - "wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level) - } + return self._matched_indexer._postprocess_output(output) - return self._matched_indexer._postprocess_output(output) + except ValueError as value_error: + logger.error(value_error) + sys.exit(1) class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer): -- GitLab