Skip to content
Snippets Groups Projects
Commit 6900d4f0 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add IWPT'21 evaluation script.

parent 0f6faf2a
Branches
Tags
2 merge requests!37Release 1.0.4.,!36Release 1.0.4
import pathlib
from absl import app
from absl import flags
from scripts import utils
CODE2LANG = {
"ar": "Arabic",
"bg": "Bulgarian",
"cs": "Czech",
"nl": "Dutch",
"en": "English",
"et": "Estonian",
"fi": "Finnish",
"fr": "French",
"it": "Italian",
"lv": "Latvian",
"lt": "Lithuanian",
"pl": "Polish",
"ru": "Russian",
"sk": "Slovak",
"sv": "Swedish",
"ta": "Tamil",
"uk": "Ukrainian",
}
FLAGS = flags.FLAGS
flags.DEFINE_string(name="data_dir", default="",
help="Path to IWPT'21 data directory.")
flags.DEFINE_string(name="models_dir", default="/tmp/",
help="Model serialization dir.")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
flags.DEFINE_string(name="evaluate_script_path", default="iwpt21_xud_eval.py",
help="Path to 'iwpt21_xud_eval.py' eval script.")
flags.DEFINE_boolean(name="expect_prefix", default=True,
help="Whether to expect allennlp prefix.")
def run(_):
models_dir = pathlib.Path(FLAGS.models_dir)
for model_dir in models_dir.iterdir():
if model_dir.name not in CODE2LANG:
print("Skipping unknown directory: ", model_dir.name)
continue
treebank_name = f"UD_{CODE2LANG[model_dir.name]}-IWPT"
if FLAGS.expect_prefix:
model_dir = list(model_dir.iterdir())
assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
model_dir = model_dir[0]
treebank_dir = pathlib.Path(FLAGS.data_dir) / treebank_name
files = list(treebank_dir.iterdir())
test_file = [f for f in files if "dev" in f.name and ".conllu" in f.name]
assert len(test_file) == 1, f"Couldn't find test file."
test_file = test_file[0]
if not (model_dir / "results.txt").exists():
output_pred = model_dir / 'predictions.conllu'
command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
--input_file {test_file}
--output_file {output_pred}
--cuda_device {FLAGS.cuda_device}
--silent
"""
utils.execute_command(command)
output_collapsed = utils.path_to_str(output_pred).replace('.conllu', '.collapsed.conllu')
utils.collapse_nodes(pathlib.Path(FLAGS.data_dir), output_pred, output_collapsed)
command = f"""python {FLAGS.evaluate_script_path} -v
{test_file}
{output_collapsed}
"""
utils.execute_command(command, output_file=model_dir / "results.txt")
def main():
app.run(run)
if __name__ == "__main__":
main()
......@@ -43,22 +43,11 @@ flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
def path_to_str(path: pathlib.Path) -> str:
return str(path.resolve())
def merge_files(files: List[str], output: pathlib.Path):
if not output.exists():
os.system(f"cat {' '.join(files)} > {output}")
def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
output_path = pathlib.Path(output)
if not output_path.exists():
utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
f"{path_to_str(treebank_file)}", output)
def run(_):
languages = FLAGS.lang
for lang in languages:
......@@ -82,21 +71,21 @@ def run(_):
for treebank_file in treebank_dir.iterdir():
name = treebank_file.name
if "conllu" in name and "fixed" not in name:
output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
output = utils.path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
if "train" in name:
collapse_nodes(data_dir, treebank_file, output)
utils.collapse_nodes(data_dir, treebank_file, output)
train_paths.append(output)
elif "dev" in name:
collapse_nodes(data_dir, treebank_file, output)
utils.collapse_nodes(data_dir, treebank_file, output)
dev_paths.append(output)
# elif "test" in name:
# collapse_nodes(data_dir, treebank_file, output)
# test_paths.append(output)
if ".txt" in name:
if "train" in name:
train_raw_paths.append(path_to_str(treebank_file))
train_raw_paths.append(utils.path_to_str(treebank_file))
elif "dev" in name:
dev_raw_paths.append(path_to_str(treebank_file))
dev_raw_paths.append(utils.path_to_str(treebank_file))
merged_dataset_name = "IWPT"
lang_data_dir = pathlib.Path(data_dir / f"UD_{full_language}-{merged_dataset_name}")
......
"""Utils for scripts."""
import pathlib
import subprocess
LANG2TRANSFORMER = {
......@@ -41,3 +42,14 @@ def execute_command(command, output_file=None):
subprocess.run(command, check=True, stdout=f)
else:
subprocess.run(command, check=True)
def path_to_str(path: pathlib.Path) -> str:
return str(path.resolve())
def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
output_path = pathlib.Path(output)
if not output_path.exists():
execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
f"{path_to_str(treebank_file)}", output)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment