From b8c83784dfd541dbae8851a45c66a204d5cd82a7 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Tue, 5 Jan 2021 16:37:59 +0100
Subject: [PATCH 1/5] Extend installation and pre-trained models documentation,
 fix installation, update trainer to be in sync with AllenNLP 1.2.1.

---
 combo/data/api.py         |   3 ++
 combo/training/trainer.py |  38 ++++++++------
 docs/installation.md      |  18 ++++++-
 docs/models.md            |   2 +
 scripts/train.py          | 108 ++++++++++++++++++++++++++++++++++++--
 setup.py                  |   9 ++--
 6 files changed, 151 insertions(+), 27 deletions(-)

diff --git a/combo/data/api.py b/combo/data/api.py
index 7d44917..4ab7f1a 100644
--- a/combo/data/api.py
+++ b/combo/data/api.py
@@ -36,6 +36,9 @@ class Sentence:
             "metadata": self.metadata,
         })
 
+    def __len__(self):
+        return len(self.tokens)
+
 
 class _TokenList(conllu.TokenList):
 
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
index 26bd75f..f74873e 100644
--- a/combo/training/trainer.py
+++ b/combo/training/trainer.py
@@ -68,10 +68,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
         self.validate_every_n = 5
 
     @overrides
-    def train(self) -> Dict[str, Any]:
-        """
-        Trains the supplied model with the supplied parameters.
-        """
+    def _try_train(self) -> Dict[str, Any]:
         try:
             epoch_counter = self._restore_checkpoint()
         except RuntimeError:
@@ -87,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
         logger.info("Beginning training.")
 
         val_metrics: Dict[str, float] = {}
-        this_epoch_val_metric: float = None
+        this_epoch_val_metric: float
         metrics: Dict[str, Any] = {}
         epochs_trained = 0
         training_start_time = time.time()
@@ -97,19 +94,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             metrics["best_validation_" + key] = value
 
         for callback in self._epoch_callbacks:
-            callback(self, metrics={}, epoch=-1, is_master=True)
+            callback(self, metrics={}, epoch=-1, is_master=self._master)
 
         for epoch in range(epoch_counter, self._num_epochs):
             epoch_start_time = time.time()
             train_metrics = self._train_epoch(epoch)
 
+            if self._master and self._checkpointer is not None:
+                self._checkpointer.save_checkpoint(epoch, self, save_model_only=True)
+
+            # Wait for the master to finish saving the model checkpoint
+            if self._distributed:
+                dist.barrier()
+
             # get peak of memory usage
-            if "cpu_memory_MB" in train_metrics:
-                metrics["peak_cpu_memory_MB"] = max(
-                    metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"]
-                )
             for key, value in train_metrics.items():
-                if key.startswith("gpu_"):
+                if key.startswith("gpu_") and key.endswith("_memory_MB"):
+                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
+                elif key.startswith("worker_") and key.endswith("_memory_MB"):
                     metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
 
             if self._validation_data_loader is not None:
@@ -129,9 +131,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
                             self.model,
                             val_loss,
                             val_reg_loss,
-                            num_batches=num_batches,
                             batch_loss=None,
                             batch_reg_loss=None,
+                            num_batches=num_batches,
                             reset=True,
                             world_size=self._world_size,
                             cuda_device=self.cuda_device,
@@ -139,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
 
                         # Check validation metric for early stopping
                         this_epoch_val_metric = val_metrics[self._validation_metric]
-                        # self._metric_tracker.add_metric(this_epoch_val_metric)
+                        self._metric_tracker.add_metric(this_epoch_val_metric)
 
                 train_metrics["patience"] = self._metric_tracker._patience
                 if self._metric_tracker.should_stop_early():
@@ -184,7 +186,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             if self._momentum_scheduler:
                 self._momentum_scheduler.step(this_epoch_val_metric)
 
-            if self._master:
+            if self._master and self._checkpointer is not None:
                 self._checkpointer.save_checkpoint(
                     epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far()
                 )
@@ -209,11 +211,13 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
 
             epochs_trained += 1
 
-        # make sure pending events are flushed to disk and files are closed properly
-        self._tensorboard.close()
+        for callback in self._end_callbacks:
+            callback(self, metrics=metrics, epoch=epoch, is_master=self._master)
 
         # Load the best model state before returning
-        best_model_state = self._checkpointer.best_model_state()
+        best_model_state = (
+            None if self._checkpointer is None else self._checkpointer.best_model_state()
+        )
         if best_model_state:
             self.model.load_state_dict(best_model_state)
 
diff --git a/docs/installation.md b/docs/installation.md
index 6371094..bf741f9 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -7,7 +7,23 @@ python setup.py develop
 combo --helpfull
 ```
 
+### Virtualenv example:
+```bash
+python -m venv venv
+source venv/bin/activate
+pip install --upgrade pip
+python setup.py develop
+```
+
 ## Problems & solutions
 * **jsonnet** installation error
 
-use `conda install -c conda-forge jsonnet=0.15.0`
+Run `conda install -c conda-forge jsonnet=0.15.0` and re-run installation.
+
+* No package 'sentencepiece' found
+
+Run `pip install sentencepiece` and re-run installation.
+
+* Missing Cython error
+
+Run `pip install spacy==2.3.2` manually and re-run installation.
diff --git a/docs/models.md b/docs/models.md
index 25a7f70..94eed03 100644
--- a/docs/models.md
+++ b/docs/models.md
@@ -4,6 +4,8 @@ COMBO provides pre-trained models for:
 - morphosyntactic prediction (i.e. part-of-speech tagging, morphosyntactic analysis, lemmatisation and dependency parsing) trained on the treebanks from [Universal Dependencies repository](https://universaldependencies.org),
 - enhanced dependency parsing trained on IWPT 2020 shared task [data](https://universaldependencies.org/iwpt20/data.html).
 
+Pre-trained models list with the **evaluation results** is available in the [spreadsheet](https://docs.google.com/spreadsheets/d/1WFYc2aLRa1jw7le030HOacv9fc4zmtqiZtRQY6gl5mc/edit?usp=sharing)
+Please notice that the name in the brackets matches the name used in [Automatic Download](models.md#Automatic download).
 ## Manual download
 
 The pre-trained models can be downloaded from [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/).
diff --git a/scripts/train.py b/scripts/train.py
index 9390888..4bd342a 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -6,38 +6,71 @@ from absl import flags
 
 from scripts import utils
 
+# # ls -1 | xargs -i echo "\"{}\","
+# UD 2.7
 TREEBANKS = [
     "UD_Afrikaans-AfriBooms",
+    "UD_Akkadian-PISANDUB",
+    "UD_Akkadian-RIAO",
+    "UD_Akuntsu-TuDeT",
+    "UD_Albanian-TSA",
+    "UD_Amharic-ATT",
+    "UD_Ancient_Greek-Perseus",
+    "UD_Ancient_Greek-PROIEL",
+    "UD_Apurina-UFPA",
     "UD_Arabic-NYUAD",
     "UD_Arabic-PADT",
+    "UD_Arabic-PUD",
     "UD_Armenian-ArmTDP",
+    "UD_Assyrian-AS",
+    "UD_Bambara-CRB",
     "UD_Basque-BDT",
     "UD_Belarusian-HSE",
+    "UD_Bhojpuri-BHTB",
     "UD_Breton-KEB",
     "UD_Bulgarian-BTB",
+    "UD_Buryat-BDT",
+    "UD_Cantonese-HK",
     "UD_Catalan-AnCora",
+    "UD_Chinese-CFL",
+    "UD_Chinese-GSD",
+    "UD_Chinese-GSDSimp",
+    "UD_Chinese-HK",
+    "UD_Chinese-PUD",
+    "UD_Chukchi-HSE",
+    "UD_Classical_Chinese-Kyoto",
+    "UD_Coptic-Scriptorium",
     "UD_Croatian-SET",
     "UD_Czech-CAC",
     "UD_Czech-CLTT",
     "UD_Czech-FicTree",
     "UD_Czech-PDT",
+    "UD_Czech-PUD",
     "UD_Danish-DDT",
     "UD_Dutch-Alpino",
     "UD_Dutch-LassySmall",
     "UD_English-ESL",
     "UD_English-EWT",
     "UD_English-GUM",
+    "UD_English-GUMReddit",
     "UD_English-LinES",
     "UD_English-ParTUT",
     "UD_English-Pronouns",
+    "UD_English-PUD",
+    "UD_Erzya-JR",
     "UD_Estonian-EDT",
     "UD_Estonian-EWT",
+    "UD_Faroese-FarPaHC",
+    "UD_Faroese-OFT",
     "UD_Finnish-FTB",
+    "UD_Finnish-OOD",
+    "UD_Finnish-PUD",
     "UD_Finnish-TDT",
     "UD_French-FQB",
     "UD_French-FTB",
     "UD_French-GSD",
     "UD_French-ParTUT",
+    "UD_French-PUD",
     "UD_French-Sequoia",
     "UD_French-Spoken",
     "UD_Galician-CTG",
@@ -45,60 +78,120 @@ TREEBANKS = [
     "UD_German-GSD",
     "UD_German-HDT",
     "UD_German-LIT",
+    "UD_German-PUD",
+    "UD_Gothic-PROIEL",
     "UD_Greek-GDT",
     "UD_Hebrew-HTB",
     "UD_Hindi_English-HIENCS",
     "UD_Hindi-HDTB",
+    "UD_Hindi-PUD",
     "UD_Hungarian-Szeged",
+    "UD_Icelandic-IcePaHC",
+    "UD_Icelandic-PUD",
+    "UD_Indonesian-CSUI",
     "UD_Indonesian-GSD",
+    "UD_Indonesian-PUD",
     "UD_Irish-IDT",
     "UD_Italian-ISDT",
     "UD_Italian-ParTUT",
     "UD_Italian-PoSTWITA",
+    "UD_Italian-PUD",
     "UD_Italian-TWITTIRO",
     "UD_Italian-VIT",
-    "UD_Japanese-BCCWJ",
+    # "UD_Japanese-BCCWJ", no data
     "UD_Japanese-GSD",
     "UD_Japanese-Modern",
+    "UD_Japanese-PUD",
+    "UD_Karelian-KKPP",
     "UD_Kazakh-KTB",
+    "UD_Khunsari-AHA",
+    "UD_Komi_Permyak-UH",
+    "UD_Komi_Zyrian-IKDP",
+    "UD_Komi_Zyrian-Lattice",
     "UD_Korean-GSD",
     "UD_Korean-Kaist",
+    "UD_Korean-PUD",
+    "UD_Kurmanji-MG",
     "UD_Latin-ITTB",
+    "UD_Latin-LLCT",
     "UD_Latin-Perseus",
     "UD_Latin-PROIEL",
     "UD_Latvian-LVTB",
     "UD_Lithuanian-ALKSNIS",
     "UD_Lithuanian-HSE",
+    "UD_Livvi-KKPP",
     "UD_Maltese-MUDT",
+    "UD_Manx-Cadhan",
     "UD_Marathi-UFAL",
+    "UD_Mbya_Guarani-Dooley",
+    "UD_Mbya_Guarani-Thomas",
+    "UD_Moksha-JR",
+    "UD_Munduruku-TuDeT",
+    "UD_Naija-NSC",
+    "UD_Nayini-AHA",
+    "UD_North_Sami-Giella",
+    "UD_Norwegian-Bokmaal",
+    "UD_Norwegian-Nynorsk",
+    "UD_Norwegian-NynorskLIA",
+    "UD_Old_Church_Slavonic-PROIEL",
+    "UD_Old_French-SRCMF",
+    "UD_Old_Russian-RNC",
+    "UD_Old_Russian-TOROT",
+    "UD_Old_Turkish-Tonqq",
+    "UD_Persian-PerDT",
     "UD_Persian-Seraji",
     "UD_Polish-LFG",
     "UD_Polish-PDB",
+    "UD_Polish-PUD",
     "UD_Portuguese-Bosque",
     "UD_Portuguese-GSD",
+    "UD_Portuguese-PUD",
     "UD_Romanian-Nonstandard",
     "UD_Romanian-RRT",
     "UD_Romanian-SiMoNERo",
     "UD_Russian-GSD",
+    "UD_Russian-PUD",
     "UD_Russian-SynTagRus",
     "UD_Russian-Taiga",
+    "UD_Sanskrit-UFAL",
+    "UD_Sanskrit-Vedic",
+    "UD_Scottish_Gaelic-ARCOSG",
     "UD_Serbian-SET",
+    "UD_Skolt_Sami-Giellagas",
     "UD_Slovak-SNK",
     "UD_Slovenian-SSJ",
     "UD_Slovenian-SST",
+    "UD_Soi-AHA",
+    "UD_South_Levantine_Arabic-MADAR",
     "UD_Spanish-AnCora",
     "UD_Spanish-GSD",
+    "UD_Spanish-PUD",
     "UD_Swedish-LinES",
+    "UD_Swedish-PUD",
     "UD_Swedish_Sign_Language-SSLC",
     "UD_Swedish-Talbanken",
+    "UD_Swiss_German-UZH",
+    "UD_Tagalog-TRG",
+    "UD_Tagalog-Ugnayan",
+    "UD_Tamil-MWTT",
     "UD_Tamil-TTB",
     "UD_Telugu-MTG",
+    "UD_Thai-PUD",
+    "UD_Tupinamba-TuDeT",
+    "UD_Turkish-BOUN",
     "UD_Turkish-GB",
+    "UD_Turkish_German-SAGT",
     "UD_Turkish-IMST",
+    "UD_Turkish-PUD",
     "UD_Ukrainian-IU",
+    "UD_Upper_Sorbian-UFAL",
     "UD_Urdu-UDTB",
     "UD_Uyghur-UDT",
     "UD_Vietnamese-VTB",
+    "UD_Warlpiri-UFAL",
+    "UD_Welsh-CCG",
+    "UD_Wolof-WTB",
+    "UD_Yoruba-YTB",
 ]
 
 FLAGS = flags.FLAGS
@@ -136,7 +229,7 @@ def run(_):
         embeddings_file = None
         if embeddings_dir:
             embeddings_dir = pathlib.Path(embeddings_dir) / language
-            embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec.gz" in f.name]
+            embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec" in f.name]
             assert len(embeddings_file) == 1, f"Couldn't find embeddings file."
             embeddings_file = embeddings_file[0]
 
@@ -153,14 +246,19 @@ def run(_):
         else f"--pretrained_transformer_name {utils.LANG2TRANSFORMER[language]}"}
         --serialization_dir {serialization_dir}
         --config_path {pathlib.Path.cwd() / 'config.template.jsonnet'}
-        --word_batch_size 2500
         --notensorboard
         """
 
-        # no XPOS datasets
-        if treebank in ["UD_Hungarian-Szeged", "UD_Armenian-ArmTDP"]:
+        # Datasets without XPOS
+        if treebank in {"UD_Armenian-ArmTDP", "UD_Basque-BDT", "UD_Hungarian-Szeged"}:
             command = command + " --targets deprel,head,upostag,lemma,feats"
 
+        # Reduce word_batch_size
+        word_batch_size = 2500
+        if treebank in {"UD_German-HDT"}:
+            word_batch_size = 1000
+        command = command + f" --word_batch_size {word_batch_size}"
+
         utils.execute_command(command)
 
 
diff --git a/setup.py b/setup.py
index fdaa2be..8d4db9d 100644
--- a/setup.py
+++ b/setup.py
@@ -6,15 +6,16 @@ REQUIREMENTS = [
     'allennlp==1.2.1',
     'conllu==2.3.2',
     'dataclasses;python_version<"3.7"',
-    'joblib==0.14.1',
     'jsonnet==0.15.0',
-    'requests==2.23.0',
+    'numpy==1.19.4',
     'overrides==3.1.0',
-    'tensorboard==2.1.0',
+    'requests==2.23.0',
+    'spacy==2.3.2',
+    'scikit-learn<=0.23.2',
     'torch==1.6.0',
     'tqdm==4.43.0',
     'transformers>=3.4.0,<3.5',
-    'urllib3>=1.25.11',
+    'urllib3==1.25.11',
 ]
 
 setup(
-- 
GitLab


From c0835180c3092ba52b19ab8b98a488468915f0a0 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Wed, 6 Jan 2021 11:12:38 +0100
Subject: [PATCH 2/5] Fix training loops and metrics.

---
 combo/training/checkpointer.py | 1 +
 combo/training/trainer.py      | 4 ++--
 combo/utils/metrics.py         | 6 +++---
 3 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/combo/training/checkpointer.py b/combo/training/checkpointer.py
index c148ed6..bae4403 100644
--- a/combo/training/checkpointer.py
+++ b/combo/training/checkpointer.py
@@ -16,6 +16,7 @@ class FinishingTrainingCheckpointer(training.Checkpointer):
             epoch: Union[int, str],
             trainer: "allen_trainer.Trainer",
             is_best_so_far: bool = False,
+            save_model_only: bool = False,
     ) -> None:
         if trainer._learning_rate_scheduler.decreases <= 1 or epoch == trainer._num_epochs - 1:
             super().save_checkpoint(epoch, trainer, is_best_so_far)
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
index f74873e..3bee8fc 100644
--- a/combo/training/trainer.py
+++ b/combo/training/trainer.py
@@ -84,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
         logger.info("Beginning training.")
 
         val_metrics: Dict[str, float] = {}
-        this_epoch_val_metric: float
+        this_epoch_val_metric: float = None
         metrics: Dict[str, Any] = {}
         epochs_trained = 0
         training_start_time = time.time()
@@ -141,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
 
                         # Check validation metric for early stopping
                         this_epoch_val_metric = val_metrics[self._validation_metric]
-                        self._metric_tracker.add_metric(this_epoch_val_metric)
+                        # self._metric_tracker.add_metric(this_epoch_val_metric)
 
                 train_metrics["patience"] = self._metric_tracker._patience
                 if self._metric_tracker.should_stop_early():
diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py
index 682e885..1a17540 100644
--- a/combo/utils/metrics.py
+++ b/combo/utils/metrics.py
@@ -241,10 +241,10 @@ class SemanticMetrics(metrics.Metric):
                            self.feats_score.correct_indices *
                            self.lemma_score.correct_indices *
                            self.attachment_scores.correct_indices *
-                           enhanced_indices)
+                           enhanced_indices) * mask.flatten()
 
-        total, correct_indices = self.detach_tensors(total, correct_indices)
-        self.em_score = (correct_indices.float().sum() / total).item()
+        total, correct_indices = self.detach_tensors(total, correct_indices.float().sum())
+        self.em_score = (correct_indices / total).item()
 
     def get_metric(self, reset: bool) -> Dict[str, float]:
         metrics_dict = {
-- 
GitLab


From 2aa41eba98c2af0857d2122be7c47549711f970e Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Wed, 6 Jan 2021 23:35:38 +0100
Subject: [PATCH 3/5] Remove lemma padding from lemma loss and metric.

---
 combo/models/lemma.py       |  4 +--
 combo/utils/metrics.py      | 62 ++++++++++++++++++++++++++++++++++++-
 tests/utils/test_metrics.py | 55 ++++++++++++++++++++++++++++++++
 3 files changed, 118 insertions(+), 3 deletions(-)

diff --git a/combo/models/lemma.py b/combo/models/lemma.py
index 4ff9c92..828ba3e 100644
--- a/combo/models/lemma.py
+++ b/combo/models/lemma.py
@@ -62,11 +62,11 @@ class LemmatizerModel(base.Predictor):
         BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_CLASSES = pred.size()
         pred = pred.reshape(-1, CHAR_CLASSES)
 
-        valid_positions = mask.sum()
-        mask = mask.reshape(-1)
         true = true.reshape(-1)
+        mask = true.gt(0)
         loss = utils.masked_cross_entropy(pred, true, mask)
         loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
+        valid_positions = mask.sum()
         return loss.sum() / valid_positions
 
     @classmethod
diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py
index 1a17540..c2a1202 100644
--- a/combo/utils/metrics.py
+++ b/combo/utils/metrics.py
@@ -6,6 +6,66 @@ from allennlp.training import metrics
 from overrides import overrides
 
 
+class LemmaAccuracy(metrics.Metric):
+
+    def __init__(self):
+        self._correct_count = 0.0
+        self._total_count = 0.0
+        self.correct_indices = torch.ones([])
+
+    @overrides
+    def __call__(self,
+                 predictions: torch.Tensor,
+                 gold_labels: torch.Tensor,
+                 mask: Optional[torch.BoolTensor] = None):
+        if gold_labels is None:
+            return
+        predictions, gold_labels, mask = self.detach_tensors(predictions,
+                                                             gold_labels,
+                                                             mask)
+
+        # Some sanity checks.
+        if gold_labels.size() != predictions.size():
+            raise ValueError(
+                f"gold_labels must have shape == predictions.size() but "
+                f"found tensor of shape: {gold_labels.size()}"
+            )
+        if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]:
+            raise ValueError(
+                f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but "
+                f"found tensor of shape: {mask.size()}"
+            )
+        if mask is None:
+            mask = predictions.new_ones(predictions.size()[:-1]).bool()
+        if mask.dim() < predictions.dim():
+            mask = mask.unsqueeze(-1)
+
+        padding_mask = gold_labels.gt(0)
+        correct = predictions.eq(gold_labels) * padding_mask
+        correct = (correct.int().sum(-1) == padding_mask.int().sum(-1)) * mask.squeeze(-1)
+        correct = correct.float()
+
+        self.correct_indices = correct.flatten().bool()
+        self._correct_count += correct.sum()
+        self._total_count += mask.sum()
+
+    @overrides
+    def get_metric(self, reset: bool) -> float:
+        if self._total_count > 0:
+            accuracy = float(self._correct_count) / float(self._total_count)
+        else:
+            accuracy = 0.0
+        if reset:
+            self.reset()
+        return accuracy
+
+    @overrides
+    def reset(self) -> None:
+        self._correct_count = 0.0
+        self._total_count = 0.0
+        self.correct_indices = torch.ones([])
+
+
 class SequenceBoolAccuracy(metrics.Metric):
     """BoolAccuracy implementation to handle sequences."""
 
@@ -202,7 +262,7 @@ class SemanticMetrics(metrics.Metric):
         self.xpos_score = SequenceBoolAccuracy()
         self.semrel_score = SequenceBoolAccuracy()
         self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
-        self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True)
+        self.lemma_score = LemmaAccuracy()
         self.attachment_scores = AttachmentScores()
         # Ignore PADDING and OOV
         self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1])
diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py
index 242eaa3..bf7f619 100644
--- a/tests/utils/test_metrics.py
+++ b/tests/utils/test_metrics.py
@@ -154,3 +154,58 @@ class SequenceBoolAccuracyTest(unittest.TestCase):
         # then
         self.assertEqual(metric._correct_count.item(), 7)
         self.assertEqual(metric._total_count.item(), 10)
+
+
+class LemmaAccuracyTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.mask: torch.BoolTensor = torch.tensor([
+            [True, True, True, True],
+            [True, True, True, False],
+        ])
+
+    def test_prediction_has_error_in_not_padded_place(self):
+        # given
+        metric = metrics.LemmaAccuracy()
+        predictions = torch.tensor([
+            [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
+            [[1, 1, 0], [1, 1000, 0], [1, 1, 0], [1, 1, 0], ],
+        ])
+        gold_labels = torch.tensor([
+            [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
+            [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
+        ])
+        expected_correct_count = 6
+        expected_total_count = 7
+        expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 0, 1, 0])
+
+        # when
+        metric(predictions, gold_labels, self.mask)
+
+        # then
+        self.assertEqual(metric._correct_count.item(), expected_correct_count)
+        self.assertEqual(metric._total_count.item(), expected_total_count)
+        self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices)))
+
+    def test_prediction_wrong_prediction_in_padding_should_be_ignored(self):
+        # given
+        metric = metrics.LemmaAccuracy()
+        predictions = torch.tensor([
+            [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
+            [[1, 1, 1000], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
+        ])
+        gold_labels = torch.tensor([
+            [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
+            [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
+        ])
+        expected_correct_count = 7
+        expected_total_count = 7
+        expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 1, 1, 0])
+
+        # when
+        metric(predictions, gold_labels, self.mask)
+
+        # then
+        self.assertEqual(expected_correct_count, metric._correct_count.item())
+        self.assertEqual(expected_total_count, metric._total_count.item())
+        self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices)))
-- 
GitLab


From a7dfecab31718c864f827f71532ac83b28856add Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Thu, 7 Jan 2021 09:30:43 +0100
Subject: [PATCH 4/5] Add seeds to configs.

---
 config.graph.template.jsonnet | 3 +++
 config.template.jsonnet       | 3 +++
 2 files changed, 6 insertions(+)

diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet
index bc8c465..c0c4696 100644
--- a/config.graph.template.jsonnet
+++ b/config.graph.template.jsonnet
@@ -419,4 +419,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
         },
         validation_metric: "+EM",
     }),
+    random_seed: 8787,
+    pytorch_seed: 8787,
+    numpy_seed: 8787,
 }
diff --git a/config.template.jsonnet b/config.template.jsonnet
index f41ba62..c602d9c 100644
--- a/config.template.jsonnet
+++ b/config.template.jsonnet
@@ -386,4 +386,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
         },
         validation_metric: "+EM",
     }),
+    random_seed: 8787,
+    pytorch_seed: 8787,
+    numpy_seed: 8787,
 }
-- 
GitLab


From bcce7338a5325c44c2b7a6a3e77ad97fc6d11aa5 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Thu, 7 Jan 2021 13:27:29 +0100
Subject: [PATCH 5/5] Fix language extraction in training script and add script
 for downloading fasttext embeddings.

---
 scripts/download_fasttext.py | 306 +++++++++++++++++++++++++++++++++++
 scripts/train.py             |   2 +-
 2 files changed, 307 insertions(+), 1 deletion(-)
 create mode 100644 scripts/download_fasttext.py

diff --git a/scripts/download_fasttext.py b/scripts/download_fasttext.py
new file mode 100644
index 0000000..1919fa2
--- /dev/null
+++ b/scripts/download_fasttext.py
@@ -0,0 +1,306 @@
+import pathlib
+
+from absl import app
+from absl import flags
+
+from scripts import utils
+
+# egrep -o 'https?://[^ ]+vec.gz' links.txt
+# https://github.com/facebookresearch/fastText/blob/master/docs/crawl-vectors.md
+LINKS = [
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.af.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sq.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.als.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.am.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ar.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.an.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hy.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.as.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ast.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.az.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ba.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.eu.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bar.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.be.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bn.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bh.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bpy.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bs.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.br.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bg.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.my.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ca.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ceb.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bcl.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ce.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zh.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cv.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.co.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hr.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cs.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.da.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.dv.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nl.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pa.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.arz.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.eml.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.myv.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.eo.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.et.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hif.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fi.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gl.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ka.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.de.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gom.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.el.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gu.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ht.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.he.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mrj.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hi.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hu.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.is.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.io.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ilo.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.id.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ia.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ga.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.it.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ja.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.jv.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.kn.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pam.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.kk.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.km.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ky.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ko.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ku.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ckb.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.la.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lv.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.li.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lt.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lmo.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nds.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lb.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mk.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mai.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mg.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ms.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ml.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mt.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gv.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mr.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mzn.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mhr.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.min.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.xmf.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mwl.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mn.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nah.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nap.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ne.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.new.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.frr.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nso.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.no.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nn.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.oc.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.or.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.os.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pfl.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ps.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fa.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pms.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pl.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pt.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.qu.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ro.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.rm.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ru.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sah.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sa.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sc.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sco.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gd.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sr.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sh.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.scn.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sd.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.si.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sk.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sl.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.so.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.azb.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.es.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.su.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sw.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sv.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tl.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tg.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ta.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tt.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.te.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.th.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bo.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tr.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tk.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.uk.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hsb.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ur.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ug.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.uz.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vec.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vi.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vo.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.wa.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.war.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cy.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vls.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fy.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pnb.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.yi.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.yo.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.diq.300.vec.gz",
+    "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zea.300.vec.gz",
+]
+
+CODE_2_LANG = {
+    "af": "Afrikaans",
+    "aii": "Assyrian",
+    "ajp": "South_Levantine_Arabic",
+    "akk": "Akkadian",
+    "am": "Amharic",
+    "apu": "Apurina",
+    "aqz": "Akuntsu",
+    "ar": "Arabic",
+    "be": "Belarusian",
+    "bg": "Bulgarian",
+    "bho": "Bhojpuri",
+    "bm": "Bambara",
+    "br": "Breton",
+    "bxr": "Buryat",
+    "ca": "Catalan",
+    "ckt": "Chukchi",
+    "cop": "Coptic",
+    "cs": "Czech",
+    "cu": "Old_Church_Slavonic",
+    "cy": "Welsh",
+    "da": "Danish",
+    "de": "German",
+    "el": "Greek",
+    "en": "English",
+    "es": "Spanish",
+    "et": "Estonian",
+    "eu": "Basque",
+    "fa": "Persian",
+    "fi": "Finnish",
+    "fo": "Faroese",
+    "fr": "French",
+    "fro": "Old_French",
+    "ga": "Irish",
+    "gd": "Scottish_Gaelic",
+    "gl": "Galician",
+    "got": "Gothic",
+    "grc": "Ancient_Greek",
+    "gsw": "Swiss_German",
+    "gun": "Mbya_Guarani",
+    "gv": "Manx",
+    "he": "Hebrew",
+    "hi": "Hindi",
+    "hr": "Croatian",
+    "hsb": "Upper_Sorbian",
+    "hu": "Hungarian",
+    "hy": "Armenian",
+    "id": "Indonesian",
+    "is": "Icelandic",
+    "it": "Italian",
+    "ja": "Japanese",
+    "kfm": "Khunsari",
+    "kk": "Kazakh",
+    "kmr": "Kurmanji",
+    "ko": "Korean",
+    "koi": "Komi_Permyak",
+    "kpv": "Komi_Zyrian",
+    "krl": "Karelian",
+    "la": "Latin",
+    "lt": "Lithuanian",
+    "lv": "Latvian",
+    "lzh": "Classical_Chinese",
+    "mdf": "Moksha",
+    "mr": "Marathi",
+    "mt": "Maltese",
+    "myu": "Munduruku",
+    "myv": "Erzya",
+    "nl": "Dutch",
+    "no": "Norwegian",
+    "nyq": "Nayini",
+    "olo": "Livvi",
+    "orv": "Old_Russian",
+    "otk": "Old_Turkish",
+    "pcm": "Naija",
+    "pl": "Polish",
+    "pt": "Portuguese",
+    "qhe": "Hindi_English",
+    "qtd": "Turkish_German",
+    "ro": "Romanian",
+    "ru": "Russian",
+    "sa": "Sanskrit",
+    "sk": "Slovak",
+    "sl": "Slovenian",
+    "sme": "North_Sami",
+    "sms": "Skolt_Sami",
+    "soj": "Soi",
+    "sq": "Albanian",
+    "sr": "Serbian",
+    "sv": "Swedish",
+    "swl": "Swedish_Sign_Language",
+    "ta": "Tamil",
+    "te": "Telugu",
+    "th": "Thai",
+    "tl": "Tagalog",
+    "tpn": "Tupinamba",
+    "tr": "Turkish",
+    "ug": "Uyghur",
+    "uk": "Ukrainian",
+    "ur": "Urdu",
+    "vi": "Vietnamese",
+    "wbp": "Warlpiri",
+    "wo": "Wolof",
+    "yo": "Yoruba",
+    "yue": "Cantonese",
+    "zh": "Chinese",
+}
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string(name="output_dir", default="",
+                    help="Path to store embeddings.")
+
+
+def run(_):
+    output_dir = pathlib.Path(FLAGS.output_dir)
+    for link in LINKS:
+        lang_code = link.split(".")[-4]
+
+        if lang_code not in CODE_2_LANG:
+            print(f"Unknown code {lang_code}.")
+            continue
+
+        output_file = output_dir / CODE_2_LANG[lang_code]
+        output_file.mkdir(exist_ok=True, parents=True)
+        if (output_file / 'vectors.vec.gz').exists():
+            print(f"Vectors for {CODE_2_LANG[lang_code]} already exists, skipping.")
+            continue
+
+        utils.execute_command(f"wget -O {output_file / 'vectors.vec.gz'} {link}")
+
+
+def main():
+    app.run(run)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/train.py b/scripts/train.py
index 4bd342a..dc75344 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -212,7 +212,7 @@ def run(_):
     for treebank in FLAGS.treebanks:
         assert treebank in TREEBANKS, f"Unknown treebank {treebank}."
         treebank_dir = treebanks_dir / treebank
-        treebank_parts = treebank.split("_")[1].split("-")
+        treebank_parts = treebank[3:].split("-")
         language = treebank_parts[0]
 
         files = list(treebank_dir.iterdir())
-- 
GitLab