From b8c83784dfd541dbae8851a45c66a204d5cd82a7 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Tue, 5 Jan 2021 16:37:59 +0100 Subject: [PATCH] Extend installation and pre-trained models documentation, fix installation, update trainer to be in sync with AllenNLP 1.2.1. --- combo/data/api.py | 3 ++ combo/training/trainer.py | 38 ++++++++------ docs/installation.md | 18 ++++++- docs/models.md | 2 + scripts/train.py | 108 ++++++++++++++++++++++++++++++++++++-- setup.py | 9 ++-- 6 files changed, 151 insertions(+), 27 deletions(-) diff --git a/combo/data/api.py b/combo/data/api.py index 7d44917..4ab7f1a 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -36,6 +36,9 @@ class Sentence: "metadata": self.metadata, }) + def __len__(self): + return len(self.tokens) + class _TokenList(conllu.TokenList): diff --git a/combo/training/trainer.py b/combo/training/trainer.py index 26bd75f..f74873e 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -68,10 +68,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): self.validate_every_n = 5 @overrides - def train(self) -> Dict[str, Any]: - """ - Trains the supplied model with the supplied parameters. - """ + def _try_train(self) -> Dict[str, Any]: try: epoch_counter = self._restore_checkpoint() except RuntimeError: @@ -87,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): logger.info("Beginning training.") val_metrics: Dict[str, float] = {} - this_epoch_val_metric: float = None + this_epoch_val_metric: float metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() @@ -97,19 +94,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer): metrics["best_validation_" + key] = value for callback in self._epoch_callbacks: - callback(self, metrics={}, epoch=-1, is_master=True) + callback(self, metrics={}, epoch=-1, is_master=self._master) for epoch in range(epoch_counter, self._num_epochs): epoch_start_time = time.time() train_metrics = self._train_epoch(epoch) + if self._master and self._checkpointer is not None: + self._checkpointer.save_checkpoint(epoch, self, save_model_only=True) + + # Wait for the master to finish saving the model checkpoint + if self._distributed: + dist.barrier() + # get peak of memory usage - if "cpu_memory_MB" in train_metrics: - metrics["peak_cpu_memory_MB"] = max( - metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"] - ) for key, value in train_metrics.items(): - if key.startswith("gpu_"): + if key.startswith("gpu_") and key.endswith("_memory_MB"): + metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) + elif key.startswith("worker_") and key.endswith("_memory_MB"): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) if self._validation_data_loader is not None: @@ -129,9 +131,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer): self.model, val_loss, val_reg_loss, - num_batches=num_batches, batch_loss=None, batch_reg_loss=None, + num_batches=num_batches, reset=True, world_size=self._world_size, cuda_device=self.cuda_device, @@ -139,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): # Check validation metric for early stopping this_epoch_val_metric = val_metrics[self._validation_metric] - # self._metric_tracker.add_metric(this_epoch_val_metric) + self._metric_tracker.add_metric(this_epoch_val_metric) train_metrics["patience"] = self._metric_tracker._patience if self._metric_tracker.should_stop_early(): @@ -184,7 +186,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): if self._momentum_scheduler: self._momentum_scheduler.step(this_epoch_val_metric) - if self._master: + if self._master and self._checkpointer is not None: self._checkpointer.save_checkpoint( epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far() ) @@ -209,11 +211,13 @@ class GradientDescentTrainer(training.GradientDescentTrainer): epochs_trained += 1 - # make sure pending events are flushed to disk and files are closed properly - self._tensorboard.close() + for callback in self._end_callbacks: + callback(self, metrics=metrics, epoch=epoch, is_master=self._master) # Load the best model state before returning - best_model_state = self._checkpointer.best_model_state() + best_model_state = ( + None if self._checkpointer is None else self._checkpointer.best_model_state() + ) if best_model_state: self.model.load_state_dict(best_model_state) diff --git a/docs/installation.md b/docs/installation.md index 6371094..bf741f9 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -7,7 +7,23 @@ python setup.py develop combo --helpfull ``` +### Virtualenv example: +```bash +python -m venv venv +source venv/bin/activate +pip install --upgrade pip +python setup.py develop +``` + ## Problems & solutions * **jsonnet** installation error -use `conda install -c conda-forge jsonnet=0.15.0` +Run `conda install -c conda-forge jsonnet=0.15.0` and re-run installation. + +* No package 'sentencepiece' found + +Run `pip install sentencepiece` and re-run installation. + +* Missing Cython error + +Run `pip install spacy==2.3.2` manually and re-run installation. diff --git a/docs/models.md b/docs/models.md index 25a7f70..94eed03 100644 --- a/docs/models.md +++ b/docs/models.md @@ -4,6 +4,8 @@ COMBO provides pre-trained models for: - morphosyntactic prediction (i.e. part-of-speech tagging, morphosyntactic analysis, lemmatisation and dependency parsing) trained on the treebanks from [Universal Dependencies repository](https://universaldependencies.org), - enhanced dependency parsing trained on IWPT 2020 shared task [data](https://universaldependencies.org/iwpt20/data.html). +Pre-trained models list with the **evaluation results** is available in the [spreadsheet](https://docs.google.com/spreadsheets/d/1WFYc2aLRa1jw7le030HOacv9fc4zmtqiZtRQY6gl5mc/edit?usp=sharing) +Please notice that the name in the brackets matches the name used in [Automatic Download](models.md#Automatic download). ## Manual download The pre-trained models can be downloaded from [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/). diff --git a/scripts/train.py b/scripts/train.py index 9390888..4bd342a 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -6,38 +6,71 @@ from absl import flags from scripts import utils +# # ls -1 | xargs -i echo "\"{}\"," +# UD 2.7 TREEBANKS = [ "UD_Afrikaans-AfriBooms", + "UD_Akkadian-PISANDUB", + "UD_Akkadian-RIAO", + "UD_Akuntsu-TuDeT", + "UD_Albanian-TSA", + "UD_Amharic-ATT", + "UD_Ancient_Greek-Perseus", + "UD_Ancient_Greek-PROIEL", + "UD_Apurina-UFPA", "UD_Arabic-NYUAD", "UD_Arabic-PADT", + "UD_Arabic-PUD", "UD_Armenian-ArmTDP", + "UD_Assyrian-AS", + "UD_Bambara-CRB", "UD_Basque-BDT", "UD_Belarusian-HSE", + "UD_Bhojpuri-BHTB", "UD_Breton-KEB", "UD_Bulgarian-BTB", + "UD_Buryat-BDT", + "UD_Cantonese-HK", "UD_Catalan-AnCora", + "UD_Chinese-CFL", + "UD_Chinese-GSD", + "UD_Chinese-GSDSimp", + "UD_Chinese-HK", + "UD_Chinese-PUD", + "UD_Chukchi-HSE", + "UD_Classical_Chinese-Kyoto", + "UD_Coptic-Scriptorium", "UD_Croatian-SET", "UD_Czech-CAC", "UD_Czech-CLTT", "UD_Czech-FicTree", "UD_Czech-PDT", + "UD_Czech-PUD", "UD_Danish-DDT", "UD_Dutch-Alpino", "UD_Dutch-LassySmall", "UD_English-ESL", "UD_English-EWT", "UD_English-GUM", + "UD_English-GUMReddit", "UD_English-LinES", "UD_English-ParTUT", "UD_English-Pronouns", + "UD_English-PUD", + "UD_Erzya-JR", "UD_Estonian-EDT", "UD_Estonian-EWT", + "UD_Faroese-FarPaHC", + "UD_Faroese-OFT", "UD_Finnish-FTB", + "UD_Finnish-OOD", + "UD_Finnish-PUD", "UD_Finnish-TDT", "UD_French-FQB", "UD_French-FTB", "UD_French-GSD", "UD_French-ParTUT", + "UD_French-PUD", "UD_French-Sequoia", "UD_French-Spoken", "UD_Galician-CTG", @@ -45,60 +78,120 @@ TREEBANKS = [ "UD_German-GSD", "UD_German-HDT", "UD_German-LIT", + "UD_German-PUD", + "UD_Gothic-PROIEL", "UD_Greek-GDT", "UD_Hebrew-HTB", "UD_Hindi_English-HIENCS", "UD_Hindi-HDTB", + "UD_Hindi-PUD", "UD_Hungarian-Szeged", + "UD_Icelandic-IcePaHC", + "UD_Icelandic-PUD", + "UD_Indonesian-CSUI", "UD_Indonesian-GSD", + "UD_Indonesian-PUD", "UD_Irish-IDT", "UD_Italian-ISDT", "UD_Italian-ParTUT", "UD_Italian-PoSTWITA", + "UD_Italian-PUD", "UD_Italian-TWITTIRO", "UD_Italian-VIT", - "UD_Japanese-BCCWJ", + # "UD_Japanese-BCCWJ", no data "UD_Japanese-GSD", "UD_Japanese-Modern", + "UD_Japanese-PUD", + "UD_Karelian-KKPP", "UD_Kazakh-KTB", + "UD_Khunsari-AHA", + "UD_Komi_Permyak-UH", + "UD_Komi_Zyrian-IKDP", + "UD_Komi_Zyrian-Lattice", "UD_Korean-GSD", "UD_Korean-Kaist", + "UD_Korean-PUD", + "UD_Kurmanji-MG", "UD_Latin-ITTB", + "UD_Latin-LLCT", "UD_Latin-Perseus", "UD_Latin-PROIEL", "UD_Latvian-LVTB", "UD_Lithuanian-ALKSNIS", "UD_Lithuanian-HSE", + "UD_Livvi-KKPP", "UD_Maltese-MUDT", + "UD_Manx-Cadhan", "UD_Marathi-UFAL", + "UD_Mbya_Guarani-Dooley", + "UD_Mbya_Guarani-Thomas", + "UD_Moksha-JR", + "UD_Munduruku-TuDeT", + "UD_Naija-NSC", + "UD_Nayini-AHA", + "UD_North_Sami-Giella", + "UD_Norwegian-Bokmaal", + "UD_Norwegian-Nynorsk", + "UD_Norwegian-NynorskLIA", + "UD_Old_Church_Slavonic-PROIEL", + "UD_Old_French-SRCMF", + "UD_Old_Russian-RNC", + "UD_Old_Russian-TOROT", + "UD_Old_Turkish-Tonqq", + "UD_Persian-PerDT", "UD_Persian-Seraji", "UD_Polish-LFG", "UD_Polish-PDB", + "UD_Polish-PUD", "UD_Portuguese-Bosque", "UD_Portuguese-GSD", + "UD_Portuguese-PUD", "UD_Romanian-Nonstandard", "UD_Romanian-RRT", "UD_Romanian-SiMoNERo", "UD_Russian-GSD", + "UD_Russian-PUD", "UD_Russian-SynTagRus", "UD_Russian-Taiga", + "UD_Sanskrit-UFAL", + "UD_Sanskrit-Vedic", + "UD_Scottish_Gaelic-ARCOSG", "UD_Serbian-SET", + "UD_Skolt_Sami-Giellagas", "UD_Slovak-SNK", "UD_Slovenian-SSJ", "UD_Slovenian-SST", + "UD_Soi-AHA", + "UD_South_Levantine_Arabic-MADAR", "UD_Spanish-AnCora", "UD_Spanish-GSD", + "UD_Spanish-PUD", "UD_Swedish-LinES", + "UD_Swedish-PUD", "UD_Swedish_Sign_Language-SSLC", "UD_Swedish-Talbanken", + "UD_Swiss_German-UZH", + "UD_Tagalog-TRG", + "UD_Tagalog-Ugnayan", + "UD_Tamil-MWTT", "UD_Tamil-TTB", "UD_Telugu-MTG", + "UD_Thai-PUD", + "UD_Tupinamba-TuDeT", + "UD_Turkish-BOUN", "UD_Turkish-GB", + "UD_Turkish_German-SAGT", "UD_Turkish-IMST", + "UD_Turkish-PUD", "UD_Ukrainian-IU", + "UD_Upper_Sorbian-UFAL", "UD_Urdu-UDTB", "UD_Uyghur-UDT", "UD_Vietnamese-VTB", + "UD_Warlpiri-UFAL", + "UD_Welsh-CCG", + "UD_Wolof-WTB", + "UD_Yoruba-YTB", ] FLAGS = flags.FLAGS @@ -136,7 +229,7 @@ def run(_): embeddings_file = None if embeddings_dir: embeddings_dir = pathlib.Path(embeddings_dir) / language - embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec.gz" in f.name] + embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec" in f.name] assert len(embeddings_file) == 1, f"Couldn't find embeddings file." embeddings_file = embeddings_file[0] @@ -153,14 +246,19 @@ def run(_): else f"--pretrained_transformer_name {utils.LANG2TRANSFORMER[language]}"} --serialization_dir {serialization_dir} --config_path {pathlib.Path.cwd() / 'config.template.jsonnet'} - --word_batch_size 2500 --notensorboard """ - # no XPOS datasets - if treebank in ["UD_Hungarian-Szeged", "UD_Armenian-ArmTDP"]: + # Datasets without XPOS + if treebank in {"UD_Armenian-ArmTDP", "UD_Basque-BDT", "UD_Hungarian-Szeged"}: command = command + " --targets deprel,head,upostag,lemma,feats" + # Reduce word_batch_size + word_batch_size = 2500 + if treebank in {"UD_German-HDT"}: + word_batch_size = 1000 + command = command + f" --word_batch_size {word_batch_size}" + utils.execute_command(command) diff --git a/setup.py b/setup.py index fdaa2be..8d4db9d 100644 --- a/setup.py +++ b/setup.py @@ -6,15 +6,16 @@ REQUIREMENTS = [ 'allennlp==1.2.1', 'conllu==2.3.2', 'dataclasses;python_version<"3.7"', - 'joblib==0.14.1', 'jsonnet==0.15.0', - 'requests==2.23.0', + 'numpy==1.19.4', 'overrides==3.1.0', - 'tensorboard==2.1.0', + 'requests==2.23.0', + 'spacy==2.3.2', + 'scikit-learn<=0.23.2', 'torch==1.6.0', 'tqdm==4.43.0', 'transformers>=3.4.0,<3.5', - 'urllib3>=1.25.11', + 'urllib3==1.25.11', ] setup( -- GitLab