From ce2d7be688db2ce28be47b734238671467ab5e28 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Tue, 11 May 2021 18:16:03 +0200 Subject: [PATCH] Add conllu-quick-fix call. Change model for Tamil. --- scripts/evaluate_iwpt21.py | 2 +- scripts/predict_iwpt21.py | 22 ++++++++++++++++++---- scripts/train_iwpt21.py | 4 ++-- scripts/utils.py | 11 +++++++++-- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/scripts/evaluate_iwpt21.py b/scripts/evaluate_iwpt21.py index b67541f..bfd24eb 100644 --- a/scripts/evaluate_iwpt21.py +++ b/scripts/evaluate_iwpt21.py @@ -70,7 +70,7 @@ def run(_): utils.execute_command(command) output_collapsed = utils.path_to_str(output_pred).replace('.conllu', '.collapsed.conllu') - utils.collapse_nodes(pathlib.Path(FLAGS.data_dir), output_pred, output_collapsed) + utils.collapse_nodes(pathlib.Path(FLAGS.data_dir) / 'tools', output_pred, output_collapsed) command = f"""python {FLAGS.evaluate_script_path} -v {test_file} diff --git a/scripts/predict_iwpt21.py b/scripts/predict_iwpt21.py index dff3594..513ce3b 100644 --- a/scripts/predict_iwpt21.py +++ b/scripts/predict_iwpt21.py @@ -27,9 +27,11 @@ CODE2LANG = { FLAGS = flags.FLAGS flags.DEFINE_string(name="data_dir", default="", - help="Path to IWPT'21 data directory.") + help="Path to data directory.") flags.DEFINE_string(name="models_dir", default="/tmp/", help="Model serialization dir.") +flags.DEFINE_string(name="tools", default="", + help="UD tools path.") flags.DEFINE_integer(name="cuda_device", default=-1, help="Cuda device id (-1 for cpu).") flags.DEFINE_boolean(name="expect_prefix", default=True, @@ -51,9 +53,15 @@ def run(_): data_dir = pathlib.Path(FLAGS.data_dir) files = list(data_dir.iterdir()) - test_file = [f for f in files if f"{lang}.conllu" == f.name] - assert len(test_file) == 1, f"Couldn't find test file." - test_file = test_file[0] + test_file = [f for f in files if f"{lang}.mwt.conllu" == f.name] + # Try to use mwt file if it exists + if test_file: + assert len(test_file) == 1, f"Should be exactly one {lang}.mwt.conllu file." + test_file = test_file[0] + else: + test_file = [f for f in files if f"{lang}.conllu" == f.name] + assert len(test_file) == 1, f"Couldn't find test file." + test_file = test_file[0] output_pred = data_dir / f'{lang}_pred.conllu' command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'} @@ -64,6 +72,12 @@ def run(_): """ utils.execute_command(command) + output_fixed = utils.path_to_str(output_pred).replace('.conllu', '.fixed.conllu') + utils.quick_fix(pathlib.Path(FLAGS.tools), output_pred, output_fixed) + + output_collapsed = output_fixed.replace('.fixed.conllu', '.collapsed.conllu') + utils.collapse_nodes(pathlib.Path(FLAGS.tools), pathlib.Path(output_fixed), output_collapsed) + def main(): app.run(run) diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index 39054e4..e1e427f 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -73,10 +73,10 @@ def run(_): if "conllu" in name and "fixed" not in name: output = utils.path_to_str(treebank_file).replace('.conllu', '.fixed.conllu') if "train" in name: - utils.collapse_nodes(data_dir, treebank_file, output) + utils.collapse_nodes(data_dir / 'tools', treebank_file, output) train_paths.append(output) elif "dev" in name: - utils.collapse_nodes(data_dir, treebank_file, output) + utils.collapse_nodes(data_dir / 'tools', treebank_file, output) dev_paths.append(output) # elif "test" in name: # collapse_nodes(data_dir, treebank_file, output) diff --git a/scripts/utils.py b/scripts/utils.py index 19808ad..2c66205 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -19,7 +19,7 @@ LANG2TRANSFORMER = { "ru": "blinoff/roberta-base-russian-v0", "sv": "KB/bert-base-swedish-cased", "uk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-uk-cased/", - "ta": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-ta-cased/", + "ta": "xlm-roberta-large", "sk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-sk-cased/", "lt": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lt-cased/", "lv": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lv-cased/", @@ -51,5 +51,12 @@ def path_to_str(path: pathlib.Path) -> str: def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): output_path = pathlib.Path(output) if not output_path.exists(): - execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} " + execute_command(f"perl {path_to_str(data_dir / 'enhanced_collapse_empty_nodes.pl')} " + f"{path_to_str(treebank_file)}", output) + + +def quick_fix(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): + output_path = pathlib.Path(output) + if not output_path.exists(): + execute_command(f"perl {path_to_str(data_dir / 'conllu-quick-fix.pl')} " f"{path_to_str(treebank_file)}", output) -- GitLab