From b8c83784dfd541dbae8851a45c66a204d5cd82a7 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Tue, 5 Jan 2021 16:37:59 +0100
Subject: [PATCH] Extend installation and pre-trained models documentation, fix
 installation, update trainer to be in sync with AllenNLP 1.2.1.

---
 combo/data/api.py         |   3 ++
 combo/training/trainer.py |  38 ++++++++------
 docs/installation.md      |  18 ++++++-
 docs/models.md            |   2 +
 scripts/train.py          | 108 ++++++++++++++++++++++++++++++++++++--
 setup.py                  |   9 ++--
 6 files changed, 151 insertions(+), 27 deletions(-)

diff --git a/combo/data/api.py b/combo/data/api.py
index 7d44917..4ab7f1a 100644
--- a/combo/data/api.py
+++ b/combo/data/api.py
@@ -36,6 +36,9 @@ class Sentence:
             "metadata": self.metadata,
         })
 
+    def __len__(self):
+        return len(self.tokens)
+
 
 class _TokenList(conllu.TokenList):
 
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
index 26bd75f..f74873e 100644
--- a/combo/training/trainer.py
+++ b/combo/training/trainer.py
@@ -68,10 +68,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
         self.validate_every_n = 5
 
     @overrides
-    def train(self) -> Dict[str, Any]:
-        """
-        Trains the supplied model with the supplied parameters.
-        """
+    def _try_train(self) -> Dict[str, Any]:
         try:
             epoch_counter = self._restore_checkpoint()
         except RuntimeError:
@@ -87,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
         logger.info("Beginning training.")
 
         val_metrics: Dict[str, float] = {}
-        this_epoch_val_metric: float = None
+        this_epoch_val_metric: float
         metrics: Dict[str, Any] = {}
         epochs_trained = 0
         training_start_time = time.time()
@@ -97,19 +94,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             metrics["best_validation_" + key] = value
 
         for callback in self._epoch_callbacks:
-            callback(self, metrics={}, epoch=-1, is_master=True)
+            callback(self, metrics={}, epoch=-1, is_master=self._master)
 
         for epoch in range(epoch_counter, self._num_epochs):
             epoch_start_time = time.time()
             train_metrics = self._train_epoch(epoch)
 
+            if self._master and self._checkpointer is not None:
+                self._checkpointer.save_checkpoint(epoch, self, save_model_only=True)
+
+            # Wait for the master to finish saving the model checkpoint
+            if self._distributed:
+                dist.barrier()
+
             # get peak of memory usage
-            if "cpu_memory_MB" in train_metrics:
-                metrics["peak_cpu_memory_MB"] = max(
-                    metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"]
-                )
             for key, value in train_metrics.items():
-                if key.startswith("gpu_"):
+                if key.startswith("gpu_") and key.endswith("_memory_MB"):
+                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
+                elif key.startswith("worker_") and key.endswith("_memory_MB"):
                     metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
 
             if self._validation_data_loader is not None:
@@ -129,9 +131,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
                             self.model,
                             val_loss,
                             val_reg_loss,
-                            num_batches=num_batches,
                             batch_loss=None,
                             batch_reg_loss=None,
+                            num_batches=num_batches,
                             reset=True,
                             world_size=self._world_size,
                             cuda_device=self.cuda_device,
@@ -139,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
 
                         # Check validation metric for early stopping
                         this_epoch_val_metric = val_metrics[self._validation_metric]
-                        # self._metric_tracker.add_metric(this_epoch_val_metric)
+                        self._metric_tracker.add_metric(this_epoch_val_metric)
 
                 train_metrics["patience"] = self._metric_tracker._patience
                 if self._metric_tracker.should_stop_early():
@@ -184,7 +186,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             if self._momentum_scheduler:
                 self._momentum_scheduler.step(this_epoch_val_metric)
 
-            if self._master:
+            if self._master and self._checkpointer is not None:
                 self._checkpointer.save_checkpoint(
                     epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far()
                 )
@@ -209,11 +211,13 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
 
             epochs_trained += 1
 
-        # make sure pending events are flushed to disk and files are closed properly
-        self._tensorboard.close()
+        for callback in self._end_callbacks:
+            callback(self, metrics=metrics, epoch=epoch, is_master=self._master)
 
         # Load the best model state before returning
-        best_model_state = self._checkpointer.best_model_state()
+        best_model_state = (
+            None if self._checkpointer is None else self._checkpointer.best_model_state()
+        )
         if best_model_state:
             self.model.load_state_dict(best_model_state)
 
diff --git a/docs/installation.md b/docs/installation.md
index 6371094..bf741f9 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -7,7 +7,23 @@ python setup.py develop
 combo --helpfull
 ```
 
+### Virtualenv example:
+```bash
+python -m venv venv
+source venv/bin/activate
+pip install --upgrade pip
+python setup.py develop
+```
+
 ## Problems & solutions
 * **jsonnet** installation error
 
-use `conda install -c conda-forge jsonnet=0.15.0`
+Run `conda install -c conda-forge jsonnet=0.15.0` and re-run installation.
+
+* No package 'sentencepiece' found
+
+Run `pip install sentencepiece` and re-run installation.
+
+* Missing Cython error
+
+Run `pip install spacy==2.3.2` manually and re-run installation.
diff --git a/docs/models.md b/docs/models.md
index 25a7f70..94eed03 100644
--- a/docs/models.md
+++ b/docs/models.md
@@ -4,6 +4,8 @@ COMBO provides pre-trained models for:
 - morphosyntactic prediction (i.e. part-of-speech tagging, morphosyntactic analysis, lemmatisation and dependency parsing) trained on the treebanks from [Universal Dependencies repository](https://universaldependencies.org),
 - enhanced dependency parsing trained on IWPT 2020 shared task [data](https://universaldependencies.org/iwpt20/data.html).
 
+Pre-trained models list with the **evaluation results** is available in the [spreadsheet](https://docs.google.com/spreadsheets/d/1WFYc2aLRa1jw7le030HOacv9fc4zmtqiZtRQY6gl5mc/edit?usp=sharing)
+Please notice that the name in the brackets matches the name used in [Automatic Download](models.md#Automatic download).
 ## Manual download
 
 The pre-trained models can be downloaded from [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/).
diff --git a/scripts/train.py b/scripts/train.py
index 9390888..4bd342a 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -6,38 +6,71 @@ from absl import flags
 
 from scripts import utils
 
+# # ls -1 | xargs -i echo "\"{}\","
+# UD 2.7
 TREEBANKS = [
     "UD_Afrikaans-AfriBooms",
+    "UD_Akkadian-PISANDUB",
+    "UD_Akkadian-RIAO",
+    "UD_Akuntsu-TuDeT",
+    "UD_Albanian-TSA",
+    "UD_Amharic-ATT",
+    "UD_Ancient_Greek-Perseus",
+    "UD_Ancient_Greek-PROIEL",
+    "UD_Apurina-UFPA",
     "UD_Arabic-NYUAD",
     "UD_Arabic-PADT",
+    "UD_Arabic-PUD",
     "UD_Armenian-ArmTDP",
+    "UD_Assyrian-AS",
+    "UD_Bambara-CRB",
     "UD_Basque-BDT",
     "UD_Belarusian-HSE",
+    "UD_Bhojpuri-BHTB",
     "UD_Breton-KEB",
     "UD_Bulgarian-BTB",
+    "UD_Buryat-BDT",
+    "UD_Cantonese-HK",
     "UD_Catalan-AnCora",
+    "UD_Chinese-CFL",
+    "UD_Chinese-GSD",
+    "UD_Chinese-GSDSimp",
+    "UD_Chinese-HK",
+    "UD_Chinese-PUD",
+    "UD_Chukchi-HSE",
+    "UD_Classical_Chinese-Kyoto",
+    "UD_Coptic-Scriptorium",
     "UD_Croatian-SET",
     "UD_Czech-CAC",
     "UD_Czech-CLTT",
     "UD_Czech-FicTree",
     "UD_Czech-PDT",
+    "UD_Czech-PUD",
     "UD_Danish-DDT",
     "UD_Dutch-Alpino",
     "UD_Dutch-LassySmall",
     "UD_English-ESL",
     "UD_English-EWT",
     "UD_English-GUM",
+    "UD_English-GUMReddit",
     "UD_English-LinES",
     "UD_English-ParTUT",
     "UD_English-Pronouns",
+    "UD_English-PUD",
+    "UD_Erzya-JR",
     "UD_Estonian-EDT",
     "UD_Estonian-EWT",
+    "UD_Faroese-FarPaHC",
+    "UD_Faroese-OFT",
     "UD_Finnish-FTB",
+    "UD_Finnish-OOD",
+    "UD_Finnish-PUD",
     "UD_Finnish-TDT",
     "UD_French-FQB",
     "UD_French-FTB",
     "UD_French-GSD",
     "UD_French-ParTUT",
+    "UD_French-PUD",
     "UD_French-Sequoia",
     "UD_French-Spoken",
     "UD_Galician-CTG",
@@ -45,60 +78,120 @@ TREEBANKS = [
     "UD_German-GSD",
     "UD_German-HDT",
     "UD_German-LIT",
+    "UD_German-PUD",
+    "UD_Gothic-PROIEL",
     "UD_Greek-GDT",
     "UD_Hebrew-HTB",
     "UD_Hindi_English-HIENCS",
     "UD_Hindi-HDTB",
+    "UD_Hindi-PUD",
     "UD_Hungarian-Szeged",
+    "UD_Icelandic-IcePaHC",
+    "UD_Icelandic-PUD",
+    "UD_Indonesian-CSUI",
     "UD_Indonesian-GSD",
+    "UD_Indonesian-PUD",
     "UD_Irish-IDT",
     "UD_Italian-ISDT",
     "UD_Italian-ParTUT",
     "UD_Italian-PoSTWITA",
+    "UD_Italian-PUD",
     "UD_Italian-TWITTIRO",
     "UD_Italian-VIT",
-    "UD_Japanese-BCCWJ",
+    # "UD_Japanese-BCCWJ", no data
     "UD_Japanese-GSD",
     "UD_Japanese-Modern",
+    "UD_Japanese-PUD",
+    "UD_Karelian-KKPP",
     "UD_Kazakh-KTB",
+    "UD_Khunsari-AHA",
+    "UD_Komi_Permyak-UH",
+    "UD_Komi_Zyrian-IKDP",
+    "UD_Komi_Zyrian-Lattice",
     "UD_Korean-GSD",
     "UD_Korean-Kaist",
+    "UD_Korean-PUD",
+    "UD_Kurmanji-MG",
     "UD_Latin-ITTB",
+    "UD_Latin-LLCT",
     "UD_Latin-Perseus",
     "UD_Latin-PROIEL",
     "UD_Latvian-LVTB",
     "UD_Lithuanian-ALKSNIS",
     "UD_Lithuanian-HSE",
+    "UD_Livvi-KKPP",
     "UD_Maltese-MUDT",
+    "UD_Manx-Cadhan",
     "UD_Marathi-UFAL",
+    "UD_Mbya_Guarani-Dooley",
+    "UD_Mbya_Guarani-Thomas",
+    "UD_Moksha-JR",
+    "UD_Munduruku-TuDeT",
+    "UD_Naija-NSC",
+    "UD_Nayini-AHA",
+    "UD_North_Sami-Giella",
+    "UD_Norwegian-Bokmaal",
+    "UD_Norwegian-Nynorsk",
+    "UD_Norwegian-NynorskLIA",
+    "UD_Old_Church_Slavonic-PROIEL",
+    "UD_Old_French-SRCMF",
+    "UD_Old_Russian-RNC",
+    "UD_Old_Russian-TOROT",
+    "UD_Old_Turkish-Tonqq",
+    "UD_Persian-PerDT",
     "UD_Persian-Seraji",
     "UD_Polish-LFG",
     "UD_Polish-PDB",
+    "UD_Polish-PUD",
     "UD_Portuguese-Bosque",
     "UD_Portuguese-GSD",
+    "UD_Portuguese-PUD",
     "UD_Romanian-Nonstandard",
     "UD_Romanian-RRT",
     "UD_Romanian-SiMoNERo",
     "UD_Russian-GSD",
+    "UD_Russian-PUD",
     "UD_Russian-SynTagRus",
     "UD_Russian-Taiga",
+    "UD_Sanskrit-UFAL",
+    "UD_Sanskrit-Vedic",
+    "UD_Scottish_Gaelic-ARCOSG",
     "UD_Serbian-SET",
+    "UD_Skolt_Sami-Giellagas",
     "UD_Slovak-SNK",
     "UD_Slovenian-SSJ",
     "UD_Slovenian-SST",
+    "UD_Soi-AHA",
+    "UD_South_Levantine_Arabic-MADAR",
     "UD_Spanish-AnCora",
     "UD_Spanish-GSD",
+    "UD_Spanish-PUD",
     "UD_Swedish-LinES",
+    "UD_Swedish-PUD",
     "UD_Swedish_Sign_Language-SSLC",
     "UD_Swedish-Talbanken",
+    "UD_Swiss_German-UZH",
+    "UD_Tagalog-TRG",
+    "UD_Tagalog-Ugnayan",
+    "UD_Tamil-MWTT",
     "UD_Tamil-TTB",
     "UD_Telugu-MTG",
+    "UD_Thai-PUD",
+    "UD_Tupinamba-TuDeT",
+    "UD_Turkish-BOUN",
     "UD_Turkish-GB",
+    "UD_Turkish_German-SAGT",
     "UD_Turkish-IMST",
+    "UD_Turkish-PUD",
     "UD_Ukrainian-IU",
+    "UD_Upper_Sorbian-UFAL",
     "UD_Urdu-UDTB",
     "UD_Uyghur-UDT",
     "UD_Vietnamese-VTB",
+    "UD_Warlpiri-UFAL",
+    "UD_Welsh-CCG",
+    "UD_Wolof-WTB",
+    "UD_Yoruba-YTB",
 ]
 
 FLAGS = flags.FLAGS
@@ -136,7 +229,7 @@ def run(_):
         embeddings_file = None
         if embeddings_dir:
             embeddings_dir = pathlib.Path(embeddings_dir) / language
-            embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec.gz" in f.name]
+            embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec" in f.name]
             assert len(embeddings_file) == 1, f"Couldn't find embeddings file."
             embeddings_file = embeddings_file[0]
 
@@ -153,14 +246,19 @@ def run(_):
         else f"--pretrained_transformer_name {utils.LANG2TRANSFORMER[language]}"}
         --serialization_dir {serialization_dir}
         --config_path {pathlib.Path.cwd() / 'config.template.jsonnet'}
-        --word_batch_size 2500
         --notensorboard
         """
 
-        # no XPOS datasets
-        if treebank in ["UD_Hungarian-Szeged", "UD_Armenian-ArmTDP"]:
+        # Datasets without XPOS
+        if treebank in {"UD_Armenian-ArmTDP", "UD_Basque-BDT", "UD_Hungarian-Szeged"}:
             command = command + " --targets deprel,head,upostag,lemma,feats"
 
+        # Reduce word_batch_size
+        word_batch_size = 2500
+        if treebank in {"UD_German-HDT"}:
+            word_batch_size = 1000
+        command = command + f" --word_batch_size {word_batch_size}"
+
         utils.execute_command(command)
 
 
diff --git a/setup.py b/setup.py
index fdaa2be..8d4db9d 100644
--- a/setup.py
+++ b/setup.py
@@ -6,15 +6,16 @@ REQUIREMENTS = [
     'allennlp==1.2.1',
     'conllu==2.3.2',
     'dataclasses;python_version<"3.7"',
-    'joblib==0.14.1',
     'jsonnet==0.15.0',
-    'requests==2.23.0',
+    'numpy==1.19.4',
     'overrides==3.1.0',
-    'tensorboard==2.1.0',
+    'requests==2.23.0',
+    'spacy==2.3.2',
+    'scikit-learn<=0.23.2',
     'torch==1.6.0',
     'tqdm==4.43.0',
     'transformers>=3.4.0,<3.5',
-    'urllib3>=1.25.11',
+    'urllib3==1.25.11',
 ]
 
 setup(
-- 
GitLab