diff --git a/combo/data/tokenizers/lambo_tokenizer.py b/combo/data/tokenizers/lambo_tokenizer.py index f37b07640e8cf7f22add0cc56ba2a9d049c3d201..fc3bbb4639052d7b8fb132456339b64da7738dc8 100644 --- a/combo/data/tokenizers/lambo_tokenizer.py +++ b/combo/data/tokenizers/lambo_tokenizer.py @@ -6,6 +6,7 @@ from combo.config import Registry from combo.config.from_parameters import register_arguments from combo.data.tokenizers.token import Token from combo.data.tokenizers.tokenizer import Tokenizer +from combo.data.api import Sentence @Registry.register('lambo_tokenizer') diff --git a/combo/main.py b/combo/main.py index 5174a7491069ec8f83bc78b201554600c6139210..1ad700c4d8382cf79ff0ea8562565e6c621ed673 100755 --- a/combo/main.py +++ b/combo/main.py @@ -20,6 +20,7 @@ from combo.modules.archival import load_archive, archive from combo.predict import COMBO from combo.data import api from config import override_parameters +from data import LamboTokenizer, Sentence from utils import ConfigurationError logging.setLoggerClass(ComboLogger) @@ -97,6 +98,8 @@ flags.DEFINE_boolean(name="silent", default=True, help="Silent prediction to file (without printing to console).") flags.DEFINE_boolean(name="finetuning", default=False, help="Finetuning mode for training.") +flags.DEFINE_string(name="tokenizer_language", default="English", + help="Tokenizer language.") flags.DEFINE_enum(name="predictor_name", default="combo-lambo", enum_values=["combo", "combo-spacy", "combo-lambo"], help="Use predictor with whitespace, spacy or lambo (recommended) tokenizer.") @@ -267,12 +270,21 @@ def run(_): elif FLAGS.output_file: checks.file_exists(FLAGS.input_file) logger.info("Predicting examples from file", prefix=prefix) - test_trees = dataset_reader.read(FLAGS.input_file) - predictor = COMBO(model, dataset_reader) - with open(FLAGS.output_file, "w") as file: - for tree in tqdm(test_trees): - file.writelines(api.sentence2conllu(predictor.predict_instance(tree), - keep_semrel=dataset_reader.use_sem).serialize()) + + if FLAGS.conllu_format: + test_trees = dataset_reader.read(FLAGS.input_file) + with open(FLAGS.output_file, "w") as file: + for tree in tqdm(test_trees): + file.writelines(api.sentence2conllu(predictor.predict_instance(tree), + keep_semrel=dataset_reader.use_sem).serialize()) + else: + tokenizer = LamboTokenizer(FLAGS.tokenizer_language) + with open(FLAGS.input_file, "r") as file: + input_sentences = tokenizer.segment(file.read()) + with open(FLAGS.output_file, "w") as file: + for sentence in tqdm(input_sentences): + file.writelines(api.sentence2conllu(predictor.predict(' '.join(sentence)), + keep_semrel=dataset_reader.use_sem).serialize()) else: msg = 'No output file for input file {input_file} specified.'.format(input_file=FLAGS.input_file)