Skip to content
Snippets Groups Projects
Commit ce036252 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add IWPT'21 evaluation script.

parent 16f97a35
Branches
No related merge requests found
import pathlib
from absl import app
from absl import flags
from scripts import utils
CODE2LANG = {
"ar": "Arabic",
"bg": "Bulgarian",
"cs": "Czech",
"nl": "Dutch",
"en": "English",
"et": "Estonian",
"fi": "Finnish",
"fr": "French",
"it": "Italian",
"lv": "Latvian",
"lt": "Lithuanian",
"pl": "Polish",
"ru": "Russian",
"sk": "Slovak",
"sv": "Swedish",
"ta": "Tamil",
"uk": "Ukrainian",
}
FLAGS = flags.FLAGS
flags.DEFINE_string(name="data_dir", default="",
help="Path to IWPT'21 data directory.")
flags.DEFINE_string(name="models_dir", default="/tmp/",
help="Model serialization dir.")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
flags.DEFINE_string(name="evaluate_script_path", default="iwpt21_xud_eval.py",
help="Path to 'iwpt21_xud_eval.py' eval script.")
flags.DEFINE_boolean(name="expect_prefix", default=True,
help="Whether to expect allennlp prefix.")
def run(_):
models_dir = pathlib.Path(FLAGS.models_dir)
for model_dir in models_dir.iterdir():
if model_dir.name not in CODE2LANG:
print("Skipping unknown directory: ", model_dir.name)
continue
treebank_name = f"UD_{CODE2LANG[model_dir.name]}-IWPT"
if FLAGS.expect_prefix:
model_dir = list(model_dir.iterdir())
assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
model_dir = model_dir[0]
treebank_dir = pathlib.Path(FLAGS.data_dir) / treebank_name
files = list(treebank_dir.iterdir())
test_file = [f for f in files if "dev" in f.name and ".conllu" in f.name]
assert len(test_file) == 1, f"Couldn't find test file."
test_file = test_file[0]
if not (model_dir / "results.txt").exists():
output_pred = model_dir / 'predictions.conllu'
command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
--input_file {test_file}
--output_file {output_pred}
--cuda_device {FLAGS.cuda_device}
--silent
"""
utils.execute_command(command)
output_collapsed = utils.path_to_str(output_pred).replace('.conllu', '.collapsed.conllu')
utils.collapse_nodes(pathlib.Path(FLAGS.data_dir), output_pred, output_collapsed)
command = f"""python {FLAGS.evaluate_script_path} -v
{test_file}
{output_collapsed}
"""
utils.execute_command(command, output_file=model_dir / "results.txt")
def main():
app.run(run)
if __name__ == "__main__":
main()
......@@ -43,22 +43,11 @@ flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
def path_to_str(path: pathlib.Path) -> str:
return str(path.resolve())
def merge_files(files: List[str], output: pathlib.Path):
if not output.exists():
os.system(f"cat {' '.join(files)} > {output}")
def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
output_path = pathlib.Path(output)
if not output_path.exists():
utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
f"{path_to_str(treebank_file)}", output)
def run(_):
languages = FLAGS.lang
for lang in languages:
......@@ -82,21 +71,21 @@ def run(_):
for treebank_file in treebank_dir.iterdir():
name = treebank_file.name
if "conllu" in name and "fixed" not in name:
output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
output = utils.path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
if "train" in name:
collapse_nodes(data_dir, treebank_file, output)
utils.collapse_nodes(data_dir, treebank_file, output)
train_paths.append(output)
elif "dev" in name:
collapse_nodes(data_dir, treebank_file, output)
utils.collapse_nodes(data_dir, treebank_file, output)
dev_paths.append(output)
# elif "test" in name:
# collapse_nodes(data_dir, treebank_file, output)
# test_paths.append(output)
if ".txt" in name:
if "train" in name:
train_raw_paths.append(path_to_str(treebank_file))
train_raw_paths.append(utils.path_to_str(treebank_file))
elif "dev" in name:
dev_raw_paths.append(path_to_str(treebank_file))
dev_raw_paths.append(utils.path_to_str(treebank_file))
merged_dataset_name = "IWPT"
lang_data_dir = pathlib.Path(data_dir / f"UD_{full_language}-{merged_dataset_name}")
......
"""Utils for scripts."""
import pathlib
import subprocess
LANG2TRANSFORMER = {
......@@ -41,3 +42,14 @@ def execute_command(command, output_file=None):
subprocess.run(command, check=True, stdout=f)
else:
subprocess.run(command, check=True)
def path_to_str(path: pathlib.Path) -> str:
return str(path.resolve())
def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
output_path = pathlib.Path(output)
if not output_path.exists():
execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
f"{path_to_str(treebank_file)}", output)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment