From ad53db1209f3351850030e42efd3bf019f98184f Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Thu, 16 Nov 2023 22:00:23 +1100
Subject: [PATCH] Change default tokenizer to LAMBO

---
 combo/data/tokenizers/lambo_tokenizer.py |  1 +
 combo/main.py                            | 24 ++++++++++++++++++------
 2 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/combo/data/tokenizers/lambo_tokenizer.py b/combo/data/tokenizers/lambo_tokenizer.py
index f37b076..fc3bbb4 100644
--- a/combo/data/tokenizers/lambo_tokenizer.py
+++ b/combo/data/tokenizers/lambo_tokenizer.py
@@ -6,6 +6,7 @@ from combo.config import Registry
 from combo.config.from_parameters import register_arguments
 from combo.data.tokenizers.token import Token
 from combo.data.tokenizers.tokenizer import Tokenizer
+from combo.data.api import Sentence
 
 
 @Registry.register('lambo_tokenizer')
diff --git a/combo/main.py b/combo/main.py
index 5174a74..1ad700c 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -20,6 +20,7 @@ from combo.modules.archival import load_archive, archive
 from combo.predict import COMBO
 from combo.data import api
 from config import override_parameters
+from data import LamboTokenizer, Sentence
 from utils import ConfigurationError
 
 logging.setLoggerClass(ComboLogger)
@@ -97,6 +98,8 @@ flags.DEFINE_boolean(name="silent", default=True,
                      help="Silent prediction to file (without printing to console).")
 flags.DEFINE_boolean(name="finetuning", default=False,
                      help="Finetuning mode for training.")
+flags.DEFINE_string(name="tokenizer_language", default="English",
+                    help="Tokenizer language.")
 flags.DEFINE_enum(name="predictor_name", default="combo-lambo",
                   enum_values=["combo", "combo-spacy", "combo-lambo"],
                   help="Use predictor with whitespace, spacy or lambo (recommended) tokenizer.")
@@ -267,12 +270,21 @@ def run(_):
         elif FLAGS.output_file:
             checks.file_exists(FLAGS.input_file)
             logger.info("Predicting examples from file", prefix=prefix)
-            test_trees = dataset_reader.read(FLAGS.input_file)
-            predictor = COMBO(model, dataset_reader)
-            with open(FLAGS.output_file, "w") as file:
-                for tree in tqdm(test_trees):
-                    file.writelines(api.sentence2conllu(predictor.predict_instance(tree),
-                                                        keep_semrel=dataset_reader.use_sem).serialize())
+
+            if FLAGS.conllu_format:
+                test_trees = dataset_reader.read(FLAGS.input_file)
+                with open(FLAGS.output_file, "w") as file:
+                    for tree in tqdm(test_trees):
+                        file.writelines(api.sentence2conllu(predictor.predict_instance(tree),
+                                                            keep_semrel=dataset_reader.use_sem).serialize())
+            else:
+                tokenizer = LamboTokenizer(FLAGS.tokenizer_language)
+                with open(FLAGS.input_file, "r") as file:
+                    input_sentences = tokenizer.segment(file.read())
+                with open(FLAGS.output_file, "w") as file:
+                    for sentence in tqdm(input_sentences):
+                        file.writelines(api.sentence2conllu(predictor.predict(' '.join(sentence)),
+                                                            keep_semrel=dataset_reader.use_sem).serialize())
 
         else:
             msg = 'No output file for input file {input_file} specified.'.format(input_file=FLAGS.input_file)
-- 
GitLab