From 447006ce0fac9102201abef1bedea52507967e99 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Mon, 19 Feb 2024 16:45:31 +1100 Subject: [PATCH] Fix tokens and TokenList serialization --- combo/combo_model.py | 2 +- combo/data/api.py | 4 ++ combo/data/fields/text_field.py | 1 + combo/data/tokenizers/lambo_tokenizer.py | 50 +------------------ .../seq2seq_encoders/transformer_encoder.py | 2 +- combo/predict.py | 16 +++++- requirements.txt | 1 + tests/data/tokenizers/test_lambo_tokenizer.py | 18 +++---- tox.ini | 2 +- 9 files changed, 31 insertions(+), 65 deletions(-) diff --git a/combo/combo_model.py b/combo/combo_model.py index d87d514..162a3b1 100644 --- a/combo/combo_model.py +++ b/combo/combo_model.py @@ -24,7 +24,7 @@ from combo.nn.utils import get_text_field_mask from combo.predictors import Predictor from combo.utils import metrics from combo.utils import ConfigurationError -from modules.seq2seq_encoders.transformer_encoder import TransformerEncoder +from combo.modules.seq2seq_encoders.transformer_encoder import TransformerEncoder @Registry.register("semantic_multitask") diff --git a/combo/data/api.py b/combo/data/api.py index 240a8b3..1e0f6c2 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -38,6 +38,10 @@ class _TokenList(conllu.models.TokenList): def __repr__(self): return 'TokenList<' + ', '.join(token['text'] for token in self) + '>' + @overrides + def serialize(self) -> str: + return serialize_token_list(self) + def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.models.TokenList: tokens = [] diff --git a/combo/data/fields/text_field.py b/combo/data/fields/text_field.py index ddd1c01..874442f 100644 --- a/combo/data/fields/text_field.py +++ b/combo/data/fields/text_field.py @@ -10,6 +10,7 @@ from copy import deepcopy from typing import Dict, List, Optional, Iterator import textwrap +from spacy.tokens import Token as SpacyToken import torch # There are two levels of dictionaries here: the top level is for the *key*, which aligns diff --git a/combo/data/tokenizers/lambo_tokenizer.py b/combo/data/tokenizers/lambo_tokenizer.py index c187233..635ab5d 100644 --- a/combo/data/tokenizers/lambo_tokenizer.py +++ b/combo/data/tokenizers/lambo_tokenizer.py @@ -54,8 +54,7 @@ class LamboTokenizer(Tokenizer): def tokenize(self, text: str, split_level: Optional[str] = None, - split_multiwords: Optional[bool] = None, - multiwords: Optional[bool] = None) -> List[List[Token]]: + split_multiwords: Optional[bool] = None) -> List[List[Token]]: """ Simple tokenization - ignoring the sentence splits :param text: @@ -105,53 +104,6 @@ class LamboTokenizer(Tokenizer): return tokens - def segment(self, - text: str, - turns: Optional[bool] = None, - split_multiwords: Optional[bool] = None) -> List[List[str]]: - """ - Full segmentation - segment into sentences and return a list of strings. - :param text: - :param turns: segment into sentences by splitting on sentences or on turns. Default: sentences. - :param split_multiwords: split subwords into separate tokens (e.g. can't into ca, n't) - :return: - """ - - turns = turns if turns is not None else self.__default_split_level.upper() == "TURNS" - split_multiwords = split_multiwords if split_multiwords is not None else self.__default_split_multiwords - - document = self.__tokenizer.segment(text) - sentences = [] - sentence_tokens = [] - - for turn in document.turns: - if turns: - sentence_tokens = [] - for sentence in turn.sentences: - _reset_idx() - if not turns: - sentence_tokens = [] - for token in sentence.tokens: - if len(token.subwords) > 0 and split_multiwords: - # @TODO this is a very dirty fix for Lambo model's shortcomings - # I noticed that for longer words with multiwords it tends to remove the last letter in the last multiword - # so this is a quick workaround to fix it - - # check if subwords in token.subwords are consistent with token.text - if "".join(token.subwords) != token.text: - fixed_subwords = fix_subwords(token) - token.subwords = fixed_subwords - # sentence_tokens.extend(_sentence_tokens(token, split_multiwords)) - # else: - sentence_tokens.extend(_sentence_tokens(token, split_multiwords)) - if not turns: - sentences.append(sentence_tokens) - if turns: - sentences.append(sentence_tokens) - - return sentences - - def fix_subwords(token: Token): fixed_subwords = [] text_it = 0 diff --git a/combo/modules/seq2seq_encoders/transformer_encoder.py b/combo/modules/seq2seq_encoders/transformer_encoder.py index a49100d..fc389b8 100644 --- a/combo/modules/seq2seq_encoders/transformer_encoder.py +++ b/combo/modules/seq2seq_encoders/transformer_encoder.py @@ -8,7 +8,7 @@ from combo.modules.encoder import _EncoderBase from combo.config.from_parameters import FromParameters, register_arguments # from modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder -from nn.utils import add_positional_features +from combo.nn.utils import add_positional_features # from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder diff --git a/combo/predict.py b/combo/predict.py index 84e9aed..3b8b81a 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -78,13 +78,24 @@ class COMBO(PredictorModule): **kwargs): if isinstance(sentence, str): sentence = self.dataset_reader.tokenizer.tokenize(sentence, **kwargs) + elif isinstance(sentence, list): + if isinstance(sentence[0], str): + sentence = [[Token(idx=idx+1, text=t) for idx, t in enumerate(sentence)]] + elif isinstance(sentence[0], list): + if isinstance(sentence[0][0], str): + sentence = [[Token(idx=idx+1, text=t) for idx, t in enumerate(subsentence)] for subsentence in sentence] + elif not isinstance(sentence[0][0], Token): + raise ValueError("Passed sentence must be a list (or list of lists) of strings or Token classes") + elif not isinstance(sentence[0], Token) and not isinstance(sentence[0], data.Sentence): + raise ValueError("Passed sentence must be a list (or list of lists) of strings or Token classes") if isinstance(sentence, list): if len(sentence) == 0: return [] example = sentence[0] sentences = sentence - if isinstance(example, Token) or (isinstance(example, list) and isinstance(example[0], Token)): + if (isinstance(example, str) or isinstance(example, Token) or + (isinstance(example, list) and isinstance(example[0], Token))): result = [] sentences = [self._to_input_json(s) for s in sentences] for sentences_batch in util.lazy_groups_of(sentences, self.batch_size): @@ -145,7 +156,8 @@ class COMBO(PredictorModule): sentence = json_dict["sentence"] # TODO: tokenize EVERYTHING, even if a list is passed? if isinstance(sentence, str): - tokens = [t.text for t in self.tokenizer.tokenize(json_dict["sentence"])] + tokens = [sentence] + #tokens = [t.text for t in self.tokenizer.tokenize(json_dict["sentence"])] elif isinstance(sentence, list): tokens = sentence else: diff --git a/requirements.txt b/requirements.txt index 129424b..ab9b7d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ filelock~=3.9.0 pandas~=2.1.3 pytest~=7.2.2 transformers~=4.27.3 +typing_extensions==4.5.0 sacremoses~=0.0.53 spacy==3.7.2 urllib3==1.26.6 \ No newline at end of file diff --git a/tests/data/tokenizers/test_lambo_tokenizer.py b/tests/data/tokenizers/test_lambo_tokenizer.py index 5726c2c..a2eab44 100644 --- a/tests/data/tokenizers/test_lambo_tokenizer.py +++ b/tests/data/tokenizers/test_lambo_tokenizer.py @@ -13,20 +13,16 @@ class LamboTokenizerTest(unittest.TestCase): self.assertListEqual([t.text for t in sentences[0] + sentences[1]], ['Hello', 'cats', '.', 'I', 'love', 'you']) - def test_segment_text(self): - tokens = self.lambo_tokenizer.segment('Hello cats. I love you.\n\nHi.') - self.assertListEqual(tokens, - [['Hello', 'cats', '.'], ['I', 'love', 'you', '.'], ['Hi', '.']]) - def test_segment_text_with_turns(self): - tokens = self.lambo_tokenizer.segment('Hello cats. I love you.\n\nHi.', turns=True) - self.assertListEqual(tokens, - [['Hello', 'cats', '.', 'I', 'love', 'you', '.'], ['Hi', '.']]) + tokens = self.lambo_tokenizer.tokenize('Hello cats. I love you.\n\nHi.', split_level="TURN") + self.assertEqual(len(tokens), 2) + self.assertListEqual([t.text for t in tokens[0]], + ['Hello', 'cats', '.', 'I', 'love', 'you', '.']) def test_segment_text_with_multiwords(self): - tokens = self.lambo_tokenizer.segment('I don\'t want a pizza.', split_multiwords=True) - self.assertListEqual(tokens, - [['I', 'do', 'n\'t', 'want', 'a', 'pizza', '.']]) + tokens = self.lambo_tokenizer.tokenize('I don\'t want a pizza.', split_multiwords=True) + self.assertListEqual([t.text for t in tokens[0]], + ['I', 'do', 'n\'t', 'want', 'a', 'pizza', '.']) def test_segment_text_with_multiwords_without_splitting(self): tokens = self.lambo_tokenizer.tokenize('I don\'t want a pizza.', split_multiwords=False) diff --git a/tox.ini b/tox.ini index 12e6382..e47d581 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py38 +envlist = py39 [testenv] commands = pytest {posargs} \ No newline at end of file -- GitLab