diff --git a/combo/data/api.py b/combo/data/api.py index 7d44917ecc42a555be3c20e8500f595b8ee1edf1..4ab7f1a33de77ccf4b17c284777200e39c668081 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -36,6 +36,9 @@ class Sentence: "metadata": self.metadata, }) + def __len__(self): + return len(self.tokens) + class _TokenList(conllu.TokenList): diff --git a/combo/models/lemma.py b/combo/models/lemma.py index 4ff9c92d78d1f495c1e0d0df0992b368e735a47d..828ba3e9c0377b52c1bfe90dfb311d4f19689d3d 100644 --- a/combo/models/lemma.py +++ b/combo/models/lemma.py @@ -62,11 +62,11 @@ class LemmatizerModel(base.Predictor): BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_CLASSES = pred.size() pred = pred.reshape(-1, CHAR_CLASSES) - valid_positions = mask.sum() - mask = mask.reshape(-1) true = true.reshape(-1) + mask = true.gt(0) loss = utils.masked_cross_entropy(pred, true, mask) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) + valid_positions = mask.sum() return loss.sum() / valid_positions @classmethod diff --git a/combo/training/checkpointer.py b/combo/training/checkpointer.py index c148ed664d4a14676bc516888d212da240bef3ea..bae4403f6a5cfa0bc3482f2670a221478cd2cae3 100644 --- a/combo/training/checkpointer.py +++ b/combo/training/checkpointer.py @@ -16,6 +16,7 @@ class FinishingTrainingCheckpointer(training.Checkpointer): epoch: Union[int, str], trainer: "allen_trainer.Trainer", is_best_so_far: bool = False, + save_model_only: bool = False, ) -> None: if trainer._learning_rate_scheduler.decreases <= 1 or epoch == trainer._num_epochs - 1: super().save_checkpoint(epoch, trainer, is_best_so_far) diff --git a/combo/training/trainer.py b/combo/training/trainer.py index 26bd75f7fbe6917f144b820bbbb1c7e14c3c8e9d..3bee8fcaaf1c29189583a551a1100ac4f0215a65 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -68,10 +68,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): self.validate_every_n = 5 @overrides - def train(self) -> Dict[str, Any]: - """ - Trains the supplied model with the supplied parameters. - """ + def _try_train(self) -> Dict[str, Any]: try: epoch_counter = self._restore_checkpoint() except RuntimeError: @@ -97,19 +94,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer): metrics["best_validation_" + key] = value for callback in self._epoch_callbacks: - callback(self, metrics={}, epoch=-1, is_master=True) + callback(self, metrics={}, epoch=-1, is_master=self._master) for epoch in range(epoch_counter, self._num_epochs): epoch_start_time = time.time() train_metrics = self._train_epoch(epoch) + if self._master and self._checkpointer is not None: + self._checkpointer.save_checkpoint(epoch, self, save_model_only=True) + + # Wait for the master to finish saving the model checkpoint + if self._distributed: + dist.barrier() + # get peak of memory usage - if "cpu_memory_MB" in train_metrics: - metrics["peak_cpu_memory_MB"] = max( - metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"] - ) for key, value in train_metrics.items(): - if key.startswith("gpu_"): + if key.startswith("gpu_") and key.endswith("_memory_MB"): + metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) + elif key.startswith("worker_") and key.endswith("_memory_MB"): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) if self._validation_data_loader is not None: @@ -129,9 +131,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer): self.model, val_loss, val_reg_loss, - num_batches=num_batches, batch_loss=None, batch_reg_loss=None, + num_batches=num_batches, reset=True, world_size=self._world_size, cuda_device=self.cuda_device, @@ -184,7 +186,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): if self._momentum_scheduler: self._momentum_scheduler.step(this_epoch_val_metric) - if self._master: + if self._master and self._checkpointer is not None: self._checkpointer.save_checkpoint( epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far() ) @@ -209,11 +211,13 @@ class GradientDescentTrainer(training.GradientDescentTrainer): epochs_trained += 1 - # make sure pending events are flushed to disk and files are closed properly - self._tensorboard.close() + for callback in self._end_callbacks: + callback(self, metrics=metrics, epoch=epoch, is_master=self._master) # Load the best model state before returning - best_model_state = self._checkpointer.best_model_state() + best_model_state = ( + None if self._checkpointer is None else self._checkpointer.best_model_state() + ) if best_model_state: self.model.load_state_dict(best_model_state) diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py index 682e8859264a3414bf86d3a1e408ce5b3588a6f3..c2a1202148e602d433d2d4a3ac2a280f80bb76fb 100644 --- a/combo/utils/metrics.py +++ b/combo/utils/metrics.py @@ -6,6 +6,66 @@ from allennlp.training import metrics from overrides import overrides +class LemmaAccuracy(metrics.Metric): + + def __init__(self): + self._correct_count = 0.0 + self._total_count = 0.0 + self.correct_indices = torch.ones([]) + + @overrides + def __call__(self, + predictions: torch.Tensor, + gold_labels: torch.Tensor, + mask: Optional[torch.BoolTensor] = None): + if gold_labels is None: + return + predictions, gold_labels, mask = self.detach_tensors(predictions, + gold_labels, + mask) + + # Some sanity checks. + if gold_labels.size() != predictions.size(): + raise ValueError( + f"gold_labels must have shape == predictions.size() but " + f"found tensor of shape: {gold_labels.size()}" + ) + if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]: + raise ValueError( + f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but " + f"found tensor of shape: {mask.size()}" + ) + if mask is None: + mask = predictions.new_ones(predictions.size()[:-1]).bool() + if mask.dim() < predictions.dim(): + mask = mask.unsqueeze(-1) + + padding_mask = gold_labels.gt(0) + correct = predictions.eq(gold_labels) * padding_mask + correct = (correct.int().sum(-1) == padding_mask.int().sum(-1)) * mask.squeeze(-1) + correct = correct.float() + + self.correct_indices = correct.flatten().bool() + self._correct_count += correct.sum() + self._total_count += mask.sum() + + @overrides + def get_metric(self, reset: bool) -> float: + if self._total_count > 0: + accuracy = float(self._correct_count) / float(self._total_count) + else: + accuracy = 0.0 + if reset: + self.reset() + return accuracy + + @overrides + def reset(self) -> None: + self._correct_count = 0.0 + self._total_count = 0.0 + self.correct_indices = torch.ones([]) + + class SequenceBoolAccuracy(metrics.Metric): """BoolAccuracy implementation to handle sequences.""" @@ -202,7 +262,7 @@ class SemanticMetrics(metrics.Metric): self.xpos_score = SequenceBoolAccuracy() self.semrel_score = SequenceBoolAccuracy() self.feats_score = SequenceBoolAccuracy(prod_last_dim=True) - self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True) + self.lemma_score = LemmaAccuracy() self.attachment_scores = AttachmentScores() # Ignore PADDING and OOV self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1]) @@ -241,10 +301,10 @@ class SemanticMetrics(metrics.Metric): self.feats_score.correct_indices * self.lemma_score.correct_indices * self.attachment_scores.correct_indices * - enhanced_indices) + enhanced_indices) * mask.flatten() - total, correct_indices = self.detach_tensors(total, correct_indices) - self.em_score = (correct_indices.float().sum() / total).item() + total, correct_indices = self.detach_tensors(total, correct_indices.float().sum()) + self.em_score = (correct_indices / total).item() def get_metric(self, reset: bool) -> Dict[str, float]: metrics_dict = { diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet index bc8c46580f17f22924a9f68628d64ce7f1060d55..c0c469674f4a74a6d46953e65d9715e9eabf0a2f 100644 --- a/config.graph.template.jsonnet +++ b/config.graph.template.jsonnet @@ -419,4 +419,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't }, validation_metric: "+EM", }), + random_seed: 8787, + pytorch_seed: 8787, + numpy_seed: 8787, } diff --git a/config.template.jsonnet b/config.template.jsonnet index f41ba62672eb4f93e130261ac85a5abc00e1efee..c602d9cdc25fbee465f8b0d5fd51f3368e0811c7 100644 --- a/config.template.jsonnet +++ b/config.template.jsonnet @@ -386,4 +386,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't }, validation_metric: "+EM", }), + random_seed: 8787, + pytorch_seed: 8787, + numpy_seed: 8787, } diff --git a/docs/installation.md b/docs/installation.md index 6371094f8f409a5fbe3ec020caf6a93c77bb2325..bf741f9c9c0a1030814e772c166e824b99114213 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -7,7 +7,23 @@ python setup.py develop combo --helpfull ``` +### Virtualenv example: +```bash +python -m venv venv +source venv/bin/activate +pip install --upgrade pip +python setup.py develop +``` + ## Problems & solutions * **jsonnet** installation error -use `conda install -c conda-forge jsonnet=0.15.0` +Run `conda install -c conda-forge jsonnet=0.15.0` and re-run installation. + +* No package 'sentencepiece' found + +Run `pip install sentencepiece` and re-run installation. + +* Missing Cython error + +Run `pip install spacy==2.3.2` manually and re-run installation. diff --git a/docs/models.md b/docs/models.md index 25a7f7092ef295a0cf1ff2b3ba13e0adb05a5bc6..94eed0332521c64099ac9fab730e62d463e08c24 100644 --- a/docs/models.md +++ b/docs/models.md @@ -4,6 +4,8 @@ COMBO provides pre-trained models for: - morphosyntactic prediction (i.e. part-of-speech tagging, morphosyntactic analysis, lemmatisation and dependency parsing) trained on the treebanks from [Universal Dependencies repository](https://universaldependencies.org), - enhanced dependency parsing trained on IWPT 2020 shared task [data](https://universaldependencies.org/iwpt20/data.html). +Pre-trained models list with the **evaluation results** is available in the [spreadsheet](https://docs.google.com/spreadsheets/d/1WFYc2aLRa1jw7le030HOacv9fc4zmtqiZtRQY6gl5mc/edit?usp=sharing) +Please notice that the name in the brackets matches the name used in [Automatic Download](models.md#Automatic download). ## Manual download The pre-trained models can be downloaded from [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/). diff --git a/scripts/download_fasttext.py b/scripts/download_fasttext.py new file mode 100644 index 0000000000000000000000000000000000000000..1919fa214672e0161d49b95bf2b7dffa2ff1002a --- /dev/null +++ b/scripts/download_fasttext.py @@ -0,0 +1,306 @@ +import pathlib + +from absl import app +from absl import flags + +from scripts import utils + +# egrep -o 'https?://[^ ]+vec.gz' links.txt +# https://github.com/facebookresearch/fastText/blob/master/docs/crawl-vectors.md +LINKS = [ + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.af.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sq.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.als.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.am.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ar.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.an.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hy.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.as.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ast.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.az.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ba.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.eu.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bar.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.be.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bh.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bpy.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bs.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.br.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bg.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.my.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ca.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ceb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bcl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ce.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zh.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.co.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cs.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.da.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.dv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pa.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.arz.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.eml.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.myv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.eo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.et.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hif.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fi.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ka.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.de.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gom.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.el.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gu.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ht.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.he.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mrj.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hi.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hu.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.is.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.io.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ilo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.id.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ia.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ga.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.it.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ja.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.jv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.kn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pam.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.kk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.km.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ky.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ko.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ku.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ckb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.la.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.li.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lt.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lmo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nds.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mai.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mg.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ms.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ml.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mt.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mzn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mhr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.min.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.xmf.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mwl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nah.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nap.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ne.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.new.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.frr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nso.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.no.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.oc.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.or.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.os.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pfl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ps.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fa.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pms.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pt.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.qu.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ro.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.rm.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ru.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sah.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sa.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sc.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sco.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gd.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sh.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.scn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sd.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.si.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.so.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.azb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.es.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.su.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sw.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tg.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ta.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tt.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.te.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.th.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.uk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hsb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ur.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ug.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.uz.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vec.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vi.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.wa.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.war.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cy.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vls.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fy.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pnb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.yi.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.yo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.diq.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zea.300.vec.gz", +] + +CODE_2_LANG = { + "af": "Afrikaans", + "aii": "Assyrian", + "ajp": "South_Levantine_Arabic", + "akk": "Akkadian", + "am": "Amharic", + "apu": "Apurina", + "aqz": "Akuntsu", + "ar": "Arabic", + "be": "Belarusian", + "bg": "Bulgarian", + "bho": "Bhojpuri", + "bm": "Bambara", + "br": "Breton", + "bxr": "Buryat", + "ca": "Catalan", + "ckt": "Chukchi", + "cop": "Coptic", + "cs": "Czech", + "cu": "Old_Church_Slavonic", + "cy": "Welsh", + "da": "Danish", + "de": "German", + "el": "Greek", + "en": "English", + "es": "Spanish", + "et": "Estonian", + "eu": "Basque", + "fa": "Persian", + "fi": "Finnish", + "fo": "Faroese", + "fr": "French", + "fro": "Old_French", + "ga": "Irish", + "gd": "Scottish_Gaelic", + "gl": "Galician", + "got": "Gothic", + "grc": "Ancient_Greek", + "gsw": "Swiss_German", + "gun": "Mbya_Guarani", + "gv": "Manx", + "he": "Hebrew", + "hi": "Hindi", + "hr": "Croatian", + "hsb": "Upper_Sorbian", + "hu": "Hungarian", + "hy": "Armenian", + "id": "Indonesian", + "is": "Icelandic", + "it": "Italian", + "ja": "Japanese", + "kfm": "Khunsari", + "kk": "Kazakh", + "kmr": "Kurmanji", + "ko": "Korean", + "koi": "Komi_Permyak", + "kpv": "Komi_Zyrian", + "krl": "Karelian", + "la": "Latin", + "lt": "Lithuanian", + "lv": "Latvian", + "lzh": "Classical_Chinese", + "mdf": "Moksha", + "mr": "Marathi", + "mt": "Maltese", + "myu": "Munduruku", + "myv": "Erzya", + "nl": "Dutch", + "no": "Norwegian", + "nyq": "Nayini", + "olo": "Livvi", + "orv": "Old_Russian", + "otk": "Old_Turkish", + "pcm": "Naija", + "pl": "Polish", + "pt": "Portuguese", + "qhe": "Hindi_English", + "qtd": "Turkish_German", + "ro": "Romanian", + "ru": "Russian", + "sa": "Sanskrit", + "sk": "Slovak", + "sl": "Slovenian", + "sme": "North_Sami", + "sms": "Skolt_Sami", + "soj": "Soi", + "sq": "Albanian", + "sr": "Serbian", + "sv": "Swedish", + "swl": "Swedish_Sign_Language", + "ta": "Tamil", + "te": "Telugu", + "th": "Thai", + "tl": "Tagalog", + "tpn": "Tupinamba", + "tr": "Turkish", + "ug": "Uyghur", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "wbp": "Warlpiri", + "wo": "Wolof", + "yo": "Yoruba", + "yue": "Cantonese", + "zh": "Chinese", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string(name="output_dir", default="", + help="Path to store embeddings.") + + +def run(_): + output_dir = pathlib.Path(FLAGS.output_dir) + for link in LINKS: + lang_code = link.split(".")[-4] + + if lang_code not in CODE_2_LANG: + print(f"Unknown code {lang_code}.") + continue + + output_file = output_dir / CODE_2_LANG[lang_code] + output_file.mkdir(exist_ok=True, parents=True) + if (output_file / 'vectors.vec.gz').exists(): + print(f"Vectors for {CODE_2_LANG[lang_code]} already exists, skipping.") + continue + + utils.execute_command(f"wget -O {output_file / 'vectors.vec.gz'} {link}") + + +def main(): + app.run(run) + + +if __name__ == "__main__": + main() diff --git a/scripts/train.py b/scripts/train.py index 939088800f772c113693eb6d0858304ed82f766d..dc75344432a52cfb64b4931582aaa7d963b6839b 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -6,38 +6,71 @@ from absl import flags from scripts import utils +# # ls -1 | xargs -i echo "\"{}\"," +# UD 2.7 TREEBANKS = [ "UD_Afrikaans-AfriBooms", + "UD_Akkadian-PISANDUB", + "UD_Akkadian-RIAO", + "UD_Akuntsu-TuDeT", + "UD_Albanian-TSA", + "UD_Amharic-ATT", + "UD_Ancient_Greek-Perseus", + "UD_Ancient_Greek-PROIEL", + "UD_Apurina-UFPA", "UD_Arabic-NYUAD", "UD_Arabic-PADT", + "UD_Arabic-PUD", "UD_Armenian-ArmTDP", + "UD_Assyrian-AS", + "UD_Bambara-CRB", "UD_Basque-BDT", "UD_Belarusian-HSE", + "UD_Bhojpuri-BHTB", "UD_Breton-KEB", "UD_Bulgarian-BTB", + "UD_Buryat-BDT", + "UD_Cantonese-HK", "UD_Catalan-AnCora", + "UD_Chinese-CFL", + "UD_Chinese-GSD", + "UD_Chinese-GSDSimp", + "UD_Chinese-HK", + "UD_Chinese-PUD", + "UD_Chukchi-HSE", + "UD_Classical_Chinese-Kyoto", + "UD_Coptic-Scriptorium", "UD_Croatian-SET", "UD_Czech-CAC", "UD_Czech-CLTT", "UD_Czech-FicTree", "UD_Czech-PDT", + "UD_Czech-PUD", "UD_Danish-DDT", "UD_Dutch-Alpino", "UD_Dutch-LassySmall", "UD_English-ESL", "UD_English-EWT", "UD_English-GUM", + "UD_English-GUMReddit", "UD_English-LinES", "UD_English-ParTUT", "UD_English-Pronouns", + "UD_English-PUD", + "UD_Erzya-JR", "UD_Estonian-EDT", "UD_Estonian-EWT", + "UD_Faroese-FarPaHC", + "UD_Faroese-OFT", "UD_Finnish-FTB", + "UD_Finnish-OOD", + "UD_Finnish-PUD", "UD_Finnish-TDT", "UD_French-FQB", "UD_French-FTB", "UD_French-GSD", "UD_French-ParTUT", + "UD_French-PUD", "UD_French-Sequoia", "UD_French-Spoken", "UD_Galician-CTG", @@ -45,60 +78,120 @@ TREEBANKS = [ "UD_German-GSD", "UD_German-HDT", "UD_German-LIT", + "UD_German-PUD", + "UD_Gothic-PROIEL", "UD_Greek-GDT", "UD_Hebrew-HTB", "UD_Hindi_English-HIENCS", "UD_Hindi-HDTB", + "UD_Hindi-PUD", "UD_Hungarian-Szeged", + "UD_Icelandic-IcePaHC", + "UD_Icelandic-PUD", + "UD_Indonesian-CSUI", "UD_Indonesian-GSD", + "UD_Indonesian-PUD", "UD_Irish-IDT", "UD_Italian-ISDT", "UD_Italian-ParTUT", "UD_Italian-PoSTWITA", + "UD_Italian-PUD", "UD_Italian-TWITTIRO", "UD_Italian-VIT", - "UD_Japanese-BCCWJ", + # "UD_Japanese-BCCWJ", no data "UD_Japanese-GSD", "UD_Japanese-Modern", + "UD_Japanese-PUD", + "UD_Karelian-KKPP", "UD_Kazakh-KTB", + "UD_Khunsari-AHA", + "UD_Komi_Permyak-UH", + "UD_Komi_Zyrian-IKDP", + "UD_Komi_Zyrian-Lattice", "UD_Korean-GSD", "UD_Korean-Kaist", + "UD_Korean-PUD", + "UD_Kurmanji-MG", "UD_Latin-ITTB", + "UD_Latin-LLCT", "UD_Latin-Perseus", "UD_Latin-PROIEL", "UD_Latvian-LVTB", "UD_Lithuanian-ALKSNIS", "UD_Lithuanian-HSE", + "UD_Livvi-KKPP", "UD_Maltese-MUDT", + "UD_Manx-Cadhan", "UD_Marathi-UFAL", + "UD_Mbya_Guarani-Dooley", + "UD_Mbya_Guarani-Thomas", + "UD_Moksha-JR", + "UD_Munduruku-TuDeT", + "UD_Naija-NSC", + "UD_Nayini-AHA", + "UD_North_Sami-Giella", + "UD_Norwegian-Bokmaal", + "UD_Norwegian-Nynorsk", + "UD_Norwegian-NynorskLIA", + "UD_Old_Church_Slavonic-PROIEL", + "UD_Old_French-SRCMF", + "UD_Old_Russian-RNC", + "UD_Old_Russian-TOROT", + "UD_Old_Turkish-Tonqq", + "UD_Persian-PerDT", "UD_Persian-Seraji", "UD_Polish-LFG", "UD_Polish-PDB", + "UD_Polish-PUD", "UD_Portuguese-Bosque", "UD_Portuguese-GSD", + "UD_Portuguese-PUD", "UD_Romanian-Nonstandard", "UD_Romanian-RRT", "UD_Romanian-SiMoNERo", "UD_Russian-GSD", + "UD_Russian-PUD", "UD_Russian-SynTagRus", "UD_Russian-Taiga", + "UD_Sanskrit-UFAL", + "UD_Sanskrit-Vedic", + "UD_Scottish_Gaelic-ARCOSG", "UD_Serbian-SET", + "UD_Skolt_Sami-Giellagas", "UD_Slovak-SNK", "UD_Slovenian-SSJ", "UD_Slovenian-SST", + "UD_Soi-AHA", + "UD_South_Levantine_Arabic-MADAR", "UD_Spanish-AnCora", "UD_Spanish-GSD", + "UD_Spanish-PUD", "UD_Swedish-LinES", + "UD_Swedish-PUD", "UD_Swedish_Sign_Language-SSLC", "UD_Swedish-Talbanken", + "UD_Swiss_German-UZH", + "UD_Tagalog-TRG", + "UD_Tagalog-Ugnayan", + "UD_Tamil-MWTT", "UD_Tamil-TTB", "UD_Telugu-MTG", + "UD_Thai-PUD", + "UD_Tupinamba-TuDeT", + "UD_Turkish-BOUN", "UD_Turkish-GB", + "UD_Turkish_German-SAGT", "UD_Turkish-IMST", + "UD_Turkish-PUD", "UD_Ukrainian-IU", + "UD_Upper_Sorbian-UFAL", "UD_Urdu-UDTB", "UD_Uyghur-UDT", "UD_Vietnamese-VTB", + "UD_Warlpiri-UFAL", + "UD_Welsh-CCG", + "UD_Wolof-WTB", + "UD_Yoruba-YTB", ] FLAGS = flags.FLAGS @@ -119,7 +212,7 @@ def run(_): for treebank in FLAGS.treebanks: assert treebank in TREEBANKS, f"Unknown treebank {treebank}." treebank_dir = treebanks_dir / treebank - treebank_parts = treebank.split("_")[1].split("-") + treebank_parts = treebank[3:].split("-") language = treebank_parts[0] files = list(treebank_dir.iterdir()) @@ -136,7 +229,7 @@ def run(_): embeddings_file = None if embeddings_dir: embeddings_dir = pathlib.Path(embeddings_dir) / language - embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec.gz" in f.name] + embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec" in f.name] assert len(embeddings_file) == 1, f"Couldn't find embeddings file." embeddings_file = embeddings_file[0] @@ -153,14 +246,19 @@ def run(_): else f"--pretrained_transformer_name {utils.LANG2TRANSFORMER[language]}"} --serialization_dir {serialization_dir} --config_path {pathlib.Path.cwd() / 'config.template.jsonnet'} - --word_batch_size 2500 --notensorboard """ - # no XPOS datasets - if treebank in ["UD_Hungarian-Szeged", "UD_Armenian-ArmTDP"]: + # Datasets without XPOS + if treebank in {"UD_Armenian-ArmTDP", "UD_Basque-BDT", "UD_Hungarian-Szeged"}: command = command + " --targets deprel,head,upostag,lemma,feats" + # Reduce word_batch_size + word_batch_size = 2500 + if treebank in {"UD_German-HDT"}: + word_batch_size = 1000 + command = command + f" --word_batch_size {word_batch_size}" + utils.execute_command(command) diff --git a/setup.py b/setup.py index fdaa2be51e5098e9aa2d2ddc3b691188f11c7b10..8d4db9dd982dd9c9f27ac42156e328ad85135b5b 100644 --- a/setup.py +++ b/setup.py @@ -6,15 +6,16 @@ REQUIREMENTS = [ 'allennlp==1.2.1', 'conllu==2.3.2', 'dataclasses;python_version<"3.7"', - 'joblib==0.14.1', 'jsonnet==0.15.0', - 'requests==2.23.0', + 'numpy==1.19.4', 'overrides==3.1.0', - 'tensorboard==2.1.0', + 'requests==2.23.0', + 'spacy==2.3.2', + 'scikit-learn<=0.23.2', 'torch==1.6.0', 'tqdm==4.43.0', 'transformers>=3.4.0,<3.5', - 'urllib3>=1.25.11', + 'urllib3==1.25.11', ] setup( diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py index 242eaa3dcffaf452c19a52e2f56625de00cd0433..bf7f619d19853700803598a72201815f264d1c70 100644 --- a/tests/utils/test_metrics.py +++ b/tests/utils/test_metrics.py @@ -154,3 +154,58 @@ class SequenceBoolAccuracyTest(unittest.TestCase): # then self.assertEqual(metric._correct_count.item(), 7) self.assertEqual(metric._total_count.item(), 10) + + +class LemmaAccuracyTest(unittest.TestCase): + + def setUp(self) -> None: + self.mask: torch.BoolTensor = torch.tensor([ + [True, True, True, True], + [True, True, True, False], + ]) + + def test_prediction_has_error_in_not_padded_place(self): + # given + metric = metrics.LemmaAccuracy() + predictions = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 0], [1, 1000, 0], [1, 1, 0], [1, 1, 0], ], + ]) + gold_labels = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ], + ]) + expected_correct_count = 6 + expected_total_count = 7 + expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 0, 1, 0]) + + # when + metric(predictions, gold_labels, self.mask) + + # then + self.assertEqual(metric._correct_count.item(), expected_correct_count) + self.assertEqual(metric._total_count.item(), expected_total_count) + self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices))) + + def test_prediction_wrong_prediction_in_padding_should_be_ignored(self): + # given + metric = metrics.LemmaAccuracy() + predictions = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 1000], [1, 1, 0], [1, 1, 0], [1, 1, 0], ], + ]) + gold_labels = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ], + ]) + expected_correct_count = 7 + expected_total_count = 7 + expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 1, 1, 0]) + + # when + metric(predictions, gold_labels, self.mask) + + # then + self.assertEqual(expected_correct_count, metric._correct_count.item()) + self.assertEqual(expected_total_count, metric._total_count.item()) + self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices)))