diff --git a/scripts/evaluate_iwpt21.py b/scripts/evaluate_iwpt21.py index b67541f108c7bd714acefdbc512f91e14727edf4..bfd24eb908178dbb32c1027054b514b4055e327e 100644 --- a/scripts/evaluate_iwpt21.py +++ b/scripts/evaluate_iwpt21.py @@ -70,7 +70,7 @@ def run(_): utils.execute_command(command) output_collapsed = utils.path_to_str(output_pred).replace('.conllu', '.collapsed.conllu') - utils.collapse_nodes(pathlib.Path(FLAGS.data_dir), output_pred, output_collapsed) + utils.collapse_nodes(pathlib.Path(FLAGS.data_dir) / 'tools', output_pred, output_collapsed) command = f"""python {FLAGS.evaluate_script_path} -v {test_file} diff --git a/scripts/predict_iwpt21.py b/scripts/predict_iwpt21.py index dff359470d6b9f74138f876bdd339009fbfa7f57..513ce3b7c1c97265c2a1ebebb2beb40a7a931905 100644 --- a/scripts/predict_iwpt21.py +++ b/scripts/predict_iwpt21.py @@ -27,9 +27,11 @@ CODE2LANG = { FLAGS = flags.FLAGS flags.DEFINE_string(name="data_dir", default="", - help="Path to IWPT'21 data directory.") + help="Path to data directory.") flags.DEFINE_string(name="models_dir", default="/tmp/", help="Model serialization dir.") +flags.DEFINE_string(name="tools", default="", + help="UD tools path.") flags.DEFINE_integer(name="cuda_device", default=-1, help="Cuda device id (-1 for cpu).") flags.DEFINE_boolean(name="expect_prefix", default=True, @@ -51,9 +53,15 @@ def run(_): data_dir = pathlib.Path(FLAGS.data_dir) files = list(data_dir.iterdir()) - test_file = [f for f in files if f"{lang}.conllu" == f.name] - assert len(test_file) == 1, f"Couldn't find test file." - test_file = test_file[0] + test_file = [f for f in files if f"{lang}.mwt.conllu" == f.name] + # Try to use mwt file if it exists + if test_file: + assert len(test_file) == 1, f"Should be exactly one {lang}.mwt.conllu file." + test_file = test_file[0] + else: + test_file = [f for f in files if f"{lang}.conllu" == f.name] + assert len(test_file) == 1, f"Couldn't find test file." + test_file = test_file[0] output_pred = data_dir / f'{lang}_pred.conllu' command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'} @@ -64,6 +72,12 @@ def run(_): """ utils.execute_command(command) + output_fixed = utils.path_to_str(output_pred).replace('.conllu', '.fixed.conllu') + utils.quick_fix(pathlib.Path(FLAGS.tools), output_pred, output_fixed) + + output_collapsed = output_fixed.replace('.fixed.conllu', '.collapsed.conllu') + utils.collapse_nodes(pathlib.Path(FLAGS.tools), pathlib.Path(output_fixed), output_collapsed) + def main(): app.run(run) diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index 39054e4243783565086ada8b73f1ec173b2fce7d..e1e427fffcc2eaa34ba11989d365894c42c1f191 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -73,10 +73,10 @@ def run(_): if "conllu" in name and "fixed" not in name: output = utils.path_to_str(treebank_file).replace('.conllu', '.fixed.conllu') if "train" in name: - utils.collapse_nodes(data_dir, treebank_file, output) + utils.collapse_nodes(data_dir / 'tools', treebank_file, output) train_paths.append(output) elif "dev" in name: - utils.collapse_nodes(data_dir, treebank_file, output) + utils.collapse_nodes(data_dir / 'tools', treebank_file, output) dev_paths.append(output) # elif "test" in name: # collapse_nodes(data_dir, treebank_file, output) diff --git a/scripts/utils.py b/scripts/utils.py index 19808ad30f09017d58fc0ce0b3c78cc88c7e8b21..2c6620597a9e1314d261d96b5814169783cb86ce 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -19,7 +19,7 @@ LANG2TRANSFORMER = { "ru": "blinoff/roberta-base-russian-v0", "sv": "KB/bert-base-swedish-cased", "uk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-uk-cased/", - "ta": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-ta-cased/", + "ta": "xlm-roberta-large", "sk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-sk-cased/", "lt": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lt-cased/", "lv": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lv-cased/", @@ -51,5 +51,12 @@ def path_to_str(path: pathlib.Path) -> str: def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): output_path = pathlib.Path(output) if not output_path.exists(): - execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} " + execute_command(f"perl {path_to_str(data_dir / 'enhanced_collapse_empty_nodes.pl')} " + f"{path_to_str(treebank_file)}", output) + + +def quick_fix(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): + output_path = pathlib.Path(output) + if not output_path.exists(): + execute_command(f"perl {path_to_str(data_dir / 'conllu-quick-fix.pl')} " f"{path_to_str(treebank_file)}", output)