From c082b65055c8c728cf4279a7c19e9d0409c3000c Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Fri, 7 May 2021 14:45:53 +0200 Subject: [PATCH] Add prediction script for IWPT'21. --- scripts/predict_iwpt21.py | 73 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 scripts/predict_iwpt21.py diff --git a/scripts/predict_iwpt21.py b/scripts/predict_iwpt21.py new file mode 100644 index 0000000..dff3594 --- /dev/null +++ b/scripts/predict_iwpt21.py @@ -0,0 +1,73 @@ +import pathlib + +from absl import app +from absl import flags + +from scripts import utils + +CODE2LANG = { + "ar": "Arabic", + "bg": "Bulgarian", + "cs": "Czech", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "it": "Italian", + "lv": "Latvian", + "lt": "Lithuanian", + "pl": "Polish", + "ru": "Russian", + "sk": "Slovak", + "sv": "Swedish", + "ta": "Tamil", + "uk": "Ukrainian", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string(name="data_dir", default="", + help="Path to IWPT'21 data directory.") +flags.DEFINE_string(name="models_dir", default="/tmp/", + help="Model serialization dir.") +flags.DEFINE_integer(name="cuda_device", default=-1, + help="Cuda device id (-1 for cpu).") +flags.DEFINE_boolean(name="expect_prefix", default=True, + help="Whether to expect allennlp prefix.") + + +def run(_): + models_dir = pathlib.Path(FLAGS.models_dir) + for model_dir in models_dir.iterdir(): + lang = model_dir.name + if lang not in CODE2LANG: + print("Skipping unknown directory: ", lang) + continue + + if FLAGS.expect_prefix: + model_dir = list(model_dir.iterdir()) + assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}" + model_dir = model_dir[0] + + data_dir = pathlib.Path(FLAGS.data_dir) + files = list(data_dir.iterdir()) + test_file = [f for f in files if f"{lang}.conllu" == f.name] + assert len(test_file) == 1, f"Couldn't find test file." + test_file = test_file[0] + + output_pred = data_dir / f'{lang}_pred.conllu' + command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'} + --input_file {test_file} + --output_file {output_pred} + --cuda_device {FLAGS.cuda_device} + --silent + """ + utils.execute_command(command) + + +def main(): + app.run(run) + + +if __name__ == "__main__": + main() -- GitLab