From 9cf28e9fa785ef9e3f71753051a1d779dc106357 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Tue, 22 Sep 2020 08:52:10 +0200
Subject: [PATCH 01/19] Add enhanced UD data preprocessing.

---
 combo/data/dataset.py                         | 55 +++++++++++++++++--
 .../data/fields/sequence_multilabel_field.py  | 27 ++++-----
 combo/main.py                                 |  2 +-
 combo/models/model.py                         |  4 +-
 .../fields/test_sequence_multilabel_field.py  | 20 ++++++-
 5 files changed, 85 insertions(+), 23 deletions(-)

diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index 459a755..0b53df3 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -1,9 +1,10 @@
 import logging
-from typing import Union, List, Dict, Iterable, Optional, Any
+from typing import Union, List, Dict, Iterable, Optional, Any, Tuple
 
 import conllu
+import torch
 from allennlp import data as allen_data
-from allennlp.common import checks
+from allennlp.common import checks, util
 from allennlp.data import fields as allen_fields, vocabulary
 from conllu import parser
 from dataclasses import dataclass
@@ -35,6 +36,9 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
         if "token" not in features and "char" not in features:
             raise checks.ConfigurationError("There must be at least one ('char' or 'token') text-based feature!")
 
+        if "deps" in targets and not ("head" in targets and "deprel" in targets):
+            raise checks.ConfigurationError("Add 'head' and 'deprel' to targets when using 'deps'!")
+
         intersection = set(features).intersection(set(targets))
         if len(intersection) != 0:
             raise checks.ConfigurationError(
@@ -102,13 +106,40 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                     elif target_name == "feats":
                         target_values = self._feat_values(tree_tokens)
                         fields_[target_name] = fields.SequenceMultiLabelField(target_values,
-                                                                              self._feats_to_index_multi_label,
+                                                                              self._feats_indexer,
+                                                                              self._feats_as_tensor_wrapper,
                                                                               text_field,
                                                                               label_namespace="feats_labels")
                     elif target_name == "head":
                         target_values = [0 if v == "_" else int(v) for v in target_values]
                         fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
                                                                                label_namespace=target_name + "_labels")
+                    elif target_name == "deps":
+                        heads = [0 if t["head"] == "_" else int(t["head"]) for t in tree_tokens]
+                        deprels = [t["deprel"] for t in tree_tokens]
+                        enhanced_heads: List[Tuple[int, int]] = []
+                        enhanced_deprels: List[str] = []
+                        for idx, t in enumerate(tree_tokens):
+                            enhanced_heads.append((idx, heads[idx]))
+                            enhanced_deprels.append(deprels[idx])
+                            t_deps = t["deps"]
+                            if t_deps and t_deps != "_":
+                                t_heads, t_deprels = zip(*[tuple(d.split(":")) for d in t_deps.split("|")])
+                                enhanced_heads.extend([(idx, t) for t in t_heads])
+                                enhanced_deprels.extend(t_deprels)
+                        fields_["enhanced_heads"] = allen_fields.AdjacencyField(
+                            indices=enhanced_heads,
+                            sequence_field=text_field,
+                            label_namespace="enhanced_heads_labels",
+                            padding_value=0,
+                        )
+                        fields_["enhanced_deprels"] = allen_fields.AdjacencyField(
+                            indices=enhanced_heads,
+                            sequence_field=text_field,
+                            labels=enhanced_deprels,
+                            label_namespace="enhanced_deprels_labels",
+                            padding_value=0,
+                        )
                     else:
                         fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
                                                                                label_namespace=target_name + "_labels")
@@ -151,12 +182,26 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
         return features
 
     @staticmethod
-    def _feats_to_index_multi_label(vocab: allen_data.Vocabulary):
+    def _feats_as_tensor_wrapper(field: fields.SequenceMultiLabelField):
+        def as_tensor(padding_lengths):
+            desired_num_tokens = padding_lengths["num_tokens"]
+            assert len(field._indexed_multi_labels) > 0
+            classes_count = len(field._indexed_multi_labels[0])
+            default_value = [0.0] * classes_count
+            padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens,
+                                                      lambda: default_value)
+            tensor = torch.LongTensor(padded_tags)
+            return tensor
+
+        return as_tensor
+
+    @staticmethod
+    def _feats_indexer(vocab: allen_data.Vocabulary):
         label_namespace = "feats_labels"
         vocab_size = vocab.get_vocab_size(label_namespace)
         slices = get_slices_if_not_provided(vocab)
 
-        def _m_from_n_ones_encoding(multi_label: List[str]) -> List[int]:
+        def _m_from_n_ones_encoding(multi_label: List[str], sentence_length: int) -> List[int]:
             one_hot_encoding = [0] * vocab_size
             for cat, cat_indices in slices.items():
                 if cat not in ["__PAD__", "_"]:
diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py
index 4e98a14..b200580 100644
--- a/combo/data/fields/sequence_multilabel_field.py
+++ b/combo/data/fields/sequence_multilabel_field.py
@@ -5,7 +5,7 @@ from typing import Set, List, Callable, Iterator, Union, Dict
 
 import torch
 from allennlp import data
-from allennlp.common import checks, util
+from allennlp.common import checks
 from allennlp.data import fields
 from overrides import overrides
 
@@ -17,15 +17,16 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
     A `SequenceMultiLabelField` is an extension of the :class:`MultiLabelField` that allows for multiple labels
     while keeping sequence dimension.
 
-    This field will get converted into a sequence of vectors of length equal to the vocabulary size with
-    M from N encoding for the labels (all zeros, and ones for the labels).
+    To allow configuration to different circumstances, class takes few delegates functions.
 
     # Parameters
 
     multi_labels : `List[List[str]]`
     multi_label_indexer : `Callable[[data.Vocabulary], Callable[[List[str]], List[int]]]`
-        Nested callable which based on vocab creates mapper for multilabel field in the sequence from strings
-        to indexed, int values.
+        Nested callable which based on vocab and sequence length maps values of the fields in the sequence
+        from strings to indexed, int values.
+    as_tensor: `Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]]`
+        Nested callable which based on the field itself, maps indexed data to a tensor.
     sequence_field : `SequenceField`
         A field containing the sequence that this `SequenceMultiLabelField` is labeling.  Most often, this is a
         `TextField`, for tagging individual tokens in a sentence.
@@ -43,7 +44,8 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
     def __init__(
             self,
             multi_labels: List[List[str]],
-            multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str]], List[int]]],
+            multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str], int], List[int]]],
+            as_tensor: Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]],
             sequence_field: fields.SequenceField,
             label_namespace: str = "labels",
     ) -> None:
@@ -53,6 +55,7 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
         self._label_namespace = label_namespace
         self._indexed_multi_labels = None
         self._maybe_warn_for_namespace(label_namespace)
+        self.as_tensor_wrapper = as_tensor(self)
         if len(multi_labels) != sequence_field.sequence_length():
             raise checks.ConfigurationError(
                 "Label length and sequence length "
@@ -101,7 +104,7 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
 
         indexed = []
         for multi_label in self.multi_labels:
-            indexed.append(indexer(multi_label))
+            indexed.append(indexer(multi_label, len(self.multi_labels)))
         self._indexed_multi_labels = indexed
 
     @overrides
@@ -110,19 +113,13 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
 
     @overrides
     def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
-        desired_num_tokens = padding_lengths["num_tokens"]
-        assert len(self._indexed_multi_labels) > 0
-        classes_count = len(self._indexed_multi_labels[0])
-        default_value = [0.0] * classes_count
-        padded_tags = util.pad_sequence_to_length(self._indexed_multi_labels, desired_num_tokens, lambda: default_value)
-        tensor = torch.LongTensor(padded_tags)
-        return tensor
+        return self.as_tensor_wrapper(padding_lengths)
 
     @overrides
     def empty_field(self) -> "SequenceMultiLabelField":
-        # The empty_list here is needed for mypy
         empty_list: List[List[str]] = [[]]
         sequence_label_field = SequenceMultiLabelField(empty_list, lambda x: lambda y: y,
+                                                       lambda x: lambda y: y,
                                                        self.sequence_field.empty_field())
         sequence_label_field._indexed_labels = empty_list
         return sequence_label_field
diff --git a/combo/main.py b/combo/main.py
index 44ad091..3f7fad4 100644
--- a/combo/main.py
+++ b/combo/main.py
@@ -18,7 +18,7 @@ from combo.utils import checks
 
 logger = logging.getLogger(__name__)
 _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"]
-_TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent"]
+_TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent", "deps"]
 
 FLAGS = flags.FLAGS
 flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict"],
diff --git a/combo/models/model.py b/combo/models/model.py
index 77b43e3..ec0f113 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -53,7 +53,9 @@ class SemanticMultitaskModel(allen_models.Model):
                 feats: torch.Tensor = None,
                 head: torch.Tensor = None,
                 deprel: torch.Tensor = None,
-                semrel: torch.Tensor = None, ) -> Dict[str, torch.Tensor]:
+                semrel: torch.Tensor = None,
+                enhanced_heads: torch.Tensor = None,
+                enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
 
         # Prepare masks
         char_mask: torch.BoolTensor = sentence["char"]["token_characters"] > 0
diff --git a/tests/data/fields/test_sequence_multilabel_field.py b/tests/data/fields/test_sequence_multilabel_field.py
index d2a1f8b..fff8ff4 100644
--- a/tests/data/fields/test_sequence_multilabel_field.py
+++ b/tests/data/fields/test_sequence_multilabel_field.py
@@ -4,6 +4,7 @@ from typing import List
 
 import torch
 from allennlp import data as allen_data
+from allennlp.common import util
 from allennlp.data import fields as allen_fields
 
 from combo.data import fields
@@ -22,7 +23,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
         def _indexer(vocab: allen_data.Vocabulary):
             vocab_size = vocab.get_vocab_size(self.namespace)
 
-            def _mapper(multi_label: List[str]) -> List[int]:
+            def _mapper(multi_label: List[str], _: int) -> List[int]:
                 one_hot = [0] * vocab_size
                 for label in multi_label:
                     index = vocab.get_token_index(label, self.namespace)
@@ -31,7 +32,21 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
 
             return _mapper
 
+        def _as_tensor(field: fields.SequenceMultiLabelField):
+
+            def _wrapped(padding_lengths):
+                desired_num_tokens = padding_lengths["num_tokens"]
+                classes_count = len(field._indexed_multi_labels[0])
+                default_value = [0.0] * classes_count
+                padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens,
+                                                          lambda: default_value)
+                tensor = torch.LongTensor(padded_tags)
+                return tensor
+
+            return _wrapped
+
         self.indexer = _indexer
+        self.as_tensor = _as_tensor
         self.sequence_field = _SequenceFieldTestWrapper(self.vocab.get_vocab_size(self.namespace))
 
     def test_indexing(self):
@@ -39,6 +54,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
         field = fields.SequenceMultiLabelField(
             multi_labels=[["t1", "t2"], [], ["t0"]],
             multi_label_indexer=self.indexer,
+            as_tensor=self.as_tensor,
             sequence_field=self.sequence_field,
             label_namespace=self.namespace
         )
@@ -55,6 +71,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
         field = fields.SequenceMultiLabelField(
             multi_labels=[["t1", "t2"], [], ["t0"]],
             multi_label_indexer=self.indexer,
+            as_tensor=self.as_tensor,
             sequence_field=self.sequence_field,
             label_namespace=self.namespace
         )
@@ -72,6 +89,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
         field = fields.SequenceMultiLabelField(
             multi_labels=[["t1", "t2"], [], ["t0"]],
             multi_label_indexer=self.indexer,
+            as_tensor=self.as_tensor,
             sequence_field=self.sequence_field,
             label_namespace=self.namespace
         )
-- 
GitLab


From 07ffff19ad790c8c9a52e27b815320e781cc9eab Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Wed, 23 Sep 2020 10:30:02 +0200
Subject: [PATCH 02/19] Fix adding enhanced labels.

---
 combo/data/dataset.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index 0b53df3..2877a65 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -124,7 +124,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                             enhanced_deprels.append(deprels[idx])
                             t_deps = t["deps"]
                             if t_deps and t_deps != "_":
-                                t_heads, t_deprels = zip(*[tuple(d.split(":")) for d in t_deps.split("|")])
+                                t_deprels, t_heads = zip(*t_deps)
                                 enhanced_heads.extend([(idx, t) for t in t_heads])
                                 enhanced_deprels.extend(t_deprels)
                         fields_["enhanced_heads"] = allen_fields.AdjacencyField(
@@ -137,7 +137,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                             indices=enhanced_heads,
                             sequence_field=text_field,
                             labels=enhanced_deprels,
-                            label_namespace="enhanced_deprels_labels",
+                            # Label namespace should match regular tree parsing.
+                            label_namespace="deprel_labels",
                             padding_value=0,
                         )
                     else:
-- 
GitLab


From 6d373257faa6c1f4dd51ae679d158b7c46512e47 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Wed, 23 Sep 2020 10:32:07 +0200
Subject: [PATCH 03/19] Passing training for enhanced dependency parsing.

---
 combo/data/dataset.py         |  30 ++-
 combo/main.py                 |   2 +
 combo/models/__init__.py      |   1 +
 combo/models/graph_parser.py  | 191 +++++++++++++++
 combo/models/model.py         |  18 +-
 combo/predict.py              |  13 +-
 combo/utils/graph.py          |  82 +++++++
 combo/utils/metrics.py        |  21 +-
 config.graph.template.jsonnet | 421 ++++++++++++++++++++++++++++++++++
 setup.py                      |   2 +-
 tests/fixtures/example.conllu |   7 +
 tests/utils/test_graph.py     |  89 +++++++
 tests/utils/test_metrics.py   |   8 +-
 13 files changed, 866 insertions(+), 19 deletions(-)
 create mode 100644 combo/models/graph_parser.py
 create mode 100644 combo/utils/graph.py
 create mode 100644 config.graph.template.jsonnet
 create mode 100644 tests/utils/test_graph.py

diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index 2877a65..fb770f6 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -1,3 +1,4 @@
+import copy
 import logging
 from typing import Union, List, Dict, Iterable, Optional, Any, Tuple
 
@@ -115,29 +116,34 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                         fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
                                                                                label_namespace=target_name + "_labels")
                     elif target_name == "deps":
-                        heads = [0 if t["head"] == "_" else int(t["head"]) for t in tree_tokens]
-                        deprels = [t["deprel"] for t in tree_tokens]
+                        # Graphs require adding ROOT (AdjacencyField uses sequence length from TextField).
+                        text_field_deps = copy.deepcopy(text_field)
+                        text_field_deps.tokens.insert(0, _Token("ROOT"))
                         enhanced_heads: List[Tuple[int, int]] = []
                         enhanced_deprels: List[str] = []
                         for idx, t in enumerate(tree_tokens):
-                            enhanced_heads.append((idx, heads[idx]))
-                            enhanced_deprels.append(deprels[idx])
                             t_deps = t["deps"]
                             if t_deps and t_deps != "_":
-                                t_deprels, t_heads = zip(*t_deps)
-                                enhanced_heads.extend([(idx, t) for t in t_heads])
-                                enhanced_deprels.extend(t_deprels)
+                                for rel, head in t_deps:
+                                    # EmoryNLP skips the first edge, if there are two edges between the same
+                                    # nodes. Thanks to that one is in a tree and another in a graph.
+                                    # This snippet follows that approach.
+                                    if enhanced_heads and enhanced_heads[-1] == (idx, head):
+                                        enhanced_heads.pop()
+                                        enhanced_deprels.pop()
+                                    enhanced_heads.append((idx, head))
+                                    enhanced_deprels.append(rel)
                         fields_["enhanced_heads"] = allen_fields.AdjacencyField(
                             indices=enhanced_heads,
-                            sequence_field=text_field,
+                            sequence_field=text_field_deps,
                             label_namespace="enhanced_heads_labels",
                             padding_value=0,
                         )
                         fields_["enhanced_deprels"] = allen_fields.AdjacencyField(
                             indices=enhanced_heads,
-                            sequence_field=text_field,
+                            sequence_field=text_field_deps,
                             labels=enhanced_deprels,
-                            # Label namespace should match regular tree parsing.
+                            # Label namespace matches regular tree parsing.
                             label_namespace="deprel_labels",
                             padding_value=0,
                         )
@@ -160,7 +166,9 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                 token["feats"] = field
 
         # metadata
-        fields_["metadata"] = allen_fields.MetadataField({"input": tree, "field_names": self.fields})
+        fields_["metadata"] = allen_fields.MetadataField({"input": tree,
+                                                          "field_names": self.fields,
+                                                          "tokens": tokens})
 
         return allen_data.Instance(fields_)
 
diff --git a/combo/main.py b/combo/main.py
index 3f7fad4..374af69 100644
--- a/combo/main.py
+++ b/combo/main.py
@@ -33,8 +33,10 @@ flags.DEFINE_string(name="output_file", default="output.log",
 # Training flags
 flags.DEFINE_list(name="training_data_path", default="./tests/fixtures/example.conllu",
                   help="Training data path(s)")
+flags.DEFINE_alias(name="training_data", original_name="training_data_path")
 flags.DEFINE_list(name="validation_data_path", default="",
                   help="Validation data path(s)")
+flags.DEFINE_alias(name="validation_data", original_name="validation_data_path")
 flags.DEFINE_string(name="pretrained_tokens", default="",
                     help="Pretrained tokens embeddings path")
 flags.DEFINE_integer(name="embedding_dim", default=300,
diff --git a/combo/models/__init__.py b/combo/models/__init__.py
index 5aa7b28..ba8d617 100644
--- a/combo/models/__init__.py
+++ b/combo/models/__init__.py
@@ -1,5 +1,6 @@
 """Models module."""
 from .base import FeedForwardPredictor
+from .graph_parser import GraphDependencyRelationModel
 from .parser import DependencyRelationModel
 from .embeddings import CharacterBasedWordEmbeddings
 from .encoder import ComboEncoder
diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py
new file mode 100644
index 0000000..a31e6d0
--- /dev/null
+++ b/combo/models/graph_parser.py
@@ -0,0 +1,191 @@
+"""Enhanced dependency parsing models."""
+from typing import Tuple, Dict, Optional, Union, List
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from allennlp import data
+from allennlp.nn import chu_liu_edmonds
+
+from combo.models import base, utils
+
+
+class GraphHeadPredictionModel(base.Predictor):
+    """Head prediction model."""
+
+    def __init__(self,
+                 head_projection_layer: base.Linear,
+                 dependency_projection_layer: base.Linear,
+                 cycle_loss_n: int = 0,
+                 graph_weighting: float = 0.2):
+        super().__init__()
+        self.head_projection_layer = head_projection_layer
+        self.dependency_projection_layer = dependency_projection_layer
+        self.cycle_loss_n = cycle_loss_n
+        self.graph_weighting = graph_weighting
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                mask: Optional[torch.BoolTensor] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        if mask is None:
+            mask = x.new_ones(x.size()[-1])
+        heads_labels = None
+        if labels is not None and labels[0] is not None:
+            heads_labels = labels
+
+        head_arc_emb = self.head_projection_layer(x)
+        dep_arc_emb = self.dependency_projection_layer(x)
+        x = dep_arc_emb.bmm(head_arc_emb.transpose(2, 1))
+        pred = x.sigmoid() > 0.5
+
+        output = {
+            "prediction": pred,
+            "probability": x
+        }
+
+        if heads_labels is not None:
+            if sample_weights is None:
+                sample_weights = heads_labels.new_ones([mask.size(0)])
+            output["loss"], output["cycle_loss"] = self._loss(x, heads_labels, mask, sample_weights)
+
+        return output
+
+    def _cycle_loss(self, pred: torch.Tensor):
+        BATCH_SIZE, _, _ = pred.size()
+        loss = pred.new_zeros(BATCH_SIZE)
+        # Index from 1: as using non __ROOT__ tokens
+        pred = pred.softmax(-1)[:, 1:, 1:]
+        x = pred
+        for i in range(self.cycle_loss_n):
+            loss += self._batch_trace(x)
+
+            # Don't multiple on last iteration
+            if i < self.cycle_loss_n - 1:
+                x = x.bmm(pred)
+
+        return loss
+
+    @staticmethod
+    def _batch_trace(x: torch.Tensor) -> torch.Tensor:
+        assert len(x.size()) == 3
+        BATCH_SIZE, N, M = x.size()
+        assert N == M
+        identity = x.new_tensor(torch.eye(N))
+        identity = identity.reshape((1, N, N))
+        batch_identity = identity.repeat(BATCH_SIZE, 1, 1)
+        return (x * batch_identity).sum((-1, -2))
+
+    def _loss(self, pred: torch.Tensor, labels: torch.Tensor,  mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        BATCH_SIZE, N, M = pred.size()
+        assert N == M
+        SENTENCE_LENGTH = N
+
+        valid_positions = mask.sum()
+
+        result = []
+        true = labels
+        # Ignore first pred dimension as it is ROOT token prediction
+        for i in range(SENTENCE_LENGTH - 1):
+            pred_i = pred[:, i + 1, 1:].reshape(-1)
+            true_i = true[:, i + 1, 1:].reshape(-1)
+            mask_i = mask[:, i]
+            bce_loss = F.binary_cross_entropy_with_logits(pred_i, true_i, reduction="none").mean(-1) * mask_i
+            result.append(bce_loss)
+        cycle_loss = self._cycle_loss(pred)
+        loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1)
+        return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean()
+
+
+@base.Predictor.register("combo_graph_dependency_parsing_from_vocab", constructor="from_vocab")
+class GraphDependencyRelationModel(base.Predictor):
+    """Dependency relation parsing model."""
+
+    def __init__(self,
+                 head_predictor: GraphHeadPredictionModel,
+                 head_projection_layer: base.Linear,
+                 dependency_projection_layer: base.Linear,
+                 relation_prediction_layer: base.Linear):
+        super().__init__()
+        self.head_predictor = head_predictor
+        self.head_projection_layer = head_projection_layer
+        self.dependency_projection_layer = dependency_projection_layer
+        self.relation_prediction_layer = relation_prediction_layer
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        # if mask is not None:
+        #     mask = mask[:, 1:]
+        relations_labels, head_labels, enhanced_heads_labels, enhanced_deprels_labels = None, None, None, None
+        if labels is not None and labels[0] is not None:
+            relations_labels, head_labels, enhanced_heads_labels = labels
+            # if mask is None:
+            #     mask = head_labels.new_ones(head_labels.size())
+
+        head_output = self.head_predictor(x, enhanced_heads_labels, mask, sample_weights)
+        head_pred = head_output["probability"]
+        BATCH_SIZE, LENGTH, _ = head_pred.size()
+
+        head_rel_emb = self.head_projection_layer(x)
+
+        dep_rel_emb = self.dependency_projection_layer(x)
+
+        # All possible edges combinations for each batch
+        # Repeat interleave to have [emb1, emb1 ... (length times) ... emb1, emb2 ... ]
+        head_rel_pred = head_rel_emb.repeat_interleave(LENGTH, -2)
+        # Regular repeat to have all combinations [deprel1, deprel2, ... deprelL, deprel1 ...]
+        dep_rel_pred = dep_rel_emb.repeat(1, LENGTH, 1)
+
+        # All possible edges combinations for each batch
+        dep_rel_pred = torch.cat((head_rel_pred, dep_rel_pred), dim=-1)
+
+        relation_prediction = self.relation_prediction_layer(dep_rel_pred).reshape(BATCH_SIZE, LENGTH, LENGTH, -1)
+        output = head_output
+
+        output["prediction"] = (relation_prediction.argmax(-1), head_output["prediction"])
+
+        if labels is not None and labels[0] is not None:
+            if sample_weights is None:
+                sample_weights = labels.new_ones([mask.size(0)])
+            loss = self._loss(relation_prediction, relations_labels, enhanced_heads_labels, mask, sample_weights)
+            output["loss"] = (loss, head_output["loss"])
+
+        return output
+
+    @staticmethod
+    def _loss(pred: torch.Tensor,
+              true: torch.Tensor,
+              heads_true: torch.Tensor,
+              mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> torch.Tensor:
+
+        true = true[true.long() > 0]
+        pred = pred[heads_true.long() == 1]
+        loss = F.cross_entropy(pred, true.long())
+        return loss.sum() / pred.size(0)
+
+    @classmethod
+    def from_vocab(cls,
+                   vocab: data.Vocabulary,
+                   vocab_namespace: str,
+                   head_predictor: GraphHeadPredictionModel,
+                   head_projection_layer: base.Linear,
+                   dependency_projection_layer: base.Linear
+                   ):
+        """Creates parser combining model configuration and vocabulary data."""
+        assert vocab_namespace in vocab.get_namespaces()
+        relation_prediction_layer = base.Linear(
+            in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(),
+            out_features=vocab.get_vocab_size(vocab_namespace)
+        )
+        return cls(
+            head_predictor=head_predictor,
+            head_projection_layer=head_projection_layer,
+            dependency_projection_layer=dependency_projection_layer,
+            relation_prediction_layer=relation_prediction_layer
+        )
diff --git a/combo/models/model.py b/combo/models/model.py
index ec0f113..124f49a 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -27,6 +27,7 @@ class SemanticMultitaskModel(allen_models.Model):
                  semantic_relation: Optional[base.Predictor] = None,
                  morphological_feat: Optional[base.Predictor] = None,
                  dependency_relation: Optional[base.Predictor] = None,
+                 enhanced_dependency_relation: Optional[base.Predictor] = None,
                  regularizer: allen_nn.RegularizerApplicator = None) -> None:
         super().__init__(vocab, regularizer)
         self.text_field_embedder = text_field_embedder
@@ -39,6 +40,7 @@ class SemanticMultitaskModel(allen_models.Model):
         self.semantic_relation = semantic_relation
         self.morphological_feat = morphological_feat
         self.dependency_relation = dependency_relation
+        self.enhanced_dependency_relation = enhanced_dependency_relation
         self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, self.seq_encoder.get_output_dim()]))
         self.scores = metrics.SemanticMetrics()
         self._partial_losses = None
@@ -96,7 +98,7 @@ class SemanticMultitaskModel(allen_models.Model):
                                        sample_weights=sample_weights)
         lemma_output = self._optional(self.lemmatizer,
                                       (encoder_emb[:, 1:], sentence.get("char").get("token_characters")
-                                       if sentence.get("char") else None),
+                                      if sentence.get("char") else None),
                                       mask=word_mask[:, 1:],
                                       labels=lemma.get("char").get("token_characters") if lemma else None,
                                       sample_weights=sample_weights)
@@ -106,7 +108,14 @@ class SemanticMultitaskModel(allen_models.Model):
                                        mask=word_mask,
                                        labels=(deprel, head),
                                        sample_weights=sample_weights)
+        enhanced_parser_output = self._optional(self.enhanced_dependency_relation,
+                                                encoder_emb,
+                                                returns_tuple=True,
+                                                mask=word_mask,
+                                                labels=(enhanced_deprels, head, enhanced_heads),
+                                                sample_weights=sample_weights)
         relations_pred, head_pred = parser_output["prediction"]
+        enhanced_relations_pred, enhanced_head_pred = enhanced_parser_output["prediction"]
         output = {
             "upostag": upos_output["prediction"],
             "xpostag": xpos_output["prediction"],
@@ -115,6 +124,8 @@ class SemanticMultitaskModel(allen_models.Model):
             "lemma": lemma_output["prediction"],
             "head": head_pred,
             "deprel": relations_pred,
+            "enhanced_head": enhanced_head_pred,
+            "enhanced_deprel": enhanced_relations_pred,
             "sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0],
         }
 
@@ -136,9 +147,12 @@ class SemanticMultitaskModel(allen_models.Model):
                 "lemma": lemma.get("char").get("token_characters") if lemma else None,
                 "head": head,
                 "deprel": deprel,
+                "enhanced_head": enhanced_heads,
+                "enhanced_deprel": enhanced_deprels,
             }
             self.scores(output, labels, word_mask[:, 1:])
             relations_loss, head_loss = parser_output["loss"]
+            enhanced_relations_loss, enhanced_head_loss = enhanced_parser_output["loss"]
             losses = {
                 "upostag_loss": upos_output["loss"],
                 "xpostag_loss": xpos_output["loss"],
@@ -147,6 +161,8 @@ class SemanticMultitaskModel(allen_models.Model):
                 "lemma_loss": lemma_output["loss"],
                 "head_loss": head_loss,
                 "deprel_loss": relations_loss,
+                "enhanced_head_loss": enhanced_head_loss,
+                "enhanced_deprel_loss": enhanced_relations_loss,
                 # Cycle loss is only for the metrics purposes.
                 "cycle_loss": parser_output.get("cycle_loss")
             }
diff --git a/combo/predict.py b/combo/predict.py
index b6c7172..55f78c3 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -154,6 +154,10 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                         token[field_name] = value
                     elif field_name in ["head"]:
                         token[field_name] = int(predictions[field_name][idx])
+                    elif field_name == "deps":
+                        # Handled after every other decoding
+                        continue
+
                     elif field_name in ["feats"]:
                         slices = self._model.morphological_feat.slices
                         features = []
@@ -171,8 +175,6 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                             field_value = "|".join(sorted(features))
 
                         token[field_name] = field_value
-                    elif field_name == "head":
-                        pass
                     elif field_name == "lemma":
                         prediction = predictions[field_name][idx]
                         word_chars = []
@@ -191,6 +193,13 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                     else:
                         raise NotImplementedError(f"Unknown field name {field_name}!")
 
+        if "enhanced_head" in predictions and predictions["enhanced_head"]:
+            import combo.utils.graph as graph
+            tree = graph.sdp_to_dag_deps(arc_scores=predictions["enhanced_head"],
+                                         rel_scores=predictions["enhanced_deprel"],
+                                         tree=tree,
+                                         root_label="ROOT")
+
         return tree, predictions["sentence_embedding"]
 
     @classmethod
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
new file mode 100644
index 0000000..5970b19
--- /dev/null
+++ b/combo/utils/graph.py
@@ -0,0 +1,82 @@
+"""Based on https://github.com/emorynlp/iwpt-shared-task-2020."""
+import numpy as np
+from conllu import TokenList
+
+
+def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label):
+    # adding ROOT
+    tree_tokens = tree.tokens
+    tree_heads = [0] + [t["head"] for t in tree_tokens]
+    graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
+                                                      root_label)
+    for i, (t, g) in enumerate(zip(tree_heads, graph)):
+        if not i:
+            continue
+        rels = [x[1] for x in g]
+        heads = [x[0] for x in g]
+        head = tree_tokens[i - 1]["head"]
+        index = heads.index(head)
+        deprel = tree_tokens[i - 1]["deprel"]
+        deprel = deprel.split('>')[-1]
+        # TODO is this necessary?
+        if len(heads) >= 2:
+            heads.pop(index)
+            rels.pop(index)
+        deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
+        tree_tokens[i - 1]["deps"] = deps
+        tree_tokens[i - 1]["deprel"] = deprel
+    return tree
+
+
+def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
+    if len(arc_scores) != tree_heads:
+        arc_scores = arc_scores[:len(tree_heads), :len(tree_heads)]
+        rel_labels = rel_labels[:len(tree_heads), :len(tree_heads)]
+    parse_preds = arc_scores > 0
+    parse_preds[:, 0] = False  # set heads to False
+    # rel_labels[:, :, root_idx] = -float('inf')
+    return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds)
+
+
+def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_preds):
+    if not isinstance(tree_heads, np.ndarray):
+        tree_heads = np.array(tree_heads)
+    dh = np.argwhere(parse_preds)
+    sdh = sorted([(arc_scores[x[0], x[1]], list(x)) for x in dh], reverse=True)
+    graph = [[] for _ in range(len(tree_heads))]
+    for d, h in enumerate(tree_heads):
+        if d:
+            graph[h].append(d)
+    for s, (d, h) in sdh:
+        if not d or not h or d in graph[h]:
+            continue
+        try:
+            path = next(_dfs(graph, d, h))
+        except StopIteration:
+            # no path from d to h
+            graph[h].append(d)
+    parse_graph = [[] for _ in range(len(tree_heads))]
+    num_root = 0
+    for h in range(len(tree_heads)):
+        for d in graph[h]:
+            rel = rel_labels[d, h]
+            if h == 0:
+                rel = root_label
+                assert num_root == 0
+                num_root += 1
+            parse_graph[d].append((h, rel))
+        parse_graph[d] = sorted(parse_graph[d])
+    return parse_graph
+
+
+def _dfs(graph, start, end):
+    fringe = [(start, [])]
+    while fringe:
+        state, path = fringe.pop()
+        if path and state == end:
+            yield path
+            continue
+        for next_state in graph[state]:
+            if next_state in path:
+                continue
+            fringe.append((next_state, path + [next_state]))
diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py
index 28f8efa..ae73db8 100644
--- a/combo/utils/metrics.py
+++ b/combo/utils/metrics.py
@@ -117,6 +117,8 @@ class AttachmentScores(metrics.Metric):
         mask : `torch.BoolTensor`, optional (default = None).
             A tensor of the same shape as `predicted_indices`.
         """
+        if gold_labels is None or gold_indices is None:
+            return
         detached = self.detach_tensors(
             predicted_indices, predicted_labels, gold_indices, gold_labels, mask
         )
@@ -198,6 +200,7 @@ class SemanticMetrics(metrics.Metric):
         self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
         self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True)
         self.attachment_scores = AttachmentScores()
+        self.enhanced_attachment_scores = AttachmentScores()
         self.em_score = 0.0
 
     def __call__(  # type: ignore
@@ -215,14 +218,25 @@ class SemanticMetrics(metrics.Metric):
                                gold_labels["head"],
                                gold_labels["deprel"],
                                mask)
+        self.enhanced_attachment_scores(predictions["enhanced_head"],
+                                        predictions["enhanced_deprel"],
+                                        gold_labels["enhanced_head"],
+                                        gold_labels["enhanced_deprel"],
+                                        mask=None)
+        enhanced_indices = (
+            self.enhanced_attachment_scores.correct_indices.reshape(mask.size(0), mask.size(1) + 1, -1)[:, 1:, 1:].sum(
+                -1).reshape(-1).bool()
+            if len(self.enhanced_attachment_scores.correct_indices.size()) > 0
+            else self.enhanced_attachment_scores.correct_indices
+        )
         total = mask.sum()
         correct_indices = (self.upos_score.correct_indices *
                            self.xpos_score.correct_indices *
                            self.semrel_score.correct_indices *
                            self.feats_score.correct_indices *
                            self.lemma_score.correct_indices *
-                           self.attachment_scores.correct_indices
-                           )
+                           self.attachment_scores.correct_indices *
+                           enhanced_indices)
 
         total, correct_indices = self.detach_tensors(total, correct_indices)
         self.em_score = (correct_indices.float().sum() / total).item()
@@ -237,6 +251,8 @@ class SemanticMetrics(metrics.Metric):
             "EM": self.em_score
         }
         metrics_dict.update(self.attachment_scores.get_metric(reset))
+        enhanced_metrics = {f"E{k}": v for k, v in self.enhanced_attachment_scores.get_metric(reset).items()}
+        metrics_dict.update(enhanced_metrics)
         return metrics_dict
 
     def reset(self) -> None:
@@ -246,4 +262,5 @@ class SemanticMetrics(metrics.Metric):
         self.lemma_score.reset()
         self.feats_score.reset()
         self.attachment_scores.reset()
+        self.enhanced_attachment_scores.reset()
         self.em_score = 0.0
diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet
new file mode 100644
index 0000000..bdb6f0b
--- /dev/null
+++ b/config.graph.template.jsonnet
@@ -0,0 +1,421 @@
+########################################################################################
+#                                 BASIC configuration                                  #
+########################################################################################
+# Training data path, str
+# Must be in CONNLU format (or it's extended version with semantic relation field).
+# Can accepted multiple paths when concatenated with ',', "path1,path2"
+local training_data_path = std.extVar("training_data_path");
+# Validation data path, str
+# Can accepted multiple paths when concatenated with ',', "path1,path2"
+local validation_data_path = if std.length(std.extVar("validation_data_path")) > 0 then std.extVar("validation_data_path");
+# Path to pretrained tokens, str or null
+local pretrained_tokens = if std.length(std.extVar("pretrained_tokens")) > 0 then std.extVar("pretrained_tokens");
+# Name of pretrained transformer model, str or null
+local pretrained_transformer_name = if std.length(std.extVar("pretrained_transformer_name")) > 0 then std.extVar("pretrained_transformer_name");
+# Learning rate value, float
+local learning_rate = 0.002;
+# Number of epochs, int
+local num_epochs = std.parseInt(std.extVar("num_epochs"));
+# Cuda device id, -1 for cpu, int
+local cuda_device = std.parseInt(std.extVar("cuda_device"));
+# Minimum number of words in batch, int
+local word_batch_size = std.parseInt(std.extVar("word_batch_size"));
+# Features used as input, list of str
+# Choice "upostag", "xpostag", "lemma"
+# Required "token", "char"
+local features = std.split(std.extVar("features"), " ");
+# Targets of the model, list of str
+# Choice "feats", "lemma", "upostag", "xpostag", "semrel". "sent"
+# Required "deprel", "head"
+local targets = std.split(std.extVar("targets"), " ");
+# Word embedding dimension, int
+# If pretrained_tokens is not null must much provided dimensionality
+local embedding_dim = std.parseInt(std.extVar("embedding_dim"));
+# Dropout rate on predictors, float
+# All of the models on top of the encoder uses this dropout
+local predictors_dropout = 0.25;
+# Xpostag embedding dimension, int
+# (discarded if xpostag not in features)
+local xpostag_dim = 32;
+# Upostag embedding dimension, int
+# (discarded if upostag not in features)
+local upostag_dim = 32;
+# Feats embedding dimension, int
+# (discarded if feats not in featres)
+local feats_dim = 32;
+# Lemma embedding dimension, int
+# (discarded if lemma not in features)
+local lemma_char_dim = 64;
+# Character embedding dim, int
+local char_dim = 64;
+# Word embedding projection dim, int
+local projected_embedding_dim = 100;
+# Loss weights, dict[str, int]
+local loss_weights = {
+    xpostag: 0.05,
+    upostag: 0.05,
+    lemma: 0.05,
+    feats: 0.2,
+    deprel: 0.8,
+    head: 0.2,
+    semrel: 0.05,
+    enhanced_head: 0.2,
+    enhanced_deprel: 0.8,
+};
+# Encoder hidden size, int
+local hidden_size = 512;
+# Number of layers in the encoder, int
+local num_layers = 2;
+# Cycle loss iterations, int
+local cycle_loss_n = 0;
+# Maximum length of the word, int
+# Shorter words are padded, longer - truncated
+local word_length = 30;
+# Whether to use tensorboard, bool
+local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
+# Path for tensorboard metrics, str
+local metrics_dir = "./runs";
+
+# Helper functions
+local in_features(name) = !(std.length(std.find(name, features)) == 0);
+local in_targets(name) = !(std.length(std.find(name, targets)) == 0);
+local use_transformer = pretrained_transformer_name != null;
+
+# Verify some configuration requirements
+assert in_features("token"): "Key 'token' must be in features!";
+assert in_features("char"): "Key 'char' must be in features!";
+
+assert in_targets("deprel"): "Key 'deprel' must be in targets!";
+assert in_targets("head"): "Key 'head' must be in targets!";
+
+assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use pretrained tokens and pretrained transformer at the same time!";
+
+########################################################################################
+#                              ADVANCED configuration                                  #
+########################################################################################
+
+# Detailed dataset, training, vocabulary and model configuration.
+{
+    # Configuration type (default or finetuning), str
+    type: std.extVar('type'),
+    # Datasets used for vocab creation, list of str
+    # Choice "train", "valid"
+    datasets_for_vocab_creation: ['train'],
+    # Path to training data, str
+    train_data_path: training_data_path,
+    # Path to validation data, str
+    validation_data_path: validation_data_path,
+    # Dataset reader configuration (conllu format)
+    dataset_reader: {
+        type: "conllu",
+        features: features,
+        targets: targets,
+        # Whether data contains semantic relation field, bool
+        use_sem: if in_targets("semrel") then true else false,
+        token_indexers: {
+            token: if use_transformer then {
+                type: "pretrained_transformer_mismatched",
+                model_name: pretrained_transformer_name,
+            } else {
+                # SingleIdTokenIndexer, token as single int
+                type: "single_id",
+            },
+            upostag: {
+                type: "single_id",
+                namespace: "upostag",
+                feature_name: "pos_",
+            },
+            xpostag: {
+                type: "single_id",
+                namespace: "xpostag",
+                feature_name: "tag_",
+            },
+            lemma: {
+                type: "characters_const_padding",
+                character_tokenizer: {
+                    start_tokens: ["__START__"],
+                    end_tokens: ["__END__"],
+                },
+                # +2 for start and end token
+                min_padding_length: word_length + 2,
+            },
+            char: {
+                type: "characters_const_padding",
+                character_tokenizer: {
+                    start_tokens: ["__START__"],
+                    end_tokens: ["__END__"],
+                },
+                # +2 for start and end token
+                min_padding_length: word_length + 2,
+            },
+            feats: {
+                type: "feats_indexer",
+            },
+        },
+        lemma_indexers: {
+            char: {
+                type: "characters_const_padding",
+                namespace: "lemma_characters",
+                character_tokenizer: {
+                    start_tokens: ["__START__"],
+                    end_tokens: ["__END__"],
+                },
+                # +2 for start and end token
+                min_padding_length: word_length + 2,
+            },
+        },
+    },
+    # Data loader configuration
+    data_loader: {
+        batch_sampler: {
+            type: "token_count",
+            word_batch_size: word_batch_size,
+        },
+    },
+    # Vocabulary configuration
+    vocabulary: std.prune({
+        type: "from_instances_extended",
+        only_include_pretrained_words: true,
+        pretrained_files: {
+            tokens: pretrained_tokens,
+        },
+        oov_token: "_",
+        padding_token: "__PAD__",
+        non_padded_namespaces: ["head_labels"],
+    }),
+    model: std.prune({
+        type: "semantic_multitask",
+        text_field_embedder: {
+            type: "basic",
+            token_embedders: {
+                xpostag: if in_features("xpostag") then {
+                    type: "embedding",
+                    padding_index: 0,
+                    embedding_dim: xpostag_dim,
+                    vocab_namespace: "xpostag",
+                },
+                upostag: if in_features("upostag") then {
+                    type: "embedding",
+                    padding_index: 0,
+                    embedding_dim: upostag_dim,
+                    vocab_namespace: "upostag",
+                },
+                token: if use_transformer then {
+                    type: "transformers_word_embeddings",
+                    model_name: pretrained_transformer_name,
+                    projection_dim: projected_embedding_dim,
+                } else {
+                    type: "embeddings_projected",
+                    embedding_dim: embedding_dim,
+                    projection_layer: {
+                        in_features: embedding_dim,
+                        out_features: projected_embedding_dim,
+                        dropout_rate: 0.25,
+                        activation: "tanh"
+                    },
+                    vocab_namespace: "tokens",
+                    pretrained_file: pretrained_tokens,
+                    trainable: if pretrained_tokens == null then true else false,
+                },
+                char: {
+                    type: "char_embeddings_from_config",
+                    embedding_dim: char_dim,
+                    dilated_cnn_encoder: {
+                        input_dim: char_dim,
+                        filters: [512, 256, char_dim],
+                        kernel_size: [3, 3, 3],
+                        stride: [1, 1, 1],
+                        padding: [1, 2, 4],
+                        dilation: [1, 2, 4],
+                        activations: ["relu", "relu", "linear"],
+                    },
+                },
+                lemma: if in_features("lemma") then {
+                    type: "char_embeddings_from_config",
+                    embedding_dim: lemma_char_dim,
+                    dilated_cnn_encoder: {
+                        input_dim: lemma_char_dim,
+                        filters: [512, 256, lemma_char_dim],
+                        kernel_size: [3, 3, 3],
+                        stride: [1, 1, 1],
+                        padding: [1, 2, 4],
+                        dilation: [1, 2, 4],
+                        activations: ["relu", "relu", "linear"],
+                    },
+                },
+                feats: if in_features("feats") then {
+                    type: "feats_embedding",
+                    padding_index: 0,
+                    embedding_dim: feats_dim,
+                    vocab_namespace: "feats",
+                },
+            },
+        },
+        loss_weights: loss_weights,
+        seq_encoder: {
+            type: "combo_encoder",
+            layer_dropout_probability: 0.33,
+            stacked_bilstm: {
+                input_size:
+                (char_dim + projected_embedding_dim +
+                (if in_features('xpostag') then xpostag_dim else 0) +
+                (if in_features('lemma') then lemma_char_dim else 0) +
+                (if in_features('upostag') then upostag_dim else 0) +
+                (if in_features('feats') then feats_dim else 0)),
+                hidden_size: hidden_size,
+                num_layers: num_layers,
+                recurrent_dropout_probability: 0.33,
+                layer_dropout_probability: 0.33
+            },
+        },
+        dependency_relation: {
+            type: "combo_dependency_parsing_from_vocab",
+            vocab_namespace: 'deprel_labels',
+            head_predictor: {
+                local projection_dim = 512,
+                cycle_loss_n: cycle_loss_n,
+                head_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+                dependency_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+            },
+            local projection_dim = 128,
+            head_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+            dependency_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+        },
+        enhanced_dependency_relation: if in_targets("deps") then {
+            type: "combo_graph_dependency_parsing_from_vocab",
+            vocab_namespace: 'deprel_labels',
+            head_predictor: {
+                local projection_dim = 512,
+                cycle_loss_n: cycle_loss_n,
+                head_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+                dependency_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+            },
+            local projection_dim = 128,
+            head_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+            dependency_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+        },
+        morphological_feat: if in_targets("feats") then {
+            type: "combo_morpho_from_vocab",
+            vocab_namespace: "feats_labels",
+            input_dim: hidden_size * 2,
+            hidden_dims: [128],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+        },
+        lemmatizer: if in_targets("lemma") then {
+            type: "combo_lemma_predictor_from_vocab",
+            char_vocab_namespace: "token_characters",
+            lemma_vocab_namespace: "lemma_characters",
+            embedding_dim: 256,
+            input_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: 32,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+            filters: [256, 256, 256],
+            kernel_size: [3, 3, 3, 1],
+            stride: [1, 1, 1, 1],
+            padding: [1, 2, 4, 0],
+            dilation: [1, 2, 4, 1],
+            activations: ["relu", "relu", "relu", "linear"],
+        },
+        upos_tagger: if in_targets("upostag") then {
+            input_dim: hidden_size * 2,
+            hidden_dims: [64],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+            vocab_namespace: "upostag_labels"
+        },
+        xpos_tagger: if in_targets("xpostag") then {
+            input_dim: hidden_size * 2,
+            hidden_dims: [128],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+            vocab_namespace: "xpostag_labels"
+        },
+        semantic_relation: if in_targets("semrel") then {
+            input_dim: hidden_size * 2,
+            hidden_dims: [64],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+            vocab_namespace: "semrel_labels"
+        },
+        regularizer: {
+            regexes: [
+                [".*conv1d.*", {type: "l2", alpha: 1e-6}],
+                [".*forward.*", {type: "l2", alpha: 1e-6}],
+                [".*backward.*", {type: "l2", alpha: 1e-6}],
+                [".*char_embed.*", {type: "l2", alpha: 1e-5}],
+            ],
+        },
+    }),
+    trainer: std.prune({
+        checkpointer: {
+            type: "finishing_only_checkpointer",
+        },
+        type: "gradient_descent_validate_n",
+        cuda_device: cuda_device,
+        grad_clipping: 5.0,
+        num_epochs: num_epochs,
+        optimizer: {
+            type: "adam",
+            lr: learning_rate,
+            betas: [0.9, 0.9],
+        },
+        patience: 1, # it will  be overwriten by callback
+        epoch_callbacks: [
+            { type: "transfer_patience" },
+        ],
+        learning_rate_scheduler: {
+            type: "combo_scheduler",
+        },
+        tensorboard_writer: if use_tensorboard then {
+            serialization_dir: metrics_dir,
+            should_log_learning_rate: false,
+            should_log_parameter_statistics: false,
+            summary_interval: 100,
+        },
+        validation_metric: "+EM",
+    }),
+}
diff --git a/setup.py b/setup.py
index 9529a0c..74540e2 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@ from setuptools import find_packages, setup
 
 REQUIREMENTS = [
     'absl-py==0.9.0',
-    'allennlp==1.2.0',
+    'allennlp==1.2.1',
     'conllu==2.3.2',
     'dataclasses;python_version<"3.7"',
     'dataclasses-json==0.5.2',
diff --git a/tests/fixtures/example.conllu b/tests/fixtures/example.conllu
index 1125392..32e0653 100644
--- a/tests/fixtures/example.conllu
+++ b/tests/fixtures/example.conllu
@@ -4,3 +4,10 @@
 2	Sentence	verylonglemmawhichmustbetruncatedbythesystemto30	NOUN	nom	Number=Sing	0	root	_	_
 3	.	.	PUNCT	.	_	1	punct	_	_
 
+# sent_id = test-s1
+# text = Easy sentence.
+1	Verylongwordwhichmustbetruncatedbythesystemto30	easy	ADJ	adj	AdpType=Prep|Adp	2	amod	_	_
+2	Sentence	verylonglemmawhichmustbetruncatedbythesystemto30	NOUN	nom	Number=Sing	0	root	_	_
+3	.	.	PUNCT	.	_	1	punct	2:mod	_
+4	.	.	PUNCT	.	_	1	punct	2:xmod	_
+
diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py
new file mode 100644
index 0000000..4a65182
--- /dev/null
+++ b/tests/utils/test_graph.py
@@ -0,0 +1,89 @@
+import unittest
+import combo.utils.graph as graph
+
+import conllu
+import numpy as np
+
+
+class GraphTest(unittest.TestCase):
+
+    def test_adding_empty_graph_with_the_same_labels(self):
+        tree = conllu.TokenList(
+            tokens=[
+                {"head": 2, "deprel": "ROOT", "form": "word1"},
+                {"head": 3, "deprel": "yes", "form": "word2"},
+                {"head": 1, "deprel": "yes", "form": "word3"},
+            ]
+        )
+        empty_graph = np.zeros((4, 4))
+        graph_labels = np.array([
+            ["no", "no", "no", "no"],
+            ["no", "no", "ROOT", "no"],
+            ["no", "no", "no", "yes"],
+            ["no", "yes", "no", "no"],
+        ])
+        root_label = "ROOT"
+        expected_deps = ["2:ROOT", "3:yes", "1:yes"]
+
+        # when
+        tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
+        actual_deps = [t["deps"] for t in tree.tokens]
+
+        # then
+        self.assertEqual(actual_deps, expected_deps)
+
+    def test_adding_empty_graph_with_different_labels(self):
+        tree = conllu.TokenList(
+            tokens=[
+                {"head": 2, "deprel": "ROOT", "form": "word1"},
+                {"head": 3, "deprel": "tree_label", "form": "word2"},
+                {"head": 1, "deprel": "tree_label", "form": "word3"},
+            ]
+        )
+        empty_graph = np.zeros((4, 4))
+        graph_labels = np.array([
+            ["no", "no", "no", "no"],
+            ["no", "no", "ROOT", "no"],
+            ["no", "no", "no", "graph_label"],
+            ["no", "graph_label", "no", "no"],
+        ])
+        root_label = "ROOT"
+        expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"]
+
+        # when
+        tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
+        actual_deps = [t["deps"] for t in tree.tokens]
+
+        # then
+        self.assertEqual(actual_deps, expected_deps)
+
+    def test_extending_tree_with_graph(self):
+        # given
+        tree = conllu.TokenList(
+            tokens=[
+                {"head": 0, "deprel": "ROOT", "form": "word1"},
+                {"head": 1, "deprel": "tree_label", "form": "word2"},
+                {"head": 2, "deprel": "tree_label", "form": "word3"},
+            ]
+        )
+        arc_scores = np.array([
+            [0, 0, 0, 0],
+            [1, 0, 0, 0],
+            [0, 1, 0, 0],
+            [0, 1, 1, 0],
+        ])
+        graph_labels = np.array([
+            ["no", "no", "no", "no"],
+            ["ROOT", "no", "no", "no"],
+            ["no", "tree_label", "no", "no"],
+            ["no", "graph_label", "tree_label", "no"],
+        ])
+        root_label = "ROOT"
+        expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"]
+
+        # when
+        tree = graph.sdp_to_dag_deps(arc_scores, graph_labels, tree, root_label)
+        actual_deps = [t["deps"] for t in tree.tokens]
+
+        # then
+        self.assertEqual(actual_deps, expected_deps)
\ No newline at end of file
diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py
index 5b8411b..1d1ad3b 100644
--- a/tests/utils/test_metrics.py
+++ b/tests/utils/test_metrics.py
@@ -27,12 +27,16 @@ class SemanticMetricsTest(unittest.TestCase):
         self.semrel, self.semrel_l = (("semrel", x) for x in [pred, gold])
         self.head, self.head_l = (("head", x) for x in [pred, gold])
         self.deprel, self.deprel_l = (("deprel", x) for x in [pred, gold])
+        # TODO(mklimasz) Add examples with correct dimension (with ROOT token)
+        self.enhanced_head, self.enhanced_head_l = (("enhanced_head", x) for x in [None, None])
+        self.enhanced_deprel, self.enhanced_deprel_l = (("enhanced_deprel", x) for x in [None, None])
         self.feats, self.feats_l = (("feats", x) for x in [pred_seq, gold_seq])
         self.lemma, self.lemma_l = (("lemma", x) for x in [pred_seq, gold_seq])
         self.predictions = dict(
-            [self.upostag, self.xpostag, self.semrel, self.feats, self.lemma, self.head, self.deprel])
+            [self.upostag, self.xpostag, self.semrel, self.feats, self.lemma, self.head, self.deprel,
+             self.enhanced_head, self.enhanced_deprel])
         self.gold_labels = dict([self.upostag_l, self.xpostag_l, self.semrel_l, self.feats_l, self.lemma_l, self.head_l,
-                                 self.deprel_l])
+                                 self.deprel_l, self.enhanced_head_l, self.enhanced_deprel_l])
         self.eps = 1e-6
 
     def test_every_prediction_correct(self):
-- 
GitLab


From 422d12c61e50f31d7ad8f075d1b21abe6ef94d04 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Thu, 19 Nov 2020 10:28:44 +0100
Subject: [PATCH 04/19] Working graph decoding.

---
 combo/data/dataset.py         |  2 ++
 combo/predict.py              |  8 +++++---
 combo/utils/graph.py          | 14 +++++++-------
 config.graph.template.jsonnet |  4 +++-
 4 files changed, 17 insertions(+), 11 deletions(-)

diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index fb770f6..bb56ac3 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -54,6 +54,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
         field_parsers = parser.DEFAULT_FIELD_PARSERS
         # Do not make it nullable
         field_parsers.pop("xpostag", None)
+        # Ignore parsing misc
+        field_parsers.pop("misc", None)
         if self.use_sem:
             fields = list(fields)
             fields.append("semrel")
diff --git a/combo/predict.py b/combo/predict.py
index 55f78c3..6d657f6 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -3,6 +3,7 @@ import os
 from typing import List, Union, Tuple
 
 import conllu
+import numpy as np
 from allennlp import data as allen_data, common, models
 from allennlp.common import util
 from allennlp.data import tokenizers
@@ -195,10 +196,11 @@ class SemanticMultitaskPredictor(predictor.Predictor):
 
         if "enhanced_head" in predictions and predictions["enhanced_head"]:
             import combo.utils.graph as graph
-            tree = graph.sdp_to_dag_deps(arc_scores=predictions["enhanced_head"],
-                                         rel_scores=predictions["enhanced_deprel"],
+            tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
+                                         rel_scores=np.array(predictions["enhanced_deprel"]),
                                          tree=tree,
-                                         root_label="ROOT")
+                                         root_label="ROOT",
+                                         vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
 
         return tree, predictions["sentence_embedding"]
 
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index 5970b19..fd59e97 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -3,7 +3,7 @@ import numpy as np
 from conllu import TokenList
 
 
-def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label):
+def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None):
     # adding ROOT
     tree_tokens = tree.tokens
     tree_heads = [0] + [t["head"] for t in tree_tokens]
@@ -12,7 +12,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label):
     for i, (t, g) in enumerate(zip(tree_heads, graph)):
         if not i:
             continue
-        rels = [x[1] for x in g]
+        rels = [vocab_index.get(x[1], "ROOT") if vocab_index else x[1] for x in g]
         heads = [x[0] for x in g]
         head = tree_tokens[i - 1]["head"]
         index = heads.index(head)
@@ -30,9 +30,9 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label):
 
 def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
     if len(arc_scores) != tree_heads:
-        arc_scores = arc_scores[:len(tree_heads), :len(tree_heads)]
-        rel_labels = rel_labels[:len(tree_heads), :len(tree_heads)]
-    parse_preds = arc_scores > 0
+        arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
+        rel_labels = rel_labels[:len(tree_heads)][:len(tree_heads)]
+    parse_preds = np.array(arc_scores) > 0
     parse_preds[:, 0] = False  # set heads to False
     # rel_labels[:, :, root_idx] = -float('inf')
     return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds)
@@ -42,7 +42,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre
     if not isinstance(tree_heads, np.ndarray):
         tree_heads = np.array(tree_heads)
     dh = np.argwhere(parse_preds)
-    sdh = sorted([(arc_scores[x[0], x[1]], list(x)) for x in dh], reverse=True)
+    sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True)
     graph = [[] for _ in range(len(tree_heads))]
     for d, h in enumerate(tree_heads):
         if d:
@@ -59,7 +59,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre
     num_root = 0
     for h in range(len(tree_heads)):
         for d in graph[h]:
-            rel = rel_labels[d, h]
+            rel = rel_labels[d][h]
             if h == 0:
                 rel = root_label
                 assert num_root == 0
diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet
index bdb6f0b..d55cb89 100644
--- a/config.graph.template.jsonnet
+++ b/config.graph.template.jsonnet
@@ -114,8 +114,10 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
         use_sem: if in_targets("semrel") then true else false,
         token_indexers: {
             token: if use_transformer then {
-                type: "pretrained_transformer_mismatched",
+                type: "pretrained_transformer_mismatched_fixed",
                 model_name: pretrained_transformer_name,
+                tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
+                                  then {use_fast: false} else {},
             } else {
                 # SingleIdTokenIndexer, token as single int
                 type: "single_id",
-- 
GitLab


From a2137989f0a28537f75c193bc2e3915b20ff8bf2 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Mon, 23 Nov 2020 16:12:47 +0100
Subject: [PATCH 05/19] Fix deps prediction for MWE expressions.

---
 combo/predict.py          | 13 ++++++-------
 combo/utils/graph.py      |  8 ++++----
 tests/utils/test_graph.py |  8 ++++----
 3 files changed, 14 insertions(+), 15 deletions(-)

diff --git a/combo/predict.py b/combo/predict.py
index 6d657f6..42e8bed 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -12,7 +12,7 @@ from overrides import overrides
 
 from combo import data
 from combo.data import sentence2conllu, tokens2conllu, conllu2sentence
-from combo.utils import download
+from combo.utils import download, graph
 
 logger = logging.getLogger(__name__)
 
@@ -195,12 +195,11 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                         raise NotImplementedError(f"Unknown field name {field_name}!")
 
         if "enhanced_head" in predictions and predictions["enhanced_head"]:
-            import combo.utils.graph as graph
-            tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
-                                         rel_scores=np.array(predictions["enhanced_deprel"]),
-                                         tree=tree,
-                                         root_label="ROOT",
-                                         vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
+            graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
+                                  rel_scores=np.array(predictions["enhanced_deprel"]),
+                                  tree_tokens=tree_tokens,
+                                  root_label="ROOT",
+                                  vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
 
         return tree, predictions["sentence_embedding"]
 
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index fd59e97..814341b 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -1,11 +1,11 @@
 """Based on https://github.com/emorynlp/iwpt-shared-task-2020."""
+from typing import List
+
 import numpy as np
-from conllu import TokenList
 
 
-def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None):
+def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab_index=None) -> None:
     # adding ROOT
-    tree_tokens = tree.tokens
     tree_heads = [0] + [t["head"] for t in tree_tokens]
     graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
                                                       root_label)
@@ -25,7 +25,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_i
         deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
         tree_tokens[i - 1]["deps"] = deps
         tree_tokens[i - 1]["deprel"] = deprel
-    return tree
+    return
 
 
 def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py
index 4a65182..4c5c7d3 100644
--- a/tests/utils/test_graph.py
+++ b/tests/utils/test_graph.py
@@ -26,7 +26,7 @@ class GraphTest(unittest.TestCase):
         expected_deps = ["2:ROOT", "3:yes", "1:yes"]
 
         # when
-        tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
+        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
@@ -51,7 +51,7 @@ class GraphTest(unittest.TestCase):
         expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"]
 
         # when
-        tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
+        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
@@ -82,8 +82,8 @@ class GraphTest(unittest.TestCase):
         expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"]
 
         # when
-        tree = graph.sdp_to_dag_deps(arc_scores, graph_labels, tree, root_label)
+        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_label)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
-        self.assertEqual(actual_deps, expected_deps)
\ No newline at end of file
+        self.assertEqual(actual_deps, expected_deps)
-- 
GitLab


From 50ea5469b0fc6af570ef4ab4b22f30b27bb31c34 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Fri, 4 Dec 2020 11:49:15 +0100
Subject: [PATCH 06/19] Sort feats by lowercase to match iwpt script.

---
 combo/predict.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/combo/predict.py b/combo/predict.py
index 42e8bed..e262b70 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -173,7 +173,9 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                         if len(features) == 0:
                             field_value = "_"
                         else:
-                            field_value = "|".join(sorted(features))
+                            lowercase_features = [f.lower() for f in features]
+                            arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__)
+                            field_value = "|".join(np.array(features)[arg_indices].tolist())
 
                         token[field_name] = field_value
                     elif field_name == "lemma":
-- 
GitLab


From fd024dcc4d2b20daa03f908c89a45efb7aa1d910 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Mon, 7 Dec 2020 15:23:27 +0100
Subject: [PATCH 07/19] Pass relation probabilities in graph extraction.

---
 combo/models/graph_parser.py |  1 +
 combo/models/model.py        |  1 +
 combo/predict.py             |  4 +--
 combo/utils/graph.py         | 21 +++++++--------
 tests/utils/test_graph.py    | 50 ++++++++++++++----------------------
 5 files changed, 34 insertions(+), 43 deletions(-)

diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py
index a31e6d0..6799d49 100644
--- a/combo/models/graph_parser.py
+++ b/combo/models/graph_parser.py
@@ -148,6 +148,7 @@ class GraphDependencyRelationModel(base.Predictor):
         output = head_output
 
         output["prediction"] = (relation_prediction.argmax(-1), head_output["prediction"])
+        output["rel_probability"] = relation_prediction
 
         if labels is not None and labels[0] is not None:
             if sample_weights is None:
diff --git a/combo/models/model.py b/combo/models/model.py
index 124f49a..710f72c 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -126,6 +126,7 @@ class SemanticMultitaskModel(allen_models.Model):
             "deprel": relations_pred,
             "enhanced_head": enhanced_head_pred,
             "enhanced_deprel": enhanced_relations_pred,
+            "enhanced_deprel_prob": enhanced_parser_output["rel_probability"],
             "sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0],
         }
 
diff --git a/combo/predict.py b/combo/predict.py
index e262b70..070975f 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -198,9 +198,9 @@ class SemanticMultitaskPredictor(predictor.Predictor):
 
         if "enhanced_head" in predictions and predictions["enhanced_head"]:
             graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
-                                  rel_scores=np.array(predictions["enhanced_deprel"]),
+                                  rel_scores=np.array(predictions["enhanced_deprel_prob"]),
                                   tree_tokens=tree_tokens,
-                                  root_label="ROOT",
+                                  root_idx=self.vocab.get_token_index("root", "deprel_labels"),
                                   vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
 
         return tree, predictions["sentence_embedding"]
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index 814341b..8a55cb9 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -4,15 +4,15 @@ from typing import List
 import numpy as np
 
 
-def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab_index=None) -> None:
+def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None:
     # adding ROOT
     tree_heads = [0] + [t["head"] for t in tree_tokens]
     graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
-                                                      root_label)
+                                                      root_idx)
     for i, (t, g) in enumerate(zip(tree_heads, graph)):
         if not i:
             continue
-        rels = [vocab_index.get(x[1], "ROOT") if vocab_index else x[1] for x in g]
+        rels = [vocab_index.get(x[1], "root") if vocab_index else x[1] for x in g]
         heads = [x[0] for x in g]
         head = tree_tokens[i - 1]["head"]
         index = heads.index(head)
@@ -28,22 +28,23 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab
     return
 
 
-def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
+def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx):
     if len(arc_scores) != tree_heads:
         arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
-        rel_labels = rel_labels[:len(tree_heads)][:len(tree_heads)]
+        rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)]
     parse_preds = np.array(arc_scores) > 0
     parse_preds[:, 0] = False  # set heads to False
-    # rel_labels[:, :, root_idx] = -float('inf')
-    return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds)
+    rel_scores[:, :, root_idx] = -float('inf')
+    return add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds)
 
 
-def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_preds):
+def add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds):
     if not isinstance(tree_heads, np.ndarray):
         tree_heads = np.array(tree_heads)
     dh = np.argwhere(parse_preds)
     sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True)
     graph = [[] for _ in range(len(tree_heads))]
+    rel_pred = np.argmax(rel_scores, axis=-1)
     for d, h in enumerate(tree_heads):
         if d:
             graph[h].append(d)
@@ -59,9 +60,9 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre
     num_root = 0
     for h in range(len(tree_heads)):
         for d in graph[h]:
-            rel = rel_labels[d][h]
+            rel = rel_pred[d][h]
             if h == 0:
-                rel = root_label
+                rel = root_idx
                 assert num_root == 0
                 num_root += 1
             parse_graph[d].append((h, rel))
diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py
index 4c5c7d3..0a66212 100644
--- a/tests/utils/test_graph.py
+++ b/tests/utils/test_graph.py
@@ -10,48 +10,40 @@ class GraphTest(unittest.TestCase):
     def test_adding_empty_graph_with_the_same_labels(self):
         tree = conllu.TokenList(
             tokens=[
-                {"head": 2, "deprel": "ROOT", "form": "word1"},
+                {"head": 0, "deprel": "root", "form": "word1"},
                 {"head": 3, "deprel": "yes", "form": "word2"},
                 {"head": 1, "deprel": "yes", "form": "word3"},
             ]
         )
+        vocab_index = {0: "root", 1: "yes", 2: "yes", 3: "yes"}
         empty_graph = np.zeros((4, 4))
-        graph_labels = np.array([
-            ["no", "no", "no", "no"],
-            ["no", "no", "ROOT", "no"],
-            ["no", "no", "no", "yes"],
-            ["no", "yes", "no", "no"],
-        ])
-        root_label = "ROOT"
-        expected_deps = ["2:ROOT", "3:yes", "1:yes"]
+        graph_labels = np.zeros((4, 4, 4))
+        expected_deps = ["0:root", "3:yes", "1:yes"]
 
         # when
-        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
+        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
-        self.assertEqual(actual_deps, expected_deps)
+        self.assertEqual(expected_deps, actual_deps)
 
     def test_adding_empty_graph_with_different_labels(self):
         tree = conllu.TokenList(
             tokens=[
-                {"head": 2, "deprel": "ROOT", "form": "word1"},
+                {"head": 0, "deprel": "root", "form": "word1"},
                 {"head": 3, "deprel": "tree_label", "form": "word2"},
                 {"head": 1, "deprel": "tree_label", "form": "word3"},
             ]
         )
+        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
         empty_graph = np.zeros((4, 4))
-        graph_labels = np.array([
-            ["no", "no", "no", "no"],
-            ["no", "no", "ROOT", "no"],
-            ["no", "no", "no", "graph_label"],
-            ["no", "graph_label", "no", "no"],
-        ])
-        root_label = "ROOT"
-        expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"]
+        graph_labels = np.zeros((4, 4, 3))
+        graph_labels[2][3][2] = 10e10
+        graph_labels[3][1][2] = 10e10
+        expected_deps = ["0:root", "3:graph_label", "1:graph_label"]
 
         # when
-        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
+        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
@@ -61,28 +53,24 @@ class GraphTest(unittest.TestCase):
         # given
         tree = conllu.TokenList(
             tokens=[
-                {"head": 0, "deprel": "ROOT", "form": "word1"},
+                {"head": 0, "deprel": "root", "form": "word1"},
                 {"head": 1, "deprel": "tree_label", "form": "word2"},
                 {"head": 2, "deprel": "tree_label", "form": "word3"},
             ]
         )
+        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
         arc_scores = np.array([
             [0, 0, 0, 0],
             [1, 0, 0, 0],
             [0, 1, 0, 0],
             [0, 1, 1, 0],
         ])
-        graph_labels = np.array([
-            ["no", "no", "no", "no"],
-            ["ROOT", "no", "no", "no"],
-            ["no", "tree_label", "no", "no"],
-            ["no", "graph_label", "tree_label", "no"],
-        ])
-        root_label = "ROOT"
-        expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"]
+        graph_labels = np.zeros((4, 4, 3))
+        graph_labels[3][1][2] = 10e10
+        expected_deps = ["0:root", "1:tree_label", "1:graph_label"]
 
         # when
-        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_label)
+        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens,  root_idx=0, vocab_index=vocab_index)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
-- 
GitLab


From e295e563cd154f48a9b1f180b91355c2e6426497 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Mon, 7 Dec 2020 16:00:09 +0100
Subject: [PATCH 08/19] Exclude self loops.

---
 combo/utils/graph.py      | 12 ++++++++----
 tests/utils/test_graph.py | 31 ++++++++++++++++++++++++++++++-
 2 files changed, 38 insertions(+), 5 deletions(-)

diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index 8a55cb9..1785b4b 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -18,10 +18,12 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab
         index = heads.index(head)
         deprel = tree_tokens[i - 1]["deprel"]
         deprel = deprel.split('>')[-1]
-        # TODO is this necessary?
-        if len(heads) >= 2:
-            heads.pop(index)
-            rels.pop(index)
+        # TODO - Consider if there should be a condition,
+        # It doesn't seem to make any sense as DEPS should contain DEPREL
+        # (although sometimes with different/more detailed label)
+        # if len(heads) >= 2:
+        #     heads.pop(index)
+        #     rels.pop(index)
         deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
         tree_tokens[i - 1]["deps"] = deps
         tree_tokens[i - 1]["deprel"] = deprel
@@ -32,6 +34,8 @@ def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads
     if len(arc_scores) != tree_heads:
         arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
         rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)]
+    # Self-loops aren't allowed, mask with 0. This is an in-place operation.
+    np.fill_diagonal(arc_scores, 0)
     parse_preds = np.array(arc_scores) > 0
     parse_preds[:, 0] = False  # set heads to False
     rel_scores[:, :, root_idx] = -float('inf')
diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py
index 0a66212..74e3744 100644
--- a/tests/utils/test_graph.py
+++ b/tests/utils/test_graph.py
@@ -67,7 +67,7 @@ class GraphTest(unittest.TestCase):
         ])
         graph_labels = np.zeros((4, 4, 3))
         graph_labels[3][1][2] = 10e10
-        expected_deps = ["0:root", "1:tree_label", "1:graph_label"]
+        expected_deps = ["0:root", "1:tree_label", "1:graph_label|2:tree_label"]
 
         # when
         graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens,  root_idx=0, vocab_index=vocab_index)
@@ -75,3 +75,32 @@ class GraphTest(unittest.TestCase):
 
         # then
         self.assertEqual(actual_deps, expected_deps)
+
+    def test_extending_tree_with_self_loop_edge_shouldnt_add_edge(self):
+        # given
+        tree = conllu.TokenList(
+            tokens=[
+                {"head": 0, "deprel": "root", "form": "word1"},
+                {"head": 1, "deprel": "tree_label", "form": "word2"},
+                {"head": 2, "deprel": "tree_label", "form": "word3"},
+            ]
+        )
+        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
+        arc_scores = np.array([
+            [0, 0, 0, 0],
+            [1, 0, 0, 0],
+            [0, 1, 0, 0],
+            [0, 0, 1, 1],
+        ])
+        graph_labels = np.zeros((4, 4, 3))
+        graph_labels[3][3][2] = 10e10
+        expected_deps = ["0:root", "1:tree_label", "2:tree_label"]
+        # TODO current actual, adds self-loop
+        # actual_deps = ["0:root", "1:tree_label", "2:tree_label|3:graph_label"]
+
+        # when
+        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens,  root_idx=0, vocab_index=vocab_index)
+        actual_deps = [t["deps"] for t in tree.tokens]
+
+        # then
+        self.assertEqual(expected_deps, actual_deps)
-- 
GitLab


From 826e57a756c8ba96b26c06406c03ac57807b6e19 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Fri, 11 Dec 2020 10:19:12 +0100
Subject: [PATCH 09/19] Hotfix off by one in enhanced graphs.

---
 combo/models/graph_parser.py | 6 +++---
 combo/predict.py             | 9 +++++++--
 2 files changed, 10 insertions(+), 5 deletions(-)

diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py
index 6799d49..2dc02dc 100644
--- a/combo/models/graph_parser.py
+++ b/combo/models/graph_parser.py
@@ -164,9 +164,9 @@ class GraphDependencyRelationModel(base.Predictor):
               heads_true: torch.Tensor,
               mask: torch.BoolTensor,
               sample_weights: torch.Tensor) -> torch.Tensor:
-
-        true = true[true.long() > 0]
-        pred = pred[heads_true.long() == 1]
+        correct_heads_mask = heads_true.long() == 1
+        true = true[correct_heads_mask]
+        pred = pred[correct_heads_mask]
         loss = F.cross_entropy(pred, true.long())
         return loss.sum() / pred.size(0)
 
diff --git a/combo/predict.py b/combo/predict.py
index 070975f..e52b42e 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -197,8 +197,13 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                         raise NotImplementedError(f"Unknown field name {field_name}!")
 
         if "enhanced_head" in predictions and predictions["enhanced_head"]:
-            graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
-                                  rel_scores=np.array(predictions["enhanced_deprel_prob"]),
+            # TODO off-by-one hotfix, refactor
+            h = np.array(predictions["enhanced_head"])
+            h = np.concatenate((h[-1:], h[:-1]))
+            r = np.array(predictions["enhanced_deprel_prob"])
+            r = np.concatenate((r[-1:], r[:-1]))
+            graph.sdp_to_dag_deps(arc_scores=h,
+                                  rel_scores=r,
                                   tree_tokens=tree_tokens,
                                   root_idx=self.vocab.get_token_index("root", "deprel_labels"),
                                   vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
-- 
GitLab


From 8e09e3360bbcc24face4b076e78cc47b96691cdb Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Fri, 11 Dec 2020 13:12:05 +0100
Subject: [PATCH 10/19] Add restoring collapsed edges (gapping).

---
 combo/predict.py     |  2 ++
 combo/utils/graph.py | 27 +++++++++++++++++++++++++++
 2 files changed, 29 insertions(+)

diff --git a/combo/predict.py b/combo/predict.py
index e52b42e..c58db25 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -207,6 +207,8 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                                   tree_tokens=tree_tokens,
                                   root_idx=self.vocab.get_token_index("root", "deprel_labels"),
                                   vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
+            empty_tokens = graph.restore_collapse_edges(tree_tokens)
+            tree.tokens.extend(empty_tokens)
 
         return tree, predictions["sentence_embedding"]
 
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index 1785b4b..32e7dd9 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -85,3 +85,30 @@ def _dfs(graph, start, end):
             if next_state in path:
                 continue
             fringe.append((next_state, path + [next_state]))
+
+
+def restore_collapse_edges(tree_tokens):
+    empty_tokens = []
+    for token in tree_tokens:
+        deps = token["deps"].split("|")
+        for i, d in enumerate(deps):
+            if ">" in d:
+                # {head}:{empty_node_relation}>{current_node_relation}
+                # should map to
+                # For new, empty node:
+                # {head}:{empty_node_relation}
+                # For current node:
+                # {new_empty_node_id}:{current_node_relation}
+                # TODO consider where to put new_empty_node_id (currently at the end)
+                head, relation = d.split(':', 1)
+                ehead = f"{len(tree_tokens)}.{len(empty_tokens) + 1}"
+                empty_node_relation, current_node_relation = relation.split(">", 1)
+                deps[i] = f"{ehead}:{current_node_relation}"
+                empty_tokens.append(
+                    {
+                        "id": ehead,
+                        "deps": f"{head}:{empty_node_relation}"
+                    }
+                )
+        token["deps"] = "|".join(deps)
+    return empty_tokens
-- 
GitLab


From 51be66807bc514b22cddff95eb3adf9201030b91 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Fri, 11 Dec 2020 13:50:50 +0100
Subject: [PATCH 11/19] Add documentation EUD model training.

---
 docs/training.md | 22 +++++++++++++++++++++-
 1 file changed, 21 insertions(+), 1 deletion(-)

diff --git a/docs/training.md b/docs/training.md
index 9dc430a..f4b6d38 100644
--- a/docs/training.md
+++ b/docs/training.md
@@ -43,7 +43,27 @@ Examples (for clarity without training/validation data paths):
     ```bash
     combo --mode train --targets head,deprel --features token,char,upostag
     ```
-  
+
+## Enhanced UD
+
+Training a model with Enhanced UD prediction **requires** data pre-processing.
+
+```bash
+combo --mode train \
+      --training_data_path your_preprocessed_training_path \
+      --validation_data_path your_preprocessed_validation_path \
+      --targets feats,upostag,xpostag,head,deprel,lemma,deps \
+      --config_path config.graph.template.jsonnet
+```
+### Data pre-processing
+Download data from [IWPT20 Shared Task](https://universaldependencies.org/iwpt20/data.html).
+It contains `enhanced_collapse_empty_nodes.pl` script which is required as pre-processing step.
+Apply this script to training and validation data.
+
+```bash
+perl enhanced_collapse_empty_nodes.pl training.conllu > training.fixed.conllu
+``` 
+
 ## Configuration
 
 ### Advanced
-- 
GitLab


From 7ffc1072953f43d6b35575fd4e145dc211901f21 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Mon, 21 Dec 2020 14:10:43 +0100
Subject: [PATCH 12/19] Fix enhanced dependency parsing metrics.

---
 combo/utils/metrics.py        | 7 ++++++-
 config.graph.template.jsonnet | 3 ---
 config.template.jsonnet       | 3 ---
 tests/utils/test_metrics.py   | 2 +-
 4 files changed, 7 insertions(+), 8 deletions(-)

diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py
index ae73db8..682e885 100644
--- a/combo/utils/metrics.py
+++ b/combo/utils/metrics.py
@@ -140,10 +140,14 @@ class AttachmentScores(metrics.Metric):
 
         correct_indices = predicted_indices.eq(gold_indices).long() * mask
         unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1)
+        if len(correct_indices.size()) > 2:
+            unlabeled_exact_match = unlabeled_exact_match.prod(dim=-1)
         correct_labels = predicted_labels.eq(gold_labels).long() * mask
         correct_labels_and_indices = correct_indices * correct_labels
         self.correct_indices = correct_labels_and_indices.flatten()
         labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1)
+        if len(correct_indices.size()) > 2:
+            labeled_exact_match = labeled_exact_match.prod(dim=-1)
 
         self._unlabeled_correct += correct_indices.sum()
         self._exact_unlabeled_correct += unlabeled_exact_match.sum()
@@ -200,7 +204,8 @@ class SemanticMetrics(metrics.Metric):
         self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
         self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True)
         self.attachment_scores = AttachmentScores()
-        self.enhanced_attachment_scores = AttachmentScores()
+        # Ignore PADDING and OOV
+        self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1])
         self.em_score = 0.0
 
     def __call__(  # type: ignore
diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet
index d55cb89..6975aba 100644
--- a/config.graph.template.jsonnet
+++ b/config.graph.template.jsonnet
@@ -73,8 +73,6 @@ local cycle_loss_n = 0;
 local word_length = 30;
 # Whether to use tensorboard, bool
 local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
-# Path for tensorboard metrics, str
-local metrics_dir = "./runs";
 
 # Helper functions
 local in_features(name) = !(std.length(std.find(name, features)) == 0);
@@ -413,7 +411,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
             type: "combo_scheduler",
         },
         tensorboard_writer: if use_tensorboard then {
-            serialization_dir: metrics_dir,
             should_log_learning_rate: false,
             should_log_parameter_statistics: false,
             summary_interval: 100,
diff --git a/config.template.jsonnet b/config.template.jsonnet
index 8e5ddc9..f41ba62 100644
--- a/config.template.jsonnet
+++ b/config.template.jsonnet
@@ -71,8 +71,6 @@ local cycle_loss_n = 0;
 local word_length = 30;
 # Whether to use tensorboard, bool
 local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
-# Path for tensorboard metrics, str
-local metrics_dir = "./runs";
 
 # Helper functions
 local in_features(name) = !(std.length(std.find(name, features)) == 0);
@@ -382,7 +380,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
             type: "combo_scheduler",
         },
         tensorboard_writer: if use_tensorboard then {
-            serialization_dir: metrics_dir,
             should_log_learning_rate: false,
             should_log_parameter_statistics: false,
             summary_interval: 100,
diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py
index 1d1ad3b..242eaa3 100644
--- a/tests/utils/test_metrics.py
+++ b/tests/utils/test_metrics.py
@@ -27,7 +27,7 @@ class SemanticMetricsTest(unittest.TestCase):
         self.semrel, self.semrel_l = (("semrel", x) for x in [pred, gold])
         self.head, self.head_l = (("head", x) for x in [pred, gold])
         self.deprel, self.deprel_l = (("deprel", x) for x in [pred, gold])
-        # TODO(mklimasz) Add examples with correct dimension (with ROOT token)
+        # TODO(mklimasz) Set up an example with size 3x5x5
         self.enhanced_head, self.enhanced_head_l = (("enhanced_head", x) for x in [None, None])
         self.enhanced_deprel, self.enhanced_deprel_l = (("enhanced_deprel", x) for x in [None, None])
         self.feats, self.feats_l = (("feats", x) for x in [pred_seq, gold_seq])
-- 
GitLab


From d0dc576f2a4d002d58cacc4e4016f236bd053cae Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Tue, 22 Dec 2020 18:52:29 +0100
Subject: [PATCH 13/19] Refactor predictor name, speed-up dataset reader and
 graph config.

---
 README.md                     |  4 ++--
 combo/data/dataset.py         |  4 ++--
 combo/main.py                 |  2 +-
 combo/models/model.py         |  4 +++-
 combo/predict.py              |  2 +-
 config.graph.template.jsonnet |  2 ++
 docs/models.md                |  4 ++--
 docs/prediction.md            | 14 ++++++++++++--
 tests/test_predict.py         |  2 +-
 9 files changed, 26 insertions(+), 12 deletions(-)

diff --git a/README.md b/README.md
index 19847b8..c339bda 100644
--- a/README.md
+++ b/README.md
@@ -18,9 +18,9 @@ python setup.py develop
 ```
 Run the following lines in your Python console to make predictions with a pre-trained model:
 ```python
-import combo.predict as predict
+from combo.predict import COMBO
 
-nlp = predict.SemanticMultitaskPredictor.from_pretrained("polish-herbert-base")
+nlp = COMBO.from_pretrained("polish-herbert-base")
 sentence = nlp("Moje zdanie.")
 print(sentence.tokens)
 ```
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index bb56ac3..48b68b1 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -119,8 +119,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                                                                                label_namespace=target_name + "_labels")
                     elif target_name == "deps":
                         # Graphs require adding ROOT (AdjacencyField uses sequence length from TextField).
-                        text_field_deps = copy.deepcopy(text_field)
-                        text_field_deps.tokens.insert(0, _Token("ROOT"))
+                        text_field_deps = allen_fields.TextField([_Token("ROOT")] + copy.deepcopy(tokens),
+                                                                 self._token_indexers)
                         enhanced_heads: List[Tuple[int, int]] = []
                         enhanced_deprels: List[str] = []
                         for idx, t in enumerate(tree_tokens):
diff --git a/combo/main.py b/combo/main.py
index 374af69..17c960a 100644
--- a/combo/main.py
+++ b/combo/main.py
@@ -136,7 +136,7 @@ def run(_):
             params = common.Params.from_file(FLAGS.config_path, ext_vars=_get_ext_vars())["dataset_reader"]
             params.pop("type")
             dataset_reader = dataset.UniversalDependenciesDatasetReader.from_params(params)
-            predictor = predict.SemanticMultitaskPredictor(
+            predictor = predict.COMBO(
                 model=model,
                 dataset_reader=dataset_reader
             )
diff --git a/combo/models/model.py b/combo/models/model.py
index 710f72c..ad0df0e 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -126,10 +126,12 @@ class SemanticMultitaskModel(allen_models.Model):
             "deprel": relations_pred,
             "enhanced_head": enhanced_head_pred,
             "enhanced_deprel": enhanced_relations_pred,
-            "enhanced_deprel_prob": enhanced_parser_output["rel_probability"],
             "sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0],
         }
 
+        if "rel_probability" in enhanced_parser_output:
+            output["enhanced_deprel_prob"] = enhanced_parser_output["rel_probability"]
+
         if self._has_labels([upostag, xpostag, lemma, feats, head, deprel, semrel]):
 
             # Feats mapping
diff --git a/combo/predict.py b/combo/predict.py
index c58db25..8d3e2f9 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
 
 @predictor.Predictor.register("semantic-multitask-predictor")
 @predictor.Predictor.register("semantic-multitask-predictor-spacy", constructor="with_spacy_tokenizer")
-class SemanticMultitaskPredictor(predictor.Predictor):
+class COMBO(predictor.Predictor):
 
     def __init__(self,
                  model: models.Model,
diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet
index 6975aba..bc8c465 100644
--- a/config.graph.template.jsonnet
+++ b/config.graph.template.jsonnet
@@ -204,6 +204,8 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
                     type: "transformers_word_embeddings",
                     model_name: pretrained_transformer_name,
                     projection_dim: projected_embedding_dim,
+                    tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
+                                      then {use_fast: false} else {},
                 } else {
                     type: "embeddings_projected",
                     embedding_dim: embedding_dim,
diff --git a/docs/models.md b/docs/models.md
index 485f761..d4346ff 100644
--- a/docs/models.md
+++ b/docs/models.md
@@ -5,9 +5,9 @@ Pre-trained models are available [here](http://mozart.ipipan.waw.pl/~mklimaszews
 ## Automatic download
 Python `from_pretrained` method will download the pre-trained model if the provided name (without the extension .tar.gz) matches one of the names in [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/).
 ```python
-import combo.predict as predict
+from combo.predict import COMBO
 
-nlp = predict.SemanticMultitaskPredictor.from_pretrained("polish-herbert-base")
+nlp = COMBO.from_pretrained("polish-herbert-base")
 ```
 Otherwise it looks for a model in local env.
 
diff --git a/docs/prediction.md b/docs/prediction.md
index 89cc74c..6de5d0e 100644
--- a/docs/prediction.md
+++ b/docs/prediction.md
@@ -32,9 +32,19 @@ Use either `--predictor_name semantic-multitask-predictor` or `--predictor_name
 
 ## Python
 ```python
-import combo.predict as predict
+from combo.predict import COMBO
 
 model_path = "your_model.tar.gz"
-nlp = predict.SemanticMultitaskPredictor.from_pretrained(model_path)
+nlp = COMBO.from_pretrained(model_path)
 sentence = nlp("Sentence to parse.")
 ```
+
+Using your own tokenization:
+```python
+from combo.predict import COMBO
+
+model_path = "your_model.tar.gz"
+nlp = COMBO.from_pretrained(model_path)
+tokenized_sentence = ["Sentence", "to", "parse", "."]
+sentence = nlp([tokenized_sentence])
+```
diff --git a/tests/test_predict.py b/tests/test_predict.py
index 2a56bd9..332ced3 100644
--- a/tests/test_predict.py
+++ b/tests/test_predict.py
@@ -22,7 +22,7 @@ class PredictionTest(unittest.TestCase):
             data.Token(id=2, token=".")
         ])]
         api_wrapped_tokenized_sentence = [data.conllu2sentence(data.tokens2conllu(["Test", "."]), [])]
-        nlp = predict.SemanticMultitaskPredictor.from_pretrained(os.path.join(self.FIXTURES_ROOT, "model.tar.gz"))
+        nlp = predict.COMBO.from_pretrained(os.path.join(self.FIXTURES_ROOT, "model.tar.gz"))
 
         # when
         results = [
-- 
GitLab


From 23e0c9ce45e63dbcf031ba8965f85590f331a92d Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Wed, 23 Dec 2020 09:14:09 +0100
Subject: [PATCH 14/19] Sort deps when uncollapsing nodes, mask root label
 possibility when root isn't head of a token.

---
 combo/models/parser.py | 25 +++++++++++++++++++++++--
 combo/utils/graph.py   |  3 ++-
 2 files changed, 25 insertions(+), 3 deletions(-)

diff --git a/combo/models/parser.py b/combo/models/parser.py
index 486b248..4b5b126 100644
--- a/combo/models/parser.py
+++ b/combo/models/parser.py
@@ -1,4 +1,5 @@
 """Dependency parsing models."""
+import math
 from typing import Tuple, Dict, Optional, Union, List
 
 import numpy as np
@@ -115,11 +116,13 @@ class DependencyRelationModel(base.Predictor):
     """Dependency relation parsing model."""
 
     def __init__(self,
+                 root_idx: int,
                  head_predictor: HeadPredictionModel,
                  head_projection_layer: base.Linear,
                  dependency_projection_layer: base.Linear,
                  relation_prediction_layer: base.Linear):
         super().__init__()
+        self.root_idx = root_idx
         self.head_predictor = head_predictor
         self.head_projection_layer = head_projection_layer
         self.dependency_projection_layer = dependency_projection_layer
@@ -130,6 +133,7 @@ class DependencyRelationModel(base.Predictor):
                 mask: Optional[torch.BoolTensor] = None,
                 labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
                 sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        device = x.device
         if mask is not None:
             mask = mask[:, 1:]
         relations_labels, head_labels = None, None
@@ -151,7 +155,23 @@ class DependencyRelationModel(base.Predictor):
         relation_prediction = self.relation_prediction_layer(dep_rel_pred)
         output = head_output
 
-        output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
+        if self.training:
+            output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
+        else:
+            # Mask root label whenever head is not 0.
+            relation_prediction_output = relation_prediction[:, 1:]
+            mask = (head_output["prediction"] == 0)
+            vocab_size = relation_prediction_output.size(-1)
+            root_idx = torch.tensor([self.root_idx], device=device)
+            relation_prediction_output[mask] = (relation_prediction_output
+                                                .masked_select(mask.unsqueeze(-1))
+                                                .reshape(-1, vocab_size)
+                                                .index_fill(-1, root_idx, 10e10))
+            relation_prediction_output[~mask] = (relation_prediction_output
+                                                 .masked_select(~(mask.unsqueeze(-1)))
+                                                 .reshape(-1, vocab_size)
+                                                 .index_fill(-1, root_idx, -10e10))
+            output["prediction"] = (relation_prediction_output.argmax(-1), head_output["prediction"])
 
         if labels is not None and labels[0] is not None:
             if sample_weights is None:
@@ -195,5 +215,6 @@ class DependencyRelationModel(base.Predictor):
             head_predictor=head_predictor,
             head_projection_layer=head_projection_layer,
             dependency_projection_layer=dependency_projection_layer,
-            relation_prediction_layer=relation_prediction_layer
+            relation_prediction_layer=relation_prediction_layer,
+            root_idx=vocab.get_token_index("root", vocab_namespace)
         )
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index 32e7dd9..651c14a 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -110,5 +110,6 @@ def restore_collapse_edges(tree_tokens):
                         "deps": f"{head}:{empty_node_relation}"
                     }
                 )
-        token["deps"] = "|".join(deps)
+        deps = sorted([d.split(":", 1) for d in deps], key=lambda x: float(x[0]))
+        token["deps"] = "|".join([f"{k}:{v}" for k, v in deps])
     return empty_tokens
-- 
GitLab


From 426d24f18d93ca709ea02562fbe44e25dbe241bd Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Wed, 23 Dec 2020 09:15:43 +0100
Subject: [PATCH 15/19] Add script for training enhanced dependency parsing
 models based on IWPT'20 Shared Task data.

---
 scripts/train_eud.py | 138 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 138 insertions(+)
 create mode 100644 scripts/train_eud.py

diff --git a/scripts/train_eud.py b/scripts/train_eud.py
new file mode 100644
index 0000000..7c50b61
--- /dev/null
+++ b/scripts/train_eud.py
@@ -0,0 +1,138 @@
+"""Script to train Enhanced Dependency Parsing models based on IWPT'20 Shared Task data.
+
+Might require:
+conda install -c bioconda perl-list-moreutils
+conda install -c bioconda perl-namespace-autoclean
+conda install -c bioconda perl-moose
+conda install -c dan_blanchard perl-moosex-semiaffordanceaccessor
+"""
+
+import os
+import pathlib
+import subprocess
+from typing import List
+
+from absl import app
+from absl import flags
+
+FLAGS = flags.FLAGS
+LANG2TREEBANK = {
+    "ar": ["Arabic-PADT"],
+    "bg": ["Bulgarian-BTB"],
+    "cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"],
+    "nl": ["Dutch-Alpino", "Dutch-LassySmall"],
+    "en": ["English-EWT", "English-PUD"],
+    "et": ["Estonian-EDT", "Estonian-EWT"],
+    "fi": ["Finnish-TDT", "Finnish-PUD"],
+    "fr": ["French-Sequoia", "French-FQB"],
+    "it": ["Italian-ISDT"],
+    "lv": ["Latvian-LVTB"],
+    "lt": ["Lithuanian-ALKSNIS"],
+    "pl": ["Polish-LFG", "Polish-PDB", "Polish-PUD"],
+    "ru": ["Russian-SynTagRus"],
+    "sk": ["Slovak-SNK"],
+    "sv": ["Swedish-Talbanken", "Swedish-PUD"],
+    "ta": ["Tamil-TTB"],
+    "uk": ["Ukrainian-IU"],
+}
+
+LANG2TRANSFORMER = {
+    "en": "bert-base-cased",
+    "pl": "allegro/herbert-base-cased",
+}
+
+flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()),
+                  help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.")
+flags.DEFINE_string(name="data_dir", default="",
+                    help="Path to 'iwpt2020stdata' directory.")
+flags.DEFINE_string(name="serialization_dir", default="/tmp/",
+                    help="Model serialization dir.")
+flags.DEFINE_integer(name="cuda_device", default=-1,
+                     help="Cuda device id (-1 for cpu).")
+
+
+def path_to_str(path: pathlib.Path) -> str:
+    return str(path.resolve())
+
+
+def merge_files(files: List[str], output: pathlib.Path):
+    if not output.exists():
+        os.system(f"cat {' '.join(files)} > {output}")
+
+
+def execute_command(command, output_file=None):
+    command = [c for c in command.split() if c.strip()]
+    if output_file:
+        with open(output_file, "w") as f:
+            subprocess.run(command, check=True, stdout=f)
+    else:
+        subprocess.run(command, check=True)
+
+
+def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
+    output_path = pathlib.Path(output)
+    if not output_path.exists():
+        execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
+                        f"{path_to_str(treebank_file)}", output)
+
+
+def run(_):
+    languages = FLAGS.lang
+    for lang in languages:
+        assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}."
+        data_dir = pathlib.Path(FLAGS.data_dir)
+        assert data_dir.is_dir(), f"'{data_dir}' is not a directory!"
+
+        treebanks = LANG2TREEBANK[lang]
+        train_paths = []
+        dev_paths = []
+        test_paths = []
+        for treebank in treebanks:
+            treebank_dir = data_dir / f"UD_{treebank}"
+            assert treebank_dir.exists() and treebank_dir.is_dir(), f"'{treebank_dir}' directory doesn't exists."
+            for treebank_file in treebank_dir.iterdir():
+                name = treebank_file.name
+                if "conllu" in name and "fixed" not in name:
+                    output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
+                    if "train" in name:
+                        collapse_nodes(data_dir, treebank_file, output)
+                        train_paths.append(output)
+                    elif "dev" in name:
+                        collapse_nodes(data_dir, treebank_file, output)
+                        dev_paths.append(output)
+                    elif "test" in name:
+                        collapse_nodes(data_dir, treebank_file, output)
+                        test_paths.append(output)
+
+        lang_data_dir = pathlib.Path(data_dir / lang)
+        lang_data_dir.mkdir(exist_ok=True)
+
+        train_path = lang_data_dir / "train.conllu"
+        dev_path = lang_data_dir / "dev.conllu"
+        test_path = lang_data_dir / "test.conllu"
+
+        merge_files(train_paths, output=train_path)
+        merge_files(dev_paths, output=dev_path)
+        merge_files(test_paths, output=test_path)
+
+        serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang
+        serialization_dir.mkdir(exist_ok=True)
+        execute_command("".join(f"""combo --mode train
+        --training_data {train_path}
+        --validation_data {dev_path}
+        --targets feats,upostag,xpostag,head,deprel,lemma,deps
+        --pretrained_transformer_name {LANG2TRANSFORMER[lang]}
+        --serialization_dir {serialization_dir}
+        --cuda_device {FLAGS.cuda_device}
+        --word_batch_size 2500
+        --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'}
+        --tensorboard
+        """.splitlines()))
+
+
+def main():
+    app.run(run)
+
+
+if __name__ == "__main__":
+    main()
-- 
GitLab


From d711c60f4e9aef246f7a08d459d45b525b1a2a0b Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Mon, 28 Dec 2020 10:27:47 +0100
Subject: [PATCH 16/19] Add script for training UD dependency parsing models
 and extend pypi description.

---
 combo/models/__init__.py     |   2 +-
 combo/models/graph_parser.py |   4 -
 combo/models/model.py        |   2 +-
 scripts/train.py             | 172 +++++++++++++++++++++++++++++++++++
 scripts/train_eud.py         |  32 ++-----
 scripts/utils.py             |  16 ++++
 setup.cfg                    |   3 +
 setup.py                     |  14 ++-
 8 files changed, 216 insertions(+), 29 deletions(-)
 create mode 100644 scripts/train.py
 create mode 100644 scripts/utils.py

diff --git a/combo/models/__init__.py b/combo/models/__init__.py
index ba8d617..ec7a138 100644
--- a/combo/models/__init__.py
+++ b/combo/models/__init__.py
@@ -5,5 +5,5 @@ from .parser import DependencyRelationModel
 from .embeddings import CharacterBasedWordEmbeddings
 from .encoder import ComboEncoder
 from .lemma import LemmatizerModel
-from .model import SemanticMultitaskModel
+from .model import ComboModel
 from .morpho import MorphologicalFeatures
diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py
index 2dc02dc..edcdc2d 100644
--- a/combo/models/graph_parser.py
+++ b/combo/models/graph_parser.py
@@ -119,13 +119,9 @@ class GraphDependencyRelationModel(base.Predictor):
                 mask: Optional[torch.BoolTensor] = None,
                 labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
                 sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
-        # if mask is not None:
-        #     mask = mask[:, 1:]
         relations_labels, head_labels, enhanced_heads_labels, enhanced_deprels_labels = None, None, None, None
         if labels is not None and labels[0] is not None:
             relations_labels, head_labels, enhanced_heads_labels = labels
-            # if mask is None:
-            #     mask = head_labels.new_ones(head_labels.size())
 
         head_output = self.head_predictor(x, enhanced_heads_labels, mask, sample_weights)
         head_pred = head_output["probability"]
diff --git a/combo/models/model.py b/combo/models/model.py
index ad0df0e..9d3f817 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -12,7 +12,7 @@ from combo.utils import metrics
 
 
 @allen_models.Model.register("semantic_multitask")
-class SemanticMultitaskModel(allen_models.Model):
+class ComboModel(allen_models.Model):
     """Main COMBO model."""
 
     def __init__(self,
diff --git a/scripts/train.py b/scripts/train.py
new file mode 100644
index 0000000..7ca0fce
--- /dev/null
+++ b/scripts/train.py
@@ -0,0 +1,172 @@
+"""Script to train Dependency Parsing models based on UD 2.x data."""
+import pathlib
+
+from absl import app
+from absl import flags
+
+from scripts import utils
+
+TREEBANKS = [
+    "UD_Afrikaans-AfriBooms",
+    "UD_Arabic-NYUAD",
+    "UD_Arabic-PADT",
+    "UD_Armenian-ArmTDP",
+    "UD_Basque-BDT",
+    "UD_Belarusian-HSE",
+    "UD_Breton-KEB",
+    "UD_Bulgarian-BTB",
+    "UD_Catalan-AnCora",
+    "UD_Croatian-SET",
+    "UD_Czech-CAC",
+    "UD_Czech-CLTT",
+    "UD_Czech-FicTree",
+    "UD_Czech-PDT",
+    "UD_Danish-DDT",
+    "UD_Dutch-Alpino",
+    "UD_Dutch-LassySmall",
+    "UD_English-ESL",
+    "UD_English-EWT",
+    "UD_English-GUM",
+    "UD_English-LinES",
+    "UD_English-ParTUT",
+    "UD_English-Pronouns",
+    "UD_Estonian-EDT",
+    "UD_Estonian-EWT",
+    "UD_Finnish-FTB",
+    "UD_Finnish-TDT",
+    "UD_French-FQB",
+    "UD_French-FTB",
+    "UD_French-GSD",
+    "UD_French-ParTUT",
+    "UD_French-Sequoia",
+    "UD_French-Spoken",
+    "UD_Galician-CTG",
+    "UD_Galician-TreeGal",
+    "UD_German-GSD",
+    "UD_German-HDT",
+    "UD_German-LIT",
+    "UD_Greek-GDT",
+    "UD_Hebrew-HTB",
+    "UD_Hindi_English-HIENCS",
+    "UD_Hindi-HDTB",
+    "UD_Hungarian-Szeged",
+    "UD_Indonesian-GSD",
+    "UD_Irish-IDT",
+    "UD_Italian-ISDT",
+    "UD_Italian-ParTUT",
+    "UD_Italian-PoSTWITA",
+    "UD_Italian-TWITTIRO",
+    "UD_Italian-VIT",
+    "UD_Japanese-BCCWJ",
+    "UD_Japanese-GSD",
+    "UD_Japanese-Modern",
+    "UD_Kazakh-KTB",
+    "UD_Korean-GSD",
+    "UD_Korean-Kaist",
+    "UD_Latin-ITTB",
+    "UD_Latin-Perseus",
+    "UD_Latin-PROIEL",
+    "UD_Latvian-LVTB",
+    "UD_Lithuanian-ALKSNIS",
+    "UD_Lithuanian-HSE",
+    "UD_Maltese-MUDT",
+    "UD_Marathi-UFAL",
+    "UD_Persian-Seraji",
+    "UD_Polish-LFG",
+    "UD_Polish-PDB",
+    "UD_Portuguese-Bosque",
+    "UD_Portuguese-GSD",
+    "UD_Romanian-Nonstandard",
+    "UD_Romanian-RRT",
+    "UD_Romanian-SiMoNERo",
+    "UD_Russian-GSD",
+    "UD_Russian-SynTagRus",
+    "UD_Russian-Taiga",
+    "UD_Serbian-SET",
+    "UD_Slovak-SNK",
+    "UD_Slovenian-SSJ",
+    "UD_Slovenian-SST",
+    "UD_Spanish-AnCora",
+    "UD_Spanish-GSD",
+    "UD_Swedish-LinES",
+    "UD_Swedish_Sign_Language-SSLC",
+    "UD_Swedish-Talbanken",
+    "UD_Tamil-TTB",
+    "UD_Telugu-MTG",
+    "UD_Turkish-GB",
+    "UD_Turkish-IMST",
+    "UD_Ukrainian-IU",
+    "UD_Urdu-UDTB",
+    "UD_Uyghur-UDT",
+    "UD_Vietnamese-VTB",
+]
+
+FLAGS = flags.FLAGS
+flags.DEFINE_list(name="treebanks", default=TREEBANKS,
+                  help=f"Treebanks to train. Possible values: {TREEBANKS}.")
+flags.DEFINE_string(name="data_dir", default="",
+                    help="Path to UD data directory.")
+flags.DEFINE_string(name="serialization_dir", default="/tmp/",
+                    help="Model serialization directory.")
+flags.DEFINE_string(name="embeddings_dir", default="",
+                    help="Path to embeddings directory (with languages as subdirectories).")
+flags.DEFINE_integer(name="cuda_device", default=-1,
+                     help="Cuda device id (-1 for cpu).")
+
+
+def run(_):
+    treebanks_dir = pathlib.Path(FLAGS.data_dir)
+    for treebank in FLAGS.treebanks:
+        assert treebank in TREEBANKS, f"Unknown treebank {treebank}."
+        treebank_dir = treebanks_dir / treebank
+        treebank_parts = treebank.split("_")[1].split("-")
+        language = treebank_parts[0]
+
+        files = list(treebank_dir.iterdir())
+
+        training_file = [f for f in files if "train" in f.name and ".conllu" in f.name]
+        assert len(training_file) == 1, f"Couldn't find training file."
+        training_file_path = training_file[0]
+
+        valid_file = [f for f in files if "dev" in f.name and ".conllu" in f.name]
+        assert len(valid_file) == 1, f"Couldn't find validation file."
+        valid_file_path = valid_file[0]
+
+        embeddings_dir = FLAGS.embeddings_dir
+        embeddings_file = None
+        if embeddings_dir:
+            embeddings_dir = embeddings_dir / language
+            embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec.gz" in f.name]
+            assert len(embeddings_file) == 1, f"Couldn't find embeddings file."
+            embeddings_file = embeddings_file[0]
+
+        language = training_file_path.name.split("_")[0]
+
+        serialization_dir = pathlib.Path(FLAGS.serialization_dir) / treebank
+        serialization_dir.mkdir(exist_ok=True, parents=True)
+
+        command = f"""time combo --mode train
+        --cuda_device {FLAGS.cuda_device}
+        --training_data_path {training_file_path}
+        --validation_data_path {valid_file_path}
+        {f"--pretrained_tokens {embeddings_file}" if embeddings_dir
+        else f"--pretrained_transformer_name {utils.LANG2TRANSFORMER[language]}"}
+        --serialization_dir {serialization_dir}
+        --config_path {pathlib.Path.cwd() / 'config.template.jsonnet'}
+        --word_batch_size 2500
+        --notensorboard
+        """
+
+        # no XPOS datasets
+        if treebank in ["UD_Hungarian-Szeged", "UD_Armenian-ArmTDP"]:
+            command = command + " --targets deprel,head,upostag,lemma,feats"
+
+        utils.execute_command(command)
+
+
+def main():
+    app.run(run)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/train_eud.py b/scripts/train_eud.py
index 7c50b61..4904e0b 100644
--- a/scripts/train_eud.py
+++ b/scripts/train_eud.py
@@ -9,13 +9,13 @@ conda install -c dan_blanchard perl-moosex-semiaffordanceaccessor
 
 import os
 import pathlib
-import subprocess
 from typing import List
 
 from absl import app
 from absl import flags
 
-FLAGS = flags.FLAGS
+from scripts import utils
+
 LANG2TREEBANK = {
     "ar": ["Arabic-PADT"],
     "bg": ["Bulgarian-BTB"],
@@ -36,11 +36,7 @@ LANG2TREEBANK = {
     "uk": ["Ukrainian-IU"],
 }
 
-LANG2TRANSFORMER = {
-    "en": "bert-base-cased",
-    "pl": "allegro/herbert-base-cased",
-}
-
+FLAGS = flags.FLAGS
 flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()),
                   help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.")
 flags.DEFINE_string(name="data_dir", default="",
@@ -60,26 +56,18 @@ def merge_files(files: List[str], output: pathlib.Path):
         os.system(f"cat {' '.join(files)} > {output}")
 
 
-def execute_command(command, output_file=None):
-    command = [c for c in command.split() if c.strip()]
-    if output_file:
-        with open(output_file, "w") as f:
-            subprocess.run(command, check=True, stdout=f)
-    else:
-        subprocess.run(command, check=True)
-
-
 def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
     output_path = pathlib.Path(output)
     if not output_path.exists():
-        execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
-                        f"{path_to_str(treebank_file)}", output)
+        utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
+                              f"{path_to_str(treebank_file)}", output)
 
 
 def run(_):
     languages = FLAGS.lang
     for lang in languages:
         assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}."
+        assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict."
         data_dir = pathlib.Path(FLAGS.data_dir)
         assert data_dir.is_dir(), f"'{data_dir}' is not a directory!"
 
@@ -116,17 +104,17 @@ def run(_):
         merge_files(test_paths, output=test_path)
 
         serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang
-        serialization_dir.mkdir(exist_ok=True)
-        execute_command("".join(f"""combo --mode train
+        serialization_dir.mkdir(exist_ok=True, parents=True)
+        utils.execute_command("".join(f"""combo --mode train
         --training_data {train_path}
         --validation_data {dev_path}
         --targets feats,upostag,xpostag,head,deprel,lemma,deps
-        --pretrained_transformer_name {LANG2TRANSFORMER[lang]}
+        --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]}
         --serialization_dir {serialization_dir}
         --cuda_device {FLAGS.cuda_device}
         --word_batch_size 2500
         --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'}
-        --tensorboard
+        --notensorboard
         """.splitlines()))
 
 
diff --git a/scripts/utils.py b/scripts/utils.py
new file mode 100644
index 0000000..5dda2b8
--- /dev/null
+++ b/scripts/utils.py
@@ -0,0 +1,16 @@
+"""Utils for scripts."""
+import subprocess
+
+LANG2TRANSFORMER = {
+    "en": "bert-base-cased",
+    "pl": "allegro/herbert-base-cased",
+}
+
+
+def execute_command(command, output_file=None):
+    command = [c for c in command.split() if c.strip()]
+    if output_file:
+        with open(output_file, "w") as f:
+            subprocess.run(command, check=True, stdout=f)
+    else:
+        subprocess.run(command, check=True)
diff --git a/setup.cfg b/setup.cfg
index b7e4789..6876d0d 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,2 +1,5 @@
 [aliases]
 test=pytest
+
+[metadata]
+description-file = README.md
diff --git a/setup.py b/setup.py
index 74540e2..5c833d5 100644
--- a/setup.py
+++ b/setup.py
@@ -20,11 +20,23 @@ REQUIREMENTS = [
 
 setup(
     name='COMBO',
-    version='0.0.1',
+    version='1.0.0b1',
     install_requires=REQUIREMENTS,
     packages=find_packages(exclude=['tests']),
+    license="GPL-3.0",
+    url='https://gitlab.clarin-pl.eu/syntactic-tools/combo',
+    keywords="nlp natural-language-processing dependency-parsing",
     setup_requires=['pytest-runner', 'pytest-pylint'],
     tests_require=['pytest', 'pylint'],
     python_requires='>=3.6',
     entry_points={'console_scripts': ['combo = combo.main:main']},
+    classifiers=[
+        'Development Status :: 4 - Beta',
+        'Intended Audience :: Science/Research',
+        'License :: OSI Approved :: GNU General Public License v3 (GPLv3)',
+        'Topic :: Scientific/Engineering :: Artificial Intelligence'
+        'Programming Language :: Python :: 3.6',
+        'Programming Language :: Python :: 3.7',
+        'Programming Language :: Python :: 3.8',
+    ]
 )
-- 
GitLab


From bdd9d5bb60fb1975a49dd097a2f7bc98d5106e9d Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Tue, 29 Dec 2020 14:50:13 +0100
Subject: [PATCH 17/19] Fix mapping and path.

---
 combo/data/api.py | 4 ++++
 scripts/train.py  | 2 +-
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/combo/data/api.py b/combo/data/api.py
index 10a3a72..ca8f75a 100644
--- a/combo/data/api.py
+++ b/combo/data/api.py
@@ -50,6 +50,10 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke
     for t in tokens:
         if type(t["id"]) == list:
             t["id"] = tuple(t["id"])
+        if t["deps"]:
+            for dep in t["deps"]:
+                if type(dep[1]) == list:
+                    dep[1] = tuple(dep[1])
     return _TokenList(tokens=tokens,
                       metadata=sentence.metadata)
 
diff --git a/scripts/train.py b/scripts/train.py
index 7ca0fce..9390888 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -135,7 +135,7 @@ def run(_):
         embeddings_dir = FLAGS.embeddings_dir
         embeddings_file = None
         if embeddings_dir:
-            embeddings_dir = embeddings_dir / language
+            embeddings_dir = pathlib.Path(embeddings_dir) / language
             embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec.gz" in f.name]
             assert len(embeddings_file) == 1, f"Couldn't find embeddings file."
             embeddings_file = embeddings_file[0]
-- 
GitLab


From 4c556fd975e90f8d01798b1ffef022c2f9ed2c4d Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Sun, 3 Jan 2021 11:31:12 +0100
Subject: [PATCH 18/19] Refactor tensor indexing and prediction mapping.

---
 combo/data/api.py          | 29 ++++++++++++++++++++-------
 combo/models/base.py       | 10 +++++++---
 combo/models/embeddings.py |  4 ++--
 combo/models/model.py      | 40 ++++++++++++++++++++------------------
 combo/models/parser.py     |  1 -
 combo/predict.py           |  6 +++---
 setup.py                   |  1 -
 7 files changed, 55 insertions(+), 36 deletions(-)

diff --git a/combo/data/api.py b/combo/data/api.py
index ca8f75a..7d44917 100644
--- a/combo/data/api.py
+++ b/combo/data/api.py
@@ -1,13 +1,13 @@
 import collections
+import dataclasses
+import json
 from dataclasses import dataclass, field
 from typing import Optional, List, Dict, Any, Union, Tuple
 
 import conllu
-from dataclasses_json import dataclass_json
 from overrides import overrides
 
 
-@dataclass_json
 @dataclass
 class Token:
     id: Optional[Union[int, Tuple]] = None
@@ -23,13 +23,19 @@ class Token:
     semrel: Optional[str] = None
 
 
-@dataclass_json
 @dataclass
 class Sentence:
     tokens: List[Token] = field(default_factory=list)
     sentence_embedding: List[float] = field(default_factory=list)
     metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict)
 
+    def to_json(self):
+        return json.dumps({
+            "tokens": [dataclasses.asdict(t) for t in self.tokens],
+            "sentence_embedding": self.sentence_embedding,
+            "metadata": self.metadata,
+        })
+
 
 class _TokenList(conllu.TokenList):
 
@@ -41,7 +47,7 @@ class _TokenList(conllu.TokenList):
 def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.TokenList:
     tokens = []
     for token in sentence.tokens:
-        token_dict = collections.OrderedDict(token.to_dict())
+        token_dict = collections.OrderedDict(dataclasses.asdict(token))
         # Remove semrel to have default conllu format.
         if not keep_semrel:
             del token_dict["semrel"]
@@ -52,7 +58,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke
             t["id"] = tuple(t["id"])
         if t["deps"]:
             for dep in t["deps"]:
-                if type(dep[1]) == list:
+                if len(dep) > 1 and type(dep[1]) == list:
                     dep[1] = tuple(dep[1])
     return _TokenList(tokens=tokens,
                       metadata=sentence.metadata)
@@ -68,9 +74,18 @@ def tokens2conllu(tokens: List[str]) -> conllu.TokenList:
 
 
 def conllu2sentence(conllu_sentence: conllu.TokenList,
-                    sentence_embedding: List[float]) -> Sentence:
+                    sentence_embedding=None) -> Sentence:
+    if sentence_embedding is None:
+        sentence_embedding = []
+    tokens = []
+    for token in conllu_sentence.tokens:
+        tokens.append(
+            Token(
+                **token
+            )
+        )
     return Sentence(
-        tokens=[Token.from_dict(t) for t in conllu_sentence.tokens],
+        tokens=tokens,
         sentence_embedding=sentence_embedding,
         metadata=conllu_sentence.metadata
     )
diff --git a/combo/models/base.py b/combo/models/base.py
index 10e9d37..a5cb5fe 100644
--- a/combo/models/base.py
+++ b/combo/models/base.py
@@ -27,11 +27,11 @@ class Linear(nn.Linear, common.FromParams):
     def __init__(self,
                  in_features: int,
                  out_features: int,
-                 activation: Optional[allen_nn.Activation] = lambda x: x,
+                 activation: Optional[allen_nn.Activation] = None,
                  dropout_rate: Optional[float] = 0.0):
         super().__init__(in_features, out_features)
-        self.activation = activation
-        self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else lambda x: x
+        self.activation = activation if activation else self.identity
+        self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else self.identity
 
     def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
         x = super().forward(x)
@@ -41,6 +41,10 @@ class Linear(nn.Linear, common.FromParams):
     def get_output_dim(self) -> int:
         return self.out_features
 
+    @staticmethod
+    def identity(x):
+        return x
+
 
 @Predictor.register("feedforward_predictor")
 @Predictor.register("feedforward_predictor_from_vocab", constructor="from_vocab")
diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
index 5cad959..6ad2559 100644
--- a/combo/models/embeddings.py
+++ b/combo/models/embeddings.py
@@ -196,10 +196,10 @@ class FeatsTokenEmbedder(token_embedders.Embedding):
 
     def forward(self, tokens: torch.Tensor) -> torch.Tensor:
         # (batch_size, sentence_length, features_vocab_length)
-        mask = (tokens > 0).float()
+        mask = tokens.gt(0)
         # (batch_size, sentence_length, features_vocab_length, embedding_dim)
         x = super().forward(tokens)
         # (batch_size, sentence_length, embedding_dim)
         return x.sum(dim=-2) / (
-            (mask.sum(dim=-1) + util.tiny_value_of_dtype(mask.dtype)).unsqueeze(dim=-1)
+            (mask.sum(dim=-1) + util.tiny_value_of_dtype(torch.float)).unsqueeze(dim=-1)
         )
diff --git a/combo/models/model.py b/combo/models/model.py
index 9d3f817..9866bcb 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -60,9 +60,11 @@ class ComboModel(allen_models.Model):
                 enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
 
         # Prepare masks
-        char_mask: torch.BoolTensor = sentence["char"]["token_characters"] > 0
+        char_mask = sentence["char"]["token_characters"].gt(0)
         word_mask = util.get_text_field_mask(sentence)
 
+        device = word_mask.device
+
         # If enabled weight samples loss by log(sentence_length)
         sample_weights = word_mask.sum(-1).float().log() if self.use_sample_weight else None
 
@@ -73,45 +75,45 @@ class ComboModel(allen_models.Model):
 
         # Concatenate the head sentinel (ROOT) onto the sentence representation.
         head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
-        encoder_emb = torch.cat([head_sentinel, encoder_emb], 1)
-        word_mask = torch.cat([word_mask.new_ones((batch_size, 1)), word_mask], 1)
+        encoder_emb_with_root = torch.cat([head_sentinel, encoder_emb], 1)
+        word_mask_with_root = torch.cat([torch.ones((batch_size, 1), device=device), word_mask], 1)
 
         upos_output = self._optional(self.upos_tagger,
-                                     encoder_emb[:, 1:],
-                                     mask=word_mask[:, 1:],
+                                     encoder_emb,
+                                     mask=word_mask,
                                      labels=upostag,
                                      sample_weights=sample_weights)
         xpos_output = self._optional(self.xpos_tagger,
-                                     encoder_emb[:, 1:],
-                                     mask=word_mask[:, 1:],
+                                     encoder_emb,
+                                     mask=word_mask,
                                      labels=xpostag,
                                      sample_weights=sample_weights)
         semrel_output = self._optional(self.semantic_relation,
-                                       encoder_emb[:, 1:],
-                                       mask=word_mask[:, 1:],
+                                       encoder_emb,
+                                       mask=word_mask,
                                        labels=semrel,
                                        sample_weights=sample_weights)
         morpho_output = self._optional(self.morphological_feat,
-                                       encoder_emb[:, 1:],
-                                       mask=word_mask[:, 1:],
+                                       encoder_emb,
+                                       mask=word_mask,
                                        labels=feats,
                                        sample_weights=sample_weights)
         lemma_output = self._optional(self.lemmatizer,
-                                      (encoder_emb[:, 1:], sentence.get("char").get("token_characters")
+                                      (encoder_emb, sentence.get("char").get("token_characters")
                                       if sentence.get("char") else None),
-                                      mask=word_mask[:, 1:],
+                                      mask=word_mask,
                                       labels=lemma.get("char").get("token_characters") if lemma else None,
                                       sample_weights=sample_weights)
         parser_output = self._optional(self.dependency_relation,
-                                       encoder_emb,
+                                       encoder_emb_with_root,
                                        returns_tuple=True,
-                                       mask=word_mask,
+                                       mask=word_mask_with_root,
                                        labels=(deprel, head),
                                        sample_weights=sample_weights)
         enhanced_parser_output = self._optional(self.enhanced_dependency_relation,
-                                                encoder_emb,
+                                                encoder_emb_with_root,
                                                 returns_tuple=True,
-                                                mask=word_mask,
+                                                mask=word_mask_with_root,
                                                 labels=(enhanced_deprels, head, enhanced_heads),
                                                 sample_weights=sample_weights)
         relations_pred, head_pred = parser_output["prediction"]
@@ -126,7 +128,7 @@ class ComboModel(allen_models.Model):
             "deprel": relations_pred,
             "enhanced_head": enhanced_head_pred,
             "enhanced_deprel": enhanced_relations_pred,
-            "sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0],
+            "sentence_embedding": torch.max(encoder_emb, dim=1)[0],
         }
 
         if "rel_probability" in enhanced_parser_output:
@@ -153,7 +155,7 @@ class ComboModel(allen_models.Model):
                 "enhanced_head": enhanced_heads,
                 "enhanced_deprel": enhanced_deprels,
             }
-            self.scores(output, labels, word_mask[:, 1:])
+            self.scores(output, labels, word_mask)
             relations_loss, head_loss = parser_output["loss"]
             enhanced_relations_loss, enhanced_head_loss = enhanced_parser_output["loss"]
             losses = {
diff --git a/combo/models/parser.py b/combo/models/parser.py
index 4b5b126..dfb53ab 100644
--- a/combo/models/parser.py
+++ b/combo/models/parser.py
@@ -1,5 +1,4 @@
 """Dependency parsing models."""
-import math
 from typing import Tuple, Dict, Optional, Union, List
 
 import numpy as np
diff --git a/combo/predict.py b/combo/predict.py
index 8d3e2f9..21941d9 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -25,7 +25,7 @@ class COMBO(predictor.Predictor):
                  model: models.Model,
                  dataset_reader: allen_data.DatasetReader,
                  tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(),
-                 batch_size: int = 500,
+                 batch_size: int = 32,
                  line_to_conllu: bool = False) -> None:
         super().__init__(model, dataset_reader)
         self.batch_size = batch_size
@@ -52,7 +52,7 @@ class COMBO(predictor.Predictor):
 
     def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
         if isinstance(sentence, str):
-            return data.Sentence.from_dict(self.predict_json({"sentence": sentence}))
+            return self.predict_json({"sentence": sentence})
         elif isinstance(sentence, list):
             if len(sentence) == 0:
                 return []
@@ -219,7 +219,7 @@ class COMBO(predictor.Predictor):
 
     @classmethod
     def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
-                        batch_size: int = 500,
+                        batch_size: int = 32,
                         cuda_device: int = -1):
         util.import_module_and_submodules("combo.commands")
         util.import_module_and_submodules("combo.models")
diff --git a/setup.py b/setup.py
index 5c833d5..fdaa2be 100644
--- a/setup.py
+++ b/setup.py
@@ -6,7 +6,6 @@ REQUIREMENTS = [
     'allennlp==1.2.1',
     'conllu==2.3.2',
     'dataclasses;python_version<"3.7"',
-    'dataclasses-json==0.5.2',
     'joblib==0.14.1',
     'jsonnet==0.15.0',
     'requests==2.23.0',
-- 
GitLab


From 160ee6ea3f05316abb2d91f2103a6d91009e819c Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Sun, 3 Jan 2021 11:46:19 +0100
Subject: [PATCH 19/19] Extend documentation with better examples.

---
 README.md        | 14 +++++++++-----
 docs/models.md   | 27 +++++++++++++++++----------
 docs/training.md |  6 +++---
 3 files changed, 29 insertions(+), 18 deletions(-)

diff --git a/README.md b/README.md
index c339bda..a9c2113 100644
--- a/README.md
+++ b/README.md
@@ -10,19 +10,24 @@
 </p>
 
 ## Quick start
-Clone this repository and install COMBO (we suggest using virtualenv/conda with Python 3.6+):
+Clone this repository and install COMBO (we suggest creating a virtualenv/conda environment with Python 3.6+, as a bundle of required packages will be installed):
 ```bash
 git clone https://gitlab.clarin-pl.eu/syntactic-tools/clarinbiz/combo.git
 cd combo
 python setup.py develop
 ```
-Run the following lines in your Python console to make predictions with a pre-trained model:
+Run the following commands in your Python console to make predictions with a pre-trained model:
 ```python
 from combo.predict import COMBO
 
 nlp = COMBO.from_pretrained("polish-herbert-base")
-sentence = nlp("Moje zdanie.")
-print(sentence.tokens)
+sentence = nlp("COVID-19 to ostra choroba zakaźna układu oddechowego wywołana zakażeniem wirusem SARS-CoV-2.")
+```
+Predictions are accessible as a list of token attributes:
+```python
+print("{:5} {:15} {:15} {:10} {:10} {:10}".format('ID', 'TOKEN', 'LEMMA', 'UPOS', 'HEAD', 'DEPREL'))
+for token in sentence.tokens:
+    print("{:5} {:15} {:15} {:10} {:10} {:10}".format(str(token.id), token.token, token.lemma, token.upostag, str(token.head), token.deprel))
 ```
 
 ## Details
@@ -31,4 +36,3 @@ print(sentence.tokens)
 - [**Pre-trained models**](docs/models.md)
 - [**Training**](docs/training.md)
 - [**Prediction**](docs/prediction.md)
-
diff --git a/docs/models.md b/docs/models.md
index d4346ff..25a7f70 100644
--- a/docs/models.md
+++ b/docs/models.md
@@ -1,19 +1,26 @@
 # Models
 
-Pre-trained models are available [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/).
+COMBO provides pre-trained models for:
+- morphosyntactic prediction (i.e. part-of-speech tagging, morphosyntactic analysis, lemmatisation and dependency parsing) trained on the treebanks from [Universal Dependencies repository](https://universaldependencies.org),
+- enhanced dependency parsing trained on IWPT 2020 shared task [data](https://universaldependencies.org/iwpt20/data.html).
+
+## Manual download
+
+The pre-trained models can be downloaded from [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/).
+
+
+If you want to use the console version of COMBO, you need to download a pre-trained model manually:
+```bash
+wget http://mozart.ipipan.waw.pl/~mklimaszewski/models/polish-herbert-base.tar.gz
+```
+
+The downloaded model should be passed as a parameter for COMBO (see [prediction doc](prediction.md)).
 
 ## Automatic download
-Python `from_pretrained` method will download the pre-trained model if the provided name (without the extension .tar.gz) matches one of the names in [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/).
+The pre-trained models can be downloaded automatically with the Python `from_pretrained` method. Select a model name (without the extension .tar.gz) from the list of [pre-trained models](http://mozart.ipipan.waw.pl/~mklimaszewski/models/) and pass the name as the attribute to `from_pretrained` method:
 ```python
 from combo.predict import COMBO
 
 nlp = COMBO.from_pretrained("polish-herbert-base")
 ```
-Otherwise it looks for a model in local env.
-
-## Console prediction/Local model
-If you want to use the console version of COMBO, you need to download a pre-trained model manually
-```bash
-wget http://mozart.ipipan.waw.pl/~mklimaszewski/models/polish-herbert-base.tar.gz
-```
-and pass it as a parameter (see [prediction doc](prediction.md)).
+If the model name doesn't match any model on the list of [pre-trained models](http://mozart.ipipan.waw.pl/~mklimaszewski/models/), COMBO looks for a model in local env.
diff --git a/docs/training.md b/docs/training.md
index f4b6d38..d3f69e0 100644
--- a/docs/training.md
+++ b/docs/training.md
@@ -1,6 +1,6 @@
 # Training
 
-Command:
+Basic command:
 ```bash
 combo --mode train \
       --training_data_path your_training_path \
@@ -32,13 +32,13 @@ Examples (for clarity without training/validation data paths):
     combo --mode train --pretrained_transformer_name your_choosen_pretrained_transformer
     ```
 
-* predict only dependency tree:
+* train only a dependency parser:
 
     ```bash
     combo --mode train --targets head,deprel
     ```
 
-* use part-of-speech tags for predicting only dependency tree
+* use additional features (e.g. part-of-speech tags) for training a dependency parser (`token` and `char` are default features)
 
     ```bash
     combo --mode train --targets head,deprel --features token,char,upostag
-- 
GitLab