Skip to content
Snippets Groups Projects
Commit ea61ff93 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Some prediction fixes

parent cea27a0f
No related merge requests found
Pipeline #16696 passed with stage
in 31 seconds
...@@ -85,8 +85,8 @@ flags.DEFINE_string(name="config_path", default="", ...@@ -85,8 +85,8 @@ flags.DEFINE_string(name="config_path", default="",
help="Config file path.") help="Config file path.")
flags.DEFINE_list(name="datasets_for_vocabulary", default=["train"], flags.DEFINE_list(name="datasets_for_vocabulary", default=["train"],
help="") help="")
flags.DEFINE_boolean(name="turns", default=False, flags.DEFINE_enum(name="split_level", default="sentence", enum_values=["none", "turn", "sentence"],
help="Segment into sentences on sentence break or on turn break.") help="Don\'t segment, or segment into sentences on sentence break or on turn break.")
flags.DEFINE_boolean(name="split_multiwords", default=False, flags.DEFINE_boolean(name="split_multiwords", default=False,
help="Split subwords (e.g. don\'t = do, n\'t) into separate tokens.") help="Split subwords (e.g. don\'t = do, n\'t) into separate tokens.")
flags.DEFINE_boolean(name="transformer_encoder", default=False, help="Use transformer encoder.") flags.DEFINE_boolean(name="transformer_encoder", default=False, help="Use transformer encoder.")
...@@ -424,12 +424,13 @@ def run(_): ...@@ -424,12 +424,13 @@ def run(_):
prefix=prefix) prefix=prefix)
dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name, dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name,
tokenizer=LamboTokenizer(tokenizer_language, tokenizer=LamboTokenizer(tokenizer_language,
default_split_level="TURNS" if FLAGS.turns else "SENTENCES", default_split_level=FLAGS.split_level,
default_split_multiwords=FLAGS.split_multiwords) default_split_multiwords=FLAGS.split_multiwords)
) )
predictor = COMBO(model, dataset_reader) predictor = COMBO(model, dataset_reader)
if FLAGS.input_file == '-': if FLAGS.input_file == '-':
print("Interactive mode.") print("Interactive mode.")
sentence = input("Sentence: ") sentence = input("Sentence: ")
...@@ -465,9 +466,9 @@ def run(_): ...@@ -465,9 +466,9 @@ def run(_):
else: else:
tokenizer = LamboTokenizer(tokenizer_language) tokenizer = LamboTokenizer(tokenizer_language)
with open(FLAGS.input_file, "r", encoding='utf-8') as file: with open(FLAGS.input_file, "r", encoding='utf-8') as file:
input_sentences = tokenizer.segment(file.read(), input_sentences = tokenizer.tokenize(file.read(),
turns=FLAGS.turns, split_level=FLAGS.split_level.upper(),
split_multiwords=FLAGS.split_multiwords) split_multiwords=FLAGS.split_multiwords)
predictions = predictor.predict(input_sentences) predictions = predictor.predict(input_sentences)
with open(FLAGS.output_file, "w") as file: with open(FLAGS.output_file, "w") as file:
for prediction in tqdm(predictions): for prediction in tqdm(predictions):
......
...@@ -3,7 +3,7 @@ requires = ["setuptools"] ...@@ -3,7 +3,7 @@ requires = ["setuptools"]
[project] [project]
name = "combo" name = "combo"
version = "3.1.5" version = "3.2.0"
authors = [ authors = [
{name = "Maja Jablonska", email = "maja.jablonska@ipipan.waw.pl"} {name = "Maja Jablonska", email = "maja.jablonska@ipipan.waw.pl"}
] ]
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment