diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py
index 1b918b3ad66692a761b564c08d0c270745a263cb..550a80ba6947f51cc5e0ef9a11b095e8488805cb 100644
--- a/combo/data/token_indexers/__init__.py
+++ b/combo/data/token_indexers/__init__.py
@@ -1,2 +1,3 @@
+from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
 from .token_characters_indexer import TokenCharactersIndexer
 from .token_features_indexer import TokenFeatsIndexer
diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbe2368b08493865ee55252e366e80c1d7b759c2
--- /dev/null
+++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
@@ -0,0 +1,85 @@
+from typing import Optional, Dict, Any, List, Tuple
+
+from allennlp import data
+from allennlp.data import token_indexers, tokenizers
+
+
+@data.TokenIndexer.register("pretrained_transformer_mismatched_fixed")
+class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
+
+    def __init__(self, model_name: str, namespace: str = "tags", max_length: int = None,
+                 tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> None:
+        # The matched version v.s. mismatchedńskie
+        super().__init__(model_name, namespace, max_length, tokenizer_kwargs, **kwargs)
+        self._matched_indexer = PretrainedTransformerIndexer(
+            model_name,
+            namespace=namespace,
+            max_length=max_length,
+            tokenizer_kwargs=tokenizer_kwargs,
+            **kwargs,
+        )
+        self._allennlp_tokenizer = self._matched_indexer._allennlp_tokenizer
+        self._tokenizer = self._matched_indexer._tokenizer
+        self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens
+        self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens
+
+
+class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer):
+
+    def __init__(
+            self,
+            model_name: str,
+            namespace: str = "tags",
+            max_length: int = None,
+            tokenizer_kwargs: Optional[Dict[str, Any]] = None,
+            **kwargs,
+    ) -> None:
+        super().__init__(model_name, namespace, max_length, tokenizer_kwargs, **kwargs)
+        self._namespace = namespace
+        self._allennlp_tokenizer = PretrainedTransformerTokenizer(
+            model_name, tokenizer_kwargs=tokenizer_kwargs
+        )
+        self._tokenizer = self._allennlp_tokenizer.tokenizer
+        self._added_to_vocabulary = False
+
+        self._num_added_start_tokens = len(self._allennlp_tokenizer.single_sequence_start_tokens)
+        self._num_added_end_tokens = len(self._allennlp_tokenizer.single_sequence_end_tokens)
+
+        self._max_length = max_length
+        if self._max_length is not None:
+            num_added_tokens = len(self._allennlp_tokenizer.tokenize("a")) - 1
+            self._effective_max_length = (  # we need to take into account special tokens
+                    self._max_length - num_added_tokens
+            )
+            if self._effective_max_length <= 0:
+                raise ValueError(
+                    "max_length needs to be greater than the number of special tokens inserted."
+                )
+
+
+class PretrainedTransformerTokenizer(tokenizers.PretrainedTransformerTokenizer):
+
+    def _intra_word_tokenize(
+            self, string_tokens: List[str]
+    ) -> Tuple[List[data.Token], List[Optional[Tuple[int, int]]]]:
+        tokens: List[data.Token] = []
+        offsets: List[Optional[Tuple[int, int]]] = []
+        for token_string in string_tokens:
+            wordpieces = self.tokenizer.encode_plus(
+                token_string,
+                add_special_tokens=False,
+                return_tensors=None,
+                return_offsets_mapping=False,
+                return_attention_mask=False,
+            )
+            wp_ids = wordpieces["input_ids"]
+
+            if len(wp_ids) > 0:
+                offsets.append((len(tokens), len(tokens) + len(wp_ids) - 1))
+                tokens.extend(
+                    data.Token(text=wp_text, text_id=wp_id)
+                    for wp_id, wp_text in zip(wp_ids, self.tokenizer.convert_ids_to_tokens(wp_ids))
+                )
+            else:
+                offsets.append(None)
+        return tokens, offsets
diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
index edb37a174adb2a2f5fd8cd3edcc4c21c6bc4fa75..5cad95928dab03d8e5046eb1a281c07e9ffe33ff 100644
--- a/combo/models/embeddings.py
+++ b/combo/models/embeddings.py
@@ -1,5 +1,5 @@
 """Embeddings."""
-from typing import Optional
+from typing import Optional, Dict, Any
 
 import torch
 import torch.nn as nn
@@ -110,8 +110,10 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
                  projection_dim: int,
                  projection_activation: Optional[allen_nn.Activation] = lambda x: x,
                  projection_dropout_rate: Optional[float] = 0.0,
-                 freeze_transformer: bool = True):
-        super().__init__(model_name)
+                 freeze_transformer: bool = True,
+                 tokenizer_kwargs: Optional[Dict[str, Any]] = None,
+                 transformer_kwargs: Optional[Dict[str, Any]] = None):
+        super().__init__(model_name, tokenizer_kwargs=tokenizer_kwargs, transformer_kwargs=transformer_kwargs)
         self.freeze_transformer = freeze_transformer
         if self.freeze_transformer:
             self._matched_embedder.eval()
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
index 772b9b08c68e3ca9865cadf222dc2cd47b139640..aeb9f097369f515b3860d961f638870ec17f6786 100644
--- a/combo/training/trainer.py
+++ b/combo/training/trainer.py
@@ -3,7 +3,7 @@ import logging
 import os
 import time
 import traceback
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Union
 
 import torch
 import torch.distributed as dist
@@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
 @training.EpochCallback.register("transfer_patience")
 class TransferPatienceEpochCallback(training.EpochCallback):
 
-    def __call__(self, trainer: "training.GradientDescentTrainer", metrics: Dict[str, Any], epoch: int) -> None:
+    def __call__(self, trainer: "training.GradientDescentTrainer", metrics: Dict[str, Any], epoch: int,
+                 is_master: bool) -> None:
         if trainer._learning_rate_scheduler and trainer._learning_rate_scheduler.patience is not None:
             trainer._metric_tracker._patience = trainer._learning_rate_scheduler.patience
             trainer._metric_tracker._epochs_with_no_improvement = 0
@@ -45,20 +46,23 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
                  patience: Optional[int] = None, validation_metric: str = "-loss",
                  validation_data_loader: data.DataLoader = None, num_epochs: int = 20,
                  serialization_dir: Optional[str] = None, checkpointer: checkpointer.Checkpointer = None,
-                 cuda_device: int = -1,
+                 cuda_device: Optional[Union[int, torch.device]] = -1,
                  grad_norm: Optional[float] = None, grad_clipping: Optional[float] = None,
                  learning_rate_scheduler: Optional[learning_rate_schedulers.LearningRateScheduler] = None,
                  momentum_scheduler: Optional[momentum_schedulers.MomentumScheduler] = None,
                  tensorboard_writer: allen_tensorboard_writer.TensorboardWriter = None,
                  moving_average: Optional[moving_average.MovingAverage] = None,
                  batch_callbacks: List[training.BatchCallback] = None,
-                 epoch_callbacks: List[training.EpochCallback] = None, distributed: bool = False, local_rank: int = 0,
+                 epoch_callbacks: List[training.EpochCallback] = None,
+                 end_callbacks: List[training.EpochCallback] = None,
+                 trainer_callbacks: List[training.TrainerCallback] = None,
+                 distributed: bool = False, local_rank: int = 0,
                  world_size: int = 1, num_gradient_accumulation_steps: int = 1,
                  use_amp: bool = False) -> None:
         super().__init__(model, optimizer, data_loader, patience, validation_metric, validation_data_loader, num_epochs,
                          serialization_dir, checkpointer, cuda_device, grad_norm, grad_clipping,
                          learning_rate_scheduler, momentum_scheduler, tensorboard_writer, moving_average,
-                         batch_callbacks, epoch_callbacks, distributed, local_rank, world_size,
+                         batch_callbacks, epoch_callbacks, end_callbacks, trainer_callbacks, distributed, local_rank, world_size,
                          num_gradient_accumulation_steps, use_amp)
         # TODO extract param to constructor (+ constructor method?)
         self.validate_every_n = 5
@@ -93,7 +97,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             metrics["best_validation_" + key] = value
 
         for callback in self._epoch_callbacks:
-            callback(self, metrics={}, epoch=-1)
+            callback(self, metrics={}, epoch=-1, is_master=True)
 
         for epoch in range(epoch_counter, self._num_epochs):
             epoch_start_time = time.time()
@@ -190,7 +194,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
                 dist.barrier()
 
             for callback in self._epoch_callbacks:
-                callback(self, metrics=metrics, epoch=epoch)
+                callback(self, metrics=metrics, epoch=epoch, is_master=self._master)
 
             epoch_elapsed_time = time.time() - epoch_start_time
             logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time))
@@ -243,7 +247,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             batch_callbacks: List[training.BatchCallback] = None,
             epoch_callbacks: List[training.EpochCallback] = None,
     ) -> "training.Trainer":
-        if tensorboard_writer.construct() is None:
+        if tensorboard_writer is None:
             tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter)
         return super().from_partial_objects(
             model=model,
diff --git a/config.template.jsonnet b/config.template.jsonnet
index 57f02ae3fcaadbad5402f40add0b6a2b5d3a874c..8e5ddc9f3d120156a4d00b1d231a54d90a66b631 100644
--- a/config.template.jsonnet
+++ b/config.template.jsonnet
@@ -112,8 +112,10 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
         use_sem: if in_targets("semrel") then true else false,
         token_indexers: {
             token: if use_transformer then {
-                type: "pretrained_transformer_mismatched",
+                type: "pretrained_transformer_mismatched_fixed",
                 model_name: pretrained_transformer_name,
+                tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
+                                  then {use_fast: false} else {},
             } else {
                 # SingleIdTokenIndexer, token as single int
                 type: "single_id",
@@ -202,6 +204,8 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
                     type: "transformers_word_embeddings",
                     model_name: pretrained_transformer_name,
                     projection_dim: projected_embedding_dim,
+                    tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
+                                      then {use_fast: false} else {},
                 } else {
                     type: "embeddings_projected",
                     embedding_dim: embedding_dim,
diff --git a/setup.py b/setup.py
index dd21555edf3cda7b60f6ebe148c5d3ad28f36e91..6ce7e3cd9a3e1fd43d4e75115ffad9724b5cd5bc 100644
--- a/setup.py
+++ b/setup.py
@@ -3,8 +3,9 @@ from setuptools import find_packages, setup
 
 REQUIREMENTS = [
     'absl-py==0.9.0',
-    'allennlp==1.1.0',
+    'allennlp==1.2.0',
     'conllu==2.3.2',
+    'dataclasses==0.5',
     'dataclasses-json==0.5.2',
     'joblib==0.14.1',
     'jsonnet==0.15.0',
@@ -13,7 +14,7 @@ REQUIREMENTS = [
     'tensorboard==2.1.0',
     'torch==1.6.0',
     'tqdm==4.43.0',
-    'transformers>=3.0.0,<3.1.0',
+    'transformers>=3.4.0,<3.5',
     'urllib3>=1.25.11',
 ]