From ce0362528a0de5eea1157ba242dc024ca30f6b27 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 6 May 2021 10:57:09 +0200 Subject: [PATCH] Add IWPT'21 evaluation script. --- scripts/evaluate_iwpt21.py | 87 ++++++++++++++++++++++++++++++++++++++ scripts/train_iwpt21.py | 21 +++------ scripts/utils.py | 12 ++++++ 3 files changed, 104 insertions(+), 16 deletions(-) create mode 100644 scripts/evaluate_iwpt21.py diff --git a/scripts/evaluate_iwpt21.py b/scripts/evaluate_iwpt21.py new file mode 100644 index 0000000..b67541f --- /dev/null +++ b/scripts/evaluate_iwpt21.py @@ -0,0 +1,87 @@ +import pathlib + +from absl import app +from absl import flags + +from scripts import utils + +CODE2LANG = { + "ar": "Arabic", + "bg": "Bulgarian", + "cs": "Czech", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "it": "Italian", + "lv": "Latvian", + "lt": "Lithuanian", + "pl": "Polish", + "ru": "Russian", + "sk": "Slovak", + "sv": "Swedish", + "ta": "Tamil", + "uk": "Ukrainian", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string(name="data_dir", default="", + help="Path to IWPT'21 data directory.") +flags.DEFINE_string(name="models_dir", default="/tmp/", + help="Model serialization dir.") +flags.DEFINE_integer(name="cuda_device", default=-1, + help="Cuda device id (-1 for cpu).") +flags.DEFINE_string(name="evaluate_script_path", default="iwpt21_xud_eval.py", + help="Path to 'iwpt21_xud_eval.py' eval script.") +flags.DEFINE_boolean(name="expect_prefix", default=True, + help="Whether to expect allennlp prefix.") + + +def run(_): + models_dir = pathlib.Path(FLAGS.models_dir) + for model_dir in models_dir.iterdir(): + if model_dir.name not in CODE2LANG: + print("Skipping unknown directory: ", model_dir.name) + continue + + treebank_name = f"UD_{CODE2LANG[model_dir.name]}-IWPT" + + if FLAGS.expect_prefix: + model_dir = list(model_dir.iterdir()) + assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}" + model_dir = model_dir[0] + + treebank_dir = pathlib.Path(FLAGS.data_dir) / treebank_name + files = list(treebank_dir.iterdir()) + + test_file = [f for f in files if "dev" in f.name and ".conllu" in f.name] + assert len(test_file) == 1, f"Couldn't find test file." + test_file = test_file[0] + + if not (model_dir / "results.txt").exists(): + output_pred = model_dir / 'predictions.conllu' + command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'} + --input_file {test_file} + --output_file {output_pred} + --cuda_device {FLAGS.cuda_device} + --silent + """ + utils.execute_command(command) + + output_collapsed = utils.path_to_str(output_pred).replace('.conllu', '.collapsed.conllu') + utils.collapse_nodes(pathlib.Path(FLAGS.data_dir), output_pred, output_collapsed) + + command = f"""python {FLAGS.evaluate_script_path} -v + {test_file} + {output_collapsed} + """ + utils.execute_command(command, output_file=model_dir / "results.txt") + + +def main(): + app.run(run) + + +if __name__ == "__main__": + main() diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index 17737c9..e4705f7 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -43,22 +43,11 @@ flags.DEFINE_integer(name="cuda_device", default=-1, help="Cuda device id (-1 for cpu).") -def path_to_str(path: pathlib.Path) -> str: - return str(path.resolve()) - - def merge_files(files: List[str], output: pathlib.Path): if not output.exists(): os.system(f"cat {' '.join(files)} > {output}") -def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): - output_path = pathlib.Path(output) - if not output_path.exists(): - utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} " - f"{path_to_str(treebank_file)}", output) - - def run(_): languages = FLAGS.lang for lang in languages: @@ -82,21 +71,21 @@ def run(_): for treebank_file in treebank_dir.iterdir(): name = treebank_file.name if "conllu" in name and "fixed" not in name: - output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu') + output = utils.path_to_str(treebank_file).replace('.conllu', '.fixed.conllu') if "train" in name: - collapse_nodes(data_dir, treebank_file, output) + utils.collapse_nodes(data_dir, treebank_file, output) train_paths.append(output) elif "dev" in name: - collapse_nodes(data_dir, treebank_file, output) + utils.collapse_nodes(data_dir, treebank_file, output) dev_paths.append(output) # elif "test" in name: # collapse_nodes(data_dir, treebank_file, output) # test_paths.append(output) if ".txt" in name: if "train" in name: - train_raw_paths.append(path_to_str(treebank_file)) + train_raw_paths.append(utils.path_to_str(treebank_file)) elif "dev" in name: - dev_raw_paths.append(path_to_str(treebank_file)) + dev_raw_paths.append(utils.path_to_str(treebank_file)) merged_dataset_name = "IWPT" lang_data_dir = pathlib.Path(data_dir / f"UD_{full_language}-{merged_dataset_name}") diff --git a/scripts/utils.py b/scripts/utils.py index bbbe2fe..19808ad 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -1,4 +1,5 @@ """Utils for scripts.""" +import pathlib import subprocess LANG2TRANSFORMER = { @@ -41,3 +42,14 @@ def execute_command(command, output_file=None): subprocess.run(command, check=True, stdout=f) else: subprocess.run(command, check=True) + + +def path_to_str(path: pathlib.Path) -> str: + return str(path.resolve()) + + +def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): + output_path = pathlib.Path(output) + if not output_path.exists(): + execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} " + f"{path_to_str(treebank_file)}", output) -- GitLab