From 520801e2a0741ceea6d66ee1ae171a77aaf205e7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com>
Date: Tue, 7 Mar 2023 19:38:16 +0100
Subject: [PATCH] General structure

---
 .idea/combolightning.iml                      |   2 +-
 .idea/misc.xml                                |   2 +-
 .idea/vcs.xml                                 |   6 +
 combo/checks.py                               |   4 -
 combo/commands/__init__.py                    |   1 +
 combo/commands/train.py                       |  10 +
 combo/data/__init__.py                        |   1 +
 combo/data/dataset.py                         | 314 +-----------------
 combo/data/fields/__init__.py                 |   1 +
 combo/data/fields/base_field.py               |   5 +
 .../data/fields/sequence_multilabel_field.py  |  40 +--
 combo/data/samplers/__init__.py               |   1 +
 combo/data/samplers/base_sampler.py           |   5 +
 combo/data/samplers/samplers.py               |   5 +-
 combo/data/token.py                           |   2 +
 combo/data/token_indexers/__init__.py         |   1 +
 combo/data/token_indexers/base_indexer.py     |   9 +
 ...etrained_transformer_mismatched_indexer.py | 118 +------
 .../token_characters_indexer.py               |  62 +---
 .../token_indexers/token_features_indexer.py  |  74 +----
 combo/main.py                                 | 150 +++++++++
 combo/models/__init__.py                      |   8 +
 combo/models/base.py                          |   6 +-
 combo/models/embeddings.py                    |  24 ++
 combo/models/encoder.py                       |  10 +
 combo/models/graph_parser.py                  |   9 +
 combo/models/lemma.py                         |   5 +
 combo/models/model.py                         |   5 +
 combo/models/morpho.py                        |   5 +
 combo/models/parser.py                        |   9 +
 combo/predict.py                              |  15 +
 combo/training/__init__.py                    |   3 +
 combo/training/checkpointer.py                |   6 +
 combo/training/scheduler.py                   |   2 +
 combo/training/tensorboard_writer.py          |   2 +
 combo/training/trainer.py                     |  13 +
 combo/utils/__init__.py                       |   0
 combo/utils/checks.py                         |  15 +
 combo/utils/download.py                       |  78 +++++
 combo/utils/graph.py                          | 149 +++++++++
 combo/utils/metrics.py                        |  17 +
 requirements.txt                              |   6 +-
 42 files changed, 607 insertions(+), 593 deletions(-)
 create mode 100644 .idea/vcs.xml
 delete mode 100644 combo/checks.py
 create mode 100644 combo/commands/__init__.py
 create mode 100644 combo/commands/train.py
 create mode 100644 combo/data/fields/base_field.py
 create mode 100644 combo/data/samplers/base_sampler.py
 create mode 100644 combo/data/token.py
 create mode 100644 combo/data/token_indexers/base_indexer.py
 create mode 100644 combo/main.py
 create mode 100644 combo/models/encoder.py
 create mode 100644 combo/models/graph_parser.py
 create mode 100644 combo/models/lemma.py
 create mode 100644 combo/models/model.py
 create mode 100644 combo/models/morpho.py
 create mode 100644 combo/models/parser.py
 create mode 100644 combo/predict.py
 create mode 100644 combo/training/__init__.py
 create mode 100644 combo/training/checkpointer.py
 create mode 100644 combo/training/scheduler.py
 create mode 100644 combo/training/tensorboard_writer.py
 create mode 100644 combo/training/trainer.py
 create mode 100644 combo/utils/__init__.py
 create mode 100644 combo/utils/checks.py
 create mode 100644 combo/utils/download.py
 create mode 100644 combo/utils/graph.py
 create mode 100644 combo/utils/metrics.py

diff --git a/.idea/combolightning.iml b/.idea/combolightning.iml
index 4154233..332f523 100644
--- a/.idea/combolightning.iml
+++ b/.idea/combolightning.iml
@@ -2,7 +2,7 @@
 <module type="PYTHON_MODULE" version="4">
   <component name="NewModuleRootManager">
     <content url="file://$MODULE_DIR$" />
-    <orderEntry type="jdk" jdkName="Python 3.9 (combolightning)" jdkType="Python SDK" />
+    <orderEntry type="jdk" jdkName="combo" jdkType="Python SDK" />
     <orderEntry type="sourceFolder" forTests="false" />
   </component>
 </module>
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
index fcbd09d..d1d59e6 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,4 +1,4 @@
 <?xml version="1.0" encoding="UTF-8"?>
 <project version="4">
-  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (combolightning)" project-jdk-type="Python SDK" />
+  <component name="ProjectRootManager" version="2" project-jdk-name="combo" project-jdk-type="Python SDK" />
 </project>
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="" vcs="Git" />
+  </component>
+</project>
\ No newline at end of file
diff --git a/combo/checks.py b/combo/checks.py
deleted file mode 100644
index 7736e96..0000000
--- a/combo/checks.py
+++ /dev/null
@@ -1,4 +0,0 @@
-class ConfigurationError(Exception):
-    def __init__(self, message: str):
-        super().__init__()
-        self.message = message
diff --git a/combo/commands/__init__.py b/combo/commands/__init__.py
new file mode 100644
index 0000000..83ff4f4
--- /dev/null
+++ b/combo/commands/__init__.py
@@ -0,0 +1 @@
+from .train import FinetuningTrainModel
\ No newline at end of file
diff --git a/combo/commands/train.py b/combo/commands/train.py
new file mode 100644
index 0000000..f4bb807
--- /dev/null
+++ b/combo/commands/train.py
@@ -0,0 +1,10 @@
+from pytorch_lightning import Trainer
+
+
+class FinetuningTrainModel(Trainer):
+    """
+    Class made only for finetuning,
+    the only difference is saving vocab from concatenated
+    (archive and current) datasets
+    """
+    pass
\ No newline at end of file
diff --git a/combo/data/__init__.py b/combo/data/__init__.py
index 91426e3..2d69cca 100644
--- a/combo/data/__init__.py
+++ b/combo/data/__init__.py
@@ -1,3 +1,4 @@
 from .samplers import TokenCountBatchSampler
+from .token import Token
 from .token_indexers import *
 from .api import *
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index 0c34d3f..62ca30f 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -1,317 +1,11 @@
-import copy
 import logging
-import pathlib
-from typing import Union, List, Dict, Iterable, Optional, Any, Tuple
 
-import conllu
-import torch
-from allennlp import data as allen_data
-from allennlp.common import checks, util
-from allennlp.data import fields as allen_fields, vocabulary
-from conllu import parser
-from dataclasses import dataclass
-from overrides import overrides
-
-from combo.data import fields
 
 logger = logging.getLogger(__name__)
 
 
-@allen_data.DatasetReader.register("conllu")
-class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
-
-    def __init__(
-            self,
-            token_indexers: Dict[str, allen_data.TokenIndexer] = None,
-            lemma_indexers: Dict[str, allen_data.TokenIndexer] = None,
-            features: List[str] = None,
-            targets: List[str] = None,
-            use_sem: bool = False,
-            **kwargs,
-    ) -> None:
-        super().__init__(**kwargs)
-        if features is None:
-            features = ["token", "char"]
-        if targets is None:
-            targets = ["head", "deprel", "upostag", "xpostag", "lemma", "feats"]
-
-        if "token" not in features and "char" not in features:
-            raise checks.ConfigurationError("There must be at least one ('char' or 'token') text-based feature!")
-
-        if "deps" in targets and not ("head" in targets and "deprel" in targets):
-            raise checks.ConfigurationError("Add 'head' and 'deprel' to targets when using 'deps'!")
-
-        intersection = set(features).intersection(set(targets))
-        if len(intersection) != 0:
-            raise checks.ConfigurationError(
-                "Features and targets cannot share elements! "
-                "Remove {} from either features or targets.".format(intersection)
-            )
-        self.use_sem = use_sem
-
-        # *.conllu readers config
-        fields = list(parser.DEFAULT_FIELDS)
-        fields[1] = "token"  # use 'token' instead of 'form'
-        field_parsers = parser.DEFAULT_FIELD_PARSERS
-        # Do not make it nullable
-        field_parsers.pop("xpostag", None)
-        # Ignore parsing misc
-        field_parsers.pop("misc", None)
-        if self.use_sem:
-            fields = list(fields)
-            fields.append("semrel")
-            field_parsers["semrel"] = lambda line, i: line[i]
-        self.field_parsers = field_parsers
-        self.fields = tuple(fields)
-
-        self._token_indexers = token_indexers
-        self._lemma_indexers = lemma_indexers
-        self._targets = targets
-        self._features = features
-        self.generate_labels = True
-        # Filter out not required token indexers to avoid
-        # Mismatched token keys ConfigurationError
-        for indexer_name in list(self._token_indexers.keys()):
-            if indexer_name not in self._features:
-                del self._token_indexers[indexer_name]
-
-    @overrides
-    def _read(self, file_path: str) -> Iterable[allen_data.Instance]:
-        file_path = [file_path] if len(file_path.split(",")) == 0 else file_path.split(",")
-
-        for conllu_file in file_path:
-            file = pathlib.Path(conllu_file)
-            assert conllu_file and file.exists(), f"File with path '{conllu_file}' does not exists!"
-            with file.open("r", encoding="utf-8") as f:
-                for annotation in conllu.parse_incr(f, fields=self.fields, field_parsers=self.field_parsers):
-                    yield self.text_to_instance(annotation)
-
-    @overrides
-    def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance:
-        fields_: Dict[str, allen_data.Field] = {}
-        tree_tokens = [t for t in tree if isinstance(t["id"], int)]
-        tokens = [_Token(t["token"],
-                         pos_=t.get("upostag"),
-                         tag_=t.get("xpostag"),
-                         lemma_=t.get("lemma"),
-                         feats_=t.get("feats"))
-                  for t in tree_tokens]
-
-        # features
-        text_field = allen_fields.TextField(tokens, self._token_indexers)
-        fields_["sentence"] = text_field
-
-        # targets
-        if self.generate_labels:
-            for target_name in self._targets:
-                if target_name != "sent":
-                    target_values = [t[target_name] for t in tree_tokens]
-                    if target_name == "lemma":
-                        target_values = [allen_data.Token(v) for v in target_values]
-                        fields_[target_name] = allen_fields.TextField(target_values, self._lemma_indexers)
-                    elif target_name == "feats":
-                        target_values = self._feat_values(tree_tokens)
-                        fields_[target_name] = fields.SequenceMultiLabelField(target_values,
-                                                                              self._feats_indexer,
-                                                                              self._feats_as_tensor_wrapper,
-                                                                              text_field,
-                                                                              label_namespace="feats_labels")
-                    elif target_name == "head":
-                        target_values = [0 if v == "_" else int(v) for v in target_values]
-                        fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
-                                                                               label_namespace=target_name + "_labels")
-                    elif target_name == "deps":
-                        # Graphs require adding ROOT (AdjacencyField uses sequence length from TextField).
-                        text_field_deps = allen_fields.TextField([_Token("ROOT")] + copy.deepcopy(tokens),
-                                                                 self._token_indexers)
-                        enhanced_heads: List[Tuple[int, int]] = []
-                        enhanced_deprels: List[str] = []
-                        for idx, t in enumerate(tree_tokens):
-                            t_deps = t["deps"]
-                            if t_deps and t_deps != "_":
-                                for rel, head in t_deps:
-                                    # EmoryNLP skips the first edge, if there are two edges between the same
-                                    # nodes. Thanks to that one is in a tree and another in a graph.
-                                    # This snippet follows that approach.
-                                    if enhanced_heads and enhanced_heads[-1] == (idx, head):
-                                        enhanced_heads.pop()
-                                        enhanced_deprels.pop()
-                                    enhanced_heads.append((idx, head))
-                                    enhanced_deprels.append(rel)
-                        fields_["enhanced_heads"] = allen_fields.AdjacencyField(
-                            indices=enhanced_heads,
-                            sequence_field=text_field_deps,
-                            label_namespace="enhanced_heads_labels",
-                            padding_value=0,
-                        )
-                        fields_["enhanced_deprels"] = allen_fields.AdjacencyField(
-                            indices=enhanced_heads,
-                            sequence_field=text_field_deps,
-                            labels=enhanced_deprels,
-                            # Label namespace matches regular tree parsing.
-                            label_namespace="enhanced_deprel_labels",
-                            padding_value=0,
-                        )
-                    else:
-                        fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
-                                                                               label_namespace=target_name + "_labels")
-
-        # Restore feats fields to string representation
-        # parser.serialize_field doesn't handle key without value
-        for token in tree.tokens:
-            if "feats" in token:
-                feats = token["feats"]
-                if feats:
-                    feats_values = []
-                    for k, v in feats.items():
-                        feats_values.append('='.join((k, v)) if v else k)
-                    field = "|".join(feats_values)
-                else:
-                    field = "_"
-                token["feats"] = field
-
-        # metadata
-        fields_["metadata"] = allen_fields.MetadataField({"input": tree,
-                                                          "field_names": self.fields,
-                                                          "tokens": tokens})
-
-        return allen_data.Instance(fields_)
-
-    @staticmethod
-    def _feat_values(tree: List[Dict[str, Any]]):
-        features = []
-        for token in tree:
-            token_features = []
-            if token["feats"] is not None:
-                for feat, value in token["feats"].items():
-                    if feat in ["_", "__ROOT__"]:
-                        pass
-                    else:
-                        # Handle case where feature is binary (doesn't have associated value)
-                        if value:
-                            token_features.append(feat + "=" + value)
-                        else:
-                            token_features.append(feat)
-            features.append(token_features)
-        return features
-
-    @staticmethod
-    def _feats_as_tensor_wrapper(field: fields.SequenceMultiLabelField):
-        def as_tensor(padding_lengths):
-            desired_num_tokens = padding_lengths["num_tokens"]
-            assert len(field._indexed_multi_labels) > 0
-            classes_count = len(field._indexed_multi_labels[0])
-            default_value = [0.0] * classes_count
-            padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens,
-                                                      lambda: default_value)
-            tensor = torch.tensor(padded_tags, dtype=torch.long)
-            return tensor
-
-        return as_tensor
-
-    @staticmethod
-    def _feats_indexer(vocab: allen_data.Vocabulary):
-        label_namespace = "feats_labels"
-        vocab_size = vocab.get_vocab_size(label_namespace)
-        slices = get_slices_if_not_provided(vocab)
-
-        def _m_from_n_ones_encoding(multi_label: List[str], sentence_length: int) -> List[int]:
-            one_hot_encoding = [0] * vocab_size
-            for cat, cat_indices in slices.items():
-                if cat not in ["__PAD__", "_"]:
-                    label_from_cat = [label for label in multi_label if cat == label.split("=")[0]]
-                    if label_from_cat:
-                        label_from_cat = label_from_cat[0]
-                        index = vocab.get_token_index(label_from_cat, label_namespace)
-                    else:
-                        # Get Cat=None index
-                        index = vocab.get_token_index(cat + "=None", label_namespace)
-                    one_hot_encoding[index] = 1
-            return one_hot_encoding
-
-        return _m_from_n_ones_encoding
-
-
-@allen_data.Vocabulary.register("from_instances_extended", constructor="from_instances_extended")
-class Vocabulary(allen_data.Vocabulary):
-
-    @classmethod
-    def from_instances_extended(
-            cls,
-            instances: Iterable[allen_data.Instance],
-            min_count: Dict[str, int] = None,
-            max_vocab_size: Union[int, Dict[str, int]] = None,
-            non_padded_namespaces: Iterable[str] = vocabulary.DEFAULT_NON_PADDED_NAMESPACES,
-            pretrained_files: Optional[Dict[str, str]] = None,
-            only_include_pretrained_words: bool = False,
-            min_pretrained_embeddings: Dict[str, int] = None,
-            padding_token: Optional[str] = vocabulary.DEFAULT_PADDING_TOKEN,
-            oov_token: Optional[str] = vocabulary.DEFAULT_OOV_TOKEN,
-    ) -> "Vocabulary":
-        """
-        Extension to manually fill gaps in missing 'feats_labels'.
-        """
-        # Load manually tokens from pretrained file (using different strategy
-        # - only words add all embedding file, without checking if were seen
-        # in any dataset.
-        tokens_to_add = None
-        if pretrained_files and "tokens" in pretrained_files:
-            pretrained_set = set(vocabulary._read_pretrained_tokens(pretrained_files["tokens"]))
-            tokens_to_add = {"tokens": list(pretrained_set)}
-            pretrained_files = None
-
-        vocab = super().from_instances(
-            instances=instances,
-            min_count=min_count,
-            max_vocab_size=max_vocab_size,
-            non_padded_namespaces=non_padded_namespaces,
-            pretrained_files=pretrained_files,
-            only_include_pretrained_words=only_include_pretrained_words,
-            tokens_to_add=tokens_to_add,
-            min_pretrained_embeddings=min_pretrained_embeddings,
-            padding_token=padding_token,
-            oov_token=oov_token
-        )
-        # Extending vocab with features that does not show up explicitly.
-        # To know all features we need to read full dataset first.
-        # Adding auxiliary '=None' feature for each category is needed
-        # to perform classification.
-        get_slices_if_not_provided(vocab)
-        return vocab
-
-
-def get_slices_if_not_provided(vocab: allen_data.Vocabulary):
-    if hasattr(vocab, "slices"):
-        return vocab.slices
-
-    if "feats_labels" in vocab.get_namespaces():
-        idx2token = vocab.get_index_to_token_vocabulary("feats_labels")
-        for _, v in dict(idx2token).items():
-            if v not in ["_", "__PAD__"]:
-                empty_value = v.split("=")[0] + "=None"
-                vocab.add_token_to_namespace(empty_value, "feats_labels")
-
-        slices = {}
-        for idx, name in vocab.get_index_to_token_vocabulary("feats_labels").items():
-            # There are 2 types features: with (Case=Acc) or without assigment (None).
-            # Here we group their indices by name (before assigment sign).
-            name = name.split("=")[0]
-            if name in slices:
-                slices[name].append(idx)
-            else:
-                slices[name] = [idx]
-        vocab.slices = slices
-        return vocab.slices
-
-
-@dataclass(init=False, repr=False)
-class _Token(allen_data.Token):
-    __slots__ = allen_data.Token.__slots__ + ['feats_']
-
-    feats_: Optional[str]
+class DatasetReader:
+    pass
 
-    def __init__(self, text: str = None, idx: int = None, idx_end: int = None, lemma_: str = None, pos_: str = None,
-                 tag_: str = None, dep_: str = None, ent_type_: str = None, text_id: int = None, type_id: int = None,
-                 feats_: str = None) -> None:
-        super().__init__(text, idx, idx_end, lemma_, pos_, tag_, dep_, ent_type_, text_id, type_id)
-        self.feats_ = feats_
+class UniversalDependenciesDatasetReader(DatasetReader):
+    pass
\ No newline at end of file
diff --git a/combo/data/fields/__init__.py b/combo/data/fields/__init__.py
index 4f8d3a2..57de344 100644
--- a/combo/data/fields/__init__.py
+++ b/combo/data/fields/__init__.py
@@ -1 +1,2 @@
+from .base_field import Field
 from .sequence_multilabel_field import SequenceMultiLabelField
diff --git a/combo/data/fields/base_field.py b/combo/data/fields/base_field.py
new file mode 100644
index 0000000..83ea563
--- /dev/null
+++ b/combo/data/fields/base_field.py
@@ -0,0 +1,5 @@
+from abc import ABCMeta
+
+
+class Field(metaclass=ABCMeta):
+    pass
diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py
index c31f78e..b7a6bbe 100644
--- a/combo/data/fields/sequence_multilabel_field.py
+++ b/combo/data/fields/sequence_multilabel_field.py
@@ -1,18 +1,16 @@
 """Sequence multilabel field implementation."""
 import logging
-import textwrap
 from typing import Set, List, Callable, Iterator, Union, Dict
 
 import torch
-from allennlp import data
-from allennlp.common import checks
-from allennlp.data import fields
 from overrides import overrides
 
+from combo.data.fields import Field
+
 logger = logging.getLogger(__name__)
 
 
-class SequenceMultiLabelField(data.Field[torch.Tensor]):
+class SequenceMultiLabelField(Field):
     """
     A `SequenceMultiLabelField` is an extension of the :class:`MultiLabelField` that allows for multiple labels
     while keeping sequence dimension.
@@ -93,43 +91,23 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
 
     @overrides
     def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
-        if self._indexed_multi_labels is None:
-            for multi_label in self.multi_labels:
-                for label in multi_label:
-                    counter[self._label_namespace][label] += 1  # type: ignore
+        pass
 
     @overrides
     def index(self, vocab: data.Vocabulary):
-        indexer = self.multi_label_indexer(vocab)
-
-        indexed = []
-        for multi_label in self.multi_labels:
-            indexed.append(indexer(multi_label, len(self.multi_labels)))
-        self._indexed_multi_labels = indexed
+        pass
 
     @overrides
     def get_padding_lengths(self) -> Dict[str, int]:
-        return {"num_tokens": self.sequence_field.sequence_length()}
+        pass
 
     @overrides
     def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
-        return self.as_tensor_wrapper(padding_lengths)
+        pass
 
     @overrides
     def empty_field(self) -> "SequenceMultiLabelField":
-        empty_list: List[List[str]] = [[]]
-        sequence_label_field = SequenceMultiLabelField(empty_list, lambda x: lambda y: y,
-                                                       lambda x: lambda y: y,
-                                                       self.sequence_field.empty_field())
-        sequence_label_field._indexed_labels = empty_list
-        return sequence_label_field
+        pass
 
     def __str__(self) -> str:
-        length = self.sequence_field.sequence_length()
-        formatted_labels = "".join(
-            "\t\t" + labels + "\n" for labels in textwrap.wrap(repr(self.multi_labels), 100)
-        )
-        return (
-            f"SequenceMultiLabelField of length {length} with "
-            f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
-        )
+        pass
diff --git a/combo/data/samplers/__init__.py b/combo/data/samplers/__init__.py
index ab003fe..1db58a9 100644
--- a/combo/data/samplers/__init__.py
+++ b/combo/data/samplers/__init__.py
@@ -1 +1,2 @@
+from .base_sampler import Sampler
 from .samplers import TokenCountBatchSampler
diff --git a/combo/data/samplers/base_sampler.py b/combo/data/samplers/base_sampler.py
new file mode 100644
index 0000000..6e5cd40
--- /dev/null
+++ b/combo/data/samplers/base_sampler.py
@@ -0,0 +1,5 @@
+from abc import ABCMeta
+
+
+class Sampler(metaclass=ABCMeta):
+    pass
diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py
index 5db74d9..ee32a5e 100644
--- a/combo/data/samplers/samplers.py
+++ b/combo/data/samplers/samplers.py
@@ -2,11 +2,10 @@ from typing import List
 
 import numpy as np
 
-from allennlp import data as allen_data
+from combo.data.samplers import Sampler
 
 
-@allen_data.BatchSampler.register("token_count")
-class TokenCountBatchSampler(allen_data.BatchSampler):
+class TokenCountBatchSampler(Sampler):
 
     def __init__(self, dataset, word_batch_size: int = 2500, shuffle_dataset: bool = True):
         self._index = 0
diff --git a/combo/data/token.py b/combo/data/token.py
new file mode 100644
index 0000000..048fab0
--- /dev/null
+++ b/combo/data/token.py
@@ -0,0 +1,2 @@
+class Token:
+    pass
diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py
index 550a80b..d179540 100644
--- a/combo/data/token_indexers/__init__.py
+++ b/combo/data/token_indexers/__init__.py
@@ -1,3 +1,4 @@
+from .base_indexer import TokenIndexer
 from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
 from .token_characters_indexer import TokenCharactersIndexer
 from .token_features_indexer import TokenFeatsIndexer
diff --git a/combo/data/token_indexers/base_indexer.py b/combo/data/token_indexers/base_indexer.py
new file mode 100644
index 0000000..2fb48c0
--- /dev/null
+++ b/combo/data/token_indexers/base_indexer.py
@@ -0,0 +1,9 @@
+from abc import ABCMeta
+
+
+class TokenIndexer(metaclass=ABCMeta):
+    pass
+
+
+class PretrainedTransformerMismatchedIndexer(TokenIndexer):
+    pass
diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
index fc29896..9aa3616 100644
--- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
@@ -1,117 +1,13 @@
-from typing import Optional, Dict, Any, List, Tuple
+from combo.data import TokenIndexer
 
-from allennlp import data
-from allennlp.data import token_indexers, tokenizers, IndexedTokenList, vocabulary
-from overrides import overrides
 
+class PretrainedTransformerMismatchedIndexer(TokenIndexer):
+    pass
 
-@data.TokenIndexer.register("pretrained_transformer_mismatched_fixed")
-class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
 
-    def __init__(self, model_name: str, namespace: str = "tags", max_length: int = None,
-                 tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> None:
-        # The matched version v.s. mismatched
-        super().__init__(model_name, namespace, max_length, tokenizer_kwargs, **kwargs)
-        self._matched_indexer = PretrainedTransformerIndexer(
-            model_name,
-            namespace=namespace,
-            max_length=max_length,
-            tokenizer_kwargs=tokenizer_kwargs,
-            **kwargs,
-        )
-        self._allennlp_tokenizer = self._matched_indexer._allennlp_tokenizer
-        self._tokenizer = self._matched_indexer._tokenizer
-        self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens
-        self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens
+class PretrainedTransformerIndexer(TokenIndexer):
+    pass
 
-    @overrides
-    def tokens_to_indices(self,
-                          tokens,
-                          vocabulary: vocabulary ) -> IndexedTokenList:
-        """
-        Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the
-        maximal input of a model.
-        """
-        self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
 
-        wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
-            [t.ensure_text() for t in tokens])
-
-        if len(wordpieces) > self._tokenizer.max_len_single_sentence:
-            raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
-                             " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
-                             f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
-                             f"Current input: {len(wordpieces)}")
-
-        offsets = [x if x is not None else (-1, -1) for x in offsets]
-
-        output: IndexedTokenList = {
-            "token_ids": [t.text_id for t in wordpieces],
-            "mask": [True] * len(tokens),  # for original tokens (i.e. word-level)
-            "type_ids": [t.type_id for t in wordpieces],
-            "offsets": offsets,
-            "wordpiece_mask": [True] * len(wordpieces),  # for wordpieces (i.e. subword-level)
-        }
-
-        return self._matched_indexer._postprocess_output(output)
-
-
-class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer):
-
-    def __init__(
-            self,
-            model_name: str,
-            namespace: str = "tags",
-            max_length: int = None,
-            tokenizer_kwargs: Optional[Dict[str, Any]] = None,
-            **kwargs,
-    ) -> None:
-        super().__init__(model_name, namespace, max_length, tokenizer_kwargs, **kwargs)
-        self._namespace = namespace
-        self._allennlp_tokenizer = PretrainedTransformerTokenizer(
-            model_name, tokenizer_kwargs=tokenizer_kwargs
-        )
-        self._tokenizer = self._allennlp_tokenizer.tokenizer
-        self._added_to_vocabulary = False
-
-        self._num_added_start_tokens = len(self._allennlp_tokenizer.single_sequence_start_tokens)
-        self._num_added_end_tokens = len(self._allennlp_tokenizer.single_sequence_end_tokens)
-
-        self._max_length = max_length
-        if self._max_length is not None:
-            num_added_tokens = len(self._allennlp_tokenizer.tokenize("a")) - 1
-            self._effective_max_length = (  # we need to take into account special tokens
-                    self._max_length - num_added_tokens
-            )
-            if self._effective_max_length <= 0:
-                raise ValueError(
-                    "max_length needs to be greater than the number of special tokens inserted."
-                )
-
-
-class PretrainedTransformerTokenizer(tokenizers.PretrainedTransformerTokenizer):
-
-    def _intra_word_tokenize(
-            self, string_tokens: List[str]
-    ) -> Tuple[List[data.Token], List[Optional[Tuple[int, int]]]]:
-        tokens: List[data.Token] = []
-        offsets: List[Optional[Tuple[int, int]]] = []
-        for token_string in string_tokens:
-            wordpieces = self.tokenizer.encode_plus(
-                token_string,
-                add_special_tokens=False,
-                return_tensors=None,
-                return_offsets_mapping=False,
-                return_attention_mask=False,
-            )
-            wp_ids = wordpieces["input_ids"]
-
-            if len(wp_ids) > 0:
-                offsets.append((len(tokens), len(tokens) + len(wp_ids) - 1))
-                tokens.extend(
-                    data.Token(text=wp_text, text_id=wp_id)
-                    for wp_id, wp_text in zip(wp_ids, self.tokenizer.convert_ids_to_tokens(wp_ids))
-                )
-            else:
-                offsets.append(None)
-        return tokens, offsets
+class PretrainedTransformerTokenizer(TokenIndexer):
+    pass
diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py
index ea7a3ea..6f5dbf0 100644
--- a/combo/data/token_indexers/token_characters_indexer.py
+++ b/combo/data/token_indexers/token_characters_indexer.py
@@ -1,62 +1,6 @@
-"""Custom character token indexer."""
-import itertools
-from typing import List, Dict
+from combo.data import TokenIndexer
 
-import torch
-from allennlp import data
-from allennlp.common import util
-from allennlp.data import tokenizers
-from allennlp.data.token_indexers import token_characters_indexer
-from overrides import overrides
 
-
-@data.TokenIndexer.register("characters_const_padding")
-class TokenCharactersIndexer(token_characters_indexer.TokenCharactersIndexer):
+class TokenCharactersIndexer(TokenIndexer):
     """Wrapper around allennlp token indexer with const padding."""
-
-    def __init__(self,
-                 namespace: str = "token_characters",
-                 character_tokenizer: tokenizers.CharacterTokenizer = tokenizers.CharacterTokenizer(),
-                 start_tokens: List[str] = None,
-                 end_tokens: List[str] = None,
-                 min_padding_length: int = 0,
-                 token_min_padding_length: int = 0):
-        super().__init__(namespace, character_tokenizer, start_tokens, end_tokens, min_padding_length,
-                         token_min_padding_length)
-
-    @overrides
-    def get_padding_lengths(self, indexed_tokens: data.IndexedTokenList) -> Dict[str, int]:
-        padding_lengths = {"token_characters": len(indexed_tokens["token_characters"]),
-                           "num_token_characters": self._min_padding_length}
-        return padding_lengths
-
-    @overrides
-    def as_padded_tensor_dict(
-            self, tokens: data.IndexedTokenList, padding_lengths: Dict[str, int]
-    ) -> Dict[str, torch.Tensor]:
-        # Pad the tokens.
-        padded_tokens = util.pad_sequence_to_length(
-            tokens["token_characters"],
-            padding_lengths["token_characters"],
-            default_value=lambda: [],
-        )
-
-        # Pad the characters within the tokens.
-        desired_token_length = padding_lengths["num_token_characters"]
-        longest_token: List[int] = max(tokens["token_characters"], key=len, default=[])  # type: ignore
-        padding_value = 0
-        if desired_token_length > len(longest_token):
-            # Since we want to pad to greater than the longest token, we add a
-            # "dummy token" so we can take advantage of the fast implementation of itertools.zip_longest.
-            padded_tokens.append([padding_value] * desired_token_length)
-        # pad the list of lists to the longest sublist, appending 0's
-        padded_tokens = list(zip(*itertools.zip_longest(*padded_tokens, fillvalue=padding_value)))
-        if desired_token_length > len(longest_token):
-            # Removes the "dummy token".
-            padded_tokens.pop()
-        # Truncates all the tokens to the desired length, and return the result.
-        return {
-            "token_characters": torch.LongTensor(
-                [list(token[:desired_token_length]) for token in padded_tokens]
-            )
-        }
+    pass
diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py
index 7c59124..b6267a4 100644
--- a/combo/data/token_indexers/token_features_indexer.py
+++ b/combo/data/token_indexers/token_features_indexer.py
@@ -1,75 +1,7 @@
 """Features indexer."""
-import collections
-from typing import List, Dict
 
-import torch
-from allennlp import data
-from allennlp.common import util
-from overrides import overrides
+from combo.data import TokenIndexer
 
 
-@data.TokenIndexer.register("feats_indexer")
-class TokenFeatsIndexer(data.TokenIndexer):
-
-    def __init__(
-            self,
-            namespace: str = "feats",
-            feature_name: str = "feats_",
-            token_min_padding_length: int = 0,
-    ) -> None:
-        super().__init__(token_min_padding_length)
-        self.namespace = namespace
-        self._feature_name = feature_name
-
-    @overrides
-    def count_vocab_items(self, token: data.Token, counter: Dict[str, Dict[str, int]]):
-        feats = self._feat_values(token)
-        for feat in feats:
-            counter[self.namespace][feat] += 1
-
-    @overrides
-    def tokens_to_indices(self, tokens: List[data.Token], vocabulary: data.Vocabulary) -> data.IndexedTokenList:
-        indices: List[List[int]] = []
-        vocab_size = vocabulary.get_vocab_size(self.namespace)
-        for token in tokens:
-            token_indices = []
-            feats = self._feat_values(token)
-            for feat in feats:
-                token_indices.append(vocabulary.get_token_index(feat, self.namespace))
-            indices.append(util.pad_sequence_to_length(token_indices, vocab_size))
-        return {"tokens": indices}
-
-    @overrides
-    def get_empty_token_list(self) -> data.IndexedTokenList:
-        return {"tokens": [[]]}
-
-    def _feat_values(self, token):
-        feats = getattr(token, self._feature_name)
-        if feats is None:
-            feats = collections.OrderedDict()
-        features = []
-        for feat, value in feats.items():
-            if feat in ["_", "__ROOT__"]:
-                pass
-            else:
-                # Handle case where feature is binary (doesn't have associated value)
-                if value:
-                    features.append(feat + "=" + value)
-                else:
-                    features.append(feat)
-        return features
-
-    @overrides
-    def as_padded_tensor_dict(
-            self, tokens: data.IndexedTokenList, padding_lengths: Dict[str, int]
-    ) -> Dict[str, torch.Tensor]:
-        tensor_dict = {}
-        for key, val in tokens.items():
-            vocab_size = len(val[0])
-            tensor = torch.tensor(util.pad_sequence_to_length(val,
-                                                              padding_lengths[key],
-                                                              default_value=lambda: [0] * vocab_size,
-                                                              )
-                                  )
-            tensor_dict[key] = tensor
-        return tensor_dict
+class TokenFeatsIndexer(TokenIndexer):
+    pass
diff --git a/combo/main.py b/combo/main.py
new file mode 100644
index 0000000..1d4d945
--- /dev/null
+++ b/combo/main.py
@@ -0,0 +1,150 @@
+import logging
+import os
+import pathlib
+import tempfile
+from typing import Dict
+
+import torch
+from absl import app
+from absl import flags
+
+from combo import models
+from combo.models.base import Predictor
+from combo.utils import checks
+
+logger = logging.getLogger(__name__)
+_FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"]
+_TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent", "deps"]
+
+FLAGS = flags.FLAGS
+flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict"],
+                  help="Specify COMBO mode: train or predict")
+
+# Common flags
+flags.DEFINE_integer(name="cuda_device", default=-1,
+                     help="Cuda device id (default -1 cpu)")
+flags.DEFINE_string(name="output_file", default="output.log",
+                    help="Predictions result file.")
+
+# Training flags
+flags.DEFINE_list(name="training_data_path", default=[],
+                  help="Training data path(s)")
+flags.DEFINE_alias(name="training_data", original_name="training_data_path")
+flags.DEFINE_list(name="validation_data_path", default="",
+                  help="Validation data path(s)")
+flags.DEFINE_alias(name="validation_data", original_name="validation_data_path")
+flags.DEFINE_string(name="pretrained_tokens", default="",
+                    help="Pretrained tokens embeddings path")
+flags.DEFINE_integer(name="embedding_dim", default=300,
+                     help="Embeddings dim")
+flags.DEFINE_integer(name="num_epochs", default=400,
+                     help="Epochs num")
+flags.DEFINE_integer(name="word_batch_size", default=2500,
+                     help="Minimum words in batch")
+flags.DEFINE_string(name="pretrained_transformer_name", default="",
+                    help="Pretrained transformer model name (see transformers from HuggingFace library for list of "
+                         "available models) for transformers based embeddings.")
+flags.DEFINE_list(name="features", default=["token", "char"],
+                  help=f"Features used to train model (required 'token' and 'char'). Possible values: {_FEATURES}.")
+flags.DEFINE_list(name="targets", default=["deprel", "feats", "head", "lemma", "upostag", "xpostag"],
+                  help=f"Targets of the model (required `deprel` and `head`). Possible values: {_TARGETS}.")
+flags.DEFINE_string(name="serialization_dir", default=None,
+                    help="Model serialization directory (default - system temp dir).")
+flags.DEFINE_boolean(name="tensorboard", default=False,
+                     help="When provided model will log tensorboard metrics.")
+
+# Finetune after training flags
+flags.DEFINE_list(name="finetuning_training_data_path", default="",
+                  help="Training data path(s)")
+flags.DEFINE_list(name="finetuning_validation_data_path", default="",
+                  help="Validation data path(s)")
+flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "config.template.jsonnet"),
+                    help="Config file path.")
+
+# Test after training flags
+flags.DEFINE_string(name="test_path", default=None,
+                    help="Test path file.")
+
+# Experimental
+flags.DEFINE_boolean(name="use_pure_config", default=False,
+                     help="Ignore ext flags (experimental).")
+
+# Prediction flags
+flags.DEFINE_string(name="model_path", default=None,
+                    help="Pretrained model path.")
+flags.DEFINE_string(name="input_file", default=None,
+                    help="File to predict path")
+flags.DEFINE_boolean(name="conllu_format", default=True,
+                     help="Prediction based on conllu format (instead of raw text).")
+flags.DEFINE_integer(name="batch_size", default=1,
+                     help="Prediction batch size.")
+flags.DEFINE_boolean(name="silent", default=True,
+                     help="Silent prediction to file (without printing to console).")
+flags.DEFINE_enum(name="predictor_name", default="combo-spacy",
+                  enum_values=["combo", "combo-spacy"],
+                  help="Use predictor with whitespace or spacy tokenizer.")
+
+
+def run(_):
+    pass
+
+
+def _get_predictor() -> Predictor:
+    # Check for GPU
+    # allen_checks.check_for_gpu(FLAGS.cuda_device)
+    checks.file_exists(FLAGS.model_path)
+    # load model from archive
+    # archive = models.load_archive(
+    #     FLAGS.model_path,
+    #     cuda_device=FLAGS.cuda_device,
+    # )
+    # return predictors.Predictor.from_archive(
+    #     archive, FLAGS.predictor_name
+    # )
+    return Predictor()
+
+
+def _get_ext_vars(finetuning: bool = False) -> Dict:
+    if FLAGS.use_pure_config:
+        return {}
+    return {
+        "training_data_path": (
+            ",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)),
+        "validation_data_path": (
+            ",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)),
+        "pretrained_tokens": FLAGS.pretrained_tokens,
+        "pretrained_transformer_name": FLAGS.pretrained_transformer_name,
+        "features": " ".join(FLAGS.features),
+        "targets": " ".join(FLAGS.targets),
+        "type": "finetuning" if finetuning else "default",
+        "embedding_dim": str(FLAGS.embedding_dim),
+        "cuda_device": str(FLAGS.cuda_device),
+        "num_epochs": str(FLAGS.num_epochs),
+        "word_batch_size": str(FLAGS.word_batch_size),
+        "use_tensorboard": str(FLAGS.tensorboard),
+    }
+
+
+def main():
+    """Parse flags."""
+    flags.register_validator(
+        "features",
+        lambda values: all(
+            value in _FEATURES for value in values),
+        message="Flags --features contains unknown value(s)."
+    )
+    flags.register_validator(
+        "mode",
+        lambda value: value is not None,
+        message="Flag --mode must be set with either `predict` or `train` value")
+    flags.register_validator(
+        "targets",
+        lambda values: all(
+            value in _TARGETS for value in values),
+        message="Flag --targets contains unknown value(s)."
+    )
+    app.run(run)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/combo/models/__init__.py b/combo/models/__init__.py
index e69de29..9ad341a 100644
--- a/combo/models/__init__.py
+++ b/combo/models/__init__.py
@@ -0,0 +1,8 @@
+from .base import FeedForwardPredictor
+from .graph_parser import GraphDependencyRelationModel
+from .parser import DependencyRelationModel
+from .embeddings import CharacterBasedWordEmbeddings
+from .encoder import ComboEncoder
+from .lemma import LemmatizerModel
+from .model import ComboModel
+from .morpho import MorphologicalFeatures
diff --git a/combo/models/base.py b/combo/models/base.py
index 11df601..016ef85 100644
--- a/combo/models/base.py
+++ b/combo/models/base.py
@@ -4,7 +4,11 @@ import torch
 import torch.nn as nn
 import utils
 import combo.models.combo_nn as combo_nn
-import combo.checks as checks
+import combo.utils.checks as checks
+
+
+class Model:
+    pass
 
 
 class Predictor(nn.Module):
diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
index e69de29..e5ebff8 100644
--- a/combo/models/embeddings.py
+++ b/combo/models/embeddings.py
@@ -0,0 +1,24 @@
+class Embedding:
+    pass
+
+class TokenEmbedder:
+    pass
+
+class CharacterBasedWordEmbeddings(TokenEmbedder):
+    pass
+
+
+class ProjectedWordEmbedder(TokenEmbedder):
+    pass
+
+
+class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
+    pass
+
+
+class TransformersWordEmbedder(PretrainedTransformerMismatchedEmbedder):
+    pass
+
+
+class FeatsTokenEmbedder(TokenEmbedder):
+    pass
\ No newline at end of file
diff --git a/combo/models/encoder.py b/combo/models/encoder.py
new file mode 100644
index 0000000..f25f4a1
--- /dev/null
+++ b/combo/models/encoder.py
@@ -0,0 +1,10 @@
+class Encoder:
+    pass
+
+
+class StackedBidirectionalLstm(Encoder):
+    pass
+
+
+class ComboEncoder(Encoder):
+    pass
diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py
new file mode 100644
index 0000000..82eefde
--- /dev/null
+++ b/combo/models/graph_parser.py
@@ -0,0 +1,9 @@
+from combo.models.base import Predictor
+
+
+class GraphHeadPredictionModel(Predictor):
+    pass
+
+
+class GraphDependencyRelationModel(Predictor):
+    pass
diff --git a/combo/models/lemma.py b/combo/models/lemma.py
new file mode 100644
index 0000000..77254f6
--- /dev/null
+++ b/combo/models/lemma.py
@@ -0,0 +1,5 @@
+from combo.models.base import Predictor
+
+
+class LemmatizerModel(Predictor):
+    pass
diff --git a/combo/models/model.py b/combo/models/model.py
new file mode 100644
index 0000000..134a606
--- /dev/null
+++ b/combo/models/model.py
@@ -0,0 +1,5 @@
+from combo.models.base import Model
+
+
+class ComboModel(Model):
+    pass
\ No newline at end of file
diff --git a/combo/models/morpho.py b/combo/models/morpho.py
new file mode 100644
index 0000000..4d65686
--- /dev/null
+++ b/combo/models/morpho.py
@@ -0,0 +1,5 @@
+from combo.models.base import Predictor
+
+
+class MorphologicalFeatures(Predictor):
+    pass
diff --git a/combo/models/parser.py b/combo/models/parser.py
new file mode 100644
index 0000000..46e3766
--- /dev/null
+++ b/combo/models/parser.py
@@ -0,0 +1,9 @@
+from combo.models.base import Predictor
+
+
+class HeadPredictionModel(Predictor):
+    pass
+
+
+class DependencyRelationModel(Predictor):
+    pass
diff --git a/combo/predict.py b/combo/predict.py
new file mode 100644
index 0000000..bdf9797
--- /dev/null
+++ b/combo/predict.py
@@ -0,0 +1,15 @@
+import logging
+import os
+import sys
+from typing import List, Union, Dict, Any
+
+from combo import data
+from combo.data import sentence2conllu, tokens2conllu, conllu2sentence
+from combo.models.base import Predictor
+from combo.utils import download, graph
+
+logger = logging.getLogger(__name__)
+
+
+class COMBO(Predictor):
+    pass
diff --git a/combo/training/__init__.py b/combo/training/__init__.py
new file mode 100644
index 0000000..94da4e6
--- /dev/null
+++ b/combo/training/__init__.py
@@ -0,0 +1,3 @@
+from .checkpointer import FinishingTrainingCheckpointer
+from .scheduler import Scheduler
+from .trainer import GradientDescentTrainer
diff --git a/combo/training/checkpointer.py b/combo/training/checkpointer.py
new file mode 100644
index 0000000..9823f14
--- /dev/null
+++ b/combo/training/checkpointer.py
@@ -0,0 +1,6 @@
+class Checkpointer:
+    pass
+
+
+class FinishingTrainingCheckpointer:
+    pass
diff --git a/combo/training/scheduler.py b/combo/training/scheduler.py
new file mode 100644
index 0000000..eff71b9
--- /dev/null
+++ b/combo/training/scheduler.py
@@ -0,0 +1,2 @@
+class Scheduler:
+    pass
diff --git a/combo/training/tensorboard_writer.py b/combo/training/tensorboard_writer.py
new file mode 100644
index 0000000..f56966c
--- /dev/null
+++ b/combo/training/tensorboard_writer.py
@@ -0,0 +1,2 @@
+class NullTensorboardWriter:
+    pass
\ No newline at end of file
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
new file mode 100644
index 0000000..ebb8ff0
--- /dev/null
+++ b/combo/training/trainer.py
@@ -0,0 +1,13 @@
+from pytorch_lightning import Trainer
+
+
+class Callback:
+    pass
+
+
+class TransferPatienceEpochCallback:
+    pass
+
+
+class GradientDescentTrainer(Trainer):
+    pass
diff --git a/combo/utils/__init__.py b/combo/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/combo/utils/checks.py b/combo/utils/checks.py
new file mode 100644
index 0000000..ea76ac1
--- /dev/null
+++ b/combo/utils/checks.py
@@ -0,0 +1,15 @@
+import torch
+
+
+class ConfigurationError(Exception):
+    def __init__(self, message: str):
+        super().__init__()
+        self.message = message
+
+
+def file_exists(*paths):
+    pass
+
+
+def check_size_match(size_1: torch.Size, size_2: torch.Size, tensor_1_name: str, tensor_2_name: str):
+    pass
diff --git a/combo/utils/download.py b/combo/utils/download.py
new file mode 100644
index 0000000..5c7ce6f
--- /dev/null
+++ b/combo/utils/download.py
@@ -0,0 +1,78 @@
+import errno
+import logging
+import os
+
+import requests
+import tqdm
+import urllib3
+from requests import adapters, exceptions
+
+logger = logging.getLogger(__name__)
+
+DATA_TO_PATH = {
+    "enhanced" : "iwpt_2020",
+    "iwpt2021" : "iwpt_2021",
+    "ud25" : "ud_25",
+    "ud27" : "ud_27",
+    "ud29" : "ud_29"}
+_URL = "http://s3.clarin-pl.eu/dspace/combo/{data}/{model}.tar.gz"
+_HOME_DIR = os.getenv("HOME", os.curdir)
+_CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo"))
+
+
+def download_file(model_name, force=False):
+    _make_cache_dir()
+    data = model_name.split("-")[-1]
+    url = _URL.format(model=model_name, data=DATA_TO_PATH[data])
+    local_filename = url.split("/")[-1]
+    location = os.path.join(_CACHE_DIR, local_filename)
+    if os.path.exists(location) and not force:
+        logger.debug("Using cached model.")
+        return location
+    chunk_size = 1024
+    logger.info(url)
+    try:
+        with _requests_retry_session(retries=2).get(url, stream=True) as r:
+            pbar = tqdm.tqdm(unit="B", total=int(r.headers.get("content-length")),
+                             unit_divisor=chunk_size, unit_scale=True)
+            with open(location, "wb") as f:
+                with pbar:
+                    for chunk in r.iter_content(chunk_size):
+                        if chunk:
+                            f.write(chunk)
+                            pbar.update(len(chunk))
+    except exceptions.RetryError:
+        raise ConnectionError(f"Couldn't find or download model {model_name}.tar.gz. "
+                              "Check if model name is correct or try again later!")
+
+    return location
+
+
+def _make_cache_dir():
+    try:
+        os.makedirs(_CACHE_DIR)
+        logger.info(f"Making cache dir {_CACHE_DIR}")
+    except OSError as e:
+        if e.errno != errno.EEXIST:
+            raise
+
+
+def _requests_retry_session(
+    retries=3,
+    backoff_factor=0.3,
+    status_forcelist=(404, 500, 502, 504),
+    session=None,
+):
+    """Source: https://www.peterbe.com/plog/best-practice-with-retries-with-requests"""
+    session = session or requests.Session()
+    retry = urllib3.Retry(
+        total=retries,
+        read=retries,
+        connect=retries,
+        backoff_factor=backoff_factor,
+        status_forcelist=status_forcelist,
+    )
+    adapter = adapters.HTTPAdapter(max_retries=retry)
+    session.mount("http://", adapter)
+    session.mount("https://", adapter)
+    return session
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
new file mode 100644
index 0000000..f61a68e
--- /dev/null
+++ b/combo/utils/graph.py
@@ -0,0 +1,149 @@
+"""Based on https://github.com/emorynlp/iwpt-shared-task-2020."""
+
+import numpy as np
+
+_ACL_REL_CL = "acl:relcl"
+
+
+def graph_and_tree_merge(tree_arc_scores,
+                         tree_rel_scores,
+                         graph_arc_scores,
+                         graph_rel_scores,
+                         label2idx,
+                         idx2label,
+                         graph_label2idx,
+                         graph_idx2label,
+                         tokens):
+    graph_arc_scores = np.copy(graph_arc_scores)
+    # Exclude self-loops, in-place operation.
+    np.fill_diagonal(graph_arc_scores, 0)
+    # Connection to root will be handled by tree.
+    graph_arc_scores[:, 0] = False
+    # The same with labels.
+    root_idx = graph_label2idx["root"]
+    graph_rel_scores[:, :, root_idx] = -float('inf')
+    graph_rel_pred = graph_rel_scores.argmax(-1)
+
+    # Add tree edges to graph
+    tree_heads = [0] + tree_arc_scores
+    graph = [[] for _ in range(len(tree_heads))]
+    labeled_graph = [[] for _ in range(len(tree_heads))]
+    for d, h in enumerate(tree_heads):
+        if not d:
+            continue
+        label = idx2label[tree_rel_scores[d - 1]]
+        # graph_label = graph_idx2label[graph_rel_pred[d - 1][h - 1]]
+        # if ">" in graph_label and label in graph_label:
+        #     print("Using graph label instead of tree.")
+        #     label = graph_label
+        if label != _ACL_REL_CL:
+            graph[h].append(d)
+            labeled_graph[h].append((d, label))
+
+    # Debug only
+    # Extract graph edges
+    graph_edges = np.argwhere(graph_arc_scores)
+
+    # Add graph edges which aren't creating a cycle
+    for (d, h) in graph_edges:
+        if not d or not h or d in graph[h]:
+            continue
+        try:
+            path = next(_dfs(graph, d, h))
+        except StopIteration:
+            # There is not path from d to h
+            label = graph_idx2label[graph_rel_pred[d][h]]
+            if label != _ACL_REL_CL:
+                graph[h].append(d)
+                labeled_graph[h].append((d, label))
+
+    # Add 'acl:relcl' without checking for cycles.
+    for d, h in enumerate(tree_heads):
+        if not d:
+            continue
+        label = idx2label[tree_rel_scores[d - 1]]
+        if label == _ACL_REL_CL:
+            graph[h].append(d)
+            labeled_graph[h].append((d, label))
+
+    assert len(labeled_graph[0]) == 1
+    d = graph[0][0]
+    graph[d].append(0)
+    labeled_graph[d].append((0, "root"))
+
+    parse_graph = [[] for _ in range(len(tree_heads))]
+    for h in range(len(tree_heads)):
+        for d, label in labeled_graph[h]:
+            parse_graph[d].append((h, label))
+        parse_graph[d] = sorted(parse_graph[d])
+
+    for i, g in enumerate(parse_graph):
+        heads = np.array([x[0] for x in g])
+        rels = np.array([x[1] for x in g])
+        indices = rels.argsort()
+        heads = heads[indices].tolist()
+        rels = rels[indices].tolist()
+        deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
+        tokens[i - 1]["deps"] = deps
+    return
+
+
+def _dfs(graph, start, end):
+    fringe = [(start, [])]
+    while fringe:
+        state, path = fringe.pop()
+        if path and state == end:
+            yield path
+            continue
+        for next_state in graph[state]:
+            if next_state in path:
+                continue
+            fringe.append((next_state, path + [next_state]))
+
+
+def restore_collapse_edges(tree_tokens):
+    # https://gist.github.com/hankcs/776e7d95c19e5ff5da8469fe4e9ab050
+    empty_tokens = []
+    for token in tree_tokens:
+        deps = token["deps"].split("|")
+        for i, d in enumerate(deps):
+            if ">" in d:
+                # {head}:{empty_node_relation}>{current_node_relation}
+                # should map to
+                # For new, empty node:
+                # {head}:{empty_node_relation}
+                # For current node:
+                # {new_empty_node_id}:{current_node_relation}
+                # TODO consider where to put new_empty_node_id (currently at the end)
+                head, relation = d.split(':', 1)
+                ehead = f"{len(tree_tokens)}.{len(empty_tokens) + 1}"
+                empty_node_relation, current_node_relation = relation.split(">", 1)
+                # Edge case, double >
+                if ">" in current_node_relation:
+                    second_empty_node_relation, current_node_relation = current_node_relation.split(">")
+                    deps[i] = f"{ehead}:{current_node_relation}"
+                    second_ehead = f"{len(tree_tokens)}.{len(empty_tokens) + 2}"
+                    empty_tokens.append(
+                        {
+                            "id": ehead,
+                            "deps": f"{second_ehead}:{empty_node_relation}"
+                        }
+                    )
+                    empty_tokens.append(
+                        {
+                            "id": second_ehead,
+                            "deps": f"{head}:{second_empty_node_relation}"
+                        }
+                    )
+
+                else:
+                    deps[i] = f"{ehead}:{current_node_relation}"
+                    empty_tokens.append(
+                        {
+                            "id": ehead,
+                            "deps": f"{head}:{empty_node_relation}"
+                        }
+                    )
+        deps = sorted([d.split(":", 1) for d in deps], key=lambda x: float(x[0]))
+        token["deps"] = "|".join([f"{k}:{v}" for k, v in deps])
+    return empty_tokens
diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py
new file mode 100644
index 0000000..9181040
--- /dev/null
+++ b/combo/utils/metrics.py
@@ -0,0 +1,17 @@
+class Metric:
+    pass
+
+class LemmaAccuracy(Metric):
+    pass
+
+
+class SequenceBoolAccuracy(Metric):
+    pass
+
+
+class AttachmentScores(Metric):
+    pass
+
+
+class SemanticMetrics(Metric):
+    pass
diff --git a/requirements.txt b/requirements.txt
index 61f0590..718d417 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,11 @@
+absl-py~=1.4.0
 conllu~=4.4.1
 dependency-injector~=4.41.0
 overrides~=7.3.1
 torch~=1.13.1
 torchtext~=0.14.1
 numpy~=1.24.1
-pytorch-lightning~=1.9.0
\ No newline at end of file
+pytorch-lightning~=1.9.0
+requests~=2.28.2
+tqdm~=4.64.1
+urllib3~=1.26.14
\ No newline at end of file
-- 
GitLab