From e060a26d6109783e8e15ab1b2482026a54f8c930 Mon Sep 17 00:00:00 2001
From: martynawiacek <Hkkm6072>
Date: Thu, 20 Jan 2022 17:28:20 +0100
Subject: [PATCH] Add try/catch clause for sentences with large number of
 wordpieces.

---
 ...etrained_transformer_mismatched_indexer.py | 44 +++++++++++--------
 1 file changed, 25 insertions(+), 19 deletions(-)

diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
index b9a4e3c..1b86103 100644
--- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
@@ -1,11 +1,12 @@
+import logging
+import sys
 from typing import Optional, Dict, Any, List, Tuple
 
 from allennlp import data
 from allennlp.data import token_indexers, tokenizers, IndexedTokenList, vocabulary
 from overrides import overrides
 
-from typing import List
-
+logger = logging.getLogger(__name__)
 
 @data.TokenIndexer.register("pretrained_transformer_mismatched_fixed")
 class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
@@ -34,28 +35,33 @@ class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransforme
         Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the
         maximal input of a model.
         """
-        self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
+        try:
+            self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
+
+            wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
+                [t.ensure_text() for t in tokens])
 
-        wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
-            [t.ensure_text() for t in tokens])
+            if len(wordpieces) > self._tokenizer.max_len_single_sentence:
+                raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
+                                 " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
+                                 f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
+                                 f"Current input: {len(wordpieces)}")
 
-        if len(wordpieces) > self._tokenizer.max_len_single_sentence:
-            raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
-                             " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
-                             f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
-                             f"Current input: {len(wordpieces)}")
+            offsets = [x if x is not None else (-1, -1) for x in offsets]
 
-        offsets = [x if x is not None else (-1, -1) for x in offsets]
+            output: IndexedTokenList = {
+                "token_ids": [t.text_id for t in wordpieces],
+                "mask": [True] * len(tokens),  # for original tokens (i.e. word-level)
+                "type_ids": [t.type_id for t in wordpieces],
+                "offsets": offsets,
+                "wordpiece_mask": [True] * len(wordpieces),  # for wordpieces (i.e. subword-level)
+            }
 
-        output: IndexedTokenList = {
-            "token_ids": [t.text_id for t in wordpieces],
-            "mask": [True] * len(tokens),  # for original tokens (i.e. word-level)
-            "type_ids": [t.type_id for t in wordpieces],
-            "offsets": offsets,
-            "wordpiece_mask": [True] * len(wordpieces),  # for wordpieces (i.e. subword-level)
-        }
+            return self._matched_indexer._postprocess_output(output)
 
-        return self._matched_indexer._postprocess_output(output)
+        except ValueError as value_error:
+            logger.error(value_error)
+            sys.exit(1)
 
 
 class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer):
-- 
GitLab