From fed57c5488d8abd06480a785c7b96fa6064a7cad Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Wed, 22 Nov 2023 01:12:56 +1100
Subject: [PATCH] Correct LAMBO segmentation

---
 combo/data/tokenizers/lambo_tokenizer.py |  1 -
 combo/data/tokenizers/tokenizer.py       |  8 ++++++++
 combo/main.py                            | 17 +++++++++++------
 combo/predict.py                         |  5 +++--
 4 files changed, 22 insertions(+), 9 deletions(-)

diff --git a/combo/data/tokenizers/lambo_tokenizer.py b/combo/data/tokenizers/lambo_tokenizer.py
index f37b076..c5f4451 100644
--- a/combo/data/tokenizers/lambo_tokenizer.py
+++ b/combo/data/tokenizers/lambo_tokenizer.py
@@ -49,4 +49,3 @@ class LamboTokenizer(Tokenizer):
                 sentences.append([t.text for t in sentence.tokens])
 
         return sentences
-
diff --git a/combo/data/tokenizers/tokenizer.py b/combo/data/tokenizers/tokenizer.py
index 7b9269c..163036e 100644
--- a/combo/data/tokenizers/tokenizer.py
+++ b/combo/data/tokenizers/tokenizer.py
@@ -72,3 +72,11 @@ class Tokenizer(FromParameters):
         Returns the number of special tokens added for a pair of sequences.
         """
         return 0
+
+    def segment(self, text: str) -> List[List[str]]:
+        """
+        Full segmentation - segment into sentences
+        :param text:
+        :return:
+        """
+        return [[]]
diff --git a/combo/main.py b/combo/main.py
index dd4bdd2..a76db04 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -27,6 +27,8 @@ from combo.modules.model import Model
 from combo.utils import ConfigurationError
 from combo.utils.matrices import extract_combo_matrices
 
+import codecs
+
 logging.setLoggerClass(ComboLogger)
 logger = logging.getLogger(__name__)
 _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"]
@@ -383,9 +385,13 @@ def run(_):
         if FLAGS.input_file == '-':
             print("Interactive mode.")
             sentence = input("Sentence: ")
-            prediction = predictor(sentence)
+            prediction = [p.tokens for p in predictor(sentence)]
+            # Flatten the prediction
+            flattened_prediction = []
+            for p in prediction:
+                flattened_prediction.extend(p)
             print("{:15} {:15} {:10} {:10} {:10}".format('TOKEN', 'LEMMA', 'UPOS', 'HEAD', 'DEPREL'))
-            for token in prediction.tokens:
+            for token in flattened_prediction:
                 print("{:15} {:15} {:10} {:10} {:10}".format(token.text, token.lemma, token.upostag, token.head,
                                                              token.deprel))
         elif FLAGS.output_file:
@@ -410,14 +416,13 @@ def run(_):
 
             else:
                 tokenizer = LamboTokenizer(tokenizer_language)
-                with open(FLAGS.input_file, "r") as file:
+                with open(FLAGS.input_file, "r", encoding='utf-8') as file:
                     input_sentences = tokenizer.segment(file.read())
+                predictions = predictor.predict(input_sentences)
                 with open(FLAGS.output_file, "w") as file:
-                    for sentence in tqdm(input_sentences):
-                        prediction = predictor.predict(' '.join(sentence))
+                    for prediction in tqdm(predictions):
                         file.writelines(api.sentence2conllu(prediction,
                                                             keep_semrel=dataset_reader.use_sem).serialize())
-                        predictions.append(prediction)
 
             if FLAGS.save_matrices:
                 logger.info("Saving matrices", prefix=prefix)
diff --git a/combo/predict.py b/combo/predict.py
index 4450d2f..8363e50 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -74,8 +74,9 @@ class COMBO(PredictorModule):
 
     def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
         if isinstance(sentence, str):
-            return self.predict_json({"sentence": sentence})
-        elif isinstance(sentence, list):
+            sentence = self.dataset_reader.tokenizer.segment(sentence)
+
+        if isinstance(sentence, list):
             if len(sentence) == 0:
                 return []
             example = sentence[0]
-- 
GitLab