From ea61ff93993f1b47ed9482cc2622f293da8164e6 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Mon, 19 Feb 2024 21:23:41 +1100 Subject: [PATCH] Some prediction fixes --- combo/main.py | 13 +++++++------ pyproject.toml | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/combo/main.py b/combo/main.py index 49a0fc7..b813074 100755 --- a/combo/main.py +++ b/combo/main.py @@ -85,8 +85,8 @@ flags.DEFINE_string(name="config_path", default="", help="Config file path.") flags.DEFINE_list(name="datasets_for_vocabulary", default=["train"], help="") -flags.DEFINE_boolean(name="turns", default=False, - help="Segment into sentences on sentence break or on turn break.") +flags.DEFINE_enum(name="split_level", default="sentence", enum_values=["none", "turn", "sentence"], + help="Don\'t segment, or segment into sentences on sentence break or on turn break.") flags.DEFINE_boolean(name="split_multiwords", default=False, help="Split subwords (e.g. don\'t = do, n\'t) into separate tokens.") flags.DEFINE_boolean(name="transformer_encoder", default=False, help="Use transformer encoder.") @@ -424,12 +424,13 @@ def run(_): prefix=prefix) dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name, tokenizer=LamboTokenizer(tokenizer_language, - default_split_level="TURNS" if FLAGS.turns else "SENTENCES", + default_split_level=FLAGS.split_level, default_split_multiwords=FLAGS.split_multiwords) ) predictor = COMBO(model, dataset_reader) + if FLAGS.input_file == '-': print("Interactive mode.") sentence = input("Sentence: ") @@ -465,9 +466,9 @@ def run(_): else: tokenizer = LamboTokenizer(tokenizer_language) with open(FLAGS.input_file, "r", encoding='utf-8') as file: - input_sentences = tokenizer.segment(file.read(), - turns=FLAGS.turns, - split_multiwords=FLAGS.split_multiwords) + input_sentences = tokenizer.tokenize(file.read(), + split_level=FLAGS.split_level.upper(), + split_multiwords=FLAGS.split_multiwords) predictions = predictor.predict(input_sentences) with open(FLAGS.output_file, "w") as file: for prediction in tqdm(predictions): diff --git a/pyproject.toml b/pyproject.toml index 457b9d8..da51df8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools"] [project] name = "combo" -version = "3.1.5" +version = "3.2.0" authors = [ {name = "Maja Jablonska", email = "maja.jablonska@ipipan.waw.pl"} ] -- GitLab