diff --git a/combo/combo_model.py b/combo/combo_model.py
index b87eb1c0719fdd1e8fc4e1e4695559deb81644cb..d87d514adeddb223bec672e43d664d347f97cf82 100644
--- a/combo/combo_model.py
+++ b/combo/combo_model.py
@@ -1,5 +1,5 @@
 """Main COMBO model."""
-from typing import Optional, Dict, Any, List
+from typing import Optional, Dict, Any, List, Union
 
 import numpy
 import torch
@@ -24,6 +24,7 @@ from combo.nn.utils import get_text_field_mask
 from combo.predictors import Predictor
 from combo.utils import metrics
 from combo.utils import ConfigurationError
+from modules.seq2seq_encoders.transformer_encoder import TransformerEncoder
 
 
 @Registry.register("semantic_multitask")
@@ -39,7 +40,7 @@ class ComboModel(Model, FromParameters):
                  vocabulary: data.Vocabulary,
                  loss_weights: Dict[str, float],
                  text_field_embedder: TextFieldEmbedder,
-                 seq_encoder: Seq2SeqEncoder,
+                 seq_encoder: Union[Seq2SeqEncoder, TransformerEncoder],
                  use_sample_weight: bool = True,
                  lemmatizer: LemmatizerModel = None,
                  upos_tagger: MorphologicalFeatures = None,
diff --git a/combo/config.template.json b/combo/config.template.json
index 7baf6547e1e49c0d0f3d0e441546cd7893329506..93332c54d3cbc0c0cba97084e7853556258e1495 100644
--- a/combo/config.template.json
+++ b/combo/config.template.json
@@ -252,7 +252,6 @@
       },
       "batch_size": 1,
       "shuffle": true,
-      "batches_per_epoch": 64,
       "quiet": false
     }
   },
diff --git a/combo/config.template.transformer.json b/combo/config.template.transformer.json
new file mode 100644
index 0000000000000000000000000000000000000000..c6bf59762fe616134895398a2b6ef8f1c7424a78
--- /dev/null
+++ b/combo/config.template.transformer.json
@@ -0,0 +1,304 @@
+{
+  "model": {
+    "type": "semantic_multitask",
+    "parameters": {
+      "dependency_relation": {
+        "type": "combo_dependency_parsing_from_vocab",
+        "parameters": {
+          "dependency_projection_layer": {
+            "type": "linear_layer",
+            "parameters": {
+              "activation": { "type": "tanh", "parameters": {} },
+              "dropout_rate": 0.25,
+              "in_features": 164,
+              "out_features": 128
+            }
+          },
+          "head_predictor": {
+            "type": "head_prediction",
+            "parameters": {
+              "cycle_loss_n": 0,
+              "dependency_projection_layer": {
+                "type": "linear_layer",
+                "parameters": {
+                  "activation": { "type": "tanh", "parameters": {} },
+                  "in_features": 164,
+                  "out_features": 512
+                }
+              },
+              "head_projection_layer": {
+                "type": "linear_layer",
+                "parameters": {
+                  "activation": { "type": "tanh", "parameters": {} },
+                  "in_features": 164,
+                  "out_features": 512
+                }
+              }
+            }
+          },
+          "head_projection_layer": {
+            "type": "linear_layer",
+            "parameters": {
+              "activation": { "type": "tanh", "parameters": {} },
+              "dropout_rate": 0.25,
+              "in_features": 164,
+              "out_features": 128
+            }
+          },
+          "vocab_namespace": "deprel_labels"
+        }
+      },
+      "lemmatizer": {
+        "type": "combo_lemma_predictor_from_vocab",
+        "parameters": {
+          "activations": [
+            { "type": "gelu", "parameters": {} },
+            { "type": "gelu", "parameters": {} },
+            { "type": "gelu", "parameters": {} },
+            { "type": "linear", "parameters": {} }
+          ],
+          "char_vocab_namespace": "token_characters",
+          "dilation": [1, 2, 4, 1],
+          "embedding_dim": 300,
+          "filters": [256, 256, 256],
+          "input_projection_layer": {
+            "type": "linear_layer",
+            "parameters": {
+              "activation": { "type": "tanh", "parameters": {} },
+              "dropout_rate": 0.25,
+              "in_features": 164,
+              "out_features": 32
+            }
+          },
+          "kernel_size": [3, 3, 3, 1],
+          "lemma_vocab_namespace": "lemma_characters",
+          "padding": [1, 2, 4, 0],
+          "stride": [1, 1, 1, 1]
+        }
+      },
+      "loss_weights": {
+        "deprel": 0.8,
+        "feats": 0.2,
+        "head": 0.2,
+        "lemma": 0.05,
+        "semrel": 0.05,
+        "upostag": 0.05,
+        "xpostag": 0.05
+      },
+      "morphological_feat": {
+        "type": "combo_morpho_from_vocab",
+        "parameters": {
+          "activations": [
+            { "type": "tanh", "parameters": {} },
+            { "type": "linear", "parameters": {} }
+          ],
+          "dropout": [0.25, 0.0],
+          "hidden_dims": [128],
+          "input_dim": 164,
+          "num_layers": 2,
+          "vocab_namespace": "feats_labels"
+        }
+      },
+      "regularizer": {
+        "type": "base_regularizer",
+        "parameters": {
+          "regexes": [
+            [
+              ".*conv1d.*",
+              { "type": "l2_regularizer", "parameters": { "alpha": 1e-6 } }
+            ],
+            [
+              ".*forward.*",
+              { "type": "l2_regularizer", "parameters": { "alpha": 1e-6 } }
+            ],
+            [
+              ".*backward.*",
+              { "type": "l2_regularizer", "parameters": { "alpha": 1e-6 } }
+            ],
+            [
+              ".*char_embed.*",
+              { "type": "l2_regularizer", "parameters": { "alpha": 1e-5 } }
+            ]
+          ]
+        }
+      },
+      "seq_encoder": {
+        "type": "combo_transformer_encoder",
+        "parameters": {
+          "layer_dropout_probability": 0.33,
+          "input_dim": 164,
+          "num_layers": 2,
+          "feedforward_hidden_dim": 2048,
+          "num_attention_heads": 4,
+          "positional_encoding": null,
+          "positional_embedding_size": 512,
+          "dropout_prob": 0.1,
+          "activation": "relu"
+        }
+      },
+      "text_field_embedder": {
+        "type": "base_text_field_embedder",
+        "parameters": {
+          "token_embedders": {
+            "char": {
+              "type": "char_embeddings_token_embedder",
+              "parameters": {
+                "dilated_cnn_encoder": {
+                  "type": "dilated_cnn",
+                  "parameters": {
+                    "activations": [
+                      { "type": "gelu", "parameters": {} },
+                      { "type": "gelu", "parameters": {} },
+                      { "type": "linear", "parameters": {} }
+                    ],
+                    "dilation": [1, 2, 4],
+                    "filters": [512, 256, 64],
+                    "input_dim": 64,
+                    "kernel_size": [3, 3, 3],
+                    "padding": [1, 2, 4],
+                    "stride": [1, 1, 1]
+                  }
+                },
+                "embedding_dim": 64
+              }
+            },
+            "token": {
+              "type": "transformers_word_embedder",
+              "parameters": { "projection_dim": 100 }
+            }
+          }
+        }
+      },
+      "upos_tagger": {
+        "type": "feedforward_predictor_from_vocab",
+        "parameters": {
+          "vocab_namespace": "upostag_labels",
+          "input_dim": 164,
+          "num_layers": 2,
+          "hidden_dims": [64],
+          "activations": [
+            { "type": "tanh", "parameters": {} },
+            { "type": "linear", "parameters": {} }
+          ],
+          "dropout": [0.25, 0.0]
+        }
+      },
+      "xpos_tagger": {
+        "type": "feedforward_predictor_from_vocab",
+        "parameters": {
+          "vocab_namespace": "xpostag_labels",
+          "input_dim": 164,
+          "num_layers": 2,
+          "hidden_dims": [64],
+          "activations": [
+            { "type": "tanh", "parameters": {} },
+            { "type": "linear", "parameters": {} }
+          ],
+          "dropout": [0.25, 0.0]
+        }
+      }
+    }
+  },
+  "data_loader": {
+    "type": "simple_data_loader_from_dataset_reader",
+    "parameters": {
+      "reader": {
+        "type": "conllu_dataset_reader",
+        "parameters": {
+          "features": ["token", "char"],
+          "tokenizer": {
+            "type": "lambo_tokenizer"
+          },
+          "lemma_indexers": {
+            "char": {
+              "type": "characters_const_padding_token_indexer",
+              "parameters": {
+                "tokenizer": {
+                  "type": "character_tokenizer",
+                  "parameters": {
+                    "end_tokens": ["__END__"],
+                    "start_tokens": ["__START__"]
+                  }
+                },
+                "min_padding_length": 32,
+                "namespace": "lemma_characters"
+              }
+            }
+          },
+          "targets": ["deprel", "feats", "head", "lemma", "upostag", "xpostag"],
+          "token_indexers": {
+            "char": {
+              "type": "characters_const_padding_token_indexer",
+              "parameters": {
+                "tokenizer": {
+                  "type": "character_tokenizer",
+                  "parameters": {
+                    "end_tokens": ["__END__"],
+                    "start_tokens": ["__START__"]
+                  }
+                },
+                "min_padding_length": 32
+              }
+            },
+            "token": {
+              "type": "pretrained_transformer_mismatched_fixed_token_indexer",
+              "parameters": { "model_name": "allegro/herbert-base-cased" }
+            }
+          },
+          "use_sem": false
+        }
+      },
+      "batch_size": 1,
+      "shuffle": true,
+      "quiet": false
+    }
+  },
+  "dataset_reader": {
+    "type": "conllu_dataset_reader",
+    "parameters": {
+      "features": ["token", "char"],
+      "tokenizer": {
+        "type": "lambo_tokenizer"
+      },
+      "lemma_indexers": {
+        "char": {
+          "type": "characters_const_padding_token_indexer",
+          "parameters": {
+            "tokenizer": {
+              "type": "character_tokenizer",
+              "parameters": {
+                "end_tokens": ["__END__"],
+                "start_tokens": ["__START__"]
+              }
+            },
+            "min_padding_length": 32,
+            "namespace": "lemma_characters"
+          }
+        }
+      },
+      "targets": ["deprel", "feats", "head", "lemma", "upostag", "xpostag"],
+      "token_indexers": {
+        "char": {
+          "type": "characters_const_padding_token_indexer",
+          "parameters": {
+            "tokenizer": {
+              "type": "character_tokenizer",
+              "parameters": {
+                "end_tokens": ["__END__"],
+                "start_tokens": ["__START__"]
+              }
+            },
+            "min_padding_length": 32
+          }
+        },
+        "token": {
+          "type": "pretrained_transformer_mismatched_fixed_token_indexer",
+          "parameters": { "model_name": "allegro/herbert-base-cased" }
+        }
+      },
+      "use_sem": false
+    }
+  },
+  "training": {},
+  "model_name": "allegro/herbert-base-cased"
+}
diff --git a/combo/config/from_parameters.py b/combo/config/from_parameters.py
index 743c80ac18587ed741651388e5b0dff6815d7be6..8d624073f83449b68afdc078a65f1e9934b27b7c 100644
--- a/combo/config/from_parameters.py
+++ b/combo/config/from_parameters.py
@@ -58,6 +58,8 @@ def serialize_single_value(value: Any, pass_down_parameter_names: List[str] = No
         return {k: serialize_single_value(v, pass_down_parameter_names) for k, v in value.items()}
     elif isinstance(value, int) or isinstance(value, float) or isinstance(value, str):
         return value
+    elif value is None:
+        return None
     else:
         return str(value)
 
diff --git a/combo/data/tokenizers/lambo_tokenizer.py b/combo/data/tokenizers/lambo_tokenizer.py
index e0beb838ec2ec31779a53a8917d3bb8be4d8d882..abb4e33ad9275bb71f72c584e994935b1e6e288d 100644
--- a/combo/data/tokenizers/lambo_tokenizer.py
+++ b/combo/data/tokenizers/lambo_tokenizer.py
@@ -28,7 +28,7 @@ def _sentence_tokens(token: Token,
                      split_subwords: Optional[bool] = None) -> List[Token]:
     if split_subwords and len(token.subwords) > 0:
         subword_idxs = [next(_token_idx()) for _ in range(len(token.subwords))]
-        multiword = (token.text, (subword_idxs[0], subword_idxs[1]))
+        multiword = (token.text, (subword_idxs[0], subword_idxs[-1]))
         tokens = [Token(idx=s_idx, text=subword, multiword=multiword) for (s_idx, subword)
                   in zip(subword_idxs, token.subwords)]
         return tokens
@@ -74,19 +74,31 @@ class LamboTokenizer(Tokenizer):
             for turn in document.turns:
                 sentence_tokens = []
                 for sentence in turn.sentences:
+                    _reset_idx()
                     for token in sentence.tokens:
                         sentence_tokens.extend(_sentence_tokens(token, split_subwords))
                 tokens.append(sentence_tokens)
         elif split_level.upper() == "SENTENCE":
             for turn in document.turns:
                 for sentence in turn.sentences:
+                    _reset_idx()
                     sentence_tokens = []
                     for token in sentence.tokens:
+                        if len(token.subwords) > 0 and split_subwords:
+                            # @TODO this is a very dirty fix for Lambo model's shortcomings
+                            # I noticed that for longer words with multiwords it tends to remove the last letter in the last multiword
+                            # so this is a quick workaround to fix it
+
+                            # check if subwords in token.subwords are consistent with token.text
+                            if "".join(token.subwords) != token.text:
+                                fixed_subwords = fix_subwords(token)
+                                token.subwords = fixed_subwords
                         sentence_tokens.extend(_sentence_tokens(token, split_subwords))
                     tokens.append(sentence_tokens)
         else:
             for turn in document.turns:
                 for sentence in turn.sentences:
+                    _reset_idx()
                     for token in sentence.tokens:
                         tokens.extend(_sentence_tokens(token, split_subwords))
             tokens = [tokens]
@@ -116,16 +128,40 @@ class LamboTokenizer(Tokenizer):
             if turns:
                 sentence_tokens = []
             for sentence in turn.sentences:
+                _reset_idx()
                 if not turns:
                     sentence_tokens = []
                 for token in sentence.tokens:
                     if len(token.subwords) > 0 and split_subwords:
-                        sentence_tokens.extend([s for s in token.subwords])
-                    else:
-                        sentence_tokens.append(token.text)
+                        # @TODO this is a very dirty fix for Lambo model's shortcomings
+                        # I noticed that for longer words with multiwords it tends to remove the last letter in the last multiword
+                        # so this is a quick workaround to fix it
+
+                        # check if subwords in token.subwords are consistent with token.text
+                        if "".join(token.subwords) != token.text:
+                            fixed_subwords = fix_subwords(token)
+                            token.subwords = fixed_subwords
+                        # sentence_tokens.extend(_sentence_tokens(token, split_subwords))
+                    # else:
+                    sentence_tokens.extend(_sentence_tokens(token, split_subwords))
                 if not turns:
                     sentences.append(sentence_tokens)
             if turns:
                 sentences.append(sentence_tokens)
 
         return sentences
+
+
+def fix_subwords(token: Token):
+    fixed_subwords = []
+    text_it = 0
+    for i, subword in enumerate(token.subwords):
+        if token.text[text_it:text_it + len(subword)] == subword:
+            if i == len(token.subwords) - 1 and (text_it + len(subword) < len(token.text)):
+                subword = token.text[text_it:]
+            fixed_subwords.append(subword)
+            text_it += len(subword)
+        else:
+            fixed_subwords.append(token.text[text_it:text_it + len(subword)])
+            text_it += len(subword)
+    return fixed_subwords
\ No newline at end of file
diff --git a/combo/default_model.py b/combo/default_model.py
index d074e6298ed856c748e63ed43d24ec4267a553e7..a7818459e512e38d39f9c77a8d0cd5702cd93070 100644
--- a/combo/default_model.py
+++ b/combo/default_model.py
@@ -12,14 +12,14 @@ from combo.data.token_indexers import TokenConstPaddingCharactersIndexer, \
 from combo.data.tokenizers import CharacterTokenizer
 from combo.data.vocabulary import Vocabulary
 from combo.combo_model import ComboModel
-from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM
+from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM, ComboTransformerEncoder
 from combo.models.dilated_cnn import DilatedCnnEncoder
 from combo.modules.lemma import LemmatizerModel
 from combo.modules.morpho import MorphologicalFeatures
 from combo.modules.parser import DependencyRelationModel, HeadPredictionModel
 from combo.modules.text_field_embedders import BasicTextFieldEmbedder
 from combo.modules.token_embedders import CharacterBasedWordEmbedder, TransformersWordEmbedder
-from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivation
+from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivation, GELUActivation
 from combo.modules import FeedForwardPredictor
 from combo.nn.base import Linear
 from combo.nn.regularizers.regularizers import L2Regularizer
@@ -55,33 +55,61 @@ def default_ud_dataset_reader(pretrained_transformer_name: str,
         targets=["deprel", "head", "upostag", "lemma", "feats", "xpostag"],
         token_indexers={
             "char": default_character_indexer(),
-            "feats": TokenFeatsIndexer(),
-            "lemma": default_character_indexer(),
+            # "feats": TokenFeatsIndexer(),
+            # "lemma": default_character_indexer(),
             "token": PretrainedTransformerFixedMismatchedIndexer(pretrained_transformer_name),
-            "upostag": SingleIdTokenIndexer(
-                feature_name="pos_",
-                namespace="upostag"
-            ),
-            "xpostag": SingleIdTokenIndexer(
-                feature_name="tag_",
-                namespace="xpostag"
-            )
+            # "upostag": SingleIdTokenIndexer(
+            #     feature_name="pos_",
+            #     namespace="upostag"
+            # ),
+            # "xpostag": SingleIdTokenIndexer(
+            #     feature_name="tag_",
+            #     namespace="xpostag"
+            # )
         },
         use_sem=False,
         tokenizer=tokenizer
     )
 
 
-def default_data_loader(dataset_reader: DatasetReader,
+def default_data_loader(
+                        dataset_reader: DatasetReader,
                         file_path: str,
-                        batch_size: int = 16,
-                        batches_per_epoch: int = 4) -> SimpleDataLoader:
-    return SimpleDataLoader.from_dataset_reader(dataset_reader,
-                                                data_path=file_path,
-                                                batch_size=batch_size,
-                                                batches_per_epoch=batches_per_epoch,
-                                                shuffle=True,
-                                                collate_fn=lambda instances: Batch(instances).as_tensor_dict())
+                        batch_size: int = 1,
+                        batches_per_epoch: int = 64) -> SimpleDataLoader:
+    # tokenizer = tokenizer or LamboTokenizer()
+    # reader = UniversalDependenciesDatasetReader(
+    #     features=["token", "char"],
+    #     lemma_indexers={
+    #         "char": default_character_indexer("lemma_characters")
+    #     },
+    #     targets=["deprel", "head", "upostag", "lemma", "feats", "xpostag"],
+    #     token_indexers={
+    #         "char": default_character_indexer(),
+    #         # "feats": TokenFeatsIndexer(),
+    #         # "lemma": default_character_indexer(),
+    #         "token": PretrainedTransformerFixedMismatchedIndexer(pretrained_transformer_name),
+    #         # "upostag": SingleIdTokenIndexer(
+    #         #     feature_name="pos_",
+    #         #     namespace="upostag"
+    #         # ),
+    #         # "xpostag": SingleIdTokenIndexer(
+    #         #     feature_name="tag_",
+    #         #     namespace="xpostag"
+    #         # )
+    #     },
+    #     use_sem=False,
+    #     tokenizer=tokenizer
+    # )
+
+    return SimpleDataLoader.from_dataset_reader(
+        dataset_reader,
+        data_path=file_path,
+        batch_size=batch_size,
+        batches_per_epoch=batches_per_epoch,
+        shuffle=True,
+        quiet=False,
+        collate_fn=lambda instances: Batch(instances).as_tensor_dict())
 
 
 def default_vocabulary(data_loader: DataLoader) -> Vocabulary:
@@ -94,124 +122,247 @@ def default_vocabulary(data_loader: DataLoader) -> Vocabulary:
     )
 
 
-def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary) -> ComboModel:
-    return ComboModel(
-        vocabulary=vocabulary,
-        dependency_relation=DependencyRelationModel(
+def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary, use_transformer_encoder=False) -> ComboModel:
+    if use_transformer_encoder:
+        return ComboModel(
             vocabulary=vocabulary,
-            dependency_projection_layer=Linear(
-                activation=TanhActivation(),
-                dropout_rate=0.25,
-                in_features=1024,
-                out_features=128
+            dependency_relation=DependencyRelationModel(
+                vocabulary=vocabulary,
+                dependency_projection_layer=Linear(
+                    activation=TanhActivation(),
+                    dropout_rate=0.25,
+                    in_features=164,
+                    out_features=128
+                ),
+                head_predictor=HeadPredictionModel(
+                    cycle_loss_n=0,
+                    dependency_projection_layer=Linear(
+                        activation=TanhActivation(),
+                        in_features=164,
+                        out_features=512
+                    ),
+                    head_projection_layer=Linear(
+                        activation=TanhActivation(),
+                        in_features=164,
+                        out_features=512
+                    )
+                ),
+                head_projection_layer=Linear(
+                    activation=TanhActivation(),
+                    dropout_rate=0.25,
+                    in_features=164,
+                    out_features=128
+                ),
+                vocab_namespace="deprel_labels"
             ),
-            head_predictor=HeadPredictionModel(
-                cycle_loss_n=0,
+            lemmatizer=LemmatizerModel(
+                vocabulary=vocabulary,
+                activations=[GELUActivation(), GELUActivation(), GELUActivation(), LinearActivation()],
+                char_vocab_namespace="token_characters",
+                dilation=[1, 2, 4, 1],
+                embedding_dim=300,
+                filters=[256, 256, 256],
+                input_projection_layer=Linear(
+                    activation=TanhActivation(),
+                    dropout_rate=0.25,
+                    in_features=164,
+                    out_features=32
+                ),
+                kernel_size=[3, 3, 3, 1],
+                lemma_vocab_namespace="lemma_characters",
+                padding=[1, 2, 4, 0],
+                stride=[1, 1, 1, 1]
+            ),
+            loss_weights={
+                "deprel": 0.8,
+                "feats": 0.2,
+                "head": 0.2,
+                "lemma": 0.05,
+                "semrel": 0.05,
+                "upostag": 0.05,
+                "xpostag": 0.05
+            },
+            morphological_feat=MorphologicalFeatures(
+                vocabulary=vocabulary,
+                activations=[TanhActivation(), LinearActivation()],
+                dropout=[0.25, 0.],
+                hidden_dims=[128],
+                input_dim=164,
+                num_layers=2,
+                vocab_namespace="feats_labels"
+            ),
+            regularizer=RegularizerApplicator([
+                (".*conv1d.*", L2Regularizer(1e-6)),
+                (".*forward.*", L2Regularizer(1e-6)),
+                (".*backward.*", L2Regularizer(1e-6)),
+                (".*char_embed.*", L2Regularizer(1e-5))
+            ]),
+            seq_encoder=ComboTransformerEncoder(
+                layer_dropout_probability=0.33,
+                input_dim=164,
+                num_layers=2,
+                feedforward_hidden_dim=2048,
+                num_attention_heads=4,
+                positional_encoding=None,
+                positional_embedding_size=512,
+                dropout_prob=0.1,
+                activation="relu"
+            ),
+            text_field_embedder=BasicTextFieldEmbedder(
+                token_embedders={
+                    "char": CharacterBasedWordEmbedder(
+                        vocabulary=vocabulary,
+                        dilated_cnn_encoder=DilatedCnnEncoder(
+                            activations=[GELUActivation(), GELUActivation(), LinearActivation()],
+                            dilation=[1, 2, 4],
+                            filters=[512, 256, 64],
+                            input_dim=64,
+                            kernel_size=[3, 3, 3],
+                            padding=[1, 2, 4],
+                            stride=[1, 1, 1],
+                        ),
+                        embedding_dim=64
+                    ),
+                    "token": TransformersWordEmbedder(pretrained_transformer_name, projection_dim=100)
+                }
+            ),
+            upos_tagger=FeedForwardPredictor.from_vocab(
+                vocabulary=vocabulary,
+                activations=[TanhActivation(), LinearActivation()],
+                dropout=[0.25, 0.],
+                hidden_dims=[64],
+                input_dim=164,
+                num_layers=2,
+                vocab_namespace="upostag_labels"
+            ),
+            xpos_tagger=FeedForwardPredictor.from_vocab(
+                vocabulary=vocabulary,
+                activations=[TanhActivation(), LinearActivation()],
+                dropout=[0.25, 0.],
+                hidden_dims=[64],
+                input_dim=164,
+                num_layers=2,
+                vocab_namespace="xpostag_labels"
+            )
+        )
+    else:
+        return ComboModel(
+            vocabulary=vocabulary,
+            dependency_relation=DependencyRelationModel(
+                vocabulary=vocabulary,
                 dependency_projection_layer=Linear(
                     activation=TanhActivation(),
+                    dropout_rate=0.25,
                     in_features=1024,
-                    out_features=512
+                    out_features=128
+                ),
+                head_predictor=HeadPredictionModel(
+                    cycle_loss_n=0,
+                    dependency_projection_layer=Linear(
+                        activation=TanhActivation(),
+                        in_features=1024,
+                        out_features=512
+                    ),
+                    head_projection_layer=Linear(
+                        activation=TanhActivation(),
+                        in_features=1024,
+                        out_features=512
+                    )
                 ),
                 head_projection_layer=Linear(
                     activation=TanhActivation(),
+                    dropout_rate=0.25,
                     in_features=1024,
-                    out_features=512
-                )
+                    out_features=128
+                ),
+                vocab_namespace="deprel_labels"
             ),
-            head_projection_layer=Linear(
-                activation=TanhActivation(),
-                dropout_rate=0.25,
-                in_features=1024,
-                out_features=128
+            lemmatizer=LemmatizerModel(
+                vocabulary=vocabulary,
+                activations=[GELUActivation(), GELUActivation(), GELUActivation(), LinearActivation()],
+                char_vocab_namespace="token_characters",
+                dilation=[1, 2, 4, 1],
+                embedding_dim=300,
+                filters=[256, 256, 256],
+                input_projection_layer=Linear(
+                    activation=TanhActivation(),
+                    dropout_rate=0.25,
+                    in_features=1024,
+                    out_features=32
+                ),
+                kernel_size=[3, 3, 3, 1],
+                lemma_vocab_namespace="lemma_characters",
+                padding=[1, 2, 4, 0],
+                stride=[1, 1, 1, 1]
             ),
-            vocab_namespace="deprel_labels"
-        ),
-        lemmatizer=LemmatizerModel(
-            vocabulary=vocabulary,
-            activations=[ReLUActivation(), ReLUActivation(), ReLUActivation(), LinearActivation()],
-            char_vocab_namespace="token_characters",
-            dilation=[1, 2, 4, 1],
-            embedding_dim=256,
-            filters=[256, 256, 256],
-            input_projection_layer=Linear(
-                activation=TanhActivation(),
-                dropout_rate=0.25,
-                in_features=1024,
-                out_features=32
+            loss_weights={
+                "deprel": 0.8,
+                "feats": 0.2,
+                "head": 0.2,
+                "lemma": 0.05,
+                "semrel": 0.05,
+                "upostag": 0.05,
+                "xpostag": 0.05
+            },
+            morphological_feat=MorphologicalFeatures(
+                vocabulary=vocabulary,
+                activations=[TanhActivation(), LinearActivation()],
+                dropout=[0.25, 0.],
+                hidden_dims=[128],
+                input_dim=1024,
+                num_layers=2,
+                vocab_namespace="feats_labels"
             ),
-            kernel_size=[3, 3, 3, 1],
-            lemma_vocab_namespace="lemma_characters",
-            padding=[1, 2, 4, 0],
-            stride=[1, 1, 1, 1]
-        ),
-        loss_weights={
-            "deprel": 0.8,
-            "feats": 0.2,
-            "head": 0.2,
-            "lemma": 0.05,
-            "semrel": 0.05,
-            "upostag": 0.05,
-            "xpostag": 0.05
-        },
-        morphological_feat=MorphologicalFeatures(
-            vocabulary=vocabulary,
-            activations=[TanhActivation(), LinearActivation()],
-            dropout=[0.25, 0.],
-            hidden_dims=[128],
-            input_dim=1024,
-            num_layers=2,
-            vocab_namespace="feats_labels"
-        ),
-        regularizer=RegularizerApplicator([
-            (".*conv1d.*", L2Regularizer(1e-6)),
-            (".*forward.*", L2Regularizer(1e-6)),
-            (".*backward.*", L2Regularizer(1e-6)),
-            (".*char_embed.*", L2Regularizer(1e-5))
-        ]),
-        seq_encoder=ComboEncoder(
-            layer_dropout_probability=0.33,
-            stacked_bilstm=ComboStackedBidirectionalLSTM(
-                hidden_size=512,
-                input_size=164,
+            regularizer=RegularizerApplicator([
+                (".*conv1d.*", L2Regularizer(1e-6)),
+                (".*forward.*", L2Regularizer(1e-6)),
+                (".*backward.*", L2Regularizer(1e-6)),
+                (".*char_embed.*", L2Regularizer(1e-5))
+            ]),
+            seq_encoder=ComboEncoder(
                 layer_dropout_probability=0.33,
+                stacked_bilstm=ComboStackedBidirectionalLSTM(
+                    hidden_size=512,
+                    input_size=164,
+                    layer_dropout_probability=0.33,
+                    num_layers=2,
+                    recurrent_dropout_probability=0.33
+                )
+            ),
+            text_field_embedder=BasicTextFieldEmbedder(
+                token_embedders={
+                    "char": CharacterBasedWordEmbedder(
+                        vocabulary=vocabulary,
+                        dilated_cnn_encoder=DilatedCnnEncoder(
+                            activations=[GELUActivation(), GELUActivation(), LinearActivation()],
+                            dilation=[1, 2, 4],
+                            filters=[512, 256, 64],
+                            input_dim=64,
+                            kernel_size=[3, 3, 3],
+                            padding=[1, 2, 4],
+                            stride=[1, 1, 1],
+                        ),
+                        embedding_dim=64
+                    ),
+                    "token": TransformersWordEmbedder(pretrained_transformer_name, projection_dim=100)
+                }
+            ),
+            upos_tagger=FeedForwardPredictor.from_vocab(
+                vocabulary=vocabulary,
+                activations=[TanhActivation(), LinearActivation()],
+                dropout=[0.25, 0.],
+                hidden_dims=[64],
+                input_dim=1024,
+                num_layers=2,
+                vocab_namespace="upostag_labels"
+            ),
+            xpos_tagger=FeedForwardPredictor.from_vocab(
+                vocabulary=vocabulary,
+                activations=[TanhActivation(), LinearActivation()],
+                dropout=[0.25, 0.],
+                hidden_dims=[64],
+                input_dim=1024,
                 num_layers=2,
-                recurrent_dropout_probability=0.33
+                vocab_namespace="xpostag_labels"
             )
-        ),
-        text_field_embedder=BasicTextFieldEmbedder(
-            token_embedders={
-                "char": CharacterBasedWordEmbedder(
-                    vocabulary=vocabulary,
-                    dilated_cnn_encoder=DilatedCnnEncoder(
-                        activations=[ReLUActivation(), ReLUActivation(), LinearActivation()],
-                        dilation=[1, 2, 4],
-                        filters=[512, 256, 64],
-                        input_dim=64,
-                        kernel_size=[3, 3, 3],
-                        padding=[1, 2, 4],
-                        stride=[1, 1, 1],
-                    ),
-                    embedding_dim=64
-                ),
-                "token": TransformersWordEmbedder(pretrained_transformer_name, projection_dim=100)
-            }
-        ),
-        upos_tagger=FeedForwardPredictor.from_vocab(
-            vocabulary=vocabulary,
-            activations=[TanhActivation(), LinearActivation()],
-            dropout=[0.25, 0.],
-            hidden_dims=[64],
-            input_dim=1024,
-            num_layers=2,
-            vocab_namespace="upostag_labels"
-        ),
-        xpos_tagger=FeedForwardPredictor.from_vocab(
-            vocabulary=vocabulary,
-            activations=[TanhActivation(), LinearActivation()],
-            dropout=[0.25, 0.],
-            hidden_dims=[64],
-            input_dim=1024,
-            num_layers=2,
-            vocab_namespace="xpostag_labels"
         )
-    )
diff --git a/combo/main.py b/combo/main.py
index 84571612e87109e007274c8d264976210e57daae..79e98208ce36c59df835363480dfab6e402e0f10 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -63,6 +63,8 @@ flags.DEFINE_integer(name="batch_size", default=256,
                      help="Batch size")
 flags.DEFINE_integer(name="batches_per_epoch", default=16,
                      help="Number of batches per epoch")
+flags.DEFINE_integer(name="validation_batches_per_epoch", default=4,
+                     help="Number of batches per epoch")
 flags.DEFINE_string(name="pretrained_transformer_name", default="bert-base-cased",
                     help="Pretrained transformer model name (see transformers from HuggingFace library for list of "
                          "available models) for transformers based embeddings.")
@@ -84,6 +86,7 @@ flags.DEFINE_boolean(name="turns", default=False,
                      help="Segment into sentences on sentence break or on turn break.")
 flags.DEFINE_boolean(name="split_subwords", default=False,
                      help="Split subwords (e.g. don\'t = do, n\'t) into separate tokens.")
+flags.DEFINE_boolean(name="transformer_encoder", default=False, help="Use transformer encoder.")
 
 # Finetune after training flags
 flags.DEFINE_string(name="finetuning_training_data_path", default="",
@@ -156,12 +159,12 @@ def get_defaults(dataset_reader: Optional[DatasetReader],
         # Dataset reader is required to read training data and/or for training (and validation) data loader
         dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name,
                                                    tokenizer=LamboTokenizer(FLAGS.tokenizer_language,
-                                                                            default_turns=FLAGS.turns,
-                                                                            default_split_subwords=FLAGS.split_subwords)
+                                                    default_split_level="TURNS" if FLAGS.turns else "SENTENCES",
+                                                    default_split_subwords=FLAGS.split_subwords)
                                                    )
 
     if not training_data_loader:
-        training_data_loader = default_data_loader(dataset_reader, training_data_path)
+        training_data_loader = default_data_loader(dataset_reader, training_data_path, FLAGS.batch_size, FLAGS.batches_per_epoch)
     else:
         if training_data_path:
             training_data_loader.data_path = training_data_path
@@ -170,7 +173,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader],
                            str(training_data_loader.data_path), prefix=prefix)
 
     if FLAGS.validation_data_path and not validation_data_loader:
-        validation_data_loader = default_data_loader(dataset_reader, validation_data_path)
+        validation_data_loader = default_data_loader(dataset_reader, validation_data_path, FLAGS.batch_size, FLAGS.validation_batches_per_epoch)
     elif FLAGS.validation_data_path and validation_data_loader:
         if validation_data_path:
             validation_data_loader.data_path = validation_data_path
@@ -312,7 +315,11 @@ def run(_):
                     FLAGS.validation_data_path,
                     prefix
                 )
-                model = default_model(FLAGS.pretrained_transformer_name, vocabulary)
+
+                if FLAGS.transformer_encoder:
+                    model = default_model(FLAGS.pretrained_transformer_name, vocabulary, True)
+                else:
+                    model = default_model(FLAGS.pretrained_transformer_name, vocabulary)
 
             if FLAGS.use_pure_config and model is None:
                 logger.error('Error in configuration - model could not be read from parameters. ' +
@@ -403,9 +410,9 @@ def run(_):
             logger.info("No dataset reader in the configuration or archive file - using a default UD dataset reader",
                         prefix=prefix)
             dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name,
-                                                       tokenizer=LamboTokenizer(tokenizer_language,
-                                                                                default_turns=FLAGS.turns,
-                                                                                default_split_subwords=FLAGS.split_subwords)
+                                                        tokenizer=LamboTokenizer(tokenizer_language,
+                                                        default_split_level="TURNS" if FLAGS.turns else "SENTENCES",
+                                                        default_split_subwords=FLAGS.split_subwords)
                                                        )
 
         predictor = COMBO(model, dataset_reader)
@@ -509,7 +516,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
             "parameters": {
                 "data_path": FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path,
                 "batch_size": FLAGS.batch_size,
-                "batches_per_epoch": FLAGS.batches_per_epoch,
+                "batches_per_epoch": FLAGS.validation_batches_per_epoch,
                 "reader": {
                     "parameters": {
                         "features": FLAGS.features,
diff --git a/combo/models/__init__.py b/combo/models/__init__.py
index c2d36ae531e502632c5b33e4802fcc3419baa9dd..68d9046e3366ab8952e2c355f1d9f8721808b175 100644
--- a/combo/models/__init__.py
+++ b/combo/models/__init__.py
@@ -1 +1 @@
-from .encoder import ComboStackedBidirectionalLSTM, ComboEncoder
+from .encoder import ComboStackedBidirectionalLSTM, ComboEncoder, ComboTransformerEncoder
diff --git a/combo/models/encoder.py b/combo/models/encoder.py
index 8ed5aff127de187ff51a929ec5e42b345e62e498..c751836235bbfa09a8547d85fdc3e6862844d4d0 100644
--- a/combo/models/encoder.py
+++ b/combo/models/encoder.py
@@ -15,6 +15,7 @@ from combo.modules.augmented_lstm import AugmentedLstm
 from combo.modules.input_variational_dropout import InputVariationalDropout
 from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
 from combo.utils import ConfigurationError
+from combo.modules.seq2seq_encoders.transformer_encoder import TransformerEncoder
 
 TensorPair = Tuple[torch.Tensor, torch.Tensor]
 
@@ -247,3 +248,46 @@ class ComboEncoder(Seq2SeqEncoder, FromParameters):
         x = self.layer_dropout(inputs)
         x = super().forward(x, mask)
         return self.layer_dropout(x)
+
+
+@Registry.register('combo_transformer_encoder')
+class ComboTransformerEncoder(TransformerEncoder, FromParameters):
+    """COMBO encoder (https://www.aclweb.org/anthology/K18-2004.pdf).
+
+    This implementation uses Variational Dropout on the input and then outputs of each BiLSTM layer
+    (instead of used Gaussian Dropout and Gaussian Noise).
+    """
+
+    @register_arguments
+    def __init__(self,
+                 layer_dropout_probability: float,
+                 input_dim: int,
+                 num_layers: int,
+                 feedforward_hidden_dim: int = 2048,
+                 num_attention_heads: int = 4,
+                 positional_encoding: Optional[str] = None,
+                 positional_embedding_size: int = 512,
+                 dropout_prob: float = 0.1,
+                 activation: str = "relu"
+                 # stacked_transformer: ComboStackedBidirectionalLSTM,
+                ):
+        super().__init__(
+            input_dim,
+            num_layers,
+            feedforward_hidden_dim,
+            num_attention_heads,
+            positional_encoding,
+            positional_embedding_size,
+            dropout_prob,
+            activation
+        )
+
+        self.layer_dropout = input_variational_dropout.InputVariationalDropout(p=layer_dropout_probability)
+
+    def forward(self,
+                inputs: torch.Tensor,
+                mask: torch.BoolTensor,
+                hidden_state: torch.Tensor = None) -> torch.Tensor:
+        x = self.layer_dropout(inputs)
+        x = super().forward(x, mask)
+        return self.layer_dropout(x)
\ No newline at end of file
diff --git a/combo/modules/seq2seq_encoders/transformer_encoder.py b/combo/modules/seq2seq_encoders/transformer_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a49100def51df028bd8427ea65e0ab8bd818905a
--- /dev/null
+++ b/combo/modules/seq2seq_encoders/transformer_encoder.py
@@ -0,0 +1,117 @@
+from typing import Optional
+
+from overrides import overrides
+import torch
+from torch import nn
+
+from combo.modules.encoder import _EncoderBase
+from combo.config.from_parameters import FromParameters, register_arguments
+
+# from modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
+from nn.utils import add_positional_features
+
+
+# from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
+# from allennlp.nn.util import add_positional_features
+
+
+class TransformerEncoder(_EncoderBase, FromParameters):
+    """
+    Implements a stacked self-attention encoder similar to the Transformer
+    architecture in [Attention is all you Need]
+    (https://www.semanticscholar.org/paper/Attention-Is-All-You-Need-Vaswani-Shazeer/0737da0767d77606169cbf4187b83e1ab62f6077).
+
+    This class adapts the Transformer from torch.nn for use in AllenNLP. Optionally, it adds positional encodings.
+
+    Registered as a `Seq2SeqEncoder` with name "pytorch_transformer".
+
+    # Parameters
+
+    input_dim : `int`, required.
+        The input dimension of the encoder.
+    feedforward_hidden_dim : `int`, required.
+        The middle dimension of the FeedForward network. The input and output
+        dimensions are fixed to ensure sizes match up for the self attention layers.
+    num_layers : `int`, required.
+        The number of stacked self attention -> feedforward -> layer normalisation blocks.
+    num_attention_heads : `int`, required.
+        The number of attention heads to use per layer.
+    use_positional_encoding : `bool`, optional, (default = `True`)
+        Whether to add sinusoidal frequencies to the input tensor. This is strongly recommended,
+        as without this feature, the self attention layers have no idea of absolute or relative
+        position (as they are just computing pairwise similarity between vectors of elements),
+        which can be important features for many tasks.
+    dropout_prob : `float`, optional, (default = `0.1`)
+        The dropout probability for the feedforward network.
+    """  # noqa
+
+    def __init__(
+        self,
+        input_dim: int,
+        num_layers: int,
+        feedforward_hidden_dim: int = 2048,
+        num_attention_heads: int = 4,
+        positional_encoding: Optional[str] = None,
+        positional_embedding_size: int = 512,
+        dropout_prob: float = 0.1,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+
+        layer = nn.TransformerEncoderLayer(
+            d_model=input_dim,
+            nhead=num_attention_heads,
+            dim_feedforward=feedforward_hidden_dim,
+            dropout=dropout_prob,
+            activation=activation,
+        )
+        self._transformer = nn.TransformerEncoder(layer, num_layers)
+        self._input_dim = input_dim
+
+        # initialize parameters
+        # We do this before the embeddings are initialized so we get the default initialization for the embeddings.
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+        if positional_encoding is None:
+            self._sinusoidal_positional_encoding = False
+            self._positional_embedding = None
+        elif positional_encoding == "sinusoidal":
+            self._sinusoidal_positional_encoding = True
+            self._positional_embedding = None
+        elif positional_encoding == "embedding":
+            self._sinusoidal_positional_encoding = False
+            self._positional_embedding = nn.Embedding(positional_embedding_size, input_dim)
+        else:
+            raise ValueError(
+                "positional_encoding must be one of None, 'sinusoidal', or 'embedding'"
+            )
+
+    def get_input_dim(self) -> int:
+        return self._input_dim
+
+    def get_output_dim(self) -> int:
+        return self._input_dim
+
+    def is_bidirectional(self):
+        return False
+
+    def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor):
+        output = inputs
+        if self._sinusoidal_positional_encoding:
+            output = add_positional_features(output)
+        if self._positional_embedding is not None:
+            position_ids = torch.arange(inputs.size(1), dtype=torch.long, device=output.device)
+            position_ids = position_ids.unsqueeze(0).expand(inputs.shape[:-1])
+            output = output + self._positional_embedding(position_ids)
+
+        # For some reason the torch transformer expects the shape (sequence, batch, features), not the more
+        # familiar (batch, sequence, features), so we have to fix it.
+        output = output.permute(1, 0, 2)
+        # For some other reason, the torch transformer takes the mask backwards.
+        mask = ~mask
+        output = self._transformer(output, src_key_padding_mask=mask)
+        output = output.permute(1, 0, 2)
+
+        return output
\ No newline at end of file
diff --git a/combo/nn/activations.py b/combo/nn/activations.py
index f4e5d43e2a4d1ca70078d5f2bcea9f2a758ba8b7..60f64aa7304f7c701c120e1f9dcea3c52ee20cae 100644
--- a/combo/nn/activations.py
+++ b/combo/nn/activations.py
@@ -39,7 +39,7 @@ class ReLUActivation(Activation):
 
 
 @Registry.register('gelu')
-class ReLUActivation(Activation):
+class GELUActivation(Activation):
     def __init__(self):
         super().__init__()
         self.__torch_activation = torch.nn.GELU()
diff --git a/combo/nn/utils.py b/combo/nn/utils.py
index 4333be0a1fc5736c87a27efb18426eaee1c5d2e7..bdf4dc3c196d47975eed97c5bc57302441345bb8 100644
--- a/combo/nn/utils.py
+++ b/combo/nn/utils.py
@@ -2,6 +2,7 @@
 Adapted from AllenNLP
 https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py
 """
+import math
 from typing import Union, Dict, Optional, List, Any, NamedTuple
 
 import torch
@@ -479,3 +480,61 @@ def batched_span_select(target: torch.Tensor, spans: torch.LongTensor) -> torch.
     span_embeddings = batched_index_select(target, span_indices)
 
     return span_embeddings, span_mask
+
+
+def add_positional_features(
+    tensor: torch.Tensor, min_timescale: float = 1.0, max_timescale: float = 1.0e4
+):
+
+    """
+    Implements the frequency-based positional encoding described
+    in [Attention is All you Need][0].
+
+    Adds sinusoids of different frequencies to a `Tensor`. A sinusoid of a
+    different frequency and phase is added to each dimension of the input `Tensor`.
+    This allows the attention heads to use absolute and relative positions.
+
+    The number of timescales is equal to hidden_dim / 2 within the range
+    (min_timescale, max_timescale). For each timescale, the two sinusoidal
+    signals sin(timestep / timescale) and cos(timestep / timescale) are
+    generated and concatenated along the hidden_dim dimension.
+
+    [0]: https://www.semanticscholar.org/paper/Attention-Is-All-You-Need-Vaswani-Shazeer/0737da0767d77606169cbf4187b83e1ab62f6077
+
+    # Parameters
+
+    tensor : `torch.Tensor`
+        a Tensor with shape (batch_size, timesteps, hidden_dim).
+    min_timescale : `float`, optional (default = `1.0`)
+        The smallest timescale to use.
+    max_timescale : `float`, optional (default = `1.0e4`)
+        The largest timescale to use.
+
+    # Returns
+
+    `torch.Tensor`
+        The input tensor augmented with the sinusoidal frequencies.
+    """  # noqa
+    _, timesteps, hidden_dim = tensor.size()
+
+    timestep_range = get_range_vector(timesteps, get_device_of(tensor)).data.float()
+    # We're generating both cos and sin frequencies,
+    # so half for each.
+    num_timescales = hidden_dim // 2
+    timescale_range = get_range_vector(num_timescales, get_device_of(tensor)).data.float()
+
+    log_timescale_increments = math.log(float(max_timescale) / float(min_timescale)) / float(
+        num_timescales - 1
+    )
+    inverse_timescales = min_timescale * torch.exp(timescale_range * -log_timescale_increments)
+
+    # Broadcasted multiplication - shape (timesteps, num_timescales)
+    scaled_time = timestep_range.unsqueeze(1) * inverse_timescales.unsqueeze(0)
+    # shape (timesteps, 2 * num_timescales)
+    sinusoids = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 1)
+    if hidden_dim % 2 != 0:
+        # if the number of dimensions is odd, the cos and sin
+        # timescales had size (hidden_dim - 1) / 2, so we need
+        # to add a row of zeros to make up the difference.
+        sinusoids = torch.cat([sinusoids, sinusoids.new_zeros(timesteps, 1)], 1)
+    return tensor + sinusoids.unsqueeze(0)
\ No newline at end of file
diff --git a/combo/predict.py b/combo/predict.py
index c4f507bb6a2aeb2ed1400a5abd24b7db773694ec..84e9aedaac9f7a7c91503fb1364b28b666c2130a 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -42,7 +42,7 @@ class COMBO(PredictorModule):
         self.without_sentence_embedding = False
         self.line_to_conllu = line_to_conllu
 
-    def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
+    def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]], **kwargs):
         """Depending on the input uses (or ignores) tokenizer.
         When model isn't only text-based only List[data.Sentence] is possible input.
 
@@ -55,7 +55,7 @@ class COMBO(PredictorModule):
         :return: Sentence or List[Sentence] depending on the input
         """
         try:
-            return self.predict(sentence)
+            return self.predict(sentence, **kwargs)
         except Exception as e:
             logger.error(e)
             logger.error('Exiting.')