From 447006ce0fac9102201abef1bedea52507967e99 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Mon, 19 Feb 2024 16:45:31 +1100
Subject: [PATCH] Fix tokens and TokenList serialization

---
 combo/combo_model.py                          |  2 +-
 combo/data/api.py                             |  4 ++
 combo/data/fields/text_field.py               |  1 +
 combo/data/tokenizers/lambo_tokenizer.py      | 50 +------------------
 .../seq2seq_encoders/transformer_encoder.py   |  2 +-
 combo/predict.py                              | 16 +++++-
 requirements.txt                              |  1 +
 tests/data/tokenizers/test_lambo_tokenizer.py | 18 +++----
 tox.ini                                       |  2 +-
 9 files changed, 31 insertions(+), 65 deletions(-)

diff --git a/combo/combo_model.py b/combo/combo_model.py
index d87d514..162a3b1 100644
--- a/combo/combo_model.py
+++ b/combo/combo_model.py
@@ -24,7 +24,7 @@ from combo.nn.utils import get_text_field_mask
 from combo.predictors import Predictor
 from combo.utils import metrics
 from combo.utils import ConfigurationError
-from modules.seq2seq_encoders.transformer_encoder import TransformerEncoder
+from combo.modules.seq2seq_encoders.transformer_encoder import TransformerEncoder
 
 
 @Registry.register("semantic_multitask")
diff --git a/combo/data/api.py b/combo/data/api.py
index 240a8b3..1e0f6c2 100644
--- a/combo/data/api.py
+++ b/combo/data/api.py
@@ -38,6 +38,10 @@ class _TokenList(conllu.models.TokenList):
     def __repr__(self):
         return 'TokenList<' + ', '.join(token['text'] for token in self) + '>'
 
+    @overrides
+    def serialize(self) -> str:
+        return serialize_token_list(self)
+
 
 def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.models.TokenList:
     tokens = []
diff --git a/combo/data/fields/text_field.py b/combo/data/fields/text_field.py
index ddd1c01..874442f 100644
--- a/combo/data/fields/text_field.py
+++ b/combo/data/fields/text_field.py
@@ -10,6 +10,7 @@ from copy import deepcopy
 from typing import Dict, List, Optional, Iterator
 import textwrap
 
+from spacy.tokens import Token as SpacyToken
 import torch
 
 # There are two levels of dictionaries here: the top level is for the *key*, which aligns
diff --git a/combo/data/tokenizers/lambo_tokenizer.py b/combo/data/tokenizers/lambo_tokenizer.py
index c187233..635ab5d 100644
--- a/combo/data/tokenizers/lambo_tokenizer.py
+++ b/combo/data/tokenizers/lambo_tokenizer.py
@@ -54,8 +54,7 @@ class LamboTokenizer(Tokenizer):
     def tokenize(self,
                  text: str,
                  split_level: Optional[str] = None,
-                 split_multiwords: Optional[bool] = None,
-                 multiwords: Optional[bool] = None) -> List[List[Token]]:
+                 split_multiwords: Optional[bool] = None) -> List[List[Token]]:
         """
         Simple tokenization - ignoring the sentence splits
         :param text:
@@ -105,53 +104,6 @@ class LamboTokenizer(Tokenizer):
 
         return tokens
 
-    def segment(self,
-                text: str,
-                turns: Optional[bool] = None,
-                split_multiwords: Optional[bool] = None) -> List[List[str]]:
-        """
-        Full segmentation - segment into sentences and return a list of strings.
-        :param text:
-        :param turns: segment into sentences by splitting on sentences or on turns. Default: sentences.
-        :param split_multiwords: split subwords into separate tokens (e.g. can't into ca, n't)
-        :return:
-        """
-
-        turns = turns if turns is not None else self.__default_split_level.upper() == "TURNS"
-        split_multiwords = split_multiwords if split_multiwords is not None else self.__default_split_multiwords
-
-        document = self.__tokenizer.segment(text)
-        sentences = []
-        sentence_tokens = []
-
-        for turn in document.turns:
-            if turns:
-                sentence_tokens = []
-            for sentence in turn.sentences:
-                _reset_idx()
-                if not turns:
-                    sentence_tokens = []
-                for token in sentence.tokens:
-                    if len(token.subwords) > 0 and split_multiwords:
-                        # @TODO this is a very dirty fix for Lambo model's shortcomings
-                        # I noticed that for longer words with multiwords it tends to remove the last letter in the last multiword
-                        # so this is a quick workaround to fix it
-
-                        # check if subwords in token.subwords are consistent with token.text
-                        if "".join(token.subwords) != token.text:
-                            fixed_subwords = fix_subwords(token)
-                            token.subwords = fixed_subwords
-                        # sentence_tokens.extend(_sentence_tokens(token, split_multiwords))
-                    # else:
-                    sentence_tokens.extend(_sentence_tokens(token, split_multiwords))
-                if not turns:
-                    sentences.append(sentence_tokens)
-            if turns:
-                sentences.append(sentence_tokens)
-
-        return sentences
-
-
 def fix_subwords(token: Token):
     fixed_subwords = []
     text_it = 0
diff --git a/combo/modules/seq2seq_encoders/transformer_encoder.py b/combo/modules/seq2seq_encoders/transformer_encoder.py
index a49100d..fc389b8 100644
--- a/combo/modules/seq2seq_encoders/transformer_encoder.py
+++ b/combo/modules/seq2seq_encoders/transformer_encoder.py
@@ -8,7 +8,7 @@ from combo.modules.encoder import _EncoderBase
 from combo.config.from_parameters import FromParameters, register_arguments
 
 # from modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
-from nn.utils import add_positional_features
+from combo.nn.utils import add_positional_features
 
 
 # from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
diff --git a/combo/predict.py b/combo/predict.py
index 84e9aed..3b8b81a 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -78,13 +78,24 @@ class COMBO(PredictorModule):
                 **kwargs):
         if isinstance(sentence, str):
             sentence = self.dataset_reader.tokenizer.tokenize(sentence, **kwargs)
+        elif isinstance(sentence, list):
+            if isinstance(sentence[0], str):
+                sentence = [[Token(idx=idx+1, text=t) for idx, t in enumerate(sentence)]]
+            elif isinstance(sentence[0], list):
+                if isinstance(sentence[0][0], str):
+                    sentence = [[Token(idx=idx+1, text=t) for idx, t in enumerate(subsentence)] for subsentence in sentence]
+                elif not isinstance(sentence[0][0], Token):
+                    raise ValueError("Passed sentence must be a list (or list of lists) of strings or Token classes")
+            elif not isinstance(sentence[0], Token) and not isinstance(sentence[0], data.Sentence):
+                raise ValueError("Passed sentence must be a list (or list of lists) of strings or Token classes")
 
         if isinstance(sentence, list):
             if len(sentence) == 0:
                 return []
             example = sentence[0]
             sentences = sentence
-            if isinstance(example, Token) or (isinstance(example, list) and isinstance(example[0], Token)):
+            if (isinstance(example, str) or isinstance(example, Token) or
+                    (isinstance(example, list) and isinstance(example[0], Token))):
                 result = []
                 sentences = [self._to_input_json(s) for s in sentences]
                 for sentences_batch in util.lazy_groups_of(sentences, self.batch_size):
@@ -145,7 +156,8 @@ class COMBO(PredictorModule):
         sentence = json_dict["sentence"]
         # TODO: tokenize EVERYTHING, even if a list is passed?
         if isinstance(sentence, str):
-            tokens = [t.text for t in self.tokenizer.tokenize(json_dict["sentence"])]
+            tokens = [sentence]
+            #tokens = [t.text for t in self.tokenizer.tokenize(json_dict["sentence"])]
         elif isinstance(sentence, list):
             tokens = sentence
         else:
diff --git a/requirements.txt b/requirements.txt
index 129424b..ab9b7d6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -17,6 +17,7 @@ filelock~=3.9.0
 pandas~=2.1.3
 pytest~=7.2.2
 transformers~=4.27.3
+typing_extensions==4.5.0
 sacremoses~=0.0.53
 spacy==3.7.2
 urllib3==1.26.6
\ No newline at end of file
diff --git a/tests/data/tokenizers/test_lambo_tokenizer.py b/tests/data/tokenizers/test_lambo_tokenizer.py
index 5726c2c..a2eab44 100644
--- a/tests/data/tokenizers/test_lambo_tokenizer.py
+++ b/tests/data/tokenizers/test_lambo_tokenizer.py
@@ -13,20 +13,16 @@ class LamboTokenizerTest(unittest.TestCase):
         self.assertListEqual([t.text for t in sentences[0] + sentences[1]],
                              ['Hello', 'cats', '.', 'I', 'love', 'you'])
 
-    def test_segment_text(self):
-        tokens = self.lambo_tokenizer.segment('Hello cats. I love you.\n\nHi.')
-        self.assertListEqual(tokens,
-                             [['Hello', 'cats', '.'], ['I', 'love', 'you', '.'], ['Hi', '.']])
-
     def test_segment_text_with_turns(self):
-        tokens = self.lambo_tokenizer.segment('Hello cats. I love you.\n\nHi.', turns=True)
-        self.assertListEqual(tokens,
-                             [['Hello', 'cats', '.', 'I', 'love', 'you', '.'], ['Hi', '.']])
+        tokens = self.lambo_tokenizer.tokenize('Hello cats. I love you.\n\nHi.', split_level="TURN")
+        self.assertEqual(len(tokens), 2)
+        self.assertListEqual([t.text for t in tokens[0]],
+                             ['Hello', 'cats', '.', 'I', 'love', 'you', '.'])
 
     def test_segment_text_with_multiwords(self):
-        tokens = self.lambo_tokenizer.segment('I don\'t want a pizza.', split_multiwords=True)
-        self.assertListEqual(tokens,
-                             [['I', 'do', 'n\'t', 'want', 'a', 'pizza', '.']])
+        tokens = self.lambo_tokenizer.tokenize('I don\'t want a pizza.', split_multiwords=True)
+        self.assertListEqual([t.text for t in tokens[0]],
+                             ['I', 'do', 'n\'t', 'want', 'a', 'pizza', '.'])
 
     def test_segment_text_with_multiwords_without_splitting(self):
         tokens = self.lambo_tokenizer.tokenize('I don\'t want a pizza.', split_multiwords=False)
diff --git a/tox.ini b/tox.ini
index 12e6382..e47d581 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
 [tox]
-envlist = py38
+envlist = py39
 
 [testenv]
 commands = pytest {posargs}
\ No newline at end of file
-- 
GitLab