From b8c83784dfd541dbae8851a45c66a204d5cd82a7 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Tue, 5 Jan 2021 16:37:59 +0100 Subject: [PATCH 1/5] Extend installation and pre-trained models documentation, fix installation, update trainer to be in sync with AllenNLP 1.2.1. --- combo/data/api.py | 3 ++ combo/training/trainer.py | 38 ++++++++------ docs/installation.md | 18 ++++++- docs/models.md | 2 + scripts/train.py | 108 ++++++++++++++++++++++++++++++++++++-- setup.py | 9 ++-- 6 files changed, 151 insertions(+), 27 deletions(-) diff --git a/combo/data/api.py b/combo/data/api.py index 7d44917..4ab7f1a 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -36,6 +36,9 @@ class Sentence: "metadata": self.metadata, }) + def __len__(self): + return len(self.tokens) + class _TokenList(conllu.TokenList): diff --git a/combo/training/trainer.py b/combo/training/trainer.py index 26bd75f..f74873e 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -68,10 +68,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): self.validate_every_n = 5 @overrides - def train(self) -> Dict[str, Any]: - """ - Trains the supplied model with the supplied parameters. - """ + def _try_train(self) -> Dict[str, Any]: try: epoch_counter = self._restore_checkpoint() except RuntimeError: @@ -87,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): logger.info("Beginning training.") val_metrics: Dict[str, float] = {} - this_epoch_val_metric: float = None + this_epoch_val_metric: float metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() @@ -97,19 +94,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer): metrics["best_validation_" + key] = value for callback in self._epoch_callbacks: - callback(self, metrics={}, epoch=-1, is_master=True) + callback(self, metrics={}, epoch=-1, is_master=self._master) for epoch in range(epoch_counter, self._num_epochs): epoch_start_time = time.time() train_metrics = self._train_epoch(epoch) + if self._master and self._checkpointer is not None: + self._checkpointer.save_checkpoint(epoch, self, save_model_only=True) + + # Wait for the master to finish saving the model checkpoint + if self._distributed: + dist.barrier() + # get peak of memory usage - if "cpu_memory_MB" in train_metrics: - metrics["peak_cpu_memory_MB"] = max( - metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"] - ) for key, value in train_metrics.items(): - if key.startswith("gpu_"): + if key.startswith("gpu_") and key.endswith("_memory_MB"): + metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) + elif key.startswith("worker_") and key.endswith("_memory_MB"): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) if self._validation_data_loader is not None: @@ -129,9 +131,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer): self.model, val_loss, val_reg_loss, - num_batches=num_batches, batch_loss=None, batch_reg_loss=None, + num_batches=num_batches, reset=True, world_size=self._world_size, cuda_device=self.cuda_device, @@ -139,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): # Check validation metric for early stopping this_epoch_val_metric = val_metrics[self._validation_metric] - # self._metric_tracker.add_metric(this_epoch_val_metric) + self._metric_tracker.add_metric(this_epoch_val_metric) train_metrics["patience"] = self._metric_tracker._patience if self._metric_tracker.should_stop_early(): @@ -184,7 +186,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): if self._momentum_scheduler: self._momentum_scheduler.step(this_epoch_val_metric) - if self._master: + if self._master and self._checkpointer is not None: self._checkpointer.save_checkpoint( epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far() ) @@ -209,11 +211,13 @@ class GradientDescentTrainer(training.GradientDescentTrainer): epochs_trained += 1 - # make sure pending events are flushed to disk and files are closed properly - self._tensorboard.close() + for callback in self._end_callbacks: + callback(self, metrics=metrics, epoch=epoch, is_master=self._master) # Load the best model state before returning - best_model_state = self._checkpointer.best_model_state() + best_model_state = ( + None if self._checkpointer is None else self._checkpointer.best_model_state() + ) if best_model_state: self.model.load_state_dict(best_model_state) diff --git a/docs/installation.md b/docs/installation.md index 6371094..bf741f9 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -7,7 +7,23 @@ python setup.py develop combo --helpfull ``` +### Virtualenv example: +```bash +python -m venv venv +source venv/bin/activate +pip install --upgrade pip +python setup.py develop +``` + ## Problems & solutions * **jsonnet** installation error -use `conda install -c conda-forge jsonnet=0.15.0` +Run `conda install -c conda-forge jsonnet=0.15.0` and re-run installation. + +* No package 'sentencepiece' found + +Run `pip install sentencepiece` and re-run installation. + +* Missing Cython error + +Run `pip install spacy==2.3.2` manually and re-run installation. diff --git a/docs/models.md b/docs/models.md index 25a7f70..94eed03 100644 --- a/docs/models.md +++ b/docs/models.md @@ -4,6 +4,8 @@ COMBO provides pre-trained models for: - morphosyntactic prediction (i.e. part-of-speech tagging, morphosyntactic analysis, lemmatisation and dependency parsing) trained on the treebanks from [Universal Dependencies repository](https://universaldependencies.org), - enhanced dependency parsing trained on IWPT 2020 shared task [data](https://universaldependencies.org/iwpt20/data.html). +Pre-trained models list with the **evaluation results** is available in the [spreadsheet](https://docs.google.com/spreadsheets/d/1WFYc2aLRa1jw7le030HOacv9fc4zmtqiZtRQY6gl5mc/edit?usp=sharing) +Please notice that the name in the brackets matches the name used in [Automatic Download](models.md#Automatic download). ## Manual download The pre-trained models can be downloaded from [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/). diff --git a/scripts/train.py b/scripts/train.py index 9390888..4bd342a 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -6,38 +6,71 @@ from absl import flags from scripts import utils +# # ls -1 | xargs -i echo "\"{}\"," +# UD 2.7 TREEBANKS = [ "UD_Afrikaans-AfriBooms", + "UD_Akkadian-PISANDUB", + "UD_Akkadian-RIAO", + "UD_Akuntsu-TuDeT", + "UD_Albanian-TSA", + "UD_Amharic-ATT", + "UD_Ancient_Greek-Perseus", + "UD_Ancient_Greek-PROIEL", + "UD_Apurina-UFPA", "UD_Arabic-NYUAD", "UD_Arabic-PADT", + "UD_Arabic-PUD", "UD_Armenian-ArmTDP", + "UD_Assyrian-AS", + "UD_Bambara-CRB", "UD_Basque-BDT", "UD_Belarusian-HSE", + "UD_Bhojpuri-BHTB", "UD_Breton-KEB", "UD_Bulgarian-BTB", + "UD_Buryat-BDT", + "UD_Cantonese-HK", "UD_Catalan-AnCora", + "UD_Chinese-CFL", + "UD_Chinese-GSD", + "UD_Chinese-GSDSimp", + "UD_Chinese-HK", + "UD_Chinese-PUD", + "UD_Chukchi-HSE", + "UD_Classical_Chinese-Kyoto", + "UD_Coptic-Scriptorium", "UD_Croatian-SET", "UD_Czech-CAC", "UD_Czech-CLTT", "UD_Czech-FicTree", "UD_Czech-PDT", + "UD_Czech-PUD", "UD_Danish-DDT", "UD_Dutch-Alpino", "UD_Dutch-LassySmall", "UD_English-ESL", "UD_English-EWT", "UD_English-GUM", + "UD_English-GUMReddit", "UD_English-LinES", "UD_English-ParTUT", "UD_English-Pronouns", + "UD_English-PUD", + "UD_Erzya-JR", "UD_Estonian-EDT", "UD_Estonian-EWT", + "UD_Faroese-FarPaHC", + "UD_Faroese-OFT", "UD_Finnish-FTB", + "UD_Finnish-OOD", + "UD_Finnish-PUD", "UD_Finnish-TDT", "UD_French-FQB", "UD_French-FTB", "UD_French-GSD", "UD_French-ParTUT", + "UD_French-PUD", "UD_French-Sequoia", "UD_French-Spoken", "UD_Galician-CTG", @@ -45,60 +78,120 @@ TREEBANKS = [ "UD_German-GSD", "UD_German-HDT", "UD_German-LIT", + "UD_German-PUD", + "UD_Gothic-PROIEL", "UD_Greek-GDT", "UD_Hebrew-HTB", "UD_Hindi_English-HIENCS", "UD_Hindi-HDTB", + "UD_Hindi-PUD", "UD_Hungarian-Szeged", + "UD_Icelandic-IcePaHC", + "UD_Icelandic-PUD", + "UD_Indonesian-CSUI", "UD_Indonesian-GSD", + "UD_Indonesian-PUD", "UD_Irish-IDT", "UD_Italian-ISDT", "UD_Italian-ParTUT", "UD_Italian-PoSTWITA", + "UD_Italian-PUD", "UD_Italian-TWITTIRO", "UD_Italian-VIT", - "UD_Japanese-BCCWJ", + # "UD_Japanese-BCCWJ", no data "UD_Japanese-GSD", "UD_Japanese-Modern", + "UD_Japanese-PUD", + "UD_Karelian-KKPP", "UD_Kazakh-KTB", + "UD_Khunsari-AHA", + "UD_Komi_Permyak-UH", + "UD_Komi_Zyrian-IKDP", + "UD_Komi_Zyrian-Lattice", "UD_Korean-GSD", "UD_Korean-Kaist", + "UD_Korean-PUD", + "UD_Kurmanji-MG", "UD_Latin-ITTB", + "UD_Latin-LLCT", "UD_Latin-Perseus", "UD_Latin-PROIEL", "UD_Latvian-LVTB", "UD_Lithuanian-ALKSNIS", "UD_Lithuanian-HSE", + "UD_Livvi-KKPP", "UD_Maltese-MUDT", + "UD_Manx-Cadhan", "UD_Marathi-UFAL", + "UD_Mbya_Guarani-Dooley", + "UD_Mbya_Guarani-Thomas", + "UD_Moksha-JR", + "UD_Munduruku-TuDeT", + "UD_Naija-NSC", + "UD_Nayini-AHA", + "UD_North_Sami-Giella", + "UD_Norwegian-Bokmaal", + "UD_Norwegian-Nynorsk", + "UD_Norwegian-NynorskLIA", + "UD_Old_Church_Slavonic-PROIEL", + "UD_Old_French-SRCMF", + "UD_Old_Russian-RNC", + "UD_Old_Russian-TOROT", + "UD_Old_Turkish-Tonqq", + "UD_Persian-PerDT", "UD_Persian-Seraji", "UD_Polish-LFG", "UD_Polish-PDB", + "UD_Polish-PUD", "UD_Portuguese-Bosque", "UD_Portuguese-GSD", + "UD_Portuguese-PUD", "UD_Romanian-Nonstandard", "UD_Romanian-RRT", "UD_Romanian-SiMoNERo", "UD_Russian-GSD", + "UD_Russian-PUD", "UD_Russian-SynTagRus", "UD_Russian-Taiga", + "UD_Sanskrit-UFAL", + "UD_Sanskrit-Vedic", + "UD_Scottish_Gaelic-ARCOSG", "UD_Serbian-SET", + "UD_Skolt_Sami-Giellagas", "UD_Slovak-SNK", "UD_Slovenian-SSJ", "UD_Slovenian-SST", + "UD_Soi-AHA", + "UD_South_Levantine_Arabic-MADAR", "UD_Spanish-AnCora", "UD_Spanish-GSD", + "UD_Spanish-PUD", "UD_Swedish-LinES", + "UD_Swedish-PUD", "UD_Swedish_Sign_Language-SSLC", "UD_Swedish-Talbanken", + "UD_Swiss_German-UZH", + "UD_Tagalog-TRG", + "UD_Tagalog-Ugnayan", + "UD_Tamil-MWTT", "UD_Tamil-TTB", "UD_Telugu-MTG", + "UD_Thai-PUD", + "UD_Tupinamba-TuDeT", + "UD_Turkish-BOUN", "UD_Turkish-GB", + "UD_Turkish_German-SAGT", "UD_Turkish-IMST", + "UD_Turkish-PUD", "UD_Ukrainian-IU", + "UD_Upper_Sorbian-UFAL", "UD_Urdu-UDTB", "UD_Uyghur-UDT", "UD_Vietnamese-VTB", + "UD_Warlpiri-UFAL", + "UD_Welsh-CCG", + "UD_Wolof-WTB", + "UD_Yoruba-YTB", ] FLAGS = flags.FLAGS @@ -136,7 +229,7 @@ def run(_): embeddings_file = None if embeddings_dir: embeddings_dir = pathlib.Path(embeddings_dir) / language - embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec.gz" in f.name] + embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec" in f.name] assert len(embeddings_file) == 1, f"Couldn't find embeddings file." embeddings_file = embeddings_file[0] @@ -153,14 +246,19 @@ def run(_): else f"--pretrained_transformer_name {utils.LANG2TRANSFORMER[language]}"} --serialization_dir {serialization_dir} --config_path {pathlib.Path.cwd() / 'config.template.jsonnet'} - --word_batch_size 2500 --notensorboard """ - # no XPOS datasets - if treebank in ["UD_Hungarian-Szeged", "UD_Armenian-ArmTDP"]: + # Datasets without XPOS + if treebank in {"UD_Armenian-ArmTDP", "UD_Basque-BDT", "UD_Hungarian-Szeged"}: command = command + " --targets deprel,head,upostag,lemma,feats" + # Reduce word_batch_size + word_batch_size = 2500 + if treebank in {"UD_German-HDT"}: + word_batch_size = 1000 + command = command + f" --word_batch_size {word_batch_size}" + utils.execute_command(command) diff --git a/setup.py b/setup.py index fdaa2be..8d4db9d 100644 --- a/setup.py +++ b/setup.py @@ -6,15 +6,16 @@ REQUIREMENTS = [ 'allennlp==1.2.1', 'conllu==2.3.2', 'dataclasses;python_version<"3.7"', - 'joblib==0.14.1', 'jsonnet==0.15.0', - 'requests==2.23.0', + 'numpy==1.19.4', 'overrides==3.1.0', - 'tensorboard==2.1.0', + 'requests==2.23.0', + 'spacy==2.3.2', + 'scikit-learn<=0.23.2', 'torch==1.6.0', 'tqdm==4.43.0', 'transformers>=3.4.0,<3.5', - 'urllib3>=1.25.11', + 'urllib3==1.25.11', ] setup( -- GitLab From c0835180c3092ba52b19ab8b98a488468915f0a0 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Wed, 6 Jan 2021 11:12:38 +0100 Subject: [PATCH 2/5] Fix training loops and metrics. --- combo/training/checkpointer.py | 1 + combo/training/trainer.py | 4 ++-- combo/utils/metrics.py | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/combo/training/checkpointer.py b/combo/training/checkpointer.py index c148ed6..bae4403 100644 --- a/combo/training/checkpointer.py +++ b/combo/training/checkpointer.py @@ -16,6 +16,7 @@ class FinishingTrainingCheckpointer(training.Checkpointer): epoch: Union[int, str], trainer: "allen_trainer.Trainer", is_best_so_far: bool = False, + save_model_only: bool = False, ) -> None: if trainer._learning_rate_scheduler.decreases <= 1 or epoch == trainer._num_epochs - 1: super().save_checkpoint(epoch, trainer, is_best_so_far) diff --git a/combo/training/trainer.py b/combo/training/trainer.py index f74873e..3bee8fc 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -84,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): logger.info("Beginning training.") val_metrics: Dict[str, float] = {} - this_epoch_val_metric: float + this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() @@ -141,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): # Check validation metric for early stopping this_epoch_val_metric = val_metrics[self._validation_metric] - self._metric_tracker.add_metric(this_epoch_val_metric) + # self._metric_tracker.add_metric(this_epoch_val_metric) train_metrics["patience"] = self._metric_tracker._patience if self._metric_tracker.should_stop_early(): diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py index 682e885..1a17540 100644 --- a/combo/utils/metrics.py +++ b/combo/utils/metrics.py @@ -241,10 +241,10 @@ class SemanticMetrics(metrics.Metric): self.feats_score.correct_indices * self.lemma_score.correct_indices * self.attachment_scores.correct_indices * - enhanced_indices) + enhanced_indices) * mask.flatten() - total, correct_indices = self.detach_tensors(total, correct_indices) - self.em_score = (correct_indices.float().sum() / total).item() + total, correct_indices = self.detach_tensors(total, correct_indices.float().sum()) + self.em_score = (correct_indices / total).item() def get_metric(self, reset: bool) -> Dict[str, float]: metrics_dict = { -- GitLab From 2aa41eba98c2af0857d2122be7c47549711f970e Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Wed, 6 Jan 2021 23:35:38 +0100 Subject: [PATCH 3/5] Remove lemma padding from lemma loss and metric. --- combo/models/lemma.py | 4 +-- combo/utils/metrics.py | 62 ++++++++++++++++++++++++++++++++++++- tests/utils/test_metrics.py | 55 ++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 3 deletions(-) diff --git a/combo/models/lemma.py b/combo/models/lemma.py index 4ff9c92..828ba3e 100644 --- a/combo/models/lemma.py +++ b/combo/models/lemma.py @@ -62,11 +62,11 @@ class LemmatizerModel(base.Predictor): BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_CLASSES = pred.size() pred = pred.reshape(-1, CHAR_CLASSES) - valid_positions = mask.sum() - mask = mask.reshape(-1) true = true.reshape(-1) + mask = true.gt(0) loss = utils.masked_cross_entropy(pred, true, mask) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) + valid_positions = mask.sum() return loss.sum() / valid_positions @classmethod diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py index 1a17540..c2a1202 100644 --- a/combo/utils/metrics.py +++ b/combo/utils/metrics.py @@ -6,6 +6,66 @@ from allennlp.training import metrics from overrides import overrides +class LemmaAccuracy(metrics.Metric): + + def __init__(self): + self._correct_count = 0.0 + self._total_count = 0.0 + self.correct_indices = torch.ones([]) + + @overrides + def __call__(self, + predictions: torch.Tensor, + gold_labels: torch.Tensor, + mask: Optional[torch.BoolTensor] = None): + if gold_labels is None: + return + predictions, gold_labels, mask = self.detach_tensors(predictions, + gold_labels, + mask) + + # Some sanity checks. + if gold_labels.size() != predictions.size(): + raise ValueError( + f"gold_labels must have shape == predictions.size() but " + f"found tensor of shape: {gold_labels.size()}" + ) + if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]: + raise ValueError( + f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but " + f"found tensor of shape: {mask.size()}" + ) + if mask is None: + mask = predictions.new_ones(predictions.size()[:-1]).bool() + if mask.dim() < predictions.dim(): + mask = mask.unsqueeze(-1) + + padding_mask = gold_labels.gt(0) + correct = predictions.eq(gold_labels) * padding_mask + correct = (correct.int().sum(-1) == padding_mask.int().sum(-1)) * mask.squeeze(-1) + correct = correct.float() + + self.correct_indices = correct.flatten().bool() + self._correct_count += correct.sum() + self._total_count += mask.sum() + + @overrides + def get_metric(self, reset: bool) -> float: + if self._total_count > 0: + accuracy = float(self._correct_count) / float(self._total_count) + else: + accuracy = 0.0 + if reset: + self.reset() + return accuracy + + @overrides + def reset(self) -> None: + self._correct_count = 0.0 + self._total_count = 0.0 + self.correct_indices = torch.ones([]) + + class SequenceBoolAccuracy(metrics.Metric): """BoolAccuracy implementation to handle sequences.""" @@ -202,7 +262,7 @@ class SemanticMetrics(metrics.Metric): self.xpos_score = SequenceBoolAccuracy() self.semrel_score = SequenceBoolAccuracy() self.feats_score = SequenceBoolAccuracy(prod_last_dim=True) - self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True) + self.lemma_score = LemmaAccuracy() self.attachment_scores = AttachmentScores() # Ignore PADDING and OOV self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1]) diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py index 242eaa3..bf7f619 100644 --- a/tests/utils/test_metrics.py +++ b/tests/utils/test_metrics.py @@ -154,3 +154,58 @@ class SequenceBoolAccuracyTest(unittest.TestCase): # then self.assertEqual(metric._correct_count.item(), 7) self.assertEqual(metric._total_count.item(), 10) + + +class LemmaAccuracyTest(unittest.TestCase): + + def setUp(self) -> None: + self.mask: torch.BoolTensor = torch.tensor([ + [True, True, True, True], + [True, True, True, False], + ]) + + def test_prediction_has_error_in_not_padded_place(self): + # given + metric = metrics.LemmaAccuracy() + predictions = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 0], [1, 1000, 0], [1, 1, 0], [1, 1, 0], ], + ]) + gold_labels = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ], + ]) + expected_correct_count = 6 + expected_total_count = 7 + expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 0, 1, 0]) + + # when + metric(predictions, gold_labels, self.mask) + + # then + self.assertEqual(metric._correct_count.item(), expected_correct_count) + self.assertEqual(metric._total_count.item(), expected_total_count) + self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices))) + + def test_prediction_wrong_prediction_in_padding_should_be_ignored(self): + # given + metric = metrics.LemmaAccuracy() + predictions = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 1000], [1, 1, 0], [1, 1, 0], [1, 1, 0], ], + ]) + gold_labels = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ], + ]) + expected_correct_count = 7 + expected_total_count = 7 + expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 1, 1, 0]) + + # when + metric(predictions, gold_labels, self.mask) + + # then + self.assertEqual(expected_correct_count, metric._correct_count.item()) + self.assertEqual(expected_total_count, metric._total_count.item()) + self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices))) -- GitLab From a7dfecab31718c864f827f71532ac83b28856add Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 7 Jan 2021 09:30:43 +0100 Subject: [PATCH 4/5] Add seeds to configs. --- config.graph.template.jsonnet | 3 +++ config.template.jsonnet | 3 +++ 2 files changed, 6 insertions(+) diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet index bc8c465..c0c4696 100644 --- a/config.graph.template.jsonnet +++ b/config.graph.template.jsonnet @@ -419,4 +419,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't }, validation_metric: "+EM", }), + random_seed: 8787, + pytorch_seed: 8787, + numpy_seed: 8787, } diff --git a/config.template.jsonnet b/config.template.jsonnet index f41ba62..c602d9c 100644 --- a/config.template.jsonnet +++ b/config.template.jsonnet @@ -386,4 +386,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't }, validation_metric: "+EM", }), + random_seed: 8787, + pytorch_seed: 8787, + numpy_seed: 8787, } -- GitLab From bcce7338a5325c44c2b7a6a3e77ad97fc6d11aa5 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 7 Jan 2021 13:27:29 +0100 Subject: [PATCH 5/5] Fix language extraction in training script and add script for downloading fasttext embeddings. --- scripts/download_fasttext.py | 306 +++++++++++++++++++++++++++++++++++ scripts/train.py | 2 +- 2 files changed, 307 insertions(+), 1 deletion(-) create mode 100644 scripts/download_fasttext.py diff --git a/scripts/download_fasttext.py b/scripts/download_fasttext.py new file mode 100644 index 0000000..1919fa2 --- /dev/null +++ b/scripts/download_fasttext.py @@ -0,0 +1,306 @@ +import pathlib + +from absl import app +from absl import flags + +from scripts import utils + +# egrep -o 'https?://[^ ]+vec.gz' links.txt +# https://github.com/facebookresearch/fastText/blob/master/docs/crawl-vectors.md +LINKS = [ + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.af.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sq.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.als.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.am.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ar.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.an.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hy.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.as.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ast.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.az.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ba.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.eu.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bar.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.be.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bh.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bpy.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bs.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.br.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bg.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.my.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ca.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ceb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bcl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ce.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zh.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.co.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cs.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.da.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.dv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pa.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.arz.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.eml.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.myv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.eo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.et.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hif.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fi.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ka.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.de.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gom.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.el.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gu.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ht.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.he.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mrj.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hi.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hu.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.is.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.io.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ilo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.id.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ia.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ga.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.it.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ja.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.jv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.kn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pam.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.kk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.km.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ky.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ko.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ku.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ckb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.la.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.li.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lt.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lmo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nds.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mai.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mg.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ms.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ml.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mt.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mzn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mhr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.min.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.xmf.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mwl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nah.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nap.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ne.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.new.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.frr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nso.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.no.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.oc.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.or.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.os.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pfl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ps.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fa.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pms.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pt.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.qu.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ro.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.rm.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ru.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sah.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sa.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sc.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sco.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gd.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sh.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.scn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sd.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.si.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.so.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.azb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.es.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.su.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sw.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tg.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ta.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tt.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.te.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.th.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.uk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hsb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ur.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ug.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.uz.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vec.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vi.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.wa.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.war.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cy.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vls.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fy.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pnb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.yi.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.yo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.diq.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zea.300.vec.gz", +] + +CODE_2_LANG = { + "af": "Afrikaans", + "aii": "Assyrian", + "ajp": "South_Levantine_Arabic", + "akk": "Akkadian", + "am": "Amharic", + "apu": "Apurina", + "aqz": "Akuntsu", + "ar": "Arabic", + "be": "Belarusian", + "bg": "Bulgarian", + "bho": "Bhojpuri", + "bm": "Bambara", + "br": "Breton", + "bxr": "Buryat", + "ca": "Catalan", + "ckt": "Chukchi", + "cop": "Coptic", + "cs": "Czech", + "cu": "Old_Church_Slavonic", + "cy": "Welsh", + "da": "Danish", + "de": "German", + "el": "Greek", + "en": "English", + "es": "Spanish", + "et": "Estonian", + "eu": "Basque", + "fa": "Persian", + "fi": "Finnish", + "fo": "Faroese", + "fr": "French", + "fro": "Old_French", + "ga": "Irish", + "gd": "Scottish_Gaelic", + "gl": "Galician", + "got": "Gothic", + "grc": "Ancient_Greek", + "gsw": "Swiss_German", + "gun": "Mbya_Guarani", + "gv": "Manx", + "he": "Hebrew", + "hi": "Hindi", + "hr": "Croatian", + "hsb": "Upper_Sorbian", + "hu": "Hungarian", + "hy": "Armenian", + "id": "Indonesian", + "is": "Icelandic", + "it": "Italian", + "ja": "Japanese", + "kfm": "Khunsari", + "kk": "Kazakh", + "kmr": "Kurmanji", + "ko": "Korean", + "koi": "Komi_Permyak", + "kpv": "Komi_Zyrian", + "krl": "Karelian", + "la": "Latin", + "lt": "Lithuanian", + "lv": "Latvian", + "lzh": "Classical_Chinese", + "mdf": "Moksha", + "mr": "Marathi", + "mt": "Maltese", + "myu": "Munduruku", + "myv": "Erzya", + "nl": "Dutch", + "no": "Norwegian", + "nyq": "Nayini", + "olo": "Livvi", + "orv": "Old_Russian", + "otk": "Old_Turkish", + "pcm": "Naija", + "pl": "Polish", + "pt": "Portuguese", + "qhe": "Hindi_English", + "qtd": "Turkish_German", + "ro": "Romanian", + "ru": "Russian", + "sa": "Sanskrit", + "sk": "Slovak", + "sl": "Slovenian", + "sme": "North_Sami", + "sms": "Skolt_Sami", + "soj": "Soi", + "sq": "Albanian", + "sr": "Serbian", + "sv": "Swedish", + "swl": "Swedish_Sign_Language", + "ta": "Tamil", + "te": "Telugu", + "th": "Thai", + "tl": "Tagalog", + "tpn": "Tupinamba", + "tr": "Turkish", + "ug": "Uyghur", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "wbp": "Warlpiri", + "wo": "Wolof", + "yo": "Yoruba", + "yue": "Cantonese", + "zh": "Chinese", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string(name="output_dir", default="", + help="Path to store embeddings.") + + +def run(_): + output_dir = pathlib.Path(FLAGS.output_dir) + for link in LINKS: + lang_code = link.split(".")[-4] + + if lang_code not in CODE_2_LANG: + print(f"Unknown code {lang_code}.") + continue + + output_file = output_dir / CODE_2_LANG[lang_code] + output_file.mkdir(exist_ok=True, parents=True) + if (output_file / 'vectors.vec.gz').exists(): + print(f"Vectors for {CODE_2_LANG[lang_code]} already exists, skipping.") + continue + + utils.execute_command(f"wget -O {output_file / 'vectors.vec.gz'} {link}") + + +def main(): + app.run(run) + + +if __name__ == "__main__": + main() diff --git a/scripts/train.py b/scripts/train.py index 4bd342a..dc75344 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -212,7 +212,7 @@ def run(_): for treebank in FLAGS.treebanks: assert treebank in TREEBANKS, f"Unknown treebank {treebank}." treebank_dir = treebanks_dir / treebank - treebank_parts = treebank.split("_")[1].split("-") + treebank_parts = treebank[3:].split("-") language = treebank_parts[0] files = list(treebank_dir.iterdir()) -- GitLab