From 74ab884353ddec4b7458c3a56097da4acd3b2bd6 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Sat, 18 Nov 2023 22:36:31 +1100
Subject: [PATCH] Add token list splitting for longer sentences in
 pretrained_transformer_fixed_mismatched_indexer.py

---
 ...ed_transformer_fixed_mismatched_indexer.py | 24 ++++++++++++---
 combo/main.py                                 |  8 +++--
 ...ed_transformer_fixed_mismatched_indexer.py | 29 +++++++++++++++++++
 3 files changed, 54 insertions(+), 7 deletions(-)
 create mode 100644 tests/data/token_indexers/test_pretrained_transformer_fixed_mismatched_indexer.py

diff --git a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
index 6fb6a20..0cdf5c0 100644
--- a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
@@ -2,7 +2,7 @@
 Adapted from COMBO
 Authors: Mateusz Klimaszewski, Lukasz Pszenny
 """
-
+import warnings
 from typing import Optional, Dict, Any
 
 from overrides import overrides
@@ -33,6 +33,9 @@ class PretrainedTransformerFixedMismatchedIndexer(PretrainedTransformerMismatche
         self._tokenizer = self._matched_indexer._tokenizer
         self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens
         self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens
+        self._max_length = max_length or self._tokenizer.max_len_single_sentence
+        if self._max_length <= 0:
+            raise ValueError('Maximum sentence length must be larger than 0.')
 
     @overrides
     def tokens_to_indices(self,
@@ -47,12 +50,25 @@ class PretrainedTransformerFixedMismatchedIndexer(PretrainedTransformerMismatche
         wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
             [t.ensure_text() for t in tokens])
 
-        if len(wordpieces) > self._tokenizer.max_len_single_sentence:
-            raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
+        if len(wordpieces) > self._max_length:
+            warnings.warn("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
                              " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
-                             f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
+                             f"Maximal input: {self._max_length}\n"+ \
                              f"Current input: {len(wordpieces)}")
 
+            tokens_chunk_len = self._max_length
+            start_ind = 0
+            wordpieces = []
+            offsets = []
+
+            while start_ind < len(tokens):
+                tokens_chunk = tokens[start_ind:min(start_ind+tokens_chunk_len, len(tokens))]
+                _wordpieces, _offsets = self._allennlp_tokenizer.intra_word_tokenize([t.ensure_text() for t in tokens_chunk])
+                wordpieces += _wordpieces
+                offset_len = offsets[-1][1] if start_ind > 0 else 0
+                offsets += [(o[0]+offset_len, o[1]+offset_len) for o in _offsets]
+                start_ind += tokens_chunk_len
+
         offsets = [x if x is not None else (-1, -1) for x in offsets]
 
         output: IndexedTokenList = {
diff --git a/combo/main.py b/combo/main.py
index 974b70d..250ac5a 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -152,7 +152,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader],
 
     if FLAGS.validation_data_path and not validation_data_loader:
         validation_data_loader = default_data_loader(dataset_reader, validation_data_path)
-    else:
+    elif FLAGS.validation_data_path and validation_data_loader:
         if validation_data_path:
             validation_data_loader.data_path = validation_data_path
         else:
@@ -297,8 +297,10 @@ def run(_):
 
         logger.info('Indexing training data loader', prefix=prefix)
         training_data_loader.index_with(model.vocab)
-        logger.info('Indexing validation data loader', prefix=prefix)
-        validation_data_loader.index_with(model.vocab)
+
+        if validation_data_loader:
+            logger.info('Indexing validation data loader', prefix=prefix)
+            validation_data_loader.index_with(model.vocab)
 
         nlp = TrainableCombo(model,
                              torch.optim.Adam,
diff --git a/tests/data/token_indexers/test_pretrained_transformer_fixed_mismatched_indexer.py b/tests/data/token_indexers/test_pretrained_transformer_fixed_mismatched_indexer.py
new file mode 100644
index 0000000..d2e3e05
--- /dev/null
+++ b/tests/data/token_indexers/test_pretrained_transformer_fixed_mismatched_indexer.py
@@ -0,0 +1,29 @@
+import unittest
+import os
+
+from combo.data.tokenizers import Token
+from combo.data.token_indexers import PretrainedTransformerFixedMismatchedIndexer
+from combo.data.vocabulary import Vocabulary
+
+
+class TokenFeatsIndexerTest(unittest.TestCase):
+    def setUp(self) -> None:
+        self.indexer = PretrainedTransformerFixedMismatchedIndexer("allegro/herbert-base-cased")
+        self.short_indexer = PretrainedTransformerFixedMismatchedIndexer("allegro/herbert-base-cased",
+                                                                         max_length=3)
+        self.vocabulary = Vocabulary.from_files(
+            os.path.join(os.getcwd(), '../../fixtures/train_vocabulary'),
+            oov_token='_',
+            padding_token='__PAD__'
+        )
+
+    def test_offsets(self):
+        output1 = self.indexer.tokens_to_indices([
+            Token('Hello'), Token(','), Token('my'), Token('friend'), Token('!'),
+            Token('What'), Token('a'), Token('nice'), Token('day'), Token('!')
+        ], self.vocabulary)
+        output2 = self.short_indexer.tokens_to_indices([
+            Token('Hello'), Token(','), Token('my'), Token('friend'), Token('!'),
+            Token('What'), Token('a'), Token('nice'), Token('day'), Token('!')
+        ], self.vocabulary)
+        self.assertListEqual(output1['offsets'], output2['offsets'])
-- 
GitLab