From ce2d7be688db2ce28be47b734238671467ab5e28 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Tue, 11 May 2021 18:16:03 +0200
Subject: [PATCH] Add conllu-quick-fix call. Change model for Tamil.

---
 scripts/evaluate_iwpt21.py |  2 +-
 scripts/predict_iwpt21.py  | 22 ++++++++++++++++++----
 scripts/train_iwpt21.py    |  4 ++--
 scripts/utils.py           | 11 +++++++++--
 4 files changed, 30 insertions(+), 9 deletions(-)

diff --git a/scripts/evaluate_iwpt21.py b/scripts/evaluate_iwpt21.py
index b67541f..bfd24eb 100644
--- a/scripts/evaluate_iwpt21.py
+++ b/scripts/evaluate_iwpt21.py
@@ -70,7 +70,7 @@ def run(_):
             utils.execute_command(command)
 
             output_collapsed = utils.path_to_str(output_pred).replace('.conllu', '.collapsed.conllu')
-            utils.collapse_nodes(pathlib.Path(FLAGS.data_dir), output_pred, output_collapsed)
+            utils.collapse_nodes(pathlib.Path(FLAGS.data_dir) / 'tools', output_pred, output_collapsed)
 
             command = f"""python {FLAGS.evaluate_script_path} -v 
             {test_file}
diff --git a/scripts/predict_iwpt21.py b/scripts/predict_iwpt21.py
index dff3594..513ce3b 100644
--- a/scripts/predict_iwpt21.py
+++ b/scripts/predict_iwpt21.py
@@ -27,9 +27,11 @@ CODE2LANG = {
 
 FLAGS = flags.FLAGS
 flags.DEFINE_string(name="data_dir", default="",
-                    help="Path to IWPT'21 data directory.")
+                    help="Path to data directory.")
 flags.DEFINE_string(name="models_dir", default="/tmp/",
                     help="Model serialization dir.")
+flags.DEFINE_string(name="tools", default="",
+                    help="UD tools path.")
 flags.DEFINE_integer(name="cuda_device", default=-1,
                      help="Cuda device id (-1 for cpu).")
 flags.DEFINE_boolean(name="expect_prefix", default=True,
@@ -51,9 +53,15 @@ def run(_):
 
         data_dir = pathlib.Path(FLAGS.data_dir)
         files = list(data_dir.iterdir())
-        test_file = [f for f in files if f"{lang}.conllu" == f.name]
-        assert len(test_file) == 1, f"Couldn't find test file."
-        test_file = test_file[0]
+        test_file = [f for f in files if f"{lang}.mwt.conllu" == f.name]
+        # Try to use mwt file if it exists
+        if test_file:
+            assert len(test_file) == 1, f"Should be exactly one {lang}.mwt.conllu file."
+            test_file = test_file[0]
+        else:
+            test_file = [f for f in files if f"{lang}.conllu" == f.name]
+            assert len(test_file) == 1, f"Couldn't find test file."
+            test_file = test_file[0]
 
         output_pred = data_dir / f'{lang}_pred.conllu'
         command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
@@ -64,6 +72,12 @@ def run(_):
         """
         utils.execute_command(command)
 
+        output_fixed = utils.path_to_str(output_pred).replace('.conllu', '.fixed.conllu')
+        utils.quick_fix(pathlib.Path(FLAGS.tools), output_pred, output_fixed)
+
+        output_collapsed = output_fixed.replace('.fixed.conllu', '.collapsed.conllu')
+        utils.collapse_nodes(pathlib.Path(FLAGS.tools), pathlib.Path(output_fixed), output_collapsed)
+
 
 def main():
     app.run(run)
diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py
index 39054e4..e1e427f 100644
--- a/scripts/train_iwpt21.py
+++ b/scripts/train_iwpt21.py
@@ -73,10 +73,10 @@ def run(_):
                 if "conllu" in name and "fixed" not in name:
                     output = utils.path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
                     if "train" in name:
-                        utils.collapse_nodes(data_dir, treebank_file, output)
+                        utils.collapse_nodes(data_dir / 'tools', treebank_file, output)
                         train_paths.append(output)
                     elif "dev" in name:
-                        utils.collapse_nodes(data_dir, treebank_file, output)
+                        utils.collapse_nodes(data_dir / 'tools', treebank_file, output)
                         dev_paths.append(output)
                     # elif "test" in name:
                     #     collapse_nodes(data_dir, treebank_file, output)
diff --git a/scripts/utils.py b/scripts/utils.py
index 19808ad..2c66205 100644
--- a/scripts/utils.py
+++ b/scripts/utils.py
@@ -19,7 +19,7 @@ LANG2TRANSFORMER = {
     "ru": "blinoff/roberta-base-russian-v0",
     "sv": "KB/bert-base-swedish-cased",
     "uk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-uk-cased/",
-    "ta": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-ta-cased/",
+    "ta": "xlm-roberta-large",
     "sk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-sk-cased/",
     "lt": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lt-cased/",
     "lv": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lv-cased/",
@@ -51,5 +51,12 @@ def path_to_str(path: pathlib.Path) -> str:
 def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
     output_path = pathlib.Path(output)
     if not output_path.exists():
-        execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
+        execute_command(f"perl {path_to_str(data_dir / 'enhanced_collapse_empty_nodes.pl')} "
+                        f"{path_to_str(treebank_file)}", output)
+
+
+def quick_fix(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
+    output_path = pathlib.Path(output)
+    if not output_path.exists():
+        execute_command(f"perl {path_to_str(data_dir / 'conllu-quick-fix.pl')} "
                         f"{path_to_str(treebank_file)}", output)
-- 
GitLab