From 16988605fe3645fc8caef5facc9ae7afbc0acb3f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com>
Date: Sun, 5 Mar 2023 20:07:49 +0100
Subject: [PATCH] First commit

---
 .idea/.gitignore                              |   8 +
 .idea/combolightning.iml                      |   8 +
 .idea/inspectionProfiles/Project_Default.xml  |  32 ++
 .../inspectionProfiles/profiles_settings.xml  |   6 +
 .idea/misc.xml                                |   4 +
 .idea/modules.xml                             |   8 +
 combo/__init__.py                             |   0
 combo/checks.py                               |   4 +
 combo/config/__init__.py                      |   0
 combo/data/__init__.py                        |   3 +
 combo/data/api.py                             |  98 ++++++
 combo/data/dataset.py                         | 317 ++++++++++++++++++
 combo/data/fields/__init__.py                 |   1 +
 .../data/fields/sequence_multilabel_field.py  | 135 ++++++++
 combo/data/samplers/__init__.py               |   1 +
 combo/data/samplers/samplers.py               |  54 +++
 combo/data/token_indexers/__init__.py         |   3 +
 ...etrained_transformer_mismatched_indexer.py | 117 +++++++
 .../token_characters_indexer.py               |  62 ++++
 .../token_indexers/token_features_indexer.py  |  75 +++++
 combo/data/vocabulary.py                      |  98 ++++++
 combo/models/__init__.py                      |   0
 combo/models/base.py                          | 215 ++++++++++++
 combo/models/combo_nn.py                      |   7 +
 combo/models/embeddings.py                    |   0
 combo/models/utils.py                         |   7 +
 docs/Configuration.md                         |   4 +
 example.conllu                                |  13 +
 main.py                                       |  16 +
 requirements.txt                              |   7 +
 30 files changed, 1303 insertions(+)
 create mode 100644 .idea/.gitignore
 create mode 100644 .idea/combolightning.iml
 create mode 100644 .idea/inspectionProfiles/Project_Default.xml
 create mode 100644 .idea/inspectionProfiles/profiles_settings.xml
 create mode 100644 .idea/misc.xml
 create mode 100644 .idea/modules.xml
 create mode 100644 combo/__init__.py
 create mode 100644 combo/checks.py
 create mode 100644 combo/config/__init__.py
 create mode 100644 combo/data/__init__.py
 create mode 100644 combo/data/api.py
 create mode 100644 combo/data/dataset.py
 create mode 100644 combo/data/fields/__init__.py
 create mode 100644 combo/data/fields/sequence_multilabel_field.py
 create mode 100644 combo/data/samplers/__init__.py
 create mode 100644 combo/data/samplers/samplers.py
 create mode 100644 combo/data/token_indexers/__init__.py
 create mode 100644 combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
 create mode 100644 combo/data/token_indexers/token_characters_indexer.py
 create mode 100644 combo/data/token_indexers/token_features_indexer.py
 create mode 100644 combo/data/vocabulary.py
 create mode 100644 combo/models/__init__.py
 create mode 100644 combo/models/base.py
 create mode 100644 combo/models/combo_nn.py
 create mode 100644 combo/models/embeddings.py
 create mode 100644 combo/models/utils.py
 create mode 100644 docs/Configuration.md
 create mode 100644 example.conllu
 create mode 100644 main.py
 create mode 100644 requirements.txt

diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..73f69e0
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
+# Editor-based HTTP Client requests
+/httpRequests/
diff --git a/.idea/combolightning.iml b/.idea/combolightning.iml
new file mode 100644
index 0000000..4154233
--- /dev/null
+++ b/.idea/combolightning.iml
@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$" />
+    <orderEntry type="jdk" jdkName="Python 3.9 (combolightning)" jdkType="Python SDK" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+</module>
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..11a85e3
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,32 @@
+<component name="InspectionProjectProfileManager">
+  <profile version="1.0">
+    <option name="myName" value="Project Default" />
+    <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
+      <option name="ignoredPackages">
+        <value>
+          <list size="19">
+            <item index="0" class="java.lang.String" itemvalue="Click" />
+            <item index="1" class="java.lang.String" itemvalue="numpy" />
+            <item index="2" class="java.lang.String" itemvalue="importlib-metadata" />
+            <item index="3" class="java.lang.String" itemvalue="absl-py" />
+            <item index="4" class="java.lang.String" itemvalue="scipy" />
+            <item index="5" class="java.lang.String" itemvalue="transformers" />
+            <item index="6" class="java.lang.String" itemvalue="pytest" />
+            <item index="7" class="java.lang.String" itemvalue="scikit-learn" />
+            <item index="8" class="java.lang.String" itemvalue="allennlp" />
+            <item index="9" class="java.lang.String" itemvalue="torch" />
+            <item index="10" class="java.lang.String" itemvalue="requests" />
+            <item index="11" class="java.lang.String" itemvalue="conllu" />
+            <item index="12" class="java.lang.String" itemvalue="tqdm" />
+            <item index="13" class="java.lang.String" itemvalue="jsonnet" />
+            <item index="14" class="java.lang.String" itemvalue="filelock" />
+            <item index="15" class="java.lang.String" itemvalue="pylint" />
+            <item index="16" class="java.lang.String" itemvalue="urllib3" />
+            <item index="17" class="java.lang.String" itemvalue="pylint-quotes" />
+            <item index="18" class="java.lang.String" itemvalue="jsonnet-binary" />
+          </list>
+        </value>
+      </option>
+    </inspection_tool>
+  </profile>
+</component>
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+<component name="InspectionProjectProfileManager">
+  <settings>
+    <option name="USE_PROJECT_PROFILE" value="false" />
+    <version value="1.0" />
+  </settings>
+</component>
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..fcbd09d
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (combolightning)" project-jdk-type="Python SDK" />
+</project>
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..89bf708
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/combolightning.iml" filepath="$PROJECT_DIR$/.idea/combolightning.iml" />
+    </modules>
+  </component>
+</project>
\ No newline at end of file
diff --git a/combo/__init__.py b/combo/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/combo/checks.py b/combo/checks.py
new file mode 100644
index 0000000..7736e96
--- /dev/null
+++ b/combo/checks.py
@@ -0,0 +1,4 @@
+class ConfigurationError(Exception):
+    def __init__(self, message: str):
+        super().__init__()
+        self.message = message
diff --git a/combo/config/__init__.py b/combo/config/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/combo/data/__init__.py b/combo/data/__init__.py
new file mode 100644
index 0000000..91426e3
--- /dev/null
+++ b/combo/data/__init__.py
@@ -0,0 +1,3 @@
+from .samplers import TokenCountBatchSampler
+from .token_indexers import *
+from .api import *
diff --git a/combo/data/api.py b/combo/data/api.py
new file mode 100644
index 0000000..308e9e4
--- /dev/null
+++ b/combo/data/api.py
@@ -0,0 +1,98 @@
+import collections
+import dataclasses
+import json
+from dataclasses import dataclass, field
+from typing import Optional, List, Dict, Any, Union, Tuple
+
+import conllu
+from overrides import overrides
+
+
+@dataclass
+class Token:
+    id: Optional[Union[int, Tuple]] = None
+    token: Optional[str] = None
+    lemma: Optional[str] = None
+    upostag: Optional[str] = None
+    xpostag: Optional[str] = None
+    feats: Optional[str] = None
+    head: Optional[int] = None
+    deprel: Optional[str] = None
+    deps: Optional[str] = None
+    misc: Optional[str] = None
+    semrel: Optional[str] = None
+    embeddings: Dict[str, List[float]] = field(default_factory=list, repr=False)
+
+
+@dataclass
+class Sentence:
+    tokens: List[Token] = field(default_factory=list)
+    sentence_embedding: List[float] = field(default_factory=list, repr=False)
+    metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict)
+
+    def to_json(self):
+        return json.dumps({
+            "tokens": [dataclasses.asdict(t) for t in self.tokens],
+            "sentence_embedding": self.sentence_embedding,
+            "metadata": self.metadata,
+        })
+
+    def __len__(self):
+        return len(self.tokens)
+
+
+class _TokenList(conllu.TokenList):
+
+    @overrides
+    def __repr__(self):
+        return 'TokenList<' + ', '.join(token['token'] for token in self) + '>'
+
+
+def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.TokenList:
+    tokens = []
+    for token in sentence.tokens:
+        token_dict = collections.OrderedDict(dataclasses.asdict(token))
+        # Remove semrel to have default conllu format.
+        if not keep_semrel:
+            del token_dict["semrel"]
+        del token_dict["embeddings"]
+        tokens.append(token_dict)
+    # Range tokens must be tuple not list, this is conllu library requirement
+    for t in tokens:
+        if type(t["id"]) == list:
+            t["id"] = tuple(t["id"])
+        if t["deps"]:
+            for dep in t["deps"]:
+                if len(dep) > 1 and type(dep[1]) == list:
+                    dep[1] = tuple(dep[1])
+    return _TokenList(tokens=tokens,
+                      metadata=sentence.metadata)
+
+
+def tokens2conllu(tokens: List[str]) -> conllu.TokenList:
+    return _TokenList(
+        [collections.OrderedDict({"id": idx, "token": token}) for
+         idx, token
+         in enumerate(tokens, start=1)],
+        metadata=collections.OrderedDict()
+    )
+
+
+def conllu2sentence(conllu_sentence: conllu.TokenList,
+                    sentence_embedding=None, embeddings=None) -> Sentence:
+    if embeddings is None:
+        embeddings = {}
+    if sentence_embedding is None:
+        sentence_embedding = []
+    tokens = []
+    for token in conllu_sentence.tokens:
+        tokens.append(
+            Token(
+                **token, embeddings=embeddings[token["id"]]
+            )
+        )
+    return Sentence(
+        tokens=tokens,
+        sentence_embedding=sentence_embedding,
+        metadata=conllu_sentence.metadata
+    )
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
new file mode 100644
index 0000000..0c34d3f
--- /dev/null
+++ b/combo/data/dataset.py
@@ -0,0 +1,317 @@
+import copy
+import logging
+import pathlib
+from typing import Union, List, Dict, Iterable, Optional, Any, Tuple
+
+import conllu
+import torch
+from allennlp import data as allen_data
+from allennlp.common import checks, util
+from allennlp.data import fields as allen_fields, vocabulary
+from conllu import parser
+from dataclasses import dataclass
+from overrides import overrides
+
+from combo.data import fields
+
+logger = logging.getLogger(__name__)
+
+
+@allen_data.DatasetReader.register("conllu")
+class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
+
+    def __init__(
+            self,
+            token_indexers: Dict[str, allen_data.TokenIndexer] = None,
+            lemma_indexers: Dict[str, allen_data.TokenIndexer] = None,
+            features: List[str] = None,
+            targets: List[str] = None,
+            use_sem: bool = False,
+            **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        if features is None:
+            features = ["token", "char"]
+        if targets is None:
+            targets = ["head", "deprel", "upostag", "xpostag", "lemma", "feats"]
+
+        if "token" not in features and "char" not in features:
+            raise checks.ConfigurationError("There must be at least one ('char' or 'token') text-based feature!")
+
+        if "deps" in targets and not ("head" in targets and "deprel" in targets):
+            raise checks.ConfigurationError("Add 'head' and 'deprel' to targets when using 'deps'!")
+
+        intersection = set(features).intersection(set(targets))
+        if len(intersection) != 0:
+            raise checks.ConfigurationError(
+                "Features and targets cannot share elements! "
+                "Remove {} from either features or targets.".format(intersection)
+            )
+        self.use_sem = use_sem
+
+        # *.conllu readers config
+        fields = list(parser.DEFAULT_FIELDS)
+        fields[1] = "token"  # use 'token' instead of 'form'
+        field_parsers = parser.DEFAULT_FIELD_PARSERS
+        # Do not make it nullable
+        field_parsers.pop("xpostag", None)
+        # Ignore parsing misc
+        field_parsers.pop("misc", None)
+        if self.use_sem:
+            fields = list(fields)
+            fields.append("semrel")
+            field_parsers["semrel"] = lambda line, i: line[i]
+        self.field_parsers = field_parsers
+        self.fields = tuple(fields)
+
+        self._token_indexers = token_indexers
+        self._lemma_indexers = lemma_indexers
+        self._targets = targets
+        self._features = features
+        self.generate_labels = True
+        # Filter out not required token indexers to avoid
+        # Mismatched token keys ConfigurationError
+        for indexer_name in list(self._token_indexers.keys()):
+            if indexer_name not in self._features:
+                del self._token_indexers[indexer_name]
+
+    @overrides
+    def _read(self, file_path: str) -> Iterable[allen_data.Instance]:
+        file_path = [file_path] if len(file_path.split(",")) == 0 else file_path.split(",")
+
+        for conllu_file in file_path:
+            file = pathlib.Path(conllu_file)
+            assert conllu_file and file.exists(), f"File with path '{conllu_file}' does not exists!"
+            with file.open("r", encoding="utf-8") as f:
+                for annotation in conllu.parse_incr(f, fields=self.fields, field_parsers=self.field_parsers):
+                    yield self.text_to_instance(annotation)
+
+    @overrides
+    def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance:
+        fields_: Dict[str, allen_data.Field] = {}
+        tree_tokens = [t for t in tree if isinstance(t["id"], int)]
+        tokens = [_Token(t["token"],
+                         pos_=t.get("upostag"),
+                         tag_=t.get("xpostag"),
+                         lemma_=t.get("lemma"),
+                         feats_=t.get("feats"))
+                  for t in tree_tokens]
+
+        # features
+        text_field = allen_fields.TextField(tokens, self._token_indexers)
+        fields_["sentence"] = text_field
+
+        # targets
+        if self.generate_labels:
+            for target_name in self._targets:
+                if target_name != "sent":
+                    target_values = [t[target_name] for t in tree_tokens]
+                    if target_name == "lemma":
+                        target_values = [allen_data.Token(v) for v in target_values]
+                        fields_[target_name] = allen_fields.TextField(target_values, self._lemma_indexers)
+                    elif target_name == "feats":
+                        target_values = self._feat_values(tree_tokens)
+                        fields_[target_name] = fields.SequenceMultiLabelField(target_values,
+                                                                              self._feats_indexer,
+                                                                              self._feats_as_tensor_wrapper,
+                                                                              text_field,
+                                                                              label_namespace="feats_labels")
+                    elif target_name == "head":
+                        target_values = [0 if v == "_" else int(v) for v in target_values]
+                        fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
+                                                                               label_namespace=target_name + "_labels")
+                    elif target_name == "deps":
+                        # Graphs require adding ROOT (AdjacencyField uses sequence length from TextField).
+                        text_field_deps = allen_fields.TextField([_Token("ROOT")] + copy.deepcopy(tokens),
+                                                                 self._token_indexers)
+                        enhanced_heads: List[Tuple[int, int]] = []
+                        enhanced_deprels: List[str] = []
+                        for idx, t in enumerate(tree_tokens):
+                            t_deps = t["deps"]
+                            if t_deps and t_deps != "_":
+                                for rel, head in t_deps:
+                                    # EmoryNLP skips the first edge, if there are two edges between the same
+                                    # nodes. Thanks to that one is in a tree and another in a graph.
+                                    # This snippet follows that approach.
+                                    if enhanced_heads and enhanced_heads[-1] == (idx, head):
+                                        enhanced_heads.pop()
+                                        enhanced_deprels.pop()
+                                    enhanced_heads.append((idx, head))
+                                    enhanced_deprels.append(rel)
+                        fields_["enhanced_heads"] = allen_fields.AdjacencyField(
+                            indices=enhanced_heads,
+                            sequence_field=text_field_deps,
+                            label_namespace="enhanced_heads_labels",
+                            padding_value=0,
+                        )
+                        fields_["enhanced_deprels"] = allen_fields.AdjacencyField(
+                            indices=enhanced_heads,
+                            sequence_field=text_field_deps,
+                            labels=enhanced_deprels,
+                            # Label namespace matches regular tree parsing.
+                            label_namespace="enhanced_deprel_labels",
+                            padding_value=0,
+                        )
+                    else:
+                        fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
+                                                                               label_namespace=target_name + "_labels")
+
+        # Restore feats fields to string representation
+        # parser.serialize_field doesn't handle key without value
+        for token in tree.tokens:
+            if "feats" in token:
+                feats = token["feats"]
+                if feats:
+                    feats_values = []
+                    for k, v in feats.items():
+                        feats_values.append('='.join((k, v)) if v else k)
+                    field = "|".join(feats_values)
+                else:
+                    field = "_"
+                token["feats"] = field
+
+        # metadata
+        fields_["metadata"] = allen_fields.MetadataField({"input": tree,
+                                                          "field_names": self.fields,
+                                                          "tokens": tokens})
+
+        return allen_data.Instance(fields_)
+
+    @staticmethod
+    def _feat_values(tree: List[Dict[str, Any]]):
+        features = []
+        for token in tree:
+            token_features = []
+            if token["feats"] is not None:
+                for feat, value in token["feats"].items():
+                    if feat in ["_", "__ROOT__"]:
+                        pass
+                    else:
+                        # Handle case where feature is binary (doesn't have associated value)
+                        if value:
+                            token_features.append(feat + "=" + value)
+                        else:
+                            token_features.append(feat)
+            features.append(token_features)
+        return features
+
+    @staticmethod
+    def _feats_as_tensor_wrapper(field: fields.SequenceMultiLabelField):
+        def as_tensor(padding_lengths):
+            desired_num_tokens = padding_lengths["num_tokens"]
+            assert len(field._indexed_multi_labels) > 0
+            classes_count = len(field._indexed_multi_labels[0])
+            default_value = [0.0] * classes_count
+            padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens,
+                                                      lambda: default_value)
+            tensor = torch.tensor(padded_tags, dtype=torch.long)
+            return tensor
+
+        return as_tensor
+
+    @staticmethod
+    def _feats_indexer(vocab: allen_data.Vocabulary):
+        label_namespace = "feats_labels"
+        vocab_size = vocab.get_vocab_size(label_namespace)
+        slices = get_slices_if_not_provided(vocab)
+
+        def _m_from_n_ones_encoding(multi_label: List[str], sentence_length: int) -> List[int]:
+            one_hot_encoding = [0] * vocab_size
+            for cat, cat_indices in slices.items():
+                if cat not in ["__PAD__", "_"]:
+                    label_from_cat = [label for label in multi_label if cat == label.split("=")[0]]
+                    if label_from_cat:
+                        label_from_cat = label_from_cat[0]
+                        index = vocab.get_token_index(label_from_cat, label_namespace)
+                    else:
+                        # Get Cat=None index
+                        index = vocab.get_token_index(cat + "=None", label_namespace)
+                    one_hot_encoding[index] = 1
+            return one_hot_encoding
+
+        return _m_from_n_ones_encoding
+
+
+@allen_data.Vocabulary.register("from_instances_extended", constructor="from_instances_extended")
+class Vocabulary(allen_data.Vocabulary):
+
+    @classmethod
+    def from_instances_extended(
+            cls,
+            instances: Iterable[allen_data.Instance],
+            min_count: Dict[str, int] = None,
+            max_vocab_size: Union[int, Dict[str, int]] = None,
+            non_padded_namespaces: Iterable[str] = vocabulary.DEFAULT_NON_PADDED_NAMESPACES,
+            pretrained_files: Optional[Dict[str, str]] = None,
+            only_include_pretrained_words: bool = False,
+            min_pretrained_embeddings: Dict[str, int] = None,
+            padding_token: Optional[str] = vocabulary.DEFAULT_PADDING_TOKEN,
+            oov_token: Optional[str] = vocabulary.DEFAULT_OOV_TOKEN,
+    ) -> "Vocabulary":
+        """
+        Extension to manually fill gaps in missing 'feats_labels'.
+        """
+        # Load manually tokens from pretrained file (using different strategy
+        # - only words add all embedding file, without checking if were seen
+        # in any dataset.
+        tokens_to_add = None
+        if pretrained_files and "tokens" in pretrained_files:
+            pretrained_set = set(vocabulary._read_pretrained_tokens(pretrained_files["tokens"]))
+            tokens_to_add = {"tokens": list(pretrained_set)}
+            pretrained_files = None
+
+        vocab = super().from_instances(
+            instances=instances,
+            min_count=min_count,
+            max_vocab_size=max_vocab_size,
+            non_padded_namespaces=non_padded_namespaces,
+            pretrained_files=pretrained_files,
+            only_include_pretrained_words=only_include_pretrained_words,
+            tokens_to_add=tokens_to_add,
+            min_pretrained_embeddings=min_pretrained_embeddings,
+            padding_token=padding_token,
+            oov_token=oov_token
+        )
+        # Extending vocab with features that does not show up explicitly.
+        # To know all features we need to read full dataset first.
+        # Adding auxiliary '=None' feature for each category is needed
+        # to perform classification.
+        get_slices_if_not_provided(vocab)
+        return vocab
+
+
+def get_slices_if_not_provided(vocab: allen_data.Vocabulary):
+    if hasattr(vocab, "slices"):
+        return vocab.slices
+
+    if "feats_labels" in vocab.get_namespaces():
+        idx2token = vocab.get_index_to_token_vocabulary("feats_labels")
+        for _, v in dict(idx2token).items():
+            if v not in ["_", "__PAD__"]:
+                empty_value = v.split("=")[0] + "=None"
+                vocab.add_token_to_namespace(empty_value, "feats_labels")
+
+        slices = {}
+        for idx, name in vocab.get_index_to_token_vocabulary("feats_labels").items():
+            # There are 2 types features: with (Case=Acc) or without assigment (None).
+            # Here we group their indices by name (before assigment sign).
+            name = name.split("=")[0]
+            if name in slices:
+                slices[name].append(idx)
+            else:
+                slices[name] = [idx]
+        vocab.slices = slices
+        return vocab.slices
+
+
+@dataclass(init=False, repr=False)
+class _Token(allen_data.Token):
+    __slots__ = allen_data.Token.__slots__ + ['feats_']
+
+    feats_: Optional[str]
+
+    def __init__(self, text: str = None, idx: int = None, idx_end: int = None, lemma_: str = None, pos_: str = None,
+                 tag_: str = None, dep_: str = None, ent_type_: str = None, text_id: int = None, type_id: int = None,
+                 feats_: str = None) -> None:
+        super().__init__(text, idx, idx_end, lemma_, pos_, tag_, dep_, ent_type_, text_id, type_id)
+        self.feats_ = feats_
diff --git a/combo/data/fields/__init__.py b/combo/data/fields/__init__.py
new file mode 100644
index 0000000..4f8d3a2
--- /dev/null
+++ b/combo/data/fields/__init__.py
@@ -0,0 +1 @@
+from .sequence_multilabel_field import SequenceMultiLabelField
diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py
new file mode 100644
index 0000000..c31f78e
--- /dev/null
+++ b/combo/data/fields/sequence_multilabel_field.py
@@ -0,0 +1,135 @@
+"""Sequence multilabel field implementation."""
+import logging
+import textwrap
+from typing import Set, List, Callable, Iterator, Union, Dict
+
+import torch
+from allennlp import data
+from allennlp.common import checks
+from allennlp.data import fields
+from overrides import overrides
+
+logger = logging.getLogger(__name__)
+
+
+class SequenceMultiLabelField(data.Field[torch.Tensor]):
+    """
+    A `SequenceMultiLabelField` is an extension of the :class:`MultiLabelField` that allows for multiple labels
+    while keeping sequence dimension.
+
+    To allow config to different circumstances, class takes few delegates functions.
+
+    # Parameters
+
+    multi_labels : `List[List[str]]`
+    multi_label_indexer : `Callable[[data.Vocabulary], Callable[[List[str]], List[int]]]`
+        Nested callable which based on vocab and sequence length maps values of the fields in the sequence
+        from strings to indexed, int values.
+    as_tensor: `Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]]`
+        Nested callable which based on the field itself, maps indexed data to a tensor.
+    sequence_field : `SequenceField`
+        A field containing the sequence that this `SequenceMultiLabelField` is labeling.  Most often, this is a
+        `TextField`, for tagging individual tokens in a sentence.
+    label_namespace : `str`, optional (default="labels")
+        The namespace to use for converting label strings into integers.  We map label strings to
+        integers for you (e.g., "entailment" and "contradiction" get converted to 0, 1, ...),
+        and this namespace tells the `Vocabulary` object which mapping from strings to integers
+        to use (so "entailment" as a label doesn't get the same integer id as "entailment" as a
+        word).  If you have multiple different label fields in your data, you should make sure you
+        use different namespaces for each one, always using the suffix "labels" (e.g.,
+        "passage_labels" and "question_labels").
+    """
+    _already_warned_namespaces: Set[str] = set()
+
+    def __init__(
+            self,
+            multi_labels: List[List[str]],
+            multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str], int], List[int]]],
+            as_tensor: Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]],
+            sequence_field: fields.SequenceField,
+            label_namespace: str = "labels",
+    ) -> None:
+        self.multi_labels = multi_labels
+        self.sequence_field = sequence_field
+        self.multi_label_indexer = multi_label_indexer
+        self._label_namespace = label_namespace
+        self._indexed_multi_labels = None
+        self._maybe_warn_for_namespace(label_namespace)
+        self.as_tensor_wrapper = as_tensor(self)
+        if len(multi_labels) != sequence_field.sequence_length():
+            raise checks.ConfigurationError(
+                "Label length and sequence length "
+                "don't match: %d and %d" % (len(multi_labels), sequence_field.sequence_length())
+            )
+
+        if not all([isinstance(x, str) for multi_label in multi_labels for x in multi_label]):
+            raise checks.ConfigurationError(
+                "SequenceMultiLabelField must be passed either all "
+                "strings or all ints. Found labels {} with "
+                "types: {}.".format(multi_labels, [type(x) for multi_label in multi_labels for x in multi_label])
+            )
+
+    def _maybe_warn_for_namespace(self, label_namespace: str) -> None:
+        if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
+            if label_namespace not in self._already_warned_namespaces:
+                logger.warning(
+                    "Your label namespace was '%s'. We recommend you use a namespace "
+                    "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
+                    "default to your vocabulary.  See documentation for "
+                    "`non_padded_namespaces` parameter in Vocabulary.",
+                    self._label_namespace,
+                )
+                self._already_warned_namespaces.add(label_namespace)
+
+    # Sequence methods
+    def __iter__(self) -> Iterator[Union[List[str], int]]:
+        return iter(self.multi_labels)
+
+    def __getitem__(self, idx: int) -> Union[List[str], int]:
+        return self.multi_labels[idx]
+
+    def __len__(self) -> int:
+        return len(self.multi_labels)
+
+    @overrides
+    def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
+        if self._indexed_multi_labels is None:
+            for multi_label in self.multi_labels:
+                for label in multi_label:
+                    counter[self._label_namespace][label] += 1  # type: ignore
+
+    @overrides
+    def index(self, vocab: data.Vocabulary):
+        indexer = self.multi_label_indexer(vocab)
+
+        indexed = []
+        for multi_label in self.multi_labels:
+            indexed.append(indexer(multi_label, len(self.multi_labels)))
+        self._indexed_multi_labels = indexed
+
+    @overrides
+    def get_padding_lengths(self) -> Dict[str, int]:
+        return {"num_tokens": self.sequence_field.sequence_length()}
+
+    @overrides
+    def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
+        return self.as_tensor_wrapper(padding_lengths)
+
+    @overrides
+    def empty_field(self) -> "SequenceMultiLabelField":
+        empty_list: List[List[str]] = [[]]
+        sequence_label_field = SequenceMultiLabelField(empty_list, lambda x: lambda y: y,
+                                                       lambda x: lambda y: y,
+                                                       self.sequence_field.empty_field())
+        sequence_label_field._indexed_labels = empty_list
+        return sequence_label_field
+
+    def __str__(self) -> str:
+        length = self.sequence_field.sequence_length()
+        formatted_labels = "".join(
+            "\t\t" + labels + "\n" for labels in textwrap.wrap(repr(self.multi_labels), 100)
+        )
+        return (
+            f"SequenceMultiLabelField of length {length} with "
+            f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
+        )
diff --git a/combo/data/samplers/__init__.py b/combo/data/samplers/__init__.py
new file mode 100644
index 0000000..ab003fe
--- /dev/null
+++ b/combo/data/samplers/__init__.py
@@ -0,0 +1 @@
+from .samplers import TokenCountBatchSampler
diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py
new file mode 100644
index 0000000..5db74d9
--- /dev/null
+++ b/combo/data/samplers/samplers.py
@@ -0,0 +1,54 @@
+from typing import List
+
+import numpy as np
+
+from allennlp import data as allen_data
+
+
+@allen_data.BatchSampler.register("token_count")
+class TokenCountBatchSampler(allen_data.BatchSampler):
+
+    def __init__(self, dataset, word_batch_size: int = 2500, shuffle_dataset: bool = True):
+        self._index = 0
+        self.shuffle_dataset = shuffle_dataset
+        self.batch_dataset = self._batchify(dataset, word_batch_size)
+        if shuffle_dataset:
+            self._shuffle()
+
+    @staticmethod
+    def _batchify(dataset, word_batch_size) -> List[List[int]]:
+        dataset = list(dataset)
+        batches = []
+        batch = []
+        words_count = 0
+        lengths = [len(instance.fields["sentence"].tokens) for instance in dataset]
+        argsorted_lengths = np.argsort(lengths)
+        for idx in argsorted_lengths:
+            words_count += lengths[idx]
+            batch.append(idx)
+            if words_count > word_batch_size:
+                batches.append(batch)
+                words_count = 0
+                batch = []
+        return batches
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        if self._index >= len(self.batch_dataset):
+            if self.shuffle_dataset:
+                self._index = 0
+                self._shuffle()
+            raise StopIteration()
+
+        batch = self.batch_dataset[self._index]
+        self._index += 1
+        return batch
+
+    def _shuffle(self):
+        indices = np.random.permutation(range(len(self.batch_dataset)))
+        self.batch_dataset = np.array(self.batch_dataset)[indices].tolist()
+
+    def __len__(self):
+        return len(self.batch_dataset)
diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py
new file mode 100644
index 0000000..550a80b
--- /dev/null
+++ b/combo/data/token_indexers/__init__.py
@@ -0,0 +1,3 @@
+from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
+from .token_characters_indexer import TokenCharactersIndexer
+from .token_features_indexer import TokenFeatsIndexer
diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
new file mode 100644
index 0000000..fc29896
--- /dev/null
+++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
@@ -0,0 +1,117 @@
+from typing import Optional, Dict, Any, List, Tuple
+
+from allennlp import data
+from allennlp.data import token_indexers, tokenizers, IndexedTokenList, vocabulary
+from overrides import overrides
+
+
+@data.TokenIndexer.register("pretrained_transformer_mismatched_fixed")
+class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
+
+    def __init__(self, model_name: str, namespace: str = "tags", max_length: int = None,
+                 tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> None:
+        # The matched version v.s. mismatched
+        super().__init__(model_name, namespace, max_length, tokenizer_kwargs, **kwargs)
+        self._matched_indexer = PretrainedTransformerIndexer(
+            model_name,
+            namespace=namespace,
+            max_length=max_length,
+            tokenizer_kwargs=tokenizer_kwargs,
+            **kwargs,
+        )
+        self._allennlp_tokenizer = self._matched_indexer._allennlp_tokenizer
+        self._tokenizer = self._matched_indexer._tokenizer
+        self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens
+        self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens
+
+    @overrides
+    def tokens_to_indices(self,
+                          tokens,
+                          vocabulary: vocabulary ) -> IndexedTokenList:
+        """
+        Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the
+        maximal input of a model.
+        """
+        self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
+
+        wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
+            [t.ensure_text() for t in tokens])
+
+        if len(wordpieces) > self._tokenizer.max_len_single_sentence:
+            raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
+                             " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
+                             f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
+                             f"Current input: {len(wordpieces)}")
+
+        offsets = [x if x is not None else (-1, -1) for x in offsets]
+
+        output: IndexedTokenList = {
+            "token_ids": [t.text_id for t in wordpieces],
+            "mask": [True] * len(tokens),  # for original tokens (i.e. word-level)
+            "type_ids": [t.type_id for t in wordpieces],
+            "offsets": offsets,
+            "wordpiece_mask": [True] * len(wordpieces),  # for wordpieces (i.e. subword-level)
+        }
+
+        return self._matched_indexer._postprocess_output(output)
+
+
+class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer):
+
+    def __init__(
+            self,
+            model_name: str,
+            namespace: str = "tags",
+            max_length: int = None,
+            tokenizer_kwargs: Optional[Dict[str, Any]] = None,
+            **kwargs,
+    ) -> None:
+        super().__init__(model_name, namespace, max_length, tokenizer_kwargs, **kwargs)
+        self._namespace = namespace
+        self._allennlp_tokenizer = PretrainedTransformerTokenizer(
+            model_name, tokenizer_kwargs=tokenizer_kwargs
+        )
+        self._tokenizer = self._allennlp_tokenizer.tokenizer
+        self._added_to_vocabulary = False
+
+        self._num_added_start_tokens = len(self._allennlp_tokenizer.single_sequence_start_tokens)
+        self._num_added_end_tokens = len(self._allennlp_tokenizer.single_sequence_end_tokens)
+
+        self._max_length = max_length
+        if self._max_length is not None:
+            num_added_tokens = len(self._allennlp_tokenizer.tokenize("a")) - 1
+            self._effective_max_length = (  # we need to take into account special tokens
+                    self._max_length - num_added_tokens
+            )
+            if self._effective_max_length <= 0:
+                raise ValueError(
+                    "max_length needs to be greater than the number of special tokens inserted."
+                )
+
+
+class PretrainedTransformerTokenizer(tokenizers.PretrainedTransformerTokenizer):
+
+    def _intra_word_tokenize(
+            self, string_tokens: List[str]
+    ) -> Tuple[List[data.Token], List[Optional[Tuple[int, int]]]]:
+        tokens: List[data.Token] = []
+        offsets: List[Optional[Tuple[int, int]]] = []
+        for token_string in string_tokens:
+            wordpieces = self.tokenizer.encode_plus(
+                token_string,
+                add_special_tokens=False,
+                return_tensors=None,
+                return_offsets_mapping=False,
+                return_attention_mask=False,
+            )
+            wp_ids = wordpieces["input_ids"]
+
+            if len(wp_ids) > 0:
+                offsets.append((len(tokens), len(tokens) + len(wp_ids) - 1))
+                tokens.extend(
+                    data.Token(text=wp_text, text_id=wp_id)
+                    for wp_id, wp_text in zip(wp_ids, self.tokenizer.convert_ids_to_tokens(wp_ids))
+                )
+            else:
+                offsets.append(None)
+        return tokens, offsets
diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py
new file mode 100644
index 0000000..ea7a3ea
--- /dev/null
+++ b/combo/data/token_indexers/token_characters_indexer.py
@@ -0,0 +1,62 @@
+"""Custom character token indexer."""
+import itertools
+from typing import List, Dict
+
+import torch
+from allennlp import data
+from allennlp.common import util
+from allennlp.data import tokenizers
+from allennlp.data.token_indexers import token_characters_indexer
+from overrides import overrides
+
+
+@data.TokenIndexer.register("characters_const_padding")
+class TokenCharactersIndexer(token_characters_indexer.TokenCharactersIndexer):
+    """Wrapper around allennlp token indexer with const padding."""
+
+    def __init__(self,
+                 namespace: str = "token_characters",
+                 character_tokenizer: tokenizers.CharacterTokenizer = tokenizers.CharacterTokenizer(),
+                 start_tokens: List[str] = None,
+                 end_tokens: List[str] = None,
+                 min_padding_length: int = 0,
+                 token_min_padding_length: int = 0):
+        super().__init__(namespace, character_tokenizer, start_tokens, end_tokens, min_padding_length,
+                         token_min_padding_length)
+
+    @overrides
+    def get_padding_lengths(self, indexed_tokens: data.IndexedTokenList) -> Dict[str, int]:
+        padding_lengths = {"token_characters": len(indexed_tokens["token_characters"]),
+                           "num_token_characters": self._min_padding_length}
+        return padding_lengths
+
+    @overrides
+    def as_padded_tensor_dict(
+            self, tokens: data.IndexedTokenList, padding_lengths: Dict[str, int]
+    ) -> Dict[str, torch.Tensor]:
+        # Pad the tokens.
+        padded_tokens = util.pad_sequence_to_length(
+            tokens["token_characters"],
+            padding_lengths["token_characters"],
+            default_value=lambda: [],
+        )
+
+        # Pad the characters within the tokens.
+        desired_token_length = padding_lengths["num_token_characters"]
+        longest_token: List[int] = max(tokens["token_characters"], key=len, default=[])  # type: ignore
+        padding_value = 0
+        if desired_token_length > len(longest_token):
+            # Since we want to pad to greater than the longest token, we add a
+            # "dummy token" so we can take advantage of the fast implementation of itertools.zip_longest.
+            padded_tokens.append([padding_value] * desired_token_length)
+        # pad the list of lists to the longest sublist, appending 0's
+        padded_tokens = list(zip(*itertools.zip_longest(*padded_tokens, fillvalue=padding_value)))
+        if desired_token_length > len(longest_token):
+            # Removes the "dummy token".
+            padded_tokens.pop()
+        # Truncates all the tokens to the desired length, and return the result.
+        return {
+            "token_characters": torch.LongTensor(
+                [list(token[:desired_token_length]) for token in padded_tokens]
+            )
+        }
diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py
new file mode 100644
index 0000000..7c59124
--- /dev/null
+++ b/combo/data/token_indexers/token_features_indexer.py
@@ -0,0 +1,75 @@
+"""Features indexer."""
+import collections
+from typing import List, Dict
+
+import torch
+from allennlp import data
+from allennlp.common import util
+from overrides import overrides
+
+
+@data.TokenIndexer.register("feats_indexer")
+class TokenFeatsIndexer(data.TokenIndexer):
+
+    def __init__(
+            self,
+            namespace: str = "feats",
+            feature_name: str = "feats_",
+            token_min_padding_length: int = 0,
+    ) -> None:
+        super().__init__(token_min_padding_length)
+        self.namespace = namespace
+        self._feature_name = feature_name
+
+    @overrides
+    def count_vocab_items(self, token: data.Token, counter: Dict[str, Dict[str, int]]):
+        feats = self._feat_values(token)
+        for feat in feats:
+            counter[self.namespace][feat] += 1
+
+    @overrides
+    def tokens_to_indices(self, tokens: List[data.Token], vocabulary: data.Vocabulary) -> data.IndexedTokenList:
+        indices: List[List[int]] = []
+        vocab_size = vocabulary.get_vocab_size(self.namespace)
+        for token in tokens:
+            token_indices = []
+            feats = self._feat_values(token)
+            for feat in feats:
+                token_indices.append(vocabulary.get_token_index(feat, self.namespace))
+            indices.append(util.pad_sequence_to_length(token_indices, vocab_size))
+        return {"tokens": indices}
+
+    @overrides
+    def get_empty_token_list(self) -> data.IndexedTokenList:
+        return {"tokens": [[]]}
+
+    def _feat_values(self, token):
+        feats = getattr(token, self._feature_name)
+        if feats is None:
+            feats = collections.OrderedDict()
+        features = []
+        for feat, value in feats.items():
+            if feat in ["_", "__ROOT__"]:
+                pass
+            else:
+                # Handle case where feature is binary (doesn't have associated value)
+                if value:
+                    features.append(feat + "=" + value)
+                else:
+                    features.append(feat)
+        return features
+
+    @overrides
+    def as_padded_tensor_dict(
+            self, tokens: data.IndexedTokenList, padding_lengths: Dict[str, int]
+    ) -> Dict[str, torch.Tensor]:
+        tensor_dict = {}
+        for key, val in tokens.items():
+            vocab_size = len(val[0])
+            tensor = torch.tensor(util.pad_sequence_to_length(val,
+                                                              padding_lengths[key],
+                                                              default_value=lambda: [0] * vocab_size,
+                                                              )
+                                  )
+            tensor_dict[key] = tensor
+        return tensor_dict
diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py
new file mode 100644
index 0000000..d482b6c
--- /dev/null
+++ b/combo/data/vocabulary.py
@@ -0,0 +1,98 @@
+from collections import defaultdict, OrderedDict
+from typing import Dict, Union, Optional, Iterable, Callable, Any, Set
+
+from torchtext.vocab import Vocab as TorchtextVocab
+from torchtext.vocab import vocab as torchtext_vocab
+
+DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
+DEFAULT_PADDING_TOKEN = "@@PADDING@@"
+DEFAULT_OOV_TOKEN = "@@UNKNOWN@@"
+NAMESPACE_PADDING_FILE = "non_padded_namespaces.txt"
+DEFAULT_NAMESPACE = "tokens"
+
+
+def match_namespace(pattern: str, namespace: str):
+    if not isinstance(pattern, str):
+        raise ValueError("Pattern and namespace must be string types, got %s and %s." %
+                         (type(pattern), type(namespace)))
+    if pattern == namespace:
+        return True
+    if len(pattern)>2 and pattern[0] == '*' and namespace.endswith(pattern[1:]):
+        return True
+    return False
+
+
+class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]):
+    def __init__(self,
+                 non_padded_namespaces: Iterable[str],
+                 padding_token: str,
+                 oov_token: str):
+        self._non_padded_namespaces = set(non_padded_namespaces)
+        self._padding_token = padding_token
+        self._oov_token = oov_token
+        super().__init__()
+
+    def __missing__(self, namespace: str):
+        # Non-padded namespace
+        if any([match_namespace(namespace, npn) for npn in self._non_padded_namespaces]):
+            value = torchtext_vocab(
+                OrderedDict([
+                    (self._padding_token, 0),
+                    (self._oov_token, 1)])
+            )
+        else:
+            value = torchtext_vocab(OrderedDict([]))
+        dict.__setitem__(self, namespace, value)
+        return value
+
+    def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
+        self._non_padded_namespaces.update(non_padded_namespaces)
+
+
+class Vocabulary:
+    def __init__(self,
+                 counter: Dict[str, Dict[str, int]] = None,
+                 min_count: Dict[str, int] = None,
+                 max_vocab_size: Union[int, Dict[str, int]] = None,
+                 non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
+                 padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
+                 oov_token: Optional[str] = DEFAULT_OOV_TOKEN):
+
+        # ordered_dict – Ordered Dictionary mapping tokens to their corresponding occurance frequencies.
+        #
+        # min_freq – The minimum frequency needed to include a token in the vocabulary.
+        #
+        # specials – Special symbols to add. The order of supplied tokens will be preserved.
+        #
+        # special_first – Indicates whether to insert symbols at the beginning or at the end.
+        self._padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
+        self._oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
+        self._non_padded_namespaces = set(non_padded_namespaces)
+        self._vocab = _NamespaceDependentDefaultDict(self._non_padded_namespaces,
+                                                     self._padding_token,
+                                                     self._oov_token)
+
+    def _extend(self,
+                tokens_to_add: Dict[str, Dict[str, int]]):
+        for namespace, tokens in tokens_to_add.items():
+            for token in tokens:
+                self._vocab[namespace].append_token(token)
+
+    # def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE):
+    #     """
+    #     Add the token if not present and return the index even if token was already in the namespace.
+    #
+    #     :param token: token to be added
+    #     :param namespace: namespace to add the token to
+    #     :return: index of the token in the namespace
+    #     """
+    #
+    #     if not isinstance(token, str):
+    #         raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token)))
+    #
+
+
+    @classmethod
+    def empty(cls):
+        return cls()
+
diff --git a/combo/models/__init__.py b/combo/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/combo/models/base.py b/combo/models/base.py
new file mode 100644
index 0000000..11df601
--- /dev/null
+++ b/combo/models/base.py
@@ -0,0 +1,215 @@
+from typing import Dict, Optional, List, Union, Tuple
+
+import torch
+import torch.nn as nn
+import utils
+import combo.models.combo_nn as combo_nn
+import combo.checks as checks
+
+
+class Predictor(nn.Module):
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        raise NotImplementedError()
+
+
+class Linear(nn.Linear):
+
+    def __init__(self,
+                 in_features: int,
+                 out_features: int,
+                 activation: Optional[combo_nn.Activation] = None,
+                 dropout_rate: Optional[float] = 0.0):
+        super().__init__(in_features, out_features)
+        self.activation = activation if activation else self.identity
+        self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else self.identity
+
+    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+        x = super().forward(x)
+        x = self.activation(x)
+        return self.dropout(x)
+
+    def get_output_dim(self) -> int:
+        return self.out_features
+
+    @staticmethod
+    def identity(x):
+        return x
+
+
+class FeedForward(torch.nn.Module):
+    """
+    Modified copy of allennlp.modules.feedforward.FeedForward
+
+    This `Module` is a feed-forward neural network, just a sequence of `Linear` layers with
+    activation functions in between.
+
+    # Parameters
+
+    input_dim : `int`, required
+        The dimensionality of the input.  We assume the input has shape `(batch_size, input_dim)`.
+    num_layers : `int`, required
+        The number of `Linear` layers to apply to the input.
+    hidden_dims : `Union[int, List[int]]`, required
+        The output dimension of each of the `Linear` layers.  If this is a single `int`, we use
+        it for all `Linear` layers.  If it is a `List[int]`, `len(hidden_dims)` must be
+        `num_layers`.
+    activations : `Union[Activation, List[Activation]]`, required
+        The activation function to use after each `Linear` layer.  If this is a single function,
+        we use it after all `Linear` layers.  If it is a `List[Activation]`,
+        `len(activations)` must be `num_layers`. Activation must have torch.nn.Module type.
+    dropout : `Union[float, List[float]]`, optional (default = `0.0`)
+        If given, we will apply this amount of dropout after each layer.  Semantics of `float`
+        versus `List[float]` is the same as with other parameters.
+
+    # Examples
+
+    ```python
+    FeedForward(124, 2, [64, 32], torch.nn.ReLU(), 0.2)
+    #> FeedForward(
+    #>   (_activations): ModuleList(
+    #>     (0): ReLU()
+    #>     (1): ReLU()
+    #>   )
+    #>   (_linear_layers): ModuleList(
+    #>     (0): Linear(in_features=124, out_features=64, bias=True)
+    #>     (1): Linear(in_features=64, out_features=32, bias=True)
+    #>   )
+    #>   (_dropout): ModuleList(
+    #>     (0): Dropout(p=0.2, inplace=False)
+    #>     (1): Dropout(p=0.2, inplace=False)
+    #>   )
+    #> )
+    ```
+    """
+
+    def __init__(
+        self,
+        input_dim: int,
+        num_layers: int,
+        hidden_dims: Union[int, List[int]],
+        activations: Union[combo_nn.Activation, List[combo_nn.Activation]],
+        dropout: Union[float, List[float]] = 0.0,
+    ) -> None:
+
+        super().__init__()
+        if not isinstance(hidden_dims, list):
+            hidden_dims = [hidden_dims] * num_layers  # type: ignore
+        if not isinstance(activations, list):
+            activations = [activations] * num_layers  # type: ignore
+        if not isinstance(dropout, list):
+            dropout = [dropout] * num_layers  # type: ignore
+        if len(hidden_dims) != num_layers:
+            raise checks.ConfigurationError(
+                "len(hidden_dims) (%d) != num_layers (%d)" % (len(hidden_dims), num_layers)
+            )
+        if len(activations) != num_layers:
+            raise checks.ConfigurationError(
+                "len(activations) (%d) != num_layers (%d)" % (len(activations), num_layers)
+            )
+        if len(dropout) != num_layers:
+            raise checks.ConfigurationError(
+                "len(dropout) (%d) != num_layers (%d)" % (len(dropout), num_layers)
+            )
+        self._activations = torch.nn.ModuleList(activations)
+        input_dims = [input_dim] + hidden_dims[:-1]
+        linear_layers = []
+        for layer_input_dim, layer_output_dim in zip(input_dims, hidden_dims):
+            linear_layers.append(torch.nn.Linear(layer_input_dim, layer_output_dim))
+        self._linear_layers = torch.nn.ModuleList(linear_layers)
+        dropout_layers = [torch.nn.Dropout(p=value) for value in dropout]
+        self._dropout = torch.nn.ModuleList(dropout_layers)
+        self._output_dim = hidden_dims[-1]
+        self.input_dim = input_dim
+
+    def get_output_dim(self):
+        return self._output_dim
+
+    def get_input_dim(self):
+        return self.input_dim
+
+    def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+
+        output = inputs
+        feature_maps = []
+        for layer, activation, dropout in zip(
+            self._linear_layers, self._activations, self._dropout
+        ):
+            feature_maps.append(output)
+            output = dropout(activation(layer(output)))
+        return output, feature_maps
+
+
+
+class FeedForwardPredictor(Predictor):
+    """Feedforward predictor. Should be used on top of Seq2Seq encoder."""
+
+    def __init__(self, feedforward_network: "FeedForward"):
+        super().__init__()
+        self.feedforward_network = feedforward_network
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        if mask is None:
+            mask = x.new_ones(x.size()[:-1])
+
+        x, feature_maps = self.feedforward_network(x)
+        output = {
+            "prediction": x.argmax(-1),
+            "probability": x,
+            "embedding": feature_maps[-1],
+        }
+
+        if labels is not None:
+            if sample_weights is None:
+                sample_weights = labels.new_ones([mask.size(0)])
+            output["loss"] = self._loss(x, labels, mask, sample_weights)
+
+        return output
+
+    def _loss(self,
+              pred: torch.Tensor,
+              true: torch.Tensor,
+              mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> torch.Tensor:
+        BATCH_SIZE, _, CLASSES = pred.size()
+        valid_positions = mask.sum()
+        pred = pred.reshape(-1, CLASSES)
+        true = true.reshape(-1)
+        mask = mask.reshape(-1)
+        loss = utils.masked_cross_entropy(pred, true, mask)
+        loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
+        return loss.sum() / valid_positions
+
+    @classmethod
+    def from_vocab(cls,
+                   vocab: data.Vocabulary,
+                   vocab_namespace: str,
+                   input_dim: int,
+                   num_layers: int,
+                   hidden_dims: List[int],
+                   activations: Union[combo_nn.Activation, List[combo_nn.Activation]],
+                   dropout: Union[float, List[float]] = 0.0,
+                   ):
+        if len(hidden_dims) + 1 != num_layers:
+            raise checks.ConfigurationError(
+                f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})"
+            )
+
+        assert vocab_namespace in vocab.get_namespaces(), \
+            f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!"
+        hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
+
+        return cls(FeedForward(
+            input_dim=input_dim,
+            num_layers=num_layers,
+            hidden_dims=hidden_dims,
+            activations=activations,
+            dropout=dropout))
+
diff --git a/combo/models/combo_nn.py b/combo/models/combo_nn.py
new file mode 100644
index 0000000..2866670
--- /dev/null
+++ b/combo/models/combo_nn.py
@@ -0,0 +1,7 @@
+import torch
+import torch.nn as nn
+
+
+class Activation(nn.Module):
+    def __cal__(self, tensor: torch.Tensor) -> torch.Tensor:
+        raise NotImplementedError
diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
new file mode 100644
index 0000000..e69de29
diff --git a/combo/models/utils.py b/combo/models/utils.py
new file mode 100644
index 0000000..0a447f9
--- /dev/null
+++ b/combo/models/utils.py
@@ -0,0 +1,7 @@
+import torch
+import torch.nn.functional as F
+
+
+def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+    pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log()
+    return F.cross_entropy(pred, true, reduction="none") * mask
diff --git a/docs/Configuration.md b/docs/Configuration.md
new file mode 100644
index 0000000..7fb14f5
--- /dev/null
+++ b/docs/Configuration.md
@@ -0,0 +1,4 @@
+# Configuration
+
+Dependency injection is used for configuration using the ```dependency_injector``` package.
+Configuration files can be in json or ini format.
\ No newline at end of file
diff --git a/example.conllu b/example.conllu
new file mode 100644
index 0000000..32e0653
--- /dev/null
+++ b/example.conllu
@@ -0,0 +1,13 @@
+# sent_id = test-s1
+# text = Easy sentence.
+1	Verylongwordwhichmustbetruncatedbythesystemto30	easy	ADJ	adj	AdpType=Prep|Adp	2	amod	_	_
+2	Sentence	verylonglemmawhichmustbetruncatedbythesystemto30	NOUN	nom	Number=Sing	0	root	_	_
+3	.	.	PUNCT	.	_	1	punct	_	_
+
+# sent_id = test-s1
+# text = Easy sentence.
+1	Verylongwordwhichmustbetruncatedbythesystemto30	easy	ADJ	adj	AdpType=Prep|Adp	2	amod	_	_
+2	Sentence	verylonglemmawhichmustbetruncatedbythesystemto30	NOUN	nom	Number=Sing	0	root	_	_
+3	.	.	PUNCT	.	_	1	punct	2:mod	_
+4	.	.	PUNCT	.	_	1	punct	2:xmod	_
+
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..5596b44
--- /dev/null
+++ b/main.py
@@ -0,0 +1,16 @@
+# This is a sample Python script.
+
+# Press Shift+F10 to execute it or replace it with your code.
+# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
+
+
+def print_hi(name):
+    # Use a breakpoint in the code line below to debug your script.
+    print(f'Hi, {name}')  # Press Ctrl+F8 to toggle the breakpoint.
+
+
+# Press the green button in the gutter to run the script.
+if __name__ == '__main__':
+    print_hi('PyCharm')
+
+# See PyCharm help at https://www.jetbrains.com/help/pycharm/
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..61f0590
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,7 @@
+conllu~=4.4.1
+dependency-injector~=4.41.0
+overrides~=7.3.1
+torch~=1.13.1
+torchtext~=0.14.1
+numpy~=1.24.1
+pytorch-lightning~=1.9.0
\ No newline at end of file
-- 
GitLab