From ea61ff93993f1b47ed9482cc2622f293da8164e6 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Mon, 19 Feb 2024 21:23:41 +1100
Subject: [PATCH] Some prediction fixes

---
 combo/main.py  | 13 +++++++------
 pyproject.toml |  2 +-
 2 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/combo/main.py b/combo/main.py
index 49a0fc7..b813074 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -85,8 +85,8 @@ flags.DEFINE_string(name="config_path", default="",
                     help="Config file path.")
 flags.DEFINE_list(name="datasets_for_vocabulary", default=["train"],
                   help="")
-flags.DEFINE_boolean(name="turns", default=False,
-                     help="Segment into sentences on sentence break or on turn break.")
+flags.DEFINE_enum(name="split_level", default="sentence", enum_values=["none", "turn", "sentence"],
+                  help="Don\'t segment, or segment into sentences on sentence break or on turn break.")
 flags.DEFINE_boolean(name="split_multiwords", default=False,
                      help="Split subwords (e.g. don\'t = do, n\'t) into separate tokens.")
 flags.DEFINE_boolean(name="transformer_encoder", default=False, help="Use transformer encoder.")
@@ -424,12 +424,13 @@ def run(_):
                         prefix=prefix)
             dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name,
                                                         tokenizer=LamboTokenizer(tokenizer_language,
-                                                        default_split_level="TURNS" if FLAGS.turns else "SENTENCES",
+                                                        default_split_level=FLAGS.split_level,
                                                         default_split_multiwords=FLAGS.split_multiwords)
                                                        )
 
         predictor = COMBO(model, dataset_reader)
 
+
         if FLAGS.input_file == '-':
             print("Interactive mode.")
             sentence = input("Sentence: ")
@@ -465,9 +466,9 @@ def run(_):
             else:
                 tokenizer = LamboTokenizer(tokenizer_language)
                 with open(FLAGS.input_file, "r", encoding='utf-8') as file:
-                    input_sentences = tokenizer.segment(file.read(),
-                                                        turns=FLAGS.turns,
-                                                        split_multiwords=FLAGS.split_multiwords)
+                    input_sentences = tokenizer.tokenize(file.read(),
+                                                         split_level=FLAGS.split_level.upper(),
+                                                         split_multiwords=FLAGS.split_multiwords)
                 predictions = predictor.predict(input_sentences)
                 with open(FLAGS.output_file, "w") as file:
                     for prediction in tqdm(predictions):
diff --git a/pyproject.toml b/pyproject.toml
index 457b9d8..da51df8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ requires = ["setuptools"]
 
 [project]
 name = "combo"
-version = "3.1.5"
+version = "3.2.0"
 authors = [
     {name = "Maja Jablonska", email = "maja.jablonska@ipipan.waw.pl"}
 ]
-- 
GitLab