From d2a17a768b4165a4a6017d1e71b3849132e7a0f1 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Wed, 13 Jan 2021 11:30:20 +0100
Subject: [PATCH] Add models license information, make dataset_reader public
 attribute, exclude treebanks without data from script.

---
 combo/predict.py | 11 +++----
 docs/models.md   |  9 +++++-
 scripts/train.py | 74 ++++++++++++++++++++----------------------------
 3 files changed, 44 insertions(+), 50 deletions(-)

diff --git a/combo/predict.py b/combo/predict.py
index bd9f5d4..e528a18 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -29,8 +29,9 @@ class COMBO(predictor.Predictor):
         super().__init__(model, dataset_reader)
         self.batch_size = batch_size
         self.vocab = model.vocab
-        self._dataset_reader.generate_labels = False
-        self._dataset_reader.lazy = True
+        self.dataset_reader = self._dataset_reader
+        self.dataset_reader.generate_labels = False
+        self.dataset_reader.lazy = True
         self._tokenizer = tokenizer
         self.without_sentence_embedding = False
         self.line_to_conllu = line_to_conllu
@@ -112,7 +113,7 @@ class COMBO(predictor.Predictor):
             tokens = sentence
         else:
             raise ValueError("Input must be either string or list of strings.")
-        return self._dataset_reader.text_to_instance(tokens2conllu(tokens))
+        return self.dataset_reader.text_to_instance(tokens2conllu(tokens))
 
     @overrides
     def load_line(self, line: str) -> common.JsonDict:
@@ -125,7 +126,7 @@ class COMBO(predictor.Predictor):
         if self.without_sentence_embedding:
             outputs.sentence_embedding = []
         if self.line_to_conllu:
-            return sentence2conllu(outputs, keep_semrel=self._dataset_reader.use_sem).serialize()
+            return sentence2conllu(outputs, keep_semrel=self.dataset_reader.use_sem).serialize()
         else:
             return outputs.to_json()
 
@@ -134,7 +135,7 @@ class COMBO(predictor.Predictor):
         return {"sentence": sentence}
 
     def _to_input_instance(self, sentence: data.Sentence) -> allen_data.Instance:
-        return self._dataset_reader.text_to_instance(sentence2conllu(sentence))
+        return self.dataset_reader.text_to_instance(sentence2conllu(sentence))
 
     def _predictions_as_tree(self, predictions: Dict[str, Any], instance: allen_data.Instance):
         tree = instance.fields["metadata"]["input"]
diff --git a/docs/models.md b/docs/models.md
index 94eed03..96bd7e9 100644
--- a/docs/models.md
+++ b/docs/models.md
@@ -4,8 +4,15 @@ COMBO provides pre-trained models for:
 - morphosyntactic prediction (i.e. part-of-speech tagging, morphosyntactic analysis, lemmatisation and dependency parsing) trained on the treebanks from [Universal Dependencies repository](https://universaldependencies.org),
 - enhanced dependency parsing trained on IWPT 2020 shared task [data](https://universaldependencies.org/iwpt20/data.html).
 
-Pre-trained models list with the **evaluation results** is available in the [spreadsheet](https://docs.google.com/spreadsheets/d/1WFYc2aLRa1jw7le030HOacv9fc4zmtqiZtRQY6gl5mc/edit?usp=sharing)
+## Pre-trained models
+**Pre-trained models** list with the **evaluation results** is available in the [spreadsheet](https://docs.google.com/spreadsheets/d/1WFYc2aLRa1jw7le030HOacv9fc4zmtqiZtRQY6gl5mc/edit?usp=sharing)
 Please notice that the name in the brackets matches the name used in [Automatic Download](models.md#Automatic download).
+
+### License
+Models are licensed on the same license as data used to train.
+
+See [Universal Dependencies v2.7 License Agreement](https://lindat.mff.cuni.cz/repository/xmlui/page/license-ud-2.7) and [Universal Dependencies v2.5 License Agreement](https://lindat.mff.cuni.cz/repository/xmlui/page/licence-UD-2.5) for details.
+
 ## Manual download
 
 The pre-trained models can be downloaded from [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/).
diff --git a/scripts/train.py b/scripts/train.py
index dc75344..accca4a 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -10,24 +10,17 @@ from scripts import utils
 # UD 2.7
 TREEBANKS = [
     "UD_Afrikaans-AfriBooms",
-    "UD_Akkadian-PISANDUB",
-    "UD_Akkadian-RIAO",
-    "UD_Akuntsu-TuDeT",
-    "UD_Albanian-TSA",
-    "UD_Amharic-ATT",
-    "UD_Ancient_Greek-Perseus",
-    "UD_Ancient_Greek-PROIEL",
-    "UD_Apurina-UFPA",
+    # "UD_Albanian-TSA", No training data
+    # "UD_Amharic-ATT", No training data
     "UD_Arabic-NYUAD",
     "UD_Arabic-PADT",
     "UD_Arabic-PUD",
     "UD_Armenian-ArmTDP",
-    "UD_Assyrian-AS",
-    "UD_Bambara-CRB",
+    # "UD_Assyrian-AS", No training data
+    # "UD_Bambara-CRB", No training data
     "UD_Basque-BDT",
     "UD_Belarusian-HSE",
-    "UD_Bhojpuri-BHTB",
-    "UD_Breton-KEB",
+    # "UD_Breton-KEB", No training data
     "UD_Bulgarian-BTB",
     "UD_Buryat-BDT",
     "UD_Cantonese-HK",
@@ -48,17 +41,9 @@ TREEBANKS = [
     "UD_Czech-PUD",
     "UD_Danish-DDT",
     "UD_Dutch-Alpino",
-    "UD_Dutch-LassySmall",
-    "UD_English-ESL",
+    #END OF FIRST RUN
     "UD_English-EWT",
-    "UD_English-GUM",
-    "UD_English-GUMReddit",
-    "UD_English-LinES",
-    "UD_English-ParTUT",
-    "UD_English-Pronouns",
-    "UD_English-PUD",
-    "UD_Erzya-JR",
-    "UD_Estonian-EDT",
+    # "UD_Erzya-JR", No training data
     "UD_Estonian-EWT",
     "UD_Faroese-FarPaHC",
     "UD_Faroese-OFT",
@@ -98,7 +83,7 @@ TREEBANKS = [
     "UD_Italian-PUD",
     "UD_Italian-TWITTIRO",
     "UD_Italian-VIT",
-    # "UD_Japanese-BCCWJ", no data
+    # "UD_Japanese-BCCWJ", No public data
     "UD_Japanese-GSD",
     "UD_Japanese-Modern",
     "UD_Japanese-PUD",
@@ -119,9 +104,9 @@ TREEBANKS = [
     "UD_Latvian-LVTB",
     "UD_Lithuanian-ALKSNIS",
     "UD_Lithuanian-HSE",
-    "UD_Livvi-KKPP",
+    # end batch 2
     "UD_Maltese-MUDT",
-    "UD_Manx-Cadhan",
+    # "UD_Manx-Cadhan", No training data
     "UD_Marathi-UFAL",
     "UD_Mbya_Guarani-Dooley",
     "UD_Mbya_Guarani-Thomas",
@@ -153,8 +138,7 @@ TREEBANKS = [
     "UD_Russian-PUD",
     "UD_Russian-SynTagRus",
     "UD_Russian-Taiga",
-    "UD_Sanskrit-UFAL",
-    "UD_Sanskrit-Vedic",
+    # "UD_Sanskrit-UFAL", No training data
     "UD_Scottish_Gaelic-ARCOSG",
     "UD_Serbian-SET",
     "UD_Skolt_Sami-Giellagas",
@@ -167,31 +151,22 @@ TREEBANKS = [
     "UD_Spanish-GSD",
     "UD_Spanish-PUD",
     "UD_Swedish-LinES",
-    "UD_Swedish-PUD",
-    "UD_Swedish_Sign_Language-SSLC",
-    "UD_Swedish-Talbanken",
-    "UD_Swiss_German-UZH",
-    "UD_Tagalog-TRG",
-    "UD_Tagalog-Ugnayan",
-    "UD_Tamil-MWTT",
-    "UD_Tamil-TTB",
+    # "UD_Tagalog-TRG", No training data
+    # "UD_Tamil-MWTT", No training data
     "UD_Telugu-MTG",
-    "UD_Thai-PUD",
-    "UD_Tupinamba-TuDeT",
+    # "UD_Thai-PUD", No training data
     "UD_Turkish-BOUN",
     "UD_Turkish-GB",
     "UD_Turkish_German-SAGT",
     "UD_Turkish-IMST",
     "UD_Turkish-PUD",
     "UD_Ukrainian-IU",
-    "UD_Upper_Sorbian-UFAL",
+    # "UD_Upper_Sorbian-UFAL", No validation data
     "UD_Urdu-UDTB",
     "UD_Uyghur-UDT",
     "UD_Vietnamese-VTB",
-    "UD_Warlpiri-UFAL",
-    "UD_Welsh-CCG",
-    "UD_Wolof-WTB",
-    "UD_Yoruba-YTB",
+    # "UD_Welsh-CCG", No validation data
+    # "UD_Yoruba-YTB", No training data
 ]
 
 FLAGS = flags.FLAGS
@@ -250,13 +225,24 @@ def run(_):
         """
 
         # Datasets without XPOS
-        if treebank in {"UD_Armenian-ArmTDP", "UD_Basque-BDT", "UD_Hungarian-Szeged"}:
+        if treebank in {"UD_Armenian-ArmTDP", "UD_Basque-BDT", "UD_Danish-DDT", "UD_Hungarian-Szeged", "UD_French-GSD",
+                        "UD_Marathi-UFAL", "UD_Norwegian-Bokmaal"}:
             command = command + " --targets deprel,head,upostag,lemma,feats"
 
+        # Datasets without LEMMA and FEATS
+        if treebank in {"UD_Maltese-MUDT"}:
+            command = command + " --targets deprel,head,upostag,xpostag"
+
+        # Datasets without XPOS and FEATS
+        if treebank in {"UD_Telugu-MTG"}:
+            command = command + " --targets deprel,head,upostag,lemma"
+
         # Reduce word_batch_size
         word_batch_size = 2500
-        if treebank in {"UD_German-HDT"}:
+        if treebank in {"UD_German-HDT", "UD_Marathi-UFAL"}:
             word_batch_size = 1000
+        elif treebank in {"UD_Telugu-MTG"}:
+            word_batch_size = 500
         command = command + f" --word_batch_size {word_batch_size}"
 
         utils.execute_command(command)
-- 
GitLab