From e060a26d6109783e8e15ab1b2482026a54f8c930 Mon Sep 17 00:00:00 2001
From: martynawiacek <Hkkm6072>
Date: Thu, 20 Jan 2022 17:28:20 +0100
Subject: [PATCH 1/2] Add try/catch clause for sentences with large number of
 wordpieces.

---
 ...etrained_transformer_mismatched_indexer.py | 44 +++++++++++--------
 1 file changed, 25 insertions(+), 19 deletions(-)

diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
index b9a4e3c..1b86103 100644
--- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
@@ -1,11 +1,12 @@
+import logging
+import sys
 from typing import Optional, Dict, Any, List, Tuple
 
 from allennlp import data
 from allennlp.data import token_indexers, tokenizers, IndexedTokenList, vocabulary
 from overrides import overrides
 
-from typing import List
-
+logger = logging.getLogger(__name__)
 
 @data.TokenIndexer.register("pretrained_transformer_mismatched_fixed")
 class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
@@ -34,28 +35,33 @@ class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransforme
         Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the
         maximal input of a model.
         """
-        self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
+        try:
+            self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
+
+            wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
+                [t.ensure_text() for t in tokens])
 
-        wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
-            [t.ensure_text() for t in tokens])
+            if len(wordpieces) > self._tokenizer.max_len_single_sentence:
+                raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
+                                 " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
+                                 f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
+                                 f"Current input: {len(wordpieces)}")
 
-        if len(wordpieces) > self._tokenizer.max_len_single_sentence:
-            raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
-                             " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
-                             f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
-                             f"Current input: {len(wordpieces)}")
+            offsets = [x if x is not None else (-1, -1) for x in offsets]
 
-        offsets = [x if x is not None else (-1, -1) for x in offsets]
+            output: IndexedTokenList = {
+                "token_ids": [t.text_id for t in wordpieces],
+                "mask": [True] * len(tokens),  # for original tokens (i.e. word-level)
+                "type_ids": [t.type_id for t in wordpieces],
+                "offsets": offsets,
+                "wordpiece_mask": [True] * len(wordpieces),  # for wordpieces (i.e. subword-level)
+            }
 
-        output: IndexedTokenList = {
-            "token_ids": [t.text_id for t in wordpieces],
-            "mask": [True] * len(tokens),  # for original tokens (i.e. word-level)
-            "type_ids": [t.type_id for t in wordpieces],
-            "offsets": offsets,
-            "wordpiece_mask": [True] * len(wordpieces),  # for wordpieces (i.e. subword-level)
-        }
+            return self._matched_indexer._postprocess_output(output)
 
-        return self._matched_indexer._postprocess_output(output)
+        except ValueError as value_error:
+            logger.error(value_error)
+            sys.exit(1)
 
 
 class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer):
-- 
GitLab


From 1abcab109ac5549ce8508e9a7a30d9b448bae603 Mon Sep 17 00:00:00 2001
From: martynawiacek <Hkkm6072>
Date: Fri, 21 Jan 2022 09:58:14 +0100
Subject: [PATCH 2/2] moved try/except to combo predictor

---
 ...etrained_transformer_mismatched_indexer.py | 42 ++++++++-----------
 combo/predict.py                              |  8 +++-
 2 files changed, 24 insertions(+), 26 deletions(-)

diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
index 1b86103..fc29896 100644
--- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
@@ -1,12 +1,9 @@
-import logging
-import sys
 from typing import Optional, Dict, Any, List, Tuple
 
 from allennlp import data
 from allennlp.data import token_indexers, tokenizers, IndexedTokenList, vocabulary
 from overrides import overrides
 
-logger = logging.getLogger(__name__)
 
 @data.TokenIndexer.register("pretrained_transformer_mismatched_fixed")
 class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
@@ -35,33 +32,28 @@ class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransforme
         Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the
         maximal input of a model.
         """
-        try:
-            self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
+        self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
 
-            wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
-                [t.ensure_text() for t in tokens])
+        wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
+            [t.ensure_text() for t in tokens])
 
-            if len(wordpieces) > self._tokenizer.max_len_single_sentence:
-                raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
-                                 " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
-                                 f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
-                                 f"Current input: {len(wordpieces)}")
+        if len(wordpieces) > self._tokenizer.max_len_single_sentence:
+            raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
+                             " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
+                             f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
+                             f"Current input: {len(wordpieces)}")
 
-            offsets = [x if x is not None else (-1, -1) for x in offsets]
+        offsets = [x if x is not None else (-1, -1) for x in offsets]
 
-            output: IndexedTokenList = {
-                "token_ids": [t.text_id for t in wordpieces],
-                "mask": [True] * len(tokens),  # for original tokens (i.e. word-level)
-                "type_ids": [t.type_id for t in wordpieces],
-                "offsets": offsets,
-                "wordpiece_mask": [True] * len(wordpieces),  # for wordpieces (i.e. subword-level)
-            }
+        output: IndexedTokenList = {
+            "token_ids": [t.text_id for t in wordpieces],
+            "mask": [True] * len(tokens),  # for original tokens (i.e. word-level)
+            "type_ids": [t.type_id for t in wordpieces],
+            "offsets": offsets,
+            "wordpiece_mask": [True] * len(wordpieces),  # for wordpieces (i.e. subword-level)
+        }
 
-            return self._matched_indexer._postprocess_output(output)
-
-        except ValueError as value_error:
-            logger.error(value_error)
-            sys.exit(1)
+        return self._matched_indexer._postprocess_output(output)
 
 
 class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer):
diff --git a/combo/predict.py b/combo/predict.py
index 01a0837..83b030f 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -1,5 +1,6 @@
 import logging
 import os
+import sys
 from typing import List, Union, Dict, Any
 
 import numpy as np
@@ -48,7 +49,12 @@ class COMBO(predictor.Predictor):
         :param sentence: sentence(s) representation
         :return: Sentence or List[Sentence] depending on the input
         """
-        return self.predict(sentence)
+        try:
+            return self.predict(sentence)
+        except Exception as e:
+            logger.error(e)
+            logger.error('Exiting.')
+            sys.exit(1)
 
     def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
         if isinstance(sentence, str):
-- 
GitLab