diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py index 1b8610315c552c01612d0b6d728ee5a96b81d55b..fc29896a2ecbb5408c9367f2d8ad5b0c1d4a5d4c 100644 --- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py @@ -1,12 +1,9 @@ -import logging -import sys from typing import Optional, Dict, Any, List, Tuple from allennlp import data from allennlp.data import token_indexers, tokenizers, IndexedTokenList, vocabulary from overrides import overrides -logger = logging.getLogger(__name__) @data.TokenIndexer.register("pretrained_transformer_mismatched_fixed") class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer): @@ -35,33 +32,28 @@ class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransforme Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the maximal input of a model. """ - try: - self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary) + self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary) - wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize( - [t.ensure_text() for t in tokens]) + wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize( + [t.ensure_text() for t in tokens]) - if len(wordpieces) > self._tokenizer.max_len_single_sentence: - raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\ - " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \ - f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \ - f"Current input: {len(wordpieces)}") + if len(wordpieces) > self._tokenizer.max_len_single_sentence: + raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\ + " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \ + f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \ + f"Current input: {len(wordpieces)}") - offsets = [x if x is not None else (-1, -1) for x in offsets] + offsets = [x if x is not None else (-1, -1) for x in offsets] - output: IndexedTokenList = { - "token_ids": [t.text_id for t in wordpieces], - "mask": [True] * len(tokens), # for original tokens (i.e. word-level) - "type_ids": [t.type_id for t in wordpieces], - "offsets": offsets, - "wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level) - } + output: IndexedTokenList = { + "token_ids": [t.text_id for t in wordpieces], + "mask": [True] * len(tokens), # for original tokens (i.e. word-level) + "type_ids": [t.type_id for t in wordpieces], + "offsets": offsets, + "wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level) + } - return self._matched_indexer._postprocess_output(output) - - except ValueError as value_error: - logger.error(value_error) - sys.exit(1) + return self._matched_indexer._postprocess_output(output) class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer): diff --git a/combo/predict.py b/combo/predict.py index 01a083727e7768953ba952c53038cc5156adf612..83b030ff41a5026672a2e555115698980d00de77 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -1,5 +1,6 @@ import logging import os +import sys from typing import List, Union, Dict, Any import numpy as np @@ -48,7 +49,12 @@ class COMBO(predictor.Predictor): :param sentence: sentence(s) representation :return: Sentence or List[Sentence] depending on the input """ - return self.predict(sentence) + try: + return self.predict(sentence) + except Exception as e: + logger.error(e) + logger.error('Exiting.') + sys.exit(1) def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): if isinstance(sentence, str):