Skip to content
Snippets Groups Projects
Commit 6900d4f0 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add IWPT'21 evaluation script.

parent 0f6faf2a
Branches
Tags
2 merge requests!37Release 1.0.4.,!36Release 1.0.4
This commit is part of merge request !36. Comments created here will be created in the context of that merge request.
import pathlib
from absl import app
from absl import flags
from scripts import utils
CODE2LANG = {
"ar": "Arabic",
"bg": "Bulgarian",
"cs": "Czech",
"nl": "Dutch",
"en": "English",
"et": "Estonian",
"fi": "Finnish",
"fr": "French",
"it": "Italian",
"lv": "Latvian",
"lt": "Lithuanian",
"pl": "Polish",
"ru": "Russian",
"sk": "Slovak",
"sv": "Swedish",
"ta": "Tamil",
"uk": "Ukrainian",
}
FLAGS = flags.FLAGS
flags.DEFINE_string(name="data_dir", default="",
help="Path to IWPT'21 data directory.")
flags.DEFINE_string(name="models_dir", default="/tmp/",
help="Model serialization dir.")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
flags.DEFINE_string(name="evaluate_script_path", default="iwpt21_xud_eval.py",
help="Path to 'iwpt21_xud_eval.py' eval script.")
flags.DEFINE_boolean(name="expect_prefix", default=True,
help="Whether to expect allennlp prefix.")
def run(_):
models_dir = pathlib.Path(FLAGS.models_dir)
for model_dir in models_dir.iterdir():
if model_dir.name not in CODE2LANG:
print("Skipping unknown directory: ", model_dir.name)
continue
treebank_name = f"UD_{CODE2LANG[model_dir.name]}-IWPT"
if FLAGS.expect_prefix:
model_dir = list(model_dir.iterdir())
assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
model_dir = model_dir[0]
treebank_dir = pathlib.Path(FLAGS.data_dir) / treebank_name
files = list(treebank_dir.iterdir())
test_file = [f for f in files if "dev" in f.name and ".conllu" in f.name]
assert len(test_file) == 1, f"Couldn't find test file."
test_file = test_file[0]
if not (model_dir / "results.txt").exists():
output_pred = model_dir / 'predictions.conllu'
command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
--input_file {test_file}
--output_file {output_pred}
--cuda_device {FLAGS.cuda_device}
--silent
"""
utils.execute_command(command)
output_collapsed = utils.path_to_str(output_pred).replace('.conllu', '.collapsed.conllu')
utils.collapse_nodes(pathlib.Path(FLAGS.data_dir), output_pred, output_collapsed)
command = f"""python {FLAGS.evaluate_script_path} -v
{test_file}
{output_collapsed}
"""
utils.execute_command(command, output_file=model_dir / "results.txt")
def main():
app.run(run)
if __name__ == "__main__":
main()
......@@ -43,22 +43,11 @@ flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
def path_to_str(path: pathlib.Path) -> str:
return str(path.resolve())
def merge_files(files: List[str], output: pathlib.Path):
if not output.exists():
os.system(f"cat {' '.join(files)} > {output}")
def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
output_path = pathlib.Path(output)
if not output_path.exists():
utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
f"{path_to_str(treebank_file)}", output)
def run(_):
languages = FLAGS.lang
for lang in languages:
......@@ -82,21 +71,21 @@ def run(_):
for treebank_file in treebank_dir.iterdir():
name = treebank_file.name
if "conllu" in name and "fixed" not in name:
output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
output = utils.path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
if "train" in name:
collapse_nodes(data_dir, treebank_file, output)
utils.collapse_nodes(data_dir, treebank_file, output)
train_paths.append(output)
elif "dev" in name:
collapse_nodes(data_dir, treebank_file, output)
utils.collapse_nodes(data_dir, treebank_file, output)
dev_paths.append(output)
# elif "test" in name:
# collapse_nodes(data_dir, treebank_file, output)
# test_paths.append(output)
if ".txt" in name:
if "train" in name:
train_raw_paths.append(path_to_str(treebank_file))
train_raw_paths.append(utils.path_to_str(treebank_file))
elif "dev" in name:
dev_raw_paths.append(path_to_str(treebank_file))
dev_raw_paths.append(utils.path_to_str(treebank_file))
merged_dataset_name = "IWPT"
lang_data_dir = pathlib.Path(data_dir / f"UD_{full_language}-{merged_dataset_name}")
......
"""Utils for scripts."""
import pathlib
import subprocess
LANG2TRANSFORMER = {
......@@ -41,3 +42,14 @@ def execute_command(command, output_file=None):
subprocess.run(command, check=True, stdout=f)
else:
subprocess.run(command, check=True)
def path_to_str(path: pathlib.Path) -> str:
return str(path.resolve())
def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
output_path = pathlib.Path(output)
if not output_path.exists():
execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
f"{path_to_str(treebank_file)}", output)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment