From 74ab884353ddec4b7458c3a56097da4acd3b2bd6 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Sat, 18 Nov 2023 22:36:31 +1100 Subject: [PATCH] Add token list splitting for longer sentences in pretrained_transformer_fixed_mismatched_indexer.py --- ...ed_transformer_fixed_mismatched_indexer.py | 24 ++++++++++++--- combo/main.py | 8 +++-- ...ed_transformer_fixed_mismatched_indexer.py | 29 +++++++++++++++++++ 3 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 tests/data/token_indexers/test_pretrained_transformer_fixed_mismatched_indexer.py diff --git a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py index 6fb6a20..0cdf5c0 100644 --- a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py @@ -2,7 +2,7 @@ Adapted from COMBO Authors: Mateusz Klimaszewski, Lukasz Pszenny """ - +import warnings from typing import Optional, Dict, Any from overrides import overrides @@ -33,6 +33,9 @@ class PretrainedTransformerFixedMismatchedIndexer(PretrainedTransformerMismatche self._tokenizer = self._matched_indexer._tokenizer self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens + self._max_length = max_length or self._tokenizer.max_len_single_sentence + if self._max_length <= 0: + raise ValueError('Maximum sentence length must be larger than 0.') @overrides def tokens_to_indices(self, @@ -47,12 +50,25 @@ class PretrainedTransformerFixedMismatchedIndexer(PretrainedTransformerMismatche wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize( [t.ensure_text() for t in tokens]) - if len(wordpieces) > self._tokenizer.max_len_single_sentence: - raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\ + if len(wordpieces) > self._max_length: + warnings.warn("Following sentence consists of more wordpiece tokens that the model can process:\n" +\ " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \ - f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \ + f"Maximal input: {self._max_length}\n"+ \ f"Current input: {len(wordpieces)}") + tokens_chunk_len = self._max_length + start_ind = 0 + wordpieces = [] + offsets = [] + + while start_ind < len(tokens): + tokens_chunk = tokens[start_ind:min(start_ind+tokens_chunk_len, len(tokens))] + _wordpieces, _offsets = self._allennlp_tokenizer.intra_word_tokenize([t.ensure_text() for t in tokens_chunk]) + wordpieces += _wordpieces + offset_len = offsets[-1][1] if start_ind > 0 else 0 + offsets += [(o[0]+offset_len, o[1]+offset_len) for o in _offsets] + start_ind += tokens_chunk_len + offsets = [x if x is not None else (-1, -1) for x in offsets] output: IndexedTokenList = { diff --git a/combo/main.py b/combo/main.py index 974b70d..250ac5a 100755 --- a/combo/main.py +++ b/combo/main.py @@ -152,7 +152,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader], if FLAGS.validation_data_path and not validation_data_loader: validation_data_loader = default_data_loader(dataset_reader, validation_data_path) - else: + elif FLAGS.validation_data_path and validation_data_loader: if validation_data_path: validation_data_loader.data_path = validation_data_path else: @@ -297,8 +297,10 @@ def run(_): logger.info('Indexing training data loader', prefix=prefix) training_data_loader.index_with(model.vocab) - logger.info('Indexing validation data loader', prefix=prefix) - validation_data_loader.index_with(model.vocab) + + if validation_data_loader: + logger.info('Indexing validation data loader', prefix=prefix) + validation_data_loader.index_with(model.vocab) nlp = TrainableCombo(model, torch.optim.Adam, diff --git a/tests/data/token_indexers/test_pretrained_transformer_fixed_mismatched_indexer.py b/tests/data/token_indexers/test_pretrained_transformer_fixed_mismatched_indexer.py new file mode 100644 index 0000000..d2e3e05 --- /dev/null +++ b/tests/data/token_indexers/test_pretrained_transformer_fixed_mismatched_indexer.py @@ -0,0 +1,29 @@ +import unittest +import os + +from combo.data.tokenizers import Token +from combo.data.token_indexers import PretrainedTransformerFixedMismatchedIndexer +from combo.data.vocabulary import Vocabulary + + +class TokenFeatsIndexerTest(unittest.TestCase): + def setUp(self) -> None: + self.indexer = PretrainedTransformerFixedMismatchedIndexer("allegro/herbert-base-cased") + self.short_indexer = PretrainedTransformerFixedMismatchedIndexer("allegro/herbert-base-cased", + max_length=3) + self.vocabulary = Vocabulary.from_files( + os.path.join(os.getcwd(), '../../fixtures/train_vocabulary'), + oov_token='_', + padding_token='__PAD__' + ) + + def test_offsets(self): + output1 = self.indexer.tokens_to_indices([ + Token('Hello'), Token(','), Token('my'), Token('friend'), Token('!'), + Token('What'), Token('a'), Token('nice'), Token('day'), Token('!') + ], self.vocabulary) + output2 = self.short_indexer.tokens_to_indices([ + Token('Hello'), Token(','), Token('my'), Token('friend'), Token('!'), + Token('What'), Token('a'), Token('nice'), Token('day'), Token('!') + ], self.vocabulary) + self.assertListEqual(output1['offsets'], output2['offsets']) -- GitLab