From 6900d4f0fd51dd9e6718d7f8091b09146ba3dff8 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Thu, 6 May 2021 10:57:09 +0200
Subject: [PATCH] Add IWPT'21 evaluation script.

---
 scripts/evaluate_iwpt21.py | 87 ++++++++++++++++++++++++++++++++++++++
 scripts/train_iwpt21.py    | 21 +++------
 scripts/utils.py           | 12 ++++++
 3 files changed, 104 insertions(+), 16 deletions(-)
 create mode 100644 scripts/evaluate_iwpt21.py

diff --git a/scripts/evaluate_iwpt21.py b/scripts/evaluate_iwpt21.py
new file mode 100644
index 0000000..b67541f
--- /dev/null
+++ b/scripts/evaluate_iwpt21.py
@@ -0,0 +1,87 @@
+import pathlib
+
+from absl import app
+from absl import flags
+
+from scripts import utils
+
+CODE2LANG = {
+    "ar": "Arabic",
+    "bg": "Bulgarian",
+    "cs": "Czech",
+    "nl": "Dutch",
+    "en": "English",
+    "et": "Estonian",
+    "fi": "Finnish",
+    "fr": "French",
+    "it": "Italian",
+    "lv": "Latvian",
+    "lt": "Lithuanian",
+    "pl": "Polish",
+    "ru": "Russian",
+    "sk": "Slovak",
+    "sv": "Swedish",
+    "ta": "Tamil",
+    "uk": "Ukrainian",
+}
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string(name="data_dir", default="",
+                    help="Path to IWPT'21 data directory.")
+flags.DEFINE_string(name="models_dir", default="/tmp/",
+                    help="Model serialization dir.")
+flags.DEFINE_integer(name="cuda_device", default=-1,
+                     help="Cuda device id (-1 for cpu).")
+flags.DEFINE_string(name="evaluate_script_path", default="iwpt21_xud_eval.py",
+                    help="Path to 'iwpt21_xud_eval.py' eval script.")
+flags.DEFINE_boolean(name="expect_prefix", default=True,
+                     help="Whether to expect allennlp prefix.")
+
+
+def run(_):
+    models_dir = pathlib.Path(FLAGS.models_dir)
+    for model_dir in models_dir.iterdir():
+        if model_dir.name not in CODE2LANG:
+            print("Skipping unknown directory: ", model_dir.name)
+            continue
+
+        treebank_name = f"UD_{CODE2LANG[model_dir.name]}-IWPT"
+
+        if FLAGS.expect_prefix:
+            model_dir = list(model_dir.iterdir())
+            assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
+            model_dir = model_dir[0]
+
+        treebank_dir = pathlib.Path(FLAGS.data_dir) / treebank_name
+        files = list(treebank_dir.iterdir())
+
+        test_file = [f for f in files if "dev" in f.name and ".conllu" in f.name]
+        assert len(test_file) == 1, f"Couldn't find test file."
+        test_file = test_file[0]
+
+        if not (model_dir / "results.txt").exists():
+            output_pred = model_dir / 'predictions.conllu'
+            command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
+            --input_file {test_file}
+            --output_file {output_pred}
+            --cuda_device {FLAGS.cuda_device}
+            --silent
+            """
+            utils.execute_command(command)
+
+            output_collapsed = utils.path_to_str(output_pred).replace('.conllu', '.collapsed.conllu')
+            utils.collapse_nodes(pathlib.Path(FLAGS.data_dir), output_pred, output_collapsed)
+
+            command = f"""python {FLAGS.evaluate_script_path} -v 
+            {test_file}
+            {output_collapsed} 
+            """
+            utils.execute_command(command, output_file=model_dir / "results.txt")
+
+
+def main():
+    app.run(run)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py
index 17737c9..e4705f7 100644
--- a/scripts/train_iwpt21.py
+++ b/scripts/train_iwpt21.py
@@ -43,22 +43,11 @@ flags.DEFINE_integer(name="cuda_device", default=-1,
                      help="Cuda device id (-1 for cpu).")
 
 
-def path_to_str(path: pathlib.Path) -> str:
-    return str(path.resolve())
-
-
 def merge_files(files: List[str], output: pathlib.Path):
     if not output.exists():
         os.system(f"cat {' '.join(files)} > {output}")
 
 
-def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
-    output_path = pathlib.Path(output)
-    if not output_path.exists():
-        utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
-                              f"{path_to_str(treebank_file)}", output)
-
-
 def run(_):
     languages = FLAGS.lang
     for lang in languages:
@@ -82,21 +71,21 @@ def run(_):
             for treebank_file in treebank_dir.iterdir():
                 name = treebank_file.name
                 if "conllu" in name and "fixed" not in name:
-                    output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
+                    output = utils.path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
                     if "train" in name:
-                        collapse_nodes(data_dir, treebank_file, output)
+                        utils.collapse_nodes(data_dir, treebank_file, output)
                         train_paths.append(output)
                     elif "dev" in name:
-                        collapse_nodes(data_dir, treebank_file, output)
+                        utils.collapse_nodes(data_dir, treebank_file, output)
                         dev_paths.append(output)
                     # elif "test" in name:
                     #     collapse_nodes(data_dir, treebank_file, output)
                     #     test_paths.append(output)
                 if ".txt" in name:
                     if "train" in name:
-                        train_raw_paths.append(path_to_str(treebank_file))
+                        train_raw_paths.append(utils.path_to_str(treebank_file))
                     elif "dev" in name:
-                        dev_raw_paths.append(path_to_str(treebank_file))
+                        dev_raw_paths.append(utils.path_to_str(treebank_file))
 
         merged_dataset_name = "IWPT"
         lang_data_dir = pathlib.Path(data_dir / f"UD_{full_language}-{merged_dataset_name}")
diff --git a/scripts/utils.py b/scripts/utils.py
index bbbe2fe..19808ad 100644
--- a/scripts/utils.py
+++ b/scripts/utils.py
@@ -1,4 +1,5 @@
 """Utils for scripts."""
+import pathlib
 import subprocess
 
 LANG2TRANSFORMER = {
@@ -41,3 +42,14 @@ def execute_command(command, output_file=None):
             subprocess.run(command, check=True, stdout=f)
     else:
         subprocess.run(command, check=True)
+
+
+def path_to_str(path: pathlib.Path) -> str:
+    return str(path.resolve())
+
+
+def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
+    output_path = pathlib.Path(output)
+    if not output_path.exists():
+        execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
+                        f"{path_to_str(treebank_file)}", output)
-- 
GitLab