Skip to content
Snippets Groups Projects
Commit 1abcab10 authored by martynawiacek's avatar martynawiacek
Browse files

moved try/except to combo predictor

parent e060a26d
1 merge request!41Add try/catch clause for sentences with large number of wordpieces.
Pipeline #4333 passed with stage
in 6 minutes and 51 seconds
import logging
import sys
from typing import Optional, Dict, Any, List, Tuple
from allennlp import data
from allennlp.data import token_indexers, tokenizers, IndexedTokenList, vocabulary
from overrides import overrides
logger = logging.getLogger(__name__)
@data.TokenIndexer.register("pretrained_transformer_mismatched_fixed")
class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
......@@ -35,33 +32,28 @@ class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransforme
Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the
maximal input of a model.
"""
try:
self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
[t.ensure_text() for t in tokens])
wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
[t.ensure_text() for t in tokens])
if len(wordpieces) > self._tokenizer.max_len_single_sentence:
raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
" ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
f"Current input: {len(wordpieces)}")
if len(wordpieces) > self._tokenizer.max_len_single_sentence:
raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
" ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
f"Current input: {len(wordpieces)}")
offsets = [x if x is not None else (-1, -1) for x in offsets]
offsets = [x if x is not None else (-1, -1) for x in offsets]
output: IndexedTokenList = {
"token_ids": [t.text_id for t in wordpieces],
"mask": [True] * len(tokens), # for original tokens (i.e. word-level)
"type_ids": [t.type_id for t in wordpieces],
"offsets": offsets,
"wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level)
}
output: IndexedTokenList = {
"token_ids": [t.text_id for t in wordpieces],
"mask": [True] * len(tokens), # for original tokens (i.e. word-level)
"type_ids": [t.type_id for t in wordpieces],
"offsets": offsets,
"wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level)
}
return self._matched_indexer._postprocess_output(output)
except ValueError as value_error:
logger.error(value_error)
sys.exit(1)
return self._matched_indexer._postprocess_output(output)
class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer):
......
import logging
import os
import sys
from typing import List, Union, Dict, Any
import numpy as np
......@@ -48,7 +49,12 @@ class COMBO(predictor.Predictor):
:param sentence: sentence(s) representation
:return: Sentence or List[Sentence] depending on the input
"""
return self.predict(sentence)
try:
return self.predict(sentence)
except Exception as e:
logger.error(e)
logger.error('Exiting.')
sys.exit(1)
def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
if isinstance(sentence, str):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment