diff --git a/.gitignore b/.gitignore
index 31cb7122cf476011b717f36c98500e8f9de01825..4f8a96d07519496b998b8b9ef000aeb691aa51f2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,10 +12,11 @@ __pycache__
 /checkpoints
 .dvc
 .tox
-notebooks
+notebooks_work
 dvc.lock
 dask-worker-space
 test_data
 .env
 deploy
-service.log
\ No newline at end of file
+service.log
+lightning_logs
\ No newline at end of file
diff --git a/README.md b/README.md
index ffead895613e0694001356eac13142709621cb0b..abdb9c122e7668962290ad5d74c1273e448e9f3a 100644
--- a/README.md
+++ b/README.md
@@ -15,6 +15,8 @@ A service that automatically adds punctuation to raw word-stream (eg. from speec
 
 3. actions_mixed: A model based on the full transformer (encoder + decoder) architecture. It's much less performant, as it only predicts actions for one word at the time. However, it can model action probabilities conditioned on both the input and output predicted so far. Because of that, it's much less prone to not uppercasing letters in a new sentence or placing multiple punctuation signs in close proximity.
 
+4. actions_ensemble: Ensemble predictor build out of transformer encoders.
+
 ### Translation
 2. translation (Deprecated): Full encoder-decoder stack that takes input (unpunctuated text) and the output produced so far to predict the next token. The main difference from the actions model is that it's a full text2text model without restriction on tokens. Because of that, in theory, it can represent more cases (eg. all upper, some upper, dashes, ellipsis, etc...), as opposed to only a few explicitly defined actions. However, the lack of constraints makes it much harder to train (both in performance and data size).
 
@@ -66,3 +68,9 @@ where model_name is one of models specified in models_enabled. If no model is pr
 
 ## Mountpoints
 Directory where the model will be downloaded (~500Mb) needs to be mounted at /punctuator/deploy
+
+## Code structure
+
+- `src/models`: Place where all models are defined.
+- `src/predictors`: Utility code that allows to use models via a simple 'predict' function that transforms text 2 text
+- `src/pipelines`: Scripts forming the whole pipeline 
\ No newline at end of file
diff --git a/dvc.yaml b/dvc.yaml
index 4970d16862ec36453545d39c8dd071e3382df7b0..e16c9b1dfafe3df3c592e03816d67efb2dcb1ed2 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -46,7 +46,7 @@ stages:
 
   # Base
   actions_base_training:
-    cmd: python3 -m src.pipelines.actions_based.train_base
+    cmd: python3 -m src.pipelines.actions_based.train -m actions_base
     deps:
     - generated/actions/stage4_reindexing
     - generated/actions/stage5_stats
@@ -54,11 +54,10 @@ stages:
     params:
     - global.base_model
     - global.random_seed
-    - actions.training_base.max_training_time
-    - actions.training_base.learning_rate
-    - actions.training_base.num_epochs
-    - actions.training_base.batch_size
-    - actions.training_base.save_step
+    - actions.training.actions_base.learning_rate
+    - actions.training.actions_base.batch_size
+    - actions.training.actions_base.save_step
+    - actions.training.actions_base.start_checkpoint
     outs:
     - checkpoints/actions_base
 
@@ -69,13 +68,13 @@ stages:
     - generated/actions/stage4_reindexing
     - src
     params:
-    - actions.testing_base.limit
+    - actions.testing.actions_base.limit
     outs:
     - generated/actions/test_results_base
 
   # Restricted
   actions_restricted_training:
-    cmd: python3 -m src.pipelines.actions_based.train_restricted
+    cmd: python3 -m src.pipelines.actions_based.train -m actions_restricted
     deps:
     - generated/actions/stage4_reindexing
     - generated/actions/stage5_stats
@@ -83,11 +82,10 @@ stages:
     params:
     - global.base_model
     - global.random_seed
-    - actions.training_restricted.max_training_time
-    - actions.training_restricted.learning_rate
-    - actions.training_restricted.num_epochs
-    - actions.training_restricted.batch_size
-    - actions.training_restricted.save_step
+    - actions.training.actions_restricted.learning_rate
+    - actions.training.actions_restricted.batch_size
+    - actions.training.actions_restricted.save_step
+    - actions.training.actions_base.start_checkpoint
     outs:
     - checkpoints/actions_restricted
 
@@ -98,7 +96,7 @@ stages:
     - generated/actions/stage4_reindexing
     - src
     params:
-    - actions.testing_restricted.limit
+    - actions.testing.actions_restricted.limit
     outs:
     - generated/actions/test_results_restricted
 
@@ -112,16 +110,15 @@ stages:
     params:
     - global.base_model
     - global.random_seed
-    - actions.training_mixed.embedding_size
-    - actions.training_mixed.num_heads
-    - actions.training_mixed.num_layers
-    - actions.training_mixed.dropout
-    - actions.training_mixed.feedforward_neurons
-    - actions.training_mixed.max_training_time
-    - actions.training_mixed.learning_rate
-    - actions.training_mixed.num_epochs
-    - actions.training_mixed.batch_size
-    - actions.training_mixed.save_step
+    - actions.training.actions_mixed.embedding_size
+    - actions.training.actions_mixed.num_heads
+    - actions.training.actions_mixed.num_layers
+    - actions.training.actions_mixed.dropout
+    - actions.training.actions_mixed.feedforward_neurons
+    - actions.training.actions_mixed.learning_rate
+    - actions.training.actions_mixed.batch_size
+    - actions.training.actions_mixed.save_step
+    - actions.training.actions_mixed.start_checkpoint
     outs:
     - checkpoints/actions_mixed
 
@@ -132,7 +129,7 @@ stages:
     - generated/actions/stage4_reindexing
     - src
     params:
-    - actions.testing_mixed.limit
+    - actions.testing.actions_mixed.limit
     outs:
     - generated/actions/test_results_mixed
 
diff --git a/notebooks/tokens_in_dataset.ipynb b/notebooks/tokens_in_dataset.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..899c403203704643dacac3b9a885477b8c45d370
--- /dev/null
+++ b/notebooks/tokens_in_dataset.ipynb
@@ -0,0 +1,183 @@
+{
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.2-final"
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python38264bita7d7da14168440cb9836372958035d4a",
+   "display_name": "Python 3.8.2 64-bit"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "%load_ext autoreload\n",
+    "%autoreload 2\n",
+    "\n",
+    "import sys\n",
+    "sys.path.append(\"../\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import dask\n",
+    "import dask.dataframe as dd\n",
+    "import glob\n",
+    "import numpy as np\n",
+    "from collections import Counter\n",
+    "from src.pipelines.actions_based.processing import text_from_xml, input_preprocess, output_preprocess\n",
+    "import multiprocessing\n",
+    "from dask.distributed import Client\n",
+    "import seaborn as sns"
+   ]
+  },
+  {
+   "source": [
+    "## Getting list of interesting files"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": "Number of files: 343368\n"
+    }
+   ],
+   "source": [
+    "FILE_SCHEMA = \"../data/**/text_structure.xml\"\n",
+    "files_paths = glob.glob(FILE_SCHEMA, recursive=True)\n",
+    "\n",
+    "print(f\"Number of files: {len(files_paths)}\")"
+   ]
+  },
+  {
+   "source": [
+    "## Shuffle the list and take only top n samples"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SUBSAMPLE_SIZE = 20_000\n",
+    "np.random.shuffle(files_paths)\n",
+    "\n",
+    "files_paths_subsample = files_paths[:SUBSAMPLE_SIZE]"
+   ]
+  },
+  {
+   "source": [
+    "## Count characters that are deleted"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "def count_missing_characters(file_path):\n",
+    "    full_text = text_from_xml(file_path)\n",
+    "    text_cleaned = input_preprocess(full_text)\n",
+    "\n",
+    "    return Counter(full_text.lower()) - Counter(text_cleaned)\n",
+    "\n",
+    "counter = []\n",
+    "for file_path in files_paths_subsample:\n",
+    "    res = dask.delayed(count_missing_characters)(file_path)\n",
+    "    counter.append(res)\n",
+    "import seaborn as sns\n",
+    "\n",
+    "client = Client(n_workers=multiprocessing.cpu_count())\n",
+    "total = dask.delayed(sum)(counter, Counter()).compute(sheduler=\"processes\")"
+   ]
+  },
+  {
+   "source": [
+    "## Visualize results"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "<matplotlib.axes._subplots.AxesSubplot at 0x7f82abf7f280>"
+     },
+     "metadata": {},
+     "execution_count": 9
+    },
+    {
+     "output_type": "display_data",
+     "data": {
+      "text/plain": "<Figure size 432x288 with 1 Axes>",
+      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 372.103125 248.518125\" width=\"372.103125pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n  <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n  </style>\n </defs>\n <g id=\"figure_1\">\n  <g id=\"patch_1\">\n   <path d=\"M 0 248.518125 \nL 372.103125 248.518125 \nL 372.103125 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n  </g>\n  <g id=\"axes_1\">\n   <g id=\"patch_2\">\n    <path d=\"M 30.103125 224.64 \nL 364.903125 224.64 \nL 364.903125 7.2 \nL 30.103125 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n   </g>\n   <g id=\"patch_3\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 31.777125 224.64 \nL 45.169125 224.64 \nL 45.169125 17.554286 \nL 31.777125 17.554286 \nz\n\" style=\"fill:#ea96a3;\"/>\n   </g>\n   <g id=\"patch_4\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 48.517125 224.64 \nL 61.909125 224.64 \nL 61.909125 59.29656 \nL 48.517125 59.29656 \nz\n\" style=\"fill:#e79683;\"/>\n   </g>\n   <g id=\"patch_5\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 65.257125 224.64 \nL 78.649125 224.64 \nL 78.649125 208.501335 \nL 65.257125 208.501335 \nz\n\" style=\"fill:#d7944e;\"/>\n   </g>\n   <g id=\"patch_6\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 81.997125 224.64 \nL 95.389125 224.64 \nL 95.389125 212.314053 \nL 81.997125 212.314053 \nz\n\" style=\"fill:#bf9a4a;\"/>\n   </g>\n   <g id=\"patch_7\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 98.737125 224.64 \nL 112.129125 224.64 \nL 112.129125 212.958542 \nL 98.737125 212.958542 \nz\n\" style=\"fill:#ab9e47;\"/>\n   </g>\n   <g id=\"patch_8\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 115.477125 224.64 \nL 128.869125 224.64 \nL 128.869125 219.159208 \nL 115.477125 219.159208 \nz\n\" style=\"fill:#98a246;\"/>\n   </g>\n   <g id=\"patch_9\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 132.217125 224.64 \nL 145.609125 224.64 \nL 145.609125 219.749564 \nL 132.217125 219.749564 \nz\n\" style=\"fill:#7fa946;\"/>\n   </g>\n   <g id=\"patch_10\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 148.957125 224.64 \nL 162.349125 224.64 \nL 162.349125 219.915309 \nL 148.957125 219.915309 \nz\n\" style=\"fill:#48b052;\"/>\n   </g>\n   <g id=\"patch_11\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 165.697125 224.64 \nL 179.089125 224.64 \nL 179.089125 220.512359 \nL 165.697125 220.512359 \nz\n\" style=\"fill:#49ae83;\"/>\n   </g>\n   <g id=\"patch_12\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 182.437125 224.64 \nL 195.829125 224.64 \nL 195.829125 221.326274 \nL 182.437125 221.326274 \nz\n\" style=\"fill:#4aad96;\"/>\n   </g>\n   <g id=\"patch_13\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 199.177125 224.64 \nL 212.569125 224.64 \nL 212.569125 221.709135 \nL 199.177125 221.709135 \nz\n\" style=\"fill:#4baba4;\"/>\n   </g>\n   <g id=\"patch_14\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 215.917125 224.64 \nL 229.309125 224.64 \nL 229.309125 221.85731 \nL 215.917125 221.85731 \nz\n\" style=\"fill:#4dabb2;\"/>\n   </g>\n   <g id=\"patch_15\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 232.657125 224.64 \nL 246.049125 224.64 \nL 246.049125 222.236657 \nL 232.657125 222.236657 \nz\n\" style=\"fill:#50acc3;\"/>\n   </g>\n   <g id=\"patch_16\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 249.397125 224.64 \nL 262.789125 224.64 \nL 262.789125 222.251718 \nL 249.397125 222.251718 \nz\n\" style=\"fill:#56addb;\"/>\n   </g>\n   <g id=\"patch_17\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 266.137125 224.64 \nL 279.529125 224.64 \nL 279.529125 222.862321 \nL 266.137125 222.862321 \nz\n\" style=\"fill:#94aee8;\"/>\n   </g>\n   <g id=\"patch_18\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 282.877125 224.64 \nL 296.269125 224.64 \nL 296.269125 222.880393 \nL 282.877125 222.880393 \nz\n\" style=\"fill:#b6a8eb;\"/>\n   </g>\n   <g id=\"patch_19\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 299.617125 224.64 \nL 313.009125 224.64 \nL 313.009125 222.944566 \nL 299.617125 222.944566 \nz\n\" style=\"fill:#ce9be9;\"/>\n   </g>\n   <g id=\"patch_20\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 316.357125 224.64 \nL 329.749125 224.64 \nL 329.749125 223.372441 \nL 316.357125 223.372441 \nz\n\" style=\"fill:#e689e4;\"/>\n   </g>\n   <g id=\"patch_21\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 333.097125 224.64 \nL 346.489125 224.64 \nL 346.489125 223.568975 \nL 333.097125 223.568975 \nz\n\" style=\"fill:#e88fcc;\"/>\n   </g>\n   <g id=\"patch_22\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 349.837125 224.64 \nL 363.229125 224.64 \nL 363.229125 224.340973 \nL 349.837125 224.340973 \nz\n\" style=\"fill:#e993b9;\"/>\n   </g>\n   <g id=\"matplotlib.axis_1\">\n    <g id=\"xtick_1\">\n     <g id=\"line2d_1\">\n      <defs>\n       <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"mfd7c4472b7\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"38.473125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_1\">\n      <!-- , -->\n      <defs>\n       <path d=\"M 11.71875 12.40625 \nL 22.015625 12.40625 \nL 22.015625 4 \nL 14.015625 -11.625 \nL 7.71875 -11.625 \nL 11.71875 4 \nz\n\" id=\"DejaVuSans-44\"/>\n      </defs>\n      <g transform=\"translate(36.884063 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-44\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_2\">\n     <g id=\"line2d_2\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"55.213125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_2\">\n      <!-- . -->\n      <defs>\n       <path d=\"M 10.6875 12.40625 \nL 21 12.40625 \nL 21 0 \nL 10.6875 0 \nz\n\" id=\"DejaVuSans-46\"/>\n      </defs>\n      <g transform=\"translate(53.624063 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-46\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_3\">\n     <g id=\"line2d_3\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"71.953125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_3\">\n      <!--   -->\n      <defs>\n       <path id=\"DejaVuSans-32\"/>\n      </defs>\n      <g transform=\"translate(70.364063 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-32\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_4\">\n     <g id=\"line2d_4\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"88.693125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_4\">\n      <!-- ? -->\n      <defs>\n       <path d=\"M 19.09375 12.40625 \nL 29 12.40625 \nL 29 0 \nL 19.09375 0 \nz\nM 28.71875 19.578125 \nL 19.390625 19.578125 \nL 19.390625 27.09375 \nQ 19.390625 32.03125 20.75 35.203125 \nQ 22.125 38.375 26.515625 42.578125 \nL 30.90625 46.921875 \nQ 33.6875 49.515625 34.9375 51.8125 \nQ 36.1875 54.109375 36.1875 56.5 \nQ 36.1875 60.84375 32.984375 63.53125 \nQ 29.78125 66.21875 24.515625 66.21875 \nQ 20.65625 66.21875 16.28125 64.5 \nQ 11.921875 62.796875 7.171875 59.515625 \nL 7.171875 68.703125 \nQ 11.765625 71.484375 16.46875 72.84375 \nQ 21.1875 74.21875 26.21875 74.21875 \nQ 35.203125 74.21875 40.640625 69.484375 \nQ 46.09375 64.75 46.09375 56.984375 \nQ 46.09375 53.265625 44.328125 49.921875 \nQ 42.578125 46.578125 38.1875 42.390625 \nL 33.890625 38.1875 \nQ 31.59375 35.890625 30.640625 34.59375 \nQ 29.6875 33.296875 29.296875 32.078125 \nQ 29 31.0625 28.859375 29.59375 \nQ 28.71875 28.125 28.71875 25.59375 \nz\n\" id=\"DejaVuSans-63\"/>\n      </defs>\n      <g transform=\"translate(86.039219 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-63\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_5\">\n     <g id=\"line2d_5\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"105.433125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_5\">\n      <!-- - -->\n      <defs>\n       <path d=\"M 4.890625 31.390625 \nL 31.203125 31.390625 \nL 31.203125 23.390625 \nL 4.890625 23.390625 \nz\n\" id=\"DejaVuSans-45\"/>\n      </defs>\n      <g transform=\"translate(103.629219 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-45\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_6\">\n     <g id=\"line2d_6\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"122.173125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_6\">\n      <!-- : -->\n      <defs>\n       <path d=\"M 11.71875 12.40625 \nL 22.015625 12.40625 \nL 22.015625 0 \nL 11.71875 0 \nz\nM 11.71875 51.703125 \nL 22.015625 51.703125 \nL 22.015625 39.3125 \nL 11.71875 39.3125 \nz\n\" id=\"DejaVuSans-58\"/>\n      </defs>\n      <g transform=\"translate(120.48875 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-58\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_7\">\n     <g id=\"line2d_7\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"138.913125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_7\">\n      <!-- – -->\n      <defs>\n       <path d=\"M 4.890625 30.90625 \nL 45.125 30.90625 \nL 45.125 23.875 \nL 4.890625 23.875 \nz\n\" id=\"DejaVuSans-8211\"/>\n      </defs>\n      <g transform=\"translate(136.413125 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-8211\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_8\">\n     <g id=\"line2d_8\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"155.653125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_8\">\n      <!-- ! -->\n      <defs>\n       <path d=\"M 15.09375 12.40625 \nL 25 12.40625 \nL 25 0 \nL 15.09375 0 \nz\nM 15.09375 72.90625 \nL 25 72.90625 \nL 25 40.921875 \nL 24.03125 23.484375 \nL 16.109375 23.484375 \nL 15.09375 40.921875 \nz\n\" id=\"DejaVuSans-33\"/>\n      </defs>\n      <g transform=\"translate(153.648438 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-33\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_9\">\n     <g id=\"line2d_9\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"172.393125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_9\">\n      <!-- \" -->\n      <defs>\n       <path d=\"M 17.921875 72.90625 \nL 17.921875 45.796875 \nL 9.625 45.796875 \nL 9.625 72.90625 \nz\nM 36.375 72.90625 \nL 36.375 45.796875 \nL 28.078125 45.796875 \nL 28.078125 72.90625 \nz\n\" id=\"DejaVuSans-34\"/>\n      </defs>\n      <g transform=\"translate(170.093125 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-34\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_10\">\n     <g id=\"line2d_10\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"189.133125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_10\">\n      <!-- ) -->\n      <defs>\n       <path d=\"M 8.015625 75.875 \nL 15.828125 75.875 \nQ 23.140625 64.359375 26.78125 53.3125 \nQ 30.421875 42.28125 30.421875 31.390625 \nQ 30.421875 20.453125 26.78125 9.375 \nQ 23.140625 -1.703125 15.828125 -13.1875 \nL 8.015625 -13.1875 \nQ 14.5 -2 17.703125 9.0625 \nQ 20.90625 20.125 20.90625 31.390625 \nQ 20.90625 42.671875 17.703125 53.65625 \nQ 14.5 64.65625 8.015625 75.875 \nz\n\" id=\"DejaVuSans-41\"/>\n      </defs>\n      <g transform=\"translate(187.182344 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-41\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_11\">\n     <g id=\"line2d_11\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"205.873125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_11\">\n      <!-- ( -->\n      <defs>\n       <path d=\"M 31 75.875 \nQ 24.46875 64.65625 21.28125 53.65625 \nQ 18.109375 42.671875 18.109375 31.390625 \nQ 18.109375 20.125 21.3125 9.0625 \nQ 24.515625 -2 31 -13.1875 \nL 23.1875 -13.1875 \nQ 15.875 -1.703125 12.234375 9.375 \nQ 8.59375 20.453125 8.59375 31.390625 \nQ 8.59375 42.28125 12.203125 53.3125 \nQ 15.828125 64.359375 23.1875 75.875 \nz\n\" id=\"DejaVuSans-40\"/>\n      </defs>\n      <g transform=\"translate(203.922344 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-40\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_12\">\n     <g id=\"line2d_12\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"222.613125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_12\">\n      <!-- % -->\n      <defs>\n       <path d=\"M 72.703125 32.078125 \nQ 68.453125 32.078125 66.03125 28.46875 \nQ 63.625 24.859375 63.625 18.40625 \nQ 63.625 12.0625 66.03125 8.421875 \nQ 68.453125 4.78125 72.703125 4.78125 \nQ 76.859375 4.78125 79.265625 8.421875 \nQ 81.6875 12.0625 81.6875 18.40625 \nQ 81.6875 24.8125 79.265625 28.4375 \nQ 76.859375 32.078125 72.703125 32.078125 \nz\nM 72.703125 38.28125 \nQ 80.421875 38.28125 84.953125 32.90625 \nQ 89.5 27.546875 89.5 18.40625 \nQ 89.5 9.28125 84.9375 3.921875 \nQ 80.375 -1.421875 72.703125 -1.421875 \nQ 64.890625 -1.421875 60.34375 3.921875 \nQ 55.8125 9.28125 55.8125 18.40625 \nQ 55.8125 27.59375 60.375 32.9375 \nQ 64.9375 38.28125 72.703125 38.28125 \nz\nM 22.3125 68.015625 \nQ 18.109375 68.015625 15.6875 64.375 \nQ 13.28125 60.75 13.28125 54.390625 \nQ 13.28125 47.953125 15.671875 44.328125 \nQ 18.0625 40.71875 22.3125 40.71875 \nQ 26.5625 40.71875 28.96875 44.328125 \nQ 31.390625 47.953125 31.390625 54.390625 \nQ 31.390625 60.6875 28.953125 64.34375 \nQ 26.515625 68.015625 22.3125 68.015625 \nz\nM 66.40625 74.21875 \nL 74.21875 74.21875 \nL 28.609375 -1.421875 \nL 20.796875 -1.421875 \nz\nM 22.3125 74.21875 \nQ 30.03125 74.21875 34.609375 68.875 \nQ 39.203125 63.53125 39.203125 54.390625 \nQ 39.203125 45.171875 34.640625 39.84375 \nQ 30.078125 34.515625 22.3125 34.515625 \nQ 14.546875 34.515625 10.03125 39.859375 \nQ 5.515625 45.21875 5.515625 54.390625 \nQ 5.515625 63.484375 10.046875 68.84375 \nQ 14.59375 74.21875 22.3125 74.21875 \nz\n\" id=\"DejaVuSans-37\"/>\n      </defs>\n      <g transform=\"translate(217.862344 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-37\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_13\">\n     <g id=\"line2d_13\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"239.353125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_13\">\n      <!-- „ -->\n      <defs>\n       <path d=\"M 32.515625 12.40625 \nL 42.828125 12.40625 \nL 42.828125 4 \nL 34.8125 -11.625 \nL 28.515625 -11.625 \nL 32.515625 4 \nz\nM 12.5 12.40625 \nL 22.796875 12.40625 \nL 22.796875 4 \nL 14.796875 -11.625 \nL 8.5 -11.625 \nL 12.5 4 \nz\n\" id=\"DejaVuSans-8222\"/>\n      </defs>\n      <g transform=\"translate(236.7625 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-8222\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_14\">\n     <g id=\"line2d_14\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"256.093125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_14\">\n      <!-- ” -->\n      <defs>\n       <path d=\"M 12.5 72.90625 \nL 22.796875 72.90625 \nL 22.796875 64.5 \nL 14.796875 48.875 \nL 8.5 48.875 \nL 12.5 64.5 \nz\nM 32.515625 72.90625 \nL 42.828125 72.90625 \nL 42.828125 64.5 \nL 34.8125 48.875 \nL 28.515625 48.875 \nL 32.515625 64.5 \nz\n\" id=\"DejaVuSans-8221\"/>\n      </defs>\n      <g transform=\"translate(253.5025 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-8221\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_15\">\n     <g id=\"line2d_15\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"272.833125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_15\">\n      <!-- — -->\n      <defs>\n       <path d=\"M 4.890625 30.90625 \nL 95.125 30.90625 \nL 95.125 23.875 \nL 4.890625 23.875 \nz\n\" id=\"DejaVuSans-8212\"/>\n      </defs>\n      <g transform=\"translate(267.833125 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-8212\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_16\">\n     <g id=\"line2d_16\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"289.573125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_16\">\n      <!-- … -->\n      <defs>\n       <path d=\"M 44.828125 12.40625 \nL 55.171875 12.40625 \nL 55.171875 0 \nL 44.828125 0 \nz\nM 78.078125 12.40625 \nL 88.484375 12.40625 \nL 88.484375 0 \nL 78.078125 0 \nz\nM 11.53125 12.40625 \nL 21.921875 12.40625 \nL 21.921875 0 \nL 11.53125 0 \nz\n\" id=\"DejaVuSans-8230\"/>\n      </defs>\n      <g transform=\"translate(284.573125 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-8230\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_17\">\n     <g id=\"line2d_17\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"306.313125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_17\">\n      <!-- ˝ -->\n      <defs>\n       <path d=\"M 37.3125 79.984375 \nL 46 79.984375 \nL 33.890625 61.625 \nL 27.296875 61.625 \nz\nM 21 79.984375 \nL 29.296875 79.984375 \nL 18.40625 61.625 \nL 11.71875 61.625 \nz\n\" id=\"DejaVuSans-733\"/>\n      </defs>\n      <g transform=\"translate(303.813125 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-733\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_18\">\n     <g id=\"line2d_18\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"323.053125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_18\">\n      <!-- / -->\n      <defs>\n       <path d=\"M 25.390625 72.90625 \nL 33.6875 72.90625 \nL 8.296875 -9.28125 \nL 0 -9.28125 \nz\n\" id=\"DejaVuSans-47\"/>\n      </defs>\n      <g transform=\"translate(321.36875 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-47\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_19\">\n     <g id=\"line2d_19\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"339.793125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_19\">\n      <!-- ; -->\n      <defs>\n       <path d=\"M 11.71875 51.703125 \nL 22.015625 51.703125 \nL 22.015625 39.3125 \nL 11.71875 39.3125 \nz\nM 11.71875 12.40625 \nL 22.015625 12.40625 \nL 22.015625 4 \nL 14.015625 -11.625 \nL 7.71875 -11.625 \nL 11.71875 4 \nz\n\" id=\"DejaVuSans-59\"/>\n      </defs>\n      <g transform=\"translate(338.10875 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-59\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_20\">\n     <g id=\"line2d_20\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"356.533125\" xlink:href=\"#mfd7c4472b7\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_20\">\n      <!-- § -->\n      <defs>\n       <path d=\"M 18.5 45.703125 \nQ 15.4375 43.453125 13.921875 41.25 \nQ 12.40625 39.0625 12.40625 36.8125 \nQ 12.40625 33.109375 15.796875 29.859375 \nQ 19.1875 26.609375 31.390625 20.015625 \nQ 34.46875 22.21875 35.984375 24.4375 \nQ 37.5 26.65625 37.5 28.90625 \nQ 37.5 32.5625 34 35.875 \nQ 30.515625 39.203125 18.5 45.703125 \nz\nM 40.484375 71.390625 \nL 40.484375 63.375 \nQ 36.421875 65.28125 32.921875 66.234375 \nQ 29.4375 67.1875 26.703125 67.1875 \nQ 21.96875 67.1875 19.328125 65.234375 \nQ 16.703125 63.28125 16.703125 59.8125 \nQ 16.703125 55.421875 26.765625 49.8125 \nQ 28.03125 49.078125 28.71875 48.6875 \nQ 39.015625 42.875 42.203125 39.109375 \nQ 45.40625 35.359375 45.40625 30.421875 \nQ 45.40625 26.03125 43.15625 22.609375 \nQ 40.921875 19.1875 36.375 16.609375 \nQ 39.40625 14.0625 40.796875 11.390625 \nQ 42.1875 8.734375 42.1875 5.609375 \nQ 42.1875 -1.3125 37.203125 -5.40625 \nQ 32.234375 -9.515625 23.78125 -9.515625 \nQ 20.21875 -9.515625 16.453125 -8.8125 \nQ 12.703125 -8.109375 8.40625 -6.6875 \nL 8.40625 1.3125 \nQ 12.640625 -0.59375 16.25 -1.53125 \nQ 19.875 -2.484375 22.703125 -2.484375 \nQ 27.6875 -2.484375 30.4375 -0.4375 \nQ 33.203125 1.609375 33.203125 5.328125 \nQ 33.203125 10.296875 22.40625 16.3125 \nL 21.1875 17 \nQ 10.75 22.859375 7.625 26.59375 \nQ 4.5 30.328125 4.5 35.296875 \nQ 4.5 39.75 6.765625 43.234375 \nQ 9.03125 46.734375 13.484375 49.125 \nQ 10.59375 51.265625 9.15625 54 \nQ 7.71875 56.734375 7.71875 60.109375 \nQ 7.71875 66.453125 12.59375 70.328125 \nQ 17.484375 74.21875 25.59375 74.21875 \nQ 29.15625 74.21875 32.890625 73.5 \nQ 36.625 72.796875 40.484375 71.390625 \nz\n\" id=\"DejaVuSans-167\"/>\n      </defs>\n      <g transform=\"translate(354.033125 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-167\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"matplotlib.axis_2\">\n    <g id=\"ytick_1\">\n     <g id=\"line2d_21\">\n      <defs>\n       <path d=\"M 0 0 \nL -3.5 0 \n\" id=\"m83dc1f6fd1\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"30.103125\" xlink:href=\"#m83dc1f6fd1\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_21\">\n      <!-- 0.0 -->\n      <defs>\n       <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421875 \nQ 19.53125 -1.421875 13.0625 8.265625 \nQ 6.59375 17.96875 6.59375 36.375 \nQ 6.59375 54.828125 13.0625 64.515625 \nQ 19.53125 74.21875 31.78125 74.21875 \nz\n\" id=\"DejaVuSans-48\"/>\n      </defs>\n      <g transform=\"translate(7.2 228.439219)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_2\">\n     <g id=\"line2d_22\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"30.103125\" xlink:href=\"#m83dc1f6fd1\" y=\"179.207547\"/>\n      </g>\n     </g>\n     <g id=\"text_22\">\n      <!-- 0.1 -->\n      <defs>\n       <path d=\"M 12.40625 8.296875 \nL 28.515625 8.296875 \nL 28.515625 63.921875 \nL 10.984375 60.40625 \nL 10.984375 69.390625 \nL 28.421875 72.90625 \nL 38.28125 72.90625 \nL 38.28125 8.296875 \nL 54.390625 8.296875 \nL 54.390625 0 \nL 12.40625 0 \nz\n\" id=\"DejaVuSans-49\"/>\n      </defs>\n      <g transform=\"translate(7.2 183.006766)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-49\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_3\">\n     <g id=\"line2d_23\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"30.103125\" xlink:href=\"#m83dc1f6fd1\" y=\"133.775095\"/>\n      </g>\n     </g>\n     <g id=\"text_23\">\n      <!-- 0.2 -->\n      <defs>\n       <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n      </defs>\n      <g transform=\"translate(7.2 137.574314)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-50\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_4\">\n     <g id=\"line2d_24\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"30.103125\" xlink:href=\"#m83dc1f6fd1\" y=\"88.342642\"/>\n      </g>\n     </g>\n     <g id=\"text_24\">\n      <!-- 0.3 -->\n      <defs>\n       <path d=\"M 40.578125 39.3125 \nQ 47.65625 37.796875 51.625 33 \nQ 55.609375 28.21875 55.609375 21.1875 \nQ 55.609375 10.40625 48.1875 4.484375 \nQ 40.765625 -1.421875 27.09375 -1.421875 \nQ 22.515625 -1.421875 17.65625 -0.515625 \nQ 12.796875 0.390625 7.625 2.203125 \nL 7.625 11.71875 \nQ 11.71875 9.328125 16.59375 8.109375 \nQ 21.484375 6.890625 26.8125 6.890625 \nQ 36.078125 6.890625 40.9375 10.546875 \nQ 45.796875 14.203125 45.796875 21.1875 \nQ 45.796875 27.640625 41.28125 31.265625 \nQ 36.765625 34.90625 28.71875 34.90625 \nL 20.21875 34.90625 \nL 20.21875 43.015625 \nL 29.109375 43.015625 \nQ 36.375 43.015625 40.234375 45.921875 \nQ 44.09375 48.828125 44.09375 54.296875 \nQ 44.09375 59.90625 40.109375 62.90625 \nQ 36.140625 65.921875 28.71875 65.921875 \nQ 24.65625 65.921875 20.015625 65.03125 \nQ 15.375 64.15625 9.8125 62.3125 \nL 9.8125 71.09375 \nQ 15.4375 72.65625 20.34375 73.4375 \nQ 25.25 74.21875 29.59375 74.21875 \nQ 40.828125 74.21875 47.359375 69.109375 \nQ 53.90625 64.015625 53.90625 55.328125 \nQ 53.90625 49.265625 50.4375 45.09375 \nQ 46.96875 40.921875 40.578125 39.3125 \nz\n\" id=\"DejaVuSans-51\"/>\n      </defs>\n      <g transform=\"translate(7.2 92.141861)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-51\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_5\">\n     <g id=\"line2d_25\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"30.103125\" xlink:href=\"#m83dc1f6fd1\" y=\"42.91019\"/>\n      </g>\n     </g>\n     <g id=\"text_25\">\n      <!-- 0.4 -->\n      <defs>\n       <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375 17.1875 \nL 47.609375 0 \nL 37.796875 0 \nL 37.796875 17.1875 \nL 4.890625 17.1875 \nL 4.890625 26.703125 \nz\n\" id=\"DejaVuSans-52\"/>\n      </defs>\n      <g transform=\"translate(7.2 46.709408)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-52\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"line2d_26\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_27\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_28\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_29\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_30\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_31\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_32\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_33\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_34\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_35\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_36\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_37\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_38\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_39\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_40\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_41\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_42\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_43\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_44\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"line2d_45\">\n    <path clip-path=\"url(#pba0c2af0c7)\" d=\"M 0 0 \n\" style=\"fill:none;stroke:#424242;stroke-linecap:square;stroke-width:2.7;\"/>\n   </g>\n   <g id=\"patch_23\">\n    <path d=\"M 30.103125 224.64 \nL 30.103125 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_24\">\n    <path d=\"M 364.903125 224.64 \nL 364.903125 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_25\">\n    <path d=\"M 30.103125 224.64 \nL 364.903125 224.64 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_26\">\n    <path d=\"M 30.103125 7.2 \nL 364.903125 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n  </g>\n </g>\n <defs>\n  <clipPath id=\"pba0c2af0c7\">\n   <rect height=\"217.44\" width=\"334.8\" x=\"30.103125\" y=\"7.2\"/>\n  </clipPath>\n </defs>\n</svg>\n",
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOW0lEQVR4nO3de7CcdX3H8feHZKJV0XbMqS2QMVQytRllqkawU2tph5ZESuIFBQbRVi5CyRRKb8xIGYZKp8DU0SkZFNFOoVWkAUsqQVptaYc61Bwuo6KikaElWOVQKf5hFUO//WM3cQknOXs72c3P9+ufs8/u8/z2O+Tw3s2zl6SqkCQd+A6a9ACSpPEw6JLUCIMuSY0w6JLUCIMuSY1YOqk7Xr58ea1cuXJSdy9JB6S77777saqame+2iQV95cqVzM7OTuruJemAlOQ/9nabp1wkqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqREGXZIaYdAlqRET+6ToLnNX//XAx8yc87ZFmESSDmw+Q5ekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRvQV9CRrkzyQZHuSC/ex35uTVJI14xtRktSPBYOeZAmwCVgHrAZOSbJ6nv0OBs4D/n3cQ0qSFtbPM/SjgO1V9WBVPQncAGyYZ78/AS4HvjfG+SRJfeon6IcCD/ds7+het1uSVwIrqurWfS2U5Kwks0lm5+bmBh5WkrR3I78omuQg4L3A7y20b1VdU1VrqmrNzMzMqHctSerRT9AfAVb0bB/WvW6Xg4GXAXckeQh4DbDFF0Ylaf/qJ+jbgFVJDk+yDDgZ2LLrxqp6oqqWV9XKqloJ3AWsr6rZRZlYkjSvBYNeVTuBjcDtwJeBG6vq/iSXJlm/2ANKkvqztJ+dqmorsHWP6y7ey77HjD6WJGlQflJUkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEX199H+affPq9wx13E+dc9GYJ5GkyfIZuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1oq+gJ1mb5IEk25NcOM/tZyf5QpL7ktyZZPX4R5Uk7cuCQU+yBNgErANWA6fME+yPVtXLq+rngSuA9459UknSPvXzDP0oYHtVPVhVTwI3ABt6d6iq7/RsPheo8Y0oSerH0j72ORR4uGd7B3D0njslORe4AFgG/OpYppMk9W1sL4pW1aaqegnwR8BF8+2T5Kwks0lm5+bmxnXXkiT6C/ojwIqe7cO61+3NDcAb5ruhqq6pqjVVtWZmZqb/KSVJC+on6NuAVUkOT7IMOBnY0rtDklU9m8cDXxvfiJKkfix4Dr2qdibZCNwOLAE+UlX3J7kUmK2qLcDGJMcCPwAeB96xmENLkp6pnxdFqaqtwNY9rru45/J5Y55LkjQgPykqSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY3oK+hJ1iZ5IMn2JBfOc/sFSb6U5PNJPpPkxeMfVZK0LwsGPckSYBOwDlgNnJJk9R673Qusqaojgc3AFeMeVJK0b/08Qz8K2F5VD1bVk8ANwIbeHarqn6vqu93Nu4DDxjumJGkh/QT9UODhnu0d3ev25nTgtvluSHJWktkks3Nzc/1PKUla0FhfFE3yNmANcOV8t1fVNVW1pqrWzMzMjPOuJelH3tI+9nkEWNGzfVj3uqdJcizwbuCXq+r74xlPktSvfp6hbwNWJTk8yTLgZGBL7w5JXgF8EFhfVY+Of0xJ0kIWDHpV7QQ2ArcDXwZurKr7k1yaZH13tyuB5wF/m+S+JFv2spwkaZH0c8qFqtoKbN3juot7Lh875rkkSQPyk6KS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmN6CvoSdYmeSDJ9iQXznP765Lck2RnkhPHP6YkaSELBj3JEmATsA5YDZySZPUeu/0n8JvAR8c9oCSpP0v72OcoYHtVPQiQ5AZgA/ClXTtU1UPd2/5vEWaUJPWhn1MuhwIP92zv6F43sCRnJZlNMjs3NzfMEpKkvdivL4pW1TVVtaaq1szMzOzPu5ak5vUT9EeAFT3bh3WvkyRNkX6Cvg1YleTwJMuAk4EtizuWJGlQCwa9qnYCG4HbgS8DN1bV/UkuTbIeIMmrk+wA3gJ8MMn9izm0JOmZ+nmXC1W1Fdi6x3UX91zeRudUjCRpQvykqCQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiOWTnqAafCVTRuGOu6l594y5kkkaXg+Q5ekRhh0SWqEp1zG5I4PHT/wMceceesiTCLpR5VBnyKb/3LtwMec+FufWoRJJB2IPOUiSY0w6JLUCIMuSY0w6JLUCIMuSY3wXS4N+eD1xw113LtOu33Mk0iaBIOup7nkxuEeFC556w8fFNbd8uaBj79tw01D3a+kHzLomkqv/8R7Bj5m6xsv2n35+JuvHup+b33TOUMdJ02DvoKeZC3wfmAJcG1V/dketz8LuA54FfDfwElV9dB4R5X2r9/Y/DdDHffJE0/dfXn95r8f+PgtJ57wtO033nTnwGt84s2v3X35pJu3D3w8wMffdMRQx2lyFgx6kiXAJuDXgB3AtiRbqupLPbudDjxeVUckORm4HDhpMQaWtP9t+sS3hjru3De+aPfl2z7+2MDHrztp+dO277320YHXeMUZPznwMQeqfp6hHwVsr6oHAZLcAGwAeoO+Abike3kzcFWSVFWNcVZJGsl/XfHIUMf99B8euvvyt95398DHv+j8Vw11v4PKQs1NciKwtqrO6G6fBhxdVRt79vlid58d3e2vd/d5bI+1zgLO6m7+LPDAPu56OTD4Q3qba0zDDNOyxjTMMC1rTMMM07LGNMywv9Z4cVXNzHfDfn1RtKquAa7pZ98ks1W1ZpT7a2WNaZhhWtaYhhmmZY1pmGFa1piGGaZhjX4+WPQIsKJn+7DudfPuk2Qp8AI6L45KkvaTfoK+DViV5PAky4CTgS177LMFeEf38onAP3n+XJL2rwVPuVTVziQbgdvpvG3xI1V1f5JLgdmq2gJ8GLg+yXbg23SiP6q+Ts38iKwxDTNMyxrTMMO0rDENM0zLGtMww8TXWPBFUUnSgcEv55KkRvjRf0kHtCQvofM5mK8B26rqtslONDk+Q19ESV6a5LNJvpDkX5IsX/iodiW5JMnvT3iGZUn+tfturFHW+bHun+mSPvefSXJnki8meUPP9bckOWSUWYaV5I4kx+z6OYkZxuQ84C+AVaPEPMkHkvzi+Mba/wz64ntbVb0c+Cxw9qSHOZAleSjJyiR3DLtGVT0JfIbRv5rincDNVfVUn/ufAnyAzievzwdIcgJwb1V9Y8RZhvVC4PM9Pw9IVfU7VfW5qjptxKVeA9w1jpkGlWRJko8l2Z7kviSHLnzUMxn0RVRVX9n1lQnAs4DvTXIe7fZ3wKkL7rVvpwK3DLD/D4Dn0Pk9eKr7N4TzgStGnGMoSZ4PfArYCXyqqr49iTmmRZKfA746wAP03tbZOuTfuI4AjqyqI4CjGfLTpgZ9P0hyHLAOuHaCMwz7izZN5oCn6Lw1dhRfBF497MHdz2P8zIDfKPpROt959I/AnwK/DVxfVd8ddo5RVNV3quoPdv2cxAxTZh2dB7iRVNXrh/wb19eBpUn+HDioqr4/zP0b9EWW5CA679NfX1X/M6k5RvhFmxpV9eqqeriq3jTiOk8BTyY5eMgllgMD/VlW1RNVdXz3I933ACcAm5N8KMnmJL8w5Cwaj+MYQ9BH8Ot0nvDdBXw6yeHDLOK7XBbfIcATVfW1YQ5Oci5wZndzIlHeYwaAD1XVpkHXqapLxjbU6EY5Bfa/wLNHuO8/Bi6jc179TjrfUHoznahMtXl+F+Yz1O/HpCR5DvDjE37CsxL4iaq6MsmjdE7HnTfoIlP9waIknwHeXlXDfeflFEjyXOCXqmqSj/7qkeSFwL9V1UtHWONhOu+qGOhBIckq4LKqemuS8+icPrqJznns1w241nXAVcBG4Kqq+twgx6sjyfF0/h+9cAxrDdWs7j8SdBUwA3wT+GRVfXLQ+5/aUy7dUxVHMML50ik5b/wC4IwJzzAV/y2SnJ3k7ZOcoetXgFtHXOMfgNcuuNczXQa8u3v5Y8A5dL4v6f1DrHUk8I2enweM3t/H7u/F2d3LhyTZup/HGcv581GaVVXfr6oz6UR9HTD4l64zxc/Qk7wMeGdVXTDpWdSWJDcDF1bVV0dY45XA747hrXLD3v/z6bw2czrw4ap6yyTmaEGSe+j8+w0/GHGdoZuV5NnA84DH6Ty4XF5Vnx54nWkNurQYdn1jaFVdN4a13gn81ahvdZO6pwFvonPK5S7gXVW1c+B1DLoktWFqz6FLkgZj0CWpEQZdkhph0CWpEQZdkhrx/6i40ak7FeOpAAAAAElFTkSuQmCC\n"
+     },
+     "metadata": {
+      "needs_background": "light"
+     }
+    }
+   ],
+   "source": [
+    "NUM_MOST_COMMON = 20\n",
+    "\n",
+    "most_common = np.array(total.most_common())[:NUM_MOST_COMMON]\n",
+    "most_common_labels = most_common[:, 0]\n",
+    "most_common_values = np.array(most_common[:, 1], dtype=np.float) / sum(total.values())\n",
+    "\n",
+    "sns.barplot(most_common_labels, most_common_values)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ]
+}
\ No newline at end of file
diff --git a/params.yaml b/params.yaml
index f554c487ee7ad93e9f9b25a0efe10cdf0e7c2525..356c7058bac360dcbb2f9c561be8b7a94294003b 100644
--- a/params.yaml
+++ b/params.yaml
@@ -27,74 +27,67 @@ actions:
         num_workers: 24
         worker_memory_limit: "2GB"
 
-    training_base:
-        learning_rate: 0.0001
-        num_epochs: 5
-        batch_size: 2
-        batch_buffer_size: 100
-        save_step: 50
-        max_training_time: null
-        loss_averaging_span: 1000
-        fresh_start: false
-        device: "cuda:0"
-
-    testing_base:
-        limit: None
-        batch_size: 1
-        device: "cuda:0"
+    training:
+        actions_base:
+            learning_rate: null
+            batch_size: 2
+            save_step: 10000
+            start_checkpoint: null
+            num_workers: 4
+            gpu_devices: [0]
+
+        actions_ensemble:
+            learning_rate: 0.00001
+            num_models: 10
+            embedding_size: 200
+            num_heads: 2
+            dropout: 0.1
+            feedforward_neurons: 300
+            batch_size: 2
+            save_step: 10000
+            start_checkpoint: null
+            num_workers: 4
+            gpu_devices: [0]
+
+        actions_mixed:
+            embedding_size: 768
+            num_heads: 12
+            num_layers: 6
+            dropout: 0.1
+            feedforward_neurons: 1000
+            learning_rate: 0.0001
+            batch_size: 2
+            save_step: 10000
+            loss_averaging_span: 1000
+            start_checkpoint: null
+            num_workers: 4
+            gpu_devices: [0]
+
+        actions_restricted:
+            learning_rate: 0.0001
+            num_epochs: 5
+            batch_size: 2
+            save_step: 1000
+            start_checkpoint: null
+            num_workers: 4
+            gpu_devices: [0]
+
+    testing:
+        actions_base:
+            limit: None
+            batch_size: 1
+            device: "cuda:0"
+
+        actions_restricted:
+            limit: None
+            batch_size: 1
+            device: "cuda:0"
+
+        actions_mixed:
+            limit: None
+            batch_size: 1
+            device: "cuda:0"
 
-    training_restricted:
-        learning_rate: 0.0001
-        num_epochs: 5
-        batch_size: 2
-        batch_buffer_size: 100
-        save_step: 1000
-        max_training_time: null
-        loss_averaging_span: 1000
-        fresh_start: true
-        device: "cuda:0"
-
-    test_restricted:
-        limit: None
-        batch_size: 1
-        device: "cuda:0"
-
-    training_mixed:
-        embedding_size: 768
-        num_heads: 12
-        num_layers: 6
-        dropout: 0.1
-        feedforward_neurons: 1000
-        learning_rate: 0.0001
-        num_epochs: 5
-        batch_size: 2
-        batch_buffer_size: 1000
-        save_step: 10000
-        max_training_time: null
-        loss_averaging_span: 1000
-        fresh_start: true
-        device: "cuda:0"
-
-    test_mixed:
-        limit: None
-        batch_size: 1
-        device: "cuda:0"
-
-    training_ensemble:
-        learning_rate: 0.00001
-        num_models: 10
-        embedding_size: 200
-        num_heads: 2
-        dropout: 0.1
-        feedforward_neurons: 300
-        num_epochs: 5
-        batch_size: 2
-        batch_buffer_size: 1000
-        save_step: 10000
-        max_training_time: null
-        loss_averaging_span: 1000
-        fresh_start: true
-        device: "cuda:0"
 translations:
     extraction:
         num_partitions: 2_000
diff --git a/punctuate.py b/punctuate.py
index 8eb4bdc04369776859a6bf4e9540463ed28c8f51..4ae7c735911b693d1590f79ead803d7fcee26e1b 100755
--- a/punctuate.py
+++ b/punctuate.py
@@ -1,19 +1,17 @@
 import argparse
+from src.models.model_factory import MODELS_MAP
+from src.predictors.factory import ACTION_MODEL_PREDICTORS
+
 from src.pipelines.actions_based.utils import max_suppression
 from src.pipelines.actions_based.processing import (
     ACTIONS_KEYS,
     recover_text,
     token_labels_to_word_labels,
 )
-from src.models.interfaces import ActionsModel
-from typing import Dict
 
 import numpy as np
 import torch
 
-from src.models.actions_model_base import ActionsModelBase
-from src.models.actions_model_mixed import ActionsModelMixed
-from src.models.actions_model_restricted import ActionsModelRestricted
 from src.utils import (
     PROJECT_ROOT,
     input_preprocess,
@@ -21,12 +19,6 @@ from src.utils import (
 )
 import colored
 
-SUPPORTED_MODELS: Dict[str, ActionsModel] = {
-    "base": ActionsModelBase,
-    "restricted": ActionsModelRestricted,
-    "mixed": ActionsModelMixed,
-}
-
 
 def print_highlighted(text: str, word_labels: np.ndarray, action_name: str) -> None:
     label_id = np.argwhere(np.array(ACTIONS_KEYS) == action_name)
@@ -46,19 +38,18 @@ if __name__ == "__main__":
         "-a",
         "--architecture",
         required=True,
-        choices=SUPPORTED_MODELS.keys(),
+        choices=MODELS_MAP.keys(),
         help="Model architecture",
     )
     parser.add_argument(
-        "-d",
-        "--directory",
+        "-m",
+        "--model",
         required=True,
         help="Directory where trained model is located, relative to project root",
     )
     parser.add_argument(
         "-i", "--input", required=True, type=str, help="Input text file"
     )
-    parser.add_argument("-m", "--model", default="final", help="Pretrained model name")
     parser.add_argument(
         "-l",
         "--highlight",
@@ -80,22 +71,24 @@ if __name__ == "__main__":
 
     print(f"Loading model {args.model}...")
     device = torch.device(args.device)
-    model_location = f"{PROJECT_ROOT}/{args.directory}"
-    model_type = SUPPORTED_MODELS[args.architecture]
-    model = model_type.load(model_location, args.model, device)
+    model_location = f"{PROJECT_ROOT}/{args.model}"
+    model_type = MODELS_MAP[args.architecture]
+    model = model_type.load_from_checkpoint(model_location, map_location=device)
     model.train(False)
 
+    predictor = ACTION_MODEL_PREDICTORS[args.architecture](model)
+
     print("Loading text...")
     with open(args.input, "r") as f:
         text = f.read()
 
     print("Inferencing...")
-    tokenizer = model.tokenizer()
+    tokenizer = predictor.tokenizer()
     data = input_preprocess(output_preprocess(text))
     data_tokenized = tokenizer(data, return_tensors="pt")
 
     predictions = (
-        model.predict_raw(
+        predictor.predict_raw(
             data_tokenized["input_ids"].to(device),
             data_tokenized["attention_mask"].to(device),
         )
diff --git a/requirements.txt b/requirements.txt
index 068c99e062e0f3457aed03cb4be1c6338c17f29e..9e0e50ad1240ea6de4887fcac3dd5f2e8ac12c65 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -61,3 +61,5 @@ zict==2.0.0
 scikit-learn==0.23.2
 nlp_ws==0.6
 colored==1.4.2
+pytorch-lightning==0.9.0
+comet-ml==3.2.0
\ No newline at end of file
diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index 96653fe1f41751b44812315018eb5336cbe2922b..48e3cf00d03cf96fbe069fd3fd9607eb310e8959 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -1,71 +1,30 @@
 from __future__ import annotations
 
-import os
-from dataclasses import dataclass
+from typing import Optional
 
 import numpy as np
+import pytorch_lightning as pl
 import torch
-import torch.nn as nn
-from torch.nn.modules.loss import BCEWithLogitsLoss
+import torch.nn.functional as F
+from pytorch_lightning.core.lightning import LightningModule
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
 from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.interfaces import ActionsModel
-from src.pipelines.actions_based.processing import (
-    ACTIONS_KEYS,
-    action_vector,
-    last_stop_label,
-    recover_text,
-    token_labels_to_word_labels,
-)
-from src.pipelines.actions_based.utils import max_suppression
-from src.utils import (
-    get_device,
-    pickle_read,
-    pickle_save,
-    prepare_folder,
-    yaml_serializable,
-)
+from src.models.interfaces import ActionsModelCreator
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.utils import ifnone, onecycle_shed
 
 
-@dataclass
-class ActionsModelBaseParams:
-    """
-    Parameters for ActionsModelBase initialization
-
-    Args:
-        base_model (str): Name of base model
-        num_labels (int): Length of action vector
-
-    """
-
-    base_model: str
-    num_labels: int = len(ACTIONS_KEYS)
-
-
-@yaml_serializable
-@dataclass
-class ActionsModelBaseRuntimeParams:
-    """
-    Parameters for ActionsModelBase during runtime interference
-
-    Args:
-        threshold (float): minimum confidence for applying action
-        chunksize (int): Maximum number of chunks to perform inference on
-    """
-
-    threshold: float = 0.9
-    chunksize: int = 500
-
-
-class ActionsModelBase(ActionsModel):
+class ActionsModelBase(LightningModule):
     """Model based on simple multilabel per-token classifiaction. Each token is binarly classified in n-dimensions"""
 
     def __init__(
         self,
-        params: ActionsModelBaseParams,
-        runtime: ActionsModelBaseRuntimeParams = ActionsModelBaseRuntimeParams(),
+        base_model: str,
+        num_labels: int = len(ACTIONS_KEYS),
+        learning_rate: float = 1e-5,
+        loss_weights: Optional[np.ndarray] = None,
     ) -> None:
         """Initializes actions model
 
@@ -74,12 +33,18 @@ class ActionsModelBase(ActionsModel):
             runtime (ActionsModelBaseRuntimeParams): Params defining model's runtime inference
         """
         super(ActionsModelBase, self).__init__()
-        self.params = params
-        self.runtime = runtime
 
-        self._tokenizer = BertTokenizerFast.from_pretrained(params.base_model)
-        config = PretrainedConfig.from_pretrained(params.base_model)
-        config.num_labels = params.num_labels
+        self.learning_rate = learning_rate
+
+        self.learning_rate_min = None
+        self.learning_rate_max = None
+
+        self.tokenizer = BertTokenizerFast.from_pretrained(base_model)
+        config = PretrainedConfig.from_pretrained(base_model)
+        config.num_labels = num_labels
+
+        self.save_hyperparameters()
+        self.register_buffer("weight_tensor", torch.tensor(loss_weights))
 
         self.core = BertForTokenClassification(config)
 
@@ -99,118 +64,67 @@ class ActionsModelBase(ActionsModel):
 
         return y_pred
 
-    def predict_raw(
-        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
-    ) -> torch.Tensor:
-        """Function that maps input_ids tensors into per-token labels
-
-        Args:
-            input_ids (torch.Tensor): Token ids of input. Shape BxL
-            attention_mask (torch.Tensor): Attention mask of tokens. Shape BxL
-
-        Returns:
-            torch.Tensor: Per-token action-vector labels. Shape BxLxA
-        """
-
-        return self.forward(input_ids, attention_mask=attention_mask).sigmoid()
-
-    def predict(self, text: str) -> str:
-        text = text.strip()
-
-        device = get_device(self)
-
-        tokenizer = self.tokenizer()
-        tokens = tokenizer(text, return_tensors="pt")["input_ids"].to(device)
-        output = None
-
-        index_start = 0
-        while index_start < len(tokens[0]):
-            index_end = min(index_start + self.runtime.chunksize, len(tokens[0]))
+    @onecycle_shed
+    def configure_optimizers(self):
+        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
 
-            tokens_chunk = tokens[:, index_start:index_end]
-            attention_mask = torch.ones_like(tokens_chunk).to(device)
+        return optimizer
 
-            actions = (
-                self.predict_raw(tokens_chunk, attention_mask).detach().cpu().numpy()
-            )
-            actions_suppresed = max_suppression(actions, self.runtime.threshold)[0]
+    def training_step(self, batch, batch_idx):
+        loss = self._common_step(batch)
+        result = pl.TrainResult(minimize=loss)
 
-            offset = last_stop_label(actions_suppresed, action_vector(["dot"]))
+        result.log("train_loss", loss)
 
-            # Prevent infinite loop
-            if (offset is None) or (offset == 0):
-                offset = index_end - index_start
+        return result
 
-            if output is None:
-                output = actions[0, 0:offset]
-            else:
-                output = np.concatenate([output, actions[0, 0:offset]], axis=0)
+    def validation_step(self, batch, batch_idx):
+        loss = self._common_step(batch)
+        result = pl.EvalResult(loss)
+        result.valid_batch_loss = loss
+        result.log("valid_loss", loss)
 
-            index_start += offset
+        return result
 
-        assert len(output) == len(tokens[0])
+    def validation_epoch_end(self, outputs):
+        avg_loss = outputs.valid_batch_loss.mean()
+        result = pl.EvalResult(checkpoint_on=avg_loss)
+        result.log("valid_loss", avg_loss, on_epoch=True, prog_bar=True)
 
-        word_labels = token_labels_to_word_labels(text, output[1:-1], tokenizer)
-        actions = max_suppression(
-            np.expand_dims(word_labels, 0), self.runtime.threshold
-        )[0]
+        return result
 
-        return recover_text(text, actions)
+    def _common_step(self, batch):
+        inputs = batch["source"]
+        outputs = batch["target"]
+        attentions_mask = batch["attention_mask"]
 
-    def tokenizer(self) -> BertTokenizerFast:
-        return self._tokenizer
+        outputs_pred = self(input_ids=inputs, attention_mask=attentions_mask)
 
-    def save(self, dir: str, name: str, runtime: bool = True) -> None:
-        prepare_folder(dir)
-        torch.save(self.state_dict(), f"{dir}/{name}.model")
-        pickle_save(self.params, f"{dir}/{name}.config")
+        loss = F.binary_cross_entropy_with_logits(
+            outputs_pred, outputs, self.weight_tensor
+        )
+        return loss
 
-        if runtime:
-            self.runtime.save_yaml(f"{dir}/{name}.runtime.yaml")
+    def _test_step(self, batch, batch_idx):
+        raise NotImplementedError("Test step not yet implemented")
 
-    @staticmethod
-    def load(dir: str, name: str, device: torch.device) -> ActionsModelBase:
-        params = pickle_read(f"{dir}/{name}.config")
-        if os.path.exists(f"{dir}/{name}.runtime.yaml"):
-            runtime = ActionsModelBaseRuntimeParams.load_yaml(
-                f"{dir}/{name}.runtime.yaml"
-            )
-        else:
-            runtime = ActionsModelBaseRuntimeParams()
 
-        model = ActionsModelBase(params, runtime).to(device)
-        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device))
+class ActionsModelBaseCreator(ActionsModelCreator):
+    def setup(self, config: dict, base_model: str, seed: int) -> None:
+        self.base_model = base_model
+        self.seed = seed
 
-        return model
+        self.pos_weight = None
+        self.learning_rate = ifnone(config["learning_rate"], 1e-5)
 
+    def config_stats(self, pos_examples: np.ndarray, neg_examples: np.ndarray):
+        self.pos_weight = neg_examples / pos_examples
 
-class ActionsModelBaseLoss(nn.Module):
-    """Proposed loss for ActionsModelBase model"""
-
-    def __init__(self, prior_inverse_odds: torch.Tensor) -> None:
-        """Initializes ActionsModelBaseLoss
-
-        Args:
-            prior_odds (torch.Tensor): Negative to positive ratio of each action vector
-                entry in dataset. Shape A
-        """
-        super(ActionsModelBaseLoss, self).__init__()
-
-        self.core = BCEWithLogitsLoss(pos_weight=prior_inverse_odds)
-
-    def forward(
-        self,
-        predicted_action_vector_logits: torch.Tensor,
-        true_action_vector: torch.Tensor,
-    ) -> torch.Tensor:
-        """Computes ActionsModelBase loss
-
-        Args:
-            true_action_vector (torch.Tensor): Logits predicted by ActionsModelBase model. Shape MxBxLxA
-            predicted_action_vector_logits (torch.Tensor): Target labels. Shape BxLxA
-
-        Returns:
-            torch.Tensor: Computed loss.
-        """
+    def create_model(self) -> LightningModule:
+        assert self.pos_weight is not None
+        return ActionsModelBase(
+            self.base_model, len(ACTIONS_KEYS), self.learning_rate, self.pos_weight
+        )
 
-        return self.core(predicted_action_vector_logits, true_action_vector)
+    def load_model(self, checkpoint: str) -> LightningModule:
+        return ActionsModelBase.load_from_checkpoint(checkpoint)
diff --git a/src/models/actions_model_ensemble.py b/src/models/actions_model_ensemble.py
index 40e8bf276753dd76c13b9b28928cfc4d9fdff99b..bf15cab5b5a532f59dd3e411c087e1587b4a775c 100644
--- a/src/models/actions_model_ensemble.py
+++ b/src/models/actions_model_ensemble.py
@@ -1,80 +1,36 @@
-import os
-from dataclasses import dataclass
-from typing import Optional
-
 import numpy as np
+import pytorch_lightning as pl
 import torch
-from torch import embedding
 import torch.nn as nn
-from torch.nn.modules.loss import BCEWithLogitsLoss
+import torch.nn.functional as F
+from pytorch_lightning.core.lightning import LightningModule
 from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.common import PositionalEncoding, generate_square_subsequent_mask
-from src.models.interfaces import PunctuationModel
-from src.pipelines.actions_based.processing import (
-    ACTIONS_KEYS,
-    action_vector,
-    recover_text,
-    token_labels_to_word_labels,
-)
-from src.utils import pickle_read, pickle_save, prepare_folder, yaml_serializable
-
-
-@dataclass
-class ActionsModelEnsembleParams:
-    """
-    Parameters for initializing ActionsModelMixed
-
-    Params:
-        base_tokenizer (str): Name of pretrained tokenizer
-        vocab_size (int): Number of tokens in tokenizer dictionary
-        embedding_size (int, optional): Shape of word and punctuation embeddings. Defaults to 200.
-        num_heads (int, optional): Number of heads in multiheaded attention. Defaults to 4.
-        num_layers (int, optional): Number of both decoded and encoder layers. Defaults to 2.
-        feedforward_neurons (int, optional): Size of feed-forward neural network at the end of encoder/decoder. Defaults to 200.
-        num_labels (int, optional): Action-vector size. Defaults to len(ACTIONS_KEYS).
-        max_len (int, optional): Maxium length of sequence. Defaults to 500.
-        dropout (float, optional): Dropout ratio. Defaults to 0.1.
-    """
-
-    base_tokenizer: str
-    num_models: int = 10
-    embedding_size: int = 200
-    num_heads: int = 4
-    feedforward_neurons: int = 200
-    num_labels: int = len(ACTIONS_KEYS)
-    max_len: int = 500
-    dropout: float = 0.1
-
-
-@yaml_serializable
-@dataclass
-class ActionsModelEnsembleRuntimeParams:
-    """
-    Parameters for ActionsModelMixed during runtime interference
-
-    Args:
-        threshold (float): minimum confidence for applying action
-        chunksize (int): Maximum number of chunks to perform inference on
-    """
-
-    threshold: float = 0.9
-    max_cond_len: Optional[int] = 500
+from src.models.common import PositionalEncoding
+from src.models.interfaces import ActionsModelCreator
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.utils import ifnone, onecycle_shed
 
 
 class ActionsModelEnsembleBase(nn.Module):
     """Encoder-decoder based model with unpunctuated token sequence as input and array of action-vectors as output"""
 
-    def __init__(self, num_labels: int, vocab_size: int, embedding_size: int, max_len: int, num_heads: int, feedforward_hidden: int, dropout: float) -> None:
+    def __init__(
+        self,
+        num_labels: int,
+        vocab_size: int,
+        embedding_size: int,
+        max_len: int,
+        num_heads: int,
+        feedforward_hidden: int,
+        dropout: float,
+    ) -> None:
         """Initializes mixed model
 
         Args:
             params (ActionsModelMixedParams): Parameters for model
         """
         super(ActionsModelEnsembleBase, self).__init__()
-
-        self._tokenizer = None
-
         self.num_labels = num_labels
 
         # Word embedder
@@ -87,18 +43,13 @@ class ActionsModelEnsembleBase(nn.Module):
 
         # Sentence encoder
         self.encoder = nn.TransformerEncoderLayer(
-            embedding_size,
-            num_heads,
-            feedforward_hidden,
-            dropout,
+            embedding_size, num_heads, feedforward_hidden, dropout,
         )
 
         self.to_labels = nn.Linear(embedding_size, num_labels)
 
     def forward(
-        self,
-        input_ids: torch.Tensor,
-        attention_mask: torch.Tensor,
+        self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
     ) -> torch.Tensor:
         """Computes action vectors array from array of tokens
 
@@ -122,13 +73,21 @@ class ActionsModelEnsembleBase(nn.Module):
         return self.to_labels(y)
 
 
-class ActionsModelEnsemble(PunctuationModel):
+class ActionsModelEnsemble(LightningModule):
     """Encoder-decoder based model with unpunctuated token sequence as input and array of action-vectors as output"""
 
     def __init__(
         self,
-        params: ActionsModelEnsembleParams,
-        runtime: ActionsModelEnsembleRuntimeParams = ActionsModelEnsembleRuntimeParams(),
+        base_tokenizer: str,
+        pos_weight: np.ndarray,
+        learning_rate=1e-4,
+        num_models: int = 10,
+        embedding_size: int = 200,
+        num_heads: int = 4,
+        feedforward_neurons: int = 200,
+        num_labels: int = len(ACTIONS_KEYS),
+        max_len: int = 500,
+        dropout: float = 0.1,
     ) -> None:
         """Initializes mixed model
 
@@ -137,22 +96,41 @@ class ActionsModelEnsemble(PunctuationModel):
         """
         super(ActionsModelEnsemble, self).__init__()
 
-        self.params = params
-        self.runtime = runtime
-        self._tokenizer = BertTokenizerFast.from_pretrained(params.base_tokenizer)
-
-        self.models = torch.nn.ModuleList([
-            ActionsModelEnsembleBase(params.num_labels, self._tokenizer.vocab_size, params.embedding_size, params.max_len, params.num_heads, params.feedforward_neurons, params.dropout) for _ in range(params.num_models)
-        ])
+        self.num_models = num_models
+        self.learning_rate = learning_rate
+        self.register_buffer("weight_tensor", torch.tensor(pos_weight))
+        self.tokenizer = BertTokenizerFast.from_pretrained(base_tokenizer)
+
+        self.learning_rate_min = None
+        self.learning_rate_max = None
+
+        self.models = torch.nn.ModuleList(
+            [
+                ActionsModelEnsembleBase(
+                    num_labels,
+                    self.tokenizer.vocab_size,
+                    embedding_size,
+                    max_len,
+                    num_heads,
+                    feedforward_neurons,
+                    dropout,
+                )
+                for _ in range(num_models)
+            ]
+        )
 
     def _model_subsample_indices(self, step_num: int):
-        return np.random.RandomState(step_num).choice(np.arange(self.params.num_models), self.params.num_models, True).astype(np.int)
+        if step_num < 0:
+            return np.arange(self.num_models)
+        else:
+            return (
+                np.random.RandomState(step_num)
+                .choice(np.arange(self.num_models), self.num_models, True)
+                .astype(np.int)
+            )
 
     def forward(
-        self,
-        input_ids: torch.Tensor,
-        attention_mask: torch.Tensor,
-        step: int
+        self, input_ids: torch.Tensor, attention_mask: torch.Tensor, step: int = -1
     ) -> torch.Tensor:
         """Computes action vectors array from array of tokens
 
@@ -168,98 +146,90 @@ class ActionsModelEnsemble(PunctuationModel):
 
         predictions = []
         for model_id in model_ids:
-            predictions.append(self.models[model_id](input_ids, attention_mask).unsqueeze(0))
+            predictions.append(
+                self.models[model_id](input_ids, attention_mask).unsqueeze(0)
+            )
 
         predictions = torch.cat(predictions, dim=0)
 
         return predictions
 
-    def tokenizer(self) -> BertTokenizerFast:
-        return self._tokenizer
+    @onecycle_shed
+    def configure_optimizers(self):
+        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
+        return optimizer
 
-    def predict(self, text: str) -> str:
-        raise NotImplementedError("Not yet implemented")
+    def training_step(self, batch, batch_idx):
+        loss = self._common_step(batch, batch_idx)
+        result = pl.TrainResult(minimize=loss)
 
-    def predict_raw(
-        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
-    ) -> torch.Tensor:
-        """Function that maps input_ids tensors into per-token labels
+        result.log("train_loss", loss)
 
-        Args:
-            input_ids (torch.Tensor): Token ids of input. Shape BxL
-            attention_mask (torch.Tensor): Attention mask of tokens. Shape BxL
+        return result
 
-        Returns:
-            torch.Tensor: Per-token action-vector labels. Shape BxLxA
-        """
-        predictions = self.forward(input_ids, attention_mask=attention_mask).sigmoid()
+    def validation_step(self, batch, batch_idx):
+        loss = self._common_step(batch, batch_idx)
+        result = pl.EvalResult(loss)
+        result.valid_batch_loss = loss
+        result.log("valid_loss", loss)
 
-        # TODO: Compare consensus types (mean / median / minimal votes etc.)
-        predictions_consensus = predictions.median(0)
+        return result
 
-        return predictions_consensus
+    def validation_epoch_end(self, outputs):
+        avg_loss = outputs.valid_batch_loss.mean()
+        result = pl.EvalResult(checkpoint_on=avg_loss)
+        result.log("valid_loss", avg_loss, on_epoch=True, prog_bar=True)
 
+        return result
 
-    def save(self, dir: str, name: str, runtime: bool = True) -> None:
-        prepare_folder(dir)
-        torch.save(self.state_dict(), f"{dir}/{name}.model")
-        pickle_save(self.params, f"{dir}/{name}.config")
+    def _common_step(self, batch, step):
+        inputs = batch["source"]
+        outputs = batch["target"]
+        attentions_mask = batch["attention_mask"] == 0
 
-        if runtime:
-            self.runtime.save_yaml(f"{dir}/{name}.runtime.yaml")
+        outputs_pred = self(input_ids=inputs, attention_mask=attentions_mask, step=step)
 
-    @staticmethod
-    def load(dir: str, name: str, device: torch.device) -> PunctuationModel:
-        params = pickle_read(f"{dir}/{name}.config")
-        if os.path.exists(f"{dir}/{name}.runtime.yaml"):
-            runtime = ActionsModelEnsembleRuntimeParams.load_yaml(
-                f"{dir}/{name}.runtime.yaml"
-            )
-        else:
-            runtime = ActionsModelEnsembleRuntimeParams()
-
-        model = ActionsModelEnsembleBase(params, runtime)
-        model.to(device)
+        loss = F.binary_cross_entropy_with_logits(
+            outputs_pred,
+            outputs.unsqueeze(0).expand(outputs_pred.shape[0], -1, -1, -1),
+            self.weight_tensor,
+        )
 
-        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device))
+        return loss
 
-        return model
+    def _test_step(self, batch, batch_idx):
+        raise NotImplementedError("Test step not yet implemented")
 
 
-class ActionsModelEnsembleLoss(nn.Module):
-    """Class representing proposed loss for training mixed actions model"""
+class ActionsModelEnsembleCreator(ActionsModelCreator):
+    def setup(self, config: dict, base_model: str, seed: int) -> None:
+        self.base_model = base_model
+        self.seed = seed
 
-    def __init__(self, prior_odds: torch.Tensor) -> None:
-        """Initializes ActionsModelMixedLoss
+        self.num_models = config["num_models"]
+        self.embedding_size = config["embedding_size"]
+        self.num_heads = config["num_heads"]
+        self.feedforward_neurons = config["feedforward_neurons"]
+        self.num_heads = config["num_heads"]
+        self.learning_rate = ifnone(config["learning_rate"], 1e-5)
 
-        Args:
-            prior_odds (torch.Tensor): Odds representing ratio of positive to negative examples for each label in action vector. Shape A
-        """
-        super(ActionsModelEnsembleLoss, self).__init__()
+        self.pos_weight = None
 
-        self.core = BCEWithLogitsLoss(pos_weight=prior_odds)
-
-    def forward(
-        self,
-        predicted_action_vector_logits: torch.Tensor,
-        true_action_vector: torch.Tensor,
-    ) -> torch.Tensor:
-        """Computes loss for training mixed actions model
+    def config_stats(self, pos_examples: np.ndarray, neg_examples: np.ndarray):
+        self.pos_weight = neg_examples / pos_examples
 
-        Args:
-            true_action_vector (torch.Tensor): Action vector that should be
-                predicted by ActionsModelMixed (shifted by 1 to the left in
-                regards to inputs). Shape MxBxLxA
-
-            predicted_action_vector_logits (torch.Tensor): Action vector that
-                was acttualy predicted by ActionsModelMixed (shifted by 1 to
-                the left in regards to inputs). Shape BxL-1xA
-
-
-        Returns:
-            torch.Tensor: Loss of predition in relation to ground truth
-        """
-
-        return self.core(predicted_action_vector_logits, true_action_vector.unsqueeze(0).expand(
-            predicted_action_vector_logits.shape[0], -1, -1, -1)
+    def create_model(self) -> LightningModule:
+        assert self.pos_weight is not None
+        return ActionsModelEnsemble(
+            self.base_model,
+            self.pos_weight,
+            self.learning_rate,
+            self.num_models,
+            self.embedding_size,
+            self.num_heads,
+            self.feedforward_neurons,
+            len(ACTIONS_KEYS),
         )
+
+    def load_model(self, checkpoint: str) -> LightningModule:
+        return ActionsModelEnsemble.load_from_checkpoint(checkpoint)
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index e09c0fabe4ee895c22275e0741f4cf4dfea51433..b939c3f17e7b8f0f37631b107c07e003fee56ada 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -1,81 +1,32 @@
-import os
-from dataclasses import dataclass
-from typing import Optional
-
 import numpy as np
+import pytorch_lightning as pl
 import torch
 import torch.nn as nn
-from torch.nn.modules.loss import BCEWithLogitsLoss
+import torch.nn.functional as F
+from pytorch_lightning.core.lightning import LightningModule
 from transformers.tokenization_bert import BertTokenizerFast
 
 from src.models.common import PositionalEncoding, generate_square_subsequent_mask
-from src.models.interfaces import PunctuationModel
-from src.pipelines.actions_based.processing import (
-    ACTIONS_KEYS,
-    action_vector,
-    recover_text,
-    token_labels_to_word_labels,
-)
-from src.utils import (
-    get_device,
-    pickle_read,
-    pickle_save,
-    prepare_folder,
-    yaml_serializable,
-)
-
-
-@dataclass
-class ActionsModelMixedParams:
-    """
-    Parameters for initializing ActionsModelMixed
-
-    Params:
-        base_tokenizer (str): Name of pretrained tokenizer
-        vocab_size (int): Number of tokens in tokenizer dictionary
-        embedding_size (int, optional): Shape of word and punctuation embeddings. Defaults to 200.
-        num_heads (int, optional): Number of heads in multiheaded attention. Defaults to 4.
-        num_layers (int, optional): Number of both decoded and encoder layers. Defaults to 2.
-        feedforward_neurons (int, optional): Size of feed-forward neural network at the end of encoder/decoder. Defaults to 200.
-        num_labels (int, optional): Action-vector size. Defaults to len(ACTIONS_KEYS).
-        max_len (int, optional): Maxium length of sequence. Defaults to 500.
-        dropout (float, optional): Dropout ratio. Defaults to 0.1.
-    """
-
-    base_tokenizer: str
-    vocab_size: int
-    threshold: float = 0.9
-    embedding_size: int = 200
-    num_heads: int = 4
-    num_layers: int = 2
-    feedforward_neurons: int = 200
-    num_labels: int = len(ACTIONS_KEYS)
-    max_len: int = 500
-    dropout: float = 0.1
-
-
-@yaml_serializable
-@dataclass
-class ActionsModelMixedRuntimeParams:
-    """
-    Parameters for ActionsModelMixed during runtime interference
-
-    Args:
-        threshold (float): minimum confidence for applying action
-        chunksize (int): Maximum number of chunks to perform inference on
-    """
-
-    threshold: float = 0.9
-    max_cond_len: Optional[int] = 500
-
-
-class ActionsModelMixed(PunctuationModel):
+from src.models.interfaces import ActionsModelCreator
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.utils import ifnone
+
+
+class ActionsModelMixed(LightningModule):
     """Encoder-decoder based model with unpunctuated token sequence as input and array of action-vectors as output"""
 
     def __init__(
         self,
-        params: ActionsModelMixedParams,
-        runtime: ActionsModelMixedRuntimeParams = ActionsModelMixedRuntimeParams(),
+        base_tokenizer: str,
+        pos_weights: np.ndarray,
+        learning_rate: float = 1e-4,
+        embedding_size: int = 200,
+        num_heads: int = 4,
+        num_layers: int = 2,
+        feedforward_neurons: int = 200,
+        num_labels: int = len(ACTIONS_KEYS),
+        max_len: int = 500,
+        dropout: float = 0.1,
     ) -> None:
         """Initializes mixed model
 
@@ -84,47 +35,44 @@ class ActionsModelMixed(PunctuationModel):
         """
         super(ActionsModelMixed, self).__init__()
 
-        self.params = params
-        self.runtime = runtime
-        self._tokenizer = None
+        self.tokenizer = BertTokenizerFast.from_pretrained(base_tokenizer)
 
-        self.num_labels = params.num_labels
+        self.register_buffer("weight_tensor", torch.tensor(pos_weights))
+        self.learning_rate = learning_rate
+        self.num_labels = num_labels
+
+        self.learning_rate_min = None
+        self.learning_rate_max = None
 
         # Word embedder
-        self.word_embedding = nn.Embedding(params.vocab_size, params.embedding_size)
-        self.punctuation_embedding = nn.Linear(params.num_labels, params.embedding_size)
+        self.word_embedding = nn.Embedding(self.tokenizer.vocab_size, embedding_size)
+        self.punctuation_embedding = nn.Linear(num_labels, embedding_size)
 
         # Add positional encoding
         self.words_position_embedding = PositionalEncoding(
-            params.embedding_size, params.max_len, params.dropout
+            embedding_size, max_len, dropout
         )
         self.punctuation_position_embedding = PositionalEncoding(
-            params.embedding_size, params.max_len, params.dropout
+            embedding_size, max_len, dropout
         )
 
         # Sentence encoder
         sentence_encoder_layer = nn.TransformerEncoderLayer(
-            params.embedding_size,
-            params.num_heads,
-            params.feedforward_neurons,
-            params.dropout,
+            embedding_size, num_heads, feedforward_neurons, dropout,
         )
         self.sentence_encoder = nn.TransformerEncoder(
-            sentence_encoder_layer, num_layers=params.num_layers
+            sentence_encoder_layer, num_layers=num_layers
         )
 
         # Punctuation decoder
         punctuation_decoder_layer = nn.TransformerDecoderLayer(
-            params.embedding_size,
-            params.num_heads,
-            params.feedforward_neurons,
-            params.dropout,
+            embedding_size, num_heads, feedforward_neurons, dropout,
         )
         self.punctuation_decoder = nn.TransformerDecoder(
-            punctuation_decoder_layer, num_layers=params.num_layers
+            punctuation_decoder_layer, num_layers=num_layers
         )
 
-        self.to_labels = nn.Linear(params.embedding_size, params.num_labels)
+        self.to_labels = nn.Linear(embedding_size, num_labels)
 
     def forward(
         self,
@@ -153,7 +101,7 @@ class ActionsModelMixed(PunctuationModel):
         y = self.punctuation_embedding(y)
         y = self.punctuation_position_embedding(y)
 
-        tgt_mask = generate_square_subsequent_mask(y.shape[0]).to(y.device)
+        tgt_mask = generate_square_subsequent_mask(y.shape[0]).type_as(y)
 
         sentence_encoded = self.sentence_encoder(x, src_key_padding_mask=attention_mask)
 
@@ -165,140 +113,81 @@ class ActionsModelMixed(PunctuationModel):
 
         return self.to_labels(z)
 
-    def tokenizer(self) -> BertTokenizerFast:
-        if self._tokenizer is None:
-            self._tokenizer = BertTokenizerFast.from_pretrained(
-                self.params.base_tokenizer
-            )
-        return self._tokenizer
-
-    def predict(self, text: str) -> str:
-        # TODO: Optimize for speed
-
-        inputs = [action_vector(["upper_case"])]
+    def configure_optimizers(self):
+        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
+        return optimizer
 
-        tokenizer = self.tokenizer()
-        text_tokenized = tokenizer(text, return_tensors="pt")
+    def training_step(self, batch, batch_idx):
+        loss = self._common_step(batch)
+        result = pl.TrainResult(minimize=loss)
 
-        target_device = get_device(self)
+        result.log("train_loss", loss)
 
-        max_cond_len = self.runtime.max_cond_len
-        if max_cond_len is None:
-            max_cond_len = np.iinfo(np.int).max
+        return result
 
-        for _ in range(text_tokenized["input_ids"].shape[1] - 2):
-            input_start = max(0, len(inputs) - max_cond_len)
+    def validation_step(self, batch, batch_idx):
+        loss = self._common_step(batch)
+        result = pl.EvalResult(loss)
+        result.valid_batch_loss = loss
+        result.log("valid_loss", loss)
 
-            prediction_raw = self.forward(
-                text_tokenized["input_ids"][:, input_start:].to(target_device),
-                torch.tensor(inputs[input_start:], dtype=torch.float)
-                .reshape(1, -1, self.num_labels)
-                .to(target_device),
-                (text_tokenized["attention_mask"][:, input_start:] == 0).to(
-                    target_device
-                ),
-            ).sigmoid()
+        return result
 
-            inputs.append(
-                (
-                    prediction_raw.detach().cpu().numpy()[0, -1, :]
-                    > self.runtime.threshold
-                ).astype(np.float)
-            )
+    def validation_epoch_end(self, outputs):
+        avg_loss = outputs.valid_batch_loss.mean()
+        result = pl.EvalResult(checkpoint_on=avg_loss)
+        result.log("valid_loss", avg_loss, on_epoch=True, prog_bar=True)
 
-        word_labels = token_labels_to_word_labels(text, inputs[1:], tokenizer)
+        return result
 
-        prediction_binary = word_labels.astype(np.int)
-
-        return recover_text(text, prediction_binary)
-
-    def predict_raw(
-        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
-    ) -> torch.Tensor:
-        """Function that maps input_ids tensors into per-token labels
-
-        Args:
-            input_ids (torch.Tensor): Token ids of input. Shape BxL
-            attention_mask (torch.Tensor): Attention mask of tokens. Shape BxL
+    def _common_step(self, batch):
+        inputs = batch["source"]
+        outputs = batch["target"].type(torch.float)
+        attentions_mask = batch["attention_mask"] == 0
 
-        Returns:
-            torch.Tensor: Per-token action-vector labels. Shape BxLxA
-        """
-        outputs = torch.tensor(action_vector(["upper_case"]), dtype=torch.float).to(
-            input_ids.device
+        outputs_pred = self(
+            input_ids=inputs, actions=outputs[:, :-1], attention_mask=attentions_mask
         )
-        outputs = outputs.unsqueeze(0).unsqueeze(0).repeat(input_ids.shape[0], 1, 1)
-
-        for _ in range(input_ids.shape[1] - 1):
-            prediction_raw = self.forward(
-                input_ids, outputs, (attention_mask == 0)
-            ).sigmoid()
-
-            prediction_raw = (prediction_raw[:, -1:, :] > self.runtime.threshold).type(
-                torch.float
-            )
-            outputs = torch.cat([outputs, prediction_raw], dim=1)
-
-        return outputs
-
-    def save(self, dir: str, name: str, runtime: bool = True) -> None:
-        prepare_folder(dir)
-        torch.save(self.state_dict(), f"{dir}/{name}.model")
-        pickle_save(self.params, f"{dir}/{name}.config")
-
-        if runtime:
-            self.runtime.save_yaml(f"{dir}/{name}.runtime.yaml")
-
-    @staticmethod
-    def load(dir: str, name: str, device: torch.device) -> PunctuationModel:
-        params = pickle_read(f"{dir}/{name}.config")
-        if os.path.exists(f"{dir}/{name}.runtime.yaml"):
-            runtime = ActionsModelMixedRuntimeParams.load_yaml(
-                f"{dir}/{name}.runtime.yaml"
-            )
-        else:
-            runtime = ActionsModelMixedRuntimeParams()
-
-        model = ActionsModelMixed(params, runtime)
-        model.to(device)
-
-        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device))
 
-        return model
-
-
-class ActionsModelMixedLoss(nn.Module):
-    """Class representing proposed loss for training mixed actions model"""
-
-    def __init__(self, prior_odds: torch.Tensor) -> None:
-        """Initializes ActionsModelMixedLoss
-
-        Args:
-            prior_odds (torch.Tensor): Odds representing ratio of positive to negative examples for each label in action vector. Shape A
-        """
-        super(ActionsModelMixedLoss, self).__init__()
-
-        self.core = BCEWithLogitsLoss(pos_weight=prior_odds)
-
-    def forward(
-        self,
-        true_action_vector: torch.Tensor,
-        predicted_action_vector_logits: torch.Tensor,
-    ) -> torch.Tensor:
-        """Computes loss for training mixed actions model
-
-        Args:
-            true_action_vector (torch.Tensor): Action vector that should be
-                predicted by ActionsModelMixed (shifted by 1 to the left in
-                regards to inputs). Shape BxL-1xA
-
-            predicted_action_vector_logits (torch.Tensor): Action vector that
-                was acttualy predicted by ActionsModelMixed (shifted by 1 to
-                the left in regards to inputs). Shape BxL-1xA
-
-
-        Returns:
-            torch.Tensor: Loss of predition in relation to ground truth
-        """
+        loss = F.binary_cross_entropy_with_logits(
+            outputs_pred, outputs[:, 1:], self.weight_tensor
+        )
+        return loss
+
+
+class ActionsModelMixedCreator(ActionsModelCreator):
+    def setup(self, config: dict, base_model: str, seed: int) -> None:
+        self.base_model = base_model
+        self.seed = seed
+
+        self.embedding_size = config["embedding_size"]
+        self.num_heads = config["num_heads"]
+        self.feedforward_neurons = config["feedforward_neurons"]
+        self.num_heads = config["num_heads"]
+        self.num_layers = config["num_layers"]
+        self.dropout = config["dropout"]
+        self.learning_rate = ifnone(config["learning_rate"], 1e-5)
+
+        self.pos_weight = None
+
+    def config_stats(self, pos_examples: np.ndarray, neg_examples: np.ndarray):
+        self.pos_weight = neg_examples / pos_examples
+
+    def create_model(self) -> LightningModule:
+        assert self.pos_weight is not None
+
+        return ActionsModelMixed(
+            self.base_model,
+            self.pos_weight,
+            self.learning_rate,
+            self.embedding_size,
+            self.num_heads,
+            self.num_layers,
+            self.feedforward_neurons,
+            len(ACTIONS_KEYS),
+            500,
+            self.dropout,
+        )
 
-        return self.core(predicted_action_vector_logits, true_action_vector)
+    def load_model(self, checkpoint: str) -> LightningModule:
+        return ActionsModelMixed.load_from_checkpoint(checkpoint)
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index eb7f859120d7b1f694a59df3547a17551d2c0ded..2bd59c746e98f196ba39d6b3ea2d1386ee47d24f 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -1,71 +1,31 @@
 from __future__ import annotations
 
-import os
-from dataclasses import dataclass
-
 import numpy as np
+import pytorch_lightning as pl
 import torch
-import torch.nn as nn
+import torch.nn.functional as F
+from pytorch_lightning.core.lightning import LightningModule
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
 from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.actions_model_mixed import ActionsModelMixed
-from src.models.interfaces import ActionsModel, PunctuationModel
-from src.pipelines.actions_based.processing import (
-    action_vector,
-    last_stop_label,
-    recover_text,
-    token_labels_to_word_labels,
-)
-from src.pipelines.actions_based.utils import max_suppression
-from src.utils import (
-    get_device,
-    pickle_read,
-    pickle_save,
-    prepare_folder,
-    yaml_serializable,
-)
-
-
-@dataclass
-class ActionsModelRestrictedParams:
-    """
-    Parameters for ActionsModelRestricted
-
-    Params:
-        base_model (str): Name of base model
-        extended_action_vector_size (int): Action-vector size including additional no-punctuation logit
-    """
-
-    base_model: str
-    extended_action_vector_size: int
-
+from src.models.interfaces import ActionsModelCreator
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.utils import ifnone, onecycle_shed
 
-@yaml_serializable
-@dataclass
-class ActionsModelRestrictedRuntimeParams:
-    """
-    Parameters for ActionsModelBase during runtime interference
 
-    Args:
-        threshold (float): minimum confidence for applying action
-        chunksize (int): Maximum number of chunks to perform inference on
-    """
-
-    threshold: float = 0.9
-    chunksize: int = 500
-
-
-class ActionsModelRestricted(ActionsModel):
+class ActionsModelRestricted(LightningModule):
     """Similar to ActionsModelBase, however no-punctuation class is added
     and punctuation-related entries are treaded as proper categorical distribution
     """
 
     def __init__(
         self,
-        params: ActionsModelRestrictedParams,
-        runtime: ActionsModelRestrictedRuntimeParams = ActionsModelRestrictedRuntimeParams(),
+        base_model: str,
+        extended_action_vector_size: int,
+        prior_uppercase_odds: np.ndarray,
+        punctuation_pos_weights: np.ndarray,
+        learning_rate: float = 1e-4,
     ) -> None:
         """Initializes restricted actions model
 
@@ -75,14 +35,23 @@ class ActionsModelRestricted(ActionsModel):
         """
         super(ActionsModelRestricted, self).__init__()
 
-        self.params = params
-        self.runtime = runtime
-        self._tokenizer = None
-
-        config = PretrainedConfig.from_pretrained(params.base_model)
+        self.tokenizer = BertTokenizerFast.from_pretrained(base_model)
+        self.register_buffer(
+            "prior_uppercase_odds",
+            torch.tensor(prior_uppercase_odds, dtype=torch.float).reshape(1),
+        )
+        self.register_buffer(
+            "punctuation_pos_weights",
+            torch.tensor(punctuation_pos_weights, dtype=torch.float),
+        )
+        self.learning_rate = learning_rate
+        self.num_labels = extended_action_vector_size
 
-        config.num_labels = params.extended_action_vector_size
+        self.learning_rate_min = None
+        self.learning_rate_max = None
 
+        config = PretrainedConfig.from_pretrained(base_model)
+        config.num_labels = extended_action_vector_size
         self.core = BertForTokenClassification(config)
 
     def forward(
@@ -101,173 +70,100 @@ class ActionsModelRestricted(ActionsModel):
 
         return y_pred
 
-    def predict_raw(
-        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
-    ) -> torch.Tensor:
-        """Function that maps input_ids tensors into per-token labels
-
-        Args:
-            input_ids (torch.Tensor): Token ids of input. Shape BxL
-            attention_mask (torch.Tensor): Attention mask of tokens. Shape BxL
-
-        Returns:
-            torch.Tensor: Per-token action-vector labels. Shape BxLxA
-        """
-
-        logits = self.forward(input_ids, attention_mask=attention_mask)
-        prob_uppercase = logits[:, :, :1].sigmoid()
-        prob_punctuation = logits[:, :, 1:].softmax(dim=-1)
-
-        no_punctuation = prob_punctuation.argmax(-1) == (
-            self.params.extended_action_vector_size - 2
-        )
-        no_punctuation = (
-            no_punctuation.type(torch.float)
-            .unsqueeze(-1)
-            .repeat(1, 1, prob_punctuation.shape[-1] - 1)
-        )
-
-        prob_punctuation = prob_punctuation[:, :, :-1].softmax(-1) * (
-            1 - no_punctuation
-        )
-
-        return torch.cat([prob_uppercase, prob_punctuation], dim=-1)
-
-    def predict(self, text: str) -> str:
-        chunk_size = self.runtime.chunksize
-        threshold = self.runtime.threshold
+    @onecycle_shed
+    def configure_optimizers(self):
+        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
+        return optimizer
 
-        device = get_device(self)
+    def training_step(self, batch, batch_idx):
+        loss = self._common_step(batch)
+        result = pl.TrainResult(minimize=loss)
 
-        text = text.strip()
+        result.log("train_loss", loss)
 
-        tokenizer = self.tokenizer()
-        tokens = tokenizer(text, return_tensors="pt")["input_ids"].to(device)
-        output = None
+        return result
 
-        index_start = 0
-        while index_start < len(tokens[0]):
-            index_end = min(index_start + chunk_size, len(tokens[0]))
+    def validation_step(self, batch, batch_idx):
+        loss = self._common_step(batch)
+        result = pl.EvalResult(loss)
+        result.valid_batch_loss = loss
+        result.log("valid_loss", loss)
 
-            tokens_chunk = tokens[:, index_start:index_end]
+        return result
 
-            attention_mask = torch.ones_like(tokens_chunk).to(device)
+    def validation_epoch_end(self, outputs):
+        avg_loss = outputs.valid_batch_loss.mean()
+        result = pl.EvalResult(checkpoint_on=avg_loss)
+        result.log("valid_loss", avg_loss, on_epoch=True, prog_bar=True)
 
-            actions = (
-                self.predict_raw(tokens_chunk, attention_mask).detach().cpu().numpy()
-            )
-            actions_suppresed = max_suppression(actions, threshold)[0]
+        return result
 
-            offset = last_stop_label(actions_suppresed, action_vector(["dot"]))
+    def _common_step(self, batch):
+        inputs = batch["source"]
+        outputs = batch["target"]
+        attentions_mask = batch["attention_mask"]
 
-            # Prevent infinite loop
-            if (offset is None) or (offset == 0):
-                offset = index_end - index_start
-
-            if output is None:
-                output = actions[0, 0:offset]
-            else:
-                output = np.concatenate([output, actions[0, 0:offset]], axis=0)
-
-            index_start += offset
-
-        assert len(output) == len(tokens[0])
-
-        word_labels = token_labels_to_word_labels(text, output[1:-1], tokenizer)
-        actions = max_suppression(np.expand_dims(word_labels, 0), threshold)[0]
-
-        return recover_text(text, actions)
-
-    @staticmethod
-    def _logit(x: torch.Tensor):
-        EPS = 1e-5
+        outputs = torch.cat(
+            [outputs, (1.0 - outputs[:, :, 1:].max(-1)[0]).unsqueeze(-1)], axis=-1
+        )
 
-        z = torch.clamp(x, EPS, 1.0 - EPS)
+        outputs_pred = self(input_ids=inputs, attention_mask=attentions_mask)
 
-        return torch.log(z / (1 - z))
+        predicted_punc = outputs_pred[:, :, 1:].transpose(1, 2)
+        target_punc_index = torch.argmax(outputs[:, :, 1:], dim=-1)
 
-    def tokenizer(self) -> BertTokenizerFast:
-        if self._tokenizer is None:
-            self._tokenizer = BertTokenizerFast.from_pretrained(self.params.base_model)
-        return self._tokenizer
+        punc_loss = F.cross_entropy(
+            predicted_punc, target_punc_index, self.punctuation_pos_weights
+        )
 
-    def save(self, dir: str, name: str, runtime: bool = True) -> None:
-        prepare_folder(dir)
-        torch.save(self.state_dict(), f"{dir}/{name}.model")
-        pickle_save(self.params, f"{dir}/{name}.config")
+        predicted_uppercase = outputs_pred[:, :, 0]
+        target_uppercase = outputs[:, :, 0]
+        uppercase_loss = F.binary_cross_entropy_with_logits(
+            predicted_uppercase, target_uppercase, self.prior_uppercase_odds
+        )
 
-        if runtime:
-            self.runtime.save_yaml(f"{dir}/{name}.runtime.yaml")
+        return punc_loss + uppercase_loss
 
-    @staticmethod
-    def load(dir: str, name: str, device: torch.device) -> ActionsModelRestricted:
-        params = pickle_read(f"{dir}/{name}.config")
-        if os.path.exists(f"{dir}/{name}.runtime.yaml"):
-            runtime = ActionsModelRestrictedRuntimeParams.load_yaml(
-                f"{dir}/{name}.runtime.yaml"
-            )
-        else:
-            runtime = ActionsModelRestrictedRuntimeParams()
+    def _test_step(self, batch, batch_idx):
+        raise NotImplementedError("Test step not yet implemented")
 
-        model = ActionsModelRestricted(params, runtime).to(device)
-        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
 
-        return model
+class ActionsModelRestrictedCreator(ActionsModelCreator):
+    def setup(self, config: dict, base_model: str, seed: int) -> None:
+        self.base_model = base_model
+        self.seed = seed
+        self.learning_rate = ifnone(config["learning_rate"], 1e-5)
 
+        self.pos_weight = None
 
-class ActionsModelRestrictedLoss(nn.Module):
-    def __init__(
-        self, prior_uppercase_odds: torch.Tensor, punctuation_weights: torch.Tensor
-    ) -> None:
-        """Initializes ActionsModelRestrictedLoss
+    def config_stats(self, pos_examples: np.ndarray, neg_examples: np.ndarray):
+        uppercase_pos_examples = pos_examples[0]
+        uppercase_neg_examples = neg_examples[0]
+        self.uppercase_pos_odds = uppercase_neg_examples / uppercase_pos_examples
 
-        Args:
-            prior_uppercase_odds (torch.Tensor): Odds od positive to negative cases of uppercase in dataset
-            punctuation_weights (torch.Tensor): Weights for each class in loss function. Should be inversly proportional to number of
-                their occurances in dataset (Shape A+1)
-        """
-        super(ActionsModelRestrictedLoss, self).__init__()
+        has_punctuation_neg_examples = neg_examples[1:]
+        has_no_punctuation_neg_examples = np.sum(pos_examples[1:])
 
-        self.binary_ce = nn.BCEWithLogitsLoss(
-            pos_weight=prior_uppercase_odds.reshape(1)
+        punctuation_neg_examples = np.concatenate(
+            [has_punctuation_neg_examples, has_no_punctuation_neg_examples.reshape(1)],
+            -1,
         )
-        self.cat_ce = nn.CrossEntropyLoss(punctuation_weights)
 
-    def forward(
-        self,
-        predicted_action_vector_logits: torch.Tensor,
-        true_extended_action_vector: torch.Tensor,
-    ) -> torch.Tensor:
-        """Loss for ActionsModelRestricted model
-
-        Args:
-            true_extended_action_vector (torch.Tensor): Ground-truth action vectors. Shape BxLxA+1
-            predicted_action_vector_logits (torch.Tensor): Action vector-s logits predicted by ActionsModelRestricted model. Shape BxLxA+1
-
-        Returns:
-            torch.Tensor: Loss value
-        """
-
-        predicted_punc = predicted_action_vector_logits[:, :, 1:].transpose(1, 2)
-        target_punc_index = torch.argmax(true_extended_action_vector[:, :, 1:], dim=-1)
-        punc_loss = self.cat_ce(predicted_punc, target_punc_index)
-
-        predicted_uppercase = predicted_action_vector_logits[:, :, 0]
-        target_uppercase = true_extended_action_vector[:, :, 0]
-        uppercase_loss = self.binary_ce(predicted_uppercase, target_uppercase)
-
-        return punc_loss + uppercase_loss
-
-    def save(self, dir: str, name: str) -> None:
-        prepare_folder(dir)
-        torch.save(self.state_dict(), f"{dir}/{name}.model")
-        pickle_save(self.params, f"{dir}/{name}.config")
+        self.punctuation_class_weights = (punctuation_neg_examples) / np.sum(
+            punctuation_neg_examples
+        )
 
-    @staticmethod
-    def load(dir: str, name: str, device: torch.device) -> PunctuationModel:
-        params = pickle_read(f"{dir}/{name}.config")
-        model = ActionsModelMixed(params)
+    def create_model(self) -> LightningModule:
+        assert self.uppercase_pos_odds is not None
+        assert self.punctuation_class_weights is not None
 
-        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
+        return ActionsModelRestricted(
+            self.base_model,
+            len(ACTIONS_KEYS) + 1,
+            self.uppercase_pos_odds,
+            self.punctuation_class_weights,
+            self.learning_rate,
+        )
 
-        return model
+    def load_model(self, checkpoint: str) -> LightningModule:
+        return ActionsModelRestricted.load_from_checkpoint(checkpoint)
diff --git a/src/models/interfaces.py b/src/models/interfaces.py
index 46271456f012fb5c862850771f8329957044456f..a15851c1b7d2bd3a56c444f001bbaa4465ff02b0 100644
--- a/src/models/interfaces.py
+++ b/src/models/interfaces.py
@@ -2,12 +2,13 @@ from __future__ import annotations
 
 from abc import ABC, abstractmethod
 
+import numpy as np
 import torch
-import torch.nn as nn
+from pytorch_lightning.core.lightning import LightningModule
 from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
 
 
-class PunctuationModel(nn.Module, ABC):
+class PunctuationModel(ABC):
     def __init__(self) -> None:
         super().__init__()
 
@@ -15,15 +16,6 @@ class PunctuationModel(nn.Module, ABC):
     def tokenizer(self) -> PreTrainedTokenizerFast:
         pass
 
-    @abstractmethod
-    def save(self, dir: str, name: str, runtime: bool = False) -> None:
-        pass
-
-    @staticmethod
-    @abstractmethod
-    def load(dir: str, name: str, device: torch.device) -> PunctuationModel:
-        pass
-
 
 class ActionsModel(PunctuationModel):
     def __init__(self) -> None:
@@ -42,8 +34,27 @@ class ActionsModel(PunctuationModel):
         Returns:
             torch.Tensor: Per-token action-vector labels. Shape BxLxA
         """
-        pass
 
     @abstractmethod
     def predict(self, text: str) -> str:
         pass
+
+
+class ModelCreator(ABC):
+    @abstractmethod
+    def setup(self, config: dict, base_model: str, seed: int) -> None:
+        pass
+
+    @abstractmethod
+    def create_model(self) -> LightningModule:
+        pass
+
+    @abstractmethod
+    def load_model(self, checkpoint: str) -> LightningModule:
+        pass
+
+
+class ActionsModelCreator(ModelCreator):
+    @abstractmethod
+    def config_stats(self, pos_examples: np.ndarray, neg_examples: np.ndarray):
+        pass
diff --git a/src/models/model_factory.py b/src/models/model_factory.py
index 5e4a9fc554fff5c1d284a0ae47b02b684f554ff8..88b87cbefb69de8272b2b1e6b2a5250c9c43c133 100644
--- a/src/models/model_factory.py
+++ b/src/models/model_factory.py
@@ -1,9 +1,20 @@
-from src.models.actions_model_base import ActionsModelBase
-from src.models.actions_model_mixed import ActionsModelMixed
-from src.models.actions_model_restricted import ActionsModelRestricted
+from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseCreator
+from src.models.actions_model_ensemble import ActionsModelEnsembleCreator
+from src.models.actions_model_mixed import ActionsModelMixed, ActionsModelMixedCreator
+from src.models.actions_model_restricted import (
+    ActionsModelRestricted,
+    ActionsModelRestrictedCreator,
+)
 
 MODELS_MAP = {
     "actions_base": ActionsModelBase,
     "actions_restricted": ActionsModelRestricted,
     "actions_mixed": ActionsModelMixed,
 }
+
+MODEL_CREATOR_MAP = {
+    "actions_base": ActionsModelBaseCreator,
+    "actions_ensemble": ActionsModelEnsembleCreator,
+    "actions_mixed": ActionsModelMixedCreator,
+    "actions_restricted": ActionsModelRestrictedCreator,
+}
diff --git a/src/pipelines/actions_based/data_source.py b/src/pipelines/actions_based/data_source.py
new file mode 100644
index 0000000000000000000000000000000000000000..a55ecf37667dea043ce08766996397a149be5ae6
--- /dev/null
+++ b/src/pipelines/actions_based/data_source.py
@@ -0,0 +1,82 @@
+from typing import List, Optional, Union
+
+import dask.dataframe as dd
+import numpy as np
+from pytorch_lightning.core.datamodule import LightningDataModule
+from torch.utils.data import IterableDataset, random_split
+from torch.utils.data.dataloader import DataLoader
+
+from src.utils import get_ordered_dataframe_len
+
+INPUT_PATH = "../generated/actions/stage4_reindexing"
+
+
+class ActionsDataset(IterableDataset):
+    def __init__(self, path: str, seed: int = 44) -> None:
+        super().__init__()
+        self._df = dd.read_parquet(path, engine="pyarrow")
+        self._df_len = get_ordered_dataframe_len(self._df)
+
+        self._generator = np.random.RandomState(seed)
+
+    def __getitem__(self, index: int) -> dict:
+        entry = self._df.loc[index].compute(scheduler="synchronous")
+
+        _, entry = next(entry.iterrows())
+
+        return {
+            "source": np.array(
+                entry["source"].reshape(entry["source_shape"]).squeeze(-1),
+                dtype=np.long,
+            ),
+            "target": np.array(
+                entry["target"].reshape(entry["target_shape"]), dtype=np.float
+            ),
+            "attention_mask": np.array(
+                entry["attention_mask"].reshape(entry["attention_mask_shape"]),
+                dtype=np.float,
+            ),
+        }
+
+    def __len__(self) -> int:
+        return self._df_len
+
+
+class ActionsDataModule(LightningDataModule):
+    def __init__(
+        self,
+        path: str,
+        batch_size: int,
+        valid_frac: float = 0.01,
+        test_frac: float = 0.01,
+        num_workers=0,
+    ) -> None:
+        super().__init__()
+
+        self._path = path
+        self._valid_frac = valid_frac
+        self._test_frac = test_frac
+        self._batch_size = batch_size
+        self._num_workers = num_workers
+
+    def setup(self, stage: Optional[str]):
+        actions_dataset = ActionsDataset(self._path)
+
+        num_valid = int(len(actions_dataset) * self._valid_frac)
+        num_test = int(len(actions_dataset) * self._test_frac)
+        num_train = len(actions_dataset) - num_valid - num_test
+
+        self._train, self._valid, self._test = random_split(
+            actions_dataset, [num_train, num_valid, num_test]
+        )
+
+    def train_dataloader(self, *args, **kwargs) -> DataLoader:
+        return DataLoader(
+            self._train, batch_size=self._batch_size, num_workers=self._num_workers
+        )
+
+    def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
+        return DataLoader(self._valid, self._batch_size, num_workers=self._num_workers)
+
+    def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
+        return DataLoader(self._test, self._batch_size, num_workers=self._num_workers)
diff --git a/src/pipelines/actions_based/test.py b/src/pipelines/actions_based/test.py
index 8cfcda93c2e04f18290269c4ae137418670e6108..79c6f1e4810df9a394ca712bf5814abcda16394b 100644
--- a/src/pipelines/actions_based/test.py
+++ b/src/pipelines/actions_based/test.py
@@ -54,9 +54,9 @@ if __name__ == "__main__":
     args = parser.parse_args()
 
     config = get_config()
-    limit = config["actions"][args.stage]["limit"]
-    batch_size = config["actions"][args.stage]["batch_size"]
-    device_name = config["actions"][args.stage]["device"]
+    limit = config["actions"]["testing"][args.stage]["limit"]
+    batch_size = config["actions"]["testing"][args.stage]["batch_size"]
+    device_name = config["actions"]["testing"][args.stage]["device"]
 
     test_dataset = f"{PROJECT_ROOT}/{args.dataset}"
 
diff --git a/src/pipelines/actions_based/train.py b/src/pipelines/actions_based/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..a621f8ffd1c9ea33650c6fdd80cb98fa0835a380
--- /dev/null
+++ b/src/pipelines/actions_based/train.py
@@ -0,0 +1,135 @@
+#!/usr/bin/python3
+
+import argparse
+import os
+import pickle
+
+import numpy as np
+from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.callbacks.progress import ProgressBar
+from pytorch_lightning.trainer import Trainer
+
+from src.models.model_factory import MODEL_CREATOR_MAP
+from src.pipelines.actions_based.data_source import ActionsDataModule
+from src.utils import (
+    PROJECT_ROOT,
+    LRFinder,
+    StepCheckpoint,
+    get_config,
+    get_logger,
+    link_datamodule,
+    prepare_folder,
+)
+
+INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
+OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions_base"
+
+VALID_FRAC = 0.1
+TEST_FRAC = 0.1
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Train actions model")
+    parser.add_argument(
+        "-m",
+        "--model",
+        type=str,
+        choices=list(MODEL_CREATOR_MAP.keys()),
+        help="Name of the model to train",
+    )
+
+    parser.add_argument(
+        "-n", "--name", type=str, required=False, help="Name of the experiment",
+    )
+
+    parser.add_argument(
+        "-t",
+        "--test",
+        required=False,
+        default=False,
+        action="store_true",
+        help="Only train on 1 batch of train, validation and test.",
+    )
+
+    args = parser.parse_args()
+
+    if args.name is not None:
+        OUTPUT_PATH = f"{OUTPUT_PATH}_{args.name}"
+
+    fast_dev_run = args.test
+    model_name = args.model
+    creator = MODEL_CREATOR_MAP[model_name]()
+
+    config = get_config()
+    learning_rate = config["actions"]["training"][model_name]["learning_rate"]
+    batch_size = config["actions"]["training"][model_name]["batch_size"]
+    save_step = config["actions"]["training"][model_name]["save_step"]
+    start_checkpoint = config["actions"]["training"][model_name]["start_checkpoint"]
+    gpu_devices = config["actions"]["training"][model_name]["gpu_devices"]
+    num_workers = config["actions"]["training"][model_name]["num_workers"]
+    base_model = config["global"]["base_model"]
+    seed = config["global"]["random_seed"]
+
+    np.random.seed(seed=seed)
+
+    # Load loss weights
+    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
+        stats = pickle.load(f)
+        pos_examples = stats["class_number"]
+        neg_examples = stats["num_examples"] - stats["class_number"]
+
+    creator.setup(config["actions"]["training"][model_name], base_model, seed)
+    creator.config_stats(pos_examples, neg_examples)
+
+    logger = get_logger(f"train/{model_name}")
+    if len(gpu_devices) == 0:
+        gpu_devices = 0
+
+    data_module = ActionsDataModule(
+        INPUT_PATH,
+        batch_size,
+        num_workers=num_workers,
+        valid_frac=VALID_FRAC,
+        test_frac=TEST_FRAC,
+    )
+
+    output_path_steps = f"{OUTPUT_PATH}/steps"
+    output_path_epochs = f"{OUTPUT_PATH}/epochs"
+
+    prepare_folder(OUTPUT_PATH)
+    prepare_folder(output_path_steps)
+    prepare_folder(output_path_epochs)
+
+    checkpoint = ModelCheckpoint(
+        output_path_epochs, save_last=True, save_top_k=5, monitor="val_loss"
+    )
+
+    callbacks = [StepCheckpoint(output_path_steps, save_step), ProgressBar()]
+
+    if learning_rate is None:
+        callbacks.append(LRFinder())
+
+    if (start_checkpoint is None) or (
+        not os.path.isfile(f"{OUTPUT_PATH}/{start_checkpoint}.ckpt")
+    ):
+        model = creator.create_model()
+        link_datamodule(model, data_module)
+        checkpoint_path = None
+    else:
+        checkpoint_path = f"{OUTPUT_PATH}/{start_checkpoint}.ckpt"
+        model = creator.load_model(checkpoint_path)
+        link_datamodule(model, data_module)
+
+    trainer = Trainer(
+        resume_from_checkpoint=checkpoint_path,
+        callbacks=callbacks,
+        gradient_clip_val=1.0,
+        gpus=gpu_devices,
+        logger=logger,
+        default_root_dir=OUTPUT_PATH,
+        checkpoint_callback=checkpoint,
+        distributed_backend="dp",
+        fast_dev_run=fast_dev_run,
+    )
+    trainer.fit(model)
+    trainer.save_checkpoint(f"{OUTPUT_PATH}/final.ckpt")
diff --git a/src/pipelines/actions_based/train_base.py b/src/pipelines/actions_based/train_base.py
deleted file mode 100755
index c8597f017c1d53582822e8a93559b8fbf95fda27..0000000000000000000000000000000000000000
--- a/src/pipelines/actions_based/train_base.py
+++ /dev/null
@@ -1,120 +0,0 @@
-#!/usr/bin/python3
-
-import pickle
-
-import dask.dataframe as dd
-import numpy as np
-import torch
-from transformers import BertTokenizerFast
-
-from src.models.actions_model_base import (
-    ActionsModelBase,
-    ActionsModelBaseLoss,
-    ActionsModelBaseParams,
-)
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.utils import (
-    PROJECT_ROOT,
-    Checkpoint,
-    Loader,
-    ProgressTracker,
-    Saver,
-    Timeout,
-    convert_to_timedelta,
-    get_config,
-    random_indexes,
-    training_loop,
-    unflattened_column,
-)
-
-INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
-INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
-OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions_base"
-
-
-if __name__ == "__main__":
-    config = get_config()
-    learning_rate = config["actions"]["training_base"]["learning_rate"]
-    num_epochs = config["actions"]["training_base"]["num_epochs"]
-    batch_size = config["actions"]["training_base"]["batch_size"]
-    save_step = config["actions"]["training_base"]["save_step"]
-    batch_buffer_size = config["actions"]["training_base"]["batch_buffer_size"]
-    loss_averaging_span = config["actions"]["training_base"]["loss_averaging_span"]
-    fresh_start = config["actions"]["training_base"]["fresh_start"]
-    device_name = config["actions"]["training_base"]["device"]
-    max_train_time = convert_to_timedelta(
-        config["actions"]["training_base"]["max_training_time"]
-    )
-    base_model = config["global"]["base_model"]
-    seed = config["global"]["random_seed"]
-
-    np.random.seed(seed=seed)
-    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
-
-    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
-    tokenizer = BertTokenizerFast.from_pretrained(base_model)
-
-    loader = Loader(OUTPUT_PATH, ActionsModelBase, torch.optim.AdamW, device)
-    if loader.has_checkpoints() and not fresh_start:
-        model, optimizer, epoch_start, sample_start = loader.load_latest()
-    else:
-        params = ActionsModelBaseParams(base_model, len(ACTIONS_KEYS))
-        model = ActionsModelBase(params)
-        model.to(device)
-
-        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
-        epoch_start, sample_start = (0, 0)
-
-    model.train()
-
-    # Load loss weights
-    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
-        stats = pickle.load(f)
-        pos_examples = stats["class_number"]
-        neg_examples = stats["num_examples"] - stats["class_number"]
-        pos_weight = torch.tensor(neg_examples / pos_examples)
-
-    criterion = ActionsModelBaseLoss(pos_weight).to(device)
-
-    random_index_shuffle = random_indexes(df)
-    training_stopped = False
-
-    saver = Saver(OUTPUT_PATH, model, optimizer)
-    checkpoint = Checkpoint(save_step, saver, epoch_start, sample_start)
-    timer = Timeout(max_train_time, saver)
-    tracker = ProgressTracker(device, loss_averaging_span)
-
-    timer.start()
-    for data_batch, epoch, i in training_loop(
-        epoch_start,
-        sample_start,
-        num_epochs,
-        df,
-        batch_size,
-        batch_buffer_size,
-        random_index_shuffle,
-    ):
-        inputs = unflattened_column(data_batch, "source")
-        outputs = unflattened_column(data_batch, "target")
-        attentions_mask = unflattened_column(data_batch, "attention_mask")
-
-        inputs = torch.tensor(inputs, dtype=torch.long).squeeze(dim=-1).to(device)
-        outputs = torch.tensor(outputs, dtype=torch.float).to(device)
-        attentions_mask = torch.tensor(attentions_mask).type(torch.long).to(device)
-
-        y_pred = model(input_ids=inputs, attention_mask=attentions_mask)
-
-        optimizer.zero_grad()
-        loss = criterion(y_pred, outputs)
-
-        tracker.step(epoch, i, loss)
-        checkpoint.step(epoch, i)
-        if timer.step(epoch, i):
-            training_stopped = True
-            break
-
-        loss.backward()
-        optimizer.step()
-
-    if not training_stopped:
-        saver.save("final")
diff --git a/src/pipelines/actions_based/train_ensemble.py b/src/pipelines/actions_based/train_ensemble.py
deleted file mode 100755
index 3e1c5ff0dcf01ccd6128feb957a1f3252320e886..0000000000000000000000000000000000000000
--- a/src/pipelines/actions_based/train_ensemble.py
+++ /dev/null
@@ -1,127 +0,0 @@
-#!/usr/bin/python3
-
-import pickle
-from src.models.actions_model_ensemble import ActionsModelEnsemble, ActionsModelEnsembleLoss, ActionsModelEnsembleParams
-
-import dask.dataframe as dd
-import numpy as np
-import torch
-from transformers import BertTokenizerFast
-
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.utils import (
-    PROJECT_ROOT,
-    Checkpoint,
-    Loader,
-    ProgressTracker,
-    Saver,
-    Timeout,
-    convert_to_timedelta,
-    get_config,
-    random_indexes,
-    training_loop,
-    unflattened_column,
-)
-
-INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
-INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
-OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions_ensemble"
-
-
-if __name__ == "__main__":
-    config = get_config()
-    num_models = config["actions"]["training_ensemble"]["num_models"]
-    embedding_size = config["actions"]["training_ensemble"]["embedding_size"]
-    num_heads = config["actions"]["training_ensemble"]["num_heads"]
-    feedforward_neurons = config["actions"]["training_ensemble"]["feedforward_neurons"]
-    num_labels = len(ACTIONS_KEYS)
-    max_len = 500
-    dropout = config["actions"]["training_ensemble"]["dropout"]
-
-    learning_rate = config["actions"]["training_ensemble"]["learning_rate"]
-    num_epochs = config["actions"]["training_ensemble"]["num_epochs"]
-    batch_size = config["actions"]["training_ensemble"]["batch_size"]
-    save_step = config["actions"]["training_ensemble"]["save_step"]
-    batch_buffer_size = config["actions"]["training_ensemble"]["batch_buffer_size"]
-    loss_averaging_span = config["actions"]["training_ensemble"]["loss_averaging_span"]
-    fresh_start = config["actions"]["training_ensemble"]["fresh_start"]
-    device_name = config["actions"]["training_ensemble"]["device"]
-    max_train_time = convert_to_timedelta(
-        config["actions"]["training_base"]["max_training_time"]
-    )
-    base_model = config["global"]["base_model"]
-    seed = config["global"]["random_seed"]
-
-    np.random.seed(seed=seed)
-    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
-
-    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
-    tokenizer = BertTokenizerFast.from_pretrained(base_model)
-
-    loader = Loader(OUTPUT_PATH, ActionsModelEnsemble, torch.optim.AdamW, device)
-    if loader.has_checkpoints() and not fresh_start:
-        model, optimizer, epoch_start, sample_start = loader.load_latest()
-    else:
-        params = ActionsModelEnsembleParams(base_model, num_models, embedding_size, num_heads, feedforward_neurons, num_labels, max_len, dropout)
-        model = ActionsModelEnsemble(params)
-        model.to(device)
-
-        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
-        epoch_start, sample_start = (0, 0)
-
-    model.train()
-
-    # Load loss weights
-    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
-        stats = pickle.load(f)
-        pos_examples = stats["class_number"]
-        neg_examples = stats["num_examples"] - stats["class_number"]
-        pos_weight = torch.tensor(neg_examples / pos_examples)
-
-    criterion = ActionsModelEnsembleLoss(pos_weight).to(device)
-
-    random_index_shuffle = random_indexes(df)
-    training_stopped = False
-
-    saver = Saver(OUTPUT_PATH, model, optimizer)
-    checkpoint = Checkpoint(save_step, saver, epoch_start, sample_start)
-    timer = Timeout(max_train_time, saver)
-    tracker = ProgressTracker(device, loss_averaging_span)
-
-    timer.start()
-    for data_batch, epoch, i in training_loop(
-        epoch_start,
-        sample_start,
-        num_epochs,
-        df,
-        batch_size,
-        batch_buffer_size,
-        random_index_shuffle,
-    ):
-        inputs = unflattened_column(data_batch, "source")
-        outputs = unflattened_column(data_batch, "target")
-        attentions_mask = unflattened_column(data_batch, "attention_mask")
-
-        inputs = torch.tensor(inputs, dtype=torch.long).squeeze(dim=-1).to(device)
-        outputs = torch.tensor(outputs, dtype=torch.float).to(device)
-        attentions_mask = torch.tensor(attentions_mask).type(torch.long).to(device)
-
-        # Convert to boolean
-        attentions_mask = torch.tensor(attentions_mask == 0).to(device)
-
-        y_pred = model(input_ids=inputs, attention_mask=attentions_mask, step=i)
-
-        optimizer.zero_grad()
-        loss = criterion(y_pred, outputs)
-
-        tracker.step(epoch, i, loss)
-        checkpoint.step(epoch, i)
-        if timer.step(epoch, i):
-            training_stopped = True
-            break
-
-        loss.backward()
-        optimizer.step()
-
-    if not training_stopped:
-        saver.save("final")
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
deleted file mode 100755
index fd44e27b0101501632366e2df949ef4b8a871f08..0000000000000000000000000000000000000000
--- a/src/pipelines/actions_based/train_mixed.py
+++ /dev/null
@@ -1,144 +0,0 @@
-#!/usr/bin/python3
-
-import pickle
-
-import dask.dataframe as dd
-import numpy as np
-import torch
-from transformers import BertTokenizerFast
-
-from src.models.actions_model_mixed import (
-    ActionsModelMixed,
-    ActionsModelMixedLoss,
-    ActionsModelMixedParams,
-)
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.utils import (
-    PROJECT_ROOT,
-    Checkpoint,
-    Loader,
-    ProgressTracker,
-    Saver,
-    Timeout,
-    convert_to_timedelta,
-    get_config,
-    random_indexes,
-    training_loop,
-    unflattened_column,
-)
-
-INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
-INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
-OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions_mixed"
-
-
-if __name__ == "__main__":
-    config = get_config()
-    threshold = config["actions"]["training_mixed"]["threshold"]
-    embedding_size = config["actions"]["training_mixed"]["embedding_size"]
-    num_heads = config["actions"]["training_mixed"]["num_heads"]
-    num_layers = config["actions"]["training_mixed"]["num_layers"]
-    dropout = config["actions"]["training_mixed"]["dropout"]
-    feedforward_neurons = config["actions"]["training_mixed"]["feedforward_neurons"]
-    learning_rate = config["actions"]["training_mixed"]["learning_rate"]
-    num_epochs = config["actions"]["training_mixed"]["num_epochs"]
-    batch_size = config["actions"]["training_mixed"]["batch_size"]
-    save_step = config["actions"]["training_mixed"]["save_step"]
-    batch_buffer_size = config["actions"]["training_mixed"]["batch_buffer_size"]
-    loss_averaging_span = config["actions"]["training_mixed"]["loss_averaging_span"]
-    fresh_start = config["actions"]["training_mixed"]["fresh_start"]
-    device_name = config["actions"]["training_mixed"]["device"]
-    max_train_time = convert_to_timedelta(
-        config["actions"]["training_mixed"]["max_training_time"]
-    )
-    base_model = config["global"]["base_model"]
-    seed = config["global"]["random_seed"]
-
-    np.random.seed(seed=seed)
-    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
-
-    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
-    tokenizer = BertTokenizerFast.from_pretrained(base_model)
-
-    loader = Loader(OUTPUT_PATH, ActionsModelMixed, torch.optim.AdamW, device)
-
-    if loader.has_checkpoints() and not fresh_start:
-        model, optimizer, epoch_start, sample_start = loader.load_latest()
-    else:
-        params = ActionsModelMixedParams(
-            base_model,
-            tokenizer.vocab_size,
-            threshold,
-            embedding_size,
-            num_heads,
-            num_layers,
-            feedforward_neurons,
-            len(ACTIONS_KEYS),
-            500,
-            dropout,
-        )
-        model = ActionsModelMixed(params)
-        model.to(device)
-        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
-        epoch_start, sample_start = (0, 0)
-
-    model.train()
-
-    # Load loss weights
-    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
-        stats = pickle.load(f)
-        pos_examples = stats["class_number"]
-        neg_examples = stats["num_examples"] - stats["class_number"]
-        pos_weight = torch.tensor(neg_examples / pos_examples)
-
-    criterion = ActionsModelMixedLoss(pos_weight).to(device)
-
-    random_index_shuffle = random_indexes(df)
-    training_stopped = False
-
-    saver = Saver(OUTPUT_PATH, model, optimizer)
-    checkpoint = Checkpoint(save_step, saver, epoch_start, sample_start)
-    timer = Timeout(max_train_time, saver)
-    tracker = ProgressTracker(device, loss_averaging_span)
-
-    timer.start()
-    for data_batch, epoch, i in training_loop(
-        epoch_start,
-        sample_start,
-        num_epochs,
-        df,
-        batch_size,
-        batch_buffer_size,
-        random_index_shuffle,
-    ):
-        inputs = unflattened_column(data_batch, "source")
-        outputs = unflattened_column(data_batch, "target")
-        attentions_mask = unflattened_column(data_batch, "attention_mask")
-
-        inputs = torch.tensor(inputs, dtype=torch.long).to(device).squeeze(dim=2)
-
-        outputs = torch.tensor(outputs, dtype=torch.float).to(device)
-
-        # Convert to boolean
-        attentions_mask = torch.tensor(attentions_mask == 0).to(device)
-
-        y_pred = model(
-            input_ids=inputs,
-            actions=outputs[:, :-1, :],
-            attention_mask=attentions_mask,
-        )
-
-        loss = criterion(outputs[:, 1:, :], y_pred)
-        optimizer.zero_grad()
-
-        tracker.step(epoch, i, loss)
-        checkpoint.step(epoch, i)
-        if timer.step(epoch, i):
-            training_stopped = True
-            break
-
-        loss.backward()
-        optimizer.step()
-
-    if not training_stopped:
-        saver.save("final")
diff --git a/src/pipelines/actions_based/train_restricted.py b/src/pipelines/actions_based/train_restricted.py
deleted file mode 100755
index ed43789b0fd7cd7866a41baeccf901293cb0a171..0000000000000000000000000000000000000000
--- a/src/pipelines/actions_based/train_restricted.py
+++ /dev/null
@@ -1,146 +0,0 @@
-#!/usr/bin/python3
-
-import pickle
-
-import dask.dataframe as dd
-import numpy as np
-import torch
-from transformers import BertTokenizerFast
-
-from src.models.actions_model_restricted import (
-    ActionsModelRestricted,
-    ActionsModelRestrictedLoss,
-    ActionsModelRestrictedParams,
-)
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.utils import (
-    PROJECT_ROOT,
-    Checkpoint,
-    Loader,
-    ProgressTracker,
-    Saver,
-    Timeout,
-    convert_to_timedelta,
-    get_config,
-    random_indexes,
-    training_loop,
-    unflattened_column,
-)
-
-INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
-INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
-OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions_restricted"
-
-
-if __name__ == "__main__":
-
-    config = get_config()
-    learning_rate = config["actions"]["training_restricted"]["learning_rate"]
-    num_epochs = config["actions"]["training_restricted"]["num_epochs"]
-    batch_size = config["actions"]["training_restricted"]["batch_size"]
-    save_step = config["actions"]["training_restricted"]["save_step"]
-    batch_buffer_size = config["actions"]["training_restricted"]["batch_buffer_size"]
-    loss_averaging_span = config["actions"]["training_restricted"][
-        "loss_averaging_span"
-    ]
-    fresh_start = config["actions"]["training_restricted"]["fresh_start"]
-    device_name = config["actions"]["training_restricted"]["device"]
-    max_train_time = convert_to_timedelta(
-        config["actions"]["training_restricted"]["max_training_time"]
-    )
-    base_model = config["global"]["base_model"]
-    seed = config["global"]["random_seed"]
-
-    np.random.seed(seed=seed)
-    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
-
-    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
-    tokenizer = BertTokenizerFast.from_pretrained(base_model)
-
-    loader = Loader(OUTPUT_PATH, ActionsModelRestricted, torch.optim.AdamW, device)
-    if loader.has_checkpoints() and not fresh_start:
-        model, optimizer, epoch_start, sample_start = loader.load_latest()
-    else:
-        params = ActionsModelRestrictedParams(base_model, len(ACTIONS_KEYS) + 1)
-        model = ActionsModelRestricted(params)
-        model.to(device)
-        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
-        epoch_start, sample_start = (0, 0)
-
-    model.train()
-
-    # Load loss weights
-    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
-        stats = pickle.load(f)
-        pos_examples = stats["class_number"]
-        neg_examples = stats["num_examples"] - stats["class_number"]
-
-        uppercase_pos_examples = pos_examples[0]
-        uppercase_neg_examples = neg_examples[0]
-        uppercase_pos_odds = torch.tensor(
-            uppercase_pos_examples / uppercase_neg_examples, dtype=torch.float
-        )
-
-        has_punctuation_neg_examples = neg_examples[1:]
-        has_no_punctuation_neg_examples = np.sum(pos_examples[1:])
-
-        punctuation_neg_examples = np.concatenate(
-            [has_punctuation_neg_examples, has_no_punctuation_neg_examples.reshape(1)],
-            -1,
-        )
-
-        punctuation_class_weights = torch.tensor(
-            (punctuation_neg_examples) / np.sum(punctuation_neg_examples),
-            dtype=torch.float,
-        )
-
-    criterion = ActionsModelRestrictedLoss(
-        uppercase_pos_odds, punctuation_class_weights
-    ).to(device)
-
-    random_index_shuffle = random_indexes(df)
-    training_stopped = False
-
-    saver = Saver(OUTPUT_PATH, model, optimizer)
-    checkpoint = Checkpoint(save_step, saver, epoch_start, sample_start)
-    timer = Timeout(max_train_time, saver)
-    tracker = ProgressTracker(device, loss_averaging_span)
-
-    timer.start()
-    for data_batch, epoch, i in training_loop(
-        epoch_start,
-        sample_start,
-        num_epochs,
-        df,
-        batch_size,
-        batch_buffer_size,
-        random_index_shuffle,
-    ):
-        inputs = unflattened_column(data_batch, "source")
-        outputs = unflattened_column(data_batch, "target")
-        attentions_mask = unflattened_column(data_batch, "attention_mask")
-
-        inputs = torch.tensor(inputs, dtype=torch.long).squeeze(dim=-1).to(device)
-        outputs = torch.tensor(outputs, dtype=torch.float).to(device)
-        attentions_mask = torch.tensor(attentions_mask).to(device)
-
-        y_pred = model(input_ids=inputs, attention_mask=attentions_mask)
-
-        outputs = torch.cat(
-            [outputs, (1.0 - outputs[:, :, 1:].max(-1)[0]).unsqueeze(-1)], axis=-1
-        )
-
-        loss = criterion(y_pred, outputs)
-        optimizer.zero_grad()
-
-        tracker.step(epoch, i, loss)
-        checkpoint.step(epoch, i)
-        if timer.step(epoch, i):
-            training_stopped = True
-            break
-
-        loss.backward()
-        optimizer.step()
-
-    if not training_stopped:
-        saver.save("final")
diff --git a/src/predictors/actions_model_base.py b/src/predictors/actions_model_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bdb4a5f76065b39afe24d9635fc0602e4168fcb
--- /dev/null
+++ b/src/predictors/actions_model_base.py
@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+import numpy as np
+import torch
+from transformers.tokenization_bert import BertTokenizerFast
+
+from src.models.actions_model_base import ActionsModelBase
+from src.models.interfaces import ActionsModel
+from src.pipelines.actions_based.processing import (
+    action_vector,
+    last_stop_label,
+    recover_text,
+    token_labels_to_word_labels,
+)
+from src.pipelines.actions_based.utils import max_suppression
+from src.utils import get_device
+
+
+class ActionsModelBasePredictor(ActionsModel):
+    def __init__(
+        self, model: ActionsModelBase, threshold: float = 0.9, max_len: int = 500
+    ) -> None:
+        self.model = model
+        self.threshold = threshold
+        self.max_len = max_len
+
+    def predict_raw(
+        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Function that maps input_ids tensors into per-token labels
+
+        Args:
+            input_ids (torch.Tensor): Token ids of input. Shape BxL
+            attention_mask (torch.Tensor): Attention mask of tokens. Shape BxL
+
+        Returns:
+            torch.Tensor: Per-token action-vector labels. Shape BxLxA
+        """
+
+        return self.model(input_ids, attention_mask=attention_mask).sigmoid()
+
+    def predict(self, text: str) -> str:
+        text = text.strip()
+
+        device = get_device(self.model)
+
+        tokenizer = self.tokenizer()
+        tokens = tokenizer(text, return_tensors="pt")["input_ids"].to(device)
+        output = None
+
+        index_start = 0
+        while index_start < len(tokens[0]):
+            index_end = min(index_start + self.max_len, len(tokens[0]))
+
+            tokens_chunk = tokens[:, index_start:index_end]
+            attention_mask = torch.ones_like(tokens_chunk).to(device)
+
+            actions = (
+                self.predict_raw(tokens_chunk, attention_mask).detach().cpu().numpy()
+            )
+            actions_suppresed = max_suppression(actions, self.threshold)[0]
+
+            offset = last_stop_label(actions_suppresed, action_vector(["dot"]))
+
+            # Prevent infinite loop
+            if (offset is None) or (offset == 0):
+                offset = index_end - index_start
+
+            if output is None:
+                output = actions[0, 0:offset]
+            else:
+                output = np.concatenate([output, actions[0, 0:offset]], axis=0)
+
+            index_start += offset
+
+        assert len(output) == len(tokens[0])
+
+        word_labels = token_labels_to_word_labels(text, output[1:-1], tokenizer)
+        actions = max_suppression(np.expand_dims(word_labels, 0), self.threshold)[0]
+
+        return recover_text(text, actions)
+
+    def tokenizer(self) -> BertTokenizerFast:
+        return self.model.tokenizer
diff --git a/src/predictors/actions_model_ensemble.py b/src/predictors/actions_model_ensemble.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa41fc4298a3700dd97f47dfbfc88120e1018c29
--- /dev/null
+++ b/src/predictors/actions_model_ensemble.py
@@ -0,0 +1,88 @@
+import numpy as np
+import torch
+from transformers.tokenization_bert import BertTokenizerFast
+
+from src.models.actions_model_ensemble import ActionsModelEnsemble
+from src.models.interfaces import ActionsModel
+from src.pipelines.actions_based.processing import (
+    action_vector,
+    last_stop_label,
+    recover_text,
+    token_labels_to_word_labels,
+)
+from src.pipelines.actions_based.utils import max_suppression
+from src.utils import get_device
+
+
+class ActionsModelEnsemblePredictor(ActionsModel):
+    def __init__(
+        self, model: ActionsModelEnsemble, threshold: float = 0.9, max_len: int = 500
+    ) -> None:
+        self.model = model
+        self.threshold = threshold
+        self.max_len = max_len
+
+    def predict(self, text: str) -> str:
+        text = text.strip()
+
+        device = get_device(self.model)
+
+        tokenizer = self.tokenizer()
+        tokens = tokenizer(text, return_tensors="pt")["input_ids"].to(device)
+        output = None
+
+        index_start = 0
+        while index_start < len(tokens[0]):
+            index_end = min(index_start + self.max_len, len(tokens[0]))
+
+            tokens_chunk = tokens[:, index_start:index_end]
+            attention_mask = torch.ones_like(tokens_chunk).to(device)
+
+            actions = (
+                self.predict_raw(tokens_chunk, attention_mask).detach().cpu().numpy()
+            )
+            actions_suppresed = max_suppression(actions, self.threshold)[0]
+
+            offset = last_stop_label(actions_suppresed, action_vector(["dot"]))
+
+            # Prevent infinite loop
+            if (offset is None) or (offset == 0):
+                offset = index_end - index_start
+
+            if output is None:
+                output = actions[0, 0:offset]
+            else:
+                output = np.concatenate([output, actions[0, 0:offset]], axis=0)
+
+            index_start += offset
+
+        assert len(output) == len(tokens[0])
+
+        word_labels = token_labels_to_word_labels(text, output[1:-1], tokenizer)
+        actions = max_suppression(np.expand_dims(word_labels, 0), self.threshold)[0]
+
+        return recover_text(text, actions)
+
+    def tokenizer(self) -> BertTokenizerFast:
+        return self.model.tokenizer
+
+    def predict_raw(
+        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Function that maps input_ids tensors into per-token labels
+
+        Args:
+            input_ids (torch.Tensor): Token ids of input. Shape BxL
+            attention_mask (torch.Tensor): Attention mask of tokens. Shape BxL
+
+        Returns:
+            torch.Tensor: Per-token action-vector labels. Shape BxLxA
+        """
+        predictions = self.model(
+            input_ids, attention_mask=(attention_mask == 0)
+        ).sigmoid()
+
+        # TODO: Compare consensus types (mean / median / minimal votes etc.)
+        predictions_consensus = predictions.mean(0)
+
+        return predictions_consensus
diff --git a/src/predictors/actions_model_mixed.py b/src/predictors/actions_model_mixed.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdad3255598f203b993862da0696f6c258076d37
--- /dev/null
+++ b/src/predictors/actions_model_mixed.py
@@ -0,0 +1,92 @@
+import numpy as np
+import torch
+from transformers.tokenization_bert import BertTokenizerFast
+
+from src.models.actions_model_mixed import ActionsModelMixed
+from src.models.interfaces import ActionsModel
+from src.pipelines.actions_based.processing import (
+    action_vector,
+    recover_text,
+    token_labels_to_word_labels,
+)
+from src.utils import get_device
+
+
+class ActionsModelMixedPredictor(ActionsModel):
+    def __init__(
+        self, model: ActionsModelMixed, threshold: float = 0.9, max_len: int = 500
+    ) -> None:
+        self.model = model
+        self.threshold = threshold
+        self.max_len = max_len
+
+    def predict(self, text: str) -> str:
+
+        # TODO: Optimize for speed
+        inputs = [action_vector(["upper_case"])]
+
+        tokenizer = self.tokenizer()
+        text_tokenized = tokenizer(text, return_tensors="pt")
+
+        target_device = get_device(self.model)
+
+        max_cond_len = self.max_len
+        if max_cond_len is None:
+            max_cond_len = np.iinfo(np.int).max
+
+        for _ in range(text_tokenized["input_ids"].shape[1] - 2):
+            input_start = max(0, len(inputs) - max_cond_len)
+
+            prediction_raw = self.model(
+                text_tokenized["input_ids"][:, input_start:].to(target_device),
+                torch.tensor(inputs[input_start:], dtype=torch.float)
+                .reshape(1, -1, self.model.num_labels)
+                .to(target_device),
+                (text_tokenized["attention_mask"][:, input_start:] == 0).to(
+                    target_device
+                ),
+            ).sigmoid()
+
+            inputs.append(
+                (
+                    prediction_raw.detach().cpu().numpy()[0, -1, :] > self.threshold
+                ).astype(np.float)
+            )
+
+        word_labels = token_labels_to_word_labels(text, inputs[1:], tokenizer)
+
+        prediction_binary = word_labels.astype(np.int)
+
+        return recover_text(text, prediction_binary)
+
+    def predict_raw(
+        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Function that maps input_ids tensors into per-token labels
+
+        Args:
+            input_ids (torch.Tensor): Token ids of input. Shape BxL
+            attention_mask (torch.Tensor): Attention mask of tokens. Shape BxL
+
+        Returns:
+            torch.Tensor: Per-token action-vector labels. Shape BxLxA
+        """
+        outputs = torch.tensor(action_vector(["upper_case"]), dtype=torch.float).to(
+            input_ids.device
+        )
+        outputs = outputs.unsqueeze(0).unsqueeze(0).repeat(input_ids.shape[0], 1, 1)
+
+        for _ in range(input_ids.shape[1] - 1):
+            prediction_raw = self.forward(
+                input_ids, outputs, (attention_mask == 0)
+            ).sigmoid()
+
+            prediction_raw = (prediction_raw[:, -1:, :] > self.threshold).type(
+                torch.float
+            )
+            outputs = torch.cat([outputs, prediction_raw], dim=1)
+
+        return outputs
+
+    def tokenizer(self) -> BertTokenizerFast:
+        return self.model.tokenizer
diff --git a/src/predictors/actions_model_restricted.py b/src/predictors/actions_model_restricted.py
new file mode 100644
index 0000000000000000000000000000000000000000..343fdcaab46c1878e6ed66c0a33725274d5b2bb6
--- /dev/null
+++ b/src/predictors/actions_model_restricted.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+import numpy as np
+import torch
+from transformers.tokenization_bert import BertTokenizerFast
+
+from src.models.actions_model_restricted import ActionsModelRestricted
+from src.models.interfaces import ActionsModel
+from src.pipelines.actions_based.processing import (
+    action_vector,
+    last_stop_label,
+    recover_text,
+    token_labels_to_word_labels,
+)
+from src.pipelines.actions_based.utils import max_suppression
+from src.utils import get_device
+
+
+class ActionsModelRestrictedPredictor(ActionsModel):
+    def __init__(
+        self, model: ActionsModelRestricted, threshold: float = 0.9, max_len: int = 500
+    ) -> None:
+        self.model = model
+        self.threshold = threshold
+        self.max_len = max_len
+
+    def predict_raw(
+        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Function that maps input_ids tensors into per-token labels
+
+        Args:
+            input_ids (torch.Tensor): Token ids of input. Shape BxL
+            attention_mask (torch.Tensor): Attention mask of tokens. Shape BxL
+
+        Returns:
+            torch.Tensor: Per-token action-vector labels. Shape BxLxA
+        """
+
+        logits = self.model(input_ids, attention_mask=attention_mask)
+        prob_uppercase = logits[:, :, :1].sigmoid()
+        prob_punctuation = logits[:, :, 1:].softmax(dim=-1)
+
+        no_punctuation = prob_punctuation.argmax(-1) == (self.model.num_labels - 2)
+        no_punctuation = (
+            no_punctuation.type(torch.float)
+            .unsqueeze(-1)
+            .repeat(1, 1, prob_punctuation.shape[-1] - 1)
+        )
+
+        prob_punctuation = prob_punctuation[:, :, :-1].softmax(-1) * (
+            1 - no_punctuation
+        )
+
+        return torch.cat([prob_uppercase, prob_punctuation], dim=-1)
+
+    def predict(self, text: str) -> str:
+        chunk_size = self.max_len
+        threshold = self.threshold
+
+        device = get_device(self.model)
+
+        text = text.strip()
+
+        tokenizer = self.tokenizer()
+        tokens = tokenizer(text, return_tensors="pt")["input_ids"].to(device)
+        output = None
+
+        index_start = 0
+        while index_start < len(tokens[0]):
+            index_end = min(index_start + chunk_size, len(tokens[0]))
+
+            tokens_chunk = tokens[:, index_start:index_end]
+
+            attention_mask = torch.ones_like(tokens_chunk).to(device)
+
+            actions = (
+                self.predict_raw(tokens_chunk, attention_mask).detach().cpu().numpy()
+            )
+            actions_suppresed = max_suppression(actions, threshold)[0]
+
+            offset = last_stop_label(actions_suppresed, action_vector(["dot"]))
+
+            # Prevent infinite loop
+            if (offset is None) or (offset == 0):
+                offset = index_end - index_start
+
+            if output is None:
+                output = actions[0, 0:offset]
+            else:
+                output = np.concatenate([output, actions[0, 0:offset]], axis=0)
+
+            index_start += offset
+
+        assert len(output) == len(tokens[0])
+
+        word_labels = token_labels_to_word_labels(text, output[1:-1], tokenizer)
+        actions = max_suppression(np.expand_dims(word_labels, 0), threshold)[0]
+
+        return recover_text(text, actions)
+
+    def tokenizer(self) -> BertTokenizerFast:
+        return self.model.tokenizer
diff --git a/src/predictors/factory.py b/src/predictors/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a7d3c736533f88f598b1acfd5de07edad0a385d
--- /dev/null
+++ b/src/predictors/factory.py
@@ -0,0 +1,15 @@
+from typing import Dict
+
+from pytorch_lightning.core.lightning import LightningModule
+
+from src.predictors.actions_model_base import ActionsModelBasePredictor
+from src.predictors.actions_model_ensemble import ActionsModelEnsemblePredictor
+from src.predictors.actions_model_mixed import ActionsModelMixedPredictor
+from src.predictors.actions_model_restricted import ActionsModelRestrictedPredictor
+
+ACTION_MODEL_PREDICTORS: Dict[str, LightningModule] = {
+    "actions_base": ActionsModelBasePredictor,
+    "actions_ensemble": ActionsModelEnsemblePredictor,
+    "actions_mixed": ActionsModelMixedPredictor,
+    "actions_restricted": ActionsModelRestrictedPredictor,
+}
diff --git a/src/utils.py b/src/utils.py
index 0bb292f6fa74e073ae5e8629171be2fbbd091c7f..2e50707319da67f6a7703a28198a2dbc79b9d559 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -1,12 +1,15 @@
 from __future__ import annotations
 
+import math
 import os
 import pickle
 import re
 import shutil
+import warnings
 from datetime import datetime, timedelta
 from glob import glob
-from typing import Generator, List, Optional, Tuple, Type
+from toolz.functoolz import juxt
+from typing import Generator, List, Optional, Tuple, Type, Union
 
 import dask.dataframe as dd
 import numpy as np
@@ -14,6 +17,12 @@ import pandas as pd
 import torch
 import torch.nn as nn
 import yaml
+from pytorch_lightning.callbacks.base import Callback
+from pytorch_lightning.core.datamodule import LightningDataModule
+from pytorch_lightning.core.lightning import LightningModule
+from pytorch_lightning.loggers.base import LightningLoggerBase
+from pytorch_lightning.loggers.comet import CometLogger
+from pytorch_lightning.trainer.trainer import Trainer
 from torch.optim import Optimizer
 
 from src.batch_loading import get_batches, get_ordered_dataframe_len
@@ -22,6 +31,50 @@ from src.models.interfaces import PunctuationModel
 PROJECT_ROOT = os.path.dirname(os.path.realpath("/".join(__file__.split("/")) + "/.."))
 
 
+class LRFinder(Callback):
+    def on_fit_start(self, trainer: Trainer, pl_module: LightningModule):
+        try:
+            finder = trainer.lr_find(pl_module, trainer.datamodule)
+            lr_min, lr_max = get_cycling_loss_range(
+                finder.results["lr"], finder.results["loss"]
+            )
+
+            if hasattr(pl_module, "lr_find_results"):
+                setattr(pl_module, "lr_find_results", finder.results)
+            if hasattr(pl_module, "learning_rate"):
+                setattr(pl_module, "learning_rate", finder.suggestion())
+                print(f"Set learning_rate to {finder.suggestion()}")
+            if hasattr(pl_module, "learning_rate_min"):
+                setattr(pl_module, "learning_rate_min", lr_min)
+                print(f"Set learning_rate_min to {lr_min}")
+            if hasattr(pl_module, "learning_rate_max"):
+                setattr(pl_module, "learning_rate_max", lr_max)
+                print(f"Set learning_rate_max to {lr_max}")
+        except ValueError:
+            warnings.warn("Too litle datapoints in LR-sweep curve")
+
+
+class StepCheckpoint(Callback):
+    """Utility callback class for creating checkpoint every n-th global step
+
+    Args:
+        Callback ([type]): [description]
+    """
+
+    def __init__(self, save_path: str, step_interval: int) -> None:
+        self._step_interval = step_interval
+        self._save_path = save_path
+
+    def on_batch_end(self, trainer: Trainer, pl_module: LightningModule):
+        epoch = trainer.current_epoch
+        step = trainer.global_step
+
+        if (step != 0) and (step % self._step_interval == 0):
+            filename = f"epoch{epoch}-step{step}.ckpt"
+            path = os.path.join(self._save_path, filename)
+            trainer.save_checkpoint(path)
+
+
 class Saver:
     """Class that allows saving and loading mode-optimizer pairs"""
 
@@ -618,3 +671,114 @@ def get_device(model: nn.Module) -> torch.device:
     """
 
     return next(model.parameters()).device
+
+
+def get_logger(name: str) -> Union[LightningLoggerBase, bool]:
+    if "COMMET_API_KEY" in os.environ.keys():
+        return CometLogger(
+            api_key=os.environ["COMMET_API_KEY"], project_name=name, save_dir="."
+        )
+    else:
+        warnings.warn(
+            "Did not found COMMET_API_KEY environmental varialbe. Using default"
+            " logger..."
+        )
+        # In pytorch-lightning true means using default logger
+        return True
+
+
+def get_cycling_loss_range(
+    lr: np.ndarray,
+    loss: np.ndarray,
+    max_derv_l: float = -math.pi / 6,
+    max_derv_r: float = -math.pi / 12,
+):
+    """Finds a range for learning loss for 1cycle or CLR
+
+    Args:
+        lr (np.ndarray): Array of learning rates used in LR test
+        loss (np.ndarray): Array of losses obtained in LR test
+        max_derv (float): Maximum angle of lr/loss graph for minimal lr
+
+    Returns:
+        [type]: [description]
+    """
+
+    loss_derv = np.gradient(loss).reshape(-1)
+
+    min_point = loss_derv.argmin()
+
+    min_lr = None
+    for i in range(min_point, -1, -1):
+        if loss_derv[i] < math.tan(max_derv_l):
+            min_lr = lr[i]
+        else:
+            break
+
+    max_lr = None
+    for i in range(min_point, len(loss)):
+        if loss_derv[i] < math.tan(max_derv_r):
+            max_lr = lr[i]
+        else:
+            break
+
+    return min_lr, max_lr
+
+
+def link_datamodule(model: LightningModule, datamodule: LightningDataModule):
+    """Assings all the function needed to fit() model without passing datamodule to fit
+
+    Args:
+        model (LightningModule): Target model
+        datamodule (LightningDataModule): Source datamodule
+    """
+
+    model.setup = juxt(model.setup, datamodule.setup)
+    model.train_dataloader = datamodule.train_dataloader
+    model.val_dataloader = datamodule.val_dataloader
+    model.test_dataloader = datamodule.test_dataloader
+
+
+def ifnone(x: Optional[any], default: any) -> any:
+    """Returns a default value if x evaluates to None
+
+    Args:
+        x (Optional[any]): Potentialy nullable value
+        default (any): Default value for None case
+
+    Returns:
+        any: [description]
+    """
+    return default if x is None else x
+
+
+def onecycle_shed(method):
+    def configure_optimizers(self):
+        optimizer = method(self)
+
+        batches_per_epoch = self.trainer.num_training_batches
+
+        if (
+            batches_per_epoch > 0
+            and self.learning_rate_min is not None
+            and self.learning_rate_max is not None
+        ):
+            num_up = batches_per_epoch // 2
+            num_down = batches_per_epoch - num_up
+
+            scheduler = torch.optim.lr_scheduler.CyclicLR(
+                optimizer,
+                self.learning_rate_min,
+                self.learning_rate_max,
+                step_size_up=num_up,
+                step_size_down=num_down,
+                cycle_momentum=False,
+            )
+
+            print("Using OneCycle LR scheduler")
+
+            return [optimizer], [scheduler]
+        else:
+            return optimizer
+
+    return configure_optimizers
diff --git a/tests/models/test_actions_model ensemble.py b/tests/models/test_actions_model ensemble.py
new file mode 100644
index 0000000000000000000000000000000000000000..c49d1addbe0ddf1d132222e03e9c97e930bfb6a5
--- /dev/null
+++ b/tests/models/test_actions_model ensemble.py	
@@ -0,0 +1,37 @@
+import numpy as np
+from transformers.tokenization_bert import BertTokenizerFast
+
+from src.models.actions_model_ensemble import ActionsModelEnsemble
+
+
+def test_dimensions():
+    base_model = "dkleczek/bert-base-polish-cased-v1"
+    action_vector_size = 5
+
+    tokens = BertTokenizerFast.from_pretrained(base_model)(
+        "Ala ma kota", return_tensors="pt"
+    )
+
+    num_models = 2
+
+    loss_weights = np.ones(action_vector_size)
+    model = ActionsModelEnsemble(
+        base_model, loss_weights, 1e-4, num_models, 10, 2, 10, action_vector_size
+    )
+    result = model(tokens["input_ids"], tokens["attention_mask"] == 0, 0)
+
+    assert len(result.shape) == 4
+
+    assert result.shape[0] == num_models
+    assert result.shape[1] == tokens["input_ids"].shape[0]
+    assert result.shape[2] == tokens["input_ids"].shape[1]
+    assert result.shape[3] == action_vector_size
+
+    # Test negative step
+    result = model(tokens["input_ids"], tokens["attention_mask"] == 0, -1)
+    assert len(result.shape) == 4
+
+    assert result.shape[0] == num_models
+    assert result.shape[1] == tokens["input_ids"].shape[0]
+    assert result.shape[2] == tokens["input_ids"].shape[1]
+    assert result.shape[3] == action_vector_size
diff --git a/tests/models/test_actions_model_base.py b/tests/models/test_actions_model_base.py
index 900cf89a8d9918d64e46c60d19c89bee75e9e0d8..ccdd7760137b7a566df22755fbd409f556cede52 100644
--- a/tests/models/test_actions_model_base.py
+++ b/tests/models/test_actions_model_base.py
@@ -1,12 +1,7 @@
-import torch
+import numpy as np
 from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.actions_model_base import (
-    ActionsModelBase,
-    ActionsModelBaseLoss,
-    ActionsModelBaseParams,
-)
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.models.actions_model_base import ActionsModelBase
 
 
 def test_dimensions():
@@ -17,8 +12,8 @@ def test_dimensions():
         "Ala ma kota", return_tensors="pt"
     )
 
-    params = ActionsModelBaseParams(base_model, action_vector_size)
-    model = ActionsModelBase(params)
+    loss_weights = np.ones(action_vector_size)
+    model = ActionsModelBase(base_model, action_vector_size, loss_weights=loss_weights)
 
     result = model(tokens["input_ids"], tokens["attention_mask"])
 
@@ -27,33 +22,3 @@ def test_dimensions():
     assert result.shape[0] == tokens["input_ids"].shape[0]
     assert result.shape[1] == tokens["input_ids"].shape[1]
     assert result.shape[2] == action_vector_size
-
-
-def test_loss_dimensions():
-    batch_size = 5
-    sequence_len = 10
-    actions_size = 3
-    weights = torch.zeros(actions_size) + 0.3
-    actions_vector_true = torch.zeros((batch_size, sequence_len, actions_size))
-    actions_vector_bad = torch.ones((batch_size, sequence_len, actions_size))
-    loss = ActionsModelBaseLoss(weights)
-
-    result = loss(actions_vector_bad, actions_vector_true)
-    assert len(result.shape) == 0
-
-    result_perfect = loss(actions_vector_true, actions_vector_true)
-    result_bad = loss(actions_vector_bad, actions_vector_true)
-
-    assert result_perfect < result_bad
-
-
-def test_predict():
-    params = ActionsModelBaseParams(
-        "dkleczek/bert-base-polish-cased-v1", len(ACTIONS_KEYS)
-    )
-    model = ActionsModelBase(params)
-
-    input_str = "testowy ciag znakow"
-    result = model.predict(input_str)
-
-    assert len(result) >= len(input_str)
diff --git a/tests/models/test_actions_model_mixed.py b/tests/models/test_actions_model_mixed.py
index 786136dbe6abdb0a09e1dfd3413df8fe485feac7..59f593fb213a3a9a1821fcfba18eff1978e16a53 100644
--- a/tests/models/test_actions_model_mixed.py
+++ b/tests/models/test_actions_model_mixed.py
@@ -1,13 +1,8 @@
+import numpy as np
 import torch
 from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.actions_model_mixed import (
-    ActionsModelMixed,
-    ActionsModelMixedLoss,
-    ActionsModelMixedParams,
-    ActionsModelMixedRuntimeParams,
-)
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.models.actions_model_mixed import ActionsModelMixed
 
 
 def test_dimensions():
@@ -18,17 +13,18 @@ def test_dimensions():
     tokens = tokenizer("Ala ma kota", return_tensors="pt")
 
     embedding_size = 20
-    threshold = 0.9
     num_heads = 2
     num_layers = 2
     feedforward_neurons = 10
     max_len = 500
     dropout = 0.1
 
-    params = ActionsModelMixedParams(
+    weights = np.ones(action_vector_size)
+
+    model = ActionsModelMixed(
         base_model,
-        tokenizer.vocab_size,
-        threshold,
+        weights,
+        1e-4,
         embedding_size,
         num_heads,
         num_layers,
@@ -37,58 +33,16 @@ def test_dimensions():
         max_len,
         dropout,
     )
-    model = ActionsModelMixed(params)
 
     actions_len = 3
     actions = torch.distributions.Multinomial(
         1, torch.tensor([0.5] * action_vector_size)
     ).sample((tokens["input_ids"].shape[0], actions_len))
 
-    result = model(tokens["input_ids"], actions, tokens["attention_mask"])
+    result = model(tokens["input_ids"], actions, tokens["attention_mask"] == 0)
 
     assert len(result.shape) == 3
 
     assert result.shape[0] == tokens["input_ids"].shape[0]
     assert result.shape[1] == actions_len
     assert result.shape[2] == action_vector_size
-
-
-def test_loss_dimensions():
-    batch_size = 5
-    sequence_len = 10
-    actions_size = 3
-    prior_odds = torch.zeros(actions_size) + 0.3
-    actions_vector_true = torch.zeros((batch_size, sequence_len, actions_size))
-    actions_vector_bad = torch.ones((batch_size, sequence_len, actions_size))
-    loss = ActionsModelMixedLoss(prior_odds)
-
-    result = loss(actions_vector_true, actions_vector_bad)
-    assert len(result.shape) == 0
-
-    result_perfect = loss(actions_vector_true, actions_vector_true)
-    result_bad = loss(actions_vector_true, actions_vector_bad)
-
-    assert result_perfect < result_bad
-
-
-def test_predict():
-    tokenizer = BertTokenizerFast.from_pretrained("dkleczek/bert-base-polish-cased-v1")
-    params = ActionsModelMixedParams(
-        "dkleczek/bert-base-polish-cased-v1",
-        tokenizer.vocab_size,
-        0.9,
-        10,
-        2,
-        1,
-        10,
-        len(ACTIONS_KEYS),
-        500,
-        0.1,
-    )
-    runtime = ActionsModelMixedRuntimeParams(0.9, 100)
-    model = ActionsModelMixed(params, runtime)
-
-    input_str = "testowy ciag znakow"
-    result = model.predict(input_str)
-
-    assert len(result) >= len(input_str)
diff --git a/tests/models/test_actions_model_restricted.py b/tests/models/test_actions_model_restricted.py
index b659b271d12bd88f00f9ebfbb5b01ae868fcb735..8b46881e4e4584e130fd870e6ed8c4495ab91eb2 100644
--- a/tests/models/test_actions_model_restricted.py
+++ b/tests/models/test_actions_model_restricted.py
@@ -1,11 +1,7 @@
-import torch
+import numpy as np
 from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.actions_model_restricted import (
-    ActionsModelRestricted,
-    ActionsModelRestrictedLoss,
-    ActionsModelRestrictedParams,
-)
+from src.models.actions_model_restricted import ActionsModelRestricted
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
 
 
@@ -17,8 +13,15 @@ def test_dimensions():
         "Ala ma kota", return_tensors="pt"
     )
 
-    params = ActionsModelRestrictedParams(base_model, action_vector_size)
-    model = ActionsModelRestricted(params)
+    num_labels = len(ACTIONS_KEYS) + 1
+    uppercase_weights = np.ones(1)
+    punc_weights = np.ones(num_labels - 1)
+    model = ActionsModelRestricted(
+        "dkleczek/bert-base-polish-cased-v1",
+        num_labels,
+        uppercase_weights,
+        punc_weights,
+    )
 
     result = model(tokens["input_ids"], tokens["attention_mask"])
 
@@ -27,48 +30,3 @@ def test_dimensions():
     assert result.shape[0] == tokens["input_ids"].shape[0]
     assert result.shape[1] == tokens["input_ids"].shape[1]
     assert result.shape[2] == action_vector_size
-
-
-def test_loss_dimensions():
-    batch_size = 5
-    sequence_len = 10
-    action_vector_size = 4
-    uppercase_odds = torch.tensor(0.3, dtype=torch.float)
-    punctuation_weights = torch.tensor([0.3, 0.3, 0.1], dtype=torch.float)
-    loss = ActionsModelRestrictedLoss(uppercase_odds, punctuation_weights)
-
-    actions_vector_true = torch.zeros(
-        (batch_size, sequence_len, action_vector_size), dtype=torch.float
-    )
-    actions_vector_true[:, :, -1] = 1.0
-
-    actions_vector_bad = torch.zeros(
-        (batch_size, sequence_len, action_vector_size), dtype=torch.float
-    )
-    actions_vector_bad[:, :, :2] = 1.0
-    actions_vector_bad[:, :, -1] = 0.0
-
-    result = loss(actions_vector_true, actions_vector_bad)
-    assert len(result.shape) == 0
-
-    result_perfect = loss(actions_vector_true, actions_vector_true)
-    result_bad = loss(actions_vector_true, actions_vector_bad)
-
-    print(result_perfect)
-    print(result_bad)
-
-    assert result_perfect < result_bad
-    assert result_perfect > 0
-    assert result_bad > 0
-
-
-def test_predict():
-    params = ActionsModelRestrictedParams(
-        "dkleczek/bert-base-polish-cased-v1", len(ACTIONS_KEYS) + 1
-    )
-    model = ActionsModelRestricted(params)
-
-    input_str = "testowy ciag znakow"
-    result = model.predict(input_str)
-
-    assert len(result) >= len(input_str)
diff --git a/tests/predictors/__init__.py b/tests/predictors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tests/predictors/test_actions_model_base.py b/tests/predictors/test_actions_model_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfe7f353b2ad6616441b4e2d820ba41a5b7ed604
--- /dev/null
+++ b/tests/predictors/test_actions_model_base.py
@@ -0,0 +1,19 @@
+import numpy as np
+
+from src.models.actions_model_base import ActionsModelBase
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.predictors.actions_model_base import ActionsModelBasePredictor
+
+
+def test_predict():
+    num_labels = len(ACTIONS_KEYS)
+    loss_weights = np.ones(num_labels)
+    model = ActionsModelBase(
+        "dkleczek/bert-base-polish-cased-v1", num_labels, 1e-4, loss_weights
+    )
+    predictor = ActionsModelBasePredictor(model)
+
+    input_txt = "ala ma kota a kot ma ale"
+    output_txt = predictor.predict(input_txt)
+
+    assert len(output_txt) >= len(input_txt)
diff --git a/tests/predictors/test_actions_model_ensemble.py b/tests/predictors/test_actions_model_ensemble.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2628b5b5475d1d2ef0130f72d2f689fbb1c0766
--- /dev/null
+++ b/tests/predictors/test_actions_model_ensemble.py
@@ -0,0 +1,26 @@
+import numpy as np
+
+from src.models.actions_model_ensemble import ActionsModelEnsemble
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.predictors.actions_model_ensemble import ActionsModelEnsemblePredictor
+
+
+def test_predict():
+    num_labels = len(ACTIONS_KEYS)
+    loss_weights = np.ones(num_labels)
+    model = ActionsModelEnsemble(
+        "dkleczek/bert-base-polish-cased-v1",
+        loss_weights,
+        1e-4,
+        2,
+        10,
+        2,
+        10,
+        num_labels,
+    )
+    predictor = ActionsModelEnsemblePredictor(model)
+
+    input_txt = "ala ma kota a kot ma ale"
+    output_txt = predictor.predict(input_txt)
+
+    assert len(output_txt) >= len(input_txt)
diff --git a/tests/predictors/test_actions_model_mixed.py b/tests/predictors/test_actions_model_mixed.py
new file mode 100644
index 0000000000000000000000000000000000000000..a611f6aa38c6c48410780f8a95821e489338dd78
--- /dev/null
+++ b/tests/predictors/test_actions_model_mixed.py
@@ -0,0 +1,26 @@
+import numpy as np
+
+from src.models.actions_model_mixed import ActionsModelMixed
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.predictors.actions_model_mixed import ActionsModelMixedPredictor
+
+
+def test_predict():
+    num_labels = len(ACTIONS_KEYS)
+    loss_weights = np.ones(num_labels)
+    model = ActionsModelMixed(
+        "dkleczek/bert-base-polish-cased-v1",
+        loss_weights,
+        1e-4,
+        10,
+        2,
+        2,
+        10,
+        num_labels,
+    )
+    predictor = ActionsModelMixedPredictor(model)
+
+    input_txt = "ala ma kota a kot ma ale"
+    output_txt = predictor.predict(input_txt)
+
+    assert len(output_txt) >= len(input_txt)
diff --git a/tests/predictors/test_actions_model_restricted.py b/tests/predictors/test_actions_model_restricted.py
new file mode 100644
index 0000000000000000000000000000000000000000..faccc671004eaacd1de5eef756253660c65955d9
--- /dev/null
+++ b/tests/predictors/test_actions_model_restricted.py
@@ -0,0 +1,26 @@
+import numpy as np
+
+from src.models.actions_model_restricted import ActionsModelRestricted
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.predictors.actions_model_restricted import ActionsModelRestrictedPredictor
+
+
+def test_predict():
+    num_labels = len(ACTIONS_KEYS) + 1
+
+    uppercase_weight = np.ones(1)
+    punctuaction_weight = np.ones(num_labels - 1)
+
+    model = ActionsModelRestricted(
+        "dkleczek/bert-base-polish-cased-v1",
+        num_labels,
+        uppercase_weight,
+        punctuaction_weight,
+        1e-4,
+    )
+    predictor = ActionsModelRestrictedPredictor(model)
+
+    input_txt = "ala ma kota a kot ma ale"
+    output_txt = predictor.predict(input_txt)
+
+    assert len(output_txt) >= len(input_txt)
diff --git a/tox.ini b/tox.ini
index 02ec5a04fd4446f894bc517a9ca760c2b5a46fc4..f4436b8ccf6cec9dd5267b016dad0a470ed67cc6 100644
--- a/tox.ini
+++ b/tox.ini
@@ -39,3 +39,8 @@ basepython = python
 commands =
     flake8 {posargs}
 
+
+[pytest]
+filterwarnings =
+    ignore::DeprecationWarning
+    ignore::PendingDeprecationWarning
diff --git a/train.sh b/train.sh
index 3d7da2d1b0ab2910fe3cffd72230bd74a748fcfe..d8a88122dcab1900d77c5ae5a8aae55e409b6e8e 100755
--- a/train.sh
+++ b/train.sh
@@ -6,4 +6,4 @@
 DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
 
 docker build . -f ./docker/training/Dockerfile -t clarinpl/punctuator_training --build-arg USERNAME=$(whoami) --build-arg USER_UID=$(id -u) --build-arg USER_GID=$(id -u) && \
-docker run -v $DIR:/punctuator --name $2 --gpus all -it --entrypoint python clarinpl/punctuator_training -m $1
+docker run -v $DIR:/punctuator -e COMMET_API_KEY --name $2 --gpus all -it --entrypoint python clarinpl/punctuator_training -m src.pipelines.actions_based.train -m $1 -n $2
diff --git a/worker.py b/worker.py
index 98e5a75385c110f3f236afe9f504b35c7a10616a..766a801d9cfe97c079db21472de5bf6ee4dd44e0 100755
--- a/worker.py
+++ b/worker.py
@@ -2,6 +2,7 @@
 
 import configparser
 import logging
+from src.predictors.factory import ACTION_MODEL_PREDICTORS
 from src.models.model_factory import MODELS_MAP
 from typing import List
 
@@ -31,10 +32,13 @@ class Worker(nlp_ws.NLPWorker):
 
     def _load_models(self, models_list: List[str]):
         for model_type in models_list:
-            self.models[model_type] = MODELS_MAP[model_type].load(
-                f"{self.models_dir}/{model_type}", "production", self.device
+            model = MODELS_MAP[model_type].load_from_checkpoint(
+                f"{self.models_dir}/{model_type}/production.ckpt",
+                map_location=self.device,
             )
-            self.models[model_type].train(False)
+            model.train(False)
+
+            self.models[model_type] = ACTION_MODEL_PREDICTORS[model_type](model)
 
     def process(self, input_file: str, task_options: dict, output_file: str) -> None:
         """Implementation of example tasks that copies files."""