From 67199a97dabd3580b2100987ba3864ba4a0a0f3c Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 9 Jul 2020 16:53:19 +0200
Subject: [PATCH 001/116] Initial commit

---
 .gitignore                                |   1 +
 dataset_generation/.gitignore             |   3 +
 dataset_generation/notebook_actions.ipynb | 150 +++++++++++++++++++++
 dataset_generation/notebook_simple.ipynb  | 152 ++++++++++++++++++++++
 dataset_generation/processing.py          | 149 +++++++++++++++++++++
 dataset_generation/utils.py               |  14 ++
 6 files changed, 469 insertions(+)
 create mode 100644 .gitignore
 create mode 100644 dataset_generation/.gitignore
 create mode 100644 dataset_generation/notebook_actions.ipynb
 create mode 100644 dataset_generation/notebook_simple.ipynb
 create mode 100644 dataset_generation/processing.py
 create mode 100644 dataset_generation/utils.py

diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..fcff6f3
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+dane/**
diff --git a/dataset_generation/.gitignore b/dataset_generation/.gitignore
new file mode 100644
index 0000000..b43b91b
--- /dev/null
+++ b/dataset_generation/.gitignore
@@ -0,0 +1,3 @@
+dataset_simple
+dataset_actions
+__pycache__
\ No newline at end of file
diff --git a/dataset_generation/notebook_actions.ipynb b/dataset_generation/notebook_actions.ipynb
new file mode 100644
index 0000000..aa2522b
--- /dev/null
+++ b/dataset_generation/notebook_actions.ipynb
@@ -0,0 +1,150 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%load_ext autoreload\n",
+    "%autoreload 2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import glob\n",
+    "import random\n",
+    "from lxml import etree\n",
+    "import uuid\n",
+    "import hashlib\n",
+    "import seaborn as sns\n",
+    "import re\n",
+    "import numpy as np\n",
+    "from tqdm import tqdm\n",
+    "\n",
+    "from processing import text_from_xml, create_model_input_output\n",
+    "from utils import remove_multiple_spaces, remove_punctuation"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "file_schema = \"../dane/**/text_structure.xml\"\n",
+    "files_paths = glob.glob(file_schema, recursive=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "files_subset = random.sample(files_paths, 1_000)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stderr",
+     "text": "100%|██████████| 1000/1000 [00:09<00:00, 109.49it/s]\n"
+    }
+   ],
+   "source": [
+    "num_exported = 0\n",
+    "document_lens = []\n",
+    "\n",
+    "if not os.path.exists(\"dataset_actions\"):\n",
+    "    os.mkdir(\"dataset_actions\")\n",
+    "\n",
+    "for file_path in tqdm(files_subset):\n",
+    "    full_text = text_from_xml(file_path)\n",
+    "    \n",
+    "    if len(full_text) > 0:\n",
+    "        output_file_input = f\"dataset_actions/{hashlib.md5(file_path.encode()).hexdigest()}_input.txt\"\n",
+    "        output_file_output = f\"dataset_actions/{hashlib.md5(file_path.encode()).hexdigest()}_output.txt\"\n",
+    "\n",
+    "        model_input, model_output = create_model_input_output(full_text)\n",
+    "\n",
+    "        with open(output_file_input, \"w\") as f:\n",
+    "            f.write(model_input)\n",
+    "            num_exported += 1\n",
+    "            document_lens.append(len(full_text))\n",
+    "\n",
+    "        with open(output_file_output, 'w') as f:\n",
+    "            f.write(str(model_output.tolist()))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "<matplotlib.axes._subplots.AxesSubplot at 0x7f418d4751f0>"
+     },
+     "metadata": {},
+     "execution_count": 6
+    },
+    {
+     "output_type": "display_data",
+     "data": {
+      "text/plain": "<Figure size 432x288 with 1 Axes>",
+      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"249.103945pt\" version=\"1.1\" viewBox=\"0 0 368.925 249.103945\" width=\"368.925pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n  <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n  </style>\n </defs>\n <g id=\"figure_1\">\n  <g id=\"patch_1\">\n   <path d=\"M -0 249.103945 \nL 368.925 249.103945 \nL 368.925 0 \nL -0 0 \nz\n\" style=\"fill:none;\"/>\n  </g>\n  <g id=\"axes_1\">\n   <g id=\"patch_2\">\n    <path d=\"M 26.925 225.22582 \nL 361.725 225.22582 \nL 361.725 7.78582 \nL 26.925 7.78582 \nz\n\" style=\"fill:#ffffff;\"/>\n   </g>\n   <g id=\"patch_3\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 42.143182 225.22582 \nL 53.849476 225.22582 \nL 53.849476 32.421879 \nL 42.143182 32.421879 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_4\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 53.849476 225.22582 \nL 65.555769 225.22582 \nL 65.555769 18.140105 \nL 53.849476 18.140105 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_5\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 65.555769 225.22582 \nL 77.262063 225.22582 \nL 77.262063 103.830746 \nL 65.555769 103.830746 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_6\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 77.262063 225.22582 \nL 88.968357 225.22582 \nL 88.968357 118.112519 \nL 77.262063 118.112519 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_7\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 88.968357 225.22582 \nL 100.67465 225.22582 \nL 100.67465 189.521386 \nL 88.968357 189.521386 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_8\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 100.67465 225.22582 \nL 112.380944 225.22582 \nL 112.380944 203.80316 \nL 100.67465 203.80316 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_9\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 112.380944 225.22582 \nL 124.087238 225.22582 \nL 124.087238 203.80316 \nL 112.380944 203.80316 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_10\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 124.087238 225.22582 \nL 135.793531 225.22582 \nL 135.793531 218.084933 \nL 124.087238 218.084933 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_11\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 135.793531 225.22582 \nL 147.499825 225.22582 \nL 147.499825 203.80316 \nL 135.793531 203.80316 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_12\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 147.499825 225.22582 \nL 159.206119 225.22582 \nL 159.206119 225.22582 \nL 147.499825 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_13\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 159.206119 225.22582 \nL 170.912413 225.22582 \nL 170.912413 225.22582 \nL 159.206119 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_14\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 170.912413 225.22582 \nL 182.618706 225.22582 \nL 182.618706 225.22582 \nL 170.912413 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_15\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 182.618706 225.22582 \nL 194.325 225.22582 \nL 194.325 218.084933 \nL 182.618706 218.084933 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_16\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 194.325 225.22582 \nL 206.031294 225.22582 \nL 206.031294 225.22582 \nL 194.325 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_17\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 206.031294 225.22582 \nL 217.737587 225.22582 \nL 217.737587 210.944046 \nL 206.031294 210.944046 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_18\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 217.737587 225.22582 \nL 229.443881 225.22582 \nL 229.443881 225.22582 \nL 217.737587 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_19\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 229.443881 225.22582 \nL 241.150175 225.22582 \nL 241.150175 225.22582 \nL 229.443881 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_20\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 241.150175 225.22582 \nL 252.856469 225.22582 \nL 252.856469 225.22582 \nL 241.150175 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_21\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 252.856469 225.22582 \nL 264.562762 225.22582 \nL 264.562762 218.084933 \nL 252.856469 218.084933 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_22\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 264.562762 225.22582 \nL 276.269056 225.22582 \nL 276.269056 218.084933 \nL 264.562762 218.084933 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_23\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 276.269056 225.22582 \nL 287.97535 225.22582 \nL 287.97535 225.22582 \nL 276.269056 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_24\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 287.97535 225.22582 \nL 299.681643 225.22582 \nL 299.681643 218.084933 \nL 287.97535 218.084933 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_25\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 299.681643 225.22582 \nL 311.387937 225.22582 \nL 311.387937 225.22582 \nL 299.681643 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_26\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 311.387937 225.22582 \nL 323.094231 225.22582 \nL 323.094231 225.22582 \nL 311.387937 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_27\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 323.094231 225.22582 \nL 334.800524 225.22582 \nL 334.800524 225.22582 \nL 323.094231 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_28\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 334.800524 225.22582 \nL 346.506818 225.22582 \nL 346.506818 218.084933 \nL 334.800524 218.084933 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"matplotlib.axis_1\">\n    <g id=\"xtick_1\">\n     <g id=\"line2d_1\">\n      <defs>\n       <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"mcb448f674c\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"41.035222\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_1\">\n      <!-- 0 -->\n      <defs>\n       <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421875 \nQ 19.53125 -1.421875 13.0625 8.265625 \nQ 6.59375 17.96875 6.59375 36.375 \nQ 6.59375 54.828125 13.0625 64.515625 \nQ 19.53125 74.21875 31.78125 74.21875 \nz\n\" id=\"DejaVuSans-48\"/>\n      </defs>\n      <g transform=\"translate(37.853972 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_2\">\n     <g id=\"line2d_2\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"90.853559\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_2\">\n      <!-- 100000 -->\n      <defs>\n       <path d=\"M 12.40625 8.296875 \nL 28.515625 8.296875 \nL 28.515625 63.921875 \nL 10.984375 60.40625 \nL 10.984375 69.390625 \nL 28.421875 72.90625 \nL 38.28125 72.90625 \nL 38.28125 8.296875 \nL 54.390625 8.296875 \nL 54.390625 0 \nL 12.40625 0 \nz\n\" id=\"DejaVuSans-49\"/>\n      </defs>\n      <g transform=\"translate(71.766059 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_3\">\n     <g id=\"line2d_3\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"140.671896\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_3\">\n      <!-- 200000 -->\n      <defs>\n       <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n      </defs>\n      <g transform=\"translate(121.584396 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_4\">\n     <g id=\"line2d_4\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"190.490233\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_4\">\n      <!-- 300000 -->\n      <defs>\n       <path d=\"M 40.578125 39.3125 \nQ 47.65625 37.796875 51.625 33 \nQ 55.609375 28.21875 55.609375 21.1875 \nQ 55.609375 10.40625 48.1875 4.484375 \nQ 40.765625 -1.421875 27.09375 -1.421875 \nQ 22.515625 -1.421875 17.65625 -0.515625 \nQ 12.796875 0.390625 7.625 2.203125 \nL 7.625 11.71875 \nQ 11.71875 9.328125 16.59375 8.109375 \nQ 21.484375 6.890625 26.8125 6.890625 \nQ 36.078125 6.890625 40.9375 10.546875 \nQ 45.796875 14.203125 45.796875 21.1875 \nQ 45.796875 27.640625 41.28125 31.265625 \nQ 36.765625 34.90625 28.71875 34.90625 \nL 20.21875 34.90625 \nL 20.21875 43.015625 \nL 29.109375 43.015625 \nQ 36.375 43.015625 40.234375 45.921875 \nQ 44.09375 48.828125 44.09375 54.296875 \nQ 44.09375 59.90625 40.109375 62.90625 \nQ 36.140625 65.921875 28.71875 65.921875 \nQ 24.65625 65.921875 20.015625 65.03125 \nQ 15.375 64.15625 9.8125 62.3125 \nL 9.8125 71.09375 \nQ 15.4375 72.65625 20.34375 73.4375 \nQ 25.25 74.21875 29.59375 74.21875 \nQ 40.828125 74.21875 47.359375 69.109375 \nQ 53.90625 64.015625 53.90625 55.328125 \nQ 53.90625 49.265625 50.4375 45.09375 \nQ 46.96875 40.921875 40.578125 39.3125 \nz\n\" id=\"DejaVuSans-51\"/>\n      </defs>\n      <g transform=\"translate(171.402733 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-51\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_5\">\n     <g id=\"line2d_5\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"240.308571\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_5\">\n      <!-- 400000 -->\n      <defs>\n       <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375 17.1875 \nL 47.609375 0 \nL 37.796875 0 \nL 37.796875 17.1875 \nL 4.890625 17.1875 \nL 4.890625 26.703125 \nz\n\" id=\"DejaVuSans-52\"/>\n      </defs>\n      <g transform=\"translate(221.221071 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-52\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_6\">\n     <g id=\"line2d_6\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"290.126908\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_6\">\n      <!-- 500000 -->\n      <defs>\n       <path d=\"M 10.796875 72.90625 \nL 49.515625 72.90625 \nL 49.515625 64.59375 \nL 19.828125 64.59375 \nL 19.828125 46.734375 \nQ 21.96875 47.46875 24.109375 47.828125 \nQ 26.265625 48.1875 28.421875 48.1875 \nQ 40.625 48.1875 47.75 41.5 \nQ 54.890625 34.8125 54.890625 23.390625 \nQ 54.890625 11.625 47.5625 5.09375 \nQ 40.234375 -1.421875 26.90625 -1.421875 \nQ 22.3125 -1.421875 17.546875 -0.640625 \nQ 12.796875 0.140625 7.71875 1.703125 \nL 7.71875 11.625 \nQ 12.109375 9.234375 16.796875 8.0625 \nQ 21.484375 6.890625 26.703125 6.890625 \nQ 35.15625 6.890625 40.078125 11.328125 \nQ 45.015625 15.765625 45.015625 23.390625 \nQ 45.015625 31 40.078125 35.4375 \nQ 35.15625 39.890625 26.703125 39.890625 \nQ 22.75 39.890625 18.8125 39.015625 \nQ 14.890625 38.140625 10.796875 36.28125 \nz\n\" id=\"DejaVuSans-53\"/>\n      </defs>\n      <g transform=\"translate(271.039408 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-53\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_7\">\n     <g id=\"line2d_7\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"339.945245\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_7\">\n      <!-- 600000 -->\n      <defs>\n       <path d=\"M 33.015625 40.375 \nQ 26.375 40.375 22.484375 35.828125 \nQ 18.609375 31.296875 18.609375 23.390625 \nQ 18.609375 15.53125 22.484375 10.953125 \nQ 26.375 6.390625 33.015625 6.390625 \nQ 39.65625 6.390625 43.53125 10.953125 \nQ 47.40625 15.53125 47.40625 23.390625 \nQ 47.40625 31.296875 43.53125 35.828125 \nQ 39.65625 40.375 33.015625 40.375 \nz\nM 52.59375 71.296875 \nL 52.59375 62.3125 \nQ 48.875 64.0625 45.09375 64.984375 \nQ 41.3125 65.921875 37.59375 65.921875 \nQ 27.828125 65.921875 22.671875 59.328125 \nQ 17.53125 52.734375 16.796875 39.40625 \nQ 19.671875 43.65625 24.015625 45.921875 \nQ 28.375 48.1875 33.59375 48.1875 \nQ 44.578125 48.1875 50.953125 41.515625 \nQ 57.328125 34.859375 57.328125 23.390625 \nQ 57.328125 12.15625 50.6875 5.359375 \nQ 44.046875 -1.421875 33.015625 -1.421875 \nQ 20.359375 -1.421875 13.671875 8.265625 \nQ 6.984375 17.96875 6.984375 36.375 \nQ 6.984375 53.65625 15.1875 63.9375 \nQ 23.390625 74.21875 37.203125 74.21875 \nQ 40.921875 74.21875 44.703125 73.484375 \nQ 48.484375 72.75 52.59375 71.296875 \nz\n\" id=\"DejaVuSans-54\"/>\n      </defs>\n      <g transform=\"translate(320.857745 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-54\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"matplotlib.axis_2\">\n    <g id=\"ytick_1\">\n     <g id=\"line2d_8\">\n      <defs>\n       <path d=\"M 0 0 \nL -3.5 0 \n\" id=\"m50745714e3\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_8\">\n      <!-- 0 -->\n      <g transform=\"translate(13.5625 229.025038)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_2\">\n     <g id=\"line2d_9\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"189.521386\"/>\n      </g>\n     </g>\n     <g id=\"text_9\">\n      <!-- 5 -->\n      <g transform=\"translate(13.5625 193.320605)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_3\">\n     <g id=\"line2d_10\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"153.816953\"/>\n      </g>\n     </g>\n     <g id=\"text_10\">\n      <!-- 10 -->\n      <g transform=\"translate(7.2 157.616171)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_4\">\n     <g id=\"line2d_11\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"118.112519\"/>\n      </g>\n     </g>\n     <g id=\"text_11\">\n      <!-- 15 -->\n      <g transform=\"translate(7.2 121.911738)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_5\">\n     <g id=\"line2d_12\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"82.408086\"/>\n      </g>\n     </g>\n     <g id=\"text_12\">\n      <!-- 20 -->\n      <g transform=\"translate(7.2 86.207304)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_6\">\n     <g id=\"line2d_13\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"46.703652\"/>\n      </g>\n     </g>\n     <g id=\"text_13\">\n      <!-- 25 -->\n      <g transform=\"translate(7.2 50.502871)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_7\">\n     <g id=\"line2d_14\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"10.999219\"/>\n      </g>\n     </g>\n     <g id=\"text_14\">\n      <!-- 30 -->\n      <g transform=\"translate(7.2 14.798437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-51\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"patch_29\">\n    <path d=\"M 26.925 225.22582 \nL 26.925 7.78582 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_30\">\n    <path d=\"M 361.725 225.22582 \nL 361.725 7.78582 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_31\">\n    <path d=\"M 26.925 225.22582 \nL 361.725 225.22582 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_32\">\n    <path d=\"M 26.925 7.78582 \nL 361.725 7.78582 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n  </g>\n </g>\n <defs>\n  <clipPath id=\"p1f2e03f3c1\">\n   <rect height=\"217.44\" width=\"334.8\" x=\"26.925\" y=\"7.78582\"/>\n  </clipPath>\n </defs>\n</svg>\n",
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD5CAYAAAA+0W6bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPHElEQVR4nO3dbaxlVX3H8e+vgGiFyCA3kwkwHWyIhhftQG94iMZYrFZJUzUhDWh0UmnGtJpoNGlAk2rTvrBN1da0QcZCnRciWpRCCFYpkhiTZuwdHWFgRB466pCBGTWI9kUt+O+Lswaul/tw7nmYOWv4fpKTs8/a+5z1XzN7fnffdfbek6pCktSfXzvWBUiSRmOAS1KnDHBJ6pQBLkmdMsAlqVMGuCR16sS1NkjyQuDrwMlt+5ur6sNJzgFuAl4K7AbeXlW/WO2zzjjjjNqyZcvYRUvS88nu3bt/VFVzS9vXDHDgf4FLq+rnSU4CvpHky8D7gU9U1U1JPgVcBVy72gdt2bKFhYWFEcqXpOevJN9frn3NKZQa+Hl7eVJ7FHApcHNr3wm8eQJ1SpKGNNQceJITkuwBDgF3Ag8DT1TVU22TA8CZ0ylRkrScoQK8qp6uqq3AWcCFwCuG7SDJ9iQLSRYOHz48YpmSpKXWdRZKVT0B3A1cApyW5Mgc+lnAoyu8Z0dVzVfV/Nzcc+bgJUkjWjPAk8wlOa0tvwh4HbCPQZBf3jbbBtw6rSIlSc81zFkom4CdSU5gEPhfqKrbk9wP3JTkr4FvA9dPsU5J0hJrBnhV3QOcv0z7IwzmwyVJx4BXYkpSpwxwSerUMHPgXbpx1w/Wtf1bL9o8pUokaTo8ApekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSepUN3cjXO/dBSXpeOcRuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6ROrRngSc5OcneS+5Pcl+S9rf0jSR5Nsqc9Lpt+uZKkI4a5F8pTwAeq6ltJTgV2J7mzrftEVf3d9MqTJK1kzQCvqoPAwbb8syT7gDOnXZgkaXXrmgNPsgU4H9jVmt6T5J4kNyTZsMJ7tidZSLJw+PDhsYqVJD1r6ABPcgrwReB9VfUkcC3wm8BWBkfoH1vufVW1o6rmq2p+bm5uAiVLkmDIAE9yEoPw/mxVfQmgqh6vqqer6pfAp4ELp1emJGmpYc5CCXA9sK+qPr6ofdOizd4C7J18eZKklQxzFsorgbcD9ybZ09o+CFyZZCtQwH7gXVOpUJK0rGHOQvkGkGVW3TH5ciRJw/JKTEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6ROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpU2sGeJKzk9yd5P4k9yV5b2s/PcmdSR5szxumX64k6YhhjsCfAj5QVecBFwPvTnIecDVwV1WdC9zVXkuSjpI1A7yqDlbVt9ryz4B9wJnAm4CdbbOdwJunVaQk6bnWNQeeZAtwPrAL2FhVB9uqx4CNK7xne5KFJAuHDx8eo1RJ0mJDB3iSU4AvAu+rqicXr6uqAmq591XVjqqar6r5ubm5sYqVJD1rqABPchKD8P5sVX2pNT+eZFNbvwk4NJ0SJUnLGeYslADXA/uq6uOLVt0GbGvL24BbJ1+eJGklJw6xzSuBtwP3JtnT2j4IfBT4QpKrgO8DfzSdEiVJy1kzwKvqG0BWWP3ayZYjSRqWV2JKUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUqWH+Q4fnhRt3/WDobd960eYpViJJw/EIXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1Kn1gzwJDckOZRk76K2jyR5NMme9rhsumVKkpYa5gj8M8Ablmn/RFVtbY87JluWJGktawZ4VX0d+MlRqEWStA7jzIG/J8k9bYplw8QqkiQNZdTbyV4L/BVQ7fljwDuX2zDJdmA7wObNx8dtWNdz61nw9rOSpmOkI/Cqeryqnq6qXwKfBi5cZdsdVTVfVfNzc3Oj1ilJWmKkAE+yadHLtwB7V9pWkjQda06hJPkc8BrgjCQHgA8Dr0mylcEUyn7gXVOsUZK0jDUDvKquXKb5+inUIklaB6/ElKROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6ROGeCS1CkDXJI6tWaAJ7khyaEkexe1nZ7kziQPtucN0y1TkrTUMEfgnwHesKTtauCuqjoXuKu9liQdRWsGeFV9HfjJkuY3ATvb8k7gzROuS5K0hlHnwDdW1cG2/BiwcaUNk2xPspBk4fDhwyN2J0laauwvMauqgFpl/Y6qmq+q+bm5uXG7kyQ1owb440k2AbTnQ5MrSZI0jFED/DZgW1veBtw6mXIkScMa5jTCzwH/Cbw8yYEkVwEfBV6X5EHg99prSdJRdOJaG1TVlSuseu2Ea5EkrYNXYkpSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6ROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnTpxnDcn2Q/8DHgaeKqq5idRlCRpbWMFePO7VfWjCXyOJGkdnEKRpE6NewRewFeTFHBdVe1YukGS7cB2gM2bN4/ZXZ9u3PWDdW3/1ouen39OktZn3CPwV1XVBcAbgXcnefXSDapqR1XNV9X83NzcmN1Jko4YK8Cr6tH2fAi4BbhwEkVJktY2coAneXGSU48sA68H9k6qMEnS6saZA98I3JLkyOfcWFX/PpGqJElrGjnAq+oR4LcnWIskaR08jVCSOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUqUncD1wTtt67F67Heu906J0UpdnlEbgkdcoAl6ROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ3yPHBpBJ4fr1ngEbgkdcoAl6ROGeCS1CkDXJI6ZYBLUqcMcEnqlKcRPs9M81a1s2Y9Y30+nebn7YonYxZq9whckjplgEtSp8YK8CRvSPJAkoeSXD2poiRJaxs5wJOcAPwT8EbgPODKJOdNqjBJ0urGOQK/EHioqh6pql8ANwFvmkxZkqS1jBPgZwI/XPT6QGuTJB0FUz+NMMl2YHt7+fMkD4zwMWcAP5pcVcfMcT+Otx3lQsb0zDimXfcUP3+m9qkxxjnUODrYv6b1b+M3lmscJ8AfBc5e9Pqs1vYrqmoHsGOMfkiyUFXz43zGLHAcs+V4GMfxMAZwHKMaZwrlv4Bzk5yT5AXAFcBtkylLkrSWkY/Aq+qpJO8BvgKcANxQVfdNrDJJ0qrGmgOvqjuAOyZUy2rGmoKZIY5jthwP4zgexgCOYySpqqPZnyRpQryUXpI6NfMBPguX6ye5IcmhJHsXtZ2e5M4kD7bnDa09ST7Z6r0nyQWL3rOtbf9gkm2L2n8nyb3tPZ9MktX6GGMcZye5O8n9Se5L8t7expLkhUm+meQ7bQx/2drPSbKr9fv59sU6SU5urx9q67cs+qxrWvsDSX5/Ufuy+9xKfYwjyQlJvp3k9l7HkWR/+zvfk2ShtXWzTy3q57QkNyf5bpJ9SS6Z+XFU1cw+GHw5+jDwMuAFwHeA845BHa8GLgD2Lmr7W+Dqtnw18Ddt+TLgy0CAi4Fdrf104JH2vKEtb2jrvtm2TXvvG1frY4xxbAIuaMunAt9jcBuEbsbSPveUtnwSsKv19wXgitb+KeBP2/KfAZ9qy1cAn2/L57X96WTgnLafnbDaPrdSH2P+nbwfuBG4fbU+ZnkcwH7gjCVt3exTi2reCfxJW34BcNqsj+OoBuEIf6CXAF9Z9Poa4JpjVMsWfjXAHwA2teVNwANt+TrgyqXbAVcC1y1qv661bQK+u6j9me1W6mOCY7oVeF2vYwF+HfgWcBGDiydOXLrfMDhL6pK2fGLbLkv3pSPbrbTPtfcs28cY9Z8F3AVcCty+Wh8zPo79PDfAu9qngJcA/037XrCXccz6FMosX66/saoOtuXHgI1teaWaV2s/sEz7an2Mrf0Kfj6DI9iuxtKmHfYAh4A7GRxpPlFVTy3T7zO1tvU/BV46wtheukofo/p74M+BX7bXq/Uxy+Mo4KtJdmdw5TV0tk8x+O3lMPAvbUrrn5O8eNbHMesB3oUa/Oic6uk8k+wjySnAF4H3VdWT0+pnJeP2UVVPV9VWBkewFwKvmFRtR0uSPwAOVdXuY13LBLyqqi5gcGfSdyd59eKVPexTDH6ruQC4tqrOB/6HwXTGJPtY03r7mPUAH+py/WPk8SSbANrzoda+Us2rtZ+1TPtqfYwsyUkMwvuzVfWlnsdSVU8AdzOYBjgtyZHrGhb3+0ytbf1LgB+PMLYfr9LHKF4J/GGS/Qzu5Hkp8A8djoOqerQ9HwJuYfBDtbd96gBwoKp2tdc3Mwj0mR7HrAf4LF+ufxtw5BvmbQzmk4+0v6N9S30x8NP269FXgNcn2dC+ZX49g7nHg8CTSS5u30q/Y8lnLdfHSNrnXw/sq6qP9ziWJHNJTmvLL2Iwh7+PQZBfvsIYjvR7OfC1dpRzG3BFBmd3nAOcy+BLpmX3ufaelfpYt6q6pqrOqqotrY+vVdXbehtHkhcnOfXIMoN9YS8d7VMAVfUY8MMkL29NrwXun/lxjPPlxdF4MPi293sM5jk/dIxq+BxwEPg/Bj+pr2Iwl3gX8CDwH8Dpbdsw+I8uHgbuBeYXfc47gYfa448Xtc8z2OkfBv6RZy+wWraPMcbxKga/nt0D7GmPy3oaC/BbwLfbGPYCf9HaX8YguB4C/hU4ubW/sL1+qK1/2aLP+lCr8wHaGQGr7XMr9TGB/es1PHsWSlfjaJ/1nfa470g/Pe1Ti/rZCiy0fevfGJxFMtPj8EpMSerUrE+hSJJWYIBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktSp/wdRBFKWSSQcDgAAAABJRU5ErkJggg==\n"
+     },
+     "metadata": {
+      "needs_background": "light"
+     }
+    }
+   ],
+   "source": [
+    "sns.distplot(document_lens, kde=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": 3
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python_defaultSpec_1594306004152",
+   "display_name": "Python 3.8.2 64-bit"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
\ No newline at end of file
diff --git a/dataset_generation/notebook_simple.ipynb b/dataset_generation/notebook_simple.ipynb
new file mode 100644
index 0000000..3de5576
--- /dev/null
+++ b/dataset_generation/notebook_simple.ipynb
@@ -0,0 +1,152 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%load_ext autoreload\n",
+    "%autoreload 2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 76,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import glob\n",
+    "import random\n",
+    "from lxml import etree\n",
+    "import uuid\n",
+    "import hashlib\n",
+    "import seaborn as sns\n",
+    "import re\n",
+    "import os\n",
+    "from tqdm import tqdm\n",
+    "\n",
+    "from processing import text_from_xml\n",
+    "from utils import remove_multiple_spaces, remove_punctuation"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "file_schema = \"../dane/**/text_structure.xml\"\n",
+    "files_paths = glob.glob(file_schema, recursive=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 54,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "files_subset = random.sample(files_paths, 1_000)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 78,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stderr",
+     "text": "100%|██████████| 10000/10000 [00:40<00:00, 248.07it/s]\n"
+    }
+   ],
+   "source": [
+    "num_exported = 0\n",
+    "document_lens = []\n",
+    "\n",
+    "if not os.path.exists(\"dataset_simple\"):\n",
+    "    os.mkdir(\"dataset_simple\")\n",
+    "\n",
+    "for file_path in tqdm(files_subset):\n",
+    "    full_text = text_from_xml(file_path)\n",
+    "    \n",
+    "    if len(full_text) > 0:\n",
+    "        output_file_input = f\"dataset_simple/{hashlib.md5(file_path.encode()).hexdigest()}_input.txt\"\n",
+    "        output_file_output = f\"dataset_simple/{hashlib.md5(file_path.encode()).hexdigest()}_output.txt\"\n",
+    "\n",
+    "        with open(output_file_input, \"w\") as f:\n",
+    "            f.write(full_text)\n",
+    "            num_exported += 1\n",
+    "            document_lens.append(len(full_text))\n",
+    "\n",
+    "        text_cleared = remove_punctuation(full_text)\n",
+    "        text_cleared = remove_multiple_spaces(text_cleared)\n",
+    "        text_cleared = text_cleared.lower()\n",
+    "\n",
+    "        with open(output_file_output, 'w') as f:\n",
+    "            f.write(text_cleared)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 79,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "<matplotlib.axes._subplots.AxesSubplot at 0x7fcd4958ab20>"
+     },
+     "metadata": {},
+     "execution_count": 79
+    },
+    {
+     "output_type": "display_data",
+     "data": {
+      "text/plain": "<Figure size 432x288 with 1 Axes>",
+      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 375.2875 248.518125\" width=\"375.2875pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n  <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n  </style>\n </defs>\n <g id=\"figure_1\">\n  <g id=\"patch_1\">\n   <path d=\"M 0 248.518125 \nL 375.2875 248.518125 \nL 375.2875 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n  </g>\n  <g id=\"axes_1\">\n   <g id=\"patch_2\">\n    <path d=\"M 33.2875 224.64 \nL 368.0875 224.64 \nL 368.0875 7.2 \nL 33.2875 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n   </g>\n   <g id=\"patch_3\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 48.505682 224.64 \nL 54.592955 224.64 \nL 54.592955 17.554286 \nL 48.505682 17.554286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_4\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 54.592955 224.64 \nL 60.680227 224.64 \nL 60.680227 76.606071 \nL 54.592955 76.606071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_5\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 60.680227 224.64 \nL 66.7675 224.64 \nL 66.7675 93.593571 \nL 60.680227 93.593571 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_6\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 66.7675 224.64 \nL 72.854773 224.64 \nL 72.854773 112.198929 \nL 66.7675 112.198929 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_7\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 72.854773 224.64 \nL 78.942045 224.64 \nL 78.942045 143.747143 \nL 72.854773 143.747143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_8\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 78.942045 224.64 \nL 85.029318 224.64 \nL 85.029318 176.913214 \nL 78.942045 176.913214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_9\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 85.029318 224.64 \nL 91.116591 224.64 \nL 91.116591 199.563214 \nL 85.029318 199.563214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_10\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 91.116591 224.64 \nL 97.203864 224.64 \nL 97.203864 210.888214 \nL 91.116591 210.888214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_11\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 97.203864 224.64 \nL 103.291136 224.64 \nL 103.291136 214.123929 \nL 97.203864 214.123929 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_12\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 103.291136 224.64 \nL 109.378409 224.64 \nL 109.378409 210.888214 \nL 103.291136 210.888214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_13\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 109.378409 224.64 \nL 115.465682 224.64 \nL 115.465682 221.404286 \nL 109.378409 221.404286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_14\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 115.465682 224.64 \nL 121.552955 224.64 \nL 121.552955 217.359643 \nL 115.465682 217.359643 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_15\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 121.552955 224.64 \nL 127.640227 224.64 \nL 127.640227 216.550714 \nL 121.552955 216.550714 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_16\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 127.640227 224.64 \nL 133.7275 224.64 \nL 133.7275 219.786429 \nL 127.640227 219.786429 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_17\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 133.7275 224.64 \nL 139.814773 224.64 \nL 139.814773 216.550714 \nL 133.7275 216.550714 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_18\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 139.814773 224.64 \nL 145.902045 224.64 \nL 145.902045 220.595357 \nL 139.814773 220.595357 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_19\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 145.902045 224.64 \nL 151.989318 224.64 \nL 151.989318 220.595357 \nL 145.902045 220.595357 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_20\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 151.989318 224.64 \nL 158.076591 224.64 \nL 158.076591 223.022143 \nL 151.989318 223.022143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_21\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 158.076591 224.64 \nL 164.163864 224.64 \nL 164.163864 223.022143 \nL 158.076591 223.022143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_22\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 164.163864 224.64 \nL 170.251136 224.64 \nL 170.251136 216.550714 \nL 164.163864 216.550714 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_23\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 170.251136 224.64 \nL 176.338409 224.64 \nL 176.338409 220.595357 \nL 170.251136 220.595357 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_24\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 176.338409 224.64 \nL 182.425682 224.64 \nL 182.425682 223.022143 \nL 176.338409 223.022143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_25\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 182.425682 224.64 \nL 188.512955 224.64 \nL 188.512955 223.022143 \nL 182.425682 223.022143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_26\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 188.512955 224.64 \nL 194.600227 224.64 \nL 194.600227 222.213214 \nL 188.512955 222.213214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_27\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 194.600227 224.64 \nL 200.6875 224.64 \nL 200.6875 223.831071 \nL 194.600227 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_28\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 200.6875 224.64 \nL 206.774773 224.64 \nL 206.774773 222.213214 \nL 200.6875 222.213214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_29\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 206.774773 224.64 \nL 212.862045 224.64 \nL 212.862045 224.64 \nL 206.774773 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_30\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 212.862045 224.64 \nL 218.949318 224.64 \nL 218.949318 220.595357 \nL 212.862045 220.595357 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_31\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 218.949318 224.64 \nL 225.036591 224.64 \nL 225.036591 223.831071 \nL 218.949318 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_32\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 225.036591 224.64 \nL 231.123864 224.64 \nL 231.123864 222.213214 \nL 225.036591 222.213214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_33\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 231.123864 224.64 \nL 237.211136 224.64 \nL 237.211136 221.404286 \nL 231.123864 221.404286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_34\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 237.211136 224.64 \nL 243.298409 224.64 \nL 243.298409 221.404286 \nL 237.211136 221.404286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_35\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 243.298409 224.64 \nL 249.385682 224.64 \nL 249.385682 224.64 \nL 243.298409 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_36\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 249.385682 224.64 \nL 255.472955 224.64 \nL 255.472955 222.213214 \nL 249.385682 222.213214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_37\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 255.472955 224.64 \nL 261.560227 224.64 \nL 261.560227 224.64 \nL 255.472955 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_38\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 261.560227 224.64 \nL 267.6475 224.64 \nL 267.6475 222.213214 \nL 261.560227 222.213214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_39\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 267.6475 224.64 \nL 273.734773 224.64 \nL 273.734773 223.831071 \nL 267.6475 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_40\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 273.734773 224.64 \nL 279.822045 224.64 \nL 279.822045 223.831071 \nL 273.734773 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_41\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 279.822045 224.64 \nL 285.909318 224.64 \nL 285.909318 218.9775 \nL 279.822045 218.9775 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_42\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 285.909318 224.64 \nL 291.996591 224.64 \nL 291.996591 224.64 \nL 285.909318 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_43\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 291.996591 224.64 \nL 298.083864 224.64 \nL 298.083864 224.64 \nL 291.996591 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_44\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 298.083864 224.64 \nL 304.171136 224.64 \nL 304.171136 223.831071 \nL 298.083864 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_45\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 304.171136 224.64 \nL 310.258409 224.64 \nL 310.258409 223.831071 \nL 304.171136 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_46\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 310.258409 224.64 \nL 316.345682 224.64 \nL 316.345682 224.64 \nL 310.258409 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_47\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 316.345682 224.64 \nL 322.432955 224.64 \nL 322.432955 224.64 \nL 316.345682 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_48\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 322.432955 224.64 \nL 328.520227 224.64 \nL 328.520227 224.64 \nL 322.432955 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_49\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 328.520227 224.64 \nL 334.6075 224.64 \nL 334.6075 224.64 \nL 328.520227 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_50\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 334.6075 224.64 \nL 340.694773 224.64 \nL 340.694773 224.64 \nL 334.6075 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_51\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 340.694773 224.64 \nL 346.782045 224.64 \nL 346.782045 224.64 \nL 340.694773 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_52\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 346.782045 224.64 \nL 352.869318 224.64 \nL 352.869318 223.831071 \nL 346.782045 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"matplotlib.axis_1\">\n    <g id=\"xtick_1\">\n     <g id=\"line2d_1\">\n      <defs>\n       <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"m6e779eec92\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"48.27002\" xlink:href=\"#m6e779eec92\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_1\">\n      <!-- 0 -->\n      <defs>\n       <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421875 \nQ 19.53125 -1.421875 13.0625 8.265625 \nQ 6.59375 17.96875 6.59375 36.375 \nQ 6.59375 54.828125 13.0625 64.515625 \nQ 19.53125 74.21875 31.78125 74.21875 \nz\n\" id=\"DejaVuSans-48\"/>\n      </defs>\n      <g transform=\"translate(45.08877 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_2\">\n     <g id=\"line2d_2\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"113.46009\" xlink:href=\"#m6e779eec92\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_2\">\n      <!-- 200000 -->\n      <defs>\n       <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n      </defs>\n      <g transform=\"translate(94.37259 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_3\">\n     <g id=\"line2d_3\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"178.65016\" xlink:href=\"#m6e779eec92\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_3\">\n      <!-- 400000 -->\n      <defs>\n       <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375 17.1875 \nL 47.609375 0 \nL 37.796875 0 \nL 37.796875 17.1875 \nL 4.890625 17.1875 \nL 4.890625 26.703125 \nz\n\" id=\"DejaVuSans-52\"/>\n      </defs>\n      <g transform=\"translate(159.56266 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-52\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_4\">\n     <g id=\"line2d_4\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"243.84023\" xlink:href=\"#m6e779eec92\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_4\">\n      <!-- 600000 -->\n      <defs>\n       <path d=\"M 33.015625 40.375 \nQ 26.375 40.375 22.484375 35.828125 \nQ 18.609375 31.296875 18.609375 23.390625 \nQ 18.609375 15.53125 22.484375 10.953125 \nQ 26.375 6.390625 33.015625 6.390625 \nQ 39.65625 6.390625 43.53125 10.953125 \nQ 47.40625 15.53125 47.40625 23.390625 \nQ 47.40625 31.296875 43.53125 35.828125 \nQ 39.65625 40.375 33.015625 40.375 \nz\nM 52.59375 71.296875 \nL 52.59375 62.3125 \nQ 48.875 64.0625 45.09375 64.984375 \nQ 41.3125 65.921875 37.59375 65.921875 \nQ 27.828125 65.921875 22.671875 59.328125 \nQ 17.53125 52.734375 16.796875 39.40625 \nQ 19.671875 43.65625 24.015625 45.921875 \nQ 28.375 48.1875 33.59375 48.1875 \nQ 44.578125 48.1875 50.953125 41.515625 \nQ 57.328125 34.859375 57.328125 23.390625 \nQ 57.328125 12.15625 50.6875 5.359375 \nQ 44.046875 -1.421875 33.015625 -1.421875 \nQ 20.359375 -1.421875 13.671875 8.265625 \nQ 6.984375 17.96875 6.984375 36.375 \nQ 6.984375 53.65625 15.1875 63.9375 \nQ 23.390625 74.21875 37.203125 74.21875 \nQ 40.921875 74.21875 44.703125 73.484375 \nQ 48.484375 72.75 52.59375 71.296875 \nz\n\" id=\"DejaVuSans-54\"/>\n      </defs>\n      <g transform=\"translate(224.75273 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-54\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_5\">\n     <g id=\"line2d_5\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"309.0303\" xlink:href=\"#m6e779eec92\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_5\">\n      <!-- 800000 -->\n      <defs>\n       <path d=\"M 31.78125 34.625 \nQ 24.75 34.625 20.71875 30.859375 \nQ 16.703125 27.09375 16.703125 20.515625 \nQ 16.703125 13.921875 20.71875 10.15625 \nQ 24.75 6.390625 31.78125 6.390625 \nQ 38.8125 6.390625 42.859375 10.171875 \nQ 46.921875 13.96875 46.921875 20.515625 \nQ 46.921875 27.09375 42.890625 30.859375 \nQ 38.875 34.625 31.78125 34.625 \nz\nM 21.921875 38.8125 \nQ 15.578125 40.375 12.03125 44.71875 \nQ 8.5 49.078125 8.5 55.328125 \nQ 8.5 64.0625 14.71875 69.140625 \nQ 20.953125 74.21875 31.78125 74.21875 \nQ 42.671875 74.21875 48.875 69.140625 \nQ 55.078125 64.0625 55.078125 55.328125 \nQ 55.078125 49.078125 51.53125 44.71875 \nQ 48 40.375 41.703125 38.8125 \nQ 48.828125 37.15625 52.796875 32.3125 \nQ 56.78125 27.484375 56.78125 20.515625 \nQ 56.78125 9.90625 50.3125 4.234375 \nQ 43.84375 -1.421875 31.78125 -1.421875 \nQ 19.734375 -1.421875 13.25 4.234375 \nQ 6.78125 9.90625 6.78125 20.515625 \nQ 6.78125 27.484375 10.78125 32.3125 \nQ 14.796875 37.15625 21.921875 38.8125 \nz\nM 18.3125 54.390625 \nQ 18.3125 48.734375 21.84375 45.5625 \nQ 25.390625 42.390625 31.78125 42.390625 \nQ 38.140625 42.390625 41.71875 45.5625 \nQ 45.3125 48.734375 45.3125 54.390625 \nQ 45.3125 60.0625 41.71875 63.234375 \nQ 38.140625 66.40625 31.78125 66.40625 \nQ 25.390625 66.40625 21.84375 63.234375 \nQ 18.3125 60.0625 18.3125 54.390625 \nz\n\" id=\"DejaVuSans-56\"/>\n      </defs>\n      <g transform=\"translate(289.9428 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-56\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"matplotlib.axis_2\">\n    <g id=\"ytick_1\">\n     <g id=\"line2d_6\">\n      <defs>\n       <path d=\"M 0 0 \nL -3.5 0 \n\" id=\"md0c86dbdd4\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"33.2875\" xlink:href=\"#md0c86dbdd4\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_6\">\n      <!-- 0 -->\n      <g transform=\"translate(19.925 228.439219)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_2\">\n     <g id=\"line2d_7\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"33.2875\" xlink:href=\"#md0c86dbdd4\" y=\"184.193571\"/>\n      </g>\n     </g>\n     <g id=\"text_7\">\n      <!-- 50 -->\n      <defs>\n       <path d=\"M 10.796875 72.90625 \nL 49.515625 72.90625 \nL 49.515625 64.59375 \nL 19.828125 64.59375 \nL 19.828125 46.734375 \nQ 21.96875 47.46875 24.109375 47.828125 \nQ 26.265625 48.1875 28.421875 48.1875 \nQ 40.625 48.1875 47.75 41.5 \nQ 54.890625 34.8125 54.890625 23.390625 \nQ 54.890625 11.625 47.5625 5.09375 \nQ 40.234375 -1.421875 26.90625 -1.421875 \nQ 22.3125 -1.421875 17.546875 -0.640625 \nQ 12.796875 0.140625 7.71875 1.703125 \nL 7.71875 11.625 \nQ 12.109375 9.234375 16.796875 8.0625 \nQ 21.484375 6.890625 26.703125 6.890625 \nQ 35.15625 6.890625 40.078125 11.328125 \nQ 45.015625 15.765625 45.015625 23.390625 \nQ 45.015625 31 40.078125 35.4375 \nQ 35.15625 39.890625 26.703125 39.890625 \nQ 22.75 39.890625 18.8125 39.015625 \nQ 14.890625 38.140625 10.796875 36.28125 \nz\n\" id=\"DejaVuSans-53\"/>\n      </defs>\n      <g transform=\"translate(13.5625 187.99279)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-53\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_3\">\n     <g id=\"line2d_8\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"33.2875\" xlink:href=\"#md0c86dbdd4\" y=\"143.747143\"/>\n      </g>\n     </g>\n     <g id=\"text_8\">\n      <!-- 100 -->\n      <defs>\n       <path d=\"M 12.40625 8.296875 \nL 28.515625 8.296875 \nL 28.515625 63.921875 \nL 10.984375 60.40625 \nL 10.984375 69.390625 \nL 28.421875 72.90625 \nL 38.28125 72.90625 \nL 38.28125 8.296875 \nL 54.390625 8.296875 \nL 54.390625 0 \nL 12.40625 0 \nz\n\" id=\"DejaVuSans-49\"/>\n      </defs>\n      <g transform=\"translate(7.2 147.546362)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_4\">\n     <g id=\"line2d_9\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"33.2875\" xlink:href=\"#md0c86dbdd4\" y=\"103.300714\"/>\n      </g>\n     </g>\n     <g id=\"text_9\">\n      <!-- 150 -->\n      <g transform=\"translate(7.2 107.099933)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_5\">\n     <g id=\"line2d_10\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"33.2875\" xlink:href=\"#md0c86dbdd4\" y=\"62.854286\"/>\n      </g>\n     </g>\n     <g id=\"text_10\">\n      <!-- 200 -->\n      <g transform=\"translate(7.2 66.653504)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_6\">\n     <g id=\"line2d_11\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"33.2875\" xlink:href=\"#md0c86dbdd4\" y=\"22.407857\"/>\n      </g>\n     </g>\n     <g id=\"text_11\">\n      <!-- 250 -->\n      <g transform=\"translate(7.2 26.207076)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"patch_53\">\n    <path d=\"M 33.2875 224.64 \nL 33.2875 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_54\">\n    <path d=\"M 368.0875 224.64 \nL 368.0875 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_55\">\n    <path d=\"M 33.2875 224.64 \nL 368.0875 224.64 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_56\">\n    <path d=\"M 33.2875 7.2 \nL 368.0875 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n  </g>\n </g>\n <defs>\n  <clipPath id=\"p8459d3defc\">\n   <rect height=\"217.44\" width=\"334.8\" x=\"33.2875\" y=\"7.2\"/>\n  </clipPath>\n </defs>\n</svg>\n",
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPuklEQVR4nO3dXYxcZ33H8e+vMYQWEHEa1zKxXRvkIpmLJukqCYKLtGlJiCoCEoqcVMSlIKM2kaBFqhK4gF5EohUvJWobMCQlVJiQQmgslJYGNxLiAoNN07ybGPJmy4kNtAkqEmrCvxfzOBnMvs/ujvfZ70cazTn/c86cZ84+/u3xM2fOpqqQJPXlV8bdAEnSwjPcJalDhrskdchwl6QOGe6S1KFV424AwBlnnFGbNm0adzMkaVnZv3//D6tqzWTLTopw37RpE/v27Rt3MyRpWUny2FTLHJaRpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOnRTfUB3Frr2PT7nsivM2LmFLJOnkMeOZe5INSe5K8kCS+5O8p9U/lORwkrvb45Khba5NcjDJgSQXLeYbkCT9stmcuT8LvK+qvpvk5cD+JHe2ZR+vqo8Mr5xkK7ANeC3wSuDrSX6rqp5byIZLkqY245l7VR2pqu+26Z8ADwJnTrPJpcAtVfWzqnoEOAicuxCNlSTNzpw+UE2yCTgb2NtKVye5J8lNSVa32pnAE0ObHWL6XwaSpAU263BP8jLgy8B7q+oZ4Abg1cBZwBHgo3PZcZIdSfYl2Xfs2LG5bCpJmsGswj3JixgE++er6jaAqnqqqp6rqp8Dn+aFoZfDwIahzde32i+oqp1VNVFVE2vWTHqveUnSPM3mapkANwIPVtXHhurrhlZ7K3Bfm94NbEtyapLNwBbg2wvXZEnSTGZztczrgbcD9ya5u9XeD1ye5CyggEeBdwNU1f1JbgUeYHClzVVeKSNJS2vGcK+qbwKZZNEd02xzHXDdCO2SJI3A2w9IUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOjRjuCfZkOSuJA8kuT/Je1r99CR3Jnm4Pa9u9SS5PsnBJPckOWex34Qk6RfN5sz9WeB9VbUVOB+4KslW4BpgT1VtAfa0eYA3AVvaYwdww4K3WpI0rRnDvaqOVNV32/RPgAeBM4FLgZvbajcDb2nTlwKfq4FvAaclWbfgLZckTWlOY+5JNgFnA3uBtVV1pC16Eljbps8Enhja7FCrnfhaO5LsS7Lv2LFjc2y2JGk6sw73JC8Dvgy8t6qeGV5WVQXUXHZcVTuraqKqJtasWTOXTSVJM5hVuCd5EYNg/3xV3dbKTx0fbmnPR1v9MLBhaPP1rSZJWiKzuVomwI3Ag1X1saFFu4HtbXo7cPtQ/cp21cz5wNNDwzeSpCWwahbrvB54O3Bvkrtb7f3Ah4Fbk7wTeAy4rC27A7gEOAj8FHjHgrZYkjSjGcO9qr4JZIrFF06yfgFXjdguSdII/IaqJHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOz+YbqsrVr7+OT1q84b+MSt0SSlpZn7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDnX9xzqm4h/xkNQ7z9wlqUOGuyR1yHCXpA4Z7pLUIcNdkjo0Y7gnuSnJ0ST3DdU+lORwkrvb45KhZdcmOZjkQJKLFqvhkqSpzebM/bPAxZPUP15VZ7XHHQBJtgLbgNe2bf4hySkL1VhJ0uzMGO5V9Q3gx7N8vUuBW6rqZ1X1CHAQOHeE9kmS5mGUMferk9zThm1Wt9qZwBND6xxqtV+SZEeSfUn2HTt2bIRmSJJONN9wvwF4NXAWcAT46FxfoKp2VtVEVU2sWbNmns2QJE1mXuFeVU9V1XNV9XPg07ww9HIY2DC06vpWkyQtoXmFe5J1Q7NvBY5fSbMb2Jbk1CSbgS3At0droiRprma8cViSLwAXAGckOQR8ELggyVlAAY8C7waoqvuT3Ao8ADwLXFVVzy1O0yVJU5kx3Kvq8knKN06z/nXAdaM0SpI0Gr+hKkkdMtwlqUOGuyR1aEX+Jaap+BeaJPXCM3dJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1aMZwT3JTkqNJ7huqnZ7kziQPt+fVrZ4k1yc5mOSeJOcsZuMlSZObzZn7Z4GLT6hdA+ypqi3AnjYP8CZgS3vsAG5YmGZKkuZixnCvqm8APz6hfClwc5u+GXjLUP1zNfAt4LQk6xaqsZKk2Vk1z+3WVtWRNv0ksLZNnwk8MbTeoVY7wgmS7GBwds/GjRvn2YylsWvv45PWrzjv5G63pJVr5A9Uq6qAmsd2O6tqoqom1qxZM2ozJElD5hvuTx0fbmnPR1v9MLBhaL31rSZJWkLzDffdwPY2vR24fah+Zbtq5nzg6aHhG0nSEplxzD3JF4ALgDOSHAI+CHwYuDXJO4HHgMva6ncAlwAHgZ8C71iENkuSZjBjuFfV5VMsunCSdQu4atRGSZJG4zdUJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOrRp3A5azXXsfn7R+xXkbl7glkvSLPHOXpA4Z7pLUIcNdkjo00ph7kkeBnwDPAc9W1USS04EvApuAR4HLquq/R2umJGkuFuLM/Xer6qyqmmjz1wB7qmoLsKfNS5KW0GIMy1wK3Nymbwbesgj7kCRNY9RwL+Dfk+xPsqPV1lbVkTb9JLB2sg2T7EiyL8m+Y8eOjdgMSdKwUa9zf0NVHU7yG8CdSR4aXlhVlaQm27CqdgI7ASYmJiZdR5I0PyOduVfV4fZ8FPgKcC7wVJJ1AO356KiNlCTNzbzDPclLk7z8+DTwRuA+YDewva22Hbh91EZKkuZmlGGZtcBXkhx/nV1V9W9JvgPcmuSdwGPAZaM3U5I0F/MO96r6AfDbk9R/BFw4SqMkSaPxG6qS1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOjTqLX81iV17H5+0fsV5G5e4JZJWKs/cJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXI2w8sobnelsDbGEiaL8/cJalDhrskdchhmZPAVMMvkjRfnrlLUocMd0nqkMMyK8B0wz4n25U6XiEkLQzDvSOO3Us6znCXtGj8n9j4GO4rnGf7Up8Md83JXH8ZLNQZ2nI6A1xObVW/Fi3ck1wMfAI4BfhMVX14sfa10ni2PX8n44fLC/XzXKj2z+cY6eSzKOGe5BTg74E/AA4B30myu6oeWIz96eS12L+I/EUnTW6xztzPBQ5W1Q8AktwCXAoY7jpp+YtoZuN6D+MaDlxISz1cl6pa+BdN3gZcXFXvavNvB86rqquH1tkB7GizrwEOzHN3ZwA/HKG5PfAYDHgcPAawso7Bb1bVmskWjO0D1araCewc9XWS7KuqiQVo0rLlMRjwOHgMwGNw3GLdfuAwsGFofn2rSZKWwGKF+3eALUk2J3kxsA3YvUj7kiSdYFGGZarq2SRXA19jcCnkTVV1/2LsiwUY2umAx2DA4+AxAI8BsEgfqEqSxstb/kpShwx3SerQsg73JBcnOZDkYJJrxt2euUqyIcldSR5Icn+S97T66UnuTPJwe17d6klyfXu/9yQ5Z+i1trf1H06yfaj+O0nubdtcnyTT7WNckpyS5D+TfLXNb06yt7X7i+2DeZKc2uYPtuWbhl7j2lY/kOSiofqk/WSqfYxLktOSfCnJQ0keTPK6ldYXkvx5+7dwX5IvJHnJSuwLC6KqluWDwQe13wdeBbwY+C9g67jbNcf3sA44p02/HPgesBX4G+CaVr8G+Os2fQnwr0CA84G9rX468IP2vLpNr27Lvt3WTdv2Ta0+6T7GeCz+AtgFfLXN3wpsa9OfBP60Tf8Z8Mk2vQ34Ypve2vrAqcDm1jdOma6fTLWPMR6Dm4F3tekXA6etpL4AnAk8Avzq0M/nj1diX1iQ4znuBozQEV4HfG1o/lrg2nG3a8T3dDuD+/EcANa12jrgQJv+FHD50PoH2vLLgU8N1T/VauuAh4bqz6831T7G9L7XA3uA3wO+2sLnh8CqE3/WDK7Ael2bXtXWy4k//+PrTdVPptvHmI7BK1qw5YT6iukLDML9CQa/mFa1vnDRSusLC/VYzsMyxzvCcYdabVlq/6U8G9gLrK2qI23Rk8DaNj3Ve56ufmiSOtPsYxz+FvhL4Odt/teB/6mqZ9v8cLuff69t+dNt/bkem+n2MQ6bgWPAP7bhqc8keSkrqC9U1WHgI8DjwBEGP9v9rLy+sCCWc7h3I8nLgC8D762qZ4aX1eBUYlGvV12KfUwlyR8CR6tq/zj2fxJZBZwD3FBVZwP/y2CI5HkroC+sZnCDwc3AK4GXAhePoy09WM7h3sUtDpK8iEGwf76qbmvlp5Ksa8vXAUdbfar3PF19/ST16fax1F4PvDnJo8AtDIZmPgGcluT4l+yG2/38e23LXwH8iLkfmx9Ns49xOAQcqqq9bf5LDMJ+JfWF3wceqapjVfV/wG0M+sdK6wsLYjmH+7K/xUG7WuFG4MGq+tjQot3A8asctjMYiz9ev7JdKXE+8HT77/TXgDcmWd3Oft7IYMzwCPBMkvPbvq484bUm28eSqqprq2p9VW1i8DP8j6r6I+Au4G2TtG+43W9r61erb2tXUGwGtjD4AHHSftK2mWofS66qngSeSPKaVrqQwS2yV0xfYDAcc36SX2ttPH4MVlRfWDDjHvQf5cHgioHvMfgE/APjbs882v8GBv8Fvge4uz0uYTAGuAd4GPg6cHpbPwz+CMr3gXuBiaHX+hPgYHu8Y6g+AdzXtvk7XvhW8qT7GPPxuIAXrpZ5FYN/kAeBfwZObfWXtPmDbfmrhrb/QHufB2hXgkzXT6baxxjf/1nAvtYf/oXB1S4rqi8AfwU81Nr5TwyueFlxfWEhHt5+QJI6tJyHZSRJUzDcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUof+HwNE9h5uGuuvAAAAAElFTkSuQmCC\n"
+     },
+     "metadata": {
+      "needs_background": "light"
+     }
+    }
+   ],
+   "source": [
+    "sns.distplot(document_lens, kde=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": 3
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python_defaultSpec_1594287746726",
+   "display_name": "Python 3.8.2 64-bit"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
\ No newline at end of file
diff --git a/dataset_generation/processing.py b/dataset_generation/processing.py
new file mode 100644
index 0000000..2ab5359
--- /dev/null
+++ b/dataset_generation/processing.py
@@ -0,0 +1,149 @@
+import glob
+import random
+from lxml import etree
+import uuid
+import hashlib
+import seaborn as sns
+import re
+from tqdm import tqdm
+from typing import Optional, Mapping
+from utils import remove_punctuation
+import numpy as np
+from more_itertools import windowed
+
+ACTIONS_KEYS = ['dot', 'upper_case', 'colon', 'semicolon', 'elipsis', 'dash']
+
+def text_from_xml(path: str) -> str:
+    """Extract spoken text from dataset's xml format
+
+    Args:
+        path (str): Path to xml
+
+    Returns:
+        str: Raw text
+    """
+    root = etree.parse(path)
+    full_text = ""
+
+    for node in root.iter('*'):
+        if len(node) == 0:
+            who = node.get("who")
+            text = node.text
+
+            if text is not None and who is not None and who != "#komentarz":
+                full_text = " ".join([full_text, text])
+
+    return full_text
+
+
+def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
+    """Detect what actions should model perform on a word and returns encoded action vector
+
+    Args:
+        word (str): Word on wich action is decided
+        next_word (Optional[str]): Word that follows considered word. Can be None if nothing follows a word
+
+    Returns:
+        Mapping[str, bool]: Mapping telling if each of possible actions should be performed (True) or not (False) 
+    """
+    word.replace('"', " ") # No support for quotes
+
+    actions = {
+        'dot': word[0] == '.',
+        'upper_case': word[0].isupper(),
+        'colon': word[-1] == ",",
+        'semicolon': word[-1] == ";",
+        'elipsis': len(word) > 3 and word[:-3] == "...",
+        'dash': next_word is not None and next_word == "-"
+    }
+
+    return actions
+
+def encode_actions(actions: Mapping[str, bool]) -> np.ndarray:
+    """Transforms actions into vector
+
+    Args:
+        actions (Mapping[str, bool]): Map telling which actions should be made
+
+    Returns:
+        np.ndarray: 1 dimensional action vector
+    """
+    return np.array(list(actions.values())).astype(float)
+
+def decode_actions(encoded_actions: np.ndarray) -> Mapping[str, bool]:
+    """Decodes actions
+
+    Args:
+        encoded_actions (np.ndarray): 1 dimensional action vector
+
+    Returns:
+        Mapping[str, bool]: Map telling which actions should be made
+    """
+    return dict(zip(ACTIONS_KEYS, encoded_actions.astype(np.bool).tolist()))
+
+def create_model_input_output(text: str) -> (str, np.ndarray):
+    """Returns a pair of input and desired output of the model
+
+    Args:
+        text (str): Correct text sample
+
+    Returns:
+        text_cleaned (str): Text without any interpuction and all lowercase
+        actions (np.ndarray): To dimensional array, where each row is aciton vector for each word (columns)
+    """
+    words = text.split(" ")
+
+    words_output = []
+    actions_output = []
+
+    i = 0
+    while i < len(words):
+        word = words[i]
+        next_word = words[i+1] if len(words) > i+1 else None
+
+        word_sanitized = remove_punctuation(word).lower()
+        if len(word_sanitized) > 0:
+            actions = detect_actions(word, next_word)
+            actions_encoded = encode_actions(actions)
+
+            words_output.append(word_sanitized)
+            actions_output.append(actions_encoded)
+
+        i += 1
+
+    assert len(words_output) == len(actions_output)
+
+    return " ".join(words_output), np.array(actions_output)
+
+def recover_word(word: str, action: Mapping[str, bool]) -> str:
+    word_result = word
+    
+    if action['dot']:
+        word_result += "."
+    if action['upper_case']:
+        word_result[0] = word_result[0].upper()
+    if action['colon']:
+        word_result += ","
+    if action['semicolon']:
+        word_result += ";"
+    if action['elipsis']:
+        word_result += "..."
+    if action['dash']:
+        word_result += " -"
+
+    return word
+
+def recover_text(text: str, actions_encoded: np.ndarray):
+    words = text.split(" ")
+
+    words_output = []
+
+    for word, action_encoded in zip(words, actions_encoded.tolist()):
+        action_decoded = decode_actions(np.array(action_encoded))
+        
+        word_recovered = recover_word(word, action_decoded)
+        words_output.append(word_recovered)
+
+    return " ".join(words_output)
+
+
diff --git a/dataset_generation/utils.py b/dataset_generation/utils.py
new file mode 100644
index 0000000..dc009d7
--- /dev/null
+++ b/dataset_generation/utils.py
@@ -0,0 +1,14 @@
+import glob
+import random
+from lxml import etree
+import uuid
+import hashlib
+import seaborn as sns
+import re
+from tqdm import tqdm
+
+def remove_multiple_spaces(x: str) -> str:
+    return re.sub("\s\s+", " ", x)
+
+def remove_punctuation(x: str) -> str:
+    return ''.join(filter(lambda x: x.isalnum() or x.isspace(), x))
-- 
GitLab


From 5f93306c1d32e573d4f25f5086ac67e41bd8db9d Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 9 Jul 2020 16:58:38 +0200
Subject: [PATCH 002/116] Add README

---
 README.md | 1 +
 1 file changed, 1 insertion(+)
 create mode 100644 README.md

diff --git a/README.md b/README.md
new file mode 100644
index 0000000..4a87d36
--- /dev/null
+++ b/README.md
@@ -0,0 +1 @@
+# Interpunkcja
-- 
GitLab


From ba1fdc613a41cf27ba6cbe9978c460e968e82533 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Fri, 10 Jul 2020 09:43:55 +0200
Subject: [PATCH 003/116] Automatic dataset download

---
 .gitignore                                |  2 ++
 dane/download_dataset.sh                  |  4 +++
 dataset_generation/.gitignore             |  2 --
 dataset_generation/notebook_actions.ipynb | 32 +++++++++++------------
 dataset_generation/notebook_simple.ipynb  |  8 +++---
 5 files changed, 26 insertions(+), 22 deletions(-)
 create mode 100755 dane/download_dataset.sh

diff --git a/.gitignore b/.gitignore
index fcff6f3..e75c3e1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,3 @@
 dane/**
+dataset_simple
+dataset_actions
\ No newline at end of file
diff --git a/dane/download_dataset.sh b/dane/download_dataset.sh
new file mode 100755
index 0000000..5c70a48
--- /dev/null
+++ b/dane/download_dataset.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+wget http://manage.legis.nlp.ipipan.waw.pl/download/ppc-nanno.tar.gz
+tar -xvf ppc-nanno.tar.gz
+rm ppc-nanno.tar.gz
diff --git a/dataset_generation/.gitignore b/dataset_generation/.gitignore
index b43b91b..ed8ebf5 100644
--- a/dataset_generation/.gitignore
+++ b/dataset_generation/.gitignore
@@ -1,3 +1 @@
-dataset_simple
-dataset_actions
 __pycache__
\ No newline at end of file
diff --git a/dataset_generation/notebook_actions.ipynb b/dataset_generation/notebook_actions.ipynb
index aa2522b..dedbf9d 100644
--- a/dataset_generation/notebook_actions.ipynb
+++ b/dataset_generation/notebook_actions.ipynb
@@ -12,7 +12,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -32,7 +32,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -42,7 +42,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -51,7 +51,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 9,
    "metadata": {
     "tags": []
    },
@@ -59,22 +59,22 @@
     {
      "output_type": "stream",
      "name": "stderr",
-     "text": "100%|██████████| 1000/1000 [00:09<00:00, 109.49it/s]\n"
+     "text": "100%|██████████| 1000/1000 [00:09<00:00, 103.87it/s]\n"
     }
    ],
    "source": [
     "num_exported = 0\n",
     "document_lens = []\n",
     "\n",
-    "if not os.path.exists(\"dataset_actions\"):\n",
-    "    os.mkdir(\"dataset_actions\")\n",
+    "if not os.path.exists(\"../dataset_actions\"):\n",
+    "    os.mkdir(\"../dataset_actions\")\n",
     "\n",
     "for file_path in tqdm(files_subset):\n",
     "    full_text = text_from_xml(file_path)\n",
     "    \n",
     "    if len(full_text) > 0:\n",
-    "        output_file_input = f\"dataset_actions/{hashlib.md5(file_path.encode()).hexdigest()}_input.txt\"\n",
-    "        output_file_output = f\"dataset_actions/{hashlib.md5(file_path.encode()).hexdigest()}_output.txt\"\n",
+    "        output_file_input = f\"../dataset_actions/{hashlib.md5(file_path.encode()).hexdigest()}_input.txt\"\n",
+    "        output_file_output = f\"../dataset_actions/{hashlib.md5(file_path.encode()).hexdigest()}_output.txt\"\n",
     "\n",
     "        model_input, model_output = create_model_input_output(full_text)\n",
     "\n",
@@ -89,7 +89,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 8,
    "metadata": {
     "tags": []
    },
@@ -97,17 +97,17 @@
     {
      "output_type": "execute_result",
      "data": {
-      "text/plain": "<matplotlib.axes._subplots.AxesSubplot at 0x7f418d4751f0>"
+      "text/plain": "<matplotlib.axes._subplots.AxesSubplot at 0x7f018b71d910>"
      },
      "metadata": {},
-     "execution_count": 6
+     "execution_count": 8
     },
     {
      "output_type": "display_data",
      "data": {
       "text/plain": "<Figure size 432x288 with 1 Axes>",
-      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"249.103945pt\" version=\"1.1\" viewBox=\"0 0 368.925 249.103945\" width=\"368.925pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n  <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n  </style>\n </defs>\n <g id=\"figure_1\">\n  <g id=\"patch_1\">\n   <path d=\"M -0 249.103945 \nL 368.925 249.103945 \nL 368.925 0 \nL -0 0 \nz\n\" style=\"fill:none;\"/>\n  </g>\n  <g id=\"axes_1\">\n   <g id=\"patch_2\">\n    <path d=\"M 26.925 225.22582 \nL 361.725 225.22582 \nL 361.725 7.78582 \nL 26.925 7.78582 \nz\n\" style=\"fill:#ffffff;\"/>\n   </g>\n   <g id=\"patch_3\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 42.143182 225.22582 \nL 53.849476 225.22582 \nL 53.849476 32.421879 \nL 42.143182 32.421879 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_4\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 53.849476 225.22582 \nL 65.555769 225.22582 \nL 65.555769 18.140105 \nL 53.849476 18.140105 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_5\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 65.555769 225.22582 \nL 77.262063 225.22582 \nL 77.262063 103.830746 \nL 65.555769 103.830746 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_6\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 77.262063 225.22582 \nL 88.968357 225.22582 \nL 88.968357 118.112519 \nL 77.262063 118.112519 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_7\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 88.968357 225.22582 \nL 100.67465 225.22582 \nL 100.67465 189.521386 \nL 88.968357 189.521386 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_8\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 100.67465 225.22582 \nL 112.380944 225.22582 \nL 112.380944 203.80316 \nL 100.67465 203.80316 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_9\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 112.380944 225.22582 \nL 124.087238 225.22582 \nL 124.087238 203.80316 \nL 112.380944 203.80316 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_10\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 124.087238 225.22582 \nL 135.793531 225.22582 \nL 135.793531 218.084933 \nL 124.087238 218.084933 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_11\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 135.793531 225.22582 \nL 147.499825 225.22582 \nL 147.499825 203.80316 \nL 135.793531 203.80316 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_12\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 147.499825 225.22582 \nL 159.206119 225.22582 \nL 159.206119 225.22582 \nL 147.499825 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_13\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 159.206119 225.22582 \nL 170.912413 225.22582 \nL 170.912413 225.22582 \nL 159.206119 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_14\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 170.912413 225.22582 \nL 182.618706 225.22582 \nL 182.618706 225.22582 \nL 170.912413 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_15\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 182.618706 225.22582 \nL 194.325 225.22582 \nL 194.325 218.084933 \nL 182.618706 218.084933 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_16\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 194.325 225.22582 \nL 206.031294 225.22582 \nL 206.031294 225.22582 \nL 194.325 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_17\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 206.031294 225.22582 \nL 217.737587 225.22582 \nL 217.737587 210.944046 \nL 206.031294 210.944046 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_18\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 217.737587 225.22582 \nL 229.443881 225.22582 \nL 229.443881 225.22582 \nL 217.737587 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_19\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 229.443881 225.22582 \nL 241.150175 225.22582 \nL 241.150175 225.22582 \nL 229.443881 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_20\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 241.150175 225.22582 \nL 252.856469 225.22582 \nL 252.856469 225.22582 \nL 241.150175 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_21\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 252.856469 225.22582 \nL 264.562762 225.22582 \nL 264.562762 218.084933 \nL 252.856469 218.084933 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_22\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 264.562762 225.22582 \nL 276.269056 225.22582 \nL 276.269056 218.084933 \nL 264.562762 218.084933 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_23\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 276.269056 225.22582 \nL 287.97535 225.22582 \nL 287.97535 225.22582 \nL 276.269056 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_24\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 287.97535 225.22582 \nL 299.681643 225.22582 \nL 299.681643 218.084933 \nL 287.97535 218.084933 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_25\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 299.681643 225.22582 \nL 311.387937 225.22582 \nL 311.387937 225.22582 \nL 299.681643 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_26\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 311.387937 225.22582 \nL 323.094231 225.22582 \nL 323.094231 225.22582 \nL 311.387937 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_27\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 323.094231 225.22582 \nL 334.800524 225.22582 \nL 334.800524 225.22582 \nL 323.094231 225.22582 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_28\">\n    <path clip-path=\"url(#p1f2e03f3c1)\" d=\"M 334.800524 225.22582 \nL 346.506818 225.22582 \nL 346.506818 218.084933 \nL 334.800524 218.084933 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"matplotlib.axis_1\">\n    <g id=\"xtick_1\">\n     <g id=\"line2d_1\">\n      <defs>\n       <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"mcb448f674c\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"41.035222\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_1\">\n      <!-- 0 -->\n      <defs>\n       <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421875 \nQ 19.53125 -1.421875 13.0625 8.265625 \nQ 6.59375 17.96875 6.59375 36.375 \nQ 6.59375 54.828125 13.0625 64.515625 \nQ 19.53125 74.21875 31.78125 74.21875 \nz\n\" id=\"DejaVuSans-48\"/>\n      </defs>\n      <g transform=\"translate(37.853972 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_2\">\n     <g id=\"line2d_2\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"90.853559\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_2\">\n      <!-- 100000 -->\n      <defs>\n       <path d=\"M 12.40625 8.296875 \nL 28.515625 8.296875 \nL 28.515625 63.921875 \nL 10.984375 60.40625 \nL 10.984375 69.390625 \nL 28.421875 72.90625 \nL 38.28125 72.90625 \nL 38.28125 8.296875 \nL 54.390625 8.296875 \nL 54.390625 0 \nL 12.40625 0 \nz\n\" id=\"DejaVuSans-49\"/>\n      </defs>\n      <g transform=\"translate(71.766059 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_3\">\n     <g id=\"line2d_3\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"140.671896\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_3\">\n      <!-- 200000 -->\n      <defs>\n       <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n      </defs>\n      <g transform=\"translate(121.584396 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_4\">\n     <g id=\"line2d_4\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"190.490233\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_4\">\n      <!-- 300000 -->\n      <defs>\n       <path d=\"M 40.578125 39.3125 \nQ 47.65625 37.796875 51.625 33 \nQ 55.609375 28.21875 55.609375 21.1875 \nQ 55.609375 10.40625 48.1875 4.484375 \nQ 40.765625 -1.421875 27.09375 -1.421875 \nQ 22.515625 -1.421875 17.65625 -0.515625 \nQ 12.796875 0.390625 7.625 2.203125 \nL 7.625 11.71875 \nQ 11.71875 9.328125 16.59375 8.109375 \nQ 21.484375 6.890625 26.8125 6.890625 \nQ 36.078125 6.890625 40.9375 10.546875 \nQ 45.796875 14.203125 45.796875 21.1875 \nQ 45.796875 27.640625 41.28125 31.265625 \nQ 36.765625 34.90625 28.71875 34.90625 \nL 20.21875 34.90625 \nL 20.21875 43.015625 \nL 29.109375 43.015625 \nQ 36.375 43.015625 40.234375 45.921875 \nQ 44.09375 48.828125 44.09375 54.296875 \nQ 44.09375 59.90625 40.109375 62.90625 \nQ 36.140625 65.921875 28.71875 65.921875 \nQ 24.65625 65.921875 20.015625 65.03125 \nQ 15.375 64.15625 9.8125 62.3125 \nL 9.8125 71.09375 \nQ 15.4375 72.65625 20.34375 73.4375 \nQ 25.25 74.21875 29.59375 74.21875 \nQ 40.828125 74.21875 47.359375 69.109375 \nQ 53.90625 64.015625 53.90625 55.328125 \nQ 53.90625 49.265625 50.4375 45.09375 \nQ 46.96875 40.921875 40.578125 39.3125 \nz\n\" id=\"DejaVuSans-51\"/>\n      </defs>\n      <g transform=\"translate(171.402733 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-51\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_5\">\n     <g id=\"line2d_5\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"240.308571\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_5\">\n      <!-- 400000 -->\n      <defs>\n       <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375 17.1875 \nL 47.609375 0 \nL 37.796875 0 \nL 37.796875 17.1875 \nL 4.890625 17.1875 \nL 4.890625 26.703125 \nz\n\" id=\"DejaVuSans-52\"/>\n      </defs>\n      <g transform=\"translate(221.221071 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-52\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_6\">\n     <g id=\"line2d_6\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"290.126908\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_6\">\n      <!-- 500000 -->\n      <defs>\n       <path d=\"M 10.796875 72.90625 \nL 49.515625 72.90625 \nL 49.515625 64.59375 \nL 19.828125 64.59375 \nL 19.828125 46.734375 \nQ 21.96875 47.46875 24.109375 47.828125 \nQ 26.265625 48.1875 28.421875 48.1875 \nQ 40.625 48.1875 47.75 41.5 \nQ 54.890625 34.8125 54.890625 23.390625 \nQ 54.890625 11.625 47.5625 5.09375 \nQ 40.234375 -1.421875 26.90625 -1.421875 \nQ 22.3125 -1.421875 17.546875 -0.640625 \nQ 12.796875 0.140625 7.71875 1.703125 \nL 7.71875 11.625 \nQ 12.109375 9.234375 16.796875 8.0625 \nQ 21.484375 6.890625 26.703125 6.890625 \nQ 35.15625 6.890625 40.078125 11.328125 \nQ 45.015625 15.765625 45.015625 23.390625 \nQ 45.015625 31 40.078125 35.4375 \nQ 35.15625 39.890625 26.703125 39.890625 \nQ 22.75 39.890625 18.8125 39.015625 \nQ 14.890625 38.140625 10.796875 36.28125 \nz\n\" id=\"DejaVuSans-53\"/>\n      </defs>\n      <g transform=\"translate(271.039408 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-53\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_7\">\n     <g id=\"line2d_7\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"339.945245\" xlink:href=\"#mcb448f674c\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_7\">\n      <!-- 600000 -->\n      <defs>\n       <path d=\"M 33.015625 40.375 \nQ 26.375 40.375 22.484375 35.828125 \nQ 18.609375 31.296875 18.609375 23.390625 \nQ 18.609375 15.53125 22.484375 10.953125 \nQ 26.375 6.390625 33.015625 6.390625 \nQ 39.65625 6.390625 43.53125 10.953125 \nQ 47.40625 15.53125 47.40625 23.390625 \nQ 47.40625 31.296875 43.53125 35.828125 \nQ 39.65625 40.375 33.015625 40.375 \nz\nM 52.59375 71.296875 \nL 52.59375 62.3125 \nQ 48.875 64.0625 45.09375 64.984375 \nQ 41.3125 65.921875 37.59375 65.921875 \nQ 27.828125 65.921875 22.671875 59.328125 \nQ 17.53125 52.734375 16.796875 39.40625 \nQ 19.671875 43.65625 24.015625 45.921875 \nQ 28.375 48.1875 33.59375 48.1875 \nQ 44.578125 48.1875 50.953125 41.515625 \nQ 57.328125 34.859375 57.328125 23.390625 \nQ 57.328125 12.15625 50.6875 5.359375 \nQ 44.046875 -1.421875 33.015625 -1.421875 \nQ 20.359375 -1.421875 13.671875 8.265625 \nQ 6.984375 17.96875 6.984375 36.375 \nQ 6.984375 53.65625 15.1875 63.9375 \nQ 23.390625 74.21875 37.203125 74.21875 \nQ 40.921875 74.21875 44.703125 73.484375 \nQ 48.484375 72.75 52.59375 71.296875 \nz\n\" id=\"DejaVuSans-54\"/>\n      </defs>\n      <g transform=\"translate(320.857745 239.824257)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-54\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"matplotlib.axis_2\">\n    <g id=\"ytick_1\">\n     <g id=\"line2d_8\">\n      <defs>\n       <path d=\"M 0 0 \nL -3.5 0 \n\" id=\"m50745714e3\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"225.22582\"/>\n      </g>\n     </g>\n     <g id=\"text_8\">\n      <!-- 0 -->\n      <g transform=\"translate(13.5625 229.025038)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_2\">\n     <g id=\"line2d_9\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"189.521386\"/>\n      </g>\n     </g>\n     <g id=\"text_9\">\n      <!-- 5 -->\n      <g transform=\"translate(13.5625 193.320605)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_3\">\n     <g id=\"line2d_10\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"153.816953\"/>\n      </g>\n     </g>\n     <g id=\"text_10\">\n      <!-- 10 -->\n      <g transform=\"translate(7.2 157.616171)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_4\">\n     <g id=\"line2d_11\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"118.112519\"/>\n      </g>\n     </g>\n     <g id=\"text_11\">\n      <!-- 15 -->\n      <g transform=\"translate(7.2 121.911738)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_5\">\n     <g id=\"line2d_12\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"82.408086\"/>\n      </g>\n     </g>\n     <g id=\"text_12\">\n      <!-- 20 -->\n      <g transform=\"translate(7.2 86.207304)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_6\">\n     <g id=\"line2d_13\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"46.703652\"/>\n      </g>\n     </g>\n     <g id=\"text_13\">\n      <!-- 25 -->\n      <g transform=\"translate(7.2 50.502871)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_7\">\n     <g id=\"line2d_14\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m50745714e3\" y=\"10.999219\"/>\n      </g>\n     </g>\n     <g id=\"text_14\">\n      <!-- 30 -->\n      <g transform=\"translate(7.2 14.798437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-51\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"patch_29\">\n    <path d=\"M 26.925 225.22582 \nL 26.925 7.78582 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_30\">\n    <path d=\"M 361.725 225.22582 \nL 361.725 7.78582 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_31\">\n    <path d=\"M 26.925 225.22582 \nL 361.725 225.22582 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_32\">\n    <path d=\"M 26.925 7.78582 \nL 361.725 7.78582 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n  </g>\n </g>\n <defs>\n  <clipPath id=\"p1f2e03f3c1\">\n   <rect height=\"217.44\" width=\"334.8\" x=\"26.925\" y=\"7.78582\"/>\n  </clipPath>\n </defs>\n</svg>\n",
-      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD5CAYAAAA+0W6bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPHElEQVR4nO3dbaxlVX3H8e+vgGiFyCA3kwkwHWyIhhftQG94iMZYrFZJUzUhDWh0UmnGtJpoNGlAk2rTvrBN1da0QcZCnRciWpRCCFYpkhiTZuwdHWFgRB466pCBGTWI9kUt+O+Lswaul/tw7nmYOWv4fpKTs8/a+5z1XzN7fnffdfbek6pCktSfXzvWBUiSRmOAS1KnDHBJ6pQBLkmdMsAlqVMGuCR16sS1NkjyQuDrwMlt+5ur6sNJzgFuAl4K7AbeXlW/WO2zzjjjjNqyZcvYRUvS88nu3bt/VFVzS9vXDHDgf4FLq+rnSU4CvpHky8D7gU9U1U1JPgVcBVy72gdt2bKFhYWFEcqXpOevJN9frn3NKZQa+Hl7eVJ7FHApcHNr3wm8eQJ1SpKGNNQceJITkuwBDgF3Ag8DT1TVU22TA8CZ0ylRkrScoQK8qp6uqq3AWcCFwCuG7SDJ9iQLSRYOHz48YpmSpKXWdRZKVT0B3A1cApyW5Mgc+lnAoyu8Z0dVzVfV/Nzcc+bgJUkjWjPAk8wlOa0tvwh4HbCPQZBf3jbbBtw6rSIlSc81zFkom4CdSU5gEPhfqKrbk9wP3JTkr4FvA9dPsU5J0hJrBnhV3QOcv0z7IwzmwyVJx4BXYkpSpwxwSerUMHPgXbpx1w/Wtf1bL9o8pUokaTo8ApekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSepUN3cjXO/dBSXpeOcRuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6ROrRngSc5OcneS+5Pcl+S9rf0jSR5Nsqc9Lpt+uZKkI4a5F8pTwAeq6ltJTgV2J7mzrftEVf3d9MqTJK1kzQCvqoPAwbb8syT7gDOnXZgkaXXrmgNPsgU4H9jVmt6T5J4kNyTZsMJ7tidZSLJw+PDhsYqVJD1r6ABPcgrwReB9VfUkcC3wm8BWBkfoH1vufVW1o6rmq2p+bm5uAiVLkmDIAE9yEoPw/mxVfQmgqh6vqqer6pfAp4ELp1emJGmpYc5CCXA9sK+qPr6ofdOizd4C7J18eZKklQxzFsorgbcD9ybZ09o+CFyZZCtQwH7gXVOpUJK0rGHOQvkGkGVW3TH5ciRJw/JKTEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6ROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpU2sGeJKzk9yd5P4k9yV5b2s/PcmdSR5szxumX64k6YhhjsCfAj5QVecBFwPvTnIecDVwV1WdC9zVXkuSjpI1A7yqDlbVt9ryz4B9wJnAm4CdbbOdwJunVaQk6bnWNQeeZAtwPrAL2FhVB9uqx4CNK7xne5KFJAuHDx8eo1RJ0mJDB3iSU4AvAu+rqicXr6uqAmq591XVjqqar6r5ubm5sYqVJD1rqABPchKD8P5sVX2pNT+eZFNbvwk4NJ0SJUnLGeYslADXA/uq6uOLVt0GbGvL24BbJ1+eJGklJw6xzSuBtwP3JtnT2j4IfBT4QpKrgO8DfzSdEiVJy1kzwKvqG0BWWP3ayZYjSRqWV2JKUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUqWH+Q4fnhRt3/WDobd960eYpViJJw/EIXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1Kn1gzwJDckOZRk76K2jyR5NMme9rhsumVKkpYa5gj8M8Ablmn/RFVtbY87JluWJGktawZ4VX0d+MlRqEWStA7jzIG/J8k9bYplw8QqkiQNZdTbyV4L/BVQ7fljwDuX2zDJdmA7wObNx8dtWNdz61nw9rOSpmOkI/Cqeryqnq6qXwKfBi5cZdsdVTVfVfNzc3Oj1ilJWmKkAE+yadHLtwB7V9pWkjQda06hJPkc8BrgjCQHgA8Dr0mylcEUyn7gXVOsUZK0jDUDvKquXKb5+inUIklaB6/ElKROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6ROGeCS1CkDXJI6tWaAJ7khyaEkexe1nZ7kziQPtucN0y1TkrTUMEfgnwHesKTtauCuqjoXuKu9liQdRWsGeFV9HfjJkuY3ATvb8k7gzROuS5K0hlHnwDdW1cG2/BiwcaUNk2xPspBk4fDhwyN2J0laauwvMauqgFpl/Y6qmq+q+bm5uXG7kyQ1owb440k2AbTnQ5MrSZI0jFED/DZgW1veBtw6mXIkScMa5jTCzwH/Cbw8yYEkVwEfBV6X5EHg99prSdJRdOJaG1TVlSuseu2Ea5EkrYNXYkpSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6ROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnTpxnDcn2Q/8DHgaeKqq5idRlCRpbWMFePO7VfWjCXyOJGkdnEKRpE6NewRewFeTFHBdVe1YukGS7cB2gM2bN4/ZXZ9u3PWDdW3/1ouen39OktZn3CPwV1XVBcAbgXcnefXSDapqR1XNV9X83NzcmN1Jko4YK8Cr6tH2fAi4BbhwEkVJktY2coAneXGSU48sA68H9k6qMEnS6saZA98I3JLkyOfcWFX/PpGqJElrGjnAq+oR4LcnWIskaR08jVCSOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUqUncD1wTtt67F67Heu906J0UpdnlEbgkdcoAl6ROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ3yPHBpBJ4fr1ngEbgkdcoAl6ROGeCS1CkDXJI6ZYBLUqcMcEnqlKcRPs9M81a1s2Y9Y30+nebn7YonYxZq9whckjplgEtSp8YK8CRvSPJAkoeSXD2poiRJaxs5wJOcAPwT8EbgPODKJOdNqjBJ0urGOQK/EHioqh6pql8ANwFvmkxZkqS1jBPgZwI/XPT6QGuTJB0FUz+NMMl2YHt7+fMkD4zwMWcAP5pcVcfMcT+Otx3lQsb0zDimXfcUP3+m9qkxxjnUODrYv6b1b+M3lmscJ8AfBc5e9Pqs1vYrqmoHsGOMfkiyUFXz43zGLHAcs+V4GMfxMAZwHKMaZwrlv4Bzk5yT5AXAFcBtkylLkrSWkY/Aq+qpJO8BvgKcANxQVfdNrDJJ0qrGmgOvqjuAOyZUy2rGmoKZIY5jthwP4zgexgCOYySpqqPZnyRpQryUXpI6NfMBPguX6ye5IcmhJHsXtZ2e5M4kD7bnDa09ST7Z6r0nyQWL3rOtbf9gkm2L2n8nyb3tPZ9MktX6GGMcZye5O8n9Se5L8t7expLkhUm+meQ7bQx/2drPSbKr9fv59sU6SU5urx9q67cs+qxrWvsDSX5/Ufuy+9xKfYwjyQlJvp3k9l7HkWR/+zvfk2ShtXWzTy3q57QkNyf5bpJ9SS6Z+XFU1cw+GHw5+jDwMuAFwHeA845BHa8GLgD2Lmr7W+Dqtnw18Ddt+TLgy0CAi4Fdrf104JH2vKEtb2jrvtm2TXvvG1frY4xxbAIuaMunAt9jcBuEbsbSPveUtnwSsKv19wXgitb+KeBP2/KfAZ9qy1cAn2/L57X96WTgnLafnbDaPrdSH2P+nbwfuBG4fbU+ZnkcwH7gjCVt3exTi2reCfxJW34BcNqsj+OoBuEIf6CXAF9Z9Poa4JpjVMsWfjXAHwA2teVNwANt+TrgyqXbAVcC1y1qv661bQK+u6j9me1W6mOCY7oVeF2vYwF+HfgWcBGDiydOXLrfMDhL6pK2fGLbLkv3pSPbrbTPtfcs28cY9Z8F3AVcCty+Wh8zPo79PDfAu9qngJcA/037XrCXccz6FMosX66/saoOtuXHgI1teaWaV2s/sEz7an2Mrf0Kfj6DI9iuxtKmHfYAh4A7GRxpPlFVTy3T7zO1tvU/BV46wtheukofo/p74M+BX7bXq/Uxy+Mo4KtJdmdw5TV0tk8x+O3lMPAvbUrrn5O8eNbHMesB3oUa/Oic6uk8k+wjySnAF4H3VdWT0+pnJeP2UVVPV9VWBkewFwKvmFRtR0uSPwAOVdXuY13LBLyqqi5gcGfSdyd59eKVPexTDH6ruQC4tqrOB/6HwXTGJPtY03r7mPUAH+py/WPk8SSbANrzoda+Us2rtZ+1TPtqfYwsyUkMwvuzVfWlnsdSVU8AdzOYBjgtyZHrGhb3+0ytbf1LgB+PMLYfr9LHKF4J/GGS/Qzu5Hkp8A8djoOqerQ9HwJuYfBDtbd96gBwoKp2tdc3Mwj0mR7HrAf4LF+ufxtw5BvmbQzmk4+0v6N9S30x8NP269FXgNcn2dC+ZX49g7nHg8CTSS5u30q/Y8lnLdfHSNrnXw/sq6qP9ziWJHNJTmvLL2Iwh7+PQZBfvsIYjvR7OfC1dpRzG3BFBmd3nAOcy+BLpmX3ufaelfpYt6q6pqrOqqotrY+vVdXbehtHkhcnOfXIMoN9YS8d7VMAVfUY8MMkL29NrwXun/lxjPPlxdF4MPi293sM5jk/dIxq+BxwEPg/Bj+pr2Iwl3gX8CDwH8Dpbdsw+I8uHgbuBeYXfc47gYfa448Xtc8z2OkfBv6RZy+wWraPMcbxKga/nt0D7GmPy3oaC/BbwLfbGPYCf9HaX8YguB4C/hU4ubW/sL1+qK1/2aLP+lCr8wHaGQGr7XMr9TGB/es1PHsWSlfjaJ/1nfa470g/Pe1Ti/rZCiy0fevfGJxFMtPj8EpMSerUrE+hSJJWYIBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktSp/wdRBFKWSSQcDgAAAABJRU5ErkJggg==\n"
+      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 368.925 248.518125\" width=\"368.925pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n  <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n  </style>\n </defs>\n <g id=\"figure_1\">\n  <g id=\"patch_1\">\n   <path d=\"M -0 248.518125 \nL 368.925 248.518125 \nL 368.925 0 \nL -0 0 \nz\n\" style=\"fill:none;\"/>\n  </g>\n  <g id=\"axes_1\">\n   <g id=\"patch_2\">\n    <path d=\"M 26.925 224.64 \nL 361.725 224.64 \nL 361.725 7.2 \nL 26.925 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n   </g>\n   <g id=\"patch_3\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 42.143182 224.64 \nL 54.317727 224.64 \nL 54.317727 17.554286 \nL 42.143182 17.554286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_4\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 54.317727 224.64 \nL 66.492273 224.64 \nL 66.492273 112.701776 \nL 54.317727 112.701776 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_5\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 66.492273 224.64 \nL 78.666818 224.64 \nL 78.666818 135.089421 \nL 66.492273 135.089421 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_6\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 78.666818 224.64 \nL 90.841364 224.64 \nL 90.841364 179.86471 \nL 78.666818 179.86471 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_7\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 90.841364 224.64 \nL 103.015909 224.64 \nL 103.015909 202.252355 \nL 90.841364 202.252355 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_8\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 103.015909 224.64 \nL 115.190455 224.64 \nL 115.190455 219.043089 \nL 103.015909 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_9\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 115.190455 224.64 \nL 127.365 224.64 \nL 127.365 224.64 \nL 115.190455 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_10\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 127.365 224.64 \nL 139.539545 224.64 \nL 139.539545 219.043089 \nL 127.365 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_11\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 139.539545 224.64 \nL 151.714091 224.64 \nL 151.714091 224.64 \nL 139.539545 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_12\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 151.714091 224.64 \nL 163.888636 224.64 \nL 163.888636 207.849266 \nL 151.714091 207.849266 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_13\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 163.888636 224.64 \nL 176.063182 224.64 \nL 176.063182 213.446178 \nL 163.888636 213.446178 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_14\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 176.063182 224.64 \nL 188.237727 224.64 \nL 188.237727 219.043089 \nL 176.063182 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_15\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 188.237727 224.64 \nL 200.412273 224.64 \nL 200.412273 224.64 \nL 188.237727 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_16\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 200.412273 224.64 \nL 212.586818 224.64 \nL 212.586818 219.043089 \nL 200.412273 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_17\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 212.586818 224.64 \nL 224.761364 224.64 \nL 224.761364 219.043089 \nL 212.586818 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_18\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 224.761364 224.64 \nL 236.935909 224.64 \nL 236.935909 224.64 \nL 224.761364 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_19\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 236.935909 224.64 \nL 249.110455 224.64 \nL 249.110455 224.64 \nL 236.935909 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_20\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 249.110455 224.64 \nL 261.285 224.64 \nL 261.285 224.64 \nL 249.110455 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_21\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 261.285 224.64 \nL 273.459545 224.64 \nL 273.459545 224.64 \nL 261.285 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_22\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 273.459545 224.64 \nL 285.634091 224.64 \nL 285.634091 219.043089 \nL 273.459545 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_23\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 285.634091 224.64 \nL 297.808636 224.64 \nL 297.808636 224.64 \nL 285.634091 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_24\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 297.808636 224.64 \nL 309.983182 224.64 \nL 309.983182 213.446178 \nL 297.808636 213.446178 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_25\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 309.983182 224.64 \nL 322.157727 224.64 \nL 322.157727 224.64 \nL 309.983182 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_26\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 322.157727 224.64 \nL 334.332273 224.64 \nL 334.332273 224.64 \nL 322.157727 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_27\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 334.332273 224.64 \nL 346.506818 224.64 \nL 346.506818 219.043089 \nL 334.332273 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"matplotlib.axis_1\">\n    <g id=\"xtick_1\">\n     <g id=\"line2d_1\">\n      <defs>\n       <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"m2d56e675c0\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"41.134974\" xlink:href=\"#m2d56e675c0\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_1\">\n      <!-- 0 -->\n      <defs>\n       <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421875 \nQ 19.53125 -1.421875 13.0625 8.265625 \nQ 6.59375 17.96875 6.59375 36.375 \nQ 6.59375 54.828125 13.0625 64.515625 \nQ 19.53125 74.21875 31.78125 74.21875 \nz\n\" id=\"DejaVuSans-48\"/>\n      </defs>\n      <g transform=\"translate(37.953724 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_2\">\n     <g id=\"line2d_2\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"114.140576\" xlink:href=\"#m2d56e675c0\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_2\">\n      <!-- 200000 -->\n      <defs>\n       <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n      </defs>\n      <g transform=\"translate(95.053076 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_3\">\n     <g id=\"line2d_3\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"187.146177\" xlink:href=\"#m2d56e675c0\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_3\">\n      <!-- 400000 -->\n      <defs>\n       <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375 17.1875 \nL 47.609375 0 \nL 37.796875 0 \nL 37.796875 17.1875 \nL 4.890625 17.1875 \nL 4.890625 26.703125 \nz\n\" id=\"DejaVuSans-52\"/>\n      </defs>\n      <g transform=\"translate(168.058677 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-52\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_4\">\n     <g id=\"line2d_4\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"260.151778\" xlink:href=\"#m2d56e675c0\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_4\">\n      <!-- 600000 -->\n      <defs>\n       <path d=\"M 33.015625 40.375 \nQ 26.375 40.375 22.484375 35.828125 \nQ 18.609375 31.296875 18.609375 23.390625 \nQ 18.609375 15.53125 22.484375 10.953125 \nQ 26.375 6.390625 33.015625 6.390625 \nQ 39.65625 6.390625 43.53125 10.953125 \nQ 47.40625 15.53125 47.40625 23.390625 \nQ 47.40625 31.296875 43.53125 35.828125 \nQ 39.65625 40.375 33.015625 40.375 \nz\nM 52.59375 71.296875 \nL 52.59375 62.3125 \nQ 48.875 64.0625 45.09375 64.984375 \nQ 41.3125 65.921875 37.59375 65.921875 \nQ 27.828125 65.921875 22.671875 59.328125 \nQ 17.53125 52.734375 16.796875 39.40625 \nQ 19.671875 43.65625 24.015625 45.921875 \nQ 28.375 48.1875 33.59375 48.1875 \nQ 44.578125 48.1875 50.953125 41.515625 \nQ 57.328125 34.859375 57.328125 23.390625 \nQ 57.328125 12.15625 50.6875 5.359375 \nQ 44.046875 -1.421875 33.015625 -1.421875 \nQ 20.359375 -1.421875 13.671875 8.265625 \nQ 6.984375 17.96875 6.984375 36.375 \nQ 6.984375 53.65625 15.1875 63.9375 \nQ 23.390625 74.21875 37.203125 74.21875 \nQ 40.921875 74.21875 44.703125 73.484375 \nQ 48.484375 72.75 52.59375 71.296875 \nz\n\" id=\"DejaVuSans-54\"/>\n      </defs>\n      <g transform=\"translate(241.064278 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-54\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_5\">\n     <g id=\"line2d_5\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"333.157379\" xlink:href=\"#m2d56e675c0\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_5\">\n      <!-- 800000 -->\n      <defs>\n       <path d=\"M 31.78125 34.625 \nQ 24.75 34.625 20.71875 30.859375 \nQ 16.703125 27.09375 16.703125 20.515625 \nQ 16.703125 13.921875 20.71875 10.15625 \nQ 24.75 6.390625 31.78125 6.390625 \nQ 38.8125 6.390625 42.859375 10.171875 \nQ 46.921875 13.96875 46.921875 20.515625 \nQ 46.921875 27.09375 42.890625 30.859375 \nQ 38.875 34.625 31.78125 34.625 \nz\nM 21.921875 38.8125 \nQ 15.578125 40.375 12.03125 44.71875 \nQ 8.5 49.078125 8.5 55.328125 \nQ 8.5 64.0625 14.71875 69.140625 \nQ 20.953125 74.21875 31.78125 74.21875 \nQ 42.671875 74.21875 48.875 69.140625 \nQ 55.078125 64.0625 55.078125 55.328125 \nQ 55.078125 49.078125 51.53125 44.71875 \nQ 48 40.375 41.703125 38.8125 \nQ 48.828125 37.15625 52.796875 32.3125 \nQ 56.78125 27.484375 56.78125 20.515625 \nQ 56.78125 9.90625 50.3125 4.234375 \nQ 43.84375 -1.421875 31.78125 -1.421875 \nQ 19.734375 -1.421875 13.25 4.234375 \nQ 6.78125 9.90625 6.78125 20.515625 \nQ 6.78125 27.484375 10.78125 32.3125 \nQ 14.796875 37.15625 21.921875 38.8125 \nz\nM 18.3125 54.390625 \nQ 18.3125 48.734375 21.84375 45.5625 \nQ 25.390625 42.390625 31.78125 42.390625 \nQ 38.140625 42.390625 41.71875 45.5625 \nQ 45.3125 48.734375 45.3125 54.390625 \nQ 45.3125 60.0625 41.71875 63.234375 \nQ 38.140625 66.40625 31.78125 66.40625 \nQ 25.390625 66.40625 21.84375 63.234375 \nQ 18.3125 60.0625 18.3125 54.390625 \nz\n\" id=\"DejaVuSans-56\"/>\n      </defs>\n      <g transform=\"translate(314.069879 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-56\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"matplotlib.axis_2\">\n    <g id=\"ytick_1\">\n     <g id=\"line2d_6\">\n      <defs>\n       <path d=\"M 0 0 \nL -3.5 0 \n\" id=\"me7cbd43d73\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_6\">\n      <!-- 0 -->\n      <g transform=\"translate(13.5625 228.439219)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_2\">\n     <g id=\"line2d_7\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"196.655444\"/>\n      </g>\n     </g>\n     <g id=\"text_7\">\n      <!-- 5 -->\n      <defs>\n       <path d=\"M 10.796875 72.90625 \nL 49.515625 72.90625 \nL 49.515625 64.59375 \nL 19.828125 64.59375 \nL 19.828125 46.734375 \nQ 21.96875 47.46875 24.109375 47.828125 \nQ 26.265625 48.1875 28.421875 48.1875 \nQ 40.625 48.1875 47.75 41.5 \nQ 54.890625 34.8125 54.890625 23.390625 \nQ 54.890625 11.625 47.5625 5.09375 \nQ 40.234375 -1.421875 26.90625 -1.421875 \nQ 22.3125 -1.421875 17.546875 -0.640625 \nQ 12.796875 0.140625 7.71875 1.703125 \nL 7.71875 11.625 \nQ 12.109375 9.234375 16.796875 8.0625 \nQ 21.484375 6.890625 26.703125 6.890625 \nQ 35.15625 6.890625 40.078125 11.328125 \nQ 45.015625 15.765625 45.015625 23.390625 \nQ 45.015625 31 40.078125 35.4375 \nQ 35.15625 39.890625 26.703125 39.890625 \nQ 22.75 39.890625 18.8125 39.015625 \nQ 14.890625 38.140625 10.796875 36.28125 \nz\n\" id=\"DejaVuSans-53\"/>\n      </defs>\n      <g transform=\"translate(13.5625 200.454663)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_3\">\n     <g id=\"line2d_8\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"168.670888\"/>\n      </g>\n     </g>\n     <g id=\"text_8\">\n      <!-- 10 -->\n      <defs>\n       <path d=\"M 12.40625 8.296875 \nL 28.515625 8.296875 \nL 28.515625 63.921875 \nL 10.984375 60.40625 \nL 10.984375 69.390625 \nL 28.421875 72.90625 \nL 38.28125 72.90625 \nL 38.28125 8.296875 \nL 54.390625 8.296875 \nL 54.390625 0 \nL 12.40625 0 \nz\n\" id=\"DejaVuSans-49\"/>\n      </defs>\n      <g transform=\"translate(7.2 172.470107)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_4\">\n     <g id=\"line2d_9\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"140.686332\"/>\n      </g>\n     </g>\n     <g id=\"text_9\">\n      <!-- 15 -->\n      <g transform=\"translate(7.2 144.485551)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_5\">\n     <g id=\"line2d_10\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"112.701776\"/>\n      </g>\n     </g>\n     <g id=\"text_10\">\n      <!-- 20 -->\n      <g transform=\"translate(7.2 116.500995)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_6\">\n     <g id=\"line2d_11\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"84.71722\"/>\n      </g>\n     </g>\n     <g id=\"text_11\">\n      <!-- 25 -->\n      <g transform=\"translate(7.2 88.516439)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_7\">\n     <g id=\"line2d_12\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"56.732664\"/>\n      </g>\n     </g>\n     <g id=\"text_12\">\n      <!-- 30 -->\n      <defs>\n       <path d=\"M 40.578125 39.3125 \nQ 47.65625 37.796875 51.625 33 \nQ 55.609375 28.21875 55.609375 21.1875 \nQ 55.609375 10.40625 48.1875 4.484375 \nQ 40.765625 -1.421875 27.09375 -1.421875 \nQ 22.515625 -1.421875 17.65625 -0.515625 \nQ 12.796875 0.390625 7.625 2.203125 \nL 7.625 11.71875 \nQ 11.71875 9.328125 16.59375 8.109375 \nQ 21.484375 6.890625 26.8125 6.890625 \nQ 36.078125 6.890625 40.9375 10.546875 \nQ 45.796875 14.203125 45.796875 21.1875 \nQ 45.796875 27.640625 41.28125 31.265625 \nQ 36.765625 34.90625 28.71875 34.90625 \nL 20.21875 34.90625 \nL 20.21875 43.015625 \nL 29.109375 43.015625 \nQ 36.375 43.015625 40.234375 45.921875 \nQ 44.09375 48.828125 44.09375 54.296875 \nQ 44.09375 59.90625 40.109375 62.90625 \nQ 36.140625 65.921875 28.71875 65.921875 \nQ 24.65625 65.921875 20.015625 65.03125 \nQ 15.375 64.15625 9.8125 62.3125 \nL 9.8125 71.09375 \nQ 15.4375 72.65625 20.34375 73.4375 \nQ 25.25 74.21875 29.59375 74.21875 \nQ 40.828125 74.21875 47.359375 69.109375 \nQ 53.90625 64.015625 53.90625 55.328125 \nQ 53.90625 49.265625 50.4375 45.09375 \nQ 46.96875 40.921875 40.578125 39.3125 \nz\n\" id=\"DejaVuSans-51\"/>\n      </defs>\n      <g transform=\"translate(7.2 60.531883)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-51\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_8\">\n     <g id=\"line2d_13\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"28.748108\"/>\n      </g>\n     </g>\n     <g id=\"text_13\">\n      <!-- 35 -->\n      <g transform=\"translate(7.2 32.547327)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-51\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"patch_28\">\n    <path d=\"M 26.925 224.64 \nL 26.925 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_29\">\n    <path d=\"M 361.725 224.64 \nL 361.725 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_30\">\n    <path d=\"M 26.925 224.64 \nL 361.725 224.64 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_31\">\n    <path d=\"M 26.925 7.2 \nL 361.725 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n  </g>\n </g>\n <defs>\n  <clipPath id=\"p45ea857fc5\">\n   <rect height=\"217.44\" width=\"334.8\" x=\"26.925\" y=\"7.2\"/>\n  </clipPath>\n </defs>\n</svg>\n",
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPaklEQVR4nO3df4xlZX3H8fenu/ywamQpE7LlRxct0ZAmLnTKj2gai6JImooJacBGt61mbSuJtqYt6B/VpE20UWlNG3Ut6MaIShELIVpLkcSYNGsHXWH5VX6ICFnYQYto/7AFv/3jPgPDODP3zsy9u/eB9ys5mXOec+59vvfZZz9799xz5qaqkCT15xcOdQGSpPUxwCWpUwa4JHXKAJekThngktSpzQezs2OOOaa2bdt2MLuUpO7dfPPNj1bVzNL2gxrg27ZtY25u7mB2KUndS/K95do9hSJJnTLAJalTBrgkdcoAl6ROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ06qHdibsSVex5Y0/FvOuPECVUiSdPBd+CS1CkDXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHVqaIAnOTLJN5N8J8ltSd7f2j+d5LtJ9rZl++TLlSQtGOVGnp8CZ1fVT5IcBnwjyVfavj+vqqsnV54kaSVDA7yqCvhJ2zysLTXJoiRJw410DjzJpiR7gQPADVW1p+36myS3JLksyRErPHZnkrkkc/Pz82MqW5I0UoBX1ZNVtR04Hjg9ya8BlwIvA34DOBr4yxUeu6uqZqtqdmZmZkxlS5LWdBVKVT0G3AScW1X7a+CnwKeA0ydRoCRpeaNchTKT5Ki2/jzgHODOJFtbW4DzgX2TLFSS9EyjXIWyFdidZBODwL+qqq5P8rUkM0CAvcAfTbBOSdISo1yFcgtw6jLtZ0+kIknSSLwTU5I6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSerUKN9Kf2SSbyb5TpLbkry/tZ+UZE+Se5J8Icnhky9XkrRglHfgPwXOrqqXA9uBc5OcCXwQuKyqfhX4b+CtkytTkrTU0ACvgZ+0zcPaUsDZwNWtfTdw/kQqlCQta6Rz4Ek2JdkLHABuAO4FHquqJ9ohDwLHrfDYnUnmkszNz8+Po2ZJEiMGeFU9WVXbgeOB04GXjdpBVe2qqtmqmp2ZmVlnmZKkpdZ0FUpVPQbcBJwFHJVkc9t1PPDQmGuTJK1ilKtQZpIc1dafB5wD3MEgyC9oh+0Arp1UkZKkn7d5+CFsBXYn2cQg8K+qquuT3A58PslfA98GLp9gnZKkJYYGeFXdApy6TPt9DM6HS5IOAe/ElKROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjo1yrfSn5DkpiS3J7ktyTtb+/uSPJRkb1vOm3y5kqQFo3wr/RPAu6vqW0leCNyc5Ia277Kq+tDkypMkrWSUb6XfD+xv6z9Ocgdw3KQLkyStbk3nwJNsA04F9rSmi5PckuSKJFtWeMzOJHNJ5ubn5zdUrCTpaSMHeJIXAF8E3lVVjwMfA14CbGfwDv3Dyz2uqnZV1WxVzc7MzIyhZEkSjBjgSQ5jEN6fraprAKrqkap6sqp+BnwSOH1yZUqSlhrlKpQAlwN3VNVHFrVvXXTYG4F94y9PkrSSUa5CeQXwZuDWJHtb23uAi5JsBwq4H3j7RCqUJC1rlKtQvgFkmV1fHn85kqRReSemJHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6ROGeCS1KlRvpX+hCQ3Jbk9yW1J3tnaj05yQ5K7288tky9XkrRglHfgTwDvrqpTgDOBdyQ5BbgEuLGqTgZubNuSpINkaIBX1f6q+lZb/zFwB3Ac8AZgdztsN3D+pIqUJP28zWs5OMk24FRgD3BsVe1vux4Gjl3hMTuBnQAnnnjieutcsyv3PLDmx7zpjINXnyRt1MgfYiZ5AfBF4F1V9fjifVVVQC33uKraVVWzVTU7MzOzoWIlSU8bKcCTHMYgvD9bVde05keSbG37twIHJlOiJGk5o1yFEuBy4I6q+siiXdcBO9r6DuDa8ZcnSVrJKOfAXwG8Gbg1yd7W9h7gA8BVSd4KfA/43cmUKElaztAAr6pvAFlh96vHW44kaVTeiSlJnTLAJalTBrgkdcoAl6ROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6tSavpX+2W6t32Tvt9hLOpR8By5JnTLAJalTo3wr/RVJDiTZt6jtfUkeSrK3LedNtkxJ0lKjvAP/NHDuMu2XVdX2tnx5vGVJkoYZGuBV9XXghwehFknSGmzkHPjFSW5pp1i2rHRQkp1J5pLMzc/Pb6A7SdJi6w3wjwEvAbYD+4EPr3RgVe2qqtmqmp2ZmVlnd5KkpdYV4FX1SFU9WVU/Az4JnD7esiRJw6wrwJNsXbT5RmDfSsdKkiZj6J2YST4HvAo4JsmDwF8Br0qyHSjgfuDtE6xRkrSMoQFeVRct03z5BGqRJK2Bd2JKUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnRoa4EmuSHIgyb5FbUcnuSHJ3e3nlsmWKUlaapR34J8Gzl3SdglwY1WdDNzYtiVJB9HQAK+qrwM/XNL8BmB3W98NnD/muiRJQ6z3HPixVbW/rT8MHLvSgUl2JplLMjc/P7/O7iRJS234Q8yqKqBW2b+rqmaranZmZmaj3UmSmvUG+CNJtgK0nwfGV5IkaRTrDfDrgB1tfQdw7XjKkSSNapTLCD8H/Afw0iQPJnkr8AHgnCR3A69p25Kkg2jzsAOq6qIVdr16zLVIktbAOzElqVMGuCR1ygCXpE4Z4JLUKQNckjo19CoUrezKPQ+s6fg3nXHihCqR9FzkO3BJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6RObei3ESa5H/gx8CTwRFXNjqMoSdJw4/h1sr9VVY+O4XkkSWvgKRRJ6tRGA7yAf0tyc5Kdyx2QZGeSuSRz8/PzG+xOkrRgowH+yqo6DXg98I4kv7n0gKraVVWzVTU7MzOzwe4kSQs2FOBV9VD7eQD4EnD6OIqSJA237gBP8vwkL1xYB14L7BtXYZKk1W3kKpRjgS8lWXieK6vqX8dSlSRpqHUHeFXdB7x8jLU86/kt9pLGycsIJalTBrgkdcoAl6ROGeCS1CkDXJI6NY5fZiWtyCtvpMnxHbgkdcoAl6ROGeCS1CkDXJI6ZYBLUqe8CkVTxatWtBznxfJ8By5JnTLAJalTBrgkdcoAl6RO+SHmFJv0Bzdrff719DFp0zhGazVtY7pWz4Z5tFbT8pp9By5JnTLAJalTGwrwJOcmuSvJPUkuGVdRkqTh1h3gSTYB/wi8HjgFuCjJKeMqTJK0uo28Az8duKeq7quq/wU+D7xhPGVJkoZJVa3vgckFwLlV9ba2/WbgjKq6eMlxO4GdbfOlwF3r6O4Y4NF1Ffrc4RgN5xgN5xgNdyjG6FeqamZp48QvI6yqXcCujTxHkrmqmh1TSc9KjtFwjtFwjtFw0zRGGzmF8hBwwqLt41ubJOkg2EiA/ydwcpKTkhwOXAhcN56yJEnDrPsUSlU9keRi4KvAJuCKqrptbJU904ZOwTxHOEbDOUbDOUbDTc0YrftDTEnSoeWdmJLUKQNckjo19QH+bL9dP8kJSW5KcnuS25K8s7UfneSGJHe3n1tae5J8tI3HLUlOW/RcO9rxdyfZsaj915Pc2h7z0SRZrY9plWRTkm8nub5tn5RkT3tdX2gfppPkiLZ9T9u/bdFzXNra70ryukXty86zlfqYRkmOSnJ1kjuT3JHkLOfRMyX50/b3bF+SzyU5sut5VFVTuzD4cPRe4MXA4cB3gFMOdV1jfo1bgdPa+guB/2Lwqwn+FriktV8CfLCtnwd8BQhwJrCntR8N3Nd+bmnrW9q+b7Zj0x77+ta+bB/TugB/BlwJXN+2rwIubOsfB/64rf8J8PG2fiHwhbZ+SptDRwAntbm1abV5tlIf07gAu4G3tfXDgaOcR88Yn+OA7wLPW/Rn+/s9z6NDPqhDBvws4KuLti8FLj3UdU34NV8LnMPgjtWtrW0rcFdb/wRw0aLj72r7LwI+saj9E61tK3Dnovanjlupj2lcGNxncCNwNnB9C5FHgc1L5wqDK6POauub23FZOn8Wjltpnq3Wx7QtwItaOGVJu/Po6ZqPA77P4B+nzW0eva7neTTtp1AWBnzBg63tWan9F+1UYA9wbFXtb7seBo5t6yuNyWrtDy7Tzip9TKO/A/4C+Fnb/iXgsap6om0vfl1PjUXb/6N2/FrHbrU+ps1JwDzwqXaa6Z+SPB/n0VOq6iHgQ8ADwH4G8+JmOp5H0x7gzxlJXgB8EXhXVT2+eF8N/tme6PWeB6OP9Ury28CBqrr5UNcyxTYDpwEfq6pTgf9hcDrjKc6jbGHwC/dOAn4ZeD5w7iEtaoOmPcCfE7frJzmMQXh/tqquac2PJNna9m8FDrT2lcZktfbjl2lfrY9p8wrgd5Lcz+C3Xp4N/D1wVJKFm9EWv66nxqLtfxHwA9Y+dj9YpY9p8yDwYFXtadtXMwh059HTXgN8t6rmq+r/gGsYzK1u59G0B/iz/nb99kn+5cAdVfWRRbuuAxauANjB4Nz4Qvtb2lUEZwI/av99/Srw2iRb2juN1zI4z7YfeDzJma2vtyx5ruX6mCpVdWlVHV9V2xjMga9V1e8BNwEXtMOWjtHC67qgHV+t/cJ2dcFJwMkMPphbdp61x6zUx1SpqoeB7yd5aWt6NXA7zqPFHgDOTPKL7TUsjFG/8+hQf7AwwgcP5zG4MuNe4L2Hup4JvL5XMvgv5y3A3racx+C82Y3A3cC/A0e348PgizTuBW4FZhc91x8C97TlDxa1zwL72mP+gafvwF22j2legFfx9FUoL25/ce4B/hk4orUf2bbvaftfvOjx723jcBftKorV5tlKfUzjAmwH5tpc+hcGV5E4j545Ru8H7myv4zMMriTpdh55K70kdWraT6FIklZggEtSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6RO/T/sboSrn/bcSgAAAABJRU5ErkJggg==\n"
      },
      "metadata": {
       "needs_background": "light"
@@ -137,11 +137,11 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": 3
+   "version": "3.8.2-final"
   },
   "orig_nbformat": 2,
   "kernelspec": {
-   "name": "python_defaultSpec_1594306004152",
+   "name": "python38264bita7d7da14168440cb9836372958035d4a",
    "display_name": "Python 3.8.2 64-bit"
   }
  },
diff --git a/dataset_generation/notebook_simple.ipynb b/dataset_generation/notebook_simple.ipynb
index 3de5576..9b8a447 100644
--- a/dataset_generation/notebook_simple.ipynb
+++ b/dataset_generation/notebook_simple.ipynb
@@ -66,15 +66,15 @@
     "num_exported = 0\n",
     "document_lens = []\n",
     "\n",
-    "if not os.path.exists(\"dataset_simple\"):\n",
-    "    os.mkdir(\"dataset_simple\")\n",
+    "if not os.path.exists(\"../dataset_simple\"):\n",
+    "    os.mkdir(\"../dataset_simple\")\n",
     "\n",
     "for file_path in tqdm(files_subset):\n",
     "    full_text = text_from_xml(file_path)\n",
     "    \n",
     "    if len(full_text) > 0:\n",
-    "        output_file_input = f\"dataset_simple/{hashlib.md5(file_path.encode()).hexdigest()}_input.txt\"\n",
-    "        output_file_output = f\"dataset_simple/{hashlib.md5(file_path.encode()).hexdigest()}_output.txt\"\n",
+    "        output_file_input = f\"../dataset_simple/{hashlib.md5(file_path.encode()).hexdigest()}_input.txt\"\n",
+    "        output_file_output = f\"../dataset_simple/{hashlib.md5(file_path.encode()).hexdigest()}_output.txt\"\n",
     "\n",
     "        with open(output_file_input, \"w\") as f:\n",
     "            f.write(full_text)\n",
-- 
GitLab


From 26c4db7783822be545376074e9decf79a841e539 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Sat, 11 Jul 2020 00:03:11 +0200
Subject: [PATCH 004/116] Added tokenizing & batching

---
 dataset_generation/processing.py      | 99 ++++++++++++++++++++++++++-
 dataset_generation/test_processing.py | 77 +++++++++++++++++++++
 2 files changed, 175 insertions(+), 1 deletion(-)
 create mode 100644 dataset_generation/test_processing.py

diff --git a/dataset_generation/processing.py b/dataset_generation/processing.py
index 2ab5359..3ec875c 100644
--- a/dataset_generation/processing.py
+++ b/dataset_generation/processing.py
@@ -10,6 +10,7 @@ from typing import Optional, Mapping
 from utils import remove_punctuation
 import numpy as np
 from more_itertools import windowed
+from transformers import PreTrainedTokenizerFast, BertTokenizerFast
 
 ACTIONS_KEYS = ['dot', 'upper_case', 'colon', 'semicolon', 'elipsis', 'dash']
 
@@ -47,6 +48,12 @@ def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
         Mapping[str, bool]: Mapping telling if each of possible actions should be performed (True) or not (False) 
     """
     word.replace('"', " ") # No support for quotes
+    while len(word) > 0 and not word[0].isalnum(): # remove proceding characters
+        print(word)
+        word = word[1:]
+
+    if len(word) == 0:
+        return zip(ACTIONS_KEYS, [False] * len(ACTIONS_KEYS))
 
     actions = {
         'dot': word[0] == '.',
@@ -72,13 +79,15 @@ def encode_actions(actions: Mapping[str, bool]) -> np.ndarray:
 
 def decode_actions(encoded_actions: np.ndarray) -> Mapping[str, bool]:
     """Decodes actions
-
+    
     Args:
         encoded_actions (np.ndarray): 1 dimensional action vector
 
     Returns:
         Mapping[str, bool]: Map telling which actions should be made
     """
+    assert encoded_actions.shape[0] == len(ACTIONS_KEYS)
+
     return dict(zip(ACTIONS_KEYS, encoded_actions.astype(np.bool).tolist()))
 
 def create_model_input_output(text: str) -> (str, np.ndarray):
@@ -115,6 +124,26 @@ def create_model_input_output(text: str) -> (str, np.ndarray):
 
     return " ".join(words_output), np.array(actions_output)
 
+def tokenize_labeled_text(text: str, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast) -> (np.ndarray, np.ndarray):
+    text_tokenized = tokenizer(text, return_offsets_mapping=True)
+
+    # Create a map where each character is assigned index of it's word
+    words_mapping = []
+    actual_word = 0
+    for character in text:
+        words_mapping.append(actual_word)
+        if character == " ":
+            actual_word += 1
+
+    # Assign each token to a word
+    token_mapping = [words_mapping[x[0]] for x in text_tokenized['offset_mapping']]
+    
+    # Expand word-based labels to token-based labels
+    labels_tokenized = [labels[i] for i in token_mapping]
+
+    return np.array(text_tokenized['input_ids']).reshape(-1, 1), np.array(labels_tokenized)
+
+
 def recover_word(word: str, action: Mapping[str, bool]) -> str:
     word_result = word
     
@@ -133,6 +162,74 @@ def recover_word(word: str, action: Mapping[str, bool]) -> str:
 
     return word
 
+def is_sentence_end(actions_encoded: np.ndarray) -> bool:
+    actions_decoded = decode_actions(actions_encoded)
+
+    return (actions_decoded['dot']
+            or actions_decoded['elipsis'])
+
+def nearest_sentence_l(labels: np.array, index_start: int) -> int:
+    result_index = index_start
+
+    while result_index > 0:
+        if is_sentence_end(labels[result_index - 1, :]):
+            break
+        elif result_index == 1:
+            result_index = 0
+            break
+        else:
+            result_index -= 1
+
+    return result_index
+
+def nearest_sentence_r(labels: np.array, index_start: int) -> Optional[int]:
+    result_index = index_start + 1
+
+    while result_index < len(labels):
+        if is_sentence_end(labels[result_index - 1]):
+            break
+        elif result_index == 1:
+            result_index = 0
+            break
+        else:
+            result_index += 1
+
+    if result_index >= len(labels):
+        return None
+    else:
+        return result_index
+
+def batchify_tokens(tokens: np.ndarray, labels: np.ndarray, max_tokens: int, min_tokens: int = 3) -> (np.ndarray, np.ndarray):
+
+    assert min_tokens >= 1
+
+    # remove start & end tokens
+    tokens = tokens[1:-1, :]
+    labels = labels[1:-1, :]
+
+    tokens_batches = []
+    labels_batches = []
+
+    index = 0
+    while index < (tokens.shape[0] - min_tokens):
+        num_consumed = min(max_tokens, tokens.shape[0] - index)
+
+        assert num_consumed >= min_tokens
+
+        tokens_batches.append(tokens[index:(index + num_consumed), :])
+        labels_batches.append(labels[index:(index + num_consumed), :])
+
+        new_index = nearest_sentence_l(labels, index + num_consumed)
+        if new_index == index:
+            new_index = nearest_sentence_r(labels, index + num_consumed)
+            if new_index is None:
+                break
+
+        index = new_index
+
+    return np.array(tokens_batches), np.array(labels_batches)
+
+
 def recover_text(text: str, actions_encoded: np.ndarray):
     words = text.split(" ")
 
diff --git a/dataset_generation/test_processing.py b/dataset_generation/test_processing.py
new file mode 100644
index 0000000..34064fe
--- /dev/null
+++ b/dataset_generation/test_processing.py
@@ -0,0 +1,77 @@
+import numpy
+from processing import *
+from transformers import PreTrainedTokenizerFast, BertTokenizerFast
+
+def test_encode_actions():
+    x = {
+        'dot': True,
+        'upper_case': False,
+        'colon': False,
+        'semicolon': True,
+        'elipsis': False,
+        'dash': True
+    }
+
+    assert np.all(encode_actions(x) == np.array([1, 0, 0, 1, 0, 1]))
+
+def test_decode_actions():
+    x = np.array([1, 0, 0, 1, 0, 1])
+
+    assert decode_actions(x) == {
+        'dot': True,
+        'upper_case': False,
+        'colon': False,
+        'semicolon': True,
+        'elipsis': False,
+        'dash': True
+    }
+
+def test_tokenize_labeled_text():
+    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię..."
+    tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
+
+    text_clean, labels = create_model_input_output(text)
+    tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
+
+    assert len(tokens.shape) == 2
+    assert len(token_labels.shape) == 2
+
+    assert tokens.shape[1] == 1
+    assert token_labels.shape[1] == len(ACTIONS_KEYS)
+
+    assert len(tokens) == len(token_labels)
+    assert tokens[0, 0] == tokenizer.cls_token_id
+    assert tokens[-1, 0] == tokenizer.sep_token_id
+
+    assert np.all(token_labels[0] == token_labels[1])
+    assert np.all(token_labels[-1] == token_labels[-2])
+
+def test_batchify_tokens():
+    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam niedzwiedzia..."
+    tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
+
+    text_clean, labels = create_model_input_output(text)
+    tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
+
+    input_batch, output_batch = batchify_tokens(tokens, token_labels, 5)
+    
+    assert len(input_batch.shape) == 3
+    assert len(output_batch.shape) == 3
+
+    # First dimension should be batch size
+    assert input_batch.shape[0] == output_batch.shape[0]
+
+    # Second dimension should be sequence length
+    assert input_batch.shape[1] == 5
+    assert output_batch.shape[1] == 5
+
+    # Third dimension should be feature size
+    assert input_batch.shape[2] == 1
+    assert output_batch.shape[2] == len(ACTIONS_KEYS)
+
+    # Should always start from beginning of the sentence
+    for i in range(input_batch.shape[0]):
+        assert decode_actions(output_batch[i, 0, :])['upper_case'] == True
+        assert decode_actions(output_batch[i, 1, :])['upper_case'] == True
+
+    
\ No newline at end of file
-- 
GitLab


From 7c58564f0878a164fd07c0906cb0da29a75cf8e7 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Sat, 11 Jul 2020 14:25:57 +0200
Subject: [PATCH 005/116] Working batchifying core tests

---
 dataset_generation/notebook_actions.ipynb |  34 ++++---
 dataset_generation/processing.py          | 118 ++++++++++++++++------
 dataset_generation/test_processing.py     | 104 +++++++++++++++++--
 3 files changed, 205 insertions(+), 51 deletions(-)

diff --git a/dataset_generation/notebook_actions.ipynb b/dataset_generation/notebook_actions.ipynb
index dedbf9d..20a28ac 100644
--- a/dataset_generation/notebook_actions.ipynb
+++ b/dataset_generation/notebook_actions.ipynb
@@ -2,9 +2,17 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [],
+   "execution_count": 14,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": "The autoreload extension is already loaded. To reload it, use:\n  %reload_ext autoreload\n"
+    }
+   ],
    "source": [
     "%load_ext autoreload\n",
     "%autoreload 2"
@@ -12,7 +20,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 15,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -32,7 +40,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 16,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -42,7 +50,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -51,7 +59,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 18,
    "metadata": {
     "tags": []
    },
@@ -59,7 +67,7 @@
     {
      "output_type": "stream",
      "name": "stderr",
-     "text": "100%|██████████| 1000/1000 [00:09<00:00, 103.87it/s]\n"
+     "text": "100%|██████████| 1000/1000 [00:07<00:00, 133.04it/s]\n"
     }
    ],
    "source": [
@@ -89,7 +97,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 19,
    "metadata": {
     "tags": []
    },
@@ -97,17 +105,17 @@
     {
      "output_type": "execute_result",
      "data": {
-      "text/plain": "<matplotlib.axes._subplots.AxesSubplot at 0x7f018b71d910>"
+      "text/plain": "<matplotlib.axes._subplots.AxesSubplot at 0x7f83c19294c0>"
      },
      "metadata": {},
-     "execution_count": 8
+     "execution_count": 19
     },
     {
      "output_type": "display_data",
      "data": {
       "text/plain": "<Figure size 432x288 with 1 Axes>",
-      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 368.925 248.518125\" width=\"368.925pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n  <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n  </style>\n </defs>\n <g id=\"figure_1\">\n  <g id=\"patch_1\">\n   <path d=\"M -0 248.518125 \nL 368.925 248.518125 \nL 368.925 0 \nL -0 0 \nz\n\" style=\"fill:none;\"/>\n  </g>\n  <g id=\"axes_1\">\n   <g id=\"patch_2\">\n    <path d=\"M 26.925 224.64 \nL 361.725 224.64 \nL 361.725 7.2 \nL 26.925 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n   </g>\n   <g id=\"patch_3\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 42.143182 224.64 \nL 54.317727 224.64 \nL 54.317727 17.554286 \nL 42.143182 17.554286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_4\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 54.317727 224.64 \nL 66.492273 224.64 \nL 66.492273 112.701776 \nL 54.317727 112.701776 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_5\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 66.492273 224.64 \nL 78.666818 224.64 \nL 78.666818 135.089421 \nL 66.492273 135.089421 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_6\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 78.666818 224.64 \nL 90.841364 224.64 \nL 90.841364 179.86471 \nL 78.666818 179.86471 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_7\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 90.841364 224.64 \nL 103.015909 224.64 \nL 103.015909 202.252355 \nL 90.841364 202.252355 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_8\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 103.015909 224.64 \nL 115.190455 224.64 \nL 115.190455 219.043089 \nL 103.015909 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_9\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 115.190455 224.64 \nL 127.365 224.64 \nL 127.365 224.64 \nL 115.190455 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_10\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 127.365 224.64 \nL 139.539545 224.64 \nL 139.539545 219.043089 \nL 127.365 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_11\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 139.539545 224.64 \nL 151.714091 224.64 \nL 151.714091 224.64 \nL 139.539545 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_12\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 151.714091 224.64 \nL 163.888636 224.64 \nL 163.888636 207.849266 \nL 151.714091 207.849266 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_13\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 163.888636 224.64 \nL 176.063182 224.64 \nL 176.063182 213.446178 \nL 163.888636 213.446178 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_14\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 176.063182 224.64 \nL 188.237727 224.64 \nL 188.237727 219.043089 \nL 176.063182 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_15\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 188.237727 224.64 \nL 200.412273 224.64 \nL 200.412273 224.64 \nL 188.237727 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_16\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 200.412273 224.64 \nL 212.586818 224.64 \nL 212.586818 219.043089 \nL 200.412273 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_17\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 212.586818 224.64 \nL 224.761364 224.64 \nL 224.761364 219.043089 \nL 212.586818 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_18\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 224.761364 224.64 \nL 236.935909 224.64 \nL 236.935909 224.64 \nL 224.761364 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_19\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 236.935909 224.64 \nL 249.110455 224.64 \nL 249.110455 224.64 \nL 236.935909 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_20\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 249.110455 224.64 \nL 261.285 224.64 \nL 261.285 224.64 \nL 249.110455 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_21\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 261.285 224.64 \nL 273.459545 224.64 \nL 273.459545 224.64 \nL 261.285 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_22\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 273.459545 224.64 \nL 285.634091 224.64 \nL 285.634091 219.043089 \nL 273.459545 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_23\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 285.634091 224.64 \nL 297.808636 224.64 \nL 297.808636 224.64 \nL 285.634091 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_24\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 297.808636 224.64 \nL 309.983182 224.64 \nL 309.983182 213.446178 \nL 297.808636 213.446178 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_25\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 309.983182 224.64 \nL 322.157727 224.64 \nL 322.157727 224.64 \nL 309.983182 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_26\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 322.157727 224.64 \nL 334.332273 224.64 \nL 334.332273 224.64 \nL 322.157727 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_27\">\n    <path clip-path=\"url(#p45ea857fc5)\" d=\"M 334.332273 224.64 \nL 346.506818 224.64 \nL 346.506818 219.043089 \nL 334.332273 219.043089 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"matplotlib.axis_1\">\n    <g id=\"xtick_1\">\n     <g id=\"line2d_1\">\n      <defs>\n       <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"m2d56e675c0\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"41.134974\" xlink:href=\"#m2d56e675c0\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_1\">\n      <!-- 0 -->\n      <defs>\n       <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421875 \nQ 19.53125 -1.421875 13.0625 8.265625 \nQ 6.59375 17.96875 6.59375 36.375 \nQ 6.59375 54.828125 13.0625 64.515625 \nQ 19.53125 74.21875 31.78125 74.21875 \nz\n\" id=\"DejaVuSans-48\"/>\n      </defs>\n      <g transform=\"translate(37.953724 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_2\">\n     <g id=\"line2d_2\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"114.140576\" xlink:href=\"#m2d56e675c0\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_2\">\n      <!-- 200000 -->\n      <defs>\n       <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n      </defs>\n      <g transform=\"translate(95.053076 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_3\">\n     <g id=\"line2d_3\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"187.146177\" xlink:href=\"#m2d56e675c0\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_3\">\n      <!-- 400000 -->\n      <defs>\n       <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375 17.1875 \nL 47.609375 0 \nL 37.796875 0 \nL 37.796875 17.1875 \nL 4.890625 17.1875 \nL 4.890625 26.703125 \nz\n\" id=\"DejaVuSans-52\"/>\n      </defs>\n      <g transform=\"translate(168.058677 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-52\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_4\">\n     <g id=\"line2d_4\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"260.151778\" xlink:href=\"#m2d56e675c0\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_4\">\n      <!-- 600000 -->\n      <defs>\n       <path d=\"M 33.015625 40.375 \nQ 26.375 40.375 22.484375 35.828125 \nQ 18.609375 31.296875 18.609375 23.390625 \nQ 18.609375 15.53125 22.484375 10.953125 \nQ 26.375 6.390625 33.015625 6.390625 \nQ 39.65625 6.390625 43.53125 10.953125 \nQ 47.40625 15.53125 47.40625 23.390625 \nQ 47.40625 31.296875 43.53125 35.828125 \nQ 39.65625 40.375 33.015625 40.375 \nz\nM 52.59375 71.296875 \nL 52.59375 62.3125 \nQ 48.875 64.0625 45.09375 64.984375 \nQ 41.3125 65.921875 37.59375 65.921875 \nQ 27.828125 65.921875 22.671875 59.328125 \nQ 17.53125 52.734375 16.796875 39.40625 \nQ 19.671875 43.65625 24.015625 45.921875 \nQ 28.375 48.1875 33.59375 48.1875 \nQ 44.578125 48.1875 50.953125 41.515625 \nQ 57.328125 34.859375 57.328125 23.390625 \nQ 57.328125 12.15625 50.6875 5.359375 \nQ 44.046875 -1.421875 33.015625 -1.421875 \nQ 20.359375 -1.421875 13.671875 8.265625 \nQ 6.984375 17.96875 6.984375 36.375 \nQ 6.984375 53.65625 15.1875 63.9375 \nQ 23.390625 74.21875 37.203125 74.21875 \nQ 40.921875 74.21875 44.703125 73.484375 \nQ 48.484375 72.75 52.59375 71.296875 \nz\n\" id=\"DejaVuSans-54\"/>\n      </defs>\n      <g transform=\"translate(241.064278 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-54\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_5\">\n     <g id=\"line2d_5\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"333.157379\" xlink:href=\"#m2d56e675c0\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_5\">\n      <!-- 800000 -->\n      <defs>\n       <path d=\"M 31.78125 34.625 \nQ 24.75 34.625 20.71875 30.859375 \nQ 16.703125 27.09375 16.703125 20.515625 \nQ 16.703125 13.921875 20.71875 10.15625 \nQ 24.75 6.390625 31.78125 6.390625 \nQ 38.8125 6.390625 42.859375 10.171875 \nQ 46.921875 13.96875 46.921875 20.515625 \nQ 46.921875 27.09375 42.890625 30.859375 \nQ 38.875 34.625 31.78125 34.625 \nz\nM 21.921875 38.8125 \nQ 15.578125 40.375 12.03125 44.71875 \nQ 8.5 49.078125 8.5 55.328125 \nQ 8.5 64.0625 14.71875 69.140625 \nQ 20.953125 74.21875 31.78125 74.21875 \nQ 42.671875 74.21875 48.875 69.140625 \nQ 55.078125 64.0625 55.078125 55.328125 \nQ 55.078125 49.078125 51.53125 44.71875 \nQ 48 40.375 41.703125 38.8125 \nQ 48.828125 37.15625 52.796875 32.3125 \nQ 56.78125 27.484375 56.78125 20.515625 \nQ 56.78125 9.90625 50.3125 4.234375 \nQ 43.84375 -1.421875 31.78125 -1.421875 \nQ 19.734375 -1.421875 13.25 4.234375 \nQ 6.78125 9.90625 6.78125 20.515625 \nQ 6.78125 27.484375 10.78125 32.3125 \nQ 14.796875 37.15625 21.921875 38.8125 \nz\nM 18.3125 54.390625 \nQ 18.3125 48.734375 21.84375 45.5625 \nQ 25.390625 42.390625 31.78125 42.390625 \nQ 38.140625 42.390625 41.71875 45.5625 \nQ 45.3125 48.734375 45.3125 54.390625 \nQ 45.3125 60.0625 41.71875 63.234375 \nQ 38.140625 66.40625 31.78125 66.40625 \nQ 25.390625 66.40625 21.84375 63.234375 \nQ 18.3125 60.0625 18.3125 54.390625 \nz\n\" id=\"DejaVuSans-56\"/>\n      </defs>\n      <g transform=\"translate(314.069879 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-56\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"matplotlib.axis_2\">\n    <g id=\"ytick_1\">\n     <g id=\"line2d_6\">\n      <defs>\n       <path d=\"M 0 0 \nL -3.5 0 \n\" id=\"me7cbd43d73\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_6\">\n      <!-- 0 -->\n      <g transform=\"translate(13.5625 228.439219)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_2\">\n     <g id=\"line2d_7\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"196.655444\"/>\n      </g>\n     </g>\n     <g id=\"text_7\">\n      <!-- 5 -->\n      <defs>\n       <path d=\"M 10.796875 72.90625 \nL 49.515625 72.90625 \nL 49.515625 64.59375 \nL 19.828125 64.59375 \nL 19.828125 46.734375 \nQ 21.96875 47.46875 24.109375 47.828125 \nQ 26.265625 48.1875 28.421875 48.1875 \nQ 40.625 48.1875 47.75 41.5 \nQ 54.890625 34.8125 54.890625 23.390625 \nQ 54.890625 11.625 47.5625 5.09375 \nQ 40.234375 -1.421875 26.90625 -1.421875 \nQ 22.3125 -1.421875 17.546875 -0.640625 \nQ 12.796875 0.140625 7.71875 1.703125 \nL 7.71875 11.625 \nQ 12.109375 9.234375 16.796875 8.0625 \nQ 21.484375 6.890625 26.703125 6.890625 \nQ 35.15625 6.890625 40.078125 11.328125 \nQ 45.015625 15.765625 45.015625 23.390625 \nQ 45.015625 31 40.078125 35.4375 \nQ 35.15625 39.890625 26.703125 39.890625 \nQ 22.75 39.890625 18.8125 39.015625 \nQ 14.890625 38.140625 10.796875 36.28125 \nz\n\" id=\"DejaVuSans-53\"/>\n      </defs>\n      <g transform=\"translate(13.5625 200.454663)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_3\">\n     <g id=\"line2d_8\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"168.670888\"/>\n      </g>\n     </g>\n     <g id=\"text_8\">\n      <!-- 10 -->\n      <defs>\n       <path d=\"M 12.40625 8.296875 \nL 28.515625 8.296875 \nL 28.515625 63.921875 \nL 10.984375 60.40625 \nL 10.984375 69.390625 \nL 28.421875 72.90625 \nL 38.28125 72.90625 \nL 38.28125 8.296875 \nL 54.390625 8.296875 \nL 54.390625 0 \nL 12.40625 0 \nz\n\" id=\"DejaVuSans-49\"/>\n      </defs>\n      <g transform=\"translate(7.2 172.470107)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_4\">\n     <g id=\"line2d_9\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"140.686332\"/>\n      </g>\n     </g>\n     <g id=\"text_9\">\n      <!-- 15 -->\n      <g transform=\"translate(7.2 144.485551)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_5\">\n     <g id=\"line2d_10\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"112.701776\"/>\n      </g>\n     </g>\n     <g id=\"text_10\">\n      <!-- 20 -->\n      <g transform=\"translate(7.2 116.500995)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_6\">\n     <g id=\"line2d_11\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"84.71722\"/>\n      </g>\n     </g>\n     <g id=\"text_11\">\n      <!-- 25 -->\n      <g transform=\"translate(7.2 88.516439)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_7\">\n     <g id=\"line2d_12\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"56.732664\"/>\n      </g>\n     </g>\n     <g id=\"text_12\">\n      <!-- 30 -->\n      <defs>\n       <path d=\"M 40.578125 39.3125 \nQ 47.65625 37.796875 51.625 33 \nQ 55.609375 28.21875 55.609375 21.1875 \nQ 55.609375 10.40625 48.1875 4.484375 \nQ 40.765625 -1.421875 27.09375 -1.421875 \nQ 22.515625 -1.421875 17.65625 -0.515625 \nQ 12.796875 0.390625 7.625 2.203125 \nL 7.625 11.71875 \nQ 11.71875 9.328125 16.59375 8.109375 \nQ 21.484375 6.890625 26.8125 6.890625 \nQ 36.078125 6.890625 40.9375 10.546875 \nQ 45.796875 14.203125 45.796875 21.1875 \nQ 45.796875 27.640625 41.28125 31.265625 \nQ 36.765625 34.90625 28.71875 34.90625 \nL 20.21875 34.90625 \nL 20.21875 43.015625 \nL 29.109375 43.015625 \nQ 36.375 43.015625 40.234375 45.921875 \nQ 44.09375 48.828125 44.09375 54.296875 \nQ 44.09375 59.90625 40.109375 62.90625 \nQ 36.140625 65.921875 28.71875 65.921875 \nQ 24.65625 65.921875 20.015625 65.03125 \nQ 15.375 64.15625 9.8125 62.3125 \nL 9.8125 71.09375 \nQ 15.4375 72.65625 20.34375 73.4375 \nQ 25.25 74.21875 29.59375 74.21875 \nQ 40.828125 74.21875 47.359375 69.109375 \nQ 53.90625 64.015625 53.90625 55.328125 \nQ 53.90625 49.265625 50.4375 45.09375 \nQ 46.96875 40.921875 40.578125 39.3125 \nz\n\" id=\"DejaVuSans-51\"/>\n      </defs>\n      <g transform=\"translate(7.2 60.531883)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-51\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_8\">\n     <g id=\"line2d_13\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#me7cbd43d73\" y=\"28.748108\"/>\n      </g>\n     </g>\n     <g id=\"text_13\">\n      <!-- 35 -->\n      <g transform=\"translate(7.2 32.547327)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-51\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"patch_28\">\n    <path d=\"M 26.925 224.64 \nL 26.925 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_29\">\n    <path d=\"M 361.725 224.64 \nL 361.725 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_30\">\n    <path d=\"M 26.925 224.64 \nL 361.725 224.64 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_31\">\n    <path d=\"M 26.925 7.2 \nL 361.725 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n  </g>\n </g>\n <defs>\n  <clipPath id=\"p45ea857fc5\">\n   <rect height=\"217.44\" width=\"334.8\" x=\"26.925\" y=\"7.2\"/>\n  </clipPath>\n </defs>\n</svg>\n",
-      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPaklEQVR4nO3df4xlZX3H8fenu/ywamQpE7LlRxct0ZAmLnTKj2gai6JImooJacBGt61mbSuJtqYt6B/VpE20UWlNG3Ut6MaIShELIVpLkcSYNGsHXWH5VX6ICFnYQYto/7AFv/3jPgPDODP3zsy9u/eB9ys5mXOec+59vvfZZz9799xz5qaqkCT15xcOdQGSpPUxwCWpUwa4JHXKAJekThngktSpzQezs2OOOaa2bdt2MLuUpO7dfPPNj1bVzNL2gxrg27ZtY25u7mB2KUndS/K95do9hSJJnTLAJalTBrgkdcoAl6ROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ06qHdibsSVex5Y0/FvOuPECVUiSdPBd+CS1CkDXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHVqaIAnOTLJN5N8J8ltSd7f2j+d5LtJ9rZl++TLlSQtGOVGnp8CZ1fVT5IcBnwjyVfavj+vqqsnV54kaSVDA7yqCvhJ2zysLTXJoiRJw410DjzJpiR7gQPADVW1p+36myS3JLksyRErPHZnkrkkc/Pz82MqW5I0UoBX1ZNVtR04Hjg9ya8BlwIvA34DOBr4yxUeu6uqZqtqdmZmZkxlS5LWdBVKVT0G3AScW1X7a+CnwKeA0ydRoCRpeaNchTKT5Ki2/jzgHODOJFtbW4DzgX2TLFSS9EyjXIWyFdidZBODwL+qqq5P8rUkM0CAvcAfTbBOSdISo1yFcgtw6jLtZ0+kIknSSLwTU5I6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSerUKN9Kf2SSbyb5TpLbkry/tZ+UZE+Se5J8Icnhky9XkrRglHfgPwXOrqqXA9uBc5OcCXwQuKyqfhX4b+CtkytTkrTU0ACvgZ+0zcPaUsDZwNWtfTdw/kQqlCQta6Rz4Ek2JdkLHABuAO4FHquqJ9ohDwLHrfDYnUnmkszNz8+Po2ZJEiMGeFU9WVXbgeOB04GXjdpBVe2qqtmqmp2ZmVlnmZKkpdZ0FUpVPQbcBJwFHJVkc9t1PPDQmGuTJK1ilKtQZpIc1dafB5wD3MEgyC9oh+0Arp1UkZKkn7d5+CFsBXYn2cQg8K+qquuT3A58PslfA98GLp9gnZKkJYYGeFXdApy6TPt9DM6HS5IOAe/ElKROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjo1yrfSn5DkpiS3J7ktyTtb+/uSPJRkb1vOm3y5kqQFo3wr/RPAu6vqW0leCNyc5Ia277Kq+tDkypMkrWSUb6XfD+xv6z9Ocgdw3KQLkyStbk3nwJNsA04F9rSmi5PckuSKJFtWeMzOJHNJ5ubn5zdUrCTpaSMHeJIXAF8E3lVVjwMfA14CbGfwDv3Dyz2uqnZV1WxVzc7MzIyhZEkSjBjgSQ5jEN6fraprAKrqkap6sqp+BnwSOH1yZUqSlhrlKpQAlwN3VNVHFrVvXXTYG4F94y9PkrSSUa5CeQXwZuDWJHtb23uAi5JsBwq4H3j7RCqUJC1rlKtQvgFkmV1fHn85kqRReSemJHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6ROGeCS1KlRvpX+hCQ3Jbk9yW1J3tnaj05yQ5K7288tky9XkrRglHfgTwDvrqpTgDOBdyQ5BbgEuLGqTgZubNuSpINkaIBX1f6q+lZb/zFwB3Ac8AZgdztsN3D+pIqUJP28zWs5OMk24FRgD3BsVe1vux4Gjl3hMTuBnQAnnnjieutcsyv3PLDmx7zpjINXnyRt1MgfYiZ5AfBF4F1V9fjifVVVQC33uKraVVWzVTU7MzOzoWIlSU8bKcCTHMYgvD9bVde05keSbG37twIHJlOiJGk5o1yFEuBy4I6q+siiXdcBO9r6DuDa8ZcnSVrJKOfAXwG8Gbg1yd7W9h7gA8BVSd4KfA/43cmUKElaztAAr6pvAFlh96vHW44kaVTeiSlJnTLAJalTBrgkdcoAl6ROGeCS1CkDXJI6ZYBLUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6tSavpX+2W6t32Tvt9hLOpR8By5JnTLAJalTo3wr/RVJDiTZt6jtfUkeSrK3LedNtkxJ0lKjvAP/NHDuMu2XVdX2tnx5vGVJkoYZGuBV9XXghwehFknSGmzkHPjFSW5pp1i2rHRQkp1J5pLMzc/Pb6A7SdJi6w3wjwEvAbYD+4EPr3RgVe2qqtmqmp2ZmVlnd5KkpdYV4FX1SFU9WVU/Az4JnD7esiRJw6wrwJNsXbT5RmDfSsdKkiZj6J2YST4HvAo4JsmDwF8Br0qyHSjgfuDtE6xRkrSMoQFeVRct03z5BGqRJK2Bd2JKUqcMcEnqlAEuSZ0ywCWpUwa4JHXKAJekThngktQpA1ySOmWAS1KnDHBJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnRoa4EmuSHIgyb5FbUcnuSHJ3e3nlsmWKUlaapR34J8Gzl3SdglwY1WdDNzYtiVJB9HQAK+qrwM/XNL8BmB3W98NnD/muiRJQ6z3HPixVbW/rT8MHLvSgUl2JplLMjc/P7/O7iRJS234Q8yqKqBW2b+rqmaranZmZmaj3UmSmvUG+CNJtgK0nwfGV5IkaRTrDfDrgB1tfQdw7XjKkSSNapTLCD8H/Afw0iQPJnkr8AHgnCR3A69p25Kkg2jzsAOq6qIVdr16zLVIktbAOzElqVMGuCR1ygCXpE4Z4JLUKQNckjo19CoUrezKPQ+s6fg3nXHihCqR9FzkO3BJ6pQBLkmdMsAlqVMGuCR1ygCXpE4Z4JLUKQNckjplgEtSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6RObei3ESa5H/gx8CTwRFXNjqMoSdJw4/h1sr9VVY+O4XkkSWvgKRRJ6tRGA7yAf0tyc5Kdyx2QZGeSuSRz8/PzG+xOkrRgowH+yqo6DXg98I4kv7n0gKraVVWzVTU7MzOzwe4kSQs2FOBV9VD7eQD4EnD6OIqSJA237gBP8vwkL1xYB14L7BtXYZKk1W3kKpRjgS8lWXieK6vqX8dSlSRpqHUHeFXdB7x8jLU86/kt9pLGycsIJalTBrgkdcoAl6ROGeCS1CkDXJI6NY5fZiWtyCtvpMnxHbgkdcoAl6ROGeCS1CkDXJI6ZYBLUqe8CkVTxatWtBznxfJ8By5JnTLAJalTBrgkdcoAl6RO+SHmFJv0Bzdrff719DFp0zhGazVtY7pWz4Z5tFbT8pp9By5JnTLAJalTGwrwJOcmuSvJPUkuGVdRkqTh1h3gSTYB/wi8HjgFuCjJKeMqTJK0uo28Az8duKeq7quq/wU+D7xhPGVJkoZJVa3vgckFwLlV9ba2/WbgjKq6eMlxO4GdbfOlwF3r6O4Y4NF1Ffrc4RgN5xgN5xgNdyjG6FeqamZp48QvI6yqXcCujTxHkrmqmh1TSc9KjtFwjtFwjtFw0zRGGzmF8hBwwqLt41ubJOkg2EiA/ydwcpKTkhwOXAhcN56yJEnDrPsUSlU9keRi4KvAJuCKqrptbJU904ZOwTxHOEbDOUbDOUbDTc0YrftDTEnSoeWdmJLUKQNckjo19QH+bL9dP8kJSW5KcnuS25K8s7UfneSGJHe3n1tae5J8tI3HLUlOW/RcO9rxdyfZsaj915Pc2h7z0SRZrY9plWRTkm8nub5tn5RkT3tdX2gfppPkiLZ9T9u/bdFzXNra70ryukXty86zlfqYRkmOSnJ1kjuT3JHkLOfRMyX50/b3bF+SzyU5sut5VFVTuzD4cPRe4MXA4cB3gFMOdV1jfo1bgdPa+guB/2Lwqwn+FriktV8CfLCtnwd8BQhwJrCntR8N3Nd+bmnrW9q+b7Zj0x77+ta+bB/TugB/BlwJXN+2rwIubOsfB/64rf8J8PG2fiHwhbZ+SptDRwAntbm1abV5tlIf07gAu4G3tfXDgaOcR88Yn+OA7wLPW/Rn+/s9z6NDPqhDBvws4KuLti8FLj3UdU34NV8LnMPgjtWtrW0rcFdb/wRw0aLj72r7LwI+saj9E61tK3Dnovanjlupj2lcGNxncCNwNnB9C5FHgc1L5wqDK6POauub23FZOn8Wjltpnq3Wx7QtwItaOGVJu/Po6ZqPA77P4B+nzW0eva7neTTtp1AWBnzBg63tWan9F+1UYA9wbFXtb7seBo5t6yuNyWrtDy7Tzip9TKO/A/4C+Fnb/iXgsap6om0vfl1PjUXb/6N2/FrHbrU+ps1JwDzwqXaa6Z+SPB/n0VOq6iHgQ8ADwH4G8+JmOp5H0x7gzxlJXgB8EXhXVT2+eF8N/tme6PWeB6OP9Ury28CBqrr5UNcyxTYDpwEfq6pTgf9hcDrjKc6jbGHwC/dOAn4ZeD5w7iEtaoOmPcCfE7frJzmMQXh/tqquac2PJNna9m8FDrT2lcZktfbjl2lfrY9p8wrgd5Lcz+C3Xp4N/D1wVJKFm9EWv66nxqLtfxHwA9Y+dj9YpY9p8yDwYFXtadtXMwh059HTXgN8t6rmq+r/gGsYzK1u59G0B/iz/nb99kn+5cAdVfWRRbuuAxauANjB4Nz4Qvtb2lUEZwI/av99/Srw2iRb2juN1zI4z7YfeDzJma2vtyx5ruX6mCpVdWlVHV9V2xjMga9V1e8BNwEXtMOWjtHC67qgHV+t/cJ2dcFJwMkMPphbdp61x6zUx1SpqoeB7yd5aWt6NXA7zqPFHgDOTPKL7TUsjFG/8+hQf7AwwgcP5zG4MuNe4L2Hup4JvL5XMvgv5y3A3racx+C82Y3A3cC/A0e348PgizTuBW4FZhc91x8C97TlDxa1zwL72mP+gafvwF22j2legFfx9FUoL25/ce4B/hk4orUf2bbvaftfvOjx723jcBftKorV5tlKfUzjAmwH5tpc+hcGV5E4j545Ru8H7myv4zMMriTpdh55K70kdWraT6FIklZggEtSpwxwSeqUAS5JnTLAJalTBrgkdcoAl6RO/T/sboSrn/bcSgAAAABJRU5ErkJggg==\n"
+      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 368.925 248.518125\" width=\"368.925pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n  <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n  </style>\n </defs>\n <g id=\"figure_1\">\n  <g id=\"patch_1\">\n   <path d=\"M -0 248.518125 \nL 368.925 248.518125 \nL 368.925 0 \nL -0 0 \nz\n\" style=\"fill:none;\"/>\n  </g>\n  <g id=\"axes_1\">\n   <g id=\"patch_2\">\n    <path d=\"M 26.925 224.64 \nL 361.725 224.64 \nL 361.725 7.2 \nL 26.925 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n   </g>\n   <g id=\"patch_3\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 42.143182 224.64 \nL 55.376383 224.64 \nL 55.376383 17.554286 \nL 42.143182 17.554286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_4\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 55.376383 224.64 \nL 68.609585 224.64 \nL 68.609585 34.121143 \nL 55.376383 34.121143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_5\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 68.609585 224.64 \nL 81.842787 224.64 \nL 81.842787 83.821714 \nL 68.609585 83.821714 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_6\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 81.842787 224.64 \nL 95.075988 224.64 \nL 95.075988 150.089143 \nL 81.842787 150.089143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_7\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 95.075988 224.64 \nL 108.30919 224.64 \nL 108.30919 216.356571 \nL 95.075988 216.356571 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_8\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 108.30919 224.64 \nL 121.542391 224.64 \nL 121.542391 224.64 \nL 108.30919 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_9\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 121.542391 224.64 \nL 134.775593 224.64 \nL 134.775593 191.506286 \nL 121.542391 191.506286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_10\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 134.775593 224.64 \nL 148.008794 224.64 \nL 148.008794 208.073143 \nL 134.775593 208.073143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_11\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 148.008794 224.64 \nL 161.241996 224.64 \nL 161.241996 224.64 \nL 148.008794 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_12\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 161.241996 224.64 \nL 174.475198 224.64 \nL 174.475198 208.073143 \nL 161.241996 208.073143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_13\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 174.475198 224.64 \nL 187.708399 224.64 \nL 187.708399 216.356571 \nL 174.475198 216.356571 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_14\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 187.708399 224.64 \nL 200.941601 224.64 \nL 200.941601 224.64 \nL 187.708399 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_15\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 200.941601 224.64 \nL 214.174802 224.64 \nL 214.174802 224.64 \nL 200.941601 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_16\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 214.174802 224.64 \nL 227.408004 224.64 \nL 227.408004 216.356571 \nL 214.174802 216.356571 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_17\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 227.408004 224.64 \nL 240.641206 224.64 \nL 240.641206 224.64 \nL 227.408004 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_18\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 240.641206 224.64 \nL 253.874407 224.64 \nL 253.874407 224.64 \nL 240.641206 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_19\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 253.874407 224.64 \nL 267.107609 224.64 \nL 267.107609 216.356571 \nL 253.874407 216.356571 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_20\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 267.107609 224.64 \nL 280.34081 224.64 \nL 280.34081 224.64 \nL 267.107609 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_21\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 280.34081 224.64 \nL 293.574012 224.64 \nL 293.574012 208.073143 \nL 280.34081 208.073143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_22\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 293.574012 224.64 \nL 306.807213 224.64 \nL 306.807213 224.64 \nL 293.574012 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_23\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 306.807213 224.64 \nL 320.040415 224.64 \nL 320.040415 216.356571 \nL 306.807213 216.356571 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_24\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 320.040415 224.64 \nL 333.273617 224.64 \nL 333.273617 224.64 \nL 320.040415 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_25\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 333.273617 224.64 \nL 346.506818 224.64 \nL 346.506818 208.073143 \nL 333.273617 208.073143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"matplotlib.axis_1\">\n    <g id=\"xtick_1\">\n     <g id=\"line2d_1\">\n      <defs>\n       <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"mabeb11051b\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"41.448148\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_1\">\n      <!-- 0 -->\n      <defs>\n       <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421875 \nQ 19.53125 -1.421875 13.0625 8.265625 \nQ 6.59375 17.96875 6.59375 36.375 \nQ 6.59375 54.828125 13.0625 64.515625 \nQ 19.53125 74.21875 31.78125 74.21875 \nz\n\" id=\"DejaVuSans-48\"/>\n      </defs>\n      <g transform=\"translate(38.266898 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_2\">\n     <g id=\"line2d_2\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"89.118482\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_2\">\n      <!-- 100000 -->\n      <defs>\n       <path d=\"M 12.40625 8.296875 \nL 28.515625 8.296875 \nL 28.515625 63.921875 \nL 10.984375 60.40625 \nL 10.984375 69.390625 \nL 28.421875 72.90625 \nL 38.28125 72.90625 \nL 38.28125 8.296875 \nL 54.390625 8.296875 \nL 54.390625 0 \nL 12.40625 0 \nz\n\" id=\"DejaVuSans-49\"/>\n      </defs>\n      <g transform=\"translate(70.030982 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_3\">\n     <g id=\"line2d_3\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"136.788815\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_3\">\n      <!-- 200000 -->\n      <defs>\n       <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n      </defs>\n      <g transform=\"translate(117.701315 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_4\">\n     <g id=\"line2d_4\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"184.459148\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_4\">\n      <!-- 300000 -->\n      <defs>\n       <path d=\"M 40.578125 39.3125 \nQ 47.65625 37.796875 51.625 33 \nQ 55.609375 28.21875 55.609375 21.1875 \nQ 55.609375 10.40625 48.1875 4.484375 \nQ 40.765625 -1.421875 27.09375 -1.421875 \nQ 22.515625 -1.421875 17.65625 -0.515625 \nQ 12.796875 0.390625 7.625 2.203125 \nL 7.625 11.71875 \nQ 11.71875 9.328125 16.59375 8.109375 \nQ 21.484375 6.890625 26.8125 6.890625 \nQ 36.078125 6.890625 40.9375 10.546875 \nQ 45.796875 14.203125 45.796875 21.1875 \nQ 45.796875 27.640625 41.28125 31.265625 \nQ 36.765625 34.90625 28.71875 34.90625 \nL 20.21875 34.90625 \nL 20.21875 43.015625 \nL 29.109375 43.015625 \nQ 36.375 43.015625 40.234375 45.921875 \nQ 44.09375 48.828125 44.09375 54.296875 \nQ 44.09375 59.90625 40.109375 62.90625 \nQ 36.140625 65.921875 28.71875 65.921875 \nQ 24.65625 65.921875 20.015625 65.03125 \nQ 15.375 64.15625 9.8125 62.3125 \nL 9.8125 71.09375 \nQ 15.4375 72.65625 20.34375 73.4375 \nQ 25.25 74.21875 29.59375 74.21875 \nQ 40.828125 74.21875 47.359375 69.109375 \nQ 53.90625 64.015625 53.90625 55.328125 \nQ 53.90625 49.265625 50.4375 45.09375 \nQ 46.96875 40.921875 40.578125 39.3125 \nz\n\" id=\"DejaVuSans-51\"/>\n      </defs>\n      <g transform=\"translate(165.371648 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-51\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_5\">\n     <g id=\"line2d_5\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"232.129481\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_5\">\n      <!-- 400000 -->\n      <defs>\n       <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375 17.1875 \nL 47.609375 0 \nL 37.796875 0 \nL 37.796875 17.1875 \nL 4.890625 17.1875 \nL 4.890625 26.703125 \nz\n\" id=\"DejaVuSans-52\"/>\n      </defs>\n      <g transform=\"translate(213.041981 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-52\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_6\">\n     <g id=\"line2d_6\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"279.799814\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_6\">\n      <!-- 500000 -->\n      <defs>\n       <path d=\"M 10.796875 72.90625 \nL 49.515625 72.90625 \nL 49.515625 64.59375 \nL 19.828125 64.59375 \nL 19.828125 46.734375 \nQ 21.96875 47.46875 24.109375 47.828125 \nQ 26.265625 48.1875 28.421875 48.1875 \nQ 40.625 48.1875 47.75 41.5 \nQ 54.890625 34.8125 54.890625 23.390625 \nQ 54.890625 11.625 47.5625 5.09375 \nQ 40.234375 -1.421875 26.90625 -1.421875 \nQ 22.3125 -1.421875 17.546875 -0.640625 \nQ 12.796875 0.140625 7.71875 1.703125 \nL 7.71875 11.625 \nQ 12.109375 9.234375 16.796875 8.0625 \nQ 21.484375 6.890625 26.703125 6.890625 \nQ 35.15625 6.890625 40.078125 11.328125 \nQ 45.015625 15.765625 45.015625 23.390625 \nQ 45.015625 31 40.078125 35.4375 \nQ 35.15625 39.890625 26.703125 39.890625 \nQ 22.75 39.890625 18.8125 39.015625 \nQ 14.890625 38.140625 10.796875 36.28125 \nz\n\" id=\"DejaVuSans-53\"/>\n      </defs>\n      <g transform=\"translate(260.712314 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-53\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_7\">\n     <g id=\"line2d_7\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"327.470147\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_7\">\n      <!-- 600000 -->\n      <defs>\n       <path d=\"M 33.015625 40.375 \nQ 26.375 40.375 22.484375 35.828125 \nQ 18.609375 31.296875 18.609375 23.390625 \nQ 18.609375 15.53125 22.484375 10.953125 \nQ 26.375 6.390625 33.015625 6.390625 \nQ 39.65625 6.390625 43.53125 10.953125 \nQ 47.40625 15.53125 47.40625 23.390625 \nQ 47.40625 31.296875 43.53125 35.828125 \nQ 39.65625 40.375 33.015625 40.375 \nz\nM 52.59375 71.296875 \nL 52.59375 62.3125 \nQ 48.875 64.0625 45.09375 64.984375 \nQ 41.3125 65.921875 37.59375 65.921875 \nQ 27.828125 65.921875 22.671875 59.328125 \nQ 17.53125 52.734375 16.796875 39.40625 \nQ 19.671875 43.65625 24.015625 45.921875 \nQ 28.375 48.1875 33.59375 48.1875 \nQ 44.578125 48.1875 50.953125 41.515625 \nQ 57.328125 34.859375 57.328125 23.390625 \nQ 57.328125 12.15625 50.6875 5.359375 \nQ 44.046875 -1.421875 33.015625 -1.421875 \nQ 20.359375 -1.421875 13.671875 8.265625 \nQ 6.984375 17.96875 6.984375 36.375 \nQ 6.984375 53.65625 15.1875 63.9375 \nQ 23.390625 74.21875 37.203125 74.21875 \nQ 40.921875 74.21875 44.703125 73.484375 \nQ 48.484375 72.75 52.59375 71.296875 \nz\n\" id=\"DejaVuSans-54\"/>\n      </defs>\n      <g transform=\"translate(308.382647 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-54\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"matplotlib.axis_2\">\n    <g id=\"ytick_1\">\n     <g id=\"line2d_8\">\n      <defs>\n       <path d=\"M 0 0 \nL -3.5 0 \n\" id=\"m5ee8100b6b\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m5ee8100b6b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_8\">\n      <!-- 0 -->\n      <g transform=\"translate(13.5625 228.439219)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_2\">\n     <g id=\"line2d_9\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m5ee8100b6b\" y=\"183.222857\"/>\n      </g>\n     </g>\n     <g id=\"text_9\">\n      <!-- 5 -->\n      <g transform=\"translate(13.5625 187.022076)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_3\">\n     <g id=\"line2d_10\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m5ee8100b6b\" y=\"141.805714\"/>\n      </g>\n     </g>\n     <g id=\"text_10\">\n      <!-- 10 -->\n      <g transform=\"translate(7.2 145.604933)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_4\">\n     <g id=\"line2d_11\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m5ee8100b6b\" y=\"100.388571\"/>\n      </g>\n     </g>\n     <g id=\"text_11\">\n      <!-- 15 -->\n      <g transform=\"translate(7.2 104.18779)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_5\">\n     <g id=\"line2d_12\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m5ee8100b6b\" y=\"58.971429\"/>\n      </g>\n     </g>\n     <g id=\"text_12\">\n      <!-- 20 -->\n      <g transform=\"translate(7.2 62.770647)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_6\">\n     <g id=\"line2d_13\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m5ee8100b6b\" y=\"17.554286\"/>\n      </g>\n     </g>\n     <g id=\"text_13\">\n      <!-- 25 -->\n      <g transform=\"translate(7.2 21.353504)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"patch_26\">\n    <path d=\"M 26.925 224.64 \nL 26.925 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_27\">\n    <path d=\"M 361.725 224.64 \nL 361.725 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_28\">\n    <path d=\"M 26.925 224.64 \nL 361.725 224.64 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_29\">\n    <path d=\"M 26.925 7.2 \nL 361.725 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n  </g>\n </g>\n <defs>\n  <clipPath id=\"p920417c74f\">\n   <rect height=\"217.44\" width=\"334.8\" x=\"26.925\" y=\"7.2\"/>\n  </clipPath>\n </defs>\n</svg>\n",
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOKklEQVR4nO3dX4xc5XnH8e9TDKQFFEy8sizAXRIhKl+0hq4wCIRoaCigqiQSqjAVWC2RoxYkUCNVkEhtepdWDWmjVgSn0HCBaf4ABaG0hDpIUaTK6ZoYMDiODXWIkcFL0gTUmxZ4ejHvwmTZ3dn5s7vzkO9HGs057/nzPoc9/PbsO+eMIzORJNXzS6tdgCRpMAa4JBVlgEtSUQa4JBVlgEtSUWtWsrN169bl5OTkSnYpSeXt2bPn1cycmNu+ogE+OTnJ9PT0SnYpSeVFxA/na3cIRZKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqaieAR4RZ0bEExHxXEQ8GxG3tPbPRMRLEbG3va5a/nIlSbOWch/4G8AnM/PJiDgF2BMRj7dln8/Mv1m+8iRJC+kZ4Jl5FDjapl+PiP3A6ctdmCRpcX09iRkRk8C5wG7gIuDmiLgBmKZzlf7f82yzHdgOsHHjxoEL3bn7xb63uW7L4P1J0rhb8oeYEXEy8ABwa2a+BtwJfAjYTOcK/XPzbZeZOzJzKjOnJibe9Si/JGlASwrwiDieTnjfl5kPAmTmK5n5Zma+BXwJOH/5ypQkzbWUu1ACuBvYn5l3dLVv6FrtY8C+0ZcnSVrIUsbALwKuB56JiL2t7VPA1ojYDCRwGPjEslQoSZrXUu5C+Q4Q8yz6xujLkSQtlU9iSlJRBrgkFWWAS1JRBrgkFWWAS1JRBrgkFWWAS1JRBrgkFWWAS1JRBrgkFdXX94FXM8h3iIPfIy6pBq/AJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJamongEeEWdGxBMR8VxEPBsRt7T20yLi8Yg42N7XLn+5kqRZS7kCfwP4ZGZuAi4AboqITcBtwK7MPBvY1eYlSSukZ4Bn5tHMfLJNvw7sB04HrgbubavdC3x0uYqUJL1bX2PgETEJnAvsBtZn5tG26GVg/QLbbI+I6YiYnpmZGaJUSVK3JQd4RJwMPADcmpmvdS/LzARyvu0yc0dmTmXm1MTExFDFSpLesaQAj4jj6YT3fZn5YGt+JSI2tOUbgGPLU6IkaT5LuQslgLuB/Zl5R9eiR4BtbXob8PDoy5MkLWTNEta5CLgeeCYi9ra2TwGfBb4aETcCPwR+f3lKlCTNp2eAZ+Z3gFhg8WWjLUeStFQ+iSlJRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRS3lHzX+hbNz94t9b3Pdlo3LUIkkLcwrcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKJ6BnhE3BMRxyJiX1fbZyLipYjY215XLW+ZkqS5lnIF/mXginnaP5+Zm9vrG6MtS5LUS88Az8xvAz9ZgVokSX0YZgz85oh4ug2xrB1ZRZKkJRk0wO8EPgRsBo4Cn1toxYjYHhHTETE9MzMzYHeSpLkGCvDMfCUz38zMt4AvAecvsu6OzJzKzKmJiYlB65QkzTFQgEfEhq7ZjwH7FlpXkrQ8ev6bmBFxP3ApsC4ijgB/AVwaEZuBBA4Dn1jGGiVJ8+gZ4Jm5dZ7mu5ehFklSH3wSU5KKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKK6hngEXFPRByLiH1dbadFxOMRcbC9r13eMiVJcy3lCvzLwBVz2m4DdmXm2cCuNi9JWkE9Azwzvw38ZE7z1cC9bfpe4KMjrkuS1MOgY+DrM/Nom34ZWL/QihGxPSKmI2J6ZmZmwO4kSXMN/SFmZiaQiyzfkZlTmTk1MTExbHeSpGbQAH8lIjYAtPdjoytJkrQUgwb4I8C2Nr0NeHg05UiSlmoptxHeD/wHcE5EHImIG4HPAh+JiIPAb7d5SdIKWtNrhczcusCiy0ZciySpDz6JKUlFGeCSVFTPIRQtzc7dL/a9zXVbNi5DJZJ+UXgFLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFrRlm44g4DLwOvAm8kZlToyhKktTbUAHe/FZmvjqC/UiS+uAQiiQVNWyAJ/DNiNgTEdvnWyEitkfEdERMz8zMDNmdJGnWsAF+cWaeB1wJ3BQRl8xdITN3ZOZUZk5NTEwM2Z0kadZQAZ6ZL7X3Y8BDwPmjKEqS1NvAAR4RJ0XEKbPTwOXAvlEVJkla3DB3oawHHoqI2f3szMx/G0lVkqSeBg7wzHwB+I0R1iJJ6oO3EUpSUQa4JBVlgEtSUaN4lF4F7Nz9Yt/bXLdl4zJUImlUvAKXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKL8PnAtaKW+Q3yQfgbtSyvrvfizHadj8gpckooywCWpKANckooywCWpKANckooywCWpKANckooywCWpKB/kWUUr9aDMe9U4//cb59pg/OsbxHvxmHrxClySijLAJakoA1ySijLAJamooQI8Iq6IiAMRcSgibhtVUZKk3gYO8Ig4DvgH4EpgE7A1IjaNqjBJ0uKGuQI/HziUmS9k5v8C/wxcPZqyJEm9RGYOtmHENcAVmfnxNn89sCUzb56z3nZge5s9BzgwQHfrgFcHKnQ8WP/qqVw7WP9qGqfafzUzJ+Y2LvuDPJm5A9gxzD4iYjozp0ZU0oqz/tVTuXaw/tVUofZhhlBeAs7smj+jtUmSVsAwAf6fwNkRcVZEnABcCzwymrIkSb0MPISSmW9ExM3AY8BxwD2Z+ezIKvt5Qw3BjAHrXz2VawfrX01jX/vAH2JKklaXT2JKUlEGuCQVNfYBvpqP60fEPRFxLCL2dbWdFhGPR8TB9r62tUdEfKHV+XREnNe1zba2/sGI2NbV/psR8Uzb5gsREYv1MUD9Z0bEExHxXEQ8GxG3VDmGiHhfRHw3Ip5qtf9laz8rIna3/r7SPkAnIk5s84fa8smufd3e2g9ExO90tc97bi3UxyAi4riI+F5EPFqt/og43H62eyNiurWN/bnT9nFqRHw9Ir4fEfsj4sIqtfclM8f2RefD0eeBDwInAE8Bm1aw/0uA84B9XW1/DdzWpm8D/qpNXwX8KxDABcDu1n4a8EJ7X9um17Zl323rRtv2ysX6GKD+DcB5bfoU4Ad0vvZg7I+h7e/kNn08sLv181Xg2tb+ReCP2/SfAF9s09cCX2nTm9p5cyJwVjufjlvs3FqojwF/Bn8K7AQeXWzf41g/cBhYN6dt7M+dtt29wMfb9AnAqVVq7+s4l3PnQxcHFwKPdc3fDty+wjVM8vMBfgDY0KY3AAfa9F3A1rnrAVuBu7ra72ptG4Dvd7W/vd5CfYzgWB4GPlLtGIBfAZ4EttB5Mm7N3PODzt1QF7bpNW29mHvOzK630LnVtpm3jwHqPgPYBXwYeHSxfY9p/Yd5d4CP/bkDvB/4L9pNGpVq7/c17kMopwM/6po/0tpW0/rMPNqmXwbWt+mFal2s/cg87Yv1MbD2J/m5dK5kSxxDG37YCxwDHqdzxfnTzHxjnv7errEt/xnwgQGO6QOL9NGvvwX+DHirzS+273GsP4FvRsSe6HwlBtQ4d84CZoB/asNX/xgRJxWpvS/jHuBjLTu/Zpf1PsxR9BERJwMPALdm5muj3n8vg/aRmW9m5mY6V7LnA7826tqWS0T8LnAsM/esdi1DuDgzz6PzjaM3RcQl3QvH+NxZQ2fo887MPBf4HzrDGcPuty8r0ce4B/g4Pq7/SkRsAGjvx1r7QrUu1n7GPO2L9dG3iDieTnjfl5kPVjyGzPwp8ASd4YBTI2L2AbTu/t6usS1/P/DjAY7px4v00Y+LgN+LiMN0vqnzw8DfFaqfzHypvR8DHqLzS7TCuXMEOJKZu9v81+kEeoXa+zLuAT6Oj+s/Asx+Gr2NzrjybPsN7RPtC4CftT+lHgMuj4i17RPpy+mMSR4FXouIC9on2DfM2dd8ffSl7fduYH9m3lHpGCJiIiJObdO/TGfsfj+dIL9mgdpn+7sG+Fa7AnoEuDY6d3mcBZxN5wOoec+tts1CfSxZZt6emWdk5mTb97cy8w+q1B8RJ0XEKbPTdH7m+yhw7mTmy8CPIuKc1nQZ8FyF2vu2nAPso3jR+YT4B3TGPz+9wn3fDxwF/o/Ob/Ub6Ywx7gIOAv8OnNbWDTr/wMXzwDPAVNd+/gg41F5/2NU+Red/iueBv+edJ2Pn7WOA+i+m8yfc08De9rqqwjEAvw58r9W+D/jz1v5BOgF2CPgacGJrf1+bP9SWf7BrX59u9R2g3S2w2Lm1UB9DnEeX8s5dKCXqb/t4qr2end1/hXOn7WMzMN3On3+hcxdJidr7efkovSQVNe5DKJKkBRjgklSUAS5JRRngklSUAS5JRRngklSUAS5JRf0/SKUBmaC1OaEAAAAASUVORK5CYII=\n"
      },
      "metadata": {
       "needs_background": "light"
diff --git a/dataset_generation/processing.py b/dataset_generation/processing.py
index 3ec875c..fe4e23a 100644
--- a/dataset_generation/processing.py
+++ b/dataset_generation/processing.py
@@ -14,6 +14,9 @@ from transformers import PreTrainedTokenizerFast, BertTokenizerFast
 
 ACTIONS_KEYS = ['dot', 'upper_case', 'colon', 'semicolon', 'elipsis', 'dash']
 
+def empty_action_vector() -> np.ndarray:
+    return np.zeros(len(ACTIONS_KEYS))
+
 def text_from_xml(path: str) -> str:
     """Extract spoken text from dataset's xml format
 
@@ -47,20 +50,25 @@ def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
     Returns:
         Mapping[str, bool]: Mapping telling if each of possible actions should be performed (True) or not (False) 
     """
-    word.replace('"', " ") # No support for quotes
+    # Unsuported characters
+    word.replace('"', " ")
+    word.replace('(', " ")
+    word.replace(')', " ")
+
     while len(word) > 0 and not word[0].isalnum(): # remove proceding characters
-        print(word)
         word = word[1:]
 
     if len(word) == 0:
         return zip(ACTIONS_KEYS, [False] * len(ACTIONS_KEYS))
 
+    has_colon = len(word) > 3 and word[-3:] == "..."
+
     actions = {
-        'dot': word[0] == '.',
+        'dot': word[-1] == '.' and not has_colon,
         'upper_case': word[0].isupper(),
         'colon': word[-1] == ",",
         'semicolon': word[-1] == ";",
-        'elipsis': len(word) > 3 and word[:-3] == "...",
+        'elipsis': has_colon,
         'dash': next_word is not None and next_word == "-"
     }
 
@@ -127,6 +135,9 @@ def create_model_input_output(text: str) -> (str, np.ndarray):
 def tokenize_labeled_text(text: str, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast) -> (np.ndarray, np.ndarray):
     text_tokenized = tokenizer(text, return_offsets_mapping=True)
 
+    offset_mappings = text_tokenized['offset_mapping'][1:-1]
+    input_ids = text_tokenized['input_ids'][1:-1]
+
     # Create a map where each character is assigned index of it's word
     words_mapping = []
     actual_word = 0
@@ -136,12 +147,12 @@ def tokenize_labeled_text(text: str, labels: np.ndarray, tokenizer: PreTrainedTo
             actual_word += 1
 
     # Assign each token to a word
-    token_mapping = [words_mapping[x[0]] for x in text_tokenized['offset_mapping']]
+    token_mapping = [words_mapping[x[0]] for x in offset_mappings]
     
     # Expand word-based labels to token-based labels
     labels_tokenized = [labels[i] for i in token_mapping]
 
-    return np.array(text_tokenized['input_ids']).reshape(-1, 1), np.array(labels_tokenized)
+    return np.array(input_ids).reshape(-1, 1), np.array(labels_tokenized)
 
 
 def recover_word(word: str, action: Mapping[str, bool]) -> str:
@@ -172,7 +183,10 @@ def nearest_sentence_l(labels: np.array, index_start: int) -> int:
     result_index = index_start
 
     while result_index > 0:
-        if is_sentence_end(labels[result_index - 1, :]):
+        if is_sentence_end(labels[result_index, :]): 
+            # prevent beeing in the middle of token
+            result_index -= 1
+        elif is_sentence_end(labels[result_index - 1, :]):
             break
         elif result_index == 1:
             result_index = 0
@@ -183,13 +197,10 @@ def nearest_sentence_l(labels: np.array, index_start: int) -> int:
     return result_index
 
 def nearest_sentence_r(labels: np.array, index_start: int) -> Optional[int]:
-    result_index = index_start + 1
+    result_index = index_start
 
     while result_index < len(labels):
-        if is_sentence_end(labels[result_index - 1]):
-            break
-        elif result_index == 1:
-            result_index = 0
+        if is_sentence_end(labels[result_index - 1]) and not is_sentence_end(labels[result_index]):
             break
         else:
             result_index += 1
@@ -199,36 +210,83 @@ def nearest_sentence_r(labels: np.array, index_start: int) -> Optional[int]:
     else:
         return result_index
 
-def batchify_tokens(tokens: np.ndarray, labels: np.ndarray, max_tokens: int, min_tokens: int = 3) -> (np.ndarray, np.ndarray):
-
+def batchify_labels(labels: np.ndarray, max_tokens: int, min_tokens: int = 3) -> [np.ndarray]:
     assert min_tokens >= 1
+    assert max_tokens >= 1
 
-    # remove start & end tokens
-    tokens = tokens[1:-1, :]
-    labels = labels[1:-1, :]
-
-    tokens_batches = []
     labels_batches = []
 
     index = 0
-    while index < (tokens.shape[0] - min_tokens):
-        num_consumed = min(max_tokens, tokens.shape[0] - index)
+    new_index = 0
+    while index < (labels.shape[0] - min_tokens):
+        num_consumed = min(max_tokens, labels.shape[0] - index)
 
         assert num_consumed >= min_tokens
 
-        tokens_batches.append(tokens[index:(index + num_consumed), :])
-        labels_batches.append(labels[index:(index + num_consumed), :])
+        if index + num_consumed < (labels.shape[0] - min_tokens):
+            new_index = nearest_sentence_l(labels, index + num_consumed)
+            if new_index == index:
+                new_index = nearest_sentence_r(labels, index + num_consumed)
+                if new_index is None:
+                    labels_batches.append(np.array(list(range(index, index + num_consumed))))
+                    break
+        else:
+            labels_batches.append(np.array(list(range(index, index + num_consumed))))
+            break
 
-        new_index = nearest_sentence_l(labels, index + num_consumed)
-        if new_index == index:
-            new_index = nearest_sentence_r(labels, index + num_consumed)
-            if new_index is None:
-                break
+        labels_batches.append(np.array(list(range(index, index + num_consumed))))
 
         index = new_index
 
-    return np.array(tokens_batches), np.array(labels_batches)
+    return labels_batches
+
+def add_cls_sep(tokens: np.ndarray, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast) -> (np.ndarray, np.ndarray):
+    
+    tokens = np.concatenate([[[tokenizer.cls_token_id]], tokens, [[tokenizer.sep_token_id]]])
+    labels = np.concatenate([labels[:1, :], labels, labels[-1:, :]])
+
+    return tokens, labels
 
+def add_padding(tokens: np.ndarray, labels: np.ndarray, length: int, tokenizer: PreTrainedTokenizerFast) -> (np.ndarray, np.ndarray, np.ndarray):
+
+    pad_length = tokens.shape[0] - length
+    assert pad_length >= 0
+
+    if pad_length > 0:
+        tokens = np.concatenate([tokens, [[tokenizer.pad_token_id]] * pad_length])
+        labels = np.concatenate([labels, [empty_action_vector()] * pad_length])
+
+    mask = np.ones(len(tokens)).astype(np.bool)
+
+    if pad_length > 0:
+        mask[-pad_length:] = False
+
+    return tokens, labels, mask
+
+def batchify_data(tokens: np.ndarray, labels: np.ndarray, max_tokens: int,
+                    tokenizer: PreTrainedTokenizerFast, min_tokens: int = 3) -> (np.ndarray, np.ndarray):
+
+    assert max_tokens >= min_tokens + 2
+    assert min_tokens >= 1
+
+    tokens_batch = []
+    labels_batch = []
+    mask_batch = []
+
+    idxs = batchify_labels(labels, max_tokens - 2, min_tokens)
+
+    for ids in idxs:
+        tokens_sample = tokens[ids, :]
+        labels_sample = labels[ids, :]
+
+        tokens_sample, labels_sample = add_cls_sep(tokens_sample, labels_sample, tokenizer)
+        tokens_sample, labels_sample, mask = add_padding(tokens_sample, labels_sample, max_tokens, tokenizer)
+
+        tokens_batch.append(tokens_sample)
+        labels_batch.append(labels_sample)
+        mask_batch.append(mask)
+
+    return np.array(tokens_batch), np.array(labels_batch), np.array(mask_batch)
 
 def recover_text(text: str, actions_encoded: np.ndarray):
     words = text.split(" ")
@@ -242,5 +300,3 @@ def recover_text(text: str, actions_encoded: np.ndarray):
         words_output.append(word_recovered)
 
     return " ".join(words_output)
-
-
diff --git a/dataset_generation/test_processing.py b/dataset_generation/test_processing.py
index 34064fe..573c6ca 100644
--- a/dataset_generation/test_processing.py
+++ b/dataset_generation/test_processing.py
@@ -2,6 +2,38 @@ import numpy
 from processing import *
 from transformers import PreTrainedTokenizerFast, BertTokenizerFast
 
+def test_detect_actions():
+    actions = detect_actions("Janek...", None)
+    assert actions == {
+        'dot': False,
+        'upper_case': True,
+        'colon': False,
+        'semicolon': False,
+        'elipsis': True,
+        'dash': False
+    }
+
+    actions = detect_actions("ewka.", None)
+    assert actions == {
+        'dot': True,
+        'upper_case': False,
+        'colon': False,
+        'semicolon': False,
+        'elipsis': False,
+        'dash': False
+    }
+
+    actions = detect_actions("Test", "-")
+    assert actions == {
+        'dot': False,
+        'upper_case': True,
+        'colon': False,
+        'semicolon': False,
+        'elipsis': False,
+        'dash': True
+    }
+
+
 def test_encode_actions():
     x = {
         'dot': True,
@@ -40,11 +72,52 @@ def test_tokenize_labeled_text():
     assert token_labels.shape[1] == len(ACTIONS_KEYS)
 
     assert len(tokens) == len(token_labels)
-    assert tokens[0, 0] == tokenizer.cls_token_id
-    assert tokens[-1, 0] == tokenizer.sep_token_id
+    assert tokens[0, 0] != tokenizer.cls_token_id
+    assert tokens[-1, 0] != tokenizer.sep_token_id
+
+def test_nearest_sentence_l():
+    end = create_dummy_action(True)
+    word = create_dummy_action(False)
 
-    assert np.all(token_labels[0] == token_labels[1])
-    assert np.all(token_labels[-1] == token_labels[-2])
+    entry = np.array([word, word, word, end, end, word, word, end])
+
+    assert nearest_sentence_l(entry, 3) == 0
+    assert nearest_sentence_l(entry, 4) == 0
+    assert nearest_sentence_l(entry, 5) == 5
+    assert nearest_sentence_l(entry, 7) == 5
+
+def create_dummy_action(end_sentence: bool) -> np.array:
+    return encode_actions({
+        'dot': end_sentence,
+        'upper_case': False,
+        'colon': False,
+        'semicolon': False,
+        'elipsis': False,
+        'dash': False
+    })
+
+def test_nearest_sentence_r():
+    end = create_dummy_action(True)
+    word = create_dummy_action(False)
+
+    entry = np.array([word, word, word, end, end, word, word, end])
+
+    assert nearest_sentence_r(entry, 0) == 0
+    assert nearest_sentence_r(entry, 4) == 5
+    assert nearest_sentence_r(entry, 5) == 5
+    assert nearest_sentence_r(entry, 6) is None
+    assert nearest_sentence_r(entry, 7) is None
+
+def test_batchify_labels():
+    end = create_dummy_action(True)
+    word = create_dummy_action(False)
+    entry = np.array([word, word, word, end, end, word, word, end])
+
+    batches = batchify_labels(entry, 3, 1)
+
+    assert len(batches) == 2
+    assert np.all(batches[0] == range(0, 3))
+    assert np.all(batches[1] == range(5, 8))
 
 def test_batchify_tokens():
     text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam niedzwiedzia..."
@@ -53,25 +126,42 @@ def test_batchify_tokens():
     text_clean, labels = create_model_input_output(text)
     tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
 
-    input_batch, output_batch = batchify_tokens(tokens, token_labels, 5)
+    # print(tokenizer.convert_ids_to_tokens(tokens.reshape(-1).astype(int)))
+    #print(token_labels)
+
+    input_batch, output_batch, mask_batch = batchify_data(tokens, token_labels, 5, tokenizer)
     
     assert len(input_batch.shape) == 3
     assert len(output_batch.shape) == 3
+    assert len(mask_batch.shape) == 2
 
     # First dimension should be batch size
     assert input_batch.shape[0] == output_batch.shape[0]
+    assert input_batch.shape[0] == mask_batch.shape[0]
+    assert input_batch.shape[0] > 1
 
     # Second dimension should be sequence length
     assert input_batch.shape[1] == 5
     assert output_batch.shape[1] == 5
+    assert mask_batch.shape[1] == 5
 
     # Third dimension should be feature size
     assert input_batch.shape[2] == 1
     assert output_batch.shape[2] == len(ACTIONS_KEYS)
+    
+    # Mask should be boolean (True - leave, False - mask)
+    assert mask_batch.dtype == np.bool
+    
+    # Should never be fully masked
+    assert np.all(mask_batch[:, 0] == True)
 
-    # Should always start from beginning of the sentence
     for i in range(input_batch.shape[0]):
+        # Should always start from beginning of the sentence
         assert decode_actions(output_batch[i, 0, :])['upper_case'] == True
         assert decode_actions(output_batch[i, 1, :])['upper_case'] == True
 
-    
\ No newline at end of file
+        # Should always end with sep and padding
+
+
+def generate_batches(files: (str, str), batch_size: int, max_tokens: int):
+    pass
\ No newline at end of file
-- 
GitLab


From 30c7400df6bc027a94e7bbf60bcc0ec7f83dd256 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Sat, 11 Jul 2020 16:47:44 +0200
Subject: [PATCH 006/116] Task for generating whole dataset

---
 dataset_generation/generate_actions.py | 50 ++++++++++++++
 dataset_generation/processing.py       | 96 +++++++++++++++++++++++++-
 2 files changed, 145 insertions(+), 1 deletion(-)
 create mode 100644 dataset_generation/generate_actions.py

diff --git a/dataset_generation/generate_actions.py b/dataset_generation/generate_actions.py
new file mode 100644
index 0000000..5a39d2b
--- /dev/null
+++ b/dataset_generation/generate_actions.py
@@ -0,0 +1,50 @@
+#/usr/bin/python3
+import os
+import glob
+import random
+from lxml import etree
+import uuid
+import hashlib
+import seaborn as sns
+import re
+import numpy as np
+from tqdm import tqdm
+from processing import text_from_xml, create_model_input_output
+from utils import remove_multiple_spaces, remove_punctuation
+import dask
+from dask.diagnostics import ProgressBar
+
+file_schema = "../dane/**/text_structure.xml"
+files_paths = glob.glob(file_schema, recursive=True)
+
+files_subset = files_paths #random.sample(files_paths, 1_000)
+
+document_lens = []
+
+if not os.path.exists("../dataset_actions"):
+    os.mkdir("../dataset_actions")
+
+def process_file(file_path):
+    full_text = text_from_xml(file_path)
+    
+    if len(full_text) > 0:
+        output_file_input = f"../dataset_actions/{hashlib.md5(file_path.encode()).hexdigest()}_input.txt"
+        output_file_output = f"../dataset_actions/{hashlib.md5(file_path.encode()).hexdigest()}_output.txt"
+
+        model_input, model_output = create_model_input_output(full_text)
+
+        with open(output_file_input, "w") as f:
+            f.write(model_input)
+            document_lens.append(len(full_text))
+
+        with open(output_file_output, 'w') as f:
+            f.write(str(model_output.tolist()))
+
+tasks = []
+for file_path in tqdm(files_subset):
+    tasks.append(dask.delayed(process_file)(file_path))
+
+with ProgressBar():
+    results = dask.compute(*tasks)
+
+
diff --git a/dataset_generation/processing.py b/dataset_generation/processing.py
index fe4e23a..9a9f44d 100644
--- a/dataset_generation/processing.py
+++ b/dataset_generation/processing.py
@@ -15,6 +15,11 @@ from transformers import PreTrainedTokenizerFast, BertTokenizerFast
 ACTIONS_KEYS = ['dot', 'upper_case', 'colon', 'semicolon', 'elipsis', 'dash']
 
 def empty_action_vector() -> np.ndarray:
+    """Returns a do-nothing actions vector
+
+    Returns:
+        np.ndarray: Vector with all zeroes and length of ACTION_KEYS
+    """
     return np.zeros(len(ACTIONS_KEYS))
 
 def text_from_xml(path: str) -> str:
@@ -59,7 +64,7 @@ def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
         word = word[1:]
 
     if len(word) == 0:
-        return zip(ACTIONS_KEYS, [False] * len(ACTIONS_KEYS))
+        return dict(zip(ACTIONS_KEYS, [False] * len(ACTIONS_KEYS)))
 
     has_colon = len(word) > 3 and word[-3:] == "..."
 
@@ -133,6 +138,17 @@ def create_model_input_output(text: str) -> (str, np.ndarray):
     return " ".join(words_output), np.array(actions_output)
 
 def tokenize_labeled_text(text: str, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast) -> (np.ndarray, np.ndarray):
+    """Transforms text into numerical tokens. Also expand word-level labels into token-level labels
+
+    Args:
+        text (str): Text that will be tokenized (TODO: Change to array)
+        labels (np.ndarray): Word-level labels for text to be tokenized. Word is defined via space spearation 
+        tokenizer (PreTrainedTokenizerFast): Tokenizer that will be used for tokenization
+
+    Returns:
+        np.ndarray: 2-dimensional array with tokens (without cls and sep tokens!)
+        np.ndarray 2-dimensional array with token-level labels
+    """
     text_tokenized = tokenizer(text, return_offsets_mapping=True)
 
     offset_mappings = text_tokenized['offset_mapping'][1:-1]
@@ -156,6 +172,15 @@ def tokenize_labeled_text(text: str, labels: np.ndarray, tokenizer: PreTrainedTo
 
 
 def recover_word(word: str, action: Mapping[str, bool]) -> str:
+    """Applies action to a word
+
+    Args:
+        word (str): word on which action will be applied
+        action (Mapping[str, bool]): Action to be applied
+
+    Returns:
+        str: transfomed word
+    """
     word_result = word
     
     if action['dot']:
@@ -174,12 +199,30 @@ def recover_word(word: str, action: Mapping[str, bool]) -> str:
     return word
 
 def is_sentence_end(actions_encoded: np.ndarray) -> bool:
+    """Returns if given action would end a sentence
+
+    Args:
+        actions_encoded (np.ndarray): Action vector
+
+    Returns:
+        bool: True if action would end a sentence, False otherwise
+    """
     actions_decoded = decode_actions(actions_encoded)
 
     return (actions_decoded['dot']
             or actions_decoded['elipsis'])
 
 def nearest_sentence_l(labels: np.array, index_start: int) -> int:
+    """Find nearest word that begins a sentence that has lower or equal index to index_start
+
+    Args:
+        labels (np.array): 2-dimensonal array of action-vectors
+        index_start (int): Index from which search will be started
+
+    Returns:
+        int: Index of nearest left-oriented start of the sentence. If no sentence is found, first index is assumed to
+             start a sentence
+    """
     result_index = index_start
 
     while result_index > 0:
@@ -197,6 +240,15 @@ def nearest_sentence_l(labels: np.array, index_start: int) -> int:
     return result_index
 
 def nearest_sentence_r(labels: np.array, index_start: int) -> Optional[int]:
+    """Find nearest word that begins a sentence that has higher or equal index to index_start
+
+    Args:
+        labels (np.array): 2-dimensonal array of action-vectors
+        index_start (int): Index from which search will be started
+
+    Returns:
+        int: Index of nearest right-oriented start of the sentence. None if no later sentence is found
+    """
     result_index = index_start
 
     while result_index < len(labels):
@@ -211,6 +263,16 @@ def nearest_sentence_r(labels: np.array, index_start: int) -> Optional[int]:
         return result_index
 
 def batchify_labels(labels: np.ndarray, max_tokens: int, min_tokens: int = 3) -> [np.ndarray]:
+    """Splits long labels array into batches of desired size
+
+    Args:
+        labels (np.ndarray): 2-dimensional array of action-vectors
+        max_tokens (int): Maximum number of labels in a single batch
+        min_tokens (int, optional): Minimum number of labels in a single batch. Defaults to 3.
+
+    Returns:
+        [np.ndarray]: List of arrays with indexes composing each batch
+    """
     assert min_tokens >= 1
     assert max_tokens >= 1
 
@@ -241,6 +303,16 @@ def batchify_labels(labels: np.ndarray, max_tokens: int, min_tokens: int = 3) ->
     return labels_batches
 
 def add_cls_sep(tokens: np.ndarray, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast) -> (np.ndarray, np.ndarray):
+    """Adds staring cls and ending sep token ids into tokens & labels
+
+    Args:
+        tokens (np.ndarray): 2-dimensional array (with 1 feature!) of tokens
+        labels (np.ndarray): 2-dimensional array of action vectors
+
+    Returns:
+        np.ndarray: tokens with added cls & sep tokens ids
+        np.ndarray: labels with first and last item duplicated to accomodate for cls & sep
+    """
     
     tokens = np.concatenate([[[tokenizer.cls_token_id]], tokens, [[tokenizer.sep_token_id]]])
     labels = np.concatenate([labels[:1, :], labels, labels[-1:, :]])
@@ -248,6 +320,19 @@ def add_cls_sep(tokens: np.ndarray, labels: np.ndarray, tokenizer: PreTrainedTok
     return tokens, labels
 
 def add_padding(tokens: np.ndarray, labels: np.ndarray, length: int, tokenizer: PreTrainedTokenizerFast) -> (np.ndarray, np.ndarray, np.ndarray):
+    """Appends padding to tokens and labels to match desired length
+
+    Args:
+        tokens (np.ndarray): Lx1 array of token ids
+        labels (np.ndarray): LxA array of action vectors
+        length (int): Desired length of a vector. Must be higher than L  
+        tokenizer (PreTrainedTokenizerFast): Tokenizer that was used for tokenization
+
+    Returns:
+        np.ndarray: (L+P)x1 array of token ids with added padding
+        np.ndarray: (L+P)xA array of action vectors with added padding
+        np.ndarray: (L+P)-length array of masks where True means token False - padding
+    """
 
     pad_length = tokens.shape[0] - length
     assert pad_length >= 0
@@ -265,6 +350,15 @@ def add_padding(tokens: np.ndarray, labels: np.ndarray, length: int, tokenizer:
 
 def batchify_data(tokens: np.ndarray, labels: np.ndarray, max_tokens: int,
                     tokenizer: PreTrainedTokenizerFast, min_tokens: int = 3) -> (np.ndarray, np.ndarray):
+    """Transforms tokens and labels into a batch
+
+    Args:
+        np ([type]): [description]
+        tokens (np.ndarray, labels, optional): [description]. Defaults to 3)->(np.ndarray.
+
+    Returns:
+        [type]: [description]
+    """
 
     assert max_tokens >= min_tokens + 2
     assert min_tokens >= 1
-- 
GitLab


From 3d4cd65659cb7e85c7c312e884760cc5d9e7f4ec Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 16 Jul 2020 14:40:25 +0200
Subject: [PATCH 007/116] Added pipeline final processing

---
 .gitignore                                |   6 +-
 {dane => data}/download_dataset.sh        |   0
 dataset_generation/generate_actions.py    |  50 -------
 dataset_generation/notebook_actions.ipynb | 158 ----------------------
 dataset_generation/notebook_simple.ipynb  | 152 ---------------------
 dataset_generation/processing.py          |  10 +-
 dataset_generation/stage1_extraction.py   |  35 +++++
 dataset_generation/stage2_tokenization.py |  58 ++++++++
 dataset_generation/stage3_spliting.py     |  48 +++++++
 dataset_generation/stage4_exploding.py    |  50 +++++++
 dataset_generation/test_processing.py     |  45 +++---
 generated/.gitignore                      |   2 +
 12 files changed, 232 insertions(+), 382 deletions(-)
 rename {dane => data}/download_dataset.sh (100%)
 delete mode 100644 dataset_generation/generate_actions.py
 delete mode 100644 dataset_generation/notebook_actions.ipynb
 delete mode 100644 dataset_generation/notebook_simple.ipynb
 create mode 100644 dataset_generation/stage1_extraction.py
 create mode 100644 dataset_generation/stage2_tokenization.py
 create mode 100644 dataset_generation/stage3_spliting.py
 create mode 100644 dataset_generation/stage4_exploding.py
 create mode 100644 generated/.gitignore

diff --git a/.gitignore b/.gitignore
index e75c3e1..1633793 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,7 @@
 dane/**
 dataset_simple
-dataset_actions
\ No newline at end of file
+dataset_actions
+**/dask-worker-space
+.vscode
+.idea
+test.ipynb
\ No newline at end of file
diff --git a/dane/download_dataset.sh b/data/download_dataset.sh
similarity index 100%
rename from dane/download_dataset.sh
rename to data/download_dataset.sh
diff --git a/dataset_generation/generate_actions.py b/dataset_generation/generate_actions.py
deleted file mode 100644
index 5a39d2b..0000000
--- a/dataset_generation/generate_actions.py
+++ /dev/null
@@ -1,50 +0,0 @@
-#/usr/bin/python3
-import os
-import glob
-import random
-from lxml import etree
-import uuid
-import hashlib
-import seaborn as sns
-import re
-import numpy as np
-from tqdm import tqdm
-from processing import text_from_xml, create_model_input_output
-from utils import remove_multiple_spaces, remove_punctuation
-import dask
-from dask.diagnostics import ProgressBar
-
-file_schema = "../dane/**/text_structure.xml"
-files_paths = glob.glob(file_schema, recursive=True)
-
-files_subset = files_paths #random.sample(files_paths, 1_000)
-
-document_lens = []
-
-if not os.path.exists("../dataset_actions"):
-    os.mkdir("../dataset_actions")
-
-def process_file(file_path):
-    full_text = text_from_xml(file_path)
-    
-    if len(full_text) > 0:
-        output_file_input = f"../dataset_actions/{hashlib.md5(file_path.encode()).hexdigest()}_input.txt"
-        output_file_output = f"../dataset_actions/{hashlib.md5(file_path.encode()).hexdigest()}_output.txt"
-
-        model_input, model_output = create_model_input_output(full_text)
-
-        with open(output_file_input, "w") as f:
-            f.write(model_input)
-            document_lens.append(len(full_text))
-
-        with open(output_file_output, 'w') as f:
-            f.write(str(model_output.tolist()))
-
-tasks = []
-for file_path in tqdm(files_subset):
-    tasks.append(dask.delayed(process_file)(file_path))
-
-with ProgressBar():
-    results = dask.compute(*tasks)
-
-
diff --git a/dataset_generation/notebook_actions.ipynb b/dataset_generation/notebook_actions.ipynb
deleted file mode 100644
index 20a28ac..0000000
--- a/dataset_generation/notebook_actions.ipynb
+++ /dev/null
@@ -1,158 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 14,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "stream",
-     "name": "stdout",
-     "text": "The autoreload extension is already loaded. To reload it, use:\n  %reload_ext autoreload\n"
-    }
-   ],
-   "source": [
-    "%load_ext autoreload\n",
-    "%autoreload 2"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 15,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import glob\n",
-    "import random\n",
-    "from lxml import etree\n",
-    "import uuid\n",
-    "import hashlib\n",
-    "import seaborn as sns\n",
-    "import re\n",
-    "import numpy as np\n",
-    "from tqdm import tqdm\n",
-    "\n",
-    "from processing import text_from_xml, create_model_input_output\n",
-    "from utils import remove_multiple_spaces, remove_punctuation"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 16,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "file_schema = \"../dane/**/text_structure.xml\"\n",
-    "files_paths = glob.glob(file_schema, recursive=True)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 17,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "files_subset = random.sample(files_paths, 1_000)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 18,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "stream",
-     "name": "stderr",
-     "text": "100%|██████████| 1000/1000 [00:07<00:00, 133.04it/s]\n"
-    }
-   ],
-   "source": [
-    "num_exported = 0\n",
-    "document_lens = []\n",
-    "\n",
-    "if not os.path.exists(\"../dataset_actions\"):\n",
-    "    os.mkdir(\"../dataset_actions\")\n",
-    "\n",
-    "for file_path in tqdm(files_subset):\n",
-    "    full_text = text_from_xml(file_path)\n",
-    "    \n",
-    "    if len(full_text) > 0:\n",
-    "        output_file_input = f\"../dataset_actions/{hashlib.md5(file_path.encode()).hexdigest()}_input.txt\"\n",
-    "        output_file_output = f\"../dataset_actions/{hashlib.md5(file_path.encode()).hexdigest()}_output.txt\"\n",
-    "\n",
-    "        model_input, model_output = create_model_input_output(full_text)\n",
-    "\n",
-    "        with open(output_file_input, \"w\") as f:\n",
-    "            f.write(model_input)\n",
-    "            num_exported += 1\n",
-    "            document_lens.append(len(full_text))\n",
-    "\n",
-    "        with open(output_file_output, 'w') as f:\n",
-    "            f.write(str(model_output.tolist()))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 19,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "<matplotlib.axes._subplots.AxesSubplot at 0x7f83c19294c0>"
-     },
-     "metadata": {},
-     "execution_count": 19
-    },
-    {
-     "output_type": "display_data",
-     "data": {
-      "text/plain": "<Figure size 432x288 with 1 Axes>",
-      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 368.925 248.518125\" width=\"368.925pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n  <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n  </style>\n </defs>\n <g id=\"figure_1\">\n  <g id=\"patch_1\">\n   <path d=\"M -0 248.518125 \nL 368.925 248.518125 \nL 368.925 0 \nL -0 0 \nz\n\" style=\"fill:none;\"/>\n  </g>\n  <g id=\"axes_1\">\n   <g id=\"patch_2\">\n    <path d=\"M 26.925 224.64 \nL 361.725 224.64 \nL 361.725 7.2 \nL 26.925 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n   </g>\n   <g id=\"patch_3\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 42.143182 224.64 \nL 55.376383 224.64 \nL 55.376383 17.554286 \nL 42.143182 17.554286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_4\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 55.376383 224.64 \nL 68.609585 224.64 \nL 68.609585 34.121143 \nL 55.376383 34.121143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_5\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 68.609585 224.64 \nL 81.842787 224.64 \nL 81.842787 83.821714 \nL 68.609585 83.821714 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_6\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 81.842787 224.64 \nL 95.075988 224.64 \nL 95.075988 150.089143 \nL 81.842787 150.089143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_7\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 95.075988 224.64 \nL 108.30919 224.64 \nL 108.30919 216.356571 \nL 95.075988 216.356571 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_8\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 108.30919 224.64 \nL 121.542391 224.64 \nL 121.542391 224.64 \nL 108.30919 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_9\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 121.542391 224.64 \nL 134.775593 224.64 \nL 134.775593 191.506286 \nL 121.542391 191.506286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_10\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 134.775593 224.64 \nL 148.008794 224.64 \nL 148.008794 208.073143 \nL 134.775593 208.073143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_11\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 148.008794 224.64 \nL 161.241996 224.64 \nL 161.241996 224.64 \nL 148.008794 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_12\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 161.241996 224.64 \nL 174.475198 224.64 \nL 174.475198 208.073143 \nL 161.241996 208.073143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_13\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 174.475198 224.64 \nL 187.708399 224.64 \nL 187.708399 216.356571 \nL 174.475198 216.356571 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_14\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 187.708399 224.64 \nL 200.941601 224.64 \nL 200.941601 224.64 \nL 187.708399 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_15\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 200.941601 224.64 \nL 214.174802 224.64 \nL 214.174802 224.64 \nL 200.941601 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_16\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 214.174802 224.64 \nL 227.408004 224.64 \nL 227.408004 216.356571 \nL 214.174802 216.356571 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_17\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 227.408004 224.64 \nL 240.641206 224.64 \nL 240.641206 224.64 \nL 227.408004 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_18\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 240.641206 224.64 \nL 253.874407 224.64 \nL 253.874407 224.64 \nL 240.641206 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_19\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 253.874407 224.64 \nL 267.107609 224.64 \nL 267.107609 216.356571 \nL 253.874407 216.356571 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_20\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 267.107609 224.64 \nL 280.34081 224.64 \nL 280.34081 224.64 \nL 267.107609 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_21\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 280.34081 224.64 \nL 293.574012 224.64 \nL 293.574012 208.073143 \nL 280.34081 208.073143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_22\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 293.574012 224.64 \nL 306.807213 224.64 \nL 306.807213 224.64 \nL 293.574012 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_23\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 306.807213 224.64 \nL 320.040415 224.64 \nL 320.040415 216.356571 \nL 306.807213 216.356571 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_24\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 320.040415 224.64 \nL 333.273617 224.64 \nL 333.273617 224.64 \nL 320.040415 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_25\">\n    <path clip-path=\"url(#p920417c74f)\" d=\"M 333.273617 224.64 \nL 346.506818 224.64 \nL 346.506818 208.073143 \nL 333.273617 208.073143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"matplotlib.axis_1\">\n    <g id=\"xtick_1\">\n     <g id=\"line2d_1\">\n      <defs>\n       <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"mabeb11051b\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"41.448148\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_1\">\n      <!-- 0 -->\n      <defs>\n       <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421875 \nQ 19.53125 -1.421875 13.0625 8.265625 \nQ 6.59375 17.96875 6.59375 36.375 \nQ 6.59375 54.828125 13.0625 64.515625 \nQ 19.53125 74.21875 31.78125 74.21875 \nz\n\" id=\"DejaVuSans-48\"/>\n      </defs>\n      <g transform=\"translate(38.266898 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_2\">\n     <g id=\"line2d_2\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"89.118482\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_2\">\n      <!-- 100000 -->\n      <defs>\n       <path d=\"M 12.40625 8.296875 \nL 28.515625 8.296875 \nL 28.515625 63.921875 \nL 10.984375 60.40625 \nL 10.984375 69.390625 \nL 28.421875 72.90625 \nL 38.28125 72.90625 \nL 38.28125 8.296875 \nL 54.390625 8.296875 \nL 54.390625 0 \nL 12.40625 0 \nz\n\" id=\"DejaVuSans-49\"/>\n      </defs>\n      <g transform=\"translate(70.030982 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_3\">\n     <g id=\"line2d_3\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"136.788815\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_3\">\n      <!-- 200000 -->\n      <defs>\n       <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n      </defs>\n      <g transform=\"translate(117.701315 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_4\">\n     <g id=\"line2d_4\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"184.459148\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_4\">\n      <!-- 300000 -->\n      <defs>\n       <path d=\"M 40.578125 39.3125 \nQ 47.65625 37.796875 51.625 33 \nQ 55.609375 28.21875 55.609375 21.1875 \nQ 55.609375 10.40625 48.1875 4.484375 \nQ 40.765625 -1.421875 27.09375 -1.421875 \nQ 22.515625 -1.421875 17.65625 -0.515625 \nQ 12.796875 0.390625 7.625 2.203125 \nL 7.625 11.71875 \nQ 11.71875 9.328125 16.59375 8.109375 \nQ 21.484375 6.890625 26.8125 6.890625 \nQ 36.078125 6.890625 40.9375 10.546875 \nQ 45.796875 14.203125 45.796875 21.1875 \nQ 45.796875 27.640625 41.28125 31.265625 \nQ 36.765625 34.90625 28.71875 34.90625 \nL 20.21875 34.90625 \nL 20.21875 43.015625 \nL 29.109375 43.015625 \nQ 36.375 43.015625 40.234375 45.921875 \nQ 44.09375 48.828125 44.09375 54.296875 \nQ 44.09375 59.90625 40.109375 62.90625 \nQ 36.140625 65.921875 28.71875 65.921875 \nQ 24.65625 65.921875 20.015625 65.03125 \nQ 15.375 64.15625 9.8125 62.3125 \nL 9.8125 71.09375 \nQ 15.4375 72.65625 20.34375 73.4375 \nQ 25.25 74.21875 29.59375 74.21875 \nQ 40.828125 74.21875 47.359375 69.109375 \nQ 53.90625 64.015625 53.90625 55.328125 \nQ 53.90625 49.265625 50.4375 45.09375 \nQ 46.96875 40.921875 40.578125 39.3125 \nz\n\" id=\"DejaVuSans-51\"/>\n      </defs>\n      <g transform=\"translate(165.371648 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-51\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_5\">\n     <g id=\"line2d_5\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"232.129481\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_5\">\n      <!-- 400000 -->\n      <defs>\n       <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375 17.1875 \nL 47.609375 0 \nL 37.796875 0 \nL 37.796875 17.1875 \nL 4.890625 17.1875 \nL 4.890625 26.703125 \nz\n\" id=\"DejaVuSans-52\"/>\n      </defs>\n      <g transform=\"translate(213.041981 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-52\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_6\">\n     <g id=\"line2d_6\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"279.799814\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_6\">\n      <!-- 500000 -->\n      <defs>\n       <path d=\"M 10.796875 72.90625 \nL 49.515625 72.90625 \nL 49.515625 64.59375 \nL 19.828125 64.59375 \nL 19.828125 46.734375 \nQ 21.96875 47.46875 24.109375 47.828125 \nQ 26.265625 48.1875 28.421875 48.1875 \nQ 40.625 48.1875 47.75 41.5 \nQ 54.890625 34.8125 54.890625 23.390625 \nQ 54.890625 11.625 47.5625 5.09375 \nQ 40.234375 -1.421875 26.90625 -1.421875 \nQ 22.3125 -1.421875 17.546875 -0.640625 \nQ 12.796875 0.140625 7.71875 1.703125 \nL 7.71875 11.625 \nQ 12.109375 9.234375 16.796875 8.0625 \nQ 21.484375 6.890625 26.703125 6.890625 \nQ 35.15625 6.890625 40.078125 11.328125 \nQ 45.015625 15.765625 45.015625 23.390625 \nQ 45.015625 31 40.078125 35.4375 \nQ 35.15625 39.890625 26.703125 39.890625 \nQ 22.75 39.890625 18.8125 39.015625 \nQ 14.890625 38.140625 10.796875 36.28125 \nz\n\" id=\"DejaVuSans-53\"/>\n      </defs>\n      <g transform=\"translate(260.712314 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-53\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_7\">\n     <g id=\"line2d_7\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"327.470147\" xlink:href=\"#mabeb11051b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_7\">\n      <!-- 600000 -->\n      <defs>\n       <path d=\"M 33.015625 40.375 \nQ 26.375 40.375 22.484375 35.828125 \nQ 18.609375 31.296875 18.609375 23.390625 \nQ 18.609375 15.53125 22.484375 10.953125 \nQ 26.375 6.390625 33.015625 6.390625 \nQ 39.65625 6.390625 43.53125 10.953125 \nQ 47.40625 15.53125 47.40625 23.390625 \nQ 47.40625 31.296875 43.53125 35.828125 \nQ 39.65625 40.375 33.015625 40.375 \nz\nM 52.59375 71.296875 \nL 52.59375 62.3125 \nQ 48.875 64.0625 45.09375 64.984375 \nQ 41.3125 65.921875 37.59375 65.921875 \nQ 27.828125 65.921875 22.671875 59.328125 \nQ 17.53125 52.734375 16.796875 39.40625 \nQ 19.671875 43.65625 24.015625 45.921875 \nQ 28.375 48.1875 33.59375 48.1875 \nQ 44.578125 48.1875 50.953125 41.515625 \nQ 57.328125 34.859375 57.328125 23.390625 \nQ 57.328125 12.15625 50.6875 5.359375 \nQ 44.046875 -1.421875 33.015625 -1.421875 \nQ 20.359375 -1.421875 13.671875 8.265625 \nQ 6.984375 17.96875 6.984375 36.375 \nQ 6.984375 53.65625 15.1875 63.9375 \nQ 23.390625 74.21875 37.203125 74.21875 \nQ 40.921875 74.21875 44.703125 73.484375 \nQ 48.484375 72.75 52.59375 71.296875 \nz\n\" id=\"DejaVuSans-54\"/>\n      </defs>\n      <g transform=\"translate(308.382647 239.238438)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-54\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"matplotlib.axis_2\">\n    <g id=\"ytick_1\">\n     <g id=\"line2d_8\">\n      <defs>\n       <path d=\"M 0 0 \nL -3.5 0 \n\" id=\"m5ee8100b6b\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m5ee8100b6b\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_8\">\n      <!-- 0 -->\n      <g transform=\"translate(13.5625 228.439219)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_2\">\n     <g id=\"line2d_9\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m5ee8100b6b\" y=\"183.222857\"/>\n      </g>\n     </g>\n     <g id=\"text_9\">\n      <!-- 5 -->\n      <g transform=\"translate(13.5625 187.022076)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_3\">\n     <g id=\"line2d_10\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m5ee8100b6b\" y=\"141.805714\"/>\n      </g>\n     </g>\n     <g id=\"text_10\">\n      <!-- 10 -->\n      <g transform=\"translate(7.2 145.604933)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_4\">\n     <g id=\"line2d_11\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m5ee8100b6b\" y=\"100.388571\"/>\n      </g>\n     </g>\n     <g id=\"text_11\">\n      <!-- 15 -->\n      <g transform=\"translate(7.2 104.18779)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_5\">\n     <g id=\"line2d_12\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m5ee8100b6b\" y=\"58.971429\"/>\n      </g>\n     </g>\n     <g id=\"text_12\">\n      <!-- 20 -->\n      <g transform=\"translate(7.2 62.770647)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_6\">\n     <g id=\"line2d_13\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#m5ee8100b6b\" y=\"17.554286\"/>\n      </g>\n     </g>\n     <g id=\"text_13\">\n      <!-- 25 -->\n      <g transform=\"translate(7.2 21.353504)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"patch_26\">\n    <path d=\"M 26.925 224.64 \nL 26.925 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_27\">\n    <path d=\"M 361.725 224.64 \nL 361.725 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_28\">\n    <path d=\"M 26.925 224.64 \nL 361.725 224.64 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_29\">\n    <path d=\"M 26.925 7.2 \nL 361.725 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n  </g>\n </g>\n <defs>\n  <clipPath id=\"p920417c74f\">\n   <rect height=\"217.44\" width=\"334.8\" x=\"26.925\" y=\"7.2\"/>\n  </clipPath>\n </defs>\n</svg>\n",
-      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOKklEQVR4nO3dX4xc5XnH8e9TDKQFFEy8sizAXRIhKl+0hq4wCIRoaCigqiQSqjAVWC2RoxYkUCNVkEhtepdWDWmjVgSn0HCBaf4ABaG0hDpIUaTK6ZoYMDiODXWIkcFL0gTUmxZ4ejHvwmTZ3dn5s7vzkO9HGs057/nzPoc9/PbsO+eMIzORJNXzS6tdgCRpMAa4JBVlgEtSUQa4JBVlgEtSUWtWsrN169bl5OTkSnYpSeXt2bPn1cycmNu+ogE+OTnJ9PT0SnYpSeVFxA/na3cIRZKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqaieAR4RZ0bEExHxXEQ8GxG3tPbPRMRLEbG3va5a/nIlSbOWch/4G8AnM/PJiDgF2BMRj7dln8/Mv1m+8iRJC+kZ4Jl5FDjapl+PiP3A6ctdmCRpcX09iRkRk8C5wG7gIuDmiLgBmKZzlf7f82yzHdgOsHHjxoEL3bn7xb63uW7L4P1J0rhb8oeYEXEy8ABwa2a+BtwJfAjYTOcK/XPzbZeZOzJzKjOnJibe9Si/JGlASwrwiDieTnjfl5kPAmTmK5n5Zma+BXwJOH/5ypQkzbWUu1ACuBvYn5l3dLVv6FrtY8C+0ZcnSVrIUsbALwKuB56JiL2t7VPA1ojYDCRwGPjEslQoSZrXUu5C+Q4Q8yz6xujLkSQtlU9iSlJRBrgkFWWAS1JRBrgkFWWAS1JRBrgkFWWAS1JRBrgkFWWAS1JRBrgkFdXX94FXM8h3iIPfIy6pBq/AJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJakoA1ySijLAJamongEeEWdGxBMR8VxEPBsRt7T20yLi8Yg42N7XLn+5kqRZS7kCfwP4ZGZuAi4AboqITcBtwK7MPBvY1eYlSSukZ4Bn5tHMfLJNvw7sB04HrgbubavdC3x0uYqUJL1bX2PgETEJnAvsBtZn5tG26GVg/QLbbI+I6YiYnpmZGaJUSVK3JQd4RJwMPADcmpmvdS/LzARyvu0yc0dmTmXm1MTExFDFSpLesaQAj4jj6YT3fZn5YGt+JSI2tOUbgGPLU6IkaT5LuQslgLuB/Zl5R9eiR4BtbXob8PDoy5MkLWTNEta5CLgeeCYi9ra2TwGfBb4aETcCPwR+f3lKlCTNp2eAZ+Z3gFhg8WWjLUeStFQ+iSlJRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRRngklSUAS5JRS3lHzX+hbNz94t9b3Pdlo3LUIkkLcwrcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKJ6BnhE3BMRxyJiX1fbZyLipYjY215XLW+ZkqS5lnIF/mXginnaP5+Zm9vrG6MtS5LUS88Az8xvAz9ZgVokSX0YZgz85oh4ug2xrB1ZRZKkJRk0wO8EPgRsBo4Cn1toxYjYHhHTETE9MzMzYHeSpLkGCvDMfCUz38zMt4AvAecvsu6OzJzKzKmJiYlB65QkzTFQgEfEhq7ZjwH7FlpXkrQ8ev6bmBFxP3ApsC4ijgB/AVwaEZuBBA4Dn1jGGiVJ8+gZ4Jm5dZ7mu5ehFklSH3wSU5KKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKKMsAlqSgDXJKK6hngEXFPRByLiH1dbadFxOMRcbC9r13eMiVJcy3lCvzLwBVz2m4DdmXm2cCuNi9JWkE9Azwzvw38ZE7z1cC9bfpe4KMjrkuS1MOgY+DrM/Nom34ZWL/QihGxPSKmI2J6ZmZmwO4kSXMN/SFmZiaQiyzfkZlTmTk1MTExbHeSpGbQAH8lIjYAtPdjoytJkrQUgwb4I8C2Nr0NeHg05UiSlmoptxHeD/wHcE5EHImIG4HPAh+JiIPAb7d5SdIKWtNrhczcusCiy0ZciySpDz6JKUlFGeCSVFTPIRQtzc7dL/a9zXVbNi5DJZJ+UXgFLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFGeCSVJQBLklFrRlm44g4DLwOvAm8kZlToyhKktTbUAHe/FZmvjqC/UiS+uAQiiQVNWyAJ/DNiNgTEdvnWyEitkfEdERMz8zMDNmdJGnWsAF+cWaeB1wJ3BQRl8xdITN3ZOZUZk5NTEwM2Z0kadZQAZ6ZL7X3Y8BDwPmjKEqS1NvAAR4RJ0XEKbPTwOXAvlEVJkla3DB3oawHHoqI2f3szMx/G0lVkqSeBg7wzHwB+I0R1iJJ6oO3EUpSUQa4JBVlgEtSUaN4lF4F7Nz9Yt/bXLdl4zJUImlUvAKXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKIMcEkqygCXpKL8PnAtaKW+Q3yQfgbtSyvrvfizHadj8gpckooywCWpKANckooywCWpKANckooywCWpKANckooywCWpKB/kWUUr9aDMe9U4//cb59pg/OsbxHvxmHrxClySijLAJakoA1ySijLAJamooQI8Iq6IiAMRcSgibhtVUZKk3gYO8Ig4DvgH4EpgE7A1IjaNqjBJ0uKGuQI/HziUmS9k5v8C/wxcPZqyJEm9RGYOtmHENcAVmfnxNn89sCUzb56z3nZge5s9BzgwQHfrgFcHKnQ8WP/qqVw7WP9qGqfafzUzJ+Y2LvuDPJm5A9gxzD4iYjozp0ZU0oqz/tVTuXaw/tVUofZhhlBeAs7smj+jtUmSVsAwAf6fwNkRcVZEnABcCzwymrIkSb0MPISSmW9ExM3AY8BxwD2Z+ezIKvt5Qw3BjAHrXz2VawfrX01jX/vAH2JKklaXT2JKUlEGuCQVNfYBvpqP60fEPRFxLCL2dbWdFhGPR8TB9r62tUdEfKHV+XREnNe1zba2/sGI2NbV/psR8Uzb5gsREYv1MUD9Z0bEExHxXEQ8GxG3VDmGiHhfRHw3Ip5qtf9laz8rIna3/r7SPkAnIk5s84fa8smufd3e2g9ExO90tc97bi3UxyAi4riI+F5EPFqt/og43H62eyNiurWN/bnT9nFqRHw9Ir4fEfsj4sIqtfclM8f2RefD0eeBDwInAE8Bm1aw/0uA84B9XW1/DdzWpm8D/qpNXwX8KxDABcDu1n4a8EJ7X9um17Zl323rRtv2ysX6GKD+DcB5bfoU4Ad0vvZg7I+h7e/kNn08sLv181Xg2tb+ReCP2/SfAF9s09cCX2nTm9p5cyJwVjufjlvs3FqojwF/Bn8K7AQeXWzf41g/cBhYN6dt7M+dtt29wMfb9AnAqVVq7+s4l3PnQxcHFwKPdc3fDty+wjVM8vMBfgDY0KY3AAfa9F3A1rnrAVuBu7ra72ptG4Dvd7W/vd5CfYzgWB4GPlLtGIBfAZ4EttB5Mm7N3PODzt1QF7bpNW29mHvOzK630LnVtpm3jwHqPgPYBXwYeHSxfY9p/Yd5d4CP/bkDvB/4L9pNGpVq7/c17kMopwM/6po/0tpW0/rMPNqmXwbWt+mFal2s/cg87Yv1MbD2J/m5dK5kSxxDG37YCxwDHqdzxfnTzHxjnv7errEt/xnwgQGO6QOL9NGvvwX+DHirzS+273GsP4FvRsSe6HwlBtQ4d84CZoB/asNX/xgRJxWpvS/jHuBjLTu/Zpf1PsxR9BERJwMPALdm5muj3n8vg/aRmW9m5mY6V7LnA7826tqWS0T8LnAsM/esdi1DuDgzz6PzjaM3RcQl3QvH+NxZQ2fo887MPBf4HzrDGcPuty8r0ce4B/g4Pq7/SkRsAGjvx1r7QrUu1n7GPO2L9dG3iDieTnjfl5kPVjyGzPwp8ASd4YBTI2L2AbTu/t6usS1/P/DjAY7px4v00Y+LgN+LiMN0vqnzw8DfFaqfzHypvR8DHqLzS7TCuXMEOJKZu9v81+kEeoXa+zLuAT6Oj+s/Asx+Gr2NzrjybPsN7RPtC4CftT+lHgMuj4i17RPpy+mMSR4FXouIC9on2DfM2dd8ffSl7fduYH9m3lHpGCJiIiJObdO/TGfsfj+dIL9mgdpn+7sG+Fa7AnoEuDY6d3mcBZxN5wOoec+tts1CfSxZZt6emWdk5mTb97cy8w+q1B8RJ0XEKbPTdH7m+yhw7mTmy8CPIuKc1nQZ8FyF2vu2nAPso3jR+YT4B3TGPz+9wn3fDxwF/o/Ob/Ub6Ywx7gIOAv8OnNbWDTr/wMXzwDPAVNd+/gg41F5/2NU+Red/iueBv+edJ2Pn7WOA+i+m8yfc08De9rqqwjEAvw58r9W+D/jz1v5BOgF2CPgacGJrf1+bP9SWf7BrX59u9R2g3S2w2Lm1UB9DnEeX8s5dKCXqb/t4qr2end1/hXOn7WMzMN3On3+hcxdJidr7efkovSQVNe5DKJKkBRjgklSUAS5JRRngklSUAS5JRRngklSUAS5JRf0/SKUBmaC1OaEAAAAASUVORK5CYII=\n"
-     },
-     "metadata": {
-      "needs_background": "light"
-     }
-    }
-   ],
-   "source": [
-    "sns.distplot(document_lens, kde=False)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.2-final"
-  },
-  "orig_nbformat": 2,
-  "kernelspec": {
-   "name": "python38264bita7d7da14168440cb9836372958035d4a",
-   "display_name": "Python 3.8.2 64-bit"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
\ No newline at end of file
diff --git a/dataset_generation/notebook_simple.ipynb b/dataset_generation/notebook_simple.ipynb
deleted file mode 100644
index 9b8a447..0000000
--- a/dataset_generation/notebook_simple.ipynb
+++ /dev/null
@@ -1,152 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "%load_ext autoreload\n",
-    "%autoreload 2"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 76,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import glob\n",
-    "import random\n",
-    "from lxml import etree\n",
-    "import uuid\n",
-    "import hashlib\n",
-    "import seaborn as sns\n",
-    "import re\n",
-    "import os\n",
-    "from tqdm import tqdm\n",
-    "\n",
-    "from processing import text_from_xml\n",
-    "from utils import remove_multiple_spaces, remove_punctuation"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 6,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "file_schema = \"../dane/**/text_structure.xml\"\n",
-    "files_paths = glob.glob(file_schema, recursive=True)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 54,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "files_subset = random.sample(files_paths, 1_000)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 78,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "stream",
-     "name": "stderr",
-     "text": "100%|██████████| 10000/10000 [00:40<00:00, 248.07it/s]\n"
-    }
-   ],
-   "source": [
-    "num_exported = 0\n",
-    "document_lens = []\n",
-    "\n",
-    "if not os.path.exists(\"../dataset_simple\"):\n",
-    "    os.mkdir(\"../dataset_simple\")\n",
-    "\n",
-    "for file_path in tqdm(files_subset):\n",
-    "    full_text = text_from_xml(file_path)\n",
-    "    \n",
-    "    if len(full_text) > 0:\n",
-    "        output_file_input = f\"../dataset_simple/{hashlib.md5(file_path.encode()).hexdigest()}_input.txt\"\n",
-    "        output_file_output = f\"../dataset_simple/{hashlib.md5(file_path.encode()).hexdigest()}_output.txt\"\n",
-    "\n",
-    "        with open(output_file_input, \"w\") as f:\n",
-    "            f.write(full_text)\n",
-    "            num_exported += 1\n",
-    "            document_lens.append(len(full_text))\n",
-    "\n",
-    "        text_cleared = remove_punctuation(full_text)\n",
-    "        text_cleared = remove_multiple_spaces(text_cleared)\n",
-    "        text_cleared = text_cleared.lower()\n",
-    "\n",
-    "        with open(output_file_output, 'w') as f:\n",
-    "            f.write(text_cleared)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 79,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "<matplotlib.axes._subplots.AxesSubplot at 0x7fcd4958ab20>"
-     },
-     "metadata": {},
-     "execution_count": 79
-    },
-    {
-     "output_type": "display_data",
-     "data": {
-      "text/plain": "<Figure size 432x288 with 1 Axes>",
-      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 375.2875 248.518125\" width=\"375.2875pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n  <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n  </style>\n </defs>\n <g id=\"figure_1\">\n  <g id=\"patch_1\">\n   <path d=\"M 0 248.518125 \nL 375.2875 248.518125 \nL 375.2875 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n  </g>\n  <g id=\"axes_1\">\n   <g id=\"patch_2\">\n    <path d=\"M 33.2875 224.64 \nL 368.0875 224.64 \nL 368.0875 7.2 \nL 33.2875 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n   </g>\n   <g id=\"patch_3\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 48.505682 224.64 \nL 54.592955 224.64 \nL 54.592955 17.554286 \nL 48.505682 17.554286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_4\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 54.592955 224.64 \nL 60.680227 224.64 \nL 60.680227 76.606071 \nL 54.592955 76.606071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_5\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 60.680227 224.64 \nL 66.7675 224.64 \nL 66.7675 93.593571 \nL 60.680227 93.593571 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_6\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 66.7675 224.64 \nL 72.854773 224.64 \nL 72.854773 112.198929 \nL 66.7675 112.198929 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_7\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 72.854773 224.64 \nL 78.942045 224.64 \nL 78.942045 143.747143 \nL 72.854773 143.747143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_8\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 78.942045 224.64 \nL 85.029318 224.64 \nL 85.029318 176.913214 \nL 78.942045 176.913214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_9\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 85.029318 224.64 \nL 91.116591 224.64 \nL 91.116591 199.563214 \nL 85.029318 199.563214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_10\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 91.116591 224.64 \nL 97.203864 224.64 \nL 97.203864 210.888214 \nL 91.116591 210.888214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_11\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 97.203864 224.64 \nL 103.291136 224.64 \nL 103.291136 214.123929 \nL 97.203864 214.123929 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_12\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 103.291136 224.64 \nL 109.378409 224.64 \nL 109.378409 210.888214 \nL 103.291136 210.888214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_13\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 109.378409 224.64 \nL 115.465682 224.64 \nL 115.465682 221.404286 \nL 109.378409 221.404286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_14\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 115.465682 224.64 \nL 121.552955 224.64 \nL 121.552955 217.359643 \nL 115.465682 217.359643 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_15\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 121.552955 224.64 \nL 127.640227 224.64 \nL 127.640227 216.550714 \nL 121.552955 216.550714 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_16\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 127.640227 224.64 \nL 133.7275 224.64 \nL 133.7275 219.786429 \nL 127.640227 219.786429 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_17\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 133.7275 224.64 \nL 139.814773 224.64 \nL 139.814773 216.550714 \nL 133.7275 216.550714 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_18\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 139.814773 224.64 \nL 145.902045 224.64 \nL 145.902045 220.595357 \nL 139.814773 220.595357 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_19\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 145.902045 224.64 \nL 151.989318 224.64 \nL 151.989318 220.595357 \nL 145.902045 220.595357 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_20\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 151.989318 224.64 \nL 158.076591 224.64 \nL 158.076591 223.022143 \nL 151.989318 223.022143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_21\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 158.076591 224.64 \nL 164.163864 224.64 \nL 164.163864 223.022143 \nL 158.076591 223.022143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_22\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 164.163864 224.64 \nL 170.251136 224.64 \nL 170.251136 216.550714 \nL 164.163864 216.550714 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_23\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 170.251136 224.64 \nL 176.338409 224.64 \nL 176.338409 220.595357 \nL 170.251136 220.595357 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_24\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 176.338409 224.64 \nL 182.425682 224.64 \nL 182.425682 223.022143 \nL 176.338409 223.022143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_25\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 182.425682 224.64 \nL 188.512955 224.64 \nL 188.512955 223.022143 \nL 182.425682 223.022143 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_26\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 188.512955 224.64 \nL 194.600227 224.64 \nL 194.600227 222.213214 \nL 188.512955 222.213214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_27\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 194.600227 224.64 \nL 200.6875 224.64 \nL 200.6875 223.831071 \nL 194.600227 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_28\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 200.6875 224.64 \nL 206.774773 224.64 \nL 206.774773 222.213214 \nL 200.6875 222.213214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_29\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 206.774773 224.64 \nL 212.862045 224.64 \nL 212.862045 224.64 \nL 206.774773 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_30\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 212.862045 224.64 \nL 218.949318 224.64 \nL 218.949318 220.595357 \nL 212.862045 220.595357 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_31\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 218.949318 224.64 \nL 225.036591 224.64 \nL 225.036591 223.831071 \nL 218.949318 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_32\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 225.036591 224.64 \nL 231.123864 224.64 \nL 231.123864 222.213214 \nL 225.036591 222.213214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_33\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 231.123864 224.64 \nL 237.211136 224.64 \nL 237.211136 221.404286 \nL 231.123864 221.404286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_34\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 237.211136 224.64 \nL 243.298409 224.64 \nL 243.298409 221.404286 \nL 237.211136 221.404286 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_35\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 243.298409 224.64 \nL 249.385682 224.64 \nL 249.385682 224.64 \nL 243.298409 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_36\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 249.385682 224.64 \nL 255.472955 224.64 \nL 255.472955 222.213214 \nL 249.385682 222.213214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_37\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 255.472955 224.64 \nL 261.560227 224.64 \nL 261.560227 224.64 \nL 255.472955 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_38\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 261.560227 224.64 \nL 267.6475 224.64 \nL 267.6475 222.213214 \nL 261.560227 222.213214 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_39\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 267.6475 224.64 \nL 273.734773 224.64 \nL 273.734773 223.831071 \nL 267.6475 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_40\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 273.734773 224.64 \nL 279.822045 224.64 \nL 279.822045 223.831071 \nL 273.734773 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_41\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 279.822045 224.64 \nL 285.909318 224.64 \nL 285.909318 218.9775 \nL 279.822045 218.9775 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_42\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 285.909318 224.64 \nL 291.996591 224.64 \nL 291.996591 224.64 \nL 285.909318 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_43\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 291.996591 224.64 \nL 298.083864 224.64 \nL 298.083864 224.64 \nL 291.996591 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_44\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 298.083864 224.64 \nL 304.171136 224.64 \nL 304.171136 223.831071 \nL 298.083864 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_45\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 304.171136 224.64 \nL 310.258409 224.64 \nL 310.258409 223.831071 \nL 304.171136 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_46\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 310.258409 224.64 \nL 316.345682 224.64 \nL 316.345682 224.64 \nL 310.258409 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_47\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 316.345682 224.64 \nL 322.432955 224.64 \nL 322.432955 224.64 \nL 316.345682 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_48\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 322.432955 224.64 \nL 328.520227 224.64 \nL 328.520227 224.64 \nL 322.432955 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_49\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 328.520227 224.64 \nL 334.6075 224.64 \nL 334.6075 224.64 \nL 328.520227 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_50\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 334.6075 224.64 \nL 340.694773 224.64 \nL 340.694773 224.64 \nL 334.6075 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_51\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 340.694773 224.64 \nL 346.782045 224.64 \nL 346.782045 224.64 \nL 340.694773 224.64 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"patch_52\">\n    <path clip-path=\"url(#p8459d3defc)\" d=\"M 346.782045 224.64 \nL 352.869318 224.64 \nL 352.869318 223.831071 \nL 346.782045 223.831071 \nz\n\" style=\"fill:#1f77b4;opacity:0.4;\"/>\n   </g>\n   <g id=\"matplotlib.axis_1\">\n    <g id=\"xtick_1\">\n     <g id=\"line2d_1\">\n      <defs>\n       <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"m6e779eec92\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"48.27002\" xlink:href=\"#m6e779eec92\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_1\">\n      <!-- 0 -->\n      <defs>\n       <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421875 \nQ 19.53125 -1.421875 13.0625 8.265625 \nQ 6.59375 17.96875 6.59375 36.375 \nQ 6.59375 54.828125 13.0625 64.515625 \nQ 19.53125 74.21875 31.78125 74.21875 \nz\n\" id=\"DejaVuSans-48\"/>\n      </defs>\n      <g transform=\"translate(45.08877 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_2\">\n     <g id=\"line2d_2\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"113.46009\" xlink:href=\"#m6e779eec92\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_2\">\n      <!-- 200000 -->\n      <defs>\n       <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n      </defs>\n      <g transform=\"translate(94.37259 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_3\">\n     <g id=\"line2d_3\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"178.65016\" xlink:href=\"#m6e779eec92\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_3\">\n      <!-- 400000 -->\n      <defs>\n       <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375 17.1875 \nL 47.609375 0 \nL 37.796875 0 \nL 37.796875 17.1875 \nL 4.890625 17.1875 \nL 4.890625 26.703125 \nz\n\" id=\"DejaVuSans-52\"/>\n      </defs>\n      <g transform=\"translate(159.56266 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-52\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_4\">\n     <g id=\"line2d_4\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"243.84023\" xlink:href=\"#m6e779eec92\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_4\">\n      <!-- 600000 -->\n      <defs>\n       <path d=\"M 33.015625 40.375 \nQ 26.375 40.375 22.484375 35.828125 \nQ 18.609375 31.296875 18.609375 23.390625 \nQ 18.609375 15.53125 22.484375 10.953125 \nQ 26.375 6.390625 33.015625 6.390625 \nQ 39.65625 6.390625 43.53125 10.953125 \nQ 47.40625 15.53125 47.40625 23.390625 \nQ 47.40625 31.296875 43.53125 35.828125 \nQ 39.65625 40.375 33.015625 40.375 \nz\nM 52.59375 71.296875 \nL 52.59375 62.3125 \nQ 48.875 64.0625 45.09375 64.984375 \nQ 41.3125 65.921875 37.59375 65.921875 \nQ 27.828125 65.921875 22.671875 59.328125 \nQ 17.53125 52.734375 16.796875 39.40625 \nQ 19.671875 43.65625 24.015625 45.921875 \nQ 28.375 48.1875 33.59375 48.1875 \nQ 44.578125 48.1875 50.953125 41.515625 \nQ 57.328125 34.859375 57.328125 23.390625 \nQ 57.328125 12.15625 50.6875 5.359375 \nQ 44.046875 -1.421875 33.015625 -1.421875 \nQ 20.359375 -1.421875 13.671875 8.265625 \nQ 6.984375 17.96875 6.984375 36.375 \nQ 6.984375 53.65625 15.1875 63.9375 \nQ 23.390625 74.21875 37.203125 74.21875 \nQ 40.921875 74.21875 44.703125 73.484375 \nQ 48.484375 72.75 52.59375 71.296875 \nz\n\" id=\"DejaVuSans-54\"/>\n      </defs>\n      <g transform=\"translate(224.75273 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-54\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"xtick_5\">\n     <g id=\"line2d_5\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"309.0303\" xlink:href=\"#m6e779eec92\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_5\">\n      <!-- 800000 -->\n      <defs>\n       <path d=\"M 31.78125 34.625 \nQ 24.75 34.625 20.71875 30.859375 \nQ 16.703125 27.09375 16.703125 20.515625 \nQ 16.703125 13.921875 20.71875 10.15625 \nQ 24.75 6.390625 31.78125 6.390625 \nQ 38.8125 6.390625 42.859375 10.171875 \nQ 46.921875 13.96875 46.921875 20.515625 \nQ 46.921875 27.09375 42.890625 30.859375 \nQ 38.875 34.625 31.78125 34.625 \nz\nM 21.921875 38.8125 \nQ 15.578125 40.375 12.03125 44.71875 \nQ 8.5 49.078125 8.5 55.328125 \nQ 8.5 64.0625 14.71875 69.140625 \nQ 20.953125 74.21875 31.78125 74.21875 \nQ 42.671875 74.21875 48.875 69.140625 \nQ 55.078125 64.0625 55.078125 55.328125 \nQ 55.078125 49.078125 51.53125 44.71875 \nQ 48 40.375 41.703125 38.8125 \nQ 48.828125 37.15625 52.796875 32.3125 \nQ 56.78125 27.484375 56.78125 20.515625 \nQ 56.78125 9.90625 50.3125 4.234375 \nQ 43.84375 -1.421875 31.78125 -1.421875 \nQ 19.734375 -1.421875 13.25 4.234375 \nQ 6.78125 9.90625 6.78125 20.515625 \nQ 6.78125 27.484375 10.78125 32.3125 \nQ 14.796875 37.15625 21.921875 38.8125 \nz\nM 18.3125 54.390625 \nQ 18.3125 48.734375 21.84375 45.5625 \nQ 25.390625 42.390625 31.78125 42.390625 \nQ 38.140625 42.390625 41.71875 45.5625 \nQ 45.3125 48.734375 45.3125 54.390625 \nQ 45.3125 60.0625 41.71875 63.234375 \nQ 38.140625 66.40625 31.78125 66.40625 \nQ 25.390625 66.40625 21.84375 63.234375 \nQ 18.3125 60.0625 18.3125 54.390625 \nz\n\" id=\"DejaVuSans-56\"/>\n      </defs>\n      <g transform=\"translate(289.9428 239.238437)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-56\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"190.869141\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"254.492188\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"318.115234\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"matplotlib.axis_2\">\n    <g id=\"ytick_1\">\n     <g id=\"line2d_6\">\n      <defs>\n       <path d=\"M 0 0 \nL -3.5 0 \n\" id=\"md0c86dbdd4\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n      </defs>\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"33.2875\" xlink:href=\"#md0c86dbdd4\" y=\"224.64\"/>\n      </g>\n     </g>\n     <g id=\"text_6\">\n      <!-- 0 -->\n      <g transform=\"translate(19.925 228.439219)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_2\">\n     <g id=\"line2d_7\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"33.2875\" xlink:href=\"#md0c86dbdd4\" y=\"184.193571\"/>\n      </g>\n     </g>\n     <g id=\"text_7\">\n      <!-- 50 -->\n      <defs>\n       <path d=\"M 10.796875 72.90625 \nL 49.515625 72.90625 \nL 49.515625 64.59375 \nL 19.828125 64.59375 \nL 19.828125 46.734375 \nQ 21.96875 47.46875 24.109375 47.828125 \nQ 26.265625 48.1875 28.421875 48.1875 \nQ 40.625 48.1875 47.75 41.5 \nQ 54.890625 34.8125 54.890625 23.390625 \nQ 54.890625 11.625 47.5625 5.09375 \nQ 40.234375 -1.421875 26.90625 -1.421875 \nQ 22.3125 -1.421875 17.546875 -0.640625 \nQ 12.796875 0.140625 7.71875 1.703125 \nL 7.71875 11.625 \nQ 12.109375 9.234375 16.796875 8.0625 \nQ 21.484375 6.890625 26.703125 6.890625 \nQ 35.15625 6.890625 40.078125 11.328125 \nQ 45.015625 15.765625 45.015625 23.390625 \nQ 45.015625 31 40.078125 35.4375 \nQ 35.15625 39.890625 26.703125 39.890625 \nQ 22.75 39.890625 18.8125 39.015625 \nQ 14.890625 38.140625 10.796875 36.28125 \nz\n\" id=\"DejaVuSans-53\"/>\n      </defs>\n      <g transform=\"translate(13.5625 187.99279)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-53\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_3\">\n     <g id=\"line2d_8\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"33.2875\" xlink:href=\"#md0c86dbdd4\" y=\"143.747143\"/>\n      </g>\n     </g>\n     <g id=\"text_8\">\n      <!-- 100 -->\n      <defs>\n       <path d=\"M 12.40625 8.296875 \nL 28.515625 8.296875 \nL 28.515625 63.921875 \nL 10.984375 60.40625 \nL 10.984375 69.390625 \nL 28.421875 72.90625 \nL 38.28125 72.90625 \nL 38.28125 8.296875 \nL 54.390625 8.296875 \nL 54.390625 0 \nL 12.40625 0 \nz\n\" id=\"DejaVuSans-49\"/>\n      </defs>\n      <g transform=\"translate(7.2 147.546362)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_4\">\n     <g id=\"line2d_9\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"33.2875\" xlink:href=\"#md0c86dbdd4\" y=\"103.300714\"/>\n      </g>\n     </g>\n     <g id=\"text_9\">\n      <!-- 150 -->\n      <g transform=\"translate(7.2 107.099933)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-49\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_5\">\n     <g id=\"line2d_10\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"33.2875\" xlink:href=\"#md0c86dbdd4\" y=\"62.854286\"/>\n      </g>\n     </g>\n     <g id=\"text_10\">\n      <!-- 200 -->\n      <g transform=\"translate(7.2 66.653504)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n    <g id=\"ytick_6\">\n     <g id=\"line2d_11\">\n      <g>\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"33.2875\" xlink:href=\"#md0c86dbdd4\" y=\"22.407857\"/>\n      </g>\n     </g>\n     <g id=\"text_11\">\n      <!-- 250 -->\n      <g transform=\"translate(7.2 26.207076)scale(0.1 -0.1)\">\n       <use xlink:href=\"#DejaVuSans-50\"/>\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n      </g>\n     </g>\n    </g>\n   </g>\n   <g id=\"patch_53\">\n    <path d=\"M 33.2875 224.64 \nL 33.2875 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_54\">\n    <path d=\"M 368.0875 224.64 \nL 368.0875 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_55\">\n    <path d=\"M 33.2875 224.64 \nL 368.0875 224.64 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n   <g id=\"patch_56\">\n    <path d=\"M 33.2875 7.2 \nL 368.0875 7.2 \n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n   </g>\n  </g>\n </g>\n <defs>\n  <clipPath id=\"p8459d3defc\">\n   <rect height=\"217.44\" width=\"334.8\" x=\"33.2875\" y=\"7.2\"/>\n  </clipPath>\n </defs>\n</svg>\n",
-      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPuklEQVR4nO3dXYxcZ33H8e+vMYQWEHEa1zKxXRvkIpmLJukqCYKLtGlJiCoCEoqcVMSlIKM2kaBFqhK4gF5EohUvJWobMCQlVJiQQmgslJYGNxLiAoNN07ybGPJmy4kNtAkqEmrCvxfzOBnMvs/ujvfZ70cazTn/c86cZ84+/u3xM2fOpqqQJPXlV8bdAEnSwjPcJalDhrskdchwl6QOGe6S1KFV424AwBlnnFGbNm0adzMkaVnZv3//D6tqzWTLTopw37RpE/v27Rt3MyRpWUny2FTLHJaRpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOnRTfUB3Frr2PT7nsivM2LmFLJOnkMeOZe5INSe5K8kCS+5O8p9U/lORwkrvb45Khba5NcjDJgSQXLeYbkCT9stmcuT8LvK+qvpvk5cD+JHe2ZR+vqo8Mr5xkK7ANeC3wSuDrSX6rqp5byIZLkqY245l7VR2pqu+26Z8ADwJnTrPJpcAtVfWzqnoEOAicuxCNlSTNzpw+UE2yCTgb2NtKVye5J8lNSVa32pnAE0ObHWL6XwaSpAU263BP8jLgy8B7q+oZ4Abg1cBZwBHgo3PZcZIdSfYl2Xfs2LG5bCpJmsGswj3JixgE++er6jaAqnqqqp6rqp8Dn+aFoZfDwIahzde32i+oqp1VNVFVE2vWTHqveUnSPM3mapkANwIPVtXHhurrhlZ7K3Bfm94NbEtyapLNwBbg2wvXZEnSTGZztczrgbcD9ya5u9XeD1ye5CyggEeBdwNU1f1JbgUeYHClzVVeKSNJS2vGcK+qbwKZZNEd02xzHXDdCO2SJI3A2w9IUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOjRjuCfZkOSuJA8kuT/Je1r99CR3Jnm4Pa9u9SS5PsnBJPckOWex34Qk6RfN5sz9WeB9VbUVOB+4KslW4BpgT1VtAfa0eYA3AVvaYwdww4K3WpI0rRnDvaqOVNV32/RPgAeBM4FLgZvbajcDb2nTlwKfq4FvAaclWbfgLZckTWlOY+5JNgFnA3uBtVV1pC16Eljbps8Enhja7FCrnfhaO5LsS7Lv2LFjc2y2JGk6sw73JC8Dvgy8t6qeGV5WVQXUXHZcVTuraqKqJtasWTOXTSVJM5hVuCd5EYNg/3xV3dbKTx0fbmnPR1v9MLBhaPP1rSZJWiKzuVomwI3Ag1X1saFFu4HtbXo7cPtQ/cp21cz5wNNDwzeSpCWwahbrvB54O3Bvkrtb7f3Ah4Fbk7wTeAy4rC27A7gEOAj8FHjHgrZYkjSjGcO9qr4JZIrFF06yfgFXjdguSdII/IaqJHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOz+YbqsrVr7+OT1q84b+MSt0SSlpZn7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDnX9xzqm4h/xkNQ7z9wlqUOGuyR1yHCXpA4Z7pLUIcNdkjo0Y7gnuSnJ0ST3DdU+lORwkrvb45KhZdcmOZjkQJKLFqvhkqSpzebM/bPAxZPUP15VZ7XHHQBJtgLbgNe2bf4hySkL1VhJ0uzMGO5V9Q3gx7N8vUuBW6rqZ1X1CHAQOHeE9kmS5mGUMferk9zThm1Wt9qZwBND6xxqtV+SZEeSfUn2HTt2bIRmSJJONN9wvwF4NXAWcAT46FxfoKp2VtVEVU2sWbNmns2QJE1mXuFeVU9V1XNV9XPg07ww9HIY2DC06vpWkyQtoXmFe5J1Q7NvBY5fSbMb2Jbk1CSbgS3At0droiRprma8cViSLwAXAGckOQR8ELggyVlAAY8C7waoqvuT3Ao8ADwLXFVVzy1O0yVJU5kx3Kvq8knKN06z/nXAdaM0SpI0Gr+hKkkdMtwlqUOGuyR1aEX+Jaap+BeaJPXCM3dJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1aMZwT3JTkqNJ7huqnZ7kziQPt+fVrZ4k1yc5mOSeJOcsZuMlSZObzZn7Z4GLT6hdA+ypqi3AnjYP8CZgS3vsAG5YmGZKkuZixnCvqm8APz6hfClwc5u+GXjLUP1zNfAt4LQk6xaqsZKk2Vk1z+3WVtWRNv0ksLZNnwk8MbTeoVY7wgmS7GBwds/GjRvn2YylsWvv45PWrzjv5G63pJVr5A9Uq6qAmsd2O6tqoqom1qxZM2ozJElD5hvuTx0fbmnPR1v9MLBhaL31rSZJWkLzDffdwPY2vR24fah+Zbtq5nzg6aHhG0nSEplxzD3JF4ALgDOSHAI+CHwYuDXJO4HHgMva6ncAlwAHgZ8C71iENkuSZjBjuFfV5VMsunCSdQu4atRGSZJG4zdUJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOrRp3A5azXXsfn7R+xXkbl7glkvSLPHOXpA4Z7pLUIcNdkjo00ph7kkeBnwDPAc9W1USS04EvApuAR4HLquq/R2umJGkuFuLM/Xer6qyqmmjz1wB7qmoLsKfNS5KW0GIMy1wK3Nymbwbesgj7kCRNY9RwL+Dfk+xPsqPV1lbVkTb9JLB2sg2T7EiyL8m+Y8eOjdgMSdKwUa9zf0NVHU7yG8CdSR4aXlhVlaQm27CqdgI7ASYmJiZdR5I0PyOduVfV4fZ8FPgKcC7wVJJ1AO356KiNlCTNzbzDPclLk7z8+DTwRuA+YDewva22Hbh91EZKkuZmlGGZtcBXkhx/nV1V9W9JvgPcmuSdwGPAZaM3U5I0F/MO96r6AfDbk9R/BFw4SqMkSaPxG6qS1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOjTqLX81iV17H5+0fsV5G5e4JZJWKs/cJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXI2w8sobnelsDbGEiaL8/cJalDhrskdchhmZPAVMMvkjRfnrlLUocMd0nqkMMyK8B0wz4n25U6XiEkLQzDvSOO3Us6znCXtGj8n9j4GO4rnGf7Up8Md83JXH8ZLNQZ2nI6A1xObVW/Fi3ck1wMfAI4BfhMVX14sfa10ni2PX8n44fLC/XzXKj2z+cY6eSzKOGe5BTg74E/AA4B30myu6oeWIz96eS12L+I/EUnTW6xztzPBQ5W1Q8AktwCXAoY7jpp+YtoZuN6D+MaDlxISz1cl6pa+BdN3gZcXFXvavNvB86rqquH1tkB7GizrwEOzHN3ZwA/HKG5PfAYDHgcPAawso7Bb1bVmskWjO0D1araCewc9XWS7KuqiQVo0rLlMRjwOHgMwGNw3GLdfuAwsGFofn2rSZKWwGKF+3eALUk2J3kxsA3YvUj7kiSdYFGGZarq2SRXA19jcCnkTVV1/2LsiwUY2umAx2DA4+AxAI8BsEgfqEqSxstb/kpShwx3SerQsg73JBcnOZDkYJJrxt2euUqyIcldSR5Icn+S97T66UnuTPJwe17d6klyfXu/9yQ5Z+i1trf1H06yfaj+O0nubdtcnyTT7WNckpyS5D+TfLXNb06yt7X7i+2DeZKc2uYPtuWbhl7j2lY/kOSiofqk/WSqfYxLktOSfCnJQ0keTPK6ldYXkvx5+7dwX5IvJHnJSuwLC6KqluWDwQe13wdeBbwY+C9g67jbNcf3sA44p02/HPgesBX4G+CaVr8G+Os2fQnwr0CA84G9rX468IP2vLpNr27Lvt3WTdv2Ta0+6T7GeCz+AtgFfLXN3wpsa9OfBP60Tf8Z8Mk2vQ34Ypve2vrAqcDm1jdOma6fTLWPMR6Dm4F3tekXA6etpL4AnAk8Avzq0M/nj1diX1iQ4znuBozQEV4HfG1o/lrg2nG3a8T3dDuD+/EcANa12jrgQJv+FHD50PoH2vLLgU8N1T/VauuAh4bqz6831T7G9L7XA3uA3wO+2sLnh8CqE3/WDK7Ael2bXtXWy4k//+PrTdVPptvHmI7BK1qw5YT6iukLDML9CQa/mFa1vnDRSusLC/VYzsMyxzvCcYdabVlq/6U8G9gLrK2qI23Rk8DaNj3Ve56ufmiSOtPsYxz+FvhL4Odt/teB/6mqZ9v8cLuff69t+dNt/bkem+n2MQ6bgWPAP7bhqc8keSkrqC9U1WHgI8DjwBEGP9v9rLy+sCCWc7h3I8nLgC8D762qZ4aX1eBUYlGvV12KfUwlyR8CR6tq/zj2fxJZBZwD3FBVZwP/y2CI5HkroC+sZnCDwc3AK4GXAhePoy09WM7h3sUtDpK8iEGwf76qbmvlp5Ksa8vXAUdbfar3PF19/ST16fax1F4PvDnJo8AtDIZmPgGcluT4l+yG2/38e23LXwH8iLkfmx9Ns49xOAQcqqq9bf5LDMJ+JfWF3wceqapjVfV/wG0M+sdK6wsLYjmH+7K/xUG7WuFG4MGq+tjQot3A8asctjMYiz9ev7JdKXE+8HT77/TXgDcmWd3Oft7IYMzwCPBMkvPbvq484bUm28eSqqprq2p9VW1i8DP8j6r6I+Au4G2TtG+43W9r61erb2tXUGwGtjD4AHHSftK2mWofS66qngSeSPKaVrqQwS2yV0xfYDAcc36SX2ttPH4MVlRfWDDjHvQf5cHgioHvMfgE/APjbs882v8GBv8Fvge4uz0uYTAGuAd4GPg6cHpbPwz+CMr3gXuBiaHX+hPgYHu8Y6g+AdzXtvk7XvhW8qT7GPPxuIAXrpZ5FYN/kAeBfwZObfWXtPmDbfmrhrb/QHufB2hXgkzXT6baxxjf/1nAvtYf/oXB1S4rqi8AfwU81Nr5TwyueFlxfWEhHt5+QJI6tJyHZSRJUzDcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUof+HwNE9h5uGuuvAAAAAElFTkSuQmCC\n"
-     },
-     "metadata": {
-      "needs_background": "light"
-     }
-    }
-   ],
-   "source": [
-    "sns.distplot(document_lens, kde=False)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": 3
-  },
-  "orig_nbformat": 2,
-  "kernelspec": {
-   "name": "python_defaultSpec_1594287746726",
-   "display_name": "Python 3.8.2 64-bit"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
\ No newline at end of file
diff --git a/dataset_generation/processing.py b/dataset_generation/processing.py
index 9a9f44d..6543b7e 100644
--- a/dataset_generation/processing.py
+++ b/dataset_generation/processing.py
@@ -334,14 +334,14 @@ def add_padding(tokens: np.ndarray, labels: np.ndarray, length: int, tokenizer:
         np.ndarray: (L+P)-length array of masks where True means token False - padding
     """
 
-    pad_length = tokens.shape[0] - length
+    pad_length = length - tokens.shape[0]
     assert pad_length >= 0
 
     if pad_length > 0:
         tokens = np.concatenate([tokens, [[tokenizer.pad_token_id]] * pad_length])
         labels = np.concatenate([labels, [empty_action_vector()] * pad_length])
 
-    mask = np.ones(len(tokens)).astype(np.bool)
+    mask = np.ones(len(tokens)).astype(np.int)
 
     if pad_length > 0:
         mask[-pad_length:] = False
@@ -373,7 +373,13 @@ def batchify_data(tokens: np.ndarray, labels: np.ndarray, max_tokens: int,
         tokens_sample = tokens[ids, :]
         labels_sample = labels[ids, :]
 
+        assert len(ids) >= min_tokens
+        assert len(ids) <= max_tokens - 2
+
         tokens_sample, labels_sample = add_cls_sep(tokens_sample, labels_sample, tokenizer)
+
+        assert len(tokens_sample) <= max_tokens
+
         tokens_sample, labels_sample, mask = add_padding(tokens_sample, labels_sample, max_tokens, tokenizer)
 
         tokens_batch.append(tokens_sample)
diff --git a/dataset_generation/stage1_extraction.py b/dataset_generation/stage1_extraction.py
new file mode 100644
index 0000000..f14adcd
--- /dev/null
+++ b/dataset_generation/stage1_extraction.py
@@ -0,0 +1,35 @@
+# /usr/bin/python3
+import glob
+import numpy as np
+from processing import text_from_xml, create_model_input_output
+from dask.diagnostics import ProgressBar
+import dask.dataframe as dd
+import pandas as pd
+
+OUTPUT_FOLDER = "../generated/stage1_extraction"
+
+if __name__ == "__main__":
+    file_schema = "../dane/**/text_structure.xml"
+    files_paths = glob.glob(file_schema, recursive=True)
+
+    def process_file(file_path):
+        full_text = text_from_xml(file_path)
+
+        if len(full_text) > 0:
+            model_input, model_output = create_model_input_output(full_text)
+
+            output_shape = np.array(model_output.shape, dtype=np.int)
+
+            return {'input': model_input, 'output': model_output.reshape(-1), 'output_shape': output_shape}
+        else:
+            return {'input': None, 'output': None, 'output_shape': None}
+
+    df = dd.from_pandas(pd.DataFrame({'file': files_paths}), npartitions=24)
+    res = df.apply(lambda x: process_file(x.file), result_type='expand', axis=1, meta={
+                   'input': str, 'output': object, 'output_shape': object})
+    res = res.dropna()
+
+    file_save = res.to_parquet(OUTPUT_FOLDER, compute=False)
+
+    with ProgressBar():
+        file_save.compute(scheduler='processes')
diff --git a/dataset_generation/stage2_tokenization.py b/dataset_generation/stage2_tokenization.py
new file mode 100644
index 0000000..45b1827
--- /dev/null
+++ b/dataset_generation/stage2_tokenization.py
@@ -0,0 +1,58 @@
+# /usr/bin/python3
+import os
+import glob
+import random
+from lxml import etree
+import uuid
+import hashlib
+import seaborn as sns
+import re
+import numpy as np
+from tqdm import tqdm
+from processing import text_from_xml, create_model_input_output, tokenize_labeled_text
+from utils import remove_multiple_spaces, remove_punctuation
+import dask
+from dask.diagnostics import ProgressBar
+import dask.dataframe as dd
+import pandas as pd
+from transformers import BertTokenizerFast
+from dask.distributed import Client
+
+INPUT_FOLDER = "../generated/stage1_extraction"
+OUTPUT_FOLDER = "../generated/stage2_tokenization"
+
+if __name__ == "__main__":
+    df = dd.read_parquet(INPUT_FOLDER)
+
+    # set up cluster and workers
+    #client = Client(n_workers=24)
+    #print(client.dashboard_link)
+
+    df.repartition(npartitions=(24 * 10))
+
+    tokenizer = BertTokenizerFast.from_pretrained(
+        'bert-base-multilingual-cased')
+
+    def apply_tokenization(text_clean: str, labels: np.ndarray, shape: np.ndarray):
+        global tokenizer
+
+        tokens, token_labels = tokenize_labeled_text(
+            text_clean, labels.reshape(shape), tokenizer)
+
+        new_input_shape = np.array(tokens.shape)
+        new_output_shape = np.array(token_labels.shape)
+
+        return {
+            'tokens': tokens.reshape(-1),
+            'labels': token_labels.reshape(-1),
+            'input_shape': new_input_shape,
+            "output_shape": new_output_shape
+        }
+
+    res = df.apply(lambda x: apply_tokenization(x.input, x.output, x.output_shape),
+                   result_type='expand', axis=1, meta={'tokens': object, 'labels': object, 'input_shape': object, 'output_shape': object})
+
+    file_save = res.to_parquet(OUTPUT_FOLDER, compute=False)
+
+    with ProgressBar():
+        file_save.compute(num_workers=1)
diff --git a/dataset_generation/stage3_spliting.py b/dataset_generation/stage3_spliting.py
new file mode 100644
index 0000000..435306b
--- /dev/null
+++ b/dataset_generation/stage3_spliting.py
@@ -0,0 +1,48 @@
+# /usr/bin/python3
+from processing import batchify_data
+from dask.diagnostics import ProgressBar
+import dask.dataframe as dd
+from transformers import BertTokenizerFast
+import numpy as np
+
+INPUT_FOLDER = "../generated/stage2_tokenization"
+OUTPUT_FOLDER = "../generated/stage3_spliting"
+
+MIN_TOKENS = 10
+MAX_TOKENS = 50
+
+if __name__ == "__main__":
+    df = dd.read_parquet(INPUT_FOLDER)
+
+    tokenizer = BertTokenizerFast.from_pretrained(
+        'bert-base-multilingual-cased')
+
+    def apply_splitting(entry):
+        global tokenizer
+
+        tokens = entry.tokens.reshape(entry.input_shape)
+        labels = entry.labels.reshape(entry.output_shape)
+
+        tokens_split, labels_split, masks_split = batchify_data(
+            tokens, labels, MAX_TOKENS, tokenizer, MIN_TOKENS)
+
+        tokens_shape = np.array(tokens_split.shape)
+        labels_shape = np.array(labels_split.shape)
+        masks_shape = np.array(masks_split.shape)
+
+        return {
+            'input': tokens_split.reshape(-1),
+            'output': labels_split.reshape(-1),
+            "masks": masks_split.reshape(-1),
+            "input_shape": tokens_shape,
+            'output_shape': labels_shape,
+            "mask_shape": masks_shape
+        }
+
+    res = df.apply(apply_splitting,
+                   result_type='expand', axis=1, meta={'input': object, 'output': object, 'masks': object, 'input_shape': object, "output_shape": object, "mask_shape": object})
+
+    file_save = res.to_parquet(OUTPUT_FOLDER, compute=False)
+
+    with ProgressBar():
+        file_save.compute(num_workers=1)
diff --git a/dataset_generation/stage4_exploding.py b/dataset_generation/stage4_exploding.py
new file mode 100644
index 0000000..addead2
--- /dev/null
+++ b/dataset_generation/stage4_exploding.py
@@ -0,0 +1,50 @@
+# /usr/bin/python3
+from processing import batchify_data
+from dask.diagnostics import ProgressBar
+import dask.dataframe as dd
+from transformers import BertTokenizerFast
+import numpy as np
+
+INPUT_FOLDER = "../generated/stage3_spliting"
+OUTPUT_FOLDER = "../generated/stage4_exploding"
+from dask.distributed import Client
+
+MIN_TOKENS = 10
+MAX_TOKENS = 50
+
+if __name__ == "__main__":
+    client = Client()
+
+    df = dd.read_parquet(INPUT_FOLDER)
+
+    tokenizer = BertTokenizerFast.from_pretrained(
+        'bert-base-multilingual-cased')
+
+    def expand_dims(entry):
+        global tokenizer
+
+        inputs = entry.input.reshape(entry.input_shape)
+        outputs = entry.output.reshape(entry.output_shape)
+        masks = entry.masks.reshape(entry.mask_shape)
+
+        return {
+            'inputs': inputs,
+            'outputs': outputs,
+            "masks": masks,
+        }
+
+    res = df.repartition(partition_size="100MB")
+    file_save = res.to_parquet(OUTPUT_FOLDER, compute=False)
+
+    with ProgressBar():
+        file_save.persist(num_workers=2)
+
+    res = res.apply(expand_dims, result_type='expand', axis=1, meta={'inputs': object, 'outputs': object, 'masks': object})
+    res = res.map_partitions(lambda x: x.apply(lambda y: y.explode(), axis=0), meta={'inputs': object, 'outputs': object, 'masks': object})
+
+    res.visualize('log.svg')
+
+    file_save = res.to_parquet(OUTPUT_FOLDER, compute=False)
+
+    with ProgressBar():
+        file_save.compute(num_workers=24, scheduler="processes")
diff --git a/dataset_generation/test_processing.py b/dataset_generation/test_processing.py
index 573c6ca..13f5b72 100644
--- a/dataset_generation/test_processing.py
+++ b/dataset_generation/test_processing.py
@@ -1,7 +1,7 @@
-import numpy
 from processing import *
 from transformers import PreTrainedTokenizerFast, BertTokenizerFast
 
+
 def test_detect_actions():
     actions = detect_actions("Janek...", None)
     assert actions == {
@@ -46,6 +46,7 @@ def test_encode_actions():
 
     assert np.all(encode_actions(x) == np.array([1, 0, 0, 1, 0, 1]))
 
+
 def test_decode_actions():
     x = np.array([1, 0, 0, 1, 0, 1])
 
@@ -58,9 +59,11 @@ def test_decode_actions():
         'dash': True
     }
 
+
 def test_tokenize_labeled_text():
     text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię..."
-    tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
+    tokenizer = BertTokenizerFast.from_pretrained(
+        'bert-base-multilingual-cased')
 
     text_clean, labels = create_model_input_output(text)
     tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
@@ -75,6 +78,7 @@ def test_tokenize_labeled_text():
     assert tokens[0, 0] != tokenizer.cls_token_id
     assert tokens[-1, 0] != tokenizer.sep_token_id
 
+
 def test_nearest_sentence_l():
     end = create_dummy_action(True)
     word = create_dummy_action(False)
@@ -86,6 +90,7 @@ def test_nearest_sentence_l():
     assert nearest_sentence_l(entry, 5) == 5
     assert nearest_sentence_l(entry, 7) == 5
 
+
 def create_dummy_action(end_sentence: bool) -> np.array:
     return encode_actions({
         'dot': end_sentence,
@@ -96,6 +101,7 @@ def create_dummy_action(end_sentence: bool) -> np.array:
         'dash': False
     })
 
+
 def test_nearest_sentence_r():
     end = create_dummy_action(True)
     word = create_dummy_action(False)
@@ -108,6 +114,7 @@ def test_nearest_sentence_r():
     assert nearest_sentence_r(entry, 6) is None
     assert nearest_sentence_r(entry, 7) is None
 
+
 def test_batchify_labels():
     end = create_dummy_action(True)
     word = create_dummy_action(False)
@@ -119,18 +126,21 @@ def test_batchify_labels():
     assert np.all(batches[0] == range(0, 3))
     assert np.all(batches[1] == range(5, 8))
 
-def test_batchify_tokens():
+
+def test_batchify_data():
     text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam niedzwiedzia..."
-    tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
+    tokenizer = BertTokenizerFast.from_pretrained(
+        'bert-base-multilingual-cased')
 
     text_clean, labels = create_model_input_output(text)
     tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
 
     # print(tokenizer.convert_ids_to_tokens(tokens.reshape(-1).astype(int)))
-    #print(token_labels)
+    # print(token_labels)
+
+    input_batch, output_batch, mask_batch = batchify_data(
+        tokens, token_labels, 5, tokenizer)
 
-    input_batch, output_batch, mask_batch = batchify_data(tokens, token_labels, 5, tokenizer)
-    
     assert len(input_batch.shape) == 3
     assert len(output_batch.shape) == 3
     assert len(mask_batch.shape) == 2
@@ -148,20 +158,17 @@ def test_batchify_tokens():
     # Third dimension should be feature size
     assert input_batch.shape[2] == 1
     assert output_batch.shape[2] == len(ACTIONS_KEYS)
-    
-    # Mask should be boolean (True - leave, False - mask)
-    assert mask_batch.dtype == np.bool
-    
+
+    # Mask should be integer (1 - leave, 0 - mask out)
+    assert mask_batch.dtype == np.int
+
     # Should never be fully masked
-    assert np.all(mask_batch[:, 0] == True)
+    assert np.all(mask_batch[:, 0] == 0) == False
 
     for i in range(input_batch.shape[0]):
         # Should always start from beginning of the sentence
-        assert decode_actions(output_batch[i, 0, :])['upper_case'] == True
-        assert decode_actions(output_batch[i, 1, :])['upper_case'] == True
-
-        # Should always end with sep and padding
-
+        assert decode_actions(output_batch[i, 0, :])['upper_case']
+        assert decode_actions(output_batch[i, 1, :])['upper_case']
 
-def generate_batches(files: (str, str), batch_size: int, max_tokens: int):
-    pass
\ No newline at end of file
+        # Should always end with sep and padding#
+        # TODO: Test it
diff --git a/generated/.gitignore b/generated/.gitignore
new file mode 100644
index 0000000..c96a04f
--- /dev/null
+++ b/generated/.gitignore
@@ -0,0 +1,2 @@
+*
+!.gitignore
\ No newline at end of file
-- 
GitLab


From e02b6c591f1d5b43317b551df6982d49d6a19278 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 16 Jul 2020 14:51:47 +0200
Subject: [PATCH 008/116] Readme update

---
 README.md | 1 +
 1 file changed, 1 insertion(+)

diff --git a/README.md b/README.md
index 4a87d36..80bac4a 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,2 @@
 # Interpunkcja
+Przywracanie pierwotnej formy tekstu sprowadzonego do małych liter i pozbawionych interpunkcji.
\ No newline at end of file
-- 
GitLab


From fc6a2f217fac82555ad2cfdb7aa6828556154ba5 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Fri, 17 Jul 2020 22:22:02 +0200
Subject: [PATCH 009/116] Fixed all memory leaks

---
 .gitignore                                |  3 +-
 dataset_generation/processing.py          |  8 ++-
 dataset_generation/stage1_extraction.py   | 50 ++++++++++------
 dataset_generation/stage2_tokenization.py | 70 ++++++++++++++---------
 dataset_generation/stage3_exploding.py    | 62 ++++++++++++++++++++
 dataset_generation/stage3_spliting.py     | 48 ----------------
 dataset_generation/stage4_exploding.py    | 50 ----------------
 7 files changed, 147 insertions(+), 144 deletions(-)
 create mode 100644 dataset_generation/stage3_exploding.py
 delete mode 100644 dataset_generation/stage3_spliting.py
 delete mode 100644 dataset_generation/stage4_exploding.py

diff --git a/.gitignore b/.gitignore
index 1633793..0bdd65c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,4 +4,5 @@ dataset_actions
 **/dask-worker-space
 .vscode
 .idea
-test.ipynb
\ No newline at end of file
+test.ipynb
+.metals
\ No newline at end of file
diff --git a/dataset_generation/processing.py b/dataset_generation/processing.py
index 6543b7e..0b78970 100644
--- a/dataset_generation/processing.py
+++ b/dataset_generation/processing.py
@@ -1,6 +1,6 @@
 import glob
 import random
-from lxml import etree
+from xml.etree import ElementTree as ET
 import uuid
 import hashlib
 import seaborn as sns
@@ -11,6 +11,7 @@ from utils import remove_punctuation
 import numpy as np
 from more_itertools import windowed
 from transformers import PreTrainedTokenizerFast, BertTokenizerFast
+from memory_profiler import profile
 
 ACTIONS_KEYS = ['dot', 'upper_case', 'colon', 'semicolon', 'elipsis', 'dash']
 
@@ -31,7 +32,8 @@ def text_from_xml(path: str) -> str:
     Returns:
         str: Raw text
     """
-    root = etree.parse(path)
+    root = ET.parse(path).getroot()
+
     full_text = ""
 
     for node in root.iter('*'):
@@ -42,6 +44,8 @@ def text_from_xml(path: str) -> str:
             if text is not None and who is not None and who != "#komentarz":
                 full_text = " ".join([full_text, text])
 
+    del root
+
     return full_text
 
 
diff --git a/dataset_generation/stage1_extraction.py b/dataset_generation/stage1_extraction.py
index f14adcd..924e46f 100644
--- a/dataset_generation/stage1_extraction.py
+++ b/dataset_generation/stage1_extraction.py
@@ -4,32 +4,50 @@ import numpy as np
 from processing import text_from_xml, create_model_input_output
 from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
+import dask
 import pandas as pd
+from dask.distributed import Client
+import gc
+from memory_profiler import profile
+from pympler import muppy, summary
+import stackimpact
+import lorem
 
 OUTPUT_FOLDER = "../generated/stage1_extraction"
+NUM_PARTITIONS = 2_000
+NUM_WORKERS=24
+
+def process_file(x):
+    full_text = text_from_xml(x.file)
+
+    if len(full_text) > 0:
+        model_input, model_output = create_model_input_output(full_text)
+
+        output_shape = np.array(model_output.shape, dtype=np.int)
+
+        return {'input': model_input, 'output': model_output.reshape(-1), 'output_shape': output_shape}
+    else:
+        return {'input': None, 'output': None, 'output_shape': None}
 
 if __name__ == "__main__":
-    file_schema = "../dane/**/text_structure.xml"
+    file_schema = "../data/**/text_structure.xml"
     files_paths = glob.glob(file_schema, recursive=True)
 
-    def process_file(file_path):
-        full_text = text_from_xml(file_path)
-
-        if len(full_text) > 0:
-            model_input, model_output = create_model_input_output(full_text)
+    # Make sure python memory fragmentation won't go insane
+    np.random.shuffle(files_paths)
 
-            output_shape = np.array(model_output.shape, dtype=np.int)
+    client = Client(n_workers=NUM_WORKERS)
+    print(f"Dashboard: {client.dashboard_link}")
 
-            return {'input': model_input, 'output': model_output.reshape(-1), 'output_shape': output_shape}
-        else:
-            return {'input': None, 'output': None, 'output_shape': None}
+    # Processing pipeline
+    df = dd.from_pandas(pd.DataFrame({'file': files_paths}), npartitions=NUM_PARTITIONS)
 
-    df = dd.from_pandas(pd.DataFrame({'file': files_paths}), npartitions=24)
-    res = df.apply(lambda x: process_file(x.file), result_type='expand', axis=1, meta={
+    df = df.apply(process_file, result_type='expand', axis=1, meta={
                    'input': str, 'output': object, 'output_shape': object})
-    res = res.dropna()
+    df = df.dropna()
 
-    file_save = res.to_parquet(OUTPUT_FOLDER, compute=False)
+    # Even out sizes
+    # df = df.repartition(partition_size="100MB")
 
-    with ProgressBar():
-        file_save.compute(scheduler='processes')
+    # Export
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
\ No newline at end of file
diff --git a/dataset_generation/stage2_tokenization.py b/dataset_generation/stage2_tokenization.py
index 45b1827..8482c76 100644
--- a/dataset_generation/stage2_tokenization.py
+++ b/dataset_generation/stage2_tokenization.py
@@ -9,50 +9,66 @@ import seaborn as sns
 import re
 import numpy as np
 from tqdm import tqdm
-from processing import text_from_xml, create_model_input_output, tokenize_labeled_text
+from processing import tokenize_labeled_text, batchify_data
 from utils import remove_multiple_spaces, remove_punctuation
 import dask
 from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
-import pandas as pd
 from transformers import BertTokenizerFast
 from dask.distributed import Client
 
 INPUT_FOLDER = "../generated/stage1_extraction"
 OUTPUT_FOLDER = "../generated/stage2_tokenization"
 
-if __name__ == "__main__":
-    df = dd.read_parquet(INPUT_FOLDER)
+MIN_TOKENS = 10
+MAX_TOKENS = 50
 
-    # set up cluster and workers
-    #client = Client(n_workers=24)
-    #print(client.dashboard_link)
 
-    df.repartition(npartitions=(24 * 10))
+def apply_tokenization(df, tokenizer: BertTokenizerFast):
+    text_clean = df.input
+    labels = df.output
+    shape = df.output_shape
 
-    tokenizer = BertTokenizerFast.from_pretrained(
-        'bert-base-multilingual-cased')
+    tokens, token_labels = tokenize_labeled_text(
+        text_clean, labels.reshape(shape), tokenizer)
+
+    inputs, outputs, attentions = batchify_data(
+        tokens, token_labels, MAX_TOKENS, tokenizer, MIN_TOKENS)
+
+    inputs_shape = np.array(inputs.shape)
+    outputs_shape = np.array(outputs.shape)
+    attentions_shape = np.array(attentions.shape)
 
-    def apply_tokenization(text_clean: str, labels: np.ndarray, shape: np.ndarray):
-        global tokenizer
+    return {
+        'inputs': inputs.reshape(-1),
+        'outputs': outputs.reshape(-1),
+        'attentions': attentions.reshape(-1),
+        'input_shape': inputs_shape,
+        'output_shape': outputs_shape,
+        'attentions_shape': attentions_shape
+    }
 
-        tokens, token_labels = tokenize_labeled_text(
-            text_clean, labels.reshape(shape), tokenizer)
 
-        new_input_shape = np.array(tokens.shape)
-        new_output_shape = np.array(token_labels.shape)
+RESULT_META = {
+    'inputs': object,
+    'outputs': object,
+    'attentions': object,
+    'inputs_shape': object,
+    'outputs_shape': object,
+    'attentions_shape': object
+}
 
-        return {
-            'tokens': tokens.reshape(-1),
-            'labels': token_labels.reshape(-1),
-            'input_shape': new_input_shape,
-            "output_shape": new_output_shape
-        }
+if __name__ == "__main__":
+    client = Client(n_workers=24)
+    print(client.dashboard_link)
+
+    tokenizer = BertTokenizerFast.from_pretrained(
+        'bert-base-multilingual-cased')
 
-    res = df.apply(lambda x: apply_tokenization(x.input, x.output, x.output_shape),
-                   result_type='expand', axis=1, meta={'tokens': object, 'labels': object, 'input_shape': object, 'output_shape': object})
+    tokenizer = dask.delayed(tokenizer)
 
-    file_save = res.to_parquet(OUTPUT_FOLDER, compute=False)
+    df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
+    df = df.apply(apply_tokenization, args=(tokenizer,),
+                  result_type='expand', axis=1, meta=RESULT_META)
 
-    with ProgressBar():
-        file_save.compute(num_workers=1)
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/dataset_generation/stage3_exploding.py b/dataset_generation/stage3_exploding.py
new file mode 100644
index 0000000..bfcb088
--- /dev/null
+++ b/dataset_generation/stage3_exploding.py
@@ -0,0 +1,62 @@
+# /usr/bin/python3
+from processing import batchify_data
+from dask.diagnostics import ProgressBar
+import dask.dataframe as dd
+from transformers import BertTokenizerFast
+import numpy as np
+import dask
+from dask.distributed import Client
+
+INPUT_FOLDER = "../generated/stage2_tokenization"
+OUTPUT_FOLDER = "../generated/stage3_exploding"
+
+def expand_dims(entry):
+    inputs = entry.inputs.reshape(entry.input_shape)
+    outputs = entry.outputs.reshape(entry.output_shape)
+    masks = entry.attentions.reshape(entry.attentions_shape)
+
+    return {
+        'inputs': inputs,
+        'outputs': outputs,
+        "attentions": masks,
+    }
+
+def flatten_dims(entry):
+    inputs_shape = np.array(entry.inputs.shape)
+    outputs_shape = np.array(entry.outputs.shape)
+    attentions_shape = np.array(entry.attentions.shape)
+
+    inputs = entry.inputs.reshape(-1)
+    outputs = entry.outputs.reshape(-1)
+    attentions = entry.attentions.reshape(-1)
+
+    return {
+        'inputs': inputs,
+        'outputs': outputs,
+        'attentions': attentions,
+        'inputs_shape': inputs_shape,
+        'outputs_shape': outputs_shape,
+        'attentions_shape': attentions_shape
+    }
+
+
+RESULT_META = {
+    'inputs': object,
+    'outputs': object,
+    'attentions': object,
+    'inputs_shape': object,
+    'outputs_shape': object,
+    'attentions_shape': object
+}
+
+if __name__ == "__main__":
+    client = Client(n_workers=24)
+    print(client.dashboard_link)
+
+    df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
+
+    df = df.apply(expand_dims, result_type='expand', axis=1, meta={'inputs': object, 'outputs': object, 'attentions': object})
+    df = df.map_partitions(lambda x: x.apply(lambda y: y.explode(), axis=0), meta={'inputs': object, 'outputs': object, 'attentions': object})
+    df = df.apply(flatten_dims, result_type='expand', axis=1, meta=RESULT_META)
+
+    df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/dataset_generation/stage3_spliting.py b/dataset_generation/stage3_spliting.py
deleted file mode 100644
index 435306b..0000000
--- a/dataset_generation/stage3_spliting.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# /usr/bin/python3
-from processing import batchify_data
-from dask.diagnostics import ProgressBar
-import dask.dataframe as dd
-from transformers import BertTokenizerFast
-import numpy as np
-
-INPUT_FOLDER = "../generated/stage2_tokenization"
-OUTPUT_FOLDER = "../generated/stage3_spliting"
-
-MIN_TOKENS = 10
-MAX_TOKENS = 50
-
-if __name__ == "__main__":
-    df = dd.read_parquet(INPUT_FOLDER)
-
-    tokenizer = BertTokenizerFast.from_pretrained(
-        'bert-base-multilingual-cased')
-
-    def apply_splitting(entry):
-        global tokenizer
-
-        tokens = entry.tokens.reshape(entry.input_shape)
-        labels = entry.labels.reshape(entry.output_shape)
-
-        tokens_split, labels_split, masks_split = batchify_data(
-            tokens, labels, MAX_TOKENS, tokenizer, MIN_TOKENS)
-
-        tokens_shape = np.array(tokens_split.shape)
-        labels_shape = np.array(labels_split.shape)
-        masks_shape = np.array(masks_split.shape)
-
-        return {
-            'input': tokens_split.reshape(-1),
-            'output': labels_split.reshape(-1),
-            "masks": masks_split.reshape(-1),
-            "input_shape": tokens_shape,
-            'output_shape': labels_shape,
-            "mask_shape": masks_shape
-        }
-
-    res = df.apply(apply_splitting,
-                   result_type='expand', axis=1, meta={'input': object, 'output': object, 'masks': object, 'input_shape': object, "output_shape": object, "mask_shape": object})
-
-    file_save = res.to_parquet(OUTPUT_FOLDER, compute=False)
-
-    with ProgressBar():
-        file_save.compute(num_workers=1)
diff --git a/dataset_generation/stage4_exploding.py b/dataset_generation/stage4_exploding.py
deleted file mode 100644
index addead2..0000000
--- a/dataset_generation/stage4_exploding.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# /usr/bin/python3
-from processing import batchify_data
-from dask.diagnostics import ProgressBar
-import dask.dataframe as dd
-from transformers import BertTokenizerFast
-import numpy as np
-
-INPUT_FOLDER = "../generated/stage3_spliting"
-OUTPUT_FOLDER = "../generated/stage4_exploding"
-from dask.distributed import Client
-
-MIN_TOKENS = 10
-MAX_TOKENS = 50
-
-if __name__ == "__main__":
-    client = Client()
-
-    df = dd.read_parquet(INPUT_FOLDER)
-
-    tokenizer = BertTokenizerFast.from_pretrained(
-        'bert-base-multilingual-cased')
-
-    def expand_dims(entry):
-        global tokenizer
-
-        inputs = entry.input.reshape(entry.input_shape)
-        outputs = entry.output.reshape(entry.output_shape)
-        masks = entry.masks.reshape(entry.mask_shape)
-
-        return {
-            'inputs': inputs,
-            'outputs': outputs,
-            "masks": masks,
-        }
-
-    res = df.repartition(partition_size="100MB")
-    file_save = res.to_parquet(OUTPUT_FOLDER, compute=False)
-
-    with ProgressBar():
-        file_save.persist(num_workers=2)
-
-    res = res.apply(expand_dims, result_type='expand', axis=1, meta={'inputs': object, 'outputs': object, 'masks': object})
-    res = res.map_partitions(lambda x: x.apply(lambda y: y.explode(), axis=0), meta={'inputs': object, 'outputs': object, 'masks': object})
-
-    res.visualize('log.svg')
-
-    file_save = res.to_parquet(OUTPUT_FOLDER, compute=False)
-
-    with ProgressBar():
-        file_save.compute(num_workers=24, scheduler="processes")
-- 
GitLab


From d993b794468ad967db76a92f27e0364b071c377d Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 23 Jul 2020 12:17:30 +0200
Subject: [PATCH 010/116] More memory leak fixes, first training procedure

---
 dataset_generation/stage2_tokenization.py |  4 +-
 dataset_generation/stage3_exploding.py    |  5 +-
 dataset_generation/stage4_reindexing.py   | 27 +++++++++
 modeling/train.py                         | 69 +++++++++++++++++++++++
 4 files changed, 101 insertions(+), 4 deletions(-)
 create mode 100644 dataset_generation/stage4_reindexing.py
 create mode 100755 modeling/train.py

diff --git a/dataset_generation/stage2_tokenization.py b/dataset_generation/stage2_tokenization.py
index 8482c76..88ea96d 100644
--- a/dataset_generation/stage2_tokenization.py
+++ b/dataset_generation/stage2_tokenization.py
@@ -53,8 +53,8 @@ RESULT_META = {
     'inputs': object,
     'outputs': object,
     'attentions': object,
-    'inputs_shape': object,
-    'outputs_shape': object,
+    'input_shape': object,
+    'output_shape': object,
     'attentions_shape': object
 }
 
diff --git a/dataset_generation/stage3_exploding.py b/dataset_generation/stage3_exploding.py
index bfcb088..bfd46bb 100644
--- a/dataset_generation/stage3_exploding.py
+++ b/dataset_generation/stage3_exploding.py
@@ -6,6 +6,7 @@ from transformers import BertTokenizerFast
 import numpy as np
 import dask
 from dask.distributed import Client
+import pandas as pd
 
 INPUT_FOLDER = "../generated/stage2_tokenization"
 OUTPUT_FOLDER = "../generated/stage3_exploding"
@@ -50,7 +51,7 @@ RESULT_META = {
 }
 
 if __name__ == "__main__":
-    client = Client(n_workers=24)
+    client = Client(n_workers=24, memory_limit='2GB')
     print(client.dashboard_link)
 
     df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
@@ -58,5 +59,5 @@ if __name__ == "__main__":
     df = df.apply(expand_dims, result_type='expand', axis=1, meta={'inputs': object, 'outputs': object, 'attentions': object})
     df = df.map_partitions(lambda x: x.apply(lambda y: y.explode(), axis=0), meta={'inputs': object, 'outputs': object, 'attentions': object})
     df = df.apply(flatten_dims, result_type='expand', axis=1, meta=RESULT_META)
-
+    
     df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/dataset_generation/stage4_reindexing.py b/dataset_generation/stage4_reindexing.py
new file mode 100644
index 0000000..920d743
--- /dev/null
+++ b/dataset_generation/stage4_reindexing.py
@@ -0,0 +1,27 @@
+# /usr/bin/python3
+from processing import batchify_data
+from dask.diagnostics import ProgressBar
+import dask.dataframe as dd
+from transformers import BertTokenizerFast
+import numpy as np
+import dask
+from dask.distributed import Client
+import pandas as pd
+
+INPUT_FOLDER = "../generated/stage3_exploding"
+OUTPUT_FOLDER = "../generated/stage4_reindexing"
+
+if __name__ == "__main__":
+    client = Client(n_workers=1, memory_limit='60GB')
+    print(client.dashboard_link)
+
+    df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
+
+    df = df.assign(ones=1)
+    df = df.reset_index(drop=True)
+    idx = df.ones.cumsum().persist()
+    df = df.assign(ones=idx)
+    #df = df.assign(idx=df.idx - 1)
+    df = df.set_index('ones')
+ 
+    df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/modeling/train.py b/modeling/train.py
new file mode 100755
index 0000000..72ce861
--- /dev/null
+++ b/modeling/train.py
@@ -0,0 +1,69 @@
+#!/usr/bin/python3
+
+from transformers import BertTokenizer, BertForTokenClassification
+import torch
+from torch.nn import BCEWithLogitsLoss
+import pandas as pd
+from joblib import Memory
+import numpy as np
+import dask.dataframe as dd
+
+INPUT_PATH="../generated/stage4_reindexing"
+MODEL_BASE = "bert-base-multilingual-cased"
+LR = 1e-3
+
+BATCH_SIZE=8
+NUM_EPOCH=5
+
+if __name__ == "__main__":
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    print(f"Training on {device}")
+
+    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
+    
+    tokenizer = BertTokenizer.from_pretrained(MODEL_BASE)
+
+    # TODO: Change num labels
+    model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=6).to(device)
+    
+    criterion = BCEWithLogitsLoss().to(device)
+    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
+
+    for epoch in range(NUM_EPOCH):
+        i = 0
+        while True:
+            # TODO: Change to 0-indexed...
+            data_batch_indexes = list(range(i*BATCH_SIZE+1, i*BATCH_SIZE + BATCH_SIZE +1))
+            
+            # Precomputing total number of samples very long, so lets
+            # try to get next batch until fail :)
+            try:
+                data_batch = df.loc[data_batch_indexes].compute()
+            except:
+                # TODO: Specify exception type
+                break
+
+            inputs = data_batch.apply(lambda x: x['inputs'].reshape(x['inputs_shape']), axis=1).values
+            outputs = data_batch.apply(lambda x: x['outputs'].reshape(x['outputs_shape']), axis=1).values
+            attentions_mask = data_batch.apply(lambda x: x['attentions'].reshape(x['attentions_shape']), axis=1).values
+
+            inputs = torch.tensor(np.stack(inputs).squeeze()).to(device)
+            outputs = torch.tensor(np.stack(outputs)).to(device)
+            attentions_mask = torch.tensor(np.stack(attentions_mask)).to(device)
+
+            # Forward pass: Compute predicted y by passing x to the model
+            y_pred = model(input_ids=inputs, attention_mask=attentions_mask)[0]
+
+            # Compute and print loss
+            loss = criterion(y_pred, outputs)
+            print('epoch: ', epoch,' loss: ', loss.item())
+
+            # Zero gradients, perform a backward pass, and update the weights.
+            optimizer.zero_grad()
+
+            # perform a backward pass (backpropagation)
+            loss.backward()
+
+            # Update the parameters
+            optimizer.step()
+
-- 
GitLab


From 8d56401ef8eca6af8d1465dd3d5471ef391efe36 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 23 Jul 2020 12:29:42 +0200
Subject: [PATCH 011/116] Saving model

---
 modeling/train.py | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/modeling/train.py b/modeling/train.py
index 72ce861..764913a 100755
--- a/modeling/train.py
+++ b/modeling/train.py
@@ -10,10 +10,12 @@ import dask.dataframe as dd
 
 INPUT_PATH="../generated/stage4_reindexing"
 MODEL_BASE = "bert-base-multilingual-cased"
-LR = 1e-3
+MODEL_NAME = "actionv1"
+LR = 1e-4
 
 BATCH_SIZE=8
 NUM_EPOCH=5
+SAVE_STEP=5_000
 
 if __name__ == "__main__":
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -51,19 +53,21 @@ if __name__ == "__main__":
             outputs = torch.tensor(np.stack(outputs)).to(device)
             attentions_mask = torch.tensor(np.stack(attentions_mask)).to(device)
 
-            # Forward pass: Compute predicted y by passing x to the model
             y_pred = model(input_ids=inputs, attention_mask=attentions_mask)[0]
 
-            # Compute and print loss
             loss = criterion(y_pred, outputs)
-            print('epoch: ', epoch,' loss: ', loss.item())
+            print(f'epoch: {epoch} | step: {i} | loss: {loss.item()})
 
-            # Zero gradients, perform a backward pass, and update the weights.
             optimizer.zero_grad()
 
-            # perform a backward pass (backpropagation)
+            if i % SAVE_STEP == 0:
+                print(f"Saving: Epoch {epoch}, step {i}")
+                torch.save(model.state_dict(), f"models/{MODEL_NAME}-{epoch}-{i}.model")
+                torch.save(optimizer.state_dict(), f"models/{MODEL_NAME}-{epoch}-{i}.optimizer")
+
             loss.backward()
 
-            # Update the parameters
             optimizer.step()
 
+            i += 1
+
-- 
GitLab


From 3af9e8369499d5e6cba6ac6877c318297576f76d Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 23 Jul 2020 12:30:51 +0200
Subject: [PATCH 012/116] Origanization files

---
 dataset_generation/__init__.py | 0
 modeling/models/.gitignore     | 2 ++
 2 files changed, 2 insertions(+)
 create mode 100644 dataset_generation/__init__.py
 create mode 100644 modeling/models/.gitignore

diff --git a/dataset_generation/__init__.py b/dataset_generation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/modeling/models/.gitignore b/modeling/models/.gitignore
new file mode 100644
index 0000000..c96a04f
--- /dev/null
+++ b/modeling/models/.gitignore
@@ -0,0 +1,2 @@
+*
+!.gitignore
\ No newline at end of file
-- 
GitLab


From e5e3b515fbb1593f1e08fe4f8c262e7f50e05622 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 23 Jul 2020 12:34:48 +0200
Subject: [PATCH 013/116] Added dvc

---
 .dvc/.gitignore           |  3 +++
 .dvc/config               |  0
 .dvc/plots/confusion.json | 30 ++++++++++++++++++++++++++++++
 .dvc/plots/default.json   | 29 +++++++++++++++++++++++++++++
 .dvc/plots/scatter.json   | 27 +++++++++++++++++++++++++++
 .dvc/plots/smooth.json    | 39 +++++++++++++++++++++++++++++++++++++++
 6 files changed, 128 insertions(+)
 create mode 100644 .dvc/.gitignore
 create mode 100644 .dvc/config
 create mode 100644 .dvc/plots/confusion.json
 create mode 100644 .dvc/plots/default.json
 create mode 100644 .dvc/plots/scatter.json
 create mode 100644 .dvc/plots/smooth.json

diff --git a/.dvc/.gitignore b/.dvc/.gitignore
new file mode 100644
index 0000000..528f30c
--- /dev/null
+++ b/.dvc/.gitignore
@@ -0,0 +1,3 @@
+/config.local
+/tmp
+/cache
diff --git a/.dvc/config b/.dvc/config
new file mode 100644
index 0000000..e69de29
diff --git a/.dvc/plots/confusion.json b/.dvc/plots/confusion.json
new file mode 100644
index 0000000..0d9a333
--- /dev/null
+++ b/.dvc/plots/confusion.json
@@ -0,0 +1,30 @@
+{
+    "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
+    "data": {
+        "values": "<DVC_METRIC_DATA>"
+    },
+    "title": "<DVC_METRIC_TITLE>",
+    "mark": "rect",
+    "encoding": {
+        "x": {
+            "field": "<DVC_METRIC_X>",
+            "type": "nominal",
+            "sort": "ascending",
+            "title": "<DVC_METRIC_X_LABEL>"
+        },
+        "y": {
+            "field": "<DVC_METRIC_Y>",
+            "type": "nominal",
+            "sort": "ascending",
+            "title": "<DVC_METRIC_Y_LABEL>"
+        },
+        "color": {
+            "aggregate": "count",
+            "type": "quantitative"
+        },
+        "facet": {
+            "field": "rev",
+            "type": "nominal"
+        }
+    }
+}
diff --git a/.dvc/plots/default.json b/.dvc/plots/default.json
new file mode 100644
index 0000000..d00782a
--- /dev/null
+++ b/.dvc/plots/default.json
@@ -0,0 +1,29 @@
+{
+    "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
+    "data": {
+        "values": "<DVC_METRIC_DATA>"
+    },
+    "title": "<DVC_METRIC_TITLE>",
+    "mark": {
+        "type": "line"
+    },
+    "encoding": {
+        "x": {
+            "field": "<DVC_METRIC_X>",
+            "type": "quantitative",
+            "title": "<DVC_METRIC_X_LABEL>"
+        },
+        "y": {
+            "field": "<DVC_METRIC_Y>",
+            "type": "quantitative",
+            "title": "<DVC_METRIC_Y_LABEL>",
+            "scale": {
+                "zero": false
+            }
+        },
+        "color": {
+            "field": "rev",
+            "type": "nominal"
+        }
+    }
+}
diff --git a/.dvc/plots/scatter.json b/.dvc/plots/scatter.json
new file mode 100644
index 0000000..90165d4
--- /dev/null
+++ b/.dvc/plots/scatter.json
@@ -0,0 +1,27 @@
+{
+    "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
+    "data": {
+        "values": "<DVC_METRIC_DATA>"
+    },
+    "title": "<DVC_METRIC_TITLE>",
+    "mark": "point",
+    "encoding": {
+        "x": {
+            "field": "<DVC_METRIC_X>",
+            "type": "quantitative",
+            "title": "<DVC_METRIC_X_LABEL>"
+        },
+        "y": {
+            "field": "<DVC_METRIC_Y>",
+            "type": "quantitative",
+            "title": "<DVC_METRIC_Y_LABEL>",
+            "scale": {
+                "zero": false
+            }
+        },
+        "color": {
+            "field": "rev",
+            "type": "nominal"
+        }
+    }
+}
diff --git a/.dvc/plots/smooth.json b/.dvc/plots/smooth.json
new file mode 100644
index 0000000..d497ce7
--- /dev/null
+++ b/.dvc/plots/smooth.json
@@ -0,0 +1,39 @@
+{
+    "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
+    "data": {
+        "values": "<DVC_METRIC_DATA>"
+    },
+    "title": "<DVC_METRIC_TITLE>",
+    "mark": {
+        "type": "line"
+    },
+    "encoding": {
+        "x": {
+            "field": "<DVC_METRIC_X>",
+            "type": "quantitative",
+            "title": "<DVC_METRIC_X_LABEL>"
+        },
+        "y": {
+            "field": "<DVC_METRIC_Y>",
+            "type": "quantitative",
+            "title": "<DVC_METRIC_Y_LABEL>",
+            "scale": {
+                "zero": false
+            }
+        },
+        "color": {
+            "field": "rev",
+            "type": "nominal"
+        }
+    },
+    "transform": [
+        {
+            "loess": "<DVC_METRIC_Y>",
+            "on": "<DVC_METRIC_X>",
+            "groupby": [
+                "rev"
+            ],
+            "bandwidth": 0.3
+        }
+    ]
+}
-- 
GitLab


From 36f0c657a84cb0cd30c397ea316c4f4f1bf1b805 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 23 Jul 2020 12:57:50 +0200
Subject: [PATCH 014/116] stop tracking data

---
 data/download_dataset.sh | 4 ----
 1 file changed, 4 deletions(-)
 delete mode 100755 data/download_dataset.sh

diff --git a/data/download_dataset.sh b/data/download_dataset.sh
deleted file mode 100755
index 5c70a48..0000000
--- a/data/download_dataset.sh
+++ /dev/null
@@ -1,4 +0,0 @@
-#!/bin/bash
-wget http://manage.legis.nlp.ipipan.waw.pl/download/ppc-nanno.tar.gz
-tar -xvf ppc-nanno.tar.gz
-rm ppc-nanno.tar.gz
-- 
GitLab


From 5a031d45240732c606df844f10b20749009f41f3 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Fri, 24 Jul 2020 14:24:21 +0200
Subject: [PATCH 015/116] Cleanup, dvc support

---
 .dvc/config                                   |   4 +
 .gitignore                                    |   6 +-
 Dockerfile                                    |   4 +
 dataset_generation/__init__.py => __init__.py |   0
 data.dvc                                      |   3 +
 dataset_generation/utils.py                   |  14 --
 download_dataset.sh                           |   4 +
 dvc.lock                                      |  47 ++++++
 dvc.yaml                                      |  34 ++++
 generated/.gitignore                          |   6 +-
 modeling/BertTokenMultilabel.py               |  12 ++
 notebooks/test.ipynb                          | 156 ++++++++++++++++++
 notebooks/test_training_files.ipynb           | 137 +++++++++++++++
 params.yaml                                   |  21 +++
 scripts/__init__.py                           |   0
 .../dataset_generation}/.gitignore            |   0
 scripts/dataset_generation/__init__.py        |   0
 .../dataset_generation}/stage1_extraction.py  |  23 +--
 .../stage2_tokenization.py                    |  27 +--
 .../dataset_generation}/stage3_exploding.py   |  13 +-
 .../dataset_generation}/stage4_reindexing.py  |  13 +-
 {modeling => scripts}/train.py                |  33 +++-
 src/__init__.py                               |   0
 {dataset_generation => src}/processing.py     |  58 +++++--
 .../test_processing.py                        |  49 +++++-
 src/utils.py                                  |  41 +++++
 26 files changed, 636 insertions(+), 69 deletions(-)
 create mode 100644 Dockerfile
 rename dataset_generation/__init__.py => __init__.py (100%)
 create mode 100644 data.dvc
 delete mode 100644 dataset_generation/utils.py
 create mode 100755 download_dataset.sh
 create mode 100644 dvc.lock
 create mode 100644 dvc.yaml
 create mode 100644 modeling/BertTokenMultilabel.py
 create mode 100644 notebooks/test.ipynb
 create mode 100644 notebooks/test_training_files.ipynb
 create mode 100644 params.yaml
 create mode 100644 scripts/__init__.py
 rename {dataset_generation => scripts/dataset_generation}/.gitignore (100%)
 create mode 100644 scripts/dataset_generation/__init__.py
 rename {dataset_generation => scripts/dataset_generation}/stage1_extraction.py (69%)
 rename {dataset_generation => scripts/dataset_generation}/stage2_tokenization.py (64%)
 rename {dataset_generation => scripts/dataset_generation}/stage3_exploding.py (80%)
 rename {dataset_generation => scripts/dataset_generation}/stage4_reindexing.py (57%)
 rename {modeling => scripts}/train.py (75%)
 create mode 100644 src/__init__.py
 rename {dataset_generation => src}/processing.py (88%)
 rename {dataset_generation => src}/test_processing.py (78%)
 create mode 100644 src/utils.py

diff --git a/.dvc/config b/.dvc/config
index e69de29..c226a2a 100644
--- a/.dvc/config
+++ b/.dvc/config
@@ -0,0 +1,4 @@
+['remote "newremote"']
+    url = s3://punctuation/actions
+    endpointurl = https://minio.clarin-pl.eu/minio
+    profile = clarinpl
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 0bdd65c..11deaf7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,5 +4,7 @@ dataset_actions
 **/dask-worker-space
 .vscode
 .idea
-test.ipynb
-.metals
\ No newline at end of file
+.metals
+/data
+__pycache__
+.pytest_cache
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..d31b532
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,4 @@
+FROM ubuntu:20.04
+
+RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install python3 -y python3-pip
+RUN pip3 install numpy pandas "dask[complete]" torch transformers
\ No newline at end of file
diff --git a/dataset_generation/__init__.py b/__init__.py
similarity index 100%
rename from dataset_generation/__init__.py
rename to __init__.py
diff --git a/data.dvc b/data.dvc
new file mode 100644
index 0000000..eb543e5
--- /dev/null
+++ b/data.dvc
@@ -0,0 +1,3 @@
+outs:
+- md5: 1fa175e752af1638dc896838e82a9d7d.dir
+  path: data
diff --git a/dataset_generation/utils.py b/dataset_generation/utils.py
deleted file mode 100644
index dc009d7..0000000
--- a/dataset_generation/utils.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import glob
-import random
-from lxml import etree
-import uuid
-import hashlib
-import seaborn as sns
-import re
-from tqdm import tqdm
-
-def remove_multiple_spaces(x: str) -> str:
-    return re.sub("\s\s+", " ", x)
-
-def remove_punctuation(x: str) -> str:
-    return ''.join(filter(lambda x: x.isalnum() or x.isspace(), x))
diff --git a/download_dataset.sh b/download_dataset.sh
new file mode 100755
index 0000000..5c70a48
--- /dev/null
+++ b/download_dataset.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+wget http://manage.legis.nlp.ipipan.waw.pl/download/ppc-nanno.tar.gz
+tar -xvf ppc-nanno.tar.gz
+rm ppc-nanno.tar.gz
diff --git a/dvc.lock b/dvc.lock
new file mode 100644
index 0000000..23b393d
--- /dev/null
+++ b/dvc.lock
@@ -0,0 +1,47 @@
+extraction:
+  cmd: python3 -m scripts.dataset_generation.stage1_extraction
+  deps:
+  - path: data
+    md5: 1fa175e752af1638dc896838e82a9d7d.dir
+  - path: scripts/dataset_generation/stage1_extraction.py
+    md5: b5256e47e54f55fd406f23889a9cbca9
+  params:
+    params.yaml:
+      extraction.num_partitions: 2000
+  outs:
+  - path: generated/stage1_extraction
+    md5: c33e5a857a8de3bce69bfc8636f64854.dir
+tokenization:
+  cmd: python3 -m scripts.dataset_generation.stage2_tokenization
+  deps:
+  - path: generated/stage1_extraction
+    md5: c33e5a857a8de3bce69bfc8636f64854.dir
+  - path: scripts/dataset_generation/stage2_tokenization.py
+    md5: 1afe768315b818e3a051c976cf19d2f3
+  params:
+    params.yaml:
+      tokenization.max_tokens: 500
+      tokenization.min_tokens: 10
+  outs:
+  - path: generated/stage2_tokenization
+    md5: 1b98b64ecf98ec74c446721256182539.dir
+exploding:
+  cmd: python3 -m scripts.dataset_generation.stage3_exploding
+  deps:
+  - path: generated/stage2_tokenization
+    md5: 1b98b64ecf98ec74c446721256182539.dir
+  - path: scripts/dataset_generation/stage3_exploding.py
+    md5: 490843d650534f09003480d26fde2390
+  outs:
+  - path: generated/stage3_exploding
+    md5: 688ce8926016ed49154b088850be6cff.dir
+reindexing:
+  cmd: python3 -m scripts.dataset_generation.stage4_reindexing
+  deps:
+  - path: generated/stage3_exploding
+    md5: 688ce8926016ed49154b088850be6cff.dir
+  - path: scripts/dataset_generation/stage4_reindexing.py
+    md5: 342a0fd49f45c3d9ff9b3a701f6ebb7d
+  outs:
+  - path: generated/stage4_reindexing
+    md5: 9e797430fe072a60e778e191606a9952.dir
diff --git a/dvc.yaml b/dvc.yaml
new file mode 100644
index 0000000..ec9d4e9
--- /dev/null
+++ b/dvc.yaml
@@ -0,0 +1,34 @@
+stages:
+  extraction:
+    cmd: python3 -m scripts.dataset_generation.stage1_extraction
+    deps:
+    - data
+    - scripts/dataset_generation/stage1_extraction.py
+    params:
+    - extraction.num_partitions
+    outs:
+    - generated/stage1_extraction
+  tokenization:
+    cmd: python3 -m scripts.dataset_generation.stage2_tokenization
+    deps:
+    - generated/stage1_extraction
+    - scripts/dataset_generation/stage2_tokenization.py
+    params:
+    - tokenization.max_tokens
+    - tokenization.min_tokens
+    outs:
+    - generated/stage2_tokenization
+  exploding:
+    cmd: python3 -m scripts.dataset_generation.stage3_exploding
+    deps:
+    - generated/stage2_tokenization
+    - scripts/dataset_generation/stage3_exploding.py
+    outs:
+    - generated/stage3_exploding
+  reindexing:
+    cmd: python3 -m scripts.dataset_generation.stage4_reindexing
+    deps:
+    - generated/stage3_exploding
+    - scripts/dataset_generation/stage4_reindexing.py
+    outs:
+    - generated/stage4_reindexing
diff --git a/generated/.gitignore b/generated/.gitignore
index c96a04f..959c3a4 100644
--- a/generated/.gitignore
+++ b/generated/.gitignore
@@ -1,2 +1,4 @@
-*
-!.gitignore
\ No newline at end of file
+/stage1_extraction
+/stage2_tokenization
+/stage3_exploding
+/stage4_reindexing
diff --git a/modeling/BertTokenMultilabel.py b/modeling/BertTokenMultilabel.py
new file mode 100644
index 0000000..f57e47f
--- /dev/null
+++ b/modeling/BertTokenMultilabel.py
@@ -0,0 +1,12 @@
+import torch.nn as nn
+from transformers import BertForTokenClassification
+
+class BertTokenMultilabel(nn.Module):
+    def __init__(self, base_model: str, num_labels: int):
+        super(BertTokenMultilabel, self).__init__()
+        self.base_model = BertForTokenClassification.from_pretrained(base_model, num_labels=num_labels)
+
+        self.add_module(self.base_model)
+
+    def forward():
+        
\ No newline at end of file
diff --git a/notebooks/test.ipynb b/notebooks/test.ipynb
new file mode 100644
index 0000000..9e81e62
--- /dev/null
+++ b/notebooks/test.ipynb
@@ -0,0 +1,156 @@
+{
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.2-final"
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python38264bita7d7da14168440cb9836372958035d4a",
+   "display_name": "Python 3.8.2 64-bit"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 113,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import sys\n",
+    "sys.path.append(\"../\")\n",
+    "\n",
+    "from transformers import BertTokenizerFast, BertForTokenClassification\n",
+    "import torch\n",
+    "from torch.nn import BCEWithLogitsLoss\n",
+    "import pandas as pd\n",
+    "import numpy as np\n",
+    "import dask.dataframe as dd\n",
+    "\n",
+    "from src.processing import create_model_input_output"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 114,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "INPUT_PATH=\"../generated/stage4_reindexing\"\n",
+    "MODEL_BASE = \"bert-base-multilingual-cased\"\n",
+    "MODEL_NAME = \"actionv1\"\n",
+    "LR = 1e-4\n",
+    "\n",
+    "BATCH_SIZE=8\n",
+    "NUM_EPOCH=5\n",
+    "SAVE_STEP=5_000"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 115,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tokenizer = BertTokenizerFast.from_pretrained(MODEL_BASE)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 116,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "expected = \"To nie amunicja. To krew przygotowana do wysyłki na front. W Lublinie uruchomiono pierwszy ośrodek przetaczania krwi. Dobrowolna ofiara krwi tych dziewcząt ratuje życie tysiącom rannych żołnierzy.\"\n",
+    "text_clean = create_model_input_output(expected)[0]\n",
+    "\n",
+    "inputs = tokenizer(text_clean, return_tensors=\"pt\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 117,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stderr",
+     "text": "Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\nSome weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+    },
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "<All keys matched successfully>"
+     },
+     "metadata": {},
+     "execution_count": 117
+    }
+   ],
+   "source": [
+    "model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=6)\n",
+    "model.load_state_dict(torch.load(\"models/actionv1-0-50000.model\"))\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 118,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "{'input_ids': tensor([[   101,  10114,  11058,  10392,  23124,  14083,  10114,  50302,  26127,\n          14052,  47163,  26277,  10149,    191,  12682,  72855,  10506,  10132,\n          14589,    191,  13455,  31998,  10399,  20591,  32013,  12507,  19458,\n            183,  78190,  20157,  17931,  18306,  46136,  50302,  15926,  37908,\n          16828,  24626,  10108,  85286,  50302,  15926,  23453,    172,  14548,\n          10874,  55239,  10123,  45840,  10381,  49934,  34720,  28702,  22530,\n          17044,  12837,  40938,  19626,  81422,  10400,  23311,  50302,  15926,\n          20157,  17931,  18306,  50851,  50302,  15926,  10973,  95230,  11297,\n          13863,  22729,  10171,  10238,    177,  82413,  11058,  17272,  80852,\n          10390,  12577,  12197,    194,  82322,  10361,  50302,  13362,  10284,\n            348,  12097,  83978,  74780,  10149,  58133,  11530,  12741,  15183,\n          14052,  54609,  11624,  84269,  10621,  73121,  13050,  10132,  82992,\n            191,  35327,  42041,  10112,    191,  12211,  62187,  10419,  68037,\n          44227,  35090,  13717,  27828,  10449,  87042,  11133,  77029,  22578,\n          67405,  11717,  78098,  12294,  35779,  10108,  85286,  50302,  15926,\n          14052, 100963,  10280,  10424,  10149,    194,  12524,  10598,  73243,\n          14916,  73837,  22555,  89484,  10108,  21501,  21907,  34582,  24203,\n          64199,  21838,  10418,  92822, 110206,  10113,    191,  27652,  11679,\n          25175,  11877,  19495,    191,  19888,  10514,  33619,  10730,  17249,\n          10229,  10220,    186, 102075,  15050,  23090,  42155,  46845,  10149,\n          27648,  10514,  56423,  17249,    183,  10418,  17771,  60519,    102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
+     },
+     "metadata": {},
+     "execution_count": 118
+    }
+   ],
+   "source": [
+    "inputs"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 125,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": "To nie amunicja. To krew przygotowana do wysyłki na front. W Lublinie uruchomiono pierwszy ośrodek przetaczania krwi. Dobrowolna ofiara krwi tych dziewcząt ratuje życie tysiącom rannych żołnierzy. Ustalanie grupy krwi. Przetaczanie krwi jest absolutnie bezbolesne i niemal nieszkodliwe dla zdrowia. Te krople ściekające do kolby tak samo przyspieszają zwycięstwo, jak naboje wyrabiane w fabryce broni. Niemiec poprzysiągł narodowi polskiemu zagładę. Ofiara krwi przyczynia się do zwycięstwa, zmniejsza liczbę ofiar katastrofy wojennej, przeciwdziała wyludnieniu Polski, w której po wojnie każda para rąk potrzebna będzie do pracy. Pomóż ojczyźnie!\nTo nie amunicja, to krew przygotowana do wysyłki na Front w Lublinie, uruchomiono pierwszy Ośrodek Przetaczania Krwi, dobrowolna, ofiara, krwi tych dziewcząt, ratuje życie Tysiącom, rannych, żołnierzy, ustalanie grupy, krwi przetaczanie krwi, jest absolutnie bezbolesne i niemal Nieszkodliwe dla Zdrowia, Te krople, ściekające do kolby, tak, samo, przyspieszają zwycięstwo, jak naboje wyrabiane w fabryce Broni, Niemiec, poprzysiągł narodowi polskiemu, Zagładę, Ofiara Krwi, przyczynia się do zwycięstwa zmniejsza liczbę Ofiar Katastrofy, wojennej, przeciwdziała wyludnieniu Polski, w której po wojnie, każda para, Rąk, potrzebna będzie do, pracy, pomóż Ojczyźnie,\n"
+    }
+   ],
+   "source": [
+    "from src.processing import token_labels_to_word_labels, recover_text\n",
+    "\n",
+    "y_pred = model(**inputs)[0].sigmoid()\n",
+    "labels_pred = token_labels_to_word_labels(text_clean, y_pred.detach().numpy()[0, 1:-1, :], tokenizer)\n",
+    "\n",
+    "actions = labels_pred > 0.2\n",
+    "print(expected)\n",
+    "print(recover_text(text_clean, actions))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ]
+}
\ No newline at end of file
diff --git a/notebooks/test_training_files.ipynb b/notebooks/test_training_files.ipynb
new file mode 100644
index 0000000..c8ee4da
--- /dev/null
+++ b/notebooks/test_training_files.ipynb
@@ -0,0 +1,137 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from src.processing import text_from_xml, create_model_input_output, token_labels_to_word_labels, recover_text\n",
+    "import glob\n",
+    "import numpy as np\n",
+    "from dask.diagnostics import ProgressBar\n",
+    "import dask.dataframe as dd\n",
+    "import dask\n",
+    "import pandas as pd\n",
+    "from dask.distributed import Client\n",
+    "import gc\n",
+    "from memory_profiler import profile\n",
+    "import pyspark\n",
+    "from pyspark.sql import SparkSession, Row, udf\n",
+    "from pyspark.sql.types import ArrayType, IntegerType\n",
+    "from pyspark.sql.types import StructType, StructField\n",
+    "from pyspark.mllib.linalg import Vectors\n",
+    "from transformers import BertTokenizerFast, BertForTokenClassification\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "INPUT_FOLDER=\"generated/stage4_reindexing\"\n",
+    "MODEL_BASE = \"bert-base-multilingual-cased\"\n",
+    "\n",
+    "tokenizer = BertTokenizerFast.from_pretrained(MODEL_BASE)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "                                                  inputs  \\\nones                                                       \n92132  [101, 12644, 82233, 10451, 13863, 48616, 10797...   \n\n                                                 outputs  \\\nones                                                       \n92132  [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ...   \n\n                                              attentions inputs_shape  \\\nones                                                                    \n92132  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...      [50, 1]   \n\n      outputs_shape attentions_shape  \nones                                  \n92132       [50, 6]             [50]  ",
+      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>inputs</th>\n      <th>outputs</th>\n      <th>attentions</th>\n      <th>inputs_shape</th>\n      <th>outputs_shape</th>\n      <th>attentions_shape</th>\n    </tr>\n    <tr>\n      <th>ones</th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>92132</th>\n      <td>[101, 12644, 82233, 10451, 13863, 48616, 10797...</td>\n      <td>[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ...</td>\n      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...</td>\n      <td>[50, 1]</td>\n      <td>[50, 6]</td>\n      <td>[50]</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
+     },
+     "metadata": {},
+     "execution_count": 7
+    }
+   ],
+   "source": [
+    "sample = df.loc[92132].compute()\n",
+    "sample"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": "(50, 6)\n"
+    },
+    {
+     "output_type": "error",
+     "ename": "AssertionError",
+     "evalue": "",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-20-b333c374ea16>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_outputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mlabels_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoken_labels_to_word_labels\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext_clean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0mactions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabels_pred\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0.1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/projekty/clarin/interpunkcja/src/processing.py\u001b[0m in \u001b[0;36mtoken_labels_to_word_labels\u001b[0;34m(text, token_labels, tokenizer)\u001b[0m\n\u001b[1;32m    165\u001b[0m     \u001b[0mmapping\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoken_word_mapping\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    166\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m     \u001b[0;32massert\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmapping\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtoken_labels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    169\u001b[0m     \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdefaultdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mAssertionError\u001b[0m: "
+     ]
+    }
+   ],
+   "source": [
+    "sample_inputs = sample['inputs'].values[0].reshape((50))\n",
+    "sample_outputs = sample['outputs'].values[0].reshape((50, 6))\n",
+    "sample_attentions = sample['outputs'].values[0].reshape((50))\n",
+    "\n",
+    "length = np.sum(sample_attentions)\n",
+    "\n",
+    "text_clean = tokenizer.decode(sample_inputs)\n",
+    "\n",
+    "labels_pred = token_labels_to_word_labels(text_clean, sample_outputs[1:-1, :], tokenizer)\n",
+    "\n",
+    "actions = labels_pred > 0.1\n",
+    "recover_text(text_clean, actions)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.2-final"
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python38264bita7d7da14168440cb9836372958035d4a",
+   "display_name": "Python 3.8.2 64-bit"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
\ No newline at end of file
diff --git a/params.yaml b/params.yaml
new file mode 100644
index 0000000..d8aefcb
--- /dev/null
+++ b/params.yaml
@@ -0,0 +1,21 @@
+global:
+    dashboard_port: 8787
+
+extraction:
+    num_partitions: 2_000
+    num_workers: 24
+    worker_memory_limit: "2GB"
+
+tokenization:
+    min_tokens: 10
+    max_tokens: 500
+    num_workers: 24
+    worker_memory_limit: "2GB"
+
+exploding:
+    num_workers: 24
+    worker_memory_limit: "2GB"
+
+reindexing:
+    num_workers: 1
+    worker_memory_limit: "60GB"
\ No newline at end of file
diff --git a/scripts/__init__.py b/scripts/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/dataset_generation/.gitignore b/scripts/dataset_generation/.gitignore
similarity index 100%
rename from dataset_generation/.gitignore
rename to scripts/dataset_generation/.gitignore
diff --git a/scripts/dataset_generation/__init__.py b/scripts/dataset_generation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/dataset_generation/stage1_extraction.py b/scripts/dataset_generation/stage1_extraction.py
similarity index 69%
rename from dataset_generation/stage1_extraction.py
rename to scripts/dataset_generation/stage1_extraction.py
index 924e46f..79db141 100644
--- a/dataset_generation/stage1_extraction.py
+++ b/scripts/dataset_generation/stage1_extraction.py
@@ -1,7 +1,7 @@
 # /usr/bin/python3
 import glob
 import numpy as np
-from processing import text_from_xml, create_model_input_output
+from src.processing import text_from_xml, create_model_input_output
 from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
 import dask
@@ -12,10 +12,10 @@ from memory_profiler import profile
 from pympler import muppy, summary
 import stackimpact
 import lorem
+from src.utils import get_config, PROJECT_ROOT
 
-OUTPUT_FOLDER = "../generated/stage1_extraction"
-NUM_PARTITIONS = 2_000
-NUM_WORKERS=24
+GENERATED_FOLDER = "generated"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/{GENERATED_FOLDER}/stage1_extraction"
 
 def process_file(x):
     full_text = text_from_xml(x.file)
@@ -30,24 +30,27 @@ def process_file(x):
         return {'input': None, 'output': None, 'output_shape': None}
 
 if __name__ == "__main__":
-    file_schema = "../data/**/text_structure.xml"
+
+    config = get_config()
+    num_partitions = config['extraction']['num_partitions']
+    num_workers = config['extraction']['num_workers']
+    memory_limit = config['extraction']['worker_memory_limit']
+
+    file_schema = "data/**/text_structure.xml"
     files_paths = glob.glob(file_schema, recursive=True)
 
     # Make sure python memory fragmentation won't go insane
     np.random.shuffle(files_paths)
 
-    client = Client(n_workers=NUM_WORKERS)
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
     print(f"Dashboard: {client.dashboard_link}")
 
     # Processing pipeline
-    df = dd.from_pandas(pd.DataFrame({'file': files_paths}), npartitions=NUM_PARTITIONS)
+    df = dd.from_pandas(pd.DataFrame({'file': files_paths}), npartitions=num_partitions)
 
     df = df.apply(process_file, result_type='expand', axis=1, meta={
                    'input': str, 'output': object, 'output_shape': object})
     df = df.dropna()
 
-    # Even out sizes
-    # df = df.repartition(partition_size="100MB")
-
     # Export
     df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
\ No newline at end of file
diff --git a/dataset_generation/stage2_tokenization.py b/scripts/dataset_generation/stage2_tokenization.py
similarity index 64%
rename from dataset_generation/stage2_tokenization.py
rename to scripts/dataset_generation/stage2_tokenization.py
index 88ea96d..27482e1 100644
--- a/dataset_generation/stage2_tokenization.py
+++ b/scripts/dataset_generation/stage2_tokenization.py
@@ -9,22 +9,18 @@ import seaborn as sns
 import re
 import numpy as np
 from tqdm import tqdm
-from processing import tokenize_labeled_text, batchify_data
-from utils import remove_multiple_spaces, remove_punctuation
+from src.processing import tokenize_labeled_text, batchify_data
+from src.utils import remove_multiple_spaces, remove_punctuation, PROJECT_ROOT, get_config
 import dask
 from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
 from transformers import BertTokenizerFast
 from dask.distributed import Client
 
-INPUT_FOLDER = "../generated/stage1_extraction"
-OUTPUT_FOLDER = "../generated/stage2_tokenization"
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/stage1_extraction"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/stage2_tokenization"
 
-MIN_TOKENS = 10
-MAX_TOKENS = 50
-
-
-def apply_tokenization(df, tokenizer: BertTokenizerFast):
+def apply_tokenization(df, min_tokens: int, max_tokens: int, tokenizer: BertTokenizerFast):
     text_clean = df.input
     labels = df.output
     shape = df.output_shape
@@ -33,7 +29,7 @@ def apply_tokenization(df, tokenizer: BertTokenizerFast):
         text_clean, labels.reshape(shape), tokenizer)
 
     inputs, outputs, attentions = batchify_data(
-        tokens, token_labels, MAX_TOKENS, tokenizer, MIN_TOKENS)
+        tokens, token_labels, max_tokens, tokenizer, min_tokens)
 
     inputs_shape = np.array(inputs.shape)
     outputs_shape = np.array(outputs.shape)
@@ -59,7 +55,14 @@ RESULT_META = {
 }
 
 if __name__ == "__main__":
-    client = Client(n_workers=24)
+
+    config = get_config()
+    max_tokens = config['tokenization']['max_tokens']
+    min_tokens = config['tokenization']['min_tokens']
+    num_workers = config['tokenization']['num_workers']
+    memory_limit = config['tokenization']['worker_memory_limit']
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
     print(client.dashboard_link)
 
     tokenizer = BertTokenizerFast.from_pretrained(
@@ -68,7 +71,7 @@ if __name__ == "__main__":
     tokenizer = dask.delayed(tokenizer)
 
     df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
-    df = df.apply(apply_tokenization, args=(tokenizer,),
+    df = df.apply(apply_tokenization, args=(min_tokens, max_tokens, tokenizer),
                   result_type='expand', axis=1, meta=RESULT_META)
 
     df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/dataset_generation/stage3_exploding.py b/scripts/dataset_generation/stage3_exploding.py
similarity index 80%
rename from dataset_generation/stage3_exploding.py
rename to scripts/dataset_generation/stage3_exploding.py
index bfd46bb..bd6d0af 100644
--- a/dataset_generation/stage3_exploding.py
+++ b/scripts/dataset_generation/stage3_exploding.py
@@ -1,5 +1,5 @@
 # /usr/bin/python3
-from processing import batchify_data
+from src.processing import batchify_data
 from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
 from transformers import BertTokenizerFast
@@ -7,9 +7,10 @@ import numpy as np
 import dask
 from dask.distributed import Client
 import pandas as pd
+from src.utils import PROJECT_ROOT, get_config
 
-INPUT_FOLDER = "../generated/stage2_tokenization"
-OUTPUT_FOLDER = "../generated/stage3_exploding"
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/stage2_tokenization"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/stage3_exploding"
 
 def expand_dims(entry):
     inputs = entry.inputs.reshape(entry.input_shape)
@@ -51,7 +52,11 @@ RESULT_META = {
 }
 
 if __name__ == "__main__":
-    client = Client(n_workers=24, memory_limit='2GB')
+    config = get_config()
+    num_workers = config['tokenization']['num_workers']
+    memory_limit = config['tokenization']['worker_memory_limit']
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
     print(client.dashboard_link)
 
     df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
diff --git a/dataset_generation/stage4_reindexing.py b/scripts/dataset_generation/stage4_reindexing.py
similarity index 57%
rename from dataset_generation/stage4_reindexing.py
rename to scripts/dataset_generation/stage4_reindexing.py
index 920d743..57182f3 100644
--- a/dataset_generation/stage4_reindexing.py
+++ b/scripts/dataset_generation/stage4_reindexing.py
@@ -1,5 +1,5 @@
 # /usr/bin/python3
-from processing import batchify_data
+from src.processing import batchify_data
 from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
 from transformers import BertTokenizerFast
@@ -7,12 +7,17 @@ import numpy as np
 import dask
 from dask.distributed import Client
 import pandas as pd
+from src.utils import PROJECT_ROOT, get_config
 
-INPUT_FOLDER = "../generated/stage3_exploding"
-OUTPUT_FOLDER = "../generated/stage4_reindexing"
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/stage3_exploding"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/stage4_reindexing"
 
 if __name__ == "__main__":
-    client = Client(n_workers=1, memory_limit='60GB')
+    config = get_config()
+    num_workers = config['tokenization']['num_workers']
+    memory_limit = config['tokenization']['worker_memory_limit']
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
     print(client.dashboard_link)
 
     df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
diff --git a/modeling/train.py b/scripts/train.py
similarity index 75%
rename from modeling/train.py
rename to scripts/train.py
index 764913a..5ced9ee 100755
--- a/modeling/train.py
+++ b/scripts/train.py
@@ -4,7 +4,6 @@ from transformers import BertTokenizer, BertForTokenClassification
 import torch
 from torch.nn import BCEWithLogitsLoss
 import pandas as pd
-from joblib import Memory
 import numpy as np
 import dask.dataframe as dd
 
@@ -16,6 +15,9 @@ LR = 1e-4
 BATCH_SIZE=8
 NUM_EPOCH=5
 SAVE_STEP=5_000
+AVERAGE_SPAN = 1_000
+
+LOAD = "0-50000"
 
 if __name__ == "__main__":
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -27,12 +29,28 @@ if __name__ == "__main__":
 
     # TODO: Change num labels
     model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=6).to(device)
-    
     criterion = BCEWithLogitsLoss().to(device)
     optimizer = torch.optim.Adam(model.parameters(), lr=LR)
 
-    for epoch in range(NUM_EPOCH):
-        i = 0
+
+    epoch_start = 0
+    sample_start = 0
+    if LOAD is not None:
+        model.load_state_dict(torch.load(f"models/{MODEL_NAME}-{LOAD}.model"))
+        #optimizer.load_state_dict(torch.load(f"models/{MODEL_NAME}-{LOAD}.optimizer"))
+
+        epoch_start, sample_start = LOAD.split("-")
+        epoch_start = int(epoch_start)
+        sample_start = int(sample_start)
+
+        print(f"Loaded {MODEL_NAME}-{LOAD}")
+
+    model.train()
+
+    losses = []
+
+    for epoch in range(epoch_start, NUM_EPOCH):
+        i = sample_start
         while True:
             # TODO: Change to 0-indexed...
             data_batch_indexes = list(range(i*BATCH_SIZE+1, i*BATCH_SIZE + BATCH_SIZE +1))
@@ -56,7 +74,12 @@ if __name__ == "__main__":
             y_pred = model(input_ids=inputs, attention_mask=attentions_mask)[0]
 
             loss = criterion(y_pred, outputs)
-            print(f'epoch: {epoch} | step: {i} | loss: {loss.item()})
+
+            losses.append(loss.item())
+            if len(losses) > AVERAGE_SPAN:
+                losses = losses[-AVERAGE_SPAN:]
+
+            print(f'epoch: {epoch} | step: {i} | loss: {np.mean(losses)}')
 
             optimizer.zero_grad()
 
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/dataset_generation/processing.py b/src/processing.py
similarity index 88%
rename from dataset_generation/processing.py
rename to src/processing.py
index 0b78970..e7b302a 100644
--- a/dataset_generation/processing.py
+++ b/src/processing.py
@@ -1,17 +1,10 @@
 import glob
-import random
 from xml.etree import ElementTree as ET
-import uuid
-import hashlib
-import seaborn as sns
-import re
-from tqdm import tqdm
 from typing import Optional, Mapping
-from utils import remove_punctuation
+from src.utils import remove_punctuation
 import numpy as np
-from more_itertools import windowed
-from transformers import PreTrainedTokenizerFast, BertTokenizerFast
-from memory_profiler import profile
+from transformers import PreTrainedTokenizerFast
+from collections import defaultdict
 
 ACTIONS_KEYS = ['dot', 'upper_case', 'colon', 'semicolon', 'elipsis', 'dash']
 
@@ -141,6 +134,47 @@ def create_model_input_output(text: str) -> (str, np.ndarray):
 
     return " ".join(words_output), np.array(actions_output)
 
+def token_word_mapping(text: str, tokenizer: PreTrainedTokenizerFast) -> np.ndarray:
+    """Returns mapping where each token is labeled with index of word it's part of
+
+    Args:
+        text (str): Input text
+        tokenizer (PreTrainedTokenizerFast): Tokenizer used to tokenize text
+
+    Returns:
+        np.ndarray: Array of length L (number of tokens) where each entry is index of word (cls and sep labels are not counted).
+    """
+    text_tokenized = tokenizer(text, return_offsets_mapping=True)
+    offset_mappings = text_tokenized['offset_mapping'][1:-1]
+
+    offset_mappings = text_tokenized['offset_mapping'][1:-1]
+
+    # Create a map where each character is assigned index of it's word
+    words_mapping = []
+    actual_word = 0
+    for character in text:
+        words_mapping.append(actual_word)
+        if character == " ":
+            actual_word += 1
+
+    token_mapping = [words_mapping[x[0]] for x in offset_mappings]
+
+    return np.array(token_mapping)
+
+def token_labels_to_word_labels(text: str, token_labels: np.ndarray, tokenizer: PreTrainedTokenizerFast) -> np.ndarray:
+    mapping = token_word_mapping(text, tokenizer)
+
+    assert len(mapping) == len(token_labels)
+
+    labels = defaultdict(list)
+
+    for i in range(len(mapping)):
+        labels[mapping[i]].append(token_labels[i])
+
+    return np.array([
+        np.mean(labels[x], axis=0) for x in sorted(labels)
+    ])
+
 def tokenize_labeled_text(text: str, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast) -> (np.ndarray, np.ndarray):
     """Transforms text into numerical tokens. Also expand word-level labels into token-level labels
 
@@ -190,7 +224,7 @@ def recover_word(word: str, action: Mapping[str, bool]) -> str:
     if action['dot']:
         word_result += "."
     if action['upper_case']:
-        word_result[0] = word_result[0].upper()
+        word_result = word_result.capitalize()
     if action['colon']:
         word_result += ","
     if action['semicolon']:
@@ -200,7 +234,7 @@ def recover_word(word: str, action: Mapping[str, bool]) -> str:
     if action['dash']:
         word_result += " -"
 
-    return word
+    return word_result
 
 def is_sentence_end(actions_encoded: np.ndarray) -> bool:
     """Returns if given action would end a sentence
diff --git a/dataset_generation/test_processing.py b/src/test_processing.py
similarity index 78%
rename from dataset_generation/test_processing.py
rename to src/test_processing.py
index 13f5b72..64ef6d7 100644
--- a/dataset_generation/test_processing.py
+++ b/src/test_processing.py
@@ -1,6 +1,6 @@
-from processing import *
+from src.processing import *
 from transformers import PreTrainedTokenizerFast, BertTokenizerFast
-
+import pytest
 
 def test_detect_actions():
     actions = detect_actions("Janek...", None)
@@ -59,6 +59,37 @@ def test_decode_actions():
         'dash': True
     }
 
+def test_token_word_mapping():
+    text = "janek poszedł do ogrodu"
+    tokenizer = BertTokenizerFast.from_pretrained(
+        'bert-base-multilingual-cased')
+
+    text_tokenized = tokenizer(text)
+
+    mapping = token_word_mapping(text, tokenizer)
+
+    assert len(mapping) == (len(text_tokenized['input_ids']) - 2)
+    assert min(mapping) == 0
+    assert max(mapping) == 3
+
+def test_token_labels_to_word_labels():
+    text = "janek poszedł do ogrodu"
+    labels = np.array([
+        [0, 0, 0],
+        [1, 0, 0],
+        [0, 1, 0],
+        [0, 0, 1]
+    ])
+    tokenizer = BertTokenizerFast.from_pretrained(
+        'bert-base-multilingual-cased')
+
+    tokens, token_labels = tokenize_labeled_text(text, labels, tokenizer)
+
+    mapping = token_word_mapping(text, tokenizer)
+    word_labels = token_labels_to_word_labels(text, token_labels, tokenizer)
+
+    assert np.all(np.vectorize(pytest.approx)(word_labels, labels)) == True
+    
 
 def test_tokenize_labeled_text():
     text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię..."
@@ -78,6 +109,17 @@ def test_tokenize_labeled_text():
     assert tokens[0, 0] != tokenizer.cls_token_id
     assert tokens[-1, 0] != tokenizer.sep_token_id
 
+def test_recover_text():
+    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię..."
+    tokenizer = BertTokenizerFast.from_pretrained(
+        'bert-base-multilingual-cased')
+
+    text_clean, word_labels = create_model_input_output(text)
+
+    result_text = recover_text(text_clean, word_labels)
+
+    assert result_text == text
+
 
 def test_nearest_sentence_l():
     end = create_dummy_action(True)
@@ -145,8 +187,6 @@ def test_batchify_data():
     assert len(output_batch.shape) == 3
     assert len(mask_batch.shape) == 2
 
-    # First dimension should be batch size
-    assert input_batch.shape[0] == output_batch.shape[0]
     assert input_batch.shape[0] == mask_batch.shape[0]
     assert input_batch.shape[0] > 1
 
@@ -165,6 +205,7 @@ def test_batchify_data():
     # Should never be fully masked
     assert np.all(mask_batch[:, 0] == 0) == False
 
+    # Should never be fully masked0
     for i in range(input_batch.shape[0]):
         # Should always start from beginning of the sentence
         assert decode_actions(output_batch[i, 0, :])['upper_case']
diff --git a/src/utils.py b/src/utils.py
new file mode 100644
index 0000000..f9b145e
--- /dev/null
+++ b/src/utils.py
@@ -0,0 +1,41 @@
+import yaml
+import re
+import os
+
+PROJECT_ROOT=os.path.dirname(os.path.realpath("/".join(__file__.split("/")) + "/.."))
+
+def get_config() -> dict:
+    """Returns dict with config values
+
+    Returns:
+        dict: Dict with condig values
+    """
+
+    with open(f"{PROJECT_ROOT}/params.yaml", "r")  as file:
+        config = yaml.load(file, Loader=yaml.FullLoader)
+
+    return config
+
+def remove_multiple_spaces(text: str) -> str:
+    """Replaces multiple spaces by a single one
+
+    Args:
+        text (str): Text potentialy containing multiple spaces
+
+    Returns:
+        str: Text with all multiple spaces replaced by one
+    """
+    return re.sub(r"\s\s+", " ", text)
+
+def remove_punctuation(text: str) -> str:
+    """Removes all non-alphanumeric characters from the text.  
+    Might result in multiple spaces while chracters like `-` 
+    are used
+
+    Args:
+        text (str): Text containing punctuation
+
+    Returns:
+        str: Text with all punctuactions removed
+    """
+    return ''.join(filter(lambda x: x.isalnum() or x.isspace(), text))
-- 
GitLab


From fd130986ad00e9deddb4fa48ba4db82bc3d84cae Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Fri, 24 Jul 2020 15:35:22 +0200
Subject: [PATCH 016/116] First working model

---
 .dvc/config                            |  6 ++++--
 {modeling/models => models}/.gitignore |  0
 notebooks/test.ipynb                   | 29 +++++++++++++-------------
 scripts/train.py                       | 21 ++++++++++++-------
 4 files changed, 32 insertions(+), 24 deletions(-)
 rename {modeling/models => models}/.gitignore (100%)

diff --git a/.dvc/config b/.dvc/config
index c226a2a..481cc15 100644
--- a/.dvc/config
+++ b/.dvc/config
@@ -1,4 +1,6 @@
+[core]
+    remote = newremote
 ['remote "newremote"']
-    url = s3://punctuation/actions
+    url = s3://punctuation
     endpointurl = https://minio.clarin-pl.eu/minio
-    profile = clarinpl
\ No newline at end of file
+    profile = clarinpl
diff --git a/modeling/models/.gitignore b/models/.gitignore
similarity index 100%
rename from modeling/models/.gitignore
rename to models/.gitignore
diff --git a/notebooks/test.ipynb b/notebooks/test.ipynb
index 9e81e62..7219b70 100644
--- a/notebooks/test.ipynb
+++ b/notebooks/test.ipynb
@@ -23,7 +23,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 113,
+   "execution_count": 15,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -42,13 +42,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 114,
+   "execution_count": 16,
    "metadata": {},
    "outputs": [],
    "source": [
     "INPUT_PATH=\"../generated/stage4_reindexing\"\n",
     "MODEL_BASE = \"bert-base-multilingual-cased\"\n",
-    "MODEL_NAME = \"actionv1\"\n",
+    "MODEL_NAME = \"actionv1_500\"\n",
     "LR = 1e-4\n",
     "\n",
     "BATCH_SIZE=8\n",
@@ -58,7 +58,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 115,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -67,11 +67,11 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 116,
+   "execution_count": 25,
    "metadata": {},
    "outputs": [],
    "source": [
-    "expected = \"To nie amunicja. To krew przygotowana do wysyłki na front. W Lublinie uruchomiono pierwszy ośrodek przetaczania krwi. Dobrowolna ofiara krwi tych dziewcząt ratuje życie tysiącom rannych żołnierzy.\"\n",
+    "expected = \"Dekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik niemieckiego obozu pracy znęcał się nad obywatelami polskimi. Oskarżony Musielski. Świadek Jankowska opowiedziała, jak bił on Polaków po twarzy, kopał i groził Majdankiem. Świadek Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu, Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\"\n",
     "text_clean = create_model_input_output(expected)[0]\n",
     "\n",
     "inputs = tokenizer(text_clean, return_tensors=\"pt\")"
@@ -79,7 +79,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 117,
+   "execution_count": 19,
    "metadata": {
     "tags": []
    },
@@ -95,26 +95,26 @@
       "text/plain": "<All keys matched successfully>"
      },
      "metadata": {},
-     "execution_count": 117
+     "execution_count": 19
     }
    ],
    "source": [
     "model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=6)\n",
-    "model.load_state_dict(torch.load(\"models/actionv1-0-50000.model\"))\n"
+    "model.load_state_dict(torch.load(\"../models/actionv1_500-0-5000.model\"))\n"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 118,
+   "execution_count": 20,
    "metadata": {},
    "outputs": [
     {
      "output_type": "execute_result",
      "data": {
-      "text/plain": "{'input_ids': tensor([[   101,  10114,  11058,  10392,  23124,  14083,  10114,  50302,  26127,\n          14052,  47163,  26277,  10149,    191,  12682,  72855,  10506,  10132,\n          14589,    191,  13455,  31998,  10399,  20591,  32013,  12507,  19458,\n            183,  78190,  20157,  17931,  18306,  46136,  50302,  15926,  37908,\n          16828,  24626,  10108,  85286,  50302,  15926,  23453,    172,  14548,\n          10874,  55239,  10123,  45840,  10381,  49934,  34720,  28702,  22530,\n          17044,  12837,  40938,  19626,  81422,  10400,  23311,  50302,  15926,\n          20157,  17931,  18306,  50851,  50302,  15926,  10973,  95230,  11297,\n          13863,  22729,  10171,  10238,    177,  82413,  11058,  17272,  80852,\n          10390,  12577,  12197,    194,  82322,  10361,  50302,  13362,  10284,\n            348,  12097,  83978,  74780,  10149,  58133,  11530,  12741,  15183,\n          14052,  54609,  11624,  84269,  10621,  73121,  13050,  10132,  82992,\n            191,  35327,  42041,  10112,    191,  12211,  62187,  10419,  68037,\n          44227,  35090,  13717,  27828,  10449,  87042,  11133,  77029,  22578,\n          67405,  11717,  78098,  12294,  35779,  10108,  85286,  50302,  15926,\n          14052, 100963,  10280,  10424,  10149,    194,  12524,  10598,  73243,\n          14916,  73837,  22555,  89484,  10108,  21501,  21907,  34582,  24203,\n          64199,  21838,  10418,  92822, 110206,  10113,    191,  27652,  11679,\n          25175,  11877,  19495,    191,  19888,  10514,  33619,  10730,  17249,\n          10229,  10220,    186, 102075,  15050,  23090,  42155,  46845,  10149,\n          27648,  10514,  56423,  17249,    183,  10418,  17771,  60519,    102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
+      "text/plain": "{'input_ids': tensor([[   101,  25553,  10311,  92058,  42750,  49200,  25470,  10305,  13520,\n          20868,  74422,    191,  61337,  53873,  10877,  24356,  11157,  10797,\n          53873,  41257,  10797,  10395,  11486,  28404,  52321,  84639,  10147,\n          26356,  14196,  31839,  60347,  11499,  45637,  10156,  25750,  10238,\n          10339,  35636,  11234,  21273,  78098,  31110,  20802,    194,  13455,\n          21282,  14052,  10381,  12964,  12294,  10132,  69857,  10240,  10963,\n            191,  10157,  12097,  53837,    172, 107494,  10157,  29230,  10269,\n          10488,  19428,  22510,    177,  35374,  10269,  10493,  10253,  26477,\n          14552,  10418,  11033,  18809,  11048,  58531,  68756,  10637,  12524,\n          98033,  91176,  32650,  14950,  20162,  18761,    172, 107494,  10112,\n            194, 107452,  10390,  30518,  63497,  40938,  11415,  43172,    177,\n          40398,  88238,  39855,  11951,    191,  54609,  20315,  12402,  86420,\n            183,  69857,  10240,  10963,  38339,  11058,  11044,  10305,  19237,\n          33207,  46972,  10963,  80711,  17339,  53142,  34105,  10138,  10132,\n          10148,  19063,    179,  91496,  10400,  81160,  10354,  10390,  10149,\n          48025,  11478, 103904,  14152,  49854,  49200,  10132,  23040,  84921,\n          13055,  40590,  11783,  54609,  18220,  28768,  52665,  69997,  13717,\n          30057,  11401,    172, 107494,  10112,  10311, 109245, 105084,  25470,\n          37661,  10644,    194,  12947,  37863,  10116,    370,  10133,  11133,\n          18521,  69886,  20056,  10418,  63184,  55485,    175,  15357,  19665,\n            191,  78517,  10342,  10149,  48242,  66083,  30214,  18795,  26986,\n          10390,  10424,  11048,  16512,  98413,    191,  25470,  37661,  28612,\n            194,  35374,  10147,  33705,  10171,  27119,  49015,  10147,  24960,\n          10451,  46113,  16818,  13274,  10149, 107452,  17641,  10424,    183,\n          92718,  59685,  96716,  18936,  10269,    191,  12979,    183,  10418,\n          17771,  60519,    102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
      },
      "metadata": {},
-     "execution_count": 118
+     "execution_count": 20
     }
    ],
    "source": [
@@ -123,7 +123,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 125,
+   "execution_count": 26,
    "metadata": {
     "tags": []
    },
@@ -131,7 +131,7 @@
     {
      "output_type": "stream",
      "name": "stdout",
-     "text": "To nie amunicja. To krew przygotowana do wysyłki na front. W Lublinie uruchomiono pierwszy ośrodek przetaczania krwi. Dobrowolna ofiara krwi tych dziewcząt ratuje życie tysiącom rannych żołnierzy. Ustalanie grupy krwi. Przetaczanie krwi jest absolutnie bezbolesne i niemal nieszkodliwe dla zdrowia. Te krople ściekające do kolby tak samo przyspieszają zwycięstwo, jak naboje wyrabiane w fabryce broni. Niemiec poprzysiągł narodowi polskiemu zagładę. Ofiara krwi przyczynia się do zwycięstwa, zmniejsza liczbę ofiar katastrofy wojennej, przeciwdziała wyludnieniu Polski, w której po wojnie każda para rąk potrzebna będzie do pracy. Pomóż ojczyźnie!\nTo nie amunicja, to krew przygotowana do wysyłki na Front w Lublinie, uruchomiono pierwszy Ośrodek Przetaczania Krwi, dobrowolna, ofiara, krwi tych dziewcząt, ratuje życie Tysiącom, rannych, żołnierzy, ustalanie grupy, krwi przetaczanie krwi, jest absolutnie bezbolesne i niemal Nieszkodliwe dla Zdrowia, Te krople, ściekające do kolby, tak, samo, przyspieszają zwycięstwo, jak naboje wyrabiane w fabryce Broni, Niemiec, poprzysiągł narodowi polskiemu, Zagładę, Ofiara Krwi, przyczynia się do zwycięstwa zmniejsza liczbę Ofiar Katastrofy, wojennej, przeciwdziała wyludnieniu Polski, w której po wojnie, każda para, Rąk, potrzebna będzie do, pracy, pomóż Ojczyźnie,\n"
+     "text": "Dekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik niemieckiego obozu pracy znęcał się nad obywatelami polskimi. Oskarżony Musielski. Świadek Jankowska opowiedziała, jak bił on Polaków po twarzy, kopał i groził Majdankiem. Świadek Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu, Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\nDekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej Specjalny Sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom. W ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko Folksdojczowi Musialskiemu, który jako kierownik Niemieckiego Obozu Pracy znęcał się nad obywatelami polskimi. Oskarżony Musielski świadek. Jankowska opowiedziała, jak bił on Polaków, po twarzy kopał i groził majdankiem świadek. Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\n"
     }
    ],
    "source": [
@@ -142,6 +142,7 @@
     "\n",
     "actions = labels_pred > 0.2\n",
     "print(expected)\n",
+    "print(\"------\")\n",
     "print(recover_text(text_clean, actions))"
    ]
   },
diff --git a/scripts/train.py b/scripts/train.py
index 5ced9ee..d2d1b60 100755
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -6,24 +6,30 @@ from torch.nn import BCEWithLogitsLoss
 import pandas as pd
 import numpy as np
 import dask.dataframe as dd
+from src.utils import PROJECT_ROOT
+import os
 
-INPUT_PATH="../generated/stage4_reindexing"
+MODEL_DIR = f"{PROJECT_ROOT}/models"
+INPUT_PATH=f"{PROJECT_ROOT}/generated/stage4_reindexing"
 MODEL_BASE = "bert-base-multilingual-cased"
-MODEL_NAME = "actionv1"
+MODEL_NAME = "actionv1_500"
 LR = 1e-4
 
-BATCH_SIZE=8
+BATCH_SIZE=2
 NUM_EPOCH=5
 SAVE_STEP=5_000
 AVERAGE_SPAN = 1_000
 
-LOAD = "0-50000"
+LOAD = None #"0-50000"
 
 if __name__ == "__main__":
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
     print(f"Training on {device}")
 
     df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
+
+    if not os.path.exists(f"{MODEL_DIR}/{MODEL_NAME}"):
+        os.makedirs(f"{MODEL_DIR}/{MODEL_NAME}", exist_ok=True)
     
     tokenizer = BertTokenizer.from_pretrained(MODEL_BASE)
 
@@ -32,11 +38,10 @@ if __name__ == "__main__":
     criterion = BCEWithLogitsLoss().to(device)
     optimizer = torch.optim.Adam(model.parameters(), lr=LR)
 
-
     epoch_start = 0
     sample_start = 0
     if LOAD is not None:
-        model.load_state_dict(torch.load(f"models/{MODEL_NAME}-{LOAD}.model"))
+        model.load_state_dict(torch.load(f"{MODEL_DIR}/{MODEL_NAME}-{LOAD}.model"))
         #optimizer.load_state_dict(torch.load(f"models/{MODEL_NAME}-{LOAD}.optimizer"))
 
         epoch_start, sample_start = LOAD.split("-")
@@ -85,8 +90,8 @@ if __name__ == "__main__":
 
             if i % SAVE_STEP == 0:
                 print(f"Saving: Epoch {epoch}, step {i}")
-                torch.save(model.state_dict(), f"models/{MODEL_NAME}-{epoch}-{i}.model")
-                torch.save(optimizer.state_dict(), f"models/{MODEL_NAME}-{epoch}-{i}.optimizer")
+                torch.save(model.state_dict(), f"{MODEL_DIR}/{MODEL_NAME}-{epoch}-{i}.model")
+                torch.save(optimizer.state_dict(), f"{MODEL_DIR}/{MODEL_NAME}-{epoch}-{i}.optimizer")
 
             loss.backward()
 
-- 
GitLab


From 94bbcc3cb3583e64415da8939d17268b2e0fa8b1 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 30 Jul 2020 10:07:51 +0200
Subject: [PATCH 017/116] Tokenizator fixup

---
 scripts/train.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/scripts/train.py b/scripts/train.py
index d2d1b60..c26b8e8 100755
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -1,6 +1,6 @@
 #!/usr/bin/python3
 
-from transformers import BertTokenizer, BertForTokenClassification
+from transformers import BertTokenizerFast, BertForTokenClassification
 import torch
 from torch.nn import BCEWithLogitsLoss
 import pandas as pd
@@ -20,7 +20,7 @@ NUM_EPOCH=5
 SAVE_STEP=5_000
 AVERAGE_SPAN = 1_000
 
-LOAD = None #"0-50000"
+LOAD = "0-60000"
 
 if __name__ == "__main__":
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -31,7 +31,7 @@ if __name__ == "__main__":
     if not os.path.exists(f"{MODEL_DIR}/{MODEL_NAME}"):
         os.makedirs(f"{MODEL_DIR}/{MODEL_NAME}", exist_ok=True)
     
-    tokenizer = BertTokenizer.from_pretrained(MODEL_BASE)
+    tokenizer = BertTokenizerFast.from_pretrained(MODEL_BASE)
 
     # TODO: Change num labels
     model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=6).to(device)
-- 
GitLab


From 67c78771ad2aa14fd95b7843123147d4034e536e Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 30 Jul 2020 11:51:25 +0200
Subject: [PATCH 018/116] Parametrizing all constants

---
 .dvc/config                                   |  4 +-
 .gitignore                                    |  3 +-
 Dockerfile                                    |  4 -
 dvc.yaml                                      |  9 ++
 modeling/BertTokenMultilabel.py               | 12 ---
 models/.gitignore                             |  2 -
 notebooks/test.ipynb                          | 26 +++---
 params.yaml                                   | 12 ++-
 .../dataset_generation/stage2_tokenization.py |  4 +-
 .../dataset_generation/stage3_exploding.py    |  1 -
 scripts/train.py                              | 90 +++++++++++--------
 11 files changed, 90 insertions(+), 77 deletions(-)
 delete mode 100644 Dockerfile
 delete mode 100644 modeling/BertTokenMultilabel.py
 delete mode 100644 models/.gitignore

diff --git a/.dvc/config b/.dvc/config
index 481cc15..c30b54e 100644
--- a/.dvc/config
+++ b/.dvc/config
@@ -1,6 +1,6 @@
 [core]
     remote = newremote
 ['remote "newremote"']
-    url = s3://punctuation
-    endpointurl = https://minio.clarin-pl.eu/minio
+    url = s3://punctuation/action_based
+    endpointurl = https://minio.clarin-pl.eu
     profile = clarinpl
diff --git a/.gitignore b/.gitignore
index 11deaf7..1272d38 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,4 +7,5 @@ dataset_actions
 .metals
 /data
 __pycache__
-.pytest_cache
\ No newline at end of file
+.pytest_cache
+/checkpoints
diff --git a/Dockerfile b/Dockerfile
deleted file mode 100644
index d31b532..0000000
--- a/Dockerfile
+++ /dev/null
@@ -1,4 +0,0 @@
-FROM ubuntu:20.04
-
-RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install python3 -y python3-pip
-RUN pip3 install numpy pandas "dask[complete]" torch transformers
\ No newline at end of file
diff --git a/dvc.yaml b/dvc.yaml
index ec9d4e9..e700ca4 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -16,6 +16,7 @@ stages:
     params:
     - tokenization.max_tokens
     - tokenization.min_tokens
+    - global.base_model
     outs:
     - generated/stage2_tokenization
   exploding:
@@ -32,3 +33,11 @@ stages:
     - scripts/dataset_generation/stage4_reindexing.py
     outs:
     - generated/stage4_reindexing
+  training:
+    cmd: python3 -m scripts.train
+    deps:
+    - generated/stage4_reindexing
+    - scripts/train.py
+    - global.base_model
+    outs:
+    - checkpoints
diff --git a/modeling/BertTokenMultilabel.py b/modeling/BertTokenMultilabel.py
deleted file mode 100644
index f57e47f..0000000
--- a/modeling/BertTokenMultilabel.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import torch.nn as nn
-from transformers import BertForTokenClassification
-
-class BertTokenMultilabel(nn.Module):
-    def __init__(self, base_model: str, num_labels: int):
-        super(BertTokenMultilabel, self).__init__()
-        self.base_model = BertForTokenClassification.from_pretrained(base_model, num_labels=num_labels)
-
-        self.add_module(self.base_model)
-
-    def forward():
-        
\ No newline at end of file
diff --git a/models/.gitignore b/models/.gitignore
deleted file mode 100644
index c96a04f..0000000
--- a/models/.gitignore
+++ /dev/null
@@ -1,2 +0,0 @@
-*
-!.gitignore
\ No newline at end of file
diff --git a/notebooks/test.ipynb b/notebooks/test.ipynb
index 7219b70..22779f2 100644
--- a/notebooks/test.ipynb
+++ b/notebooks/test.ipynb
@@ -23,7 +23,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 45,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -42,7 +42,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 46,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -53,12 +53,12 @@
     "\n",
     "BATCH_SIZE=8\n",
     "NUM_EPOCH=5\n",
-    "SAVE_STEP=5_000"
+    "SAVE_STEP=60_000"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 47,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -67,7 +67,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 48,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -79,7 +79,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 49,
    "metadata": {
     "tags": []
    },
@@ -95,7 +95,7 @@
       "text/plain": "<All keys matched successfully>"
      },
      "metadata": {},
-     "execution_count": 19
+     "execution_count": 49
     }
    ],
    "source": [
@@ -105,16 +105,16 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 50,
    "metadata": {},
    "outputs": [
     {
      "output_type": "execute_result",
      "data": {
-      "text/plain": "{'input_ids': tensor([[   101,  25553,  10311,  92058,  42750,  49200,  25470,  10305,  13520,\n          20868,  74422,    191,  61337,  53873,  10877,  24356,  11157,  10797,\n          53873,  41257,  10797,  10395,  11486,  28404,  52321,  84639,  10147,\n          26356,  14196,  31839,  60347,  11499,  45637,  10156,  25750,  10238,\n          10339,  35636,  11234,  21273,  78098,  31110,  20802,    194,  13455,\n          21282,  14052,  10381,  12964,  12294,  10132,  69857,  10240,  10963,\n            191,  10157,  12097,  53837,    172, 107494,  10157,  29230,  10269,\n          10488,  19428,  22510,    177,  35374,  10269,  10493,  10253,  26477,\n          14552,  10418,  11033,  18809,  11048,  58531,  68756,  10637,  12524,\n          98033,  91176,  32650,  14950,  20162,  18761,    172, 107494,  10112,\n            194, 107452,  10390,  30518,  63497,  40938,  11415,  43172,    177,\n          40398,  88238,  39855,  11951,    191,  54609,  20315,  12402,  86420,\n            183,  69857,  10240,  10963,  38339,  11058,  11044,  10305,  19237,\n          33207,  46972,  10963,  80711,  17339,  53142,  34105,  10138,  10132,\n          10148,  19063,    179,  91496,  10400,  81160,  10354,  10390,  10149,\n          48025,  11478, 103904,  14152,  49854,  49200,  10132,  23040,  84921,\n          13055,  40590,  11783,  54609,  18220,  28768,  52665,  69997,  13717,\n          30057,  11401,    172, 107494,  10112,  10311, 109245, 105084,  25470,\n          37661,  10644,    194,  12947,  37863,  10116,    370,  10133,  11133,\n          18521,  69886,  20056,  10418,  63184,  55485,    175,  15357,  19665,\n            191,  78517,  10342,  10149,  48242,  66083,  30214,  18795,  26986,\n          10390,  10424,  11048,  16512,  98413,    191,  25470,  37661,  28612,\n            194,  35374,  10147,  33705,  10171,  27119,  49015,  10147,  24960,\n          10451,  46113,  16818,  13274,  10149, 107452,  17641,  10424,    183,\n          92718,  59685,  96716,  18936,  10269,    191,  12979,    183,  10418,\n          17771,  60519,    102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
+      "text/plain": "{'input_ids': tensor([[   101,  18910,  34942,  10147,  41458,  49264,  86412,  14196,  28780,\n          16828,  19435,  12507,  10132,    191,  10157,  10305,  16828,  47472,\n          10147,  51043,  53549,  10157,  22815,  29378,  42606,  32650,  15953,\n          24011,  10756,  13130,  10162,  25085,  10756,    191,  21619,  12240,\n          19094,  10136,  17528,    191,  82689,  59763,  10157,    183,    194,\n          14951,  12355,  77029,  10138,  49200,  18795,  13130,  16050,  45244,\n          58814,  63256,  19172,  10598,  30214,  11058,  18933,    191,  20728,\n          25983,  10390,  10424,  72058,  15083,  11372,  11841,  91496,  85421,\n            187,  21633,  22427,  57753,  10514,  96760,  10390,  44227,  22530,\n            191,  12979,    194,  20923,  63168,  10269,  11202,  19234,  14996,\n          39144,  10157,    194,  25470,  10305,  91865,  39371,  18694,  82836,\n          21799,  20868,  22578,  87985,  20162,  21791,  10138,  13680,  10701,\n         107626,  80711,  92859,  27648,    194,  31223,  10425,  11133,  10424,\n          12060,  17339,  70500,  87907,  10500,  72325,  10116,  10427,  15190,\n          36701,  87985,  18338,  10506,  18996, 103412,  10174,  63923,  72275,\n          11485,  10303,  28612, 110206,  10113,  13050,  11342,  11133,  10135,\n          14151,  16036,  10514,  37975,  27828,  39268,  16251,    177,  30518,\n          20129,  21617,  14991,  30022,  13711,  18996, 103412,  10174,  45244,\n          62872,  28780,  79534,  12537,  87985,  18338,  10506,  20157,  82351,\n          10157,  61610,  11133,    187,  21633,  14302,  11680,  12097,  14194,\n          82775,  13717,  23090, 108605, 107626,  10644,  92859,  44227,  22064,\n          11284,  96858,  11813,  43307,  17112,  84280,  10339,  67464,  40384,\n          12197,  10427,  15190,  61187,  10797,  25085,  10157,  26584,  10514,\n          90086, 102984,  13130,  10162,  31569,  34105,  87985,  18338,  18761,\n          10132,  25085,  10963,  26584,  11048,  10514,  52784,  21620,    102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
      },
      "metadata": {},
-     "execution_count": 20
+     "execution_count": 50
     }
    ],
    "source": [
@@ -123,7 +123,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 26,
+   "execution_count": 51,
    "metadata": {
     "tags": []
    },
@@ -131,7 +131,7 @@
     {
      "output_type": "stream",
      "name": "stdout",
-     "text": "Dekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik niemieckiego obozu pracy znęcał się nad obywatelami polskimi. Oskarżony Musielski. Świadek Jankowska opowiedziała, jak bił on Polaków po twarzy, kopał i groził Majdankiem. Świadek Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu, Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\nDekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej Specjalny Sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom. W ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko Folksdojczowi Musialskiemu, który jako kierownik Niemieckiego Obozu Pracy znęcał się nad obywatelami polskimi. Oskarżony Musielski świadek. Jankowska opowiedziała, jak bił on Polaków, po twarzy kopał i groził majdankiem świadek. Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\n"
+     "text": "Dekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik niemieckiego obozu pracy znęcał się nad obywatelami polskimi. Oskarżony Musielski. Świadek Jankowska opowiedziała, jak bił on Polaków po twarzy, kopał i groził Majdankiem. Świadek Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu, Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\n------\nDekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik Niemieckiego Obozu pracy znęcał się nad obywatelami polskimi. oskarżony musielski świadek Jankowska opowiedziała, jak bił on Polaków po twarzy kopał i groził majdankiem świadek. Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\n"
     }
    ],
    "source": [
@@ -140,7 +140,7 @@
     "y_pred = model(**inputs)[0].sigmoid()\n",
     "labels_pred = token_labels_to_word_labels(text_clean, y_pred.detach().numpy()[0, 1:-1, :], tokenizer)\n",
     "\n",
-    "actions = labels_pred > 0.2\n",
+    "actions = labels_pred > 0.5\n",
     "print(expected)\n",
     "print(\"------\")\n",
     "print(recover_text(text_clean, actions))"
diff --git a/params.yaml b/params.yaml
index d8aefcb..36bc41c 100644
--- a/params.yaml
+++ b/params.yaml
@@ -1,5 +1,6 @@
 global:
     dashboard_port: 8787
+    base_model: "bert-base-multilingual-cased"
 
 extraction:
     num_partitions: 2_000
@@ -18,4 +19,13 @@ exploding:
 
 reindexing:
     num_workers: 1
-    worker_memory_limit: "60GB"
\ No newline at end of file
+    worker_memory_limit: "60GB"
+
+training:
+    learning_rate: 0.0001
+    num_epochs: 5
+    batch_size: 2
+    save_step: 20
+    loss_averaging_span: 1000
+    fresh_start: false
+    device: "cuda:0"
\ No newline at end of file
diff --git a/scripts/dataset_generation/stage2_tokenization.py b/scripts/dataset_generation/stage2_tokenization.py
index 27482e1..ade0e48 100644
--- a/scripts/dataset_generation/stage2_tokenization.py
+++ b/scripts/dataset_generation/stage2_tokenization.py
@@ -61,12 +61,12 @@ if __name__ == "__main__":
     min_tokens = config['tokenization']['min_tokens']
     num_workers = config['tokenization']['num_workers']
     memory_limit = config['tokenization']['worker_memory_limit']
+    base_model = config['global']['base_model']
 
     client = Client(n_workers=num_workers, memory_limit=memory_limit)
     print(client.dashboard_link)
 
-    tokenizer = BertTokenizerFast.from_pretrained(
-        'bert-base-multilingual-cased')
+    tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
     tokenizer = dask.delayed(tokenizer)
 
diff --git a/scripts/dataset_generation/stage3_exploding.py b/scripts/dataset_generation/stage3_exploding.py
index bd6d0af..53c49d8 100644
--- a/scripts/dataset_generation/stage3_exploding.py
+++ b/scripts/dataset_generation/stage3_exploding.py
@@ -2,7 +2,6 @@
 from src.processing import batchify_data
 from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
-from transformers import BertTokenizerFast
 import numpy as np
 import dask
 from dask.distributed import Client
diff --git a/scripts/train.py b/scripts/train.py
index c26b8e8..392318d 100755
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -6,61 +6,71 @@ from torch.nn import BCEWithLogitsLoss
 import pandas as pd
 import numpy as np
 import dask.dataframe as dd
-from src.utils import PROJECT_ROOT
 import os
+import glob
+from src.utils import PROJECT_ROOT, get_config
+from src.processing import ACTIONS_KEYS
 
-MODEL_DIR = f"{PROJECT_ROOT}/models"
-INPUT_PATH=f"{PROJECT_ROOT}/generated/stage4_reindexing"
-MODEL_BASE = "bert-base-multilingual-cased"
-MODEL_NAME = "actionv1_500"
-LR = 1e-4
-
-BATCH_SIZE=2
-NUM_EPOCH=5
-SAVE_STEP=5_000
-AVERAGE_SPAN = 1_000
-
-LOAD = "0-60000"
+INPUT_PATH = f"{PROJECT_ROOT}/generated/stage4_reindexing"
+OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints"
 
 if __name__ == "__main__":
-    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    config = get_config()
+    learning_rate = config['training']['learning_rate']
+    num_epochs = config['training']['num_epochs']
+    batch_size = config['training']['batch_size']
+    save_step = config['training']['save_step']
+    loss_averaging_span = config['training']['loss_averaging_span']
+    fresh_start = config['training']['fresh_start']
+    device_name = config['training']['device']
+    base_model = config['global']['base_model']
+
+    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
     print(f"Training on {device}")
 
     df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
 
-    if not os.path.exists(f"{MODEL_DIR}/{MODEL_NAME}"):
-        os.makedirs(f"{MODEL_DIR}/{MODEL_NAME}", exist_ok=True)
+    if not os.path.exists(f"{OUTPUT_PATH}"):
+        os.makedirs(f"{OUTPUT_PATH}", exist_ok=True)
     
-    tokenizer = BertTokenizerFast.from_pretrained(MODEL_BASE)
+    tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
-    # TODO: Change num labels
-    model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=6).to(device)
+    model = BertForTokenClassification.from_pretrained(base_model, num_labels=len(ACTIONS_KEYS)).to(device)
     criterion = BCEWithLogitsLoss().to(device)
-    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
+    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
 
     epoch_start = 0
     sample_start = 0
-    if LOAD is not None:
-        model.load_state_dict(torch.load(f"{MODEL_DIR}/{MODEL_NAME}-{LOAD}.model"))
-        #optimizer.load_state_dict(torch.load(f"models/{MODEL_NAME}-{LOAD}.optimizer"))
-
-        epoch_start, sample_start = LOAD.split("-")
-        epoch_start = int(epoch_start)
-        sample_start = int(sample_start)
-
-        print(f"Loaded {MODEL_NAME}-{LOAD}")
+    if fresh_start == False:
+        checkpoint_files = glob.glob(f"{OUTPUT_PATH}/*.model")
+        furthest_epoch = -1
+        furthest_batch_num = -1
+        for checkpoint_file in checkpoint_files:
+            filename = checkpoint_file.split("/")[-1].split(".")[0]
+            epoch, iteration = filename.split("-")
+            epoch, iteration = int(epoch), int(iteration)
+
+            if epoch >= furthest_epoch:
+                furthest_epoch = epoch
+                furthest_batch_num = max(iteration, furthest_batch_num)
+
+        if furthest_epoch > -1 and furthest_batch_num > -1:
+            model.load_state_dict(torch.load(f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.model"))
+            optimizer.load_state_dict(torch.load(f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.optimizer"))
+
+            epoch_start, sample_start = furthest_epoch, furthest_batch_num
+            print(f"Loaded {furthest_epoch}-{furthest_batch_num}")
 
     model.train()
-
     losses = []
 
-    for epoch in range(epoch_start, NUM_EPOCH):
+    for epoch in range(epoch_start, num_epochs):
         i = sample_start
         while True:
             # TODO: Change to 0-indexed...
-            data_batch_indexes = list(range(i*BATCH_SIZE+1, i*BATCH_SIZE + BATCH_SIZE +1))
+            data_batch_indexes = list(range(i*batch_size+1, i*batch_size + batch_size +1))
             
-            # Precomputing total number of samples very long, so lets
+            # Precomputing total number of samples takes very long, so lets
             # try to get next batch until fail :)
             try:
                 data_batch = df.loc[data_batch_indexes].compute()
@@ -81,21 +91,23 @@ if __name__ == "__main__":
             loss = criterion(y_pred, outputs)
 
             losses.append(loss.item())
-            if len(losses) > AVERAGE_SPAN:
-                losses = losses[-AVERAGE_SPAN:]
+            if len(losses) > loss_averaging_span:
+                losses = losses[-loss_averaging_span:]
 
             print(f'epoch: {epoch} | step: {i} | loss: {np.mean(losses)}')
 
             optimizer.zero_grad()
 
-            if i % SAVE_STEP == 0:
+            if i % save_step == 0 and (i != sample_start or epoch != epoch_start):
                 print(f"Saving: Epoch {epoch}, step {i}")
-                torch.save(model.state_dict(), f"{MODEL_DIR}/{MODEL_NAME}-{epoch}-{i}.model")
-                torch.save(optimizer.state_dict(), f"{MODEL_DIR}/{MODEL_NAME}-{epoch}-{i}.optimizer")
+                torch.save(model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model")
+                torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.optimizer")
 
             loss.backward()
-
             optimizer.step()
 
             i += 1
 
+    torch.save(model.state_dict(), f"{OUTPUT_PATH}/final.model")
+    torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/final.optimizer")
+
-- 
GitLab


From d96f23f1db4effd26285e9269fc1951c15b40e94 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 30 Jul 2020 12:24:39 +0200
Subject: [PATCH 019/116] Docker development enviroment now works with train
 script

---
 .devcontainer/devcontainer.json | 38 ++++++++++++++++++++++++++++++
 docker/Dockerfile               | 41 +++++++++++++++++++++++++++++++++
 2 files changed, 79 insertions(+)
 create mode 100644 .devcontainer/devcontainer.json
 create mode 100644 docker/Dockerfile

diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
new file mode 100644
index 0000000..9981ce0
--- /dev/null
+++ b/.devcontainer/devcontainer.json
@@ -0,0 +1,38 @@
+// For format details, see https://aka.ms/vscode-remote/devcontainer.json or this file's README at:
+// https://github.com/microsoft/vscode-dev-containers/tree/v0.128.0/containers/docker-existing-dockerfile
+{
+	"name": "Development Container",
+
+	// Sets the run context to one level up instead of the .devcontainer folder.
+	"context": "../docker",
+
+	// Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename.
+	"dockerFile": "../docker/Dockerfile",
+
+	// Set *default* container specific settings.json values on container create.
+	"settings": { 
+		"terminal.integrated.shell.linux": null
+	},
+
+	// Add the IDs of extensions you want installed when the container is created.
+	"extensions": [],
+
+	// Use 'forwardPorts' to make a list of ports inside the container available locally.
+	"forwardPorts": [8787],
+
+	"runArgs": [
+		"--gpus", "all"
+	]
+
+	// Uncomment the next line to run commands after the container is created - for example installing curl.
+	// "postCreateCommand": "apt-get update && apt-get install -y curl",
+
+	// Uncomment when using a ptrace-based debugger like C++, Go, and Rust
+	// "runArgs": [ "--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined" ],
+
+	// Uncomment to use the Docker CLI from inside the container. See https://aka.ms/vscode-remote/samples/docker-from-docker.
+	// "mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind" ],
+
+	// Uncomment to connect as a non-root user. See https://aka.ms/vscode-remote/containers/non-root.
+	// "remoteUser": "vscode"
+}
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000..a4e025a
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,41 @@
+from ubuntu:20.04
+
+RUN apt update && apt install -y python3 python3-pip
+RUN apt update && apt install -y git
+RUN pip3 install ipywidgets
+
+#### CUDA Installation
+RUN apt-get update && apt-get install -y --no-install-recommends \
+gnupg2 curl ca-certificates && \
+    curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
+    echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
+    echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \
+rm -rf /var/lib/apt/lists/*
+
+ENV CUDA_VERSION 10.2.89
+
+ENV CUDA_PKG_VERSION 10-2=$CUDA_VERSION-1
+
+# For libraries in the cuda-compat-* package: https://docs.nvidia.com/cuda/eula/index.html#attachment-a
+RUN apt-get update && apt-get install -y --no-install-recommends \
+        cuda-cudart-$CUDA_PKG_VERSION \
+cuda-compat-10-2 && \
+ln -s cuda-10.2 /usr/local/cuda && \
+    rm -rf /var/lib/apt/lists/*
+
+# Required for nvidia-docker v1
+RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \
+    echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf
+
+ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}
+ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64
+
+# nvidia-container-runtime
+ENV NVIDIA_VISIBLE_DEVICES all
+ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
+ENV NVIDIA_REQUIRE_CUDA "cuda>=10.2 brand=tesla,driver>=384,driver<385 brand=tesla,driver>=396,driver<397 brand=tesla,driver>=410,driver<411 brand=tesla,driver>=418,driver<419"
+
+### END CUDA Installation
+
+RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pyarrow
+RUN ln -s /usr/bin/pip3 /usr/bin/pip
\ No newline at end of file
-- 
GitLab


From 151e4584862bfecc0a9b46606863bd612fb0026f Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 30 Jul 2020 12:16:19 +0000
Subject: [PATCH 020/116] Added option for maxium training time

---
 .devcontainer/devcontainer.json |  9 +++++++-
 docker/Dockerfile               |  2 +-
 dvc.yaml                        |  6 +++++
 params.yaml                     |  3 ++-
 scripts/train.py                | 29 ++++++++++++++++++++---
 src/utils.py                    | 41 +++++++++++++++++++++++++++++++++
 6 files changed, 84 insertions(+), 6 deletions(-)

diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index 9981ce0..73ee0b8 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -15,13 +15,20 @@
 	},
 
 	// Add the IDs of extensions you want installed when the container is created.
-	"extensions": [],
+	"extensions": [
+		"ms-python.python"
+	],
 
 	// Use 'forwardPorts' to make a list of ports inside the container available locally.
 	"forwardPorts": [8787],
 
 	"runArgs": [
 		"--gpus", "all"
+	],
+
+	"mounts": [
+		"source=${localEnv:HOME}/.gitconfig,target=/root/.gitconfig,type=bind",
+		"source=${localEnv:HOME}/.aws,target=/root/.aws,type=bind"
 	]
 
 	// Uncomment the next line to run commands after the container is created - for example installing curl.
diff --git a/docker/Dockerfile b/docker/Dockerfile
index a4e025a..c702b69 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -37,5 +37,5 @@ ENV NVIDIA_REQUIRE_CUDA "cuda>=10.2 brand=tesla,driver>=384,driver<385 brand=tes
 
 ### END CUDA Installation
 
-RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pyarrow
+RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pyarrow pytest
 RUN ln -s /usr/bin/pip3 /usr/bin/pip
\ No newline at end of file
diff --git a/dvc.yaml b/dvc.yaml
index e700ca4..2b6c435 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -38,6 +38,12 @@ stages:
     deps:
     - generated/stage4_reindexing
     - scripts/train.py
+    params:
     - global.base_model
+    - training.max_training_time
+    - training.learning_rate
+    - training.num_epochs
+    - training.batch_size
+    - training.save_step
     outs:
     - checkpoints
diff --git a/params.yaml b/params.yaml
index 36bc41c..0dbe4e4 100644
--- a/params.yaml
+++ b/params.yaml
@@ -25,7 +25,8 @@ training:
     learning_rate: 0.0001
     num_epochs: 5
     batch_size: 2
-    save_step: 20
+    save_step: 1000
+    max_training_time: "30s"
     loss_averaging_span: 1000
     fresh_start: false
     device: "cuda:0"
\ No newline at end of file
diff --git a/scripts/train.py b/scripts/train.py
index 392318d..d3216f8 100755
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -8,8 +8,9 @@ import numpy as np
 import dask.dataframe as dd
 import os
 import glob
-from src.utils import PROJECT_ROOT, get_config
+from src.utils import PROJECT_ROOT, get_config, convert_to_timedelta
 from src.processing import ACTIONS_KEYS
+from datetime import datetime
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/stage4_reindexing"
 OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints"
@@ -24,6 +25,10 @@ if __name__ == "__main__":
     fresh_start = config['training']['fresh_start']
     device_name = config['training']['device']
     base_model = config['global']['base_model']
+    max_train_time = config['training']['max_training_time']
+
+    if max_train_time is not None:
+        max_train_time = convert_to_timedelta(max_train_time)
 
     device = torch.device(device_name if torch.cuda.is_available() else "cpu")
     print(f"Training on {device}")
@@ -64,7 +69,16 @@ if __name__ == "__main__":
     model.train()
     losses = []
 
+    training_stopped = False
+
+    time_max = datetime.max
+    if max_train_time is not None:
+        time_max = datetime.now() + max_train_time
+
     for epoch in range(epoch_start, num_epochs):
+        if training_stopped:
+            break
+
         i = sample_start
         while True:
             # TODO: Change to 0-indexed...
@@ -103,11 +117,20 @@ if __name__ == "__main__":
                 torch.save(model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model")
                 torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.optimizer")
 
+            if datetime.now() > time_max:
+                print(f"Max time reached, saving: Epoch {epoch}, step {i}")
+                torch.save(model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model")
+                torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.optimizer")
+                training_stopped = True
+                break
+
+
             loss.backward()
             optimizer.step()
 
             i += 1
 
-    torch.save(model.state_dict(), f"{OUTPUT_PATH}/final.model")
-    torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/final.optimizer")
+    if not training_stopped:
+        torch.save(model.state_dict(), f"{OUTPUT_PATH}/final.model")
+        torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/final.optimizer")
 
diff --git a/src/utils.py b/src/utils.py
index f9b145e..5dd9fc7 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -1,6 +1,8 @@
 import yaml
 import re
 import os
+from datetime import timedelta
+from typing import Optional
 
 PROJECT_ROOT=os.path.dirname(os.path.realpath("/".join(__file__.split("/")) + "/.."))
 
@@ -39,3 +41,42 @@ def remove_punctuation(text: str) -> str:
         str: Text with all punctuactions removed
     """
     return ''.join(filter(lambda x: x.isalnum() or x.isspace(), text))
+
+def convert_to_timedelta(time_val: str) -> Optional[timedelta]:
+    """
+    src: https://code.activestate.com/recipes/577894-convert-strings-like-5d-and-60s-to-timedelta-objec/
+    Given a *time_val* (string) such as '5d', returns a timedelta object
+    representing the given value (e.g. timedelta(days=5)).  Accepts the
+    following '<num><char>' formats:
+    
+    =========   ======= ===================
+    Character   Meaning Example
+    =========   ======= ===================
+    s           Seconds '60s' -> 60 Seconds
+    m           Minutes '5m'  -> 5 Minutes
+    h           Hours   '24h' -> 24 Hours
+    d           Days    '7d'  -> 7 Days
+    =========   ======= ===================
+    
+    Examples::
+    
+        >>> convert_to_timedelta('7d')
+        datetime.timedelta(7)
+        >>> convert_to_timedelta('24h')
+        datetime.timedelta(1)
+        >>> convert_to_timedelta('60m')
+        datetime.timedelta(0, 3600)
+        >>> convert_to_timedelta('120s')
+        datetime.timedelta(0, 120)
+    """
+    num = int(time_val[:-1])
+    if time_val.endswith('s'):
+        return timedelta(seconds=num)
+    elif time_val.endswith('m'):
+        return timedelta(minutes=num)
+    elif time_val.endswith('h'):
+        return timedelta(hours=num)
+    elif time_val.endswith('d'):
+        return timedelta(days=num)
+    else:
+        return None
\ No newline at end of file
-- 
GitLab


From 36fbaf941de99092d10b19ec49b56932e1466c26 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 30 Jul 2020 13:22:38 +0000
Subject: [PATCH 021/116] Changed folder structure for easier mulitple-models
 experimetents develompent

---
 dvc.yaml                                      | 64 +++++++++----------
 params.yaml                                   | 49 +++++++-------
 .../__init__.py                               |  0
 .../stage1_extraction.py                      |  8 +--
 .../stage2_tokenization.py                    | 12 ++--
 .../stage3_exploding.py                       |  8 +--
 .../stage4_reindexing.py                      |  8 +--
 scripts/{ => actions_based}/train.py          | 20 +++---
 scripts/dataset_generation/.gitignore         |  1 -
 scripts/translation_based/__init__.py         |  0
 10 files changed, 85 insertions(+), 85 deletions(-)
 rename scripts/{dataset_generation => actions_based}/__init__.py (100%)
 rename scripts/{dataset_generation => actions_based}/stage1_extraction.py (86%)
 rename scripts/{dataset_generation => actions_based}/stage2_tokenization.py (83%)
 rename scripts/{dataset_generation => actions_based}/stage3_exploding.py (87%)
 rename scripts/{dataset_generation => actions_based}/stage4_reindexing.py (73%)
 rename scripts/{ => actions_based}/train.py (88%)
 delete mode 100644 scripts/dataset_generation/.gitignore
 create mode 100644 scripts/translation_based/__init__.py

diff --git a/dvc.yaml b/dvc.yaml
index 2b6c435..32ee641 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -1,49 +1,49 @@
 stages:
-  extraction:
-    cmd: python3 -m scripts.dataset_generation.stage1_extraction
+  actions_extraction:
+    cmd: python3 -m scripts.actions_based.stage1_extraction
     deps:
     - data
-    - scripts/dataset_generation/stage1_extraction.py
+    - scripts/actions_based/stage1_extraction.py
     params:
-    - extraction.num_partitions
+    - actions.extraction.num_partitions
     outs:
-    - generated/stage1_extraction
-  tokenization:
-    cmd: python3 -m scripts.dataset_generation.stage2_tokenization
+    - generated/actions/stage1_extraction
+  actions_tokenization:
+    cmd: python3 -m scripts.actions_based.stage2_tokenization
     deps:
-    - generated/stage1_extraction
-    - scripts/dataset_generation/stage2_tokenization.py
+    - generated/actions/stage1_extraction
+    - scripts/actions_based/stage2_tokenization.py
     params:
-    - tokenization.max_tokens
-    - tokenization.min_tokens
+    - actions.tokenization.max_tokens
+    - actions.tokenization.min_tokens
     - global.base_model
     outs:
-    - generated/stage2_tokenization
-  exploding:
-    cmd: python3 -m scripts.dataset_generation.stage3_exploding
+    - generated/actions/stage2_tokenization
+  actions_exploding:
+    cmd: python3 -m scripts.actions_based.stage3_exploding
     deps:
-    - generated/stage2_tokenization
-    - scripts/dataset_generation/stage3_exploding.py
+    - generated/actions/stage2_tokenization
+    - scripts/actions_based/stage3_exploding.py
     outs:
-    - generated/stage3_exploding
-  reindexing:
-    cmd: python3 -m scripts.dataset_generation.stage4_reindexing
+    - generated/actions/stage3_exploding
+  actions_reindexing:
+    cmd: python3 -m scripts.actions_based.stage4_reindexing
     deps:
-    - generated/stage3_exploding
-    - scripts/dataset_generation/stage4_reindexing.py
+    - generated/actions/stage3_exploding
+    - scripts/actions_based/stage4_reindexing.py
     outs:
-    - generated/stage4_reindexing
-  training:
-    cmd: python3 -m scripts.train
+    - generated/actions/stage4_reindexing
+  actions_training:
+    cmd: python3 -m scripts.actions_based.train
     deps:
-    - generated/stage4_reindexing
-    - scripts/train.py
+    - generated/actions/stage4_reindexing
+    - scripts/actions/train.py
     params:
     - global.base_model
-    - training.max_training_time
-    - training.learning_rate
-    - training.num_epochs
-    - training.batch_size
-    - training.save_step
+    - actions.training.max_training_time
+    - actions.training.learning_rate
+    - actions.training.num_epochs
+    - actions.training.batch_size
+    - actions.training.save_step
     outs:
-    - checkpoints
+    - checkpoints/actions
diff --git a/params.yaml b/params.yaml
index 0dbe4e4..f27696d 100644
--- a/params.yaml
+++ b/params.yaml
@@ -2,31 +2,32 @@ global:
     dashboard_port: 8787
     base_model: "bert-base-multilingual-cased"
 
-extraction:
-    num_partitions: 2_000
-    num_workers: 24
-    worker_memory_limit: "2GB"
+actions:
+    extraction:
+        num_partitions: 2_000
+        num_workers: 24
+        worker_memory_limit: "2GB"
 
-tokenization:
-    min_tokens: 10
-    max_tokens: 500
-    num_workers: 24
-    worker_memory_limit: "2GB"
+    tokenization:
+        min_tokens: 10
+        max_tokens: 500
+        num_workers: 24
+        worker_memory_limit: "2GB"
 
-exploding:
-    num_workers: 24
-    worker_memory_limit: "2GB"
+    exploding:
+        num_workers: 24
+        worker_memory_limit: "2GB"
 
-reindexing:
-    num_workers: 1
-    worker_memory_limit: "60GB"
+    reindexing:
+        num_workers: 1
+        worker_memory_limit: "60GB"
 
-training:
-    learning_rate: 0.0001
-    num_epochs: 5
-    batch_size: 2
-    save_step: 1000
-    max_training_time: "30s"
-    loss_averaging_span: 1000
-    fresh_start: false
-    device: "cuda:0"
\ No newline at end of file
+    training:
+        learning_rate: 0.0001
+        num_epochs: 5
+        batch_size: 2
+        save_step: 1000
+        max_training_time: "30s"
+        loss_averaging_span: 1000
+        fresh_start: false
+        device: "cuda:0"
\ No newline at end of file
diff --git a/scripts/dataset_generation/__init__.py b/scripts/actions_based/__init__.py
similarity index 100%
rename from scripts/dataset_generation/__init__.py
rename to scripts/actions_based/__init__.py
diff --git a/scripts/dataset_generation/stage1_extraction.py b/scripts/actions_based/stage1_extraction.py
similarity index 86%
rename from scripts/dataset_generation/stage1_extraction.py
rename to scripts/actions_based/stage1_extraction.py
index 79db141..0c44f14 100644
--- a/scripts/dataset_generation/stage1_extraction.py
+++ b/scripts/actions_based/stage1_extraction.py
@@ -14,7 +14,7 @@ import stackimpact
 import lorem
 from src.utils import get_config, PROJECT_ROOT
 
-GENERATED_FOLDER = "generated"
+GENERATED_FOLDER = "generated/actions"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/{GENERATED_FOLDER}/stage1_extraction"
 
 def process_file(x):
@@ -32,9 +32,9 @@ def process_file(x):
 if __name__ == "__main__":
 
     config = get_config()
-    num_partitions = config['extraction']['num_partitions']
-    num_workers = config['extraction']['num_workers']
-    memory_limit = config['extraction']['worker_memory_limit']
+    num_partitions = config['actions']['extraction']['num_partitions']
+    num_workers = config['actions']['extraction']['num_workers']
+    memory_limit = config['actions']['extraction']['worker_memory_limit']
 
     file_schema = "data/**/text_structure.xml"
     files_paths = glob.glob(file_schema, recursive=True)
diff --git a/scripts/dataset_generation/stage2_tokenization.py b/scripts/actions_based/stage2_tokenization.py
similarity index 83%
rename from scripts/dataset_generation/stage2_tokenization.py
rename to scripts/actions_based/stage2_tokenization.py
index ade0e48..ad9d64d 100644
--- a/scripts/dataset_generation/stage2_tokenization.py
+++ b/scripts/actions_based/stage2_tokenization.py
@@ -17,8 +17,8 @@ import dask.dataframe as dd
 from transformers import BertTokenizerFast
 from dask.distributed import Client
 
-INPUT_FOLDER = f"{PROJECT_ROOT}/generated/stage1_extraction"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/stage2_tokenization"
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
 
 def apply_tokenization(df, min_tokens: int, max_tokens: int, tokenizer: BertTokenizerFast):
     text_clean = df.input
@@ -57,10 +57,10 @@ RESULT_META = {
 if __name__ == "__main__":
 
     config = get_config()
-    max_tokens = config['tokenization']['max_tokens']
-    min_tokens = config['tokenization']['min_tokens']
-    num_workers = config['tokenization']['num_workers']
-    memory_limit = config['tokenization']['worker_memory_limit']
+    max_tokens = config['actions']['tokenization']['max_tokens']
+    min_tokens = config['actions']['tokenization']['min_tokens']
+    num_workers = config['actions']['tokenization']['num_workers']
+    memory_limit = config['actions']['tokenization']['worker_memory_limit']
     base_model = config['global']['base_model']
 
     client = Client(n_workers=num_workers, memory_limit=memory_limit)
diff --git a/scripts/dataset_generation/stage3_exploding.py b/scripts/actions_based/stage3_exploding.py
similarity index 87%
rename from scripts/dataset_generation/stage3_exploding.py
rename to scripts/actions_based/stage3_exploding.py
index 53c49d8..2aa1637 100644
--- a/scripts/dataset_generation/stage3_exploding.py
+++ b/scripts/actions_based/stage3_exploding.py
@@ -8,8 +8,8 @@ from dask.distributed import Client
 import pandas as pd
 from src.utils import PROJECT_ROOT, get_config
 
-INPUT_FOLDER = f"{PROJECT_ROOT}/generated/stage2_tokenization"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/stage3_exploding"
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
 
 def expand_dims(entry):
     inputs = entry.inputs.reshape(entry.input_shape)
@@ -52,8 +52,8 @@ RESULT_META = {
 
 if __name__ == "__main__":
     config = get_config()
-    num_workers = config['tokenization']['num_workers']
-    memory_limit = config['tokenization']['worker_memory_limit']
+    num_workers = config['actions']['exploding']['num_workers']
+    memory_limit = config['actions']['exploding']['worker_memory_limit']
 
     client = Client(n_workers=num_workers, memory_limit=memory_limit)
     print(client.dashboard_link)
diff --git a/scripts/dataset_generation/stage4_reindexing.py b/scripts/actions_based/stage4_reindexing.py
similarity index 73%
rename from scripts/dataset_generation/stage4_reindexing.py
rename to scripts/actions_based/stage4_reindexing.py
index 57182f3..fb4e6d4 100644
--- a/scripts/dataset_generation/stage4_reindexing.py
+++ b/scripts/actions_based/stage4_reindexing.py
@@ -9,13 +9,13 @@ from dask.distributed import Client
 import pandas as pd
 from src.utils import PROJECT_ROOT, get_config
 
-INPUT_FOLDER = f"{PROJECT_ROOT}/generated/stage3_exploding"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/stage4_reindexing"
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
 
 if __name__ == "__main__":
     config = get_config()
-    num_workers = config['tokenization']['num_workers']
-    memory_limit = config['tokenization']['worker_memory_limit']
+    num_workers = config['actions']['reindexing']['num_workers']
+    memory_limit = config['actions']['reindexing']['worker_memory_limit']
 
     client = Client(n_workers=num_workers, memory_limit=memory_limit)
     print(client.dashboard_link)
diff --git a/scripts/train.py b/scripts/actions_based/train.py
similarity index 88%
rename from scripts/train.py
rename to scripts/actions_based/train.py
index d3216f8..67a4f4e 100755
--- a/scripts/train.py
+++ b/scripts/actions_based/train.py
@@ -12,20 +12,20 @@ from src.utils import PROJECT_ROOT, get_config, convert_to_timedelta
 from src.processing import ACTIONS_KEYS
 from datetime import datetime
 
-INPUT_PATH = f"{PROJECT_ROOT}/generated/stage4_reindexing"
-OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints"
+INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions"
 
 if __name__ == "__main__":
     config = get_config()
-    learning_rate = config['training']['learning_rate']
-    num_epochs = config['training']['num_epochs']
-    batch_size = config['training']['batch_size']
-    save_step = config['training']['save_step']
-    loss_averaging_span = config['training']['loss_averaging_span']
-    fresh_start = config['training']['fresh_start']
-    device_name = config['training']['device']
+    learning_rate = config['actions']['training']['learning_rate']
+    num_epochs = config['actions']['training']['num_epochs']
+    batch_size = config['actions']['training']['batch_size']
+    save_step = config['actions']['training']['save_step']
+    loss_averaging_span = config['actions']['training']['loss_averaging_span']
+    fresh_start = config['actions']['training']['fresh_start']
+    device_name = config['actions']['training']['device']
+    max_train_time = config['actions']['training']['max_training_time']
     base_model = config['global']['base_model']
-    max_train_time = config['training']['max_training_time']
 
     if max_train_time is not None:
         max_train_time = convert_to_timedelta(max_train_time)
diff --git a/scripts/dataset_generation/.gitignore b/scripts/dataset_generation/.gitignore
deleted file mode 100644
index ed8ebf5..0000000
--- a/scripts/dataset_generation/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-__pycache__
\ No newline at end of file
diff --git a/scripts/translation_based/__init__.py b/scripts/translation_based/__init__.py
new file mode 100644
index 0000000..e69de29
-- 
GitLab


From 373616d5dab018b800cb1bf14f492d6cd2c047da Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 30 Jul 2020 15:32:28 +0200
Subject: [PATCH 022/116] Made sure all folders will exists when needed

---
 .devcontainer/devcontainer.json              |  9 ++++++---
 scripts/actions_based/stage1_extraction.py   |  8 +++++---
 scripts/actions_based/stage2_tokenization.py |  4 +++-
 scripts/actions_based/stage3_exploding.py    |  4 +++-
 scripts/actions_based/stage4_reindexing.py   |  4 +++-
 scripts/actions_based/train.py               |  7 +++----
 src/utils.py                                 | 15 +++++++++++++++
 7 files changed, 38 insertions(+), 13 deletions(-)

diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index 73ee0b8..c0f5fbc 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -27,9 +27,12 @@
 	],
 
 	"mounts": [
-		"source=${localEnv:HOME}/.gitconfig,target=/root/.gitconfig,type=bind",
-		"source=${localEnv:HOME}/.aws,target=/root/.aws,type=bind"
-	]
+		"source=/home/mpogoda/.gitconfig,target=/root/.gitconfig,type=bind",
+		"source=/home/mpogoda/.aws,target=/root/.aws,type=bind"
+	],
+
+	"workspaceMount": "source=/home/mpogoda/mnt/interpunkcja,target=/workspace,type=bind,consistency=cached",
+	"workspaceFolder": "/workspace",
 
 	// Uncomment the next line to run commands after the container is created - for example installing curl.
 	// "postCreateCommand": "apt-get update && apt-get install -y curl",
diff --git a/scripts/actions_based/stage1_extraction.py b/scripts/actions_based/stage1_extraction.py
index 0c44f14..171051d 100644
--- a/scripts/actions_based/stage1_extraction.py
+++ b/scripts/actions_based/stage1_extraction.py
@@ -12,10 +12,10 @@ from memory_profiler import profile
 from pympler import muppy, summary
 import stackimpact
 import lorem
-from src.utils import get_config, PROJECT_ROOT
+from src.utils import get_config, PROJECT_ROOT, prepare_folder
 
-GENERATED_FOLDER = "generated/actions"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/{GENERATED_FOLDER}/stage1_extraction"
+INPUT_FOLDER = f"{PROJECT_ROOT}/data"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
 
 def process_file(x):
     full_text = text_from_xml(x.file)
@@ -36,6 +36,8 @@ if __name__ == "__main__":
     num_workers = config['actions']['extraction']['num_workers']
     memory_limit = config['actions']['extraction']['worker_memory_limit']
 
+    prepare_folder(OUTPUT_FOLDER)
+
     file_schema = "data/**/text_structure.xml"
     files_paths = glob.glob(file_schema, recursive=True)
 
diff --git a/scripts/actions_based/stage2_tokenization.py b/scripts/actions_based/stage2_tokenization.py
index ad9d64d..a6cb763 100644
--- a/scripts/actions_based/stage2_tokenization.py
+++ b/scripts/actions_based/stage2_tokenization.py
@@ -10,7 +10,7 @@ import re
 import numpy as np
 from tqdm import tqdm
 from src.processing import tokenize_labeled_text, batchify_data
-from src.utils import remove_multiple_spaces, remove_punctuation, PROJECT_ROOT, get_config
+from src.utils import remove_multiple_spaces, remove_punctuation, PROJECT_ROOT, get_config, prepare_folder
 import dask
 from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
@@ -63,6 +63,8 @@ if __name__ == "__main__":
     memory_limit = config['actions']['tokenization']['worker_memory_limit']
     base_model = config['global']['base_model']
 
+    prepare_folder(OUTPUT_FOLDER)
+
     client = Client(n_workers=num_workers, memory_limit=memory_limit)
     print(client.dashboard_link)
 
diff --git a/scripts/actions_based/stage3_exploding.py b/scripts/actions_based/stage3_exploding.py
index 2aa1637..7c0b0d4 100644
--- a/scripts/actions_based/stage3_exploding.py
+++ b/scripts/actions_based/stage3_exploding.py
@@ -6,7 +6,7 @@ import numpy as np
 import dask
 from dask.distributed import Client
 import pandas as pd
-from src.utils import PROJECT_ROOT, get_config
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
@@ -55,6 +55,8 @@ if __name__ == "__main__":
     num_workers = config['actions']['exploding']['num_workers']
     memory_limit = config['actions']['exploding']['worker_memory_limit']
 
+    prepare_folder(OUTPUT_FOLDER)
+
     client = Client(n_workers=num_workers, memory_limit=memory_limit)
     print(client.dashboard_link)
 
diff --git a/scripts/actions_based/stage4_reindexing.py b/scripts/actions_based/stage4_reindexing.py
index fb4e6d4..340b480 100644
--- a/scripts/actions_based/stage4_reindexing.py
+++ b/scripts/actions_based/stage4_reindexing.py
@@ -7,7 +7,7 @@ import numpy as np
 import dask
 from dask.distributed import Client
 import pandas as pd
-from src.utils import PROJECT_ROOT, get_config
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
@@ -17,6 +17,8 @@ if __name__ == "__main__":
     num_workers = config['actions']['reindexing']['num_workers']
     memory_limit = config['actions']['reindexing']['worker_memory_limit']
 
+    prepare_folder(OUTPUT_FOLDER)
+
     client = Client(n_workers=num_workers, memory_limit=memory_limit)
     print(client.dashboard_link)
 
diff --git a/scripts/actions_based/train.py b/scripts/actions_based/train.py
index 67a4f4e..249641c 100755
--- a/scripts/actions_based/train.py
+++ b/scripts/actions_based/train.py
@@ -8,7 +8,7 @@ import numpy as np
 import dask.dataframe as dd
 import os
 import glob
-from src.utils import PROJECT_ROOT, get_config, convert_to_timedelta
+from src.utils import PROJECT_ROOT, get_config, convert_to_timedelta, prepare_folder
 from src.processing import ACTIONS_KEYS
 from datetime import datetime
 
@@ -27,6 +27,8 @@ if __name__ == "__main__":
     max_train_time = config['actions']['training']['max_training_time']
     base_model = config['global']['base_model']
 
+    prepare_folder(OUTPUT_PATH)
+
     if max_train_time is not None:
         max_train_time = convert_to_timedelta(max_train_time)
 
@@ -34,9 +36,6 @@ if __name__ == "__main__":
     print(f"Training on {device}")
 
     df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
-
-    if not os.path.exists(f"{OUTPUT_PATH}"):
-        os.makedirs(f"{OUTPUT_PATH}", exist_ok=True)
     
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
diff --git a/src/utils.py b/src/utils.py
index 5dd9fc7..bf36f03 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -3,6 +3,7 @@ import re
 import os
 from datetime import timedelta
 from typing import Optional
+import shutil
 
 PROJECT_ROOT=os.path.dirname(os.path.realpath("/".join(__file__.split("/")) + "/.."))
 
@@ -42,6 +43,20 @@ def remove_punctuation(text: str) -> str:
     """
     return ''.join(filter(lambda x: x.isalnum() or x.isspace(), text))
 
+def prepare_folder(path: str, wipe: bool = False) -> None:
+    """Function make sure that provided path exists. Can aditionaly
+    remove all files from the path.
+
+    Args:
+        path (str): Full directory path
+        wipe (bool): Wheter to remove all files in folder
+    """
+
+    if wipe:
+        shutil.rmtree(path)
+
+    os.makedirs(path, exist_ok=True)
+
 def convert_to_timedelta(time_val: str) -> Optional[timedelta]:
     """
     src: https://code.activestate.com/recipes/577894-convert-strings-like-5d-and-60s-to-timedelta-objec/
-- 
GitLab


From 093d3a76e3c4ec198c690dc980fbfc0576e0cfec Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 30 Jul 2020 16:02:51 +0200
Subject: [PATCH 023/116] Extraction stage for translation type of model

---
 dvc.lock                          | 11 +++++++++++
 dvc.yaml                          |  8 ++++++++
 generated/translations/.gitignore |  1 +
 3 files changed, 20 insertions(+)
 create mode 100644 generated/translations/.gitignore

diff --git a/dvc.lock b/dvc.lock
index 23b393d..f53e202 100644
--- a/dvc.lock
+++ b/dvc.lock
@@ -45,3 +45,14 @@ reindexing:
   outs:
   - path: generated/stage4_reindexing
     md5: 9e797430fe072a60e778e191606a9952.dir
+translations_extraction:
+  cmd: python3 -m scripts.translation_based.stage1_extraction
+  deps:
+  - path: data
+    md5: 1fa175e752af1638dc896838e82a9d7d.dir
+  params:
+    params.yaml:
+      translations.extraction.num_partitions: 2000
+  outs:
+  - path: generated/translations/stage1_extraction
+    md5: 61a1a88c672e485fd9b0dc0ef22817a9.dir
diff --git a/dvc.yaml b/dvc.yaml
index 32ee641..a9ea952 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -47,3 +47,11 @@ stages:
     - actions.training.save_step
     outs:
     - checkpoints/actions
+  translations_extraction:
+    cmd: python3 -m scripts.translation_based.stage1_extraction
+    deps:
+    - data
+    params:
+    - translations.extraction.num_partitions
+    outs:
+    - generated/translations/stage1_extraction
diff --git a/generated/translations/.gitignore b/generated/translations/.gitignore
new file mode 100644
index 0000000..f94a83f
--- /dev/null
+++ b/generated/translations/.gitignore
@@ -0,0 +1 @@
+/stage1_extraction
-- 
GitLab


From 7c87f3a37aa826322c57ec84a0cc70767c0cd7b8 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 30 Jul 2020 16:05:27 +0200
Subject: [PATCH 024/116] Extraction stage for translation type of model

---
 generated/actions/.gitignore                  |  4 ++
 params.yaml                                   |  7 +++-
 scripts/translation_based/processing.py       | 24 ++++++++++++
 .../translation_based/stage1_extraction.py    | 38 +++++++++++++++++++
 4 files changed, 72 insertions(+), 1 deletion(-)
 create mode 100644 generated/actions/.gitignore
 create mode 100644 scripts/translation_based/processing.py
 create mode 100644 scripts/translation_based/stage1_extraction.py

diff --git a/generated/actions/.gitignore b/generated/actions/.gitignore
new file mode 100644
index 0000000..959c3a4
--- /dev/null
+++ b/generated/actions/.gitignore
@@ -0,0 +1,4 @@
+/stage1_extraction
+/stage2_tokenization
+/stage3_exploding
+/stage4_reindexing
diff --git a/params.yaml b/params.yaml
index f27696d..40db37f 100644
--- a/params.yaml
+++ b/params.yaml
@@ -30,4 +30,9 @@ actions:
         max_training_time: "30s"
         loss_averaging_span: 1000
         fresh_start: false
-        device: "cuda:0"
\ No newline at end of file
+        device: "cuda:0"
+translations:
+    extraction:
+        num_partitions: 2_000
+        num_workers: 24
+        worker_memory_limit: "2GB"
\ No newline at end of file
diff --git a/scripts/translation_based/processing.py b/scripts/translation_based/processing.py
new file mode 100644
index 0000000..e4527b9
--- /dev/null
+++ b/scripts/translation_based/processing.py
@@ -0,0 +1,24 @@
+import dask.dataframe as dd
+from src.processing import text_from_xml
+
+def raw_to_dataframe(x: dd.DataFrame):
+    """Converts dask datarfame containing files paths into
+    dataframe with content of that files (text only)
+
+    Args:
+        x (DataFrame): Dask dataframe with one column (file)
+
+    Returns:
+        DataFrame: Dask dataframe with format {input: str}. Can have null entries
+    """
+    full_text = text_from_xml(x.file)
+
+    if len(full_text) > 0:
+        return {'input': full_text}
+    else:
+        return {'input': None}
+
+RAW_TO_DATAFRAME_META = {
+    'input': str
+}
+
diff --git a/scripts/translation_based/stage1_extraction.py b/scripts/translation_based/stage1_extraction.py
new file mode 100644
index 0000000..379b959
--- /dev/null
+++ b/scripts/translation_based/stage1_extraction.py
@@ -0,0 +1,38 @@
+# /usr/bin/python3
+from scripts.translation_based.processing import raw_to_dataframe, RAW_TO_DATAFRAME_META
+from src.utils import PROJECT_ROOT, prepare_folder, get_config
+from glob import glob
+import numpy as np
+from dask.distributed import Client
+import dask.dataframe as dd
+import pandas as pd
+
+INPUT_FOLDER = f"{PROJECT_ROOT}/data"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
+
+if __name__ == "__main__":
+
+    config = get_config()
+    num_partitions = config['translations']['extraction']['num_partitions']
+    num_workers = config['translations']['extraction']['num_workers']
+    memory_limit = config['translations']['extraction']['worker_memory_limit']
+
+    prepare_folder(OUTPUT_FOLDER)
+
+    file_schema = "data/**/text_structure.xml"
+    files_paths = glob(file_schema, recursive=True)
+
+    # Make sure python memory fragmentation won't go insane
+    np.random.shuffle(files_paths)
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
+    print(f"Dashboard: {client.dashboard_link}")
+
+    # Processing pipeline
+    df = dd.from_pandas(pd.DataFrame({'file': files_paths}), npartitions=num_partitions)
+
+    df = df.apply(raw_to_dataframe, result_type='expand', axis=1, meta=RAW_TO_DATAFRAME_META)
+    df = df.dropna()
+
+    # Export
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
\ No newline at end of file
-- 
GitLab


From 3ea80e9348b8ddf15576af807112d9fe49a689b0 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Sat, 1 Aug 2020 15:50:33 +0200
Subject: [PATCH 025/116] Added randomization, first stages of translation
 pipeline

---
 dvc.lock                                      | 11 +++
 dvc.yaml                                      |  8 ++
 generated/translations/.gitignore             |  1 +
 notebooks/dask_dataframe_exploration.ipynb    | 69 +++++++++++++
 notebooks/dask_functionality_test.ipynb       | 81 ++++++++++++++++
 notebooks/tokenizer_testing.ipynb             | 96 +++++++++++++++++++
 params.yaml                                   |  4 +
 scripts/actions_based/stage4_reindexing.py    | 12 ++-
 scripts/translation_based/processing.py       | 50 +++++++++-
 .../translation_based/stage1_extraction.py    |  4 +-
 .../translation_based/stage2_tokenization.py  | 32 +++++++
 .../translation_based/stage3_batchifying.py   | 32 +++++++
 12 files changed, 391 insertions(+), 9 deletions(-)
 create mode 100644 notebooks/dask_dataframe_exploration.ipynb
 create mode 100644 notebooks/dask_functionality_test.ipynb
 create mode 100644 notebooks/tokenizer_testing.ipynb
 create mode 100644 scripts/translation_based/stage2_tokenization.py
 create mode 100644 scripts/translation_based/stage3_batchifying.py

diff --git a/dvc.lock b/dvc.lock
index f53e202..acd48b5 100644
--- a/dvc.lock
+++ b/dvc.lock
@@ -56,3 +56,14 @@ translations_extraction:
   outs:
   - path: generated/translations/stage1_extraction
     md5: 61a1a88c672e485fd9b0dc0ef22817a9.dir
+translations_tokenization:
+  cmd: python3 -m scripts.translation_based.stage2_tokenization
+  deps:
+  - path: generated/translations/stage1_extraction
+    md5: 61a1a88c672e485fd9b0dc0ef22817a9.dir
+  params:
+    params.yaml:
+      global.base_model: bert-base-multilingual-cased
+  outs:
+  - path: generated/translations/stage2_tokenization
+    md5: b4132fb48d63c09ee5fd5e017f5c279c.dir
diff --git a/dvc.yaml b/dvc.yaml
index a9ea952..3d51214 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -55,3 +55,11 @@ stages:
     - translations.extraction.num_partitions
     outs:
     - generated/translations/stage1_extraction
+  translations_tokenization:
+    cmd: python3 -m scripts.translation_based.stage2_tokenization
+    deps:
+    - generated/translations/stage1_extraction
+    params:
+    - global.base_model
+    outs:
+    - generated/translations/stage2_tokenization
diff --git a/generated/translations/.gitignore b/generated/translations/.gitignore
index f94a83f..0da25ba 100644
--- a/generated/translations/.gitignore
+++ b/generated/translations/.gitignore
@@ -1 +1,2 @@
 /stage1_extraction
+/stage2_tokenization
diff --git a/notebooks/dask_dataframe_exploration.ipynb b/notebooks/dask_dataframe_exploration.ipynb
new file mode 100644
index 0000000..568f632
--- /dev/null
+++ b/notebooks/dask_dataframe_exploration.ipynb
@@ -0,0 +1,69 @@
+{
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.2-final"
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python38264bita7d7da14168440cb9836372958035d4a",
+   "display_name": "Python 3.8.2 64-bit"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import dask.dataframe as dd"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "df = dd.read_parquet(\"../generated/translations/stage1_extraction\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "                                                input\n15   Otwieram posiedzenie Komisji. Dziś odbyło się...\n18   Wznawiam posiedzenie. Na sekretarzy powołuję ...\n21   Otwieram posiedzenie. Protokół 88 posiedzenia...\n25   Proszę państwa, otwieram wspólne posiedzenie ...\n29   Otwieram posiedzenie Komisji Budżetu i Finans...",
+      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>input</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>15</th>\n      <td>Otwieram posiedzenie Komisji. Dziś odbyło się...</td>\n    </tr>\n    <tr>\n      <th>18</th>\n      <td>Wznawiam posiedzenie. Na sekretarzy powołuję ...</td>\n    </tr>\n    <tr>\n      <th>21</th>\n      <td>Otwieram posiedzenie. Protokół 88 posiedzenia...</td>\n    </tr>\n    <tr>\n      <th>25</th>\n      <td>Proszę państwa, otwieram wspólne posiedzenie ...</td>\n    </tr>\n    <tr>\n      <th>29</th>\n      <td>Otwieram posiedzenie Komisji Budżetu i Finans...</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
+     },
+     "metadata": {},
+     "execution_count": 3
+    }
+   ],
+   "source": [
+    "df.head()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ]
+}
\ No newline at end of file
diff --git a/notebooks/dask_functionality_test.ipynb b/notebooks/dask_functionality_test.ipynb
new file mode 100644
index 0000000..ee74d31
--- /dev/null
+++ b/notebooks/dask_functionality_test.ipynb
@@ -0,0 +1,81 @@
+{
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.2-final"
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python38264bita7d7da14168440cb9836372958035d4a",
+   "display_name": "Python 3.8.2 64-bit"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import dask\n",
+    "import dask.dataframe as dd\n",
+    "import pandas as pd\n",
+    "import numpy as np"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pdf = pd.DataFrame({'x': [1,2,3, 4, 5], 'y': ['a', 'b', 'c', 'd', 'e']})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "   x  y  ones\n0  1  a     1\n1  2  b     2\n2  3  c     3\n3  4  d     4\n4  5  e     5",
+      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>x</th>\n      <th>y</th>\n      <th>ones</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>1</td>\n      <td>a</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>2</td>\n      <td>b</td>\n      <td>2</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>3</td>\n      <td>c</td>\n      <td>3</td>\n    </tr>\n    <tr>\n      <th>3</th>\n      <td>4</td>\n      <td>d</td>\n      <td>4</td>\n    </tr>\n    <tr>\n      <th>4</th>\n      <td>5</td>\n      <td>e</td>\n      <td>5</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
+     },
+     "metadata": {},
+     "execution_count": 25
+    }
+   ],
+   "source": [
+    "df = dd.from_pandas(pdf, npartitions=2)\n",
+    "df = df.assign(ones=1)\n",
+    "df.ones = df.ones.cumsum()\n",
+    "\n",
+    "order_indexes == df.ones.compute()\n",
+    "random_indexes = df.ones.compute()\n",
+    "np.random.shuffle(random_indexes)\n",
+    "mapping = \n",
+    "\n",
+    "df.compute()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ]
+}
\ No newline at end of file
diff --git a/notebooks/tokenizer_testing.ipynb b/notebooks/tokenizer_testing.ipynb
new file mode 100644
index 0000000..c4b66f6
--- /dev/null
+++ b/notebooks/tokenizer_testing.ipynb
@@ -0,0 +1,96 @@
+{
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.2-final"
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python38264bita7d7da14168440cb9836372958035d4a",
+   "display_name": "Python 3.8.2 64-bit"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from transformers import BertTokenizerFast"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "base_model = \"bert-base-multilingual-cased\"\n",
+    "tokenizer = BertTokenizerFast.from_pretrained(base_model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "{'input_ids': [101, 56500, 10824, 16469, 177, 39327, 59726, 10132, 348, 11335, 68497, 119, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
+     },
+     "metadata": {},
+     "execution_count": 7
+    }
+   ],
+   "source": [
+    "tokenizer(\"Ala ma kota i poszła na śniadanie.\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dot_token = tokenizer(\".\")['input_ids'][1]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "119"
+     },
+     "metadata": {},
+     "execution_count": 10
+    }
+   ],
+   "source": [
+    "dot_token"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ]
+}
\ No newline at end of file
diff --git a/params.yaml b/params.yaml
index 40db37f..d64c249 100644
--- a/params.yaml
+++ b/params.yaml
@@ -34,5 +34,9 @@ actions:
 translations:
     extraction:
         num_partitions: 2_000
+        num_workers: 24
+        worker_memory_limit: "2GB"
+
+    tokenization:
         num_workers: 24
         worker_memory_limit: "2GB"
\ No newline at end of file
diff --git a/scripts/actions_based/stage4_reindexing.py b/scripts/actions_based/stage4_reindexing.py
index 340b480..b436ca2 100644
--- a/scripts/actions_based/stage4_reindexing.py
+++ b/scripts/actions_based/stage4_reindexing.py
@@ -24,11 +24,17 @@ if __name__ == "__main__":
 
     df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
 
+    # Add ordered indexes
     df = df.assign(ones=1)
     df = df.reset_index(drop=True)
-    idx = df.ones.cumsum().persist()
+    idx = (df.ones.cumsum() - 1).persist()
     df = df.assign(ones=idx)
-    #df = df.assign(idx=df.idx - 1)
+
+    # Shuffle 
+    shuffled_idx = idx.compute().values
+    shuffled_idx = client.scatter(shuffled_idx)
+    mapped_ones = df.ones.apply(lambda x, idx: idx[x], args=(shuffled_idx,), meta=('ones', 'int64')).persist()
+    df = df.assign(ones=mapped_ones)
+
     df = df.set_index('ones')
- 
     df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/scripts/translation_based/processing.py b/scripts/translation_based/processing.py
index e4527b9..633e902 100644
--- a/scripts/translation_based/processing.py
+++ b/scripts/translation_based/processing.py
@@ -1,17 +1,19 @@
 import dask.dataframe as dd
 from src.processing import text_from_xml
+from transformers import BertTokenizerFast
+import numpy as np
 
-def raw_to_dataframe(x: dd.DataFrame):
+def raw_to_dataframe(entry: dict) -> dict:
     """Converts dask datarfame containing files paths into
     dataframe with content of that files (text only)
 
     Args:
-        x (DataFrame): Dask dataframe with one column (file)
+        x (dict): Dask dataframe entry with one column ('file')
 
     Returns:
-        DataFrame: Dask dataframe with format {input: str}. Can have null entries
+        dict: Dask dataframe entry with format {'input': str}. Can have null entries
     """
-    full_text = text_from_xml(x.file)
+    full_text = text_from_xml(entry.file)
 
     if len(full_text) > 0:
         return {'input': full_text}
@@ -22,3 +24,43 @@ RAW_TO_DATAFRAME_META = {
     'input': str
 }
 
+def apply_tokenization(entry: dict, tokenizer: BertTokenizerFast) -> dict:
+    """Converts raw text entries into list of tokens
+
+    Args:
+        x (dict): Dask dataframe entry with one column ['input'] containing text
+        tokenizer (BertTokenizerFast): Tokenizer used to tokenize. Must be a deleayed object to prevent memory leak!
+
+    Returns:
+        dict: Dask dataset entry with one column ('tokens') containing np.array list of tokens
+    """
+    text_tokenized = tokenizer(entry.input)['input_ids'][1:-1]
+
+    return {
+        'tokens': np.array(text_tokenized)
+    }
+
+APPLY_TOKENIZATION_META = {
+    'tokens': object
+}
+
+def split_into_batches(entry: dict, stopping_token: int) -> dict:
+    """Converts raw text entries into list of tokens
+
+    Args:
+        x (dict): Dask dataframe entry with one column ['input'] containing text
+        tokenizer (BertTokenizerFast): Tokenizer used to tokenize. Must be a deleayed object to prevent memory leak!
+
+    Returns:
+        dict: Dask dataset entry with one column ('tokens') containing np.array list of tokens
+    """
+    text_tokenized = tokenizer(entry.input)['input_ids'][1:-1]
+
+    return {
+        'tokens': np.array(text_tokenized)
+    }
+
+APPLY_TOKENIZATION_META = {
+    'tokens': object
+}
+
diff --git a/scripts/translation_based/stage1_extraction.py b/scripts/translation_based/stage1_extraction.py
index 379b959..159c7c8 100644
--- a/scripts/translation_based/stage1_extraction.py
+++ b/scripts/translation_based/stage1_extraction.py
@@ -19,8 +19,8 @@ if __name__ == "__main__":
 
     prepare_folder(OUTPUT_FOLDER)
 
-    file_schema = "data/**/text_structure.xml"
-    files_paths = glob(file_schema, recursive=True)
+    file_schema = f"{INPUT_FOLDER}/**/text_structure.xml"
+    files_paths = glob(file_schema, recursive=True) 
 
     # Make sure python memory fragmentation won't go insane
     np.random.shuffle(files_paths)
diff --git a/scripts/translation_based/stage2_tokenization.py b/scripts/translation_based/stage2_tokenization.py
new file mode 100644
index 0000000..2768d2c
--- /dev/null
+++ b/scripts/translation_based/stage2_tokenization.py
@@ -0,0 +1,32 @@
+# /usr/bin/python3
+from scripts.translation_based.processing import apply_tokenization, APPLY_TOKENIZATION_META
+from src.utils import PROJECT_ROOT, prepare_folder, get_config
+import numpy as np
+from dask.distributed import Client
+from transformers import BertTokenizerFast
+import dask.dataframe as dd
+from dask import delayed
+
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage2_tokenization"
+
+if __name__ == "__main__":
+
+    config = get_config()
+    num_workers = config['translations']['tokenization']['num_workers']
+    memory_limit = config['translations']['tokenization']['worker_memory_limit']
+    base_model = config['global']['base_model']
+
+    prepare_folder(OUTPUT_FOLDER)
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
+    print(f"Dashboard: {client.dashboard_link}")
+
+    tokenizer = BertTokenizerFast.from_pretrained(base_model)
+    tokenizer = delayed(tokenizer)
+
+    df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
+    df = df.apply(apply_tokenization, result_type='expand', axis=1, meta=APPLY_TOKENIZATION_META, args=(tokenizer,))
+
+    # Export
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
\ No newline at end of file
diff --git a/scripts/translation_based/stage3_batchifying.py b/scripts/translation_based/stage3_batchifying.py
new file mode 100644
index 0000000..2768d2c
--- /dev/null
+++ b/scripts/translation_based/stage3_batchifying.py
@@ -0,0 +1,32 @@
+# /usr/bin/python3
+from scripts.translation_based.processing import apply_tokenization, APPLY_TOKENIZATION_META
+from src.utils import PROJECT_ROOT, prepare_folder, get_config
+import numpy as np
+from dask.distributed import Client
+from transformers import BertTokenizerFast
+import dask.dataframe as dd
+from dask import delayed
+
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage2_tokenization"
+
+if __name__ == "__main__":
+
+    config = get_config()
+    num_workers = config['translations']['tokenization']['num_workers']
+    memory_limit = config['translations']['tokenization']['worker_memory_limit']
+    base_model = config['global']['base_model']
+
+    prepare_folder(OUTPUT_FOLDER)
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
+    print(f"Dashboard: {client.dashboard_link}")
+
+    tokenizer = BertTokenizerFast.from_pretrained(base_model)
+    tokenizer = delayed(tokenizer)
+
+    df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
+    df = df.apply(apply_tokenization, result_type='expand', axis=1, meta=APPLY_TOKENIZATION_META, args=(tokenizer,))
+
+    # Export
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
\ No newline at end of file
-- 
GitLab


From b759bfc86a5001a1c85266b0737c0f1629dc509e Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Sat, 1 Aug 2020 18:09:23 +0000
Subject: [PATCH 026/116] Added missing lxml dependency in dockerfile

---
 docker/Dockerfile | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/docker/Dockerfile b/docker/Dockerfile
index c702b69..5672adf 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -37,5 +37,5 @@ ENV NVIDIA_REQUIRE_CUDA "cuda>=10.2 brand=tesla,driver>=384,driver<385 brand=tes
 
 ### END CUDA Installation
 
-RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pyarrow pytest
+RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pyarrow pytest lxml
 RUN ln -s /usr/bin/pip3 /usr/bin/pip
\ No newline at end of file
-- 
GitLab


From 33d70ff2f35518bb5bcd0703ebaac5cc055b4e57 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 3 Aug 2020 18:16:55 +0200
Subject: [PATCH 027/116] Actions pipeline, few bugs

---
 __init__.py                                   |   0
 dvc.lock                                      |  66 +++++
 dvc.yaml                                      |   8 +-
 generated/translations/.gitignore             |   2 +-
 notebooks/dask_dataframe_exploration.ipynb    |  32 ++-
 notebooks/test.ipynb                          |   8 +-
 notebooks/test_bert_dimensions.ipynb          | 134 ++++++++++
 notebooks/test_training_files.ipynb           |   2 +-
 notebooks/tokenizer_testing.ipynb             |  74 ++++--
 params.yaml                                   |  10 +-
 scripts/actions_based/stage5_loss_weights.py  |  40 +++
 scripts/actions_based/train.py                |   3 +-
 scripts/translation_based/processing.py       | 230 +++++++++++++++++-
 ...kenization.py => stage2_create_batches.py} |  14 +-
 .../translation_based/stage3_batchifying.py   |  32 ---
 scripts/translation_based/stage3_exploding.py |  69 ++++++
 .../translation_based/stage4_reindexing.py    |  40 +++
 scripts/translation_based/test_processing.py  | 136 +++++++++++
 src/processing.py                             |  11 +-
 19 files changed, 821 insertions(+), 90 deletions(-)
 delete mode 100644 __init__.py
 create mode 100644 notebooks/test_bert_dimensions.ipynb
 create mode 100644 scripts/actions_based/stage5_loss_weights.py
 rename scripts/translation_based/{stage2_tokenization.py => stage2_create_batches.py} (53%)
 delete mode 100644 scripts/translation_based/stage3_batchifying.py
 create mode 100644 scripts/translation_based/stage3_exploding.py
 create mode 100644 scripts/translation_based/stage4_reindexing.py
 create mode 100644 scripts/translation_based/test_processing.py

diff --git a/__init__.py b/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/dvc.lock b/dvc.lock
index acd48b5..54498b9 100644
--- a/dvc.lock
+++ b/dvc.lock
@@ -67,3 +67,69 @@ translations_tokenization:
   outs:
   - path: generated/translations/stage2_tokenization
     md5: b4132fb48d63c09ee5fd5e017f5c279c.dir
+actions_extraction:
+  cmd: python3 -m scripts.actions_based.stage1_extraction
+  deps:
+  - path: data
+    md5: 1fa175e752af1638dc896838e82a9d7d.dir
+  - path: scripts/actions_based/stage1_extraction.py
+    md5: a01f6ee74e165e7c3d6b21c648482d45
+  params:
+    params.yaml:
+      actions.extraction.num_partitions: 2000
+  outs:
+  - path: generated/actions/stage1_extraction
+    md5: 8c9d822cc101faf137bd54932c94f922.dir
+actions_tokenization:
+  cmd: python3 -m scripts.actions_based.stage2_tokenization
+  deps:
+  - path: generated/actions/stage1_extraction
+    md5: 8c9d822cc101faf137bd54932c94f922.dir
+  - path: scripts/actions_based/stage2_tokenization.py
+    md5: 6360e7facd4af85d2deb0deabb4cc448
+  params:
+    params.yaml:
+      actions.tokenization.max_tokens: 500
+      actions.tokenization.min_tokens: 10
+      global.base_model: dkleczek/bert-base-polish-cased-v1
+  outs:
+  - path: generated/actions/stage2_tokenization
+    md5: a1a31dc4baa92b775e44c335e3f75a9c.dir
+actions_exploding:
+  cmd: python3 -m scripts.actions_based.stage3_exploding
+  deps:
+  - path: generated/actions/stage2_tokenization
+    md5: a1a31dc4baa92b775e44c335e3f75a9c.dir
+  - path: scripts/actions_based/stage3_exploding.py
+    md5: f65f552b17d012c53b5a42406cb88bcd
+  outs:
+  - path: generated/actions/stage3_exploding
+    md5: 6db856e40b88769840799232b23c2058.dir
+actions_reindexing:
+  cmd: python3 -m scripts.actions_based.stage4_reindexing
+  deps:
+  - path: generated/actions/stage3_exploding
+    md5: 6db856e40b88769840799232b23c2058.dir
+  - path: scripts/actions_based/stage4_reindexing.py
+    md5: 7841f8c3acdc12a5dc0adef12b8b8cbc
+  outs:
+  - path: generated/actions/stage4_reindexing
+    md5: 446e8e2b2011af28fcfa63557c2b5808.dir
+actions_training:
+  cmd: python3 -m scripts.actions_based.train
+  deps:
+  - path: generated/actions/stage4_reindexing
+    md5: 446e8e2b2011af28fcfa63557c2b5808.dir
+  - path: scripts/actions_based/train.py
+    md5: ef61cad42a6be6f862051530bbc6965b
+  params:
+    params.yaml:
+      actions.training.batch_size: 2
+      actions.training.learning_rate: 0.0001
+      actions.training.max_training_time: 2m
+      actions.training.num_epochs: 5
+      actions.training.save_step: 1000
+      global.base_model: dkleczek/bert-base-polish-cased-v1
+  outs:
+  - path: checkpoints/actions
+    md5: 6116b19bae31f503a350b635125a6daf.dir
diff --git a/dvc.yaml b/dvc.yaml
index 3d51214..3275740 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -37,7 +37,7 @@ stages:
     cmd: python3 -m scripts.actions_based.train
     deps:
     - generated/actions/stage4_reindexing
-    - scripts/actions/train.py
+    - scripts/actions_based/train.py
     params:
     - global.base_model
     - actions.training.max_training_time
@@ -55,11 +55,11 @@ stages:
     - translations.extraction.num_partitions
     outs:
     - generated/translations/stage1_extraction
-  translations_tokenization:
-    cmd: python3 -m scripts.translation_based.stage2_tokenization
+  translations_create_batches:
+    cmd: python3 -m scripts.translation_based.stage2_create_batches
     deps:
     - generated/translations/stage1_extraction
     params:
     - global.base_model
     outs:
-    - generated/translations/stage2_tokenization
+    - generated/translations/create_batches
diff --git a/generated/translations/.gitignore b/generated/translations/.gitignore
index 0da25ba..19240d7 100644
--- a/generated/translations/.gitignore
+++ b/generated/translations/.gitignore
@@ -1,2 +1,2 @@
 /stage1_extraction
-/stage2_tokenization
+/stage2_create_batches
diff --git a/notebooks/dask_dataframe_exploration.ipynb b/notebooks/dask_dataframe_exploration.ipynb
index 568f632..59f6caf 100644
--- a/notebooks/dask_dataframe_exploration.ipynb
+++ b/notebooks/dask_dataframe_exploration.ipynb
@@ -23,35 +23,49 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": 41,
    "metadata": {},
    "outputs": [],
    "source": [
-    "import dask.dataframe as dd"
+    "import dask.dataframe as dd\n",
+    "import numpy as np"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 42,
    "metadata": {},
    "outputs": [],
    "source": [
-    "df = dd.read_parquet(\"../generated/translations/stage1_extraction\")"
+    "df = dd.read_parquet(\"../generated/translations/stage2_create_batches\", engine=\"pyarrow\")"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {},
+   "execution_count": 43,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "shapes = df.source_shape.compute()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 44,
+   "metadata": {
+    "tags": []
+   },
    "outputs": [
     {
      "output_type": "execute_result",
      "data": {
-      "text/plain": "                                                input\n15   Otwieram posiedzenie Komisji. Dziś odbyło się...\n18   Wznawiam posiedzenie. Na sekretarzy powołuję ...\n21   Otwieram posiedzenie. Protokół 88 posiedzenia...\n25   Proszę państwa, otwieram wspólne posiedzenie ...\n29   Otwieram posiedzenie Komisji Budżetu i Finans...",
-      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>input</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>15</th>\n      <td>Otwieram posiedzenie Komisji. Dziś odbyło się...</td>\n    </tr>\n    <tr>\n      <th>18</th>\n      <td>Wznawiam posiedzenie. Na sekretarzy powołuję ...</td>\n    </tr>\n    <tr>\n      <th>21</th>\n      <td>Otwieram posiedzenie. Protokół 88 posiedzenia...</td>\n    </tr>\n    <tr>\n      <th>25</th>\n      <td>Proszę państwa, otwieram wspólne posiedzenie ...</td>\n    </tr>\n    <tr>\n      <th>29</th>\n      <td>Otwieram posiedzenie Komisji Budżetu i Finans...</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
+      "text/plain": "                                               source  \\\n15  [2, 27476, 7835, 4677, 2822, 11226, 781, 77, 4...   \n18  [2, 1178, 4607, 766, 7835, 752, 58008, 4419, 5...   \n21  [2, 27476, 7835, 25104, 9712, 6901, 8698, 778,...   \n25  [2, 1645, 2160, 27476, 11811, 7835, 4677, 5657...   \n29  [2, 27476, 7835, 4677, 5529, 77, 10814, 4994, ...   \n\n                                               target source_shape  \\\n15  [2, 15482, 7835, 2931, 18, 7331, 11226, 781, 1...     [2, 500]   \n18  [2, 56453, 7835, 18, 922, 58008, 4419, 5482, 4...   [169, 500]   \n21  [2, 15482, 7835, 18, 38648, 9712, 6901, 8698, ...    [94, 500]   \n25  [2, 1513, 2160, 16, 27476, 11811, 7835, 2931, ...    [25, 500]   \n29  [2, 15482, 7835, 2931, 53234, 77, 6789, 17353,...     [3, 500]   \n\n   target_shape  \n15     [2, 500]  \n18   [169, 500]  \n21    [94, 500]  \n25    [25, 500]  \n29     [3, 500]  ",
+      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>source</th>\n      <th>target</th>\n      <th>source_shape</th>\n      <th>target_shape</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>15</th>\n      <td>[2, 27476, 7835, 4677, 2822, 11226, 781, 77, 4...</td>\n      <td>[2, 15482, 7835, 2931, 18, 7331, 11226, 781, 1...</td>\n      <td>[2, 500]</td>\n      <td>[2, 500]</td>\n    </tr>\n    <tr>\n      <th>18</th>\n      <td>[2, 1178, 4607, 766, 7835, 752, 58008, 4419, 5...</td>\n      <td>[2, 56453, 7835, 18, 922, 58008, 4419, 5482, 4...</td>\n      <td>[169, 500]</td>\n      <td>[169, 500]</td>\n    </tr>\n    <tr>\n      <th>21</th>\n      <td>[2, 27476, 7835, 25104, 9712, 6901, 8698, 778,...</td>\n      <td>[2, 15482, 7835, 18, 38648, 9712, 6901, 8698, ...</td>\n      <td>[94, 500]</td>\n      <td>[94, 500]</td>\n    </tr>\n    <tr>\n      <th>25</th>\n      <td>[2, 1645, 2160, 27476, 11811, 7835, 4677, 5657...</td>\n      <td>[2, 1513, 2160, 16, 27476, 11811, 7835, 2931, ...</td>\n      <td>[25, 500]</td>\n      <td>[25, 500]</td>\n    </tr>\n    <tr>\n      <th>29</th>\n      <td>[2, 27476, 7835, 4677, 5529, 77, 10814, 4994, ...</td>\n      <td>[2, 15482, 7835, 2931, 53234, 77, 6789, 17353,...</td>\n      <td>[3, 500]</td>\n      <td>[3, 500]</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
      },
      "metadata": {},
-     "execution_count": 3
+     "execution_count": 44
     }
    ],
    "source": [
diff --git a/notebooks/test.ipynb b/notebooks/test.ipynb
index 22779f2..0318f59 100644
--- a/notebooks/test.ipynb
+++ b/notebooks/test.ipynb
@@ -47,13 +47,7 @@
    "outputs": [],
    "source": [
     "INPUT_PATH=\"../generated/stage4_reindexing\"\n",
-    "MODEL_BASE = \"bert-base-multilingual-cased\"\n",
-    "MODEL_NAME = \"actionv1_500\"\n",
-    "LR = 1e-4\n",
-    "\n",
-    "BATCH_SIZE=8\n",
-    "NUM_EPOCH=5\n",
-    "SAVE_STEP=60_000"
+    "MODEL_BASE = \"dkleczek/bert-base-polish-cased-v1\""
    ]
   },
   {
diff --git a/notebooks/test_bert_dimensions.ipynb b/notebooks/test_bert_dimensions.ipynb
new file mode 100644
index 0000000..8bc4617
--- /dev/null
+++ b/notebooks/test_bert_dimensions.ipynb
@@ -0,0 +1,134 @@
+{
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.2-final"
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python38264bit17f10e31b7e440e591cfca7d4c2c2274",
+   "display_name": "Python 3.8.2 64-bit"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from transformers import BertForMaskedLM, BertTokenizerFast\n",
+    "import torch"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 42,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stderr",
+     "text": "Some weights of the model checkpoint at dkleczek/bert-base-polish-cased-v1 were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+    }
+   ],
+   "source": [
+    "model = BertForMaskedLM.from_pretrained(\"dkleczek/bert-base-polish-cased-v1\")\n",
+    "tokenizer = BertTokenizerFast.from_pretrained(\"dkleczek/bert-base-polish-cased-v1\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 85,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "text = \"Dziwny <mask> ten świat!\"\n",
+    "kwgs = tokenizer(text, return_tensors='pt', max_length=30, padding='max_length')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 86,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "res = model(**kwgs)[0]\n",
+    "output = [x.argmax() for x in res[0]]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 87,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "'! Dziwny < br > ten świat!!! \" Dzi! i! < … > nowy świat!! świat! \" \"!! jest'"
+     },
+     "metadata": {},
+     "execution_count": 87
+    }
+   ],
+   "source": [
+    "tokenizer.decode(output)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 88,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "[tensor(5),\n tensor(4642),\n tensor(12407),\n tensor(32),\n tensor(6858),\n tensor(34),\n tensor(1216),\n tensor(1994),\n tensor(5),\n tensor(5),\n tensor(5),\n tensor(6),\n tensor(4642),\n tensor(5),\n tensor(77),\n tensor(5),\n tensor(32),\n tensor(372),\n tensor(34),\n tensor(3905),\n tensor(1994),\n tensor(5),\n tensor(5),\n tensor(1994),\n tensor(5),\n tensor(6),\n tensor(6),\n tensor(5),\n tensor(5),\n tensor(800)]"
+     },
+     "metadata": {},
+     "execution_count": 88
+    }
+   ],
+   "source": [
+    "output"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 89,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "{'input_ids': tensor([[    2,  4642, 12407,    32, 45933,    34,  1216,  1994,     5,     4,\n             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0]])}"
+     },
+     "metadata": {},
+     "execution_count": 89
+    }
+   ],
+   "source": [
+    "kwgs"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ]
+}
\ No newline at end of file
diff --git a/notebooks/test_training_files.ipynb b/notebooks/test_training_files.ipynb
index c8ee4da..0296c9e 100644
--- a/notebooks/test_training_files.ipynb
+++ b/notebooks/test_training_files.ipynb
@@ -30,7 +30,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "INPUT_FOLDER=\"generated/stage4_reindexing\"\n",
+    "INPUT_FOLDER=\"generated/translations/stage2_create_batches\"\n",
     "MODEL_BASE = \"bert-base-multilingual-cased\"\n",
     "\n",
     "tokenizer = BertTokenizerFast.from_pretrained(MODEL_BASE)\n"
diff --git a/notebooks/tokenizer_testing.ipynb b/notebooks/tokenizer_testing.ipynb
index c4b66f6..d6e83cc 100644
--- a/notebooks/tokenizer_testing.ipynb
+++ b/notebooks/tokenizer_testing.ipynb
@@ -23,16 +23,17 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 14,
    "metadata": {},
    "outputs": [],
    "source": [
-    "from transformers import BertTokenizerFast"
+    "from transformers import BertTokenizerFast\n",
+    "import numpy as np"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -42,47 +43,92 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 31,
    "metadata": {},
    "outputs": [
     {
      "output_type": "execute_result",
      "data": {
-      "text/plain": "{'input_ids': [101, 56500, 10824, 16469, 177, 39327, 59726, 10132, 348, 11335, 68497, 119, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
+      "text/plain": "[56500, 117, 10824, 30186, 11090, 10113, 119, 138, 13400, 11058, 106]"
      },
      "metadata": {},
-     "execution_count": 7
+     "execution_count": 31
     }
    ],
    "source": [
-    "tokenizer(\"Ala ma kota i poszła na śniadanie.\")"
+    "tokenizer(\"Ala, ma KoTa. A kot nie!\")['input_ids'][1:-1]"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 26,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "'Ala, ma KoTa.'"
+     },
+     "metadata": {},
+     "execution_count": 26
+    }
+   ],
+   "source": [
+    "tokenizer.decode(np.array(tokenizer(\"Ala, ma KoTa. A kot nie!\")['input_ids'][1:-1])[[0, 1, 2, 3, 4, 5, 6]])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "'A kot nie!'"
+     },
+     "metadata": {},
+     "execution_count": 29
+    }
+   ],
+   "source": [
+    "tokenizer.decode(np.array(tokenizer(\"Ala, ma KoTa. A kot nie!\")['input_ids'][1:-1])[[7, 8, 9, 10]])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "'ala ma kota'"
+     },
+     "metadata": {},
+     "execution_count": 20
+    }
+   ],
    "source": [
-    "dot_token = tokenizer(\".\")['input_ids'][1]"
+    "tokenizer.decode(np.array(tokenizer(\"ala ma kota a kot nie\")['input_ids'][1:-1])[[0, 1, 2]])"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 22,
    "metadata": {},
    "outputs": [
     {
      "output_type": "execute_result",
      "data": {
-      "text/plain": "119"
+      "text/plain": "'a kot nie'"
      },
      "metadata": {},
-     "execution_count": 10
+     "execution_count": 22
     }
    ],
    "source": [
-    "dot_token"
+    "tokenizer.decode(np.array(tokenizer(\"ala ma kota a kot nie\")['input_ids'][1:-1])[[3, 4, 5]])"
    ]
   },
   {
diff --git a/params.yaml b/params.yaml
index d64c249..2de9ed6 100644
--- a/params.yaml
+++ b/params.yaml
@@ -1,6 +1,6 @@
 global:
     dashboard_port: 8787
-    base_model: "bert-base-multilingual-cased"
+    base_model: "dkleczek/bert-base-polish-cased-v1"
 
 actions:
     extraction:
@@ -27,7 +27,7 @@ actions:
         num_epochs: 5
         batch_size: 2
         save_step: 1000
-        max_training_time: "30s"
+        max_training_time: null
         loss_averaging_span: 1000
         fresh_start: false
         device: "cuda:0"
@@ -37,6 +37,8 @@ translations:
         num_workers: 24
         worker_memory_limit: "2GB"
 
-    tokenization:
+    create_batches:
         num_workers: 24
-        worker_memory_limit: "2GB"
\ No newline at end of file
+        worker_memory_limit: "2GB"
+        min_tokens: 10
+        max_tokens: 500
\ No newline at end of file
diff --git a/scripts/actions_based/stage5_loss_weights.py b/scripts/actions_based/stage5_loss_weights.py
new file mode 100644
index 0000000..72498b7
--- /dev/null
+++ b/scripts/actions_based/stage5_loss_weights.py
@@ -0,0 +1,40 @@
+# /usr/bin/python3
+from src.processing import batchify_data
+from dask.diagnostics import ProgressBar
+import dask.dataframe as dd
+from transformers import BertTokenizerFast
+import numpy as np
+import dask
+from dask.distributed import Client
+import pandas as pd
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
+
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage5_loss_weights"
+
+if __name__ == "__main__":
+    config = get_config()
+    num_workers = config['actions']['reindexing']['num_workers']
+    memory_limit = config['actions']['reindexing']['worker_memory_limit']
+
+    prepare_folder(OUTPUT_FOLDER)
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
+    print(client.dashboard_link)
+
+    df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
+
+    # Add ordered indexes
+    df = df.assign(ones=1)
+    df = df.reset_index(drop=True)
+    idx = (df.ones.cumsum() - 1).persist()
+    df = df.assign(ones=idx)
+
+    # Shuffle 
+    shuffled_idx = idx.compute().values
+    shuffled_idx = client.scatter(shuffled_idx)
+    mapped_ones = df.ones.apply(lambda x, idx: idx[x], args=(shuffled_idx,), meta=('ones', 'int64')).persist()
+    df = df.assign(ones=mapped_ones)
+
+    df = df.set_index('ones')
+    df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/scripts/actions_based/train.py b/scripts/actions_based/train.py
index 249641c..46abed9 100755
--- a/scripts/actions_based/train.py
+++ b/scripts/actions_based/train.py
@@ -41,7 +41,7 @@ if __name__ == "__main__":
 
     model = BertForTokenClassification.from_pretrained(base_model, num_labels=len(ACTIONS_KEYS)).to(device)
     criterion = BCEWithLogitsLoss().to(device)
-    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
     epoch_start = 0
     sample_start = 0
@@ -123,7 +123,6 @@ if __name__ == "__main__":
                 training_stopped = True
                 break
 
-
             loss.backward()
             optimizer.step()
 
diff --git a/scripts/translation_based/processing.py b/scripts/translation_based/processing.py
index 633e902..6d00546 100644
--- a/scripts/translation_based/processing.py
+++ b/scripts/translation_based/processing.py
@@ -1,8 +1,9 @@
 import dask.dataframe as dd
-from src.processing import text_from_xml
+from src.processing import text_from_xml, remove_punctuation
 from transformers import BertTokenizerFast
 import numpy as np
 
+
 def raw_to_dataframe(entry: dict) -> dict:
     """Converts dask datarfame containing files paths into
     dataframe with content of that files (text only)
@@ -20,10 +21,12 @@ def raw_to_dataframe(entry: dict) -> dict:
     else:
         return {'input': None}
 
+
 RAW_TO_DATAFRAME_META = {
     'input': str
 }
 
+
 def apply_tokenization(entry: dict, tokenizer: BertTokenizerFast) -> dict:
     """Converts raw text entries into list of tokens
 
@@ -40,11 +43,13 @@ def apply_tokenization(entry: dict, tokenizer: BertTokenizerFast) -> dict:
         'tokens': np.array(text_tokenized)
     }
 
+
 APPLY_TOKENIZATION_META = {
     'tokens': object
 }
 
-def split_into_batches(entry: dict, stopping_token: int) -> dict:
+
+def generate_batches(entry: dict, min_len: int, max_len: int, separating_token: int, tokenizer: BertTokenizerFast) -> dict:
     """Converts raw text entries into list of tokens
 
     Args:
@@ -54,13 +59,226 @@ def split_into_batches(entry: dict, stopping_token: int) -> dict:
     Returns:
         dict: Dask dataset entry with one column ('tokens') containing np.array list of tokens
     """
-    text_tokenized = tokenizer(entry.input)['input_ids'][1:-1]
+    tokens = np.array(tokenizer(entry.input)['input_ids'][1:-1])
+
+    tokens_ending = (tokens == separating_token).astype(np.int)
+    batch_indices = get_batch_indexes(tokens_ending, min_len, max_len - 2)
+    
+    source_batch, target_batch = crete_input_output_batch(tokens, batch_indices, max_len, tokenizer)
+
+    source_batch_shape = np.array(source_batch.shape)
+    target_batch_shape = np.array(target_batch.shape)
+
+    source_batch = source_batch.reshape(-1)
+    target_batch = target_batch.reshape(-1)
 
     return {
-        'tokens': np.array(text_tokenized)
+        'source': source_batch,
+        'target': target_batch,
+        'source_shape': source_batch_shape,
+        'target_shape': target_batch_shape
     }
 
-APPLY_TOKENIZATION_META = {
-    'tokens': object
+
+GENERATE_BATCHES_META = {
+    'source': object,
+    'target': object,
+    'source_shape': object,
+    'target_shape': object
 }
 
+
+def find_new_sentence_left(seq: np.array, pos: int) -> int:
+    """Finds nerest sentence on the left of the current position (including current position)
+
+    Args:
+        seq (np.array): Array of 0s and 1s of length equal to sequence. 1 means end of sentence (dot, semicolon etc.) and 0 - every other token
+        pos (int): Starting position
+
+    Returns:
+        int: Position of the nearest new sentence on the left. Start of the sequence always counts as a start of sentence
+    """
+    assert pos < len(seq)
+    assert pos >= 0
+
+    while pos > 0:
+        if seq[pos - 1] == 1:
+            return pos
+        else:
+            pos = pos - 1
+
+    return 0
+
+
+def find_new_sentence_right(seq: np.array, pos: int) -> int:
+    """Finds nerest sentence on the right of the current position (including current position)
+
+    Args:
+        seq (np.array): Array of 0s and 1s of length equal to sequence. 1 means end of sentence (dot, semicolon etc.) and 0 - every other token
+        pos (int): [description]
+
+    Returns:
+        int: Position of the nearest new sentence on the right. Returns none if no new sentence is found on the right
+    """
+    assert pos < len(seq)
+    assert pos >= 0
+
+    while pos < len(seq):
+        if seq[pos - 1] == 1:
+            return pos
+        else:
+            pos = pos + 1
+
+    return None
+
+
+def get_batch_indexes(seq: np.array, min_length: int, max_length: int) -> [np.array]:
+    """Turns long sequence into array of indices, composing a single batch file. 
+
+    Args:
+        seq (np.array): Input sequence of 1s and 0s, where 1 means end of sequence token (dot, semicolon etc.)
+        min_length (int): Minimum length of sample in a batch
+        max_length (int): Maximum length of sample in a batch
+
+    Returns:
+        [np.array]: Array of indices, where each entry has length between <min_length, max_length>
+    """
+    pos = 0
+    batch = []
+
+    assert min_length <= max_length
+
+    while pos < len(seq):
+        pos_delta = min(max_length, len(seq) - pos)
+        assert pos + pos_delta <= len(seq)
+
+        if pos_delta >= min_length:
+            new_entry = np.array(list(range(pos, pos + pos_delta)))
+            assert len(new_entry) <= max_length
+
+            batch.append(new_entry)
+
+        if pos + pos_delta >= len(seq):
+            break
+
+        new_pos = find_new_sentence_left(seq, pos + pos_delta)
+        if new_pos == pos:
+            new_pos = find_new_sentence_right(seq, pos + pos_delta)
+            if new_pos is None:
+                break
+
+        pos = new_pos
+
+    return batch
+
+
+def add_padding(seq: np.ndarray, total_length: int, padding_symbol: any) -> np.ndarray:
+    """Pads a sequence with provided symbol, to get array of length total_length in the end
+
+    Args:
+        seq (np.ndarray): Input sequence
+        total_length (int): Desired length of a sequence
+        padding_symbol (any): Symbol that will be inserted at the end (total_legnth - len(seq)) times
+
+    Returns:
+        np.ndarray: N-dimensional array where first dimension is of length total_length
+    """
+    num_padding = total_length - len(seq)
+    assert num_padding >= 0
+
+    if num_padding > 0:
+        return np.concatenate([seq, np.array([padding_symbol] * num_padding)], axis=0)
+    else:
+        return np.copy(seq)
+
+
+def add_begin_end_tokens(seq: np.ndarray, begin_token: any, end_token: any) -> np.ndarray:
+    """Adds preceding and ending special tokens to the sequence
+
+    Args:
+        seq (np.ndarray): Sequence of len L
+        begin_token (any): Tokend that will be added at the beginning of the sequence
+        end_token (any): Token that will be added at the end of the sequence
+
+    Returns:
+        np.ndarray: Sequence of len L+2
+    """
+
+    return np.concatenate(
+        [
+            [begin_token],
+            seq,
+            [end_token]
+        ]
+    )
+
+
+def standarize_translation_sample(seq: np.ndarray, total_length: int, padding_symbol: any, begin_token: any, end_token: any) -> np.ndarray:
+    """Adds special tokens and padding so that every sample has identical shape
+
+    Args:
+        seq (np.ndarray): Input sequence of len L
+        total_length (int): Desired sequence length
+        padding_symbol (any): Token that will be used for padding
+        begin_token (any): Token that will be used as starting token
+        end_token (any): Token that will be used as ending token
+
+    Returns:
+        np.ndarray: Output sequence of length total_length
+    """
+    return add_padding(add_begin_end_tokens(seq, begin_token, end_token), total_length, padding_symbol)
+
+
+def create_input_output(tokens: np.ndarray, length: int, tokenizer: BertTokenizerFast) -> (np.ndarray, np.ndarray):
+    """Transforms a sequence of tokens into "translation" input and output
+
+    Args:
+        tokens (np.ndarray): Input sequence
+        length (int): Maximum output length. Will add padding to match it
+        tokenizer (BertTokenizerFast): Tokenizer that was used to obtain tokens
+
+    Returns:
+        np.ndarray: Single sample that will serve as input to the model
+        np.ndarray: Single sample that will serve as expected output from the model
+    """
+    decoded_str = tokenizer.decode(tokens)
+    cleaned_str = remove_punctuation(decoded_str).lower()
+    source_batch_entry = tokenizer(cleaned_str)['input_ids'][1:-1]
+    target_batch_entry = tokens
+
+    # In rare cases (because of encoding) unpunctuated lowercase input might be longer than output and exeed limits
+    # We need to trim in such cases 
+    if len(source_batch_entry) > length - 2:
+        source_batch_entry = source_batch_entry[:(length-2)]
+
+    source_batch_entry = standarize_translation_sample(
+        source_batch_entry, length, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id)
+    target_batch_entry = standarize_translation_sample(
+        target_batch_entry, length, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id)
+
+    return source_batch_entry, target_batch_entry
+
+def crete_input_output_batch(seq: np.ndarray, batch_indexes: [np.ndarray], length: int, tokenizer: BertTokenizerFast) -> (np.ndarray, np.ndarray):
+    """Transforms a sequence of tokens into "translation" input and output batch
+
+    Args:
+        tokens (np.ndarray): Input sequence
+        batch_indexes ([np.ndarray]) List where every entry is array of indices representing a batch sample from tokens array.
+        length (int): Maximum output length. Will add padding to match it
+        tokenizer (BertTokenizerFast): Tokenizer that was used to obtain tokens
+
+    Returns:
+        np.ndarray: Single sample that will serve as input to the model
+        np.ndarray: Single sample that will serve as expected output from the model
+    """
+    base_batch = [seq[indexes] for indexes in batch_indexes]
+
+    source_batch = []
+    target_batch = []
+    for entry in base_batch:
+        source_entry, target_entry = create_input_output(entry, length, tokenizer)
+
+        source_batch.append(source_entry)
+        target_batch.append(target_entry)
+
+    return np.array(source_batch), np.array(target_batch)
\ No newline at end of file
diff --git a/scripts/translation_based/stage2_tokenization.py b/scripts/translation_based/stage2_create_batches.py
similarity index 53%
rename from scripts/translation_based/stage2_tokenization.py
rename to scripts/translation_based/stage2_create_batches.py
index 2768d2c..a7f110e 100644
--- a/scripts/translation_based/stage2_tokenization.py
+++ b/scripts/translation_based/stage2_create_batches.py
@@ -1,5 +1,5 @@
 # /usr/bin/python3
-from scripts.translation_based.processing import apply_tokenization, APPLY_TOKENIZATION_META
+from scripts.translation_based.processing import generate_batches, GENERATE_BATCHES_META
 from src.utils import PROJECT_ROOT, prepare_folder, get_config
 import numpy as np
 from dask.distributed import Client
@@ -8,13 +8,15 @@ import dask.dataframe as dd
 from dask import delayed
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage2_tokenization"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage2_create_batches"
 
 if __name__ == "__main__":
 
     config = get_config()
-    num_workers = config['translations']['tokenization']['num_workers']
-    memory_limit = config['translations']['tokenization']['worker_memory_limit']
+    num_workers = config['translations']['create_batches']['num_workers']
+    memory_limit = config['translations']['create_batches']['worker_memory_limit']
+    min_tokens = config['translations']['create_batches']['min_tokens']
+    max_tokens = config['translations']['create_batches']['max_tokens']
     base_model = config['global']['base_model']
 
     prepare_folder(OUTPUT_FOLDER)
@@ -25,8 +27,10 @@ if __name__ == "__main__":
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
     tokenizer = delayed(tokenizer)
 
+    token_separating = tokenizer(".")["input_ids"][1]
+
     df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
-    df = df.apply(apply_tokenization, result_type='expand', axis=1, meta=APPLY_TOKENIZATION_META, args=(tokenizer,))
+    df = df.apply(generate_batches, result_type='expand', axis=1, meta=GENERATE_BATCHES_META, args=(min_tokens, max_tokens, token_separating, tokenizer))
 
     # Export
     df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
\ No newline at end of file
diff --git a/scripts/translation_based/stage3_batchifying.py b/scripts/translation_based/stage3_batchifying.py
deleted file mode 100644
index 2768d2c..0000000
--- a/scripts/translation_based/stage3_batchifying.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# /usr/bin/python3
-from scripts.translation_based.processing import apply_tokenization, APPLY_TOKENIZATION_META
-from src.utils import PROJECT_ROOT, prepare_folder, get_config
-import numpy as np
-from dask.distributed import Client
-from transformers import BertTokenizerFast
-import dask.dataframe as dd
-from dask import delayed
-
-INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage2_tokenization"
-
-if __name__ == "__main__":
-
-    config = get_config()
-    num_workers = config['translations']['tokenization']['num_workers']
-    memory_limit = config['translations']['tokenization']['worker_memory_limit']
-    base_model = config['global']['base_model']
-
-    prepare_folder(OUTPUT_FOLDER)
-
-    client = Client(n_workers=num_workers, memory_limit=memory_limit)
-    print(f"Dashboard: {client.dashboard_link}")
-
-    tokenizer = BertTokenizerFast.from_pretrained(base_model)
-    tokenizer = delayed(tokenizer)
-
-    df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
-    df = df.apply(apply_tokenization, result_type='expand', axis=1, meta=APPLY_TOKENIZATION_META, args=(tokenizer,))
-
-    # Export
-    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
\ No newline at end of file
diff --git a/scripts/translation_based/stage3_exploding.py b/scripts/translation_based/stage3_exploding.py
new file mode 100644
index 0000000..7c0b0d4
--- /dev/null
+++ b/scripts/translation_based/stage3_exploding.py
@@ -0,0 +1,69 @@
+# /usr/bin/python3
+from src.processing import batchify_data
+from dask.diagnostics import ProgressBar
+import dask.dataframe as dd
+import numpy as np
+import dask
+from dask.distributed import Client
+import pandas as pd
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
+
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
+
+def expand_dims(entry):
+    inputs = entry.inputs.reshape(entry.input_shape)
+    outputs = entry.outputs.reshape(entry.output_shape)
+    masks = entry.attentions.reshape(entry.attentions_shape)
+
+    return {
+        'inputs': inputs,
+        'outputs': outputs,
+        "attentions": masks,
+    }
+
+def flatten_dims(entry):
+    inputs_shape = np.array(entry.inputs.shape)
+    outputs_shape = np.array(entry.outputs.shape)
+    attentions_shape = np.array(entry.attentions.shape)
+
+    inputs = entry.inputs.reshape(-1)
+    outputs = entry.outputs.reshape(-1)
+    attentions = entry.attentions.reshape(-1)
+
+    return {
+        'inputs': inputs,
+        'outputs': outputs,
+        'attentions': attentions,
+        'inputs_shape': inputs_shape,
+        'outputs_shape': outputs_shape,
+        'attentions_shape': attentions_shape
+    }
+
+
+RESULT_META = {
+    'inputs': object,
+    'outputs': object,
+    'attentions': object,
+    'inputs_shape': object,
+    'outputs_shape': object,
+    'attentions_shape': object
+}
+
+if __name__ == "__main__":
+    config = get_config()
+    num_workers = config['actions']['exploding']['num_workers']
+    memory_limit = config['actions']['exploding']['worker_memory_limit']
+
+    prepare_folder(OUTPUT_FOLDER)
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
+    print(client.dashboard_link)
+
+    df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
+
+    df = df.apply(expand_dims, result_type='expand', axis=1, meta={'inputs': object, 'outputs': object, 'attentions': object})
+    df = df.map_partitions(lambda x: x.apply(lambda y: y.explode(), axis=0), meta={'inputs': object, 'outputs': object, 'attentions': object})
+    df = df.apply(flatten_dims, result_type='expand', axis=1, meta=RESULT_META)
+    
+    df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/scripts/translation_based/stage4_reindexing.py b/scripts/translation_based/stage4_reindexing.py
new file mode 100644
index 0000000..b436ca2
--- /dev/null
+++ b/scripts/translation_based/stage4_reindexing.py
@@ -0,0 +1,40 @@
+# /usr/bin/python3
+from src.processing import batchify_data
+from dask.diagnostics import ProgressBar
+import dask.dataframe as dd
+from transformers import BertTokenizerFast
+import numpy as np
+import dask
+from dask.distributed import Client
+import pandas as pd
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
+
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+
+if __name__ == "__main__":
+    config = get_config()
+    num_workers = config['actions']['reindexing']['num_workers']
+    memory_limit = config['actions']['reindexing']['worker_memory_limit']
+
+    prepare_folder(OUTPUT_FOLDER)
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
+    print(client.dashboard_link)
+
+    df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
+
+    # Add ordered indexes
+    df = df.assign(ones=1)
+    df = df.reset_index(drop=True)
+    idx = (df.ones.cumsum() - 1).persist()
+    df = df.assign(ones=idx)
+
+    # Shuffle 
+    shuffled_idx = idx.compute().values
+    shuffled_idx = client.scatter(shuffled_idx)
+    mapped_ones = df.ones.apply(lambda x, idx: idx[x], args=(shuffled_idx,), meta=('ones', 'int64')).persist()
+    df = df.assign(ones=mapped_ones)
+
+    df = df.set_index('ones')
+    df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/scripts/translation_based/test_processing.py b/scripts/translation_based/test_processing.py
new file mode 100644
index 0000000..b340204
--- /dev/null
+++ b/scripts/translation_based/test_processing.py
@@ -0,0 +1,136 @@
+import numpy as np
+from scripts.translation_based.processing import (
+    find_new_sentence_left, find_new_sentence_right, get_batch_indexes, add_padding, add_begin_end_tokens, standarize_translation_sample, create_input_output, crete_input_output_batch)
+from transformers import BertTokenizerFast
+
+
+def test_find_new_sentence_left():
+    test_input = np.array([0, 0, 1, 0, 1, 0])
+    assert find_new_sentence_left(test_input, 0) == 0
+    assert find_new_sentence_left(test_input, 1) == 0
+    assert find_new_sentence_left(test_input, 2) == 0
+    assert find_new_sentence_left(test_input, 3) == 3
+    assert find_new_sentence_left(test_input, 4) == 3
+    assert find_new_sentence_left(test_input, 5) == 5
+
+
+def test_find_new_sentence_right():
+    test_input = np.array([0, 0, 1, 0, 1, 0, 0])
+    assert find_new_sentence_right(test_input, 0) == 3
+    assert find_new_sentence_right(test_input, 1) == 3
+    assert find_new_sentence_right(test_input, 2) == 3
+    assert find_new_sentence_right(test_input, 3) == 3
+    assert find_new_sentence_right(test_input, 4) == 5
+    assert find_new_sentence_right(test_input, 5) == 5
+    assert find_new_sentence_right(test_input, 6) is None
+
+
+def test_split_to_samples():
+    min_len = 3
+    max_len = 5
+    test_input = np.array([0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0])
+    expeted_output = [
+        np.array([0, 1, 2, 3, 4]),
+        np.array([6, 7, 8, 9, 10])
+    ]
+
+    result = get_batch_indexes(test_input, min_len, max_len)
+    assert len(result) == len(expeted_output)
+
+    for got, expected in zip(result, expeted_output):
+        assert np.all(got == expected)
+
+
+def test_add_padding():
+    input_sequence = np.array([1, 2, 3, 4])
+
+    # Works with 0 padding
+    result = add_padding(input_sequence, 4, 9)
+    assert len(result) == 4
+    assert np.all(result == input_sequence)
+
+    # Normal use case
+    result = add_padding(input_sequence, 6, 9)
+    assert len(result) == 6
+    assert np.all(result == [1, 2, 3, 4, 9, 9])
+
+    # multidimensional use-case
+    input_sequence = np.array([
+        [1, 2, 3],
+        [4, 5, 6]
+    ])
+    padd = np.array([9, 9, 9])
+    result = add_padding(input_sequence, 4, padd)
+    assert len(result) == 4
+    assert np.all(result == [[1, 2, 3], [4, 5, 6], [9, 9, 9], [9, 9, 9]])
+
+
+def test_add_begin_end_tokens():
+    input_sequence = np.array([1])
+    result = add_begin_end_tokens(input_sequence, 9, 8)
+
+    assert len(result) == 3
+    assert np.all(result == [9, 1, 8])
+
+
+def test_standarize_translation_sample():
+    input_sequence = np.array([1])
+
+    result = standarize_translation_sample(input_sequence, 5, 5, 9, 8)
+
+    assert len(result) == 5
+    assert np.all(result == [9, 1, 8, 5, 5])
+
+
+def test_create_input_output():
+    sequence = [56500, 117, 10824, 30186, 11090, 10113, 119]
+    tokenizer = BertTokenizerFast.from_pretrained(
+        "bert-base-multilingual-cased")
+
+    expected_output_sequence = [tokenizer.cls_token_id, 56500, 117, 10824, 30186, 11090,
+                               10113, 119, tokenizer.sep_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id]
+    expected_input_sequence = [tokenizer.cls_token_id, 21739, 10824, 16469, tokenizer.sep_token_id, tokenizer.pad_token_id,
+                                tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id]
+
+    result_input, result_output = create_input_output(sequence, 11, tokenizer)
+
+    assert len(result_input) == len(expected_input_sequence)
+    assert len(result_output) == len(expected_output_sequence)
+    assert np.all(expected_input_sequence == result_input)
+    assert np.all(expected_output_sequence == result_output)
+
+
+def test_create_input_output_batch():
+    tokenizer = BertTokenizerFast.from_pretrained(
+        "bert-base-multilingual-cased")
+
+    expected_output_1 = np.array(tokenizer("Ala, ma KoTa.")['input_ids'])[1:-1]
+    expected_output_2 = np.array(tokenizer("A kOt nie!")['input_ids'])[1:-1]
+
+    expected_input_1 = np.array(tokenizer("ala ma kota")['input_ids'])[1:-1]
+    expected_input_2 = np.array(tokenizer("a kot nie")['input_ids'])[1:-1]
+
+    input_sequence = np.concatenate([expected_output_1, expected_output_2])
+    batch_ids = [
+        np.array(list(range(len(expected_output_1)))),
+        np.array(list(range(len(expected_output_2)))) + len(expected_output_1)
+    ]
+
+    expected_input_1 = standarize_translation_sample(expected_input_1, 20, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id)
+    expected_input_2 = standarize_translation_sample(expected_input_2, 20, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id)
+    expected_output_1 = standarize_translation_sample(expected_output_1, 20, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id)
+    expected_output_2 = standarize_translation_sample(expected_output_2, 20, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id)
+
+    result_input, result_output = crete_input_output_batch(input_sequence, batch_ids, 20, tokenizer)
+
+    assert result_input.shape[0] == 2
+    assert result_input.shape[1] == 20
+
+    assert result_output.shape[0] == 2
+    assert result_output.shape[1] == 20
+
+    assert np.all(result_input[0] == expected_input_1)
+    assert np.all(result_input[1] == expected_input_2)
+
+    assert np.all(result_output[0] == expected_output_1)
+    assert np.all(result_output[1] == expected_output_2)
\ No newline at end of file
diff --git a/src/processing.py b/src/processing.py
index e7b302a..8e3501e 100644
--- a/src/processing.py
+++ b/src/processing.py
@@ -6,7 +6,7 @@ import numpy as np
 from transformers import PreTrainedTokenizerFast
 from collections import defaultdict
 
-ACTIONS_KEYS = ['dot', 'upper_case', 'colon', 'semicolon', 'elipsis', 'dash']
+ACTIONS_KEYS = ['dot', 'upper_case', 'colon', 'elipsis', 'dash', 'question_mark']
 
 def empty_action_vector() -> np.ndarray:
     """Returns a do-nothing actions vector
@@ -53,6 +53,7 @@ def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
         Mapping[str, bool]: Mapping telling if each of possible actions should be performed (True) or not (False) 
     """
     # Unsuported characters
+    word.replace(";", ".") 
     word.replace('"', " ")
     word.replace('(', " ")
     word.replace(')', " ")
@@ -69,9 +70,9 @@ def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
         'dot': word[-1] == '.' and not has_colon,
         'upper_case': word[0].isupper(),
         'colon': word[-1] == ",",
-        'semicolon': word[-1] == ";",
         'elipsis': has_colon,
-        'dash': next_word is not None and next_word == "-"
+        'dash': next_word is not None and next_word == "-",
+        'question_mark': word[-1] == "?"
     }
 
     return actions
@@ -227,12 +228,12 @@ def recover_word(word: str, action: Mapping[str, bool]) -> str:
         word_result = word_result.capitalize()
     if action['colon']:
         word_result += ","
-    if action['semicolon']:
-        word_result += ";"
     if action['elipsis']:
         word_result += "..."
     if action['dash']:
         word_result += " -"
+    if action['question_mark']:
+        word_result += " -"
 
     return word_result
 
-- 
GitLab


From 0c51e9a8d3d3a5d656e1fd73670308c687bd5e78 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 4 Aug 2020 01:40:52 +0200
Subject: [PATCH 028/116] Finished translation pipeline

---
 dvc.yaml                                      |  14 +-
 generated/translations/.gitignore             |   2 +
 notebooks/dask_dataframe_exploration.ipynb    |  14 +-
 params.yaml                                   |  10 +-
 scripts/translation_based/processing.py       |  78 ++++++----
 scripts/translation_based/stage3_exploding.py |  55 ++-----
 .../translation_based/stage4_reindexing.py    |   8 +-
 scripts/translation_based/train.py            | 134 ++++++++++++++++++
 8 files changed, 230 insertions(+), 85 deletions(-)
 create mode 100755 scripts/translation_based/train.py

diff --git a/dvc.yaml b/dvc.yaml
index 3275740..4a20714 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -62,4 +62,16 @@ stages:
     params:
     - global.base_model
     outs:
-    - generated/translations/create_batches
+    - generated/translations/stage2_create_batches
+  translations_exploding:
+    cmd: python3 -m scripts.translation_based.stage3_exploding
+    deps:
+    - generated/translations/stage2_create_batches
+    outs:
+    - generated/translations/stage3_exploding
+  translations_reindexing:
+    cmd: python3 -m scripts.translation_based.stage4_reindexing
+    deps:
+    - generated/translations/stage3_exploding
+    outs:
+    - generated/translations/stage4_reindexing
\ No newline at end of file
diff --git a/generated/translations/.gitignore b/generated/translations/.gitignore
index 19240d7..c31dad5 100644
--- a/generated/translations/.gitignore
+++ b/generated/translations/.gitignore
@@ -1,2 +1,4 @@
 /stage1_extraction
 /stage2_create_batches
+/stage3_exploding
+/stage4_reindexing
diff --git a/notebooks/dask_dataframe_exploration.ipynb b/notebooks/dask_dataframe_exploration.ipynb
index 59f6caf..810287e 100644
--- a/notebooks/dask_dataframe_exploration.ipynb
+++ b/notebooks/dask_dataframe_exploration.ipynb
@@ -23,7 +23,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 41,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -33,7 +33,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 42,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -42,7 +42,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 43,
+   "execution_count": 3,
    "metadata": {
     "tags": []
    },
@@ -53,7 +53,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 44,
+   "execution_count": 4,
    "metadata": {
     "tags": []
    },
@@ -61,11 +61,11 @@
     {
      "output_type": "execute_result",
      "data": {
-      "text/plain": "                                               source  \\\n15  [2, 27476, 7835, 4677, 2822, 11226, 781, 77, 4...   \n18  [2, 1178, 4607, 766, 7835, 752, 58008, 4419, 5...   \n21  [2, 27476, 7835, 25104, 9712, 6901, 8698, 778,...   \n25  [2, 1645, 2160, 27476, 11811, 7835, 4677, 5657...   \n29  [2, 27476, 7835, 4677, 5529, 77, 10814, 4994, ...   \n\n                                               target source_shape  \\\n15  [2, 15482, 7835, 2931, 18, 7331, 11226, 781, 1...     [2, 500]   \n18  [2, 56453, 7835, 18, 922, 58008, 4419, 5482, 4...   [169, 500]   \n21  [2, 15482, 7835, 18, 38648, 9712, 6901, 8698, ...    [94, 500]   \n25  [2, 1513, 2160, 16, 27476, 11811, 7835, 2931, ...    [25, 500]   \n29  [2, 15482, 7835, 2931, 53234, 77, 6789, 17353,...     [3, 500]   \n\n   target_shape  \n15     [2, 500]  \n18   [169, 500]  \n21    [94, 500]  \n25    [25, 500]  \n29     [3, 500]  ",
-      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>source</th>\n      <th>target</th>\n      <th>source_shape</th>\n      <th>target_shape</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>15</th>\n      <td>[2, 27476, 7835, 4677, 2822, 11226, 781, 77, 4...</td>\n      <td>[2, 15482, 7835, 2931, 18, 7331, 11226, 781, 1...</td>\n      <td>[2, 500]</td>\n      <td>[2, 500]</td>\n    </tr>\n    <tr>\n      <th>18</th>\n      <td>[2, 1178, 4607, 766, 7835, 752, 58008, 4419, 5...</td>\n      <td>[2, 56453, 7835, 18, 922, 58008, 4419, 5482, 4...</td>\n      <td>[169, 500]</td>\n      <td>[169, 500]</td>\n    </tr>\n    <tr>\n      <th>21</th>\n      <td>[2, 27476, 7835, 25104, 9712, 6901, 8698, 778,...</td>\n      <td>[2, 15482, 7835, 18, 38648, 9712, 6901, 8698, ...</td>\n      <td>[94, 500]</td>\n      <td>[94, 500]</td>\n    </tr>\n    <tr>\n      <th>25</th>\n      <td>[2, 1645, 2160, 27476, 11811, 7835, 4677, 5657...</td>\n      <td>[2, 1513, 2160, 16, 27476, 11811, 7835, 2931, ...</td>\n      <td>[25, 500]</td>\n      <td>[25, 500]</td>\n    </tr>\n    <tr>\n      <th>29</th>\n      <td>[2, 27476, 7835, 4677, 5529, 77, 10814, 4994, ...</td>\n      <td>[2, 15482, 7835, 2931, 53234, 77, 6789, 17353,...</td>\n      <td>[3, 500]</td>\n      <td>[3, 500]</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
+      "text/plain": "                                               source  \\\n15  [2, 27476, 7835, 4677, 2822, 11226, 781, 77, 4...   \n18  [2, 1178, 4607, 766, 7835, 752, 58008, 4419, 5...   \n21  [2, 27476, 7835, 25104, 9712, 6901, 8698, 778,...   \n25  [2, 1645, 2160, 27476, 11811, 7835, 4677, 5657...   \n29  [2, 27476, 7835, 4677, 5529, 77, 10814, 4994, ...   \n\n                                               target  \\\n15  [2, 15482, 7835, 2931, 18, 7331, 11226, 781, 1...   \n18  [2, 56453, 7835, 18, 922, 58008, 4419, 5482, 4...   \n21  [2, 15482, 7835, 18, 38648, 9712, 6901, 8698, ...   \n25  [2, 1513, 2160, 16, 27476, 11811, 7835, 2931, ...   \n29  [2, 15482, 7835, 2931, 53234, 77, 6789, 17353,...   \n\n                                                 mask source_shape  \\\n15  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...     [2, 500]   \n18  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...   [169, 500]   \n21  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...    [94, 500]   \n25  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...    [25, 500]   \n29  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...     [3, 500]   \n\n   target_shape  mask_shape  \n15     [2, 500]    [2, 500]  \n18   [169, 500]  [169, 500]  \n21    [94, 500]   [94, 500]  \n25    [25, 500]   [25, 500]  \n29     [3, 500]    [3, 500]  ",
+      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>source</th>\n      <th>target</th>\n      <th>mask</th>\n      <th>source_shape</th>\n      <th>target_shape</th>\n      <th>mask_shape</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>15</th>\n      <td>[2, 27476, 7835, 4677, 2822, 11226, 781, 77, 4...</td>\n      <td>[2, 15482, 7835, 2931, 18, 7331, 11226, 781, 1...</td>\n      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...</td>\n      <td>[2, 500]</td>\n      <td>[2, 500]</td>\n      <td>[2, 500]</td>\n    </tr>\n    <tr>\n      <th>18</th>\n      <td>[2, 1178, 4607, 766, 7835, 752, 58008, 4419, 5...</td>\n      <td>[2, 56453, 7835, 18, 922, 58008, 4419, 5482, 4...</td>\n      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...</td>\n      <td>[169, 500]</td>\n      <td>[169, 500]</td>\n      <td>[169, 500]</td>\n    </tr>\n    <tr>\n      <th>21</th>\n      <td>[2, 27476, 7835, 25104, 9712, 6901, 8698, 778,...</td>\n      <td>[2, 15482, 7835, 18, 38648, 9712, 6901, 8698, ...</td>\n      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...</td>\n      <td>[94, 500]</td>\n      <td>[94, 500]</td>\n      <td>[94, 500]</td>\n    </tr>\n    <tr>\n      <th>25</th>\n      <td>[2, 1645, 2160, 27476, 11811, 7835, 4677, 5657...</td>\n      <td>[2, 1513, 2160, 16, 27476, 11811, 7835, 2931, ...</td>\n      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...</td>\n      <td>[25, 500]</td>\n      <td>[25, 500]</td>\n      <td>[25, 500]</td>\n    </tr>\n    <tr>\n      <th>29</th>\n      <td>[2, 27476, 7835, 4677, 5529, 77, 10814, 4994, ...</td>\n      <td>[2, 15482, 7835, 2931, 53234, 77, 6789, 17353,...</td>\n      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...</td>\n      <td>[3, 500]</td>\n      <td>[3, 500]</td>\n      <td>[3, 500]</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
      },
      "metadata": {},
-     "execution_count": 44
+     "execution_count": 4
     }
    ],
    "source": [
diff --git a/params.yaml b/params.yaml
index 2de9ed6..c11f4e5 100644
--- a/params.yaml
+++ b/params.yaml
@@ -41,4 +41,12 @@ translations:
         num_workers: 24
         worker_memory_limit: "2GB"
         min_tokens: 10
-        max_tokens: 500
\ No newline at end of file
+        max_tokens: 500
+
+    exploding:
+        num_workers: 24
+        worker_memory_limit: "2GB"
+
+    reindexing:
+        num_workers: 1
+        worker_memory_limit: "60GB"
\ No newline at end of file
diff --git a/scripts/translation_based/processing.py b/scripts/translation_based/processing.py
index 6d00546..941d0ec 100644
--- a/scripts/translation_based/processing.py
+++ b/scripts/translation_based/processing.py
@@ -26,29 +26,6 @@ RAW_TO_DATAFRAME_META = {
     'input': str
 }
 
-
-def apply_tokenization(entry: dict, tokenizer: BertTokenizerFast) -> dict:
-    """Converts raw text entries into list of tokens
-
-    Args:
-        x (dict): Dask dataframe entry with one column ['input'] containing text
-        tokenizer (BertTokenizerFast): Tokenizer used to tokenize. Must be a deleayed object to prevent memory leak!
-
-    Returns:
-        dict: Dask dataset entry with one column ('tokens') containing np.array list of tokens
-    """
-    text_tokenized = tokenizer(entry.input)['input_ids'][1:-1]
-
-    return {
-        'tokens': np.array(text_tokenized)
-    }
-
-
-APPLY_TOKENIZATION_META = {
-    'tokens': object
-}
-
-
 def generate_batches(entry: dict, min_len: int, max_len: int, separating_token: int, tokenizer: BertTokenizerFast) -> dict:
     """Converts raw text entries into list of tokens
 
@@ -65,26 +42,77 @@ def generate_batches(entry: dict, min_len: int, max_len: int, separating_token:
     batch_indices = get_batch_indexes(tokens_ending, min_len, max_len - 2)
     
     source_batch, target_batch = crete_input_output_batch(tokens, batch_indices, max_len, tokenizer)
+    mask_batch = (source_batch != tokenizer.pad_token_id).astype(np.int)
 
     source_batch_shape = np.array(source_batch.shape)
     target_batch_shape = np.array(target_batch.shape)
+    mask_batch_shape = np.array(mask_batch.shape)
 
     source_batch = source_batch.reshape(-1)
     target_batch = target_batch.reshape(-1)
+    mask_batch = mask_batch.reshape(-1)
 
     return {
         'source': source_batch,
         'target': target_batch,
+        'attention_mask': mask_batch,
         'source_shape': source_batch_shape,
-        'target_shape': target_batch_shape
+        'target_shape': target_batch_shape,
+        'attention_mask_shape': mask_batch_shape
     }
 
 
 GENERATE_BATCHES_META = {
     'source': object,
     'target': object,
+    'attention_mask': object,
+    'source_shape': object,
+    'target_shape': object,
+    'attention_mask_shape': object
+}
+
+def expand_dims(entry):
+    source = entry.source.reshape(entry.source_shape)
+    target = entry.target.reshape(entry.target_shape)
+    mask = entry.attention_mask.reshape(entry.attention_mask_shape)
+
+    return {
+        'source': source,
+        'target': target,
+        "attention_mask": mask,
+    }
+
+EXPAND_DIMS_META = {
+    'source': object, 
+    'target': object,
+    'attention_mask': object
+}
+
+def flatten_dims(entry):
+    source_shape = np.array(entry.source.shape)
+    target_shape = np.array(entry.target.shape)
+    mask_shape = np.array(entry.attention_mask.shape)
+
+    source = entry.source.reshape(-1)
+    target = entry.target.reshape(-1)
+    mask = entry.attention_mask.reshape(-1)
+
+    return {
+        'source': source,
+        'target': target,
+        'attention_mask': mask,
+        'source_shape': source_shape,
+        'target_shape': target_shape,
+        'attention_mask_shape': mask_shape
+    }
+
+FLATTEN_DIMS_META = {
+    'source': object,
+    'target': object,
+    'attention_mask': object,
     'source_shape': object,
-    'target_shape': object
+    'target_shape': object,
+    'attention_mask_shape': object
 }
 
 
diff --git a/scripts/translation_based/stage3_exploding.py b/scripts/translation_based/stage3_exploding.py
index 7c0b0d4..f357dcd 100644
--- a/scripts/translation_based/stage3_exploding.py
+++ b/scripts/translation_based/stage3_exploding.py
@@ -1,5 +1,5 @@
 # /usr/bin/python3
-from src.processing import batchify_data
+from scripts.translation_based.processing import flatten_dims, expand_dims, FLATTEN_DIMS_META, EXPAND_DIMS_META
 from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
 import numpy as np
@@ -8,52 +8,13 @@ from dask.distributed import Client
 import pandas as pd
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
-INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
-
-def expand_dims(entry):
-    inputs = entry.inputs.reshape(entry.input_shape)
-    outputs = entry.outputs.reshape(entry.output_shape)
-    masks = entry.attentions.reshape(entry.attentions_shape)
-
-    return {
-        'inputs': inputs,
-        'outputs': outputs,
-        "attentions": masks,
-    }
-
-def flatten_dims(entry):
-    inputs_shape = np.array(entry.inputs.shape)
-    outputs_shape = np.array(entry.outputs.shape)
-    attentions_shape = np.array(entry.attentions.shape)
-
-    inputs = entry.inputs.reshape(-1)
-    outputs = entry.outputs.reshape(-1)
-    attentions = entry.attentions.reshape(-1)
-
-    return {
-        'inputs': inputs,
-        'outputs': outputs,
-        'attentions': attentions,
-        'inputs_shape': inputs_shape,
-        'outputs_shape': outputs_shape,
-        'attentions_shape': attentions_shape
-    }
-
-
-RESULT_META = {
-    'inputs': object,
-    'outputs': object,
-    'attentions': object,
-    'inputs_shape': object,
-    'outputs_shape': object,
-    'attentions_shape': object
-}
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage2_create_batches"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage3_exploding"
 
 if __name__ == "__main__":
     config = get_config()
-    num_workers = config['actions']['exploding']['num_workers']
-    memory_limit = config['actions']['exploding']['worker_memory_limit']
+    num_workers = config['translations']['exploding']['num_workers']
+    memory_limit = config['translations']['exploding']['worker_memory_limit']
 
     prepare_folder(OUTPUT_FOLDER)
 
@@ -62,8 +23,8 @@ if __name__ == "__main__":
 
     df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
 
-    df = df.apply(expand_dims, result_type='expand', axis=1, meta={'inputs': object, 'outputs': object, 'attentions': object})
-    df = df.map_partitions(lambda x: x.apply(lambda y: y.explode(), axis=0), meta={'inputs': object, 'outputs': object, 'attentions': object})
-    df = df.apply(flatten_dims, result_type='expand', axis=1, meta=RESULT_META)
+    df = df.apply(expand_dims, result_type='expand', axis=1, meta=EXPAND_DIMS_META)
+    df = df.map_partitions(lambda x: x.apply(lambda y: y.explode(), axis=0), meta=EXPAND_DIMS_META)
+    df = df.apply(flatten_dims, result_type='expand', axis=1, meta=FLATTEN_DIMS_META)
     
     df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/scripts/translation_based/stage4_reindexing.py b/scripts/translation_based/stage4_reindexing.py
index b436ca2..1974753 100644
--- a/scripts/translation_based/stage4_reindexing.py
+++ b/scripts/translation_based/stage4_reindexing.py
@@ -9,13 +9,13 @@ from dask.distributed import Client
 import pandas as pd
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
-INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage3_exploding"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage4_reindexing"
 
 if __name__ == "__main__":
     config = get_config()
-    num_workers = config['actions']['reindexing']['num_workers']
-    memory_limit = config['actions']['reindexing']['worker_memory_limit']
+    num_workers = config['translations']['reindexing']['num_workers']
+    memory_limit = config['translations']['reindexing']['worker_memory_limit']
 
     prepare_folder(OUTPUT_FOLDER)
 
diff --git a/scripts/translation_based/train.py b/scripts/translation_based/train.py
new file mode 100755
index 0000000..35b17ee
--- /dev/null
+++ b/scripts/translation_based/train.py
@@ -0,0 +1,134 @@
+#!/usr/bin/python3
+
+from transformers import BertTokenizerFast, BertForMaskedLM
+import torch
+from torch.nn import BCEWithLogitsLoss
+import pandas as pd
+import numpy as np
+import dask.dataframe as dd
+import os
+import glob
+from src.utils import PROJECT_ROOT, get_config, convert_to_timedelta, prepare_folder
+from src.processing import ACTIONS_KEYS
+from datetime import datetime
+
+INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions"
+
+if __name__ == "__main__":
+    config = get_config()
+    learning_rate = config['translations']['training']['learning_rate']
+    num_epochs = config['translations']['training']['num_epochs']
+    batch_size = config['translations']['training']['batch_size']
+    save_step = config['translations']['training']['save_step']
+    loss_averaging_span = config['translations']['training']['loss_averaging_span']
+    fresh_start = config['translations']['training']['fresh_start']
+    device_name = config['translations']['training']['device']
+    max_train_time = config['translations']['training']['max_training_time']
+    base_model = config['translations']['base_model']
+
+    prepare_folder(OUTPUT_PATH)
+
+    if max_train_time is not None:
+        max_train_time = convert_to_timedelta(max_train_time)
+
+    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
+    print(f"Training on {device}")
+
+    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
+    
+    tokenizer = BertTokenizerFast.from_pretrained(base_model)
+
+    model = BertForMaskedLM.from_pretrained(base_model).to(device)
+    criterion = BCEWithLogitsLoss().to(device)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+
+    epoch_start = 0
+    sample_start = 0
+    if fresh_start == False:
+        checkpoint_files = glob.glob(f"{OUTPUT_PATH}/*.model")
+        furthest_epoch = -1
+        furthest_batch_num = -1
+        for checkpoint_file in checkpoint_files:
+            filename = checkpoint_file.split("/")[-1].split(".")[0]
+            epoch, iteration = filename.split("-")
+            epoch, iteration = int(epoch), int(iteration)
+
+            if epoch >= furthest_epoch:
+                furthest_epoch = epoch
+                furthest_batch_num = max(iteration, furthest_batch_num)
+
+        if furthest_epoch > -1 and furthest_batch_num > -1:
+            model.load_state_dict(torch.load(f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.model"))
+            optimizer.load_state_dict(torch.load(f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.optimizer"))
+
+            epoch_start, sample_start = furthest_epoch, furthest_batch_num
+            print(f"Loaded {furthest_epoch}-{furthest_batch_num}")
+
+    model.train()
+    losses = []
+
+    training_stopped = False
+
+    time_max = datetime.max
+    if max_train_time is not None:
+        time_max = datetime.now() + max_train_time
+
+    for epoch in range(epoch_start, num_epochs):
+        if training_stopped:
+            break
+
+        i = sample_start
+        while True:
+            # TODO: Change to 0-indexed...
+            data_batch_indexes = list(range(i*batch_size+1, i*batch_size + batch_size +1))
+            
+            # Precomputing total number of samples takes very long, so lets
+            # try to get next batch until fail :)
+            try:
+                data_batch = df.loc[data_batch_indexes].compute()
+            except:
+                # TODO: Specify exception type
+                break
+
+            inputs = data_batch.apply(lambda x: x['inputs'].reshape(x['inputs_shape']), axis=1).values
+            outputs = data_batch.apply(lambda x: x['outputs'].reshape(x['outputs_shape']), axis=1).values
+            attentions_mask = data_batch.apply(lambda x: x['attentions'].reshape(x['attentions_shape']), axis=1).values
+
+            inputs = torch.tensor(np.stack(inputs).squeeze()).to(device)
+            outputs = torch.tensor(np.stack(outputs)).to(device)
+            attentions_mask = torch.tensor(np.stack(attentions_mask)).to(device)
+
+            y_pred = model(input_ids=inputs, attention_mask=attentions_mask)[0]
+
+            loss = criterion(y_pred, outputs)
+
+            losses.append(loss.item())
+            if len(losses) > loss_averaging_span:
+                losses = losses[-loss_averaging_span:]
+
+            print(f'epoch: {epoch} | step: {i} | loss: {np.mean(losses)}')
+
+            optimizer.zero_grad()
+
+            if i % save_step == 0 and (i != sample_start or epoch != epoch_start):
+                print(f"Saving: Epoch {epoch}, step {i}")
+                torch.save(model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model")
+                torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.optimizer")
+
+            if datetime.now() > time_max:
+                print(f"Max time reached, saving: Epoch {epoch}, step {i}")
+                torch.save(model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model")
+                torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.optimizer")
+                training_stopped = True
+                break
+
+            loss.backward()
+            optimizer.step()
+
+            i += 1
+
+    if not training_stopped:
+        torch.save(model.state_dict(), f"{OUTPUT_PATH}/final.model")
+        torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/final.optimizer")
+
-- 
GitLab


From f0ce7302918ddcce348a1da9ef7f4de162cb5221 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 4 Aug 2020 12:13:12 +0200
Subject: [PATCH 029/116] Translation model v1

---
 .../{test.ipynb => test_actions_model.ipynb}  |   0
 notebooks/test_training_files.ipynb           | 137 ----------------
 notebooks/test_translations_model.ipynb       | 150 ++++++++++++++++++
 notebooks/torch_exploration.ipynb             | 118 ++++++++++++++
 params.yaml                                   |  12 +-
 scripts/translation_based/train.py            |  27 ++--
 6 files changed, 293 insertions(+), 151 deletions(-)
 rename notebooks/{test.ipynb => test_actions_model.ipynb} (100%)
 delete mode 100644 notebooks/test_training_files.ipynb
 create mode 100644 notebooks/test_translations_model.ipynb
 create mode 100644 notebooks/torch_exploration.ipynb

diff --git a/notebooks/test.ipynb b/notebooks/test_actions_model.ipynb
similarity index 100%
rename from notebooks/test.ipynb
rename to notebooks/test_actions_model.ipynb
diff --git a/notebooks/test_training_files.ipynb b/notebooks/test_training_files.ipynb
deleted file mode 100644
index 0296c9e..0000000
--- a/notebooks/test_training_files.ipynb
+++ /dev/null
@@ -1,137 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 16,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from src.processing import text_from_xml, create_model_input_output, token_labels_to_word_labels, recover_text\n",
-    "import glob\n",
-    "import numpy as np\n",
-    "from dask.diagnostics import ProgressBar\n",
-    "import dask.dataframe as dd\n",
-    "import dask\n",
-    "import pandas as pd\n",
-    "from dask.distributed import Client\n",
-    "import gc\n",
-    "from memory_profiler import profile\n",
-    "import pyspark\n",
-    "from pyspark.sql import SparkSession, Row, udf\n",
-    "from pyspark.sql.types import ArrayType, IntegerType\n",
-    "from pyspark.sql.types import StructType, StructField\n",
-    "from pyspark.mllib.linalg import Vectors\n",
-    "from transformers import BertTokenizerFast, BertForTokenClassification\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 11,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "INPUT_FOLDER=\"generated/translations/stage2_create_batches\"\n",
-    "MODEL_BASE = \"bert-base-multilingual-cased\"\n",
-    "\n",
-    "tokenizer = BertTokenizerFast.from_pretrained(MODEL_BASE)\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 6,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 7,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "                                                  inputs  \\\nones                                                       \n92132  [101, 12644, 82233, 10451, 13863, 48616, 10797...   \n\n                                                 outputs  \\\nones                                                       \n92132  [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ...   \n\n                                              attentions inputs_shape  \\\nones                                                                    \n92132  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...      [50, 1]   \n\n      outputs_shape attentions_shape  \nones                                  \n92132       [50, 6]             [50]  ",
-      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>inputs</th>\n      <th>outputs</th>\n      <th>attentions</th>\n      <th>inputs_shape</th>\n      <th>outputs_shape</th>\n      <th>attentions_shape</th>\n    </tr>\n    <tr>\n      <th>ones</th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>92132</th>\n      <td>[101, 12644, 82233, 10451, 13863, 48616, 10797...</td>\n      <td>[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ...</td>\n      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...</td>\n      <td>[50, 1]</td>\n      <td>[50, 6]</td>\n      <td>[50]</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
-     },
-     "metadata": {},
-     "execution_count": 7
-    }
-   ],
-   "source": [
-    "sample = df.loc[92132].compute()\n",
-    "sample"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 20,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "stream",
-     "name": "stdout",
-     "text": "(50, 6)\n"
-    },
-    {
-     "output_type": "error",
-     "ename": "AssertionError",
-     "evalue": "",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-20-b333c374ea16>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_outputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mlabels_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoken_labels_to_word_labels\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext_clean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0mactions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabels_pred\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0.1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;32m~/projekty/clarin/interpunkcja/src/processing.py\u001b[0m in \u001b[0;36mtoken_labels_to_word_labels\u001b[0;34m(text, token_labels, tokenizer)\u001b[0m\n\u001b[1;32m    165\u001b[0m     \u001b[0mmapping\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoken_word_mapping\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    166\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m     \u001b[0;32massert\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmapping\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtoken_labels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    169\u001b[0m     \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdefaultdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;31mAssertionError\u001b[0m: "
-     ]
-    }
-   ],
-   "source": [
-    "sample_inputs = sample['inputs'].values[0].reshape((50))\n",
-    "sample_outputs = sample['outputs'].values[0].reshape((50, 6))\n",
-    "sample_attentions = sample['outputs'].values[0].reshape((50))\n",
-    "\n",
-    "length = np.sum(sample_attentions)\n",
-    "\n",
-    "text_clean = tokenizer.decode(sample_inputs)\n",
-    "\n",
-    "labels_pred = token_labels_to_word_labels(text_clean, sample_outputs[1:-1, :], tokenizer)\n",
-    "\n",
-    "actions = labels_pred > 0.1\n",
-    "recover_text(text_clean, actions)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.2-final"
-  },
-  "orig_nbformat": 2,
-  "kernelspec": {
-   "name": "python38264bita7d7da14168440cb9836372958035d4a",
-   "display_name": "Python 3.8.2 64-bit"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
\ No newline at end of file
diff --git a/notebooks/test_translations_model.ipynb b/notebooks/test_translations_model.ipynb
new file mode 100644
index 0000000..618327d
--- /dev/null
+++ b/notebooks/test_translations_model.ipynb
@@ -0,0 +1,150 @@
+{
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.2-final"
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python38264bita7d7da14168440cb9836372958035d4a",
+   "display_name": "Python 3.8.2 64-bit"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 45,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import sys\n",
+    "sys.path.append(\"../\")\n",
+    "\n",
+    "from transformers import BertTokenizerFast, BertForMaskedLM\n",
+    "import torch\n",
+    "from torch.nn import BCEWithLogitsLoss\n",
+    "import pandas as pd\n",
+    "import numpy as np\n",
+    "import dask.dataframe as dd\n",
+    "\n",
+    "from src.processing import create_model_input_output"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 46,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "MODEL_BASE = \"dkleczek/bert-base-polish-cased-v1\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 47,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tokenizer = BertTokenizerFast.from_pretrained(MODEL_BASE)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 48,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "expected = \"Dekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik niemieckiego obozu pracy znęcał się nad obywatelami polskimi. Oskarżony Musielski. Świadek Jankowska opowiedziała, jak bił on Polaków po twarzy, kopał i groził Majdankiem. Świadek Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu, Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\"\n",
+    "text_clean = create_model_input_output(expected)[0]\n",
+    "\n",
+    "inputs = tokenizer(text_clean, return_tensors=\"pt\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 49,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stderr",
+     "text": "Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\nSome weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+    },
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "<All keys matched successfully>"
+     },
+     "metadata": {},
+     "execution_count": 49
+    }
+   ],
+   "source": [
+    "model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=6)\n",
+    "model.load_state_dict(torch.load(\"../models/actionv1_500-0-5000.model\", map_location={'cuda:0': 'cpu'}))\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 50,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "{'input_ids': tensor([[   101,  18910,  34942,  10147,  41458,  49264,  86412,  14196,  28780,\n          16828,  19435,  12507,  10132,    191,  10157,  10305,  16828,  47472,\n          10147,  51043,  53549,  10157,  22815,  29378,  42606,  32650,  15953,\n          24011,  10756,  13130,  10162,  25085,  10756,    191,  21619,  12240,\n          19094,  10136,  17528,    191,  82689,  59763,  10157,    183,    194,\n          14951,  12355,  77029,  10138,  49200,  18795,  13130,  16050,  45244,\n          58814,  63256,  19172,  10598,  30214,  11058,  18933,    191,  20728,\n          25983,  10390,  10424,  72058,  15083,  11372,  11841,  91496,  85421,\n            187,  21633,  22427,  57753,  10514,  96760,  10390,  44227,  22530,\n            191,  12979,    194,  20923,  63168,  10269,  11202,  19234,  14996,\n          39144,  10157,    194,  25470,  10305,  91865,  39371,  18694,  82836,\n          21799,  20868,  22578,  87985,  20162,  21791,  10138,  13680,  10701,\n         107626,  80711,  92859,  27648,    194,  31223,  10425,  11133,  10424,\n          12060,  17339,  70500,  87907,  10500,  72325,  10116,  10427,  15190,\n          36701,  87985,  18338,  10506,  18996, 103412,  10174,  63923,  72275,\n          11485,  10303,  28612, 110206,  10113,  13050,  11342,  11133,  10135,\n          14151,  16036,  10514,  37975,  27828,  39268,  16251,    177,  30518,\n          20129,  21617,  14991,  30022,  13711,  18996, 103412,  10174,  45244,\n          62872,  28780,  79534,  12537,  87985,  18338,  10506,  20157,  82351,\n          10157,  61610,  11133,    187,  21633,  14302,  11680,  12097,  14194,\n          82775,  13717,  23090, 108605, 107626,  10644,  92859,  44227,  22064,\n          11284,  96858,  11813,  43307,  17112,  84280,  10339,  67464,  40384,\n          12197,  10427,  15190,  61187,  10797,  25085,  10157,  26584,  10514,\n          90086, 102984,  13130,  10162,  31569,  34105,  87985,  18338,  18761,\n          10132,  25085,  10963,  26584,  11048,  10514,  52784,  21620,    102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
+     },
+     "metadata": {},
+     "execution_count": 50
+    }
+   ],
+   "source": [
+    "inputs"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 51,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": "Dekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik niemieckiego obozu pracy znęcał się nad obywatelami polskimi. Oskarżony Musielski. Świadek Jankowska opowiedziała, jak bił on Polaków po twarzy, kopał i groził Majdankiem. Świadek Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu, Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\n------\nDekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik Niemieckiego Obozu pracy znęcał się nad obywatelami polskimi. oskarżony musielski świadek Jankowska opowiedziała, jak bił on Polaków po twarzy kopał i groził majdankiem świadek. Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\n"
+    }
+   ],
+   "source": [
+    "from src.processing import token_labels_to_word_labels, recover_text\n",
+    "\n",
+    "y_pred = model(**inputs)[0].sigmoid()\n",
+    "labels_pred = token_labels_to_word_labels(text_clean, y_pred.detach().numpy()[0, 1:-1, :], tokenizer)\n",
+    "\n",
+    "actions = labels_pred > 0.5\n",
+    "print(expected)\n",
+    "print(\"------\")\n",
+    "print(recover_text(text_clean, actions))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ]
+}
\ No newline at end of file
diff --git a/notebooks/torch_exploration.ipynb b/notebooks/torch_exploration.ipynb
new file mode 100644
index 0000000..58f68ee
--- /dev/null
+++ b/notebooks/torch_exploration.ipynb
@@ -0,0 +1,118 @@
+{
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.2-final"
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python38264bita7d7da14168440cb9836372958035d4a",
+   "display_name": "Python 3.8.2 64-bit"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 62,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "error",
+     "ename": "RuntimeError",
+     "evalue": "Could not run 'aten::scatter_.value' with arguments from the 'SparseCPUTensorId' backend. 'aten::scatter_.value' is only available for these backends: [CPUTensorId, CUDATensorId, VariableTensorId].",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-62-9bd81c403b2e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0monehot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msparse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0monehot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscatter_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m \u001b[0monehot\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mRuntimeError\u001b[0m: Could not run 'aten::scatter_.value' with arguments from the 'SparseCPUTensorId' backend. 'aten::scatter_.value' is only available for these backends: [CPUTensorId, CUDATensorId, VariableTensorId]."
+     ]
+    }
+   ],
+   "source": [
+    "i = torch.LongTensor([\n",
+    "    [1, 2, 3],\n",
+    "    [4, 4, 4]\n",
+    "]).unsqueeze(-1)\n",
+    "\n",
+    "onehot = torch.sparse.FloatTensor(2, 3, 5).zero_()\n",
+    "onehot.scatter_(2, i, 1)\n",
+    "onehot"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 77,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "tensor([[1.3133, 1.3133, 1.3133],\n        [0.3133, 0.3133, 0.3133]])"
+     },
+     "metadata": {},
+     "execution_count": 77
+    }
+   ],
+   "source": [
+    "got = torch.tensor([\n",
+    "    [[0, 1], [0, 1], [1, 0]],\n",
+    "    [[0, 1], [0, 1], [1, 0]]\n",
+    "], dtype=torch.float)\n",
+    "\n",
+    "target = torch.tensor([\n",
+    "    [0, 0, 1],\n",
+    "    [1, 1, 0]\n",
+    "])\n",
+    "\n",
+    "got.transpose_(1, 2)\n",
+    "\n",
+    "loss = torch.nn.CrossEntropyLoss(reduction='none')\n",
+    "loss(got, target)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 79,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "tensor([[0.2689, 0.7311],\n        [0.5000, 0.5000]])"
+     },
+     "metadata": {},
+     "execution_count": 79
+    }
+   ],
+   "source": [
+    "torch.tensor([[1.0, 2.0], [5.0, 5.0]]).softmax(-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ]
+}
\ No newline at end of file
diff --git a/params.yaml b/params.yaml
index c11f4e5..ae33122 100644
--- a/params.yaml
+++ b/params.yaml
@@ -49,4 +49,14 @@ translations:
 
     reindexing:
         num_workers: 1
-        worker_memory_limit: "60GB"
\ No newline at end of file
+        worker_memory_limit: "60GB"
+
+    training:
+        learning_rate: 0.001
+        num_epochs: 5
+        batch_size: 1
+        save_step: 1000
+        max_training_time: null
+        loss_averaging_span: 1000
+        fresh_start: false
+        device: "cuda:0"
\ No newline at end of file
diff --git a/scripts/translation_based/train.py b/scripts/translation_based/train.py
index 35b17ee..235d995 100755
--- a/scripts/translation_based/train.py
+++ b/scripts/translation_based/train.py
@@ -12,8 +12,8 @@ from src.utils import PROJECT_ROOT, get_config, convert_to_timedelta, prepare_fo
 from src.processing import ACTIONS_KEYS
 from datetime import datetime
 
-INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
-OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions"
+INPUT_PATH = f"{PROJECT_ROOT}/generated/translations/stage4_reindexing"
+OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/translations"
 
 if __name__ == "__main__":
     config = get_config()
@@ -25,7 +25,7 @@ if __name__ == "__main__":
     fresh_start = config['translations']['training']['fresh_start']
     device_name = config['translations']['training']['device']
     max_train_time = config['translations']['training']['max_training_time']
-    base_model = config['translations']['base_model']
+    base_model = config['global']['base_model']
 
     prepare_folder(OUTPUT_PATH)
 
@@ -40,7 +40,7 @@ if __name__ == "__main__":
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
     model = BertForMaskedLM.from_pretrained(base_model).to(device)
-    criterion = BCEWithLogitsLoss().to(device)
+    criterion = torch.nn.CrossEntropyLoss(reduction="mean").to(device)
     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
     epoch_start = 0
@@ -79,9 +79,9 @@ if __name__ == "__main__":
             break
 
         i = sample_start
+        
         while True:
-            # TODO: Change to 0-indexed...
-            data_batch_indexes = list(range(i*batch_size+1, i*batch_size + batch_size +1))
+            data_batch_indexes = list(range(i*batch_size, i*batch_size + batch_size))
             
             # Precomputing total number of samples takes very long, so lets
             # try to get next batch until fail :)
@@ -91,17 +91,18 @@ if __name__ == "__main__":
                 # TODO: Specify exception type
                 break
 
-            inputs = data_batch.apply(lambda x: x['inputs'].reshape(x['inputs_shape']), axis=1).values
-            outputs = data_batch.apply(lambda x: x['outputs'].reshape(x['outputs_shape']), axis=1).values
-            attentions_mask = data_batch.apply(lambda x: x['attentions'].reshape(x['attentions_shape']), axis=1).values
+            inputs = data_batch.apply(lambda x: x['source'].reshape(x['source_shape']), axis=1).values
+            outputs = data_batch.apply(lambda x: x['target'].reshape(x['target_shape']), axis=1).values
+            attentions_mask = data_batch.apply(lambda x: x['attention_mask'].reshape(x['attention_mask_shape']), axis=1).values
 
-            inputs = torch.tensor(np.stack(inputs).squeeze()).to(device)
-            outputs = torch.tensor(np.stack(outputs)).to(device)
-            attentions_mask = torch.tensor(np.stack(attentions_mask)).to(device)
+            inputs = torch.tensor(np.stack(inputs, axis=0)).to(device)
+            attentions_mask = torch.tensor(np.stack(attentions_mask, axis=0)).to(device)
+            output_indices = torch.tensor(np.stack(outputs, axis=0)).to(device)
 
             y_pred = model(input_ids=inputs, attention_mask=attentions_mask)[0]
+            y_pred = y_pred.transpose(1, 2)
 
-            loss = criterion(y_pred, outputs)
+            loss = criterion(y_pred, output_indices)
 
             losses.append(loss.item())
             if len(losses) > loss_averaging_span:
-- 
GitLab


From 84f5f3179d49ee00db4349e4a4e8e686cb80978c Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 4 Aug 2020 12:15:49 +0200
Subject: [PATCH 030/116] Added translation training to dag

---
 dvc.yaml | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)

diff --git a/dvc.yaml b/dvc.yaml
index 4a20714..51001a5 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -74,4 +74,18 @@ stages:
     deps:
     - generated/translations/stage3_exploding
     outs:
-    - generated/translations/stage4_reindexing
\ No newline at end of file
+    - generated/translations/stage4_reindexing
+  translations_training:
+    cmd: python3 -m scripts.translation_based.train
+    deps:
+    - generated/translations/stage4_reindexing
+    - scripts/translation_based/train.py
+    params:
+    - global.base_model
+    - translations.training.max_training_time
+    - translations.training.learning_rate
+    - translations.training.num_epochs
+    - translations.training.batch_size
+    - translations.training.save_step
+    outs:
+    - checkpoints/translations
\ No newline at end of file
-- 
GitLab


From a9472e19c2c63d2e68a5977e4e831758a1b3a0df Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 4 Aug 2020 16:33:49 +0000
Subject: [PATCH 031/116] Only encoder didn't work. Implemented full
 encoder-decoder

---
 .gitignore                                    |   1 +
 dvc.lock                                      |  29 +++-
 notebooks/torch_transformer.ipynb             | 133 ++++++++++++++++++
 params.yaml                                   |  14 +-
 scripts/actions_based/train.py                |   1 +
 .../stage2_create_batches.py                  |   1 +
 scripts/translation_based/train.py            |  10 +-
 src/models/TransformerSeq2Seq.py              |  92 ++++++++++++
 src/models/__init__.py                        |   0
 9 files changed, 269 insertions(+), 12 deletions(-)
 create mode 100644 notebooks/torch_transformer.ipynb
 create mode 100644 src/models/TransformerSeq2Seq.py
 create mode 100644 src/models/__init__.py

diff --git a/.gitignore b/.gitignore
index 1272d38..fec9ca9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,3 +9,4 @@ dataset_actions
 __pycache__
 .pytest_cache
 /checkpoints
+.dvc
diff --git a/dvc.lock b/dvc.lock
index 54498b9..8780ddc 100644
--- a/dvc.lock
+++ b/dvc.lock
@@ -55,7 +55,7 @@ translations_extraction:
       translations.extraction.num_partitions: 2000
   outs:
   - path: generated/translations/stage1_extraction
-    md5: 61a1a88c672e485fd9b0dc0ef22817a9.dir
+    md5: c7f5bb265082fdd21b8936ddca14a8ab.dir
 translations_tokenization:
   cmd: python3 -m scripts.translation_based.stage2_tokenization
   deps:
@@ -133,3 +133,30 @@ actions_training:
   outs:
   - path: checkpoints/actions
     md5: 6116b19bae31f503a350b635125a6daf.dir
+translations_create_batches:
+  cmd: python3 -m scripts.translation_based.stage2_create_batches
+  deps:
+  - path: generated/translations/stage1_extraction
+    md5: c7f5bb265082fdd21b8936ddca14a8ab.dir
+  params:
+    params.yaml:
+      global.base_model: dkleczek/bert-base-polish-cased-v1
+  outs:
+  - path: generated/translations/stage2_create_batches
+    md5: 730e90598dac106a9088eb0906caa227.dir
+translations_exploding:
+  cmd: python3 -m scripts.translation_based.stage3_exploding
+  deps:
+  - path: generated/translations/stage2_create_batches
+    md5: 730e90598dac106a9088eb0906caa227.dir
+  outs:
+  - path: generated/translations/stage3_exploding
+    md5: 918ba496477757257e702b12da0ef21e.dir
+translations_reindexing:
+  cmd: python3 -m scripts.translation_based.stage4_reindexing
+  deps:
+  - path: generated/translations/stage3_exploding
+    md5: 918ba496477757257e702b12da0ef21e.dir
+  outs:
+  - path: generated/translations/stage4_reindexing
+    md5: caa09e33b141187800d330ab131a45e0.dir
diff --git a/notebooks/torch_transformer.ipynb b/notebooks/torch_transformer.ipynb
new file mode 100644
index 0000000..e8d69dd
--- /dev/null
+++ b/notebooks/torch_transformer.ipynb
@@ -0,0 +1,133 @@
+{
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": 3
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python_defaultSpec_1596544773362",
+   "display_name": "Python 3.8.2 64-bit"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import math"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class PositionalEncoding(nn.Module):\n",
+    "    def __init__(self, d_model, dropout=0.1, max_len=5000):\n",
+    "        super(PositionalEncoding, self).__init__()\n",
+    "        self.dropout = nn.Dropout(p=dropout)\n",
+    "\n",
+    "        pe = torch.zeros(max_len, d_model)\n",
+    "        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
+    "        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n",
+    "        pe[:, 0::2] = torch.sin(position * div_term)\n",
+    "        pe[:, 1::2] = torch.cos(position * div_term)\n",
+    "        pe = pe.unsqueeze(0).transpose(0, 1)\n",
+    "        self.register_buffer('pe', pe)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        x = x + self.pe[:x.size(0), :]\n",
+    "        return self.dropout(x)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 63,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class Transformer(nn.Module):\n",
+    "    def __init__(self, vocab_size: int = 5, embedding_size: int = 16, num_heads: int = 4, encoder_layers: int = 1, decoder_layers: int = 1, feedforward_neurons: int = 100, dropout: float = 0.1, max_len: int= 10):\n",
+    "        super(Transformer, self).__init__()\n",
+    "\n",
+    "        self.word_embedding = nn.Embedding(vocab_size, embedding_size)\n",
+    "        self.position_embedding = PositionalEncoding(embedding_size, dropout, max_len)\n",
+    "        self.core = nn.Transformer(embedding_size, num_heads, encoder_layers, decoder_layers, feedforward_neurons, dropout)\n",
+    "        self.embedding_to_words = nn.Linear(embedding_size, vocab_size)\n",
+    "\n",
+    "    def forward(self, source, target, source_mask):\n",
+    "        x = source.transpose(0, 1)\n",
+    "        x = self.word_embedding(x)\n",
+    "        x = self.position_embedding(x)\n",
+    "\n",
+    "        y = target.transpose(0, 1)\n",
+    "        y = self.word_embedding(y)\n",
+    "        y = self.position_embedding(y)\n",
+    "\n",
+    "        tgt_mask = self.core.generate_square_subsequent_mask(y.shape[0])\n",
+    "\n",
+    "        print(tgt_mask.shape)\n",
+    "\n",
+    "        z = self.core(x, y, src_key_padding_mask=source_mask, tgt_mask=tgt_mask).transpose(1, 0)\n",
+    "        z = self.embedding_to_words(z)\n",
+    "\n",
+    "        return z"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 66,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": "torch.Size([3, 3])\n"
+    },
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "torch.Size([2, 3, 5])"
+     },
+     "metadata": {},
+     "execution_count": 66
+    }
+   ],
+   "source": [
+    "transformer = Transformer()\n",
+    "\n",
+    "example_batch = torch.randint(0, 5, (2, 4))\n",
+    "example_target = torch.randint(0, 5, (2, 4))\n",
+    "\n",
+    "source_mask = torch.ones_like(example_batch, dtype=torch.uint8) == 0\n",
+    "\n",
+    "transformer(example_batch, example_target[:, :-1], source_mask).shape\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ]
+}
\ No newline at end of file
diff --git a/params.yaml b/params.yaml
index ae33122..d399494 100644
--- a/params.yaml
+++ b/params.yaml
@@ -40,8 +40,8 @@ translations:
     create_batches:
         num_workers: 24
         worker_memory_limit: "2GB"
-        min_tokens: 10
-        max_tokens: 500
+        min_tokens: 5
+        max_tokens: 300
 
     exploding:
         num_workers: 24
@@ -52,11 +52,11 @@ translations:
         worker_memory_limit: "60GB"
 
     training:
-        learning_rate: 0.001
+        learning_rate: 0.0001
         num_epochs: 5
-        batch_size: 1
+        batch_size: 10
         save_step: 1000
-        max_training_time: null
+        max_training_time: "4h"
         loss_averaging_span: 1000
-        fresh_start: false
-        device: "cuda:0"
\ No newline at end of file
+        fresh_start: true
+        device: "cuda:1"
\ No newline at end of file
diff --git a/scripts/actions_based/train.py b/scripts/actions_based/train.py
index 46abed9..1530f49 100755
--- a/scripts/actions_based/train.py
+++ b/scripts/actions_based/train.py
@@ -66,6 +66,7 @@ if __name__ == "__main__":
             print(f"Loaded {furthest_epoch}-{furthest_batch_num}")
 
     model.train()
+    model.base_model.train()
     losses = []
 
     training_stopped = False
diff --git a/scripts/translation_based/stage2_create_batches.py b/scripts/translation_based/stage2_create_batches.py
index a7f110e..c7719f3 100644
--- a/scripts/translation_based/stage2_create_batches.py
+++ b/scripts/translation_based/stage2_create_batches.py
@@ -31,6 +31,7 @@ if __name__ == "__main__":
 
     df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
     df = df.apply(generate_batches, result_type='expand', axis=1, meta=GENERATE_BATCHES_META, args=(min_tokens, max_tokens, token_separating, tokenizer))
+    df = df.dropna()
 
     # Export
     df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
\ No newline at end of file
diff --git a/scripts/translation_based/train.py b/scripts/translation_based/train.py
index 235d995..0ba36f2 100755
--- a/scripts/translation_based/train.py
+++ b/scripts/translation_based/train.py
@@ -11,6 +11,7 @@ import glob
 from src.utils import PROJECT_ROOT, get_config, convert_to_timedelta, prepare_folder
 from src.processing import ACTIONS_KEYS
 from datetime import datetime
+from src.models.TransformerSeq2Seq import TransformerSeq2Seq
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/translations/stage4_reindexing"
 OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/translations"
@@ -18,6 +19,7 @@ OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/translations"
 if __name__ == "__main__":
     config = get_config()
     learning_rate = config['translations']['training']['learning_rate']
+    max_len = config['translations']['create_batches']['max_tokens']
     num_epochs = config['translations']['training']['num_epochs']
     batch_size = config['translations']['training']['batch_size']
     save_step = config['translations']['training']['save_step']
@@ -39,7 +41,7 @@ if __name__ == "__main__":
     
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
-    model = BertForMaskedLM.from_pretrained(base_model).to(device)
+    model = TransformerSeq2Seq(tokenizer.vocab_size, 200, max_len, 1, 2, 2, ).to(device)
     criterion = torch.nn.CrossEntropyLoss(reduction="mean").to(device)
     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
@@ -96,13 +98,13 @@ if __name__ == "__main__":
             attentions_mask = data_batch.apply(lambda x: x['attention_mask'].reshape(x['attention_mask_shape']), axis=1).values
 
             inputs = torch.tensor(np.stack(inputs, axis=0)).to(device)
-            attentions_mask = torch.tensor(np.stack(attentions_mask, axis=0)).to(device)
+            attentions_mask = torch.tensor(np.stack(attentions_mask, axis=0) == 0).to(device)
             output_indices = torch.tensor(np.stack(outputs, axis=0)).to(device)
 
-            y_pred = model(input_ids=inputs, attention_mask=attentions_mask)[0]
+            y_pred = model(inputs, output_indices[:, :-1], attentions_mask)
             y_pred = y_pred.transpose(1, 2)
 
-            loss = criterion(y_pred, output_indices)
+            loss = criterion(y_pred, output_indices[:, 1:])
 
             losses.append(loss.item())
             if len(losses) > loss_averaging_span:
diff --git a/src/models/TransformerSeq2Seq.py b/src/models/TransformerSeq2Seq.py
new file mode 100644
index 0000000..2466348
--- /dev/null
+++ b/src/models/TransformerSeq2Seq.py
@@ -0,0 +1,92 @@
+import torch
+import torch.nn as nn
+import math
+
+
+class PositionalEncoding(nn.Module):
+    """Adds sinsusoidal positional encoding (as in original AIAYN paper)
+    src: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
+
+    """
+
+    def __init__(self, d_model: int, max_len: int, dropout=0.1):
+        """Sinusidal positional encodings
+
+        Args:
+            d_model (int): Embedding dimension
+            max_len (int): Maximum length of sequence
+            dropout (float, optional): Dropout ratio. Defaults to 0.1.
+        """
+        super(PositionalEncoding, self).__init__()
+        self.dropout = nn.Dropout(p=dropout)
+
+        pe = torch.zeros(max_len, d_model)
+        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+        div_term = torch.exp(torch.arange(
+            0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0).transpose(0, 1)
+        self.register_buffer('pe', pe)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Applies positional encoding
+
+        Args:
+            x (torch.Tensor): Word embeddings tensor
+
+        Returns:
+            torch.Tensor: Word embeddings with added positional encodings
+        """
+        x = x + self.pe[:x.size(0), :]
+        return self.dropout(x)
+
+
+class TransformerSeq2Seq(nn.Module):
+    """Class representing a sequence to sequence transformer, based on original "Attention is all you need" paper.
+    """
+
+    def __init__(self, vocab_size: int
+        , embedding_size: int
+        , max_len: int
+        , num_heads: int = 8
+        , encoder_layers: int = 6
+        , decoder_layers: int = 6
+        , feedforward_neurons: int = 2048
+        , dropout: float = 0.1):
+
+        super(TransformerSeq2Seq, self).__init__()
+
+        self.word_embedding = nn.Embedding(vocab_size, embedding_size)
+        self.position_embedding = PositionalEncoding(embedding_size, max_len, dropout)
+        self.core = nn.Transformer(
+            embedding_size, num_heads, encoder_layers, decoder_layers, feedforward_neurons, dropout)
+        self.embedding_to_words = nn.Linear(embedding_size, vocab_size)
+
+    def forward(self, source: torch.Tensor, target: torch.Tensor, source_mask: torch.Tensor) -> torch.Tensor:
+        """Full encoder-decoder pass
+
+        Args:
+            source (torch.Tensor): Tensor with batch of source sentences tokens (BxL)
+            target (torch.Tensor): Tensor with batch of target sentences tokens (BxL-1)
+            source_mask (torch.Tensor): Mask applied to source (True if element is padding, False otherwise) (BxL)
+
+        Returns:
+            torch.Tensor: Tensor with predicted target sentences tokens (BxL-1xV)
+        """
+
+        x = source.transpose(0, 1)
+        x = self.word_embedding(x)
+        x = self.position_embedding(x)
+
+        y = target.transpose(0, 1)
+        y = self.word_embedding(y)
+        y = self.position_embedding(y)
+
+        tgt_mask = self.core.generate_square_subsequent_mask(y.shape[0]).to(y.device)
+
+        z = self.core(x, y, src_key_padding_mask=source_mask,
+                      tgt_mask=tgt_mask).transpose(1, 0)
+        z = self.embedding_to_words(z)
+
+        return z
diff --git a/src/models/__init__.py b/src/models/__init__.py
new file mode 100644
index 0000000..e69de29
-- 
GitLab


From 19c56782b3b51451ad3ada62745a328e718da07e Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 5 Aug 2020 11:49:39 +0200
Subject: [PATCH 032/116] Newest version of pyarrow has bug on serializing
 numpy arrays

---
 docker/Dockerfile | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/docker/Dockerfile b/docker/Dockerfile
index 5672adf..1535758 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -37,5 +37,5 @@ ENV NVIDIA_REQUIRE_CUDA "cuda>=10.2 brand=tesla,driver>=384,driver<385 brand=tes
 
 ### END CUDA Installation
 
-RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pyarrow pytest lxml
+RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pyarrow==0.17.1 pytest lxml
 RUN ln -s /usr/bin/pip3 /usr/bin/pip
\ No newline at end of file
-- 
GitLab


From e23f247c4c9d154f1f17e5b5a960c4e4fd80718c Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 5 Aug 2020 14:23:33 +0200
Subject: [PATCH 033/116] Added weighting into actions pipieline

---
 generated/actions/.gitignore                 |  1 +
 scripts/actions_based/processing.py          | 82 ++++++++++++++++++++
 scripts/actions_based/stage2_tokenization.py | 37 +--------
 scripts/actions_based/stage3_exploding.py    | 46 +----------
 scripts/actions_based/stage5_loss_weights.py | 30 +++----
 5 files changed, 105 insertions(+), 91 deletions(-)
 create mode 100644 scripts/actions_based/processing.py

diff --git a/generated/actions/.gitignore b/generated/actions/.gitignore
index 959c3a4..828b48b 100644
--- a/generated/actions/.gitignore
+++ b/generated/actions/.gitignore
@@ -2,3 +2,4 @@
 /stage2_tokenization
 /stage3_exploding
 /stage4_reindexing
+/stage5_loss_weights
\ No newline at end of file
diff --git a/scripts/actions_based/processing.py b/scripts/actions_based/processing.py
new file mode 100644
index 0000000..8e8db74
--- /dev/null
+++ b/scripts/actions_based/processing.py
@@ -0,0 +1,82 @@
+from transformers import BertTokenizerFast
+from src.processing import tokenize_labeled_text, batchify_data
+import numpy as np
+
+def expand_dims(entry: dict):
+    inputs = entry.input.reshape(entry.input_shape)
+    outputs = entry.output.reshape(entry.output_shape)
+    masks = entry.attention_mask.reshape(entry.attention_mask_shape)
+
+    return {
+        'input': inputs,
+        'output': outputs,
+        "attention_mask": masks,
+    }
+
+EXPAND_DIMS_META = {
+    'input': object,
+    'output': object,
+    'attention_mask': object
+}
+
+def apply_tokenization(df, min_tokens: int, max_tokens: int, tokenizer: BertTokenizerFast):
+    text_clean = df.input
+    labels = df.output
+    shape = df.output_shape
+
+    tokens, token_labels = tokenize_labeled_text(
+        text_clean, labels.reshape(shape), tokenizer)
+
+    inputs, outputs, attentions = batchify_data(
+        tokens, token_labels, max_tokens, tokenizer, min_tokens)
+
+    inputs_shape = np.array(inputs.shape)
+    outputs_shape = np.array(outputs.shape)
+    attentions_shape = np.array(attentions.shape)
+
+    return {
+        'input': inputs.reshape(-1),
+        'output': outputs.reshape(-1),
+        'attention_mask': attentions.reshape(-1),
+        'input_shape': inputs_shape,
+        'output_shape': outputs_shape,
+        'attention_mask_shape': attentions_shape
+    }
+
+
+APPLY_TOKENIZATION_META = {
+    'input': object,
+    'output': object,
+    'attention_mask': object,
+    'input_shape': object,
+    'output_shape': object,
+    'attention_mask_shape': object
+}
+
+def flatten_dims(entry):
+    inputs_shape = np.array(entry.input.shape)
+    outputs_shape = np.array(entry.output.shape)
+    attentions_shape = np.array(entry.attention_mask.shape)
+
+    inputs = entry.input.reshape(-1)
+    outputs = entry.output.reshape(-1)
+    attentions = entry.attention_mask.reshape(-1)
+
+    return {
+        'input': inputs,
+        'output': outputs,
+        'attention_mask': attentions,
+        'input_shape': inputs_shape,
+        'output_shape': outputs_shape,
+        'attention_mask_shape': attentions_shape
+    }
+
+
+FLATTEN_DIMS_META = {
+    'input': object,
+    'output': object,
+    'attention_mask': object,
+    'input_shape': object,
+    'output_shape': object,
+    'attention_mask_shape': object
+}
diff --git a/scripts/actions_based/stage2_tokenization.py b/scripts/actions_based/stage2_tokenization.py
index a6cb763..fae7d86 100644
--- a/scripts/actions_based/stage2_tokenization.py
+++ b/scripts/actions_based/stage2_tokenization.py
@@ -16,44 +16,11 @@ from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
 from transformers import BertTokenizerFast
 from dask.distributed import Client
+from scripts.actions_based.processing import apply_tokenization, APPLY_TOKENIZATION_META
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
 
-def apply_tokenization(df, min_tokens: int, max_tokens: int, tokenizer: BertTokenizerFast):
-    text_clean = df.input
-    labels = df.output
-    shape = df.output_shape
-
-    tokens, token_labels = tokenize_labeled_text(
-        text_clean, labels.reshape(shape), tokenizer)
-
-    inputs, outputs, attentions = batchify_data(
-        tokens, token_labels, max_tokens, tokenizer, min_tokens)
-
-    inputs_shape = np.array(inputs.shape)
-    outputs_shape = np.array(outputs.shape)
-    attentions_shape = np.array(attentions.shape)
-
-    return {
-        'inputs': inputs.reshape(-1),
-        'outputs': outputs.reshape(-1),
-        'attentions': attentions.reshape(-1),
-        'input_shape': inputs_shape,
-        'output_shape': outputs_shape,
-        'attentions_shape': attentions_shape
-    }
-
-
-RESULT_META = {
-    'inputs': object,
-    'outputs': object,
-    'attentions': object,
-    'input_shape': object,
-    'output_shape': object,
-    'attentions_shape': object
-}
-
 if __name__ == "__main__":
 
     config = get_config()
@@ -74,6 +41,6 @@ if __name__ == "__main__":
 
     df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
     df = df.apply(apply_tokenization, args=(min_tokens, max_tokens, tokenizer),
-                  result_type='expand', axis=1, meta=RESULT_META)
+                  result_type='expand', axis=1, meta=APPLY_TOKENIZATION_META)
 
     df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/scripts/actions_based/stage3_exploding.py b/scripts/actions_based/stage3_exploding.py
index 7c0b0d4..2978324 100644
--- a/scripts/actions_based/stage3_exploding.py
+++ b/scripts/actions_based/stage3_exploding.py
@@ -7,49 +7,11 @@ import dask
 from dask.distributed import Client
 import pandas as pd
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
+from scripts.actions_based.processing import expand_dims, EXPAND_DIMS_META, flatten_dims, FLATTEN_DIMS_META
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
 
-def expand_dims(entry):
-    inputs = entry.inputs.reshape(entry.input_shape)
-    outputs = entry.outputs.reshape(entry.output_shape)
-    masks = entry.attentions.reshape(entry.attentions_shape)
-
-    return {
-        'inputs': inputs,
-        'outputs': outputs,
-        "attentions": masks,
-    }
-
-def flatten_dims(entry):
-    inputs_shape = np.array(entry.inputs.shape)
-    outputs_shape = np.array(entry.outputs.shape)
-    attentions_shape = np.array(entry.attentions.shape)
-
-    inputs = entry.inputs.reshape(-1)
-    outputs = entry.outputs.reshape(-1)
-    attentions = entry.attentions.reshape(-1)
-
-    return {
-        'inputs': inputs,
-        'outputs': outputs,
-        'attentions': attentions,
-        'inputs_shape': inputs_shape,
-        'outputs_shape': outputs_shape,
-        'attentions_shape': attentions_shape
-    }
-
-
-RESULT_META = {
-    'inputs': object,
-    'outputs': object,
-    'attentions': object,
-    'inputs_shape': object,
-    'outputs_shape': object,
-    'attentions_shape': object
-}
-
 if __name__ == "__main__":
     config = get_config()
     num_workers = config['actions']['exploding']['num_workers']
@@ -62,8 +24,8 @@ if __name__ == "__main__":
 
     df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
 
-    df = df.apply(expand_dims, result_type='expand', axis=1, meta={'inputs': object, 'outputs': object, 'attentions': object})
-    df = df.map_partitions(lambda x: x.apply(lambda y: y.explode(), axis=0), meta={'inputs': object, 'outputs': object, 'attentions': object})
-    df = df.apply(flatten_dims, result_type='expand', axis=1, meta=RESULT_META)
+    df = df.apply(expand_dims, result_type='expand', axis=1, meta=EXPAND_DIMS_META)
+    df = df.map_partitions(lambda x: x.apply(lambda y: y.explode(), axis=0), meta=EXPAND_DIMS_META)
+    df = df.apply(flatten_dims, result_type='expand', axis=1, meta=FLATTEN_DIMS_META)
     
     df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/scripts/actions_based/stage5_loss_weights.py b/scripts/actions_based/stage5_loss_weights.py
index 72498b7..5d72fcb 100644
--- a/scripts/actions_based/stage5_loss_weights.py
+++ b/scripts/actions_based/stage5_loss_weights.py
@@ -1,5 +1,5 @@
 # /usr/bin/python3
-from src.processing import batchify_data
+from src.processing import batchify_data, ACTIONS_KEYS
 from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
 from transformers import BertTokenizerFast
@@ -8,10 +8,20 @@ import dask
 from dask.distributed import Client
 import pandas as pd
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
+import pickle
+from scripts.actions_based.processing import expand_dims, EXPAND_DIMS_META
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage5_loss_weights"
 
+def reduce(x, y):
+    if len(x.shape) == 2:
+        x = x.sum(axis=0)
+    if len(y.shape) == 2:
+        y = y.sum(axis=0)
+
+    return x + y
+
 if __name__ == "__main__":
     config = get_config()
     num_workers = config['actions']['reindexing']['num_workers']
@@ -23,18 +33,10 @@ if __name__ == "__main__":
     print(client.dashboard_link)
 
     df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
+    df = df.apply(expand_dims, result_type='expand', axis=1, meta=EXPAND_DIMS_META)
 
-    # Add ordered indexes
-    df = df.assign(ones=1)
-    df = df.reset_index(drop=True)
-    idx = (df.ones.cumsum() - 1).persist()
-    df = df.assign(ones=idx)
-
-    # Shuffle 
-    shuffled_idx = idx.compute().values
-    shuffled_idx = client.scatter(shuffled_idx)
-    mapped_ones = df.ones.apply(lambda x, idx: idx[x], args=(shuffled_idx,), meta=('ones', 'int64')).persist()
-    df = df.assign(ones=mapped_ones)
+    outputs_bag = df['output'].to_bag()
+    result = outputs_bag.fold(reduce, initial=np.array([0] * len(ACTIONS_KEYS))).compute()
 
-    df = df.set_index('ones')
-    df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
+    with open(f"{OUTPUT_FOLDER}/stats.pickle", 'wb') as f:
+        pickle.dump(result, f)
\ No newline at end of file
-- 
GitLab


From a1944883737c1c5f35963dfc5ef3c22000ddf9ff Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 5 Aug 2020 16:51:13 +0200
Subject: [PATCH 034/116] Added weigthing into actions training script

---
 dvc.yaml                                      |  8 +++++
 generated/actions/.gitignore                  |  2 +-
 notebooks/dask_dataframe_exploration.ipynb    | 25 ++++++++++++++++
 notebooks/test_actions_model.ipynb            | 23 ++------------
 params.yaml                                   |  4 +++
 ...stage5_loss_weights.py => stage5_stats.py} | 30 ++++++++++++-------
 scripts/actions_based/train.py                | 21 ++++++++-----
 7 files changed, 74 insertions(+), 39 deletions(-)
 rename scripts/actions_based/{stage5_loss_weights.py => stage5_stats.py} (55%)

diff --git a/dvc.yaml b/dvc.yaml
index 51001a5..7875e8d 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -33,10 +33,18 @@ stages:
     - scripts/actions_based/stage4_reindexing.py
     outs:
     - generated/actions/stage4_reindexing
+  actions_stats:
+    cmd: python3 -m scripts.actions_based.stage5_stats
+    deps:
+    - generated/actions/stage4_reindexing
+    - scripts/actions_based/stage5_stats.py
+    outs:
+    - generated/actions/stage5_stats
   actions_training:
     cmd: python3 -m scripts.actions_based.train
     deps:
     - generated/actions/stage4_reindexing
+    - generated/actions/stage5_stats
     - scripts/actions_based/train.py
     params:
     - global.base_model
diff --git a/generated/actions/.gitignore b/generated/actions/.gitignore
index 828b48b..49854ca 100644
--- a/generated/actions/.gitignore
+++ b/generated/actions/.gitignore
@@ -2,4 +2,4 @@
 /stage2_tokenization
 /stage3_exploding
 /stage4_reindexing
-/stage5_loss_weights
\ No newline at end of file
+/stage5_stats
\ No newline at end of file
diff --git a/notebooks/dask_dataframe_exploration.ipynb b/notebooks/dask_dataframe_exploration.ipynb
index 810287e..0e48640 100644
--- a/notebooks/dask_dataframe_exploration.ipynb
+++ b/notebooks/dask_dataframe_exploration.ipynb
@@ -72,6 +72,31 @@
     "df.head()"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": "[  13.41055592    6.22695335    9.50220389 1603.42842607  290.42591041\n  182.60031948]\n"
+    }
+   ],
+   "source": [
+    "import pickle\n",
+    "import numpy as np\n",
+    "with open(\"../generated/actions/stage5_stats/stats.pickle\", 'rb') as f:\n",
+    "    stats = pickle.load(f)\n",
+    "    pos_examples = stats['class_number']\n",
+    "    neg_examples = stats['num_examples'] - stats['class_number']\n",
+    "    ratio = neg_examples / pos_examples\n",
+    "\n",
+    "    print(ratio)"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
diff --git a/notebooks/test_actions_model.ipynb b/notebooks/test_actions_model.ipynb
index 0318f59..d1bae4e 100644
--- a/notebooks/test_actions_model.ipynb
+++ b/notebooks/test_actions_model.ipynb
@@ -46,7 +46,6 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "INPUT_PATH=\"../generated/stage4_reindexing\"\n",
     "MODEL_BASE = \"dkleczek/bert-base-polish-cased-v1\""
    ]
   },
@@ -65,7 +64,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "expected = \"Dekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik niemieckiego obozu pracy znęcał się nad obywatelami polskimi. Oskarżony Musielski. Świadek Jankowska opowiedziała, jak bił on Polaków po twarzy, kopał i groził Majdankiem. Świadek Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu, Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\"\n",
+    "expected = \"W porcie w stolicy Libanu doszło we wtorek po południu do potężnej eksplozji. Zginęło co najmniej sto osób, ok. 4 tys. zostało rannych. Z dotychczasowych informacji wynika, że przyczyną wybuchu mogły być chemikalia, w tym saletra amonowa. Trop wiedzie do statku pod mołdawską banderą.\"\n",
     "text_clean = create_model_input_output(expected)[0]\n",
     "\n",
     "inputs = tokenizer(text_clean, return_tensors=\"pt\")"
@@ -94,25 +93,7 @@
    ],
    "source": [
     "model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=6)\n",
-    "model.load_state_dict(torch.load(\"../models/actionv1_500-0-5000.model\"))\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 50,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "{'input_ids': tensor([[   101,  18910,  34942,  10147,  41458,  49264,  86412,  14196,  28780,\n          16828,  19435,  12507,  10132,    191,  10157,  10305,  16828,  47472,\n          10147,  51043,  53549,  10157,  22815,  29378,  42606,  32650,  15953,\n          24011,  10756,  13130,  10162,  25085,  10756,    191,  21619,  12240,\n          19094,  10136,  17528,    191,  82689,  59763,  10157,    183,    194,\n          14951,  12355,  77029,  10138,  49200,  18795,  13130,  16050,  45244,\n          58814,  63256,  19172,  10598,  30214,  11058,  18933,    191,  20728,\n          25983,  10390,  10424,  72058,  15083,  11372,  11841,  91496,  85421,\n            187,  21633,  22427,  57753,  10514,  96760,  10390,  44227,  22530,\n            191,  12979,    194,  20923,  63168,  10269,  11202,  19234,  14996,\n          39144,  10157,    194,  25470,  10305,  91865,  39371,  18694,  82836,\n          21799,  20868,  22578,  87985,  20162,  21791,  10138,  13680,  10701,\n         107626,  80711,  92859,  27648,    194,  31223,  10425,  11133,  10424,\n          12060,  17339,  70500,  87907,  10500,  72325,  10116,  10427,  15190,\n          36701,  87985,  18338,  10506,  18996, 103412,  10174,  63923,  72275,\n          11485,  10303,  28612, 110206,  10113,  13050,  11342,  11133,  10135,\n          14151,  16036,  10514,  37975,  27828,  39268,  16251,    177,  30518,\n          20129,  21617,  14991,  30022,  13711,  18996, 103412,  10174,  45244,\n          62872,  28780,  79534,  12537,  87985,  18338,  10506,  20157,  82351,\n          10157,  61610,  11133,    187,  21633,  14302,  11680,  12097,  14194,\n          82775,  13717,  23090, 108605, 107626,  10644,  92859,  44227,  22064,\n          11284,  96858,  11813,  43307,  17112,  84280,  10339,  67464,  40384,\n          12197,  10427,  15190,  61187,  10797,  25085,  10157,  26584,  10514,\n          90086, 102984,  13130,  10162,  31569,  34105,  87985,  18338,  18761,\n          10132,  25085,  10963,  26584,  11048,  10514,  52784,  21620,    102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
-     },
-     "metadata": {},
-     "execution_count": 50
-    }
-   ],
-   "source": [
-    "inputs"
+    "model.load_state_dict(torch.load(\"../checkpoints/actions/0-183000.model\", map_location={'cuda:0': 'cpu'}))\n"
    ]
   },
   {
diff --git a/params.yaml b/params.yaml
index d399494..e084c7a 100644
--- a/params.yaml
+++ b/params.yaml
@@ -22,6 +22,10 @@ actions:
         num_workers: 1
         worker_memory_limit: "60GB"
 
+    stats:
+        num_workers: 24
+        worker_memory_limit: "2GB"
+
     training:
         learning_rate: 0.0001
         num_epochs: 5
diff --git a/scripts/actions_based/stage5_loss_weights.py b/scripts/actions_based/stage5_stats.py
similarity index 55%
rename from scripts/actions_based/stage5_loss_weights.py
rename to scripts/actions_based/stage5_stats.py
index 5d72fcb..7038008 100644
--- a/scripts/actions_based/stage5_loss_weights.py
+++ b/scripts/actions_based/stage5_stats.py
@@ -12,20 +12,24 @@ import pickle
 from scripts.actions_based.processing import expand_dims, EXPAND_DIMS_META
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage5_loss_weights"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
 
-def reduce(x, y):
-    if len(x.shape) == 2:
-        x = x.sum(axis=0)
-    if len(y.shape) == 2:
-        y = y.sum(axis=0)
+def reduce_fold(fold_value, new_value):
+    return {
+        'class_number': fold_value["class_number"] + np.sum(new_value, axis=0),
+        'num_examples': fold_value["num_examples"] + new_value.shape[0]
+    }
 
-    return x + y
+def reduce_partitions(x, y):
+    return {
+        'class_number': x["class_number"] + y["class_number"],
+        'num_examples': x["num_examples"] + y["num_examples"]
+    }
 
 if __name__ == "__main__":
     config = get_config()
-    num_workers = config['actions']['reindexing']['num_workers']
-    memory_limit = config['actions']['reindexing']['worker_memory_limit']
+    num_workers = config['actions']['stats']['num_workers']
+    memory_limit = config['actions']['stats']['worker_memory_limit']
 
     prepare_folder(OUTPUT_FOLDER)
 
@@ -36,7 +40,13 @@ if __name__ == "__main__":
     df = df.apply(expand_dims, result_type='expand', axis=1, meta=EXPAND_DIMS_META)
 
     outputs_bag = df['output'].to_bag()
-    result = outputs_bag.fold(reduce, initial=np.array([0] * len(ACTIONS_KEYS))).compute()
+
+    inital_values = {
+        "class_number": np.array([0] * len(ACTIONS_KEYS)),
+        'num_examples': 0
+    }
+
+    result = outputs_bag.fold(reduce_fold, reduce_partitions, initial=inital_values).compute()
 
     with open(f"{OUTPUT_FOLDER}/stats.pickle", 'wb') as f:
         pickle.dump(result, f)
\ No newline at end of file
diff --git a/scripts/actions_based/train.py b/scripts/actions_based/train.py
index 1530f49..3a5bb3c 100755
--- a/scripts/actions_based/train.py
+++ b/scripts/actions_based/train.py
@@ -11,8 +11,10 @@ import glob
 from src.utils import PROJECT_ROOT, get_config, convert_to_timedelta, prepare_folder
 from src.processing import ACTIONS_KEYS
 from datetime import datetime
+import pickle
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
 OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions"
 
 if __name__ == "__main__":
@@ -35,12 +37,18 @@ if __name__ == "__main__":
     device = torch.device(device_name if torch.cuda.is_available() else "cpu")
     print(f"Training on {device}")
 
+    # Load loss weights
+    with open(f"{INPUT_STATS_PATH}/stats.pickle", 'rb') as f:
+        stats = pickle.load(f)
+        pos_examples = stats['class_number']
+        neg_examples = stats['num_examples'] - stats['class_number']
+        pos_weight = torch.tensor(neg_examples / pos_examples)
+
     df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
-    
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
     model = BertForTokenClassification.from_pretrained(base_model, num_labels=len(ACTIONS_KEYS)).to(device)
-    criterion = BCEWithLogitsLoss().to(device)
+    criterion = BCEWithLogitsLoss(pos_weight=pos_weight).to(device)
     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
     epoch_start = 0
@@ -81,8 +89,7 @@ if __name__ == "__main__":
 
         i = sample_start
         while True:
-            # TODO: Change to 0-indexed...
-            data_batch_indexes = list(range(i*batch_size+1, i*batch_size + batch_size +1))
+            data_batch_indexes = list(range(i*batch_size, i*batch_size + batch_size))
             
             # Precomputing total number of samples takes very long, so lets
             # try to get next batch until fail :)
@@ -92,9 +99,9 @@ if __name__ == "__main__":
                 # TODO: Specify exception type
                 break
 
-            inputs = data_batch.apply(lambda x: x['inputs'].reshape(x['inputs_shape']), axis=1).values
-            outputs = data_batch.apply(lambda x: x['outputs'].reshape(x['outputs_shape']), axis=1).values
-            attentions_mask = data_batch.apply(lambda x: x['attentions'].reshape(x['attentions_shape']), axis=1).values
+            inputs = data_batch.apply(lambda x: x['input'].reshape(x['input_shape']), axis=1).values
+            outputs = data_batch.apply(lambda x: x['output'].reshape(x['output_shape']), axis=1).values
+            attentions_mask = data_batch.apply(lambda x: x['attention_mask'].reshape(x['attention_mask_shape']), axis=1).values
 
             inputs = torch.tensor(np.stack(inputs).squeeze()).to(device)
             outputs = torch.tensor(np.stack(outputs)).to(device)
-- 
GitLab


From d5cb850f24d91cb44fbb403de634edc8881dffdd Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 5 Aug 2020 19:34:41 +0200
Subject: [PATCH 035/116] Changes in shuffling

---
 scripts/actions_based/stage4_reindexing.py | 13 ++++++++-----
 scripts/actions_based/train.py             |  7 +++++--
 2 files changed, 13 insertions(+), 7 deletions(-)

diff --git a/scripts/actions_based/stage4_reindexing.py b/scripts/actions_based/stage4_reindexing.py
index b436ca2..c50d62f 100644
--- a/scripts/actions_based/stage4_reindexing.py
+++ b/scripts/actions_based/stage4_reindexing.py
@@ -31,10 +31,13 @@ if __name__ == "__main__":
     df = df.assign(ones=idx)
 
     # Shuffle 
-    shuffled_idx = idx.compute().values
-    shuffled_idx = client.scatter(shuffled_idx)
-    mapped_ones = df.ones.apply(lambda x, idx: idx[x], args=(shuffled_idx,), meta=('ones', 'int64')).persist()
-    df = df.assign(ones=mapped_ones)
+    #shuffled_idx = idx.compute().values
+    #np.random.shuffle(shuffled_idx)
+    #shuffled_idx = client.scatter(shuffled_idx)    
+    #mapped_ones = df.ones.apply(lambda x, idx: idx[x], args=(shuffled_idx,), meta=('ones', 'int64'))
+    #df = df.assign(ones=mapped_ones)
 
-    df = df.set_index('ones')
+    #df = df.persist()
+
+    df = df.set_index('ones', shuffle='tasks')
     df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/scripts/actions_based/train.py b/scripts/actions_based/train.py
index 3a5bb3c..f6419f9 100755
--- a/scripts/actions_based/train.py
+++ b/scripts/actions_based/train.py
@@ -77,6 +77,9 @@ if __name__ == "__main__":
     model.base_model.train()
     losses = []
 
+    num_samples = df.tail(1).index.values[0] + 1
+    random_index_shuffle = np.random.permutation(range(num_samples))
+
     training_stopped = False
 
     time_max = datetime.max
@@ -88,8 +91,8 @@ if __name__ == "__main__":
             break
 
         i = sample_start
-        while True:
-            data_batch_indexes = list(range(i*batch_size, i*batch_size + batch_size))
+        while i + batch_size < num_samples:
+            data_batch_indexes = random_index_shuffle[list(range(i*batch_size, i*batch_size + batch_size))]
             
             # Precomputing total number of samples takes very long, so lets
             # try to get next batch until fail :)
-- 
GitLab


From ecd1b676c581dde15baac39604622251de0381da Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 6 Aug 2020 12:18:41 +0000
Subject: [PATCH 036/116] Performance fixes for actions

---
 .gitignore                                 |  2 +-
 notebooks/test_translations_model.ipynb    | 74 +++++++++++++-------
 params.yaml                                |  6 +-
 scripts/actions_based/stage1_extraction.py |  5 --
 scripts/actions_based/train.py             | 17 ++---
 scripts/translation_based/train.py         |  6 +-
 src/batch_loading.py                       | 79 ++++++++++++++++++++++
 src/processing.py                          | 24 +++----
 src/test_batch_loading.py                  | 47 +++++++++++++
 9 files changed, 195 insertions(+), 65 deletions(-)
 create mode 100644 src/batch_loading.py
 create mode 100644 src/test_batch_loading.py

diff --git a/.gitignore b/.gitignore
index fec9ca9..7db4ff0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,7 +2,7 @@ dane/**
 dataset_simple
 dataset_actions
 **/dask-worker-space
-.vscode
+.vscode`
 .idea
 .metals
 /data
diff --git a/notebooks/test_translations_model.ipynb b/notebooks/test_translations_model.ipynb
index 618327d..d418b29 100644
--- a/notebooks/test_translations_model.ipynb
+++ b/notebooks/test_translations_model.ipynb
@@ -14,7 +14,7 @@
   },
   "orig_nbformat": 2,
   "kernelspec": {
-   "name": "python38264bita7d7da14168440cb9836372958035d4a",
+   "name": "python_defaultSpec_1596573889994",
    "display_name": "Python 3.8.2 64-bit"
   }
  },
@@ -23,14 +23,15 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 45,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
     "import sys\n",
     "sys.path.append(\"../\")\n",
     "\n",
-    "from transformers import BertTokenizerFast, BertForMaskedLM\n",
+    "from transformers import BertTokenizerFast\n",
+    "from src.models.TransformerSeq2Seq import TransformerSeq2Seq\n",
     "import torch\n",
     "from torch.nn import BCEWithLogitsLoss\n",
     "import pandas as pd\n",
@@ -42,7 +43,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 46,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -51,7 +52,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 47,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -60,11 +61,20 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 48,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
-    "expected = \"Dekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik niemieckiego obozu pracy znęcał się nad obywatelami polskimi. Oskarżony Musielski. Świadek Jankowska opowiedziała, jak bił on Polaków po twarzy, kopał i groził Majdankiem. Świadek Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu, Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\"\n",
+    "tokenizer.em"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 36,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "expected = \"Ogromny wybuch w Bejrucie zrównał z ziemią prawie całą dzielnicę.\"\n",
     "text_clean = create_model_input_output(expected)[0]\n",
     "\n",
     "inputs = tokenizer(text_clean, return_tensors=\"pt\")"
@@ -72,66 +82,78 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 49,
+   "execution_count": 41,
    "metadata": {
     "tags": []
    },
    "outputs": [
-    {
-     "output_type": "stream",
-     "name": "stderr",
-     "text": "Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\nSome weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
-    },
     {
      "output_type": "execute_result",
      "data": {
       "text/plain": "<All keys matched successfully>"
      },
      "metadata": {},
-     "execution_count": 49
+     "execution_count": 41
     }
    ],
    "source": [
-    "model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=6)\n",
-    "model.load_state_dict(torch.load(\"../models/actionv1_500-0-5000.model\", map_location={'cuda:0': 'cpu'}))\n"
+    "model = TransformerSeq2Seq(tokenizer.vocab_size, 200, 300, 1, 2, 2)\n",
+    "model.load_state_dict(torch.load(\"../checkpoints/translations/0-43000.model\", map_location={'cuda:2': 'cpu'}))\n"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 50,
+   "execution_count": 42,
    "metadata": {},
    "outputs": [
     {
      "output_type": "execute_result",
      "data": {
-      "text/plain": "{'input_ids': tensor([[   101,  18910,  34942,  10147,  41458,  49264,  86412,  14196,  28780,\n          16828,  19435,  12507,  10132,    191,  10157,  10305,  16828,  47472,\n          10147,  51043,  53549,  10157,  22815,  29378,  42606,  32650,  15953,\n          24011,  10756,  13130,  10162,  25085,  10756,    191,  21619,  12240,\n          19094,  10136,  17528,    191,  82689,  59763,  10157,    183,    194,\n          14951,  12355,  77029,  10138,  49200,  18795,  13130,  16050,  45244,\n          58814,  63256,  19172,  10598,  30214,  11058,  18933,    191,  20728,\n          25983,  10390,  10424,  72058,  15083,  11372,  11841,  91496,  85421,\n            187,  21633,  22427,  57753,  10514,  96760,  10390,  44227,  22530,\n            191,  12979,    194,  20923,  63168,  10269,  11202,  19234,  14996,\n          39144,  10157,    194,  25470,  10305,  91865,  39371,  18694,  82836,\n          21799,  20868,  22578,  87985,  20162,  21791,  10138,  13680,  10701,\n         107626,  80711,  92859,  27648,    194,  31223,  10425,  11133,  10424,\n          12060,  17339,  70500,  87907,  10500,  72325,  10116,  10427,  15190,\n          36701,  87985,  18338,  10506,  18996, 103412,  10174,  63923,  72275,\n          11485,  10303,  28612, 110206,  10113,  13050,  11342,  11133,  10135,\n          14151,  16036,  10514,  37975,  27828,  39268,  16251,    177,  30518,\n          20129,  21617,  14991,  30022,  13711,  18996, 103412,  10174,  45244,\n          62872,  28780,  79534,  12537,  87985,  18338,  10506,  20157,  82351,\n          10157,  61610,  11133,    187,  21633,  14302,  11680,  12097,  14194,\n          82775,  13717,  23090, 108605, 107626,  10644,  92859,  44227,  22064,\n          11284,  96858,  11813,  43307,  17112,  84280,  10339,  67464,  40384,\n          12197,  10427,  15190,  61187,  10797,  25085,  10157,  26584,  10514,\n          90086, 102984,  13130,  10162,  31569,  34105,  87985,  18338,  18761,\n          10132,  25085,  10963,  26584,  11048,  10514,  52784,  21620,    102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n         1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
+      "text/plain": "'[CLS] Zapis w Bej, zrównał z ziemią prawie całą dodatki, całą, całą, dzielnicę, teatr, dzielnicę,, rozpro, a całą dzielnicę, z Londynu, budownictwie, prawie całą dzielnicę, niezarzymy, cuk z ziemią, prawie całą dzielnicę, rozpro, bory, nieufności, daną, z daleko, autorytetem, krowybina, w Argentynie wybuch, w dzielnicę, wybuch w dzielnicę, płyt, partia, doniesie,cie, wybuch w walce, kardi, ogromny, wybuchwany, przeszło, w wybuch,'"
      },
      "metadata": {},
-     "execution_count": 50
+     "execution_count": 42
     }
    ],
    "source": [
-    "inputs"
+    "input_tokens = inputs['input_ids']\n",
+    "outputs = [[tokenizer.cls_token_id]]\n",
+    "\n",
+    "for j in range(100):\n",
+    "    preds = model(input_tokens, torch.tensor(outputs, dtype=torch.long), torch.zeros_like(input_tokens).bool()).softmax(-1)\n",
+    "    outputs[0].append(preds[0, -1].argmax().detach().tolist())\n",
+    "\n",
+    "tokenizer.decode(outputs[0])\n"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 51,
+   "execution_count": 43,
    "metadata": {
     "tags": []
    },
    "outputs": [
     {
-     "output_type": "stream",
-     "name": "stdout",
-     "text": "Dekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik niemieckiego obozu pracy znęcał się nad obywatelami polskimi. Oskarżony Musielski. Świadek Jankowska opowiedziała, jak bił on Polaków po twarzy, kopał i groził Majdankiem. Świadek Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu, Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\n------\nDekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik Niemieckiego Obozu pracy znęcał się nad obywatelami polskimi. oskarżony musielski świadek Jankowska opowiedziała, jak bił on Polaków po twarzy kopał i groził majdankiem świadek. Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\n"
+     "output_type": "error",
+     "ename": "TypeError",
+     "evalue": "TransformerSeq2Seq object argument after ** must be a mapping, not Tensor",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-43-ed4e46be8788>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcls_token_id\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0mtokens_predictions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mTypeError\u001b[0m: TransformerSeq2Seq object argument after ** must be a mapping, not Tensor"
+     ]
     }
    ],
    "source": [
     "from src.processing import token_labels_to_word_labels, recover_text\n",
     "\n",
-    "y_pred = model(**inputs)[0].sigmoid()\n",
-    "labels_pred = token_labels_to_word_labels(text_clean, y_pred.detach().numpy()[0, 1:-1, :], tokenizer)\n",
+    "inputs = inputs['input_ids']\n",
+    "outputs = np.array([[tokenizer.cls_token_id]])\n",
+    "\n",
+    "y_pred = model(**inputs)\n",
+    "\n",
+    "tokens_predictions = [np.argmax(x) for x in y_pred[0, :]]\n",
     "\n",
     "actions = labels_pred > 0.5\n",
     "print(expected)\n",
diff --git a/params.yaml b/params.yaml
index e084c7a..9b3fb5b 100644
--- a/params.yaml
+++ b/params.yaml
@@ -29,12 +29,12 @@ actions:
     training:
         learning_rate: 0.0001
         num_epochs: 5
-        batch_size: 2
+        batch_size: 8
         save_step: 1000
         max_training_time: null
         loss_averaging_span: 1000
         fresh_start: false
-        device: "cuda:0"
+        device: "cuda:1"
 translations:
     extraction:
         num_partitions: 2_000
@@ -62,5 +62,5 @@ translations:
         save_step: 1000
         max_training_time: "4h"
         loss_averaging_span: 1000
-        fresh_start: true
+        fresh_start: false
         device: "cuda:1"
\ No newline at end of file
diff --git a/scripts/actions_based/stage1_extraction.py b/scripts/actions_based/stage1_extraction.py
index 171051d..7ccfa7b 100644
--- a/scripts/actions_based/stage1_extraction.py
+++ b/scripts/actions_based/stage1_extraction.py
@@ -7,11 +7,6 @@ import dask.dataframe as dd
 import dask
 import pandas as pd
 from dask.distributed import Client
-import gc
-from memory_profiler import profile
-from pympler import muppy, summary
-import stackimpact
-import lorem
 from src.utils import get_config, PROJECT_ROOT, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/scripts/actions_based/train.py b/scripts/actions_based/train.py
index f6419f9..40b8939 100755
--- a/scripts/actions_based/train.py
+++ b/scripts/actions_based/train.py
@@ -12,6 +12,7 @@ from src.utils import PROJECT_ROOT, get_config, convert_to_timedelta, prepare_fo
 from src.processing import ACTIONS_KEYS
 from datetime import datetime
 import pickle
+from src.batch_loading import get_batches
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
 INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
@@ -67,8 +68,8 @@ if __name__ == "__main__":
                 furthest_batch_num = max(iteration, furthest_batch_num)
 
         if furthest_epoch > -1 and furthest_batch_num > -1:
-            model.load_state_dict(torch.load(f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.model"))
-            optimizer.load_state_dict(torch.load(f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.optimizer"))
+            model.load_state_dict(torch.load(f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.model", map_location=device))
+            #optimizer.load_state_dict(torch.load(f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.optimizer", map_location=device))
 
             epoch_start, sample_start = furthest_epoch, furthest_batch_num
             print(f"Loaded {furthest_epoch}-{furthest_batch_num}")
@@ -91,17 +92,7 @@ if __name__ == "__main__":
             break
 
         i = sample_start
-        while i + batch_size < num_samples:
-            data_batch_indexes = random_index_shuffle[list(range(i*batch_size, i*batch_size + batch_size))]
-            
-            # Precomputing total number of samples takes very long, so lets
-            # try to get next batch until fail :)
-            try:
-                data_batch = df.loc[data_batch_indexes].compute()
-            except:
-                # TODO: Specify exception type
-                break
-
+        for data_batch in get_batches(df, batch_size, 100, random_index_shuffle, i):
             inputs = data_batch.apply(lambda x: x['input'].reshape(x['input_shape']), axis=1).values
             outputs = data_batch.apply(lambda x: x['output'].reshape(x['output_shape']), axis=1).values
             attentions_mask = data_batch.apply(lambda x: x['attention_mask'].reshape(x['attention_mask_shape']), axis=1).values
diff --git a/scripts/translation_based/train.py b/scripts/translation_based/train.py
index 0ba36f2..1a17599 100755
--- a/scripts/translation_based/train.py
+++ b/scripts/translation_based/train.py
@@ -41,7 +41,7 @@ if __name__ == "__main__":
     
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
-    model = TransformerSeq2Seq(tokenizer.vocab_size, 200, max_len, 1, 2, 2, ).to(device)
+    model = TransformerSeq2Seq(tokenizer.vocab_size, 256, max_len, 4, 4, 4, ).to(device)
     criterion = torch.nn.CrossEntropyLoss(reduction="mean").to(device)
     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
@@ -97,9 +97,9 @@ if __name__ == "__main__":
             outputs = data_batch.apply(lambda x: x['target'].reshape(x['target_shape']), axis=1).values
             attentions_mask = data_batch.apply(lambda x: x['attention_mask'].reshape(x['attention_mask_shape']), axis=1).values
 
-            inputs = torch.tensor(np.stack(inputs, axis=0)).to(device)
+            inputs = torch.tensor(np.stack(inputs, axis=0), dtype=torch.long).to(device)
             attentions_mask = torch.tensor(np.stack(attentions_mask, axis=0) == 0).to(device)
-            output_indices = torch.tensor(np.stack(outputs, axis=0)).to(device)
+            output_indices = torch.tensor(np.stack(outputs, axis=0), dtype=torch.long).to(device)
 
             y_pred = model(inputs, output_indices[:, :-1], attentions_mask)
             y_pred = y_pred.transpose(1, 2)
diff --git a/src/batch_loading.py b/src/batch_loading.py
new file mode 100644
index 0000000..414bf93
--- /dev/null
+++ b/src/batch_loading.py
@@ -0,0 +1,79 @@
+import numpy as np
+import pandas as pd
+import dask.dataframe as dd
+from typing import Union
+
+def calculate_batch_buffer_id(batch_id: int, buffer_batch_num: int) -> int:
+    """Calculate which buffer should be loaded into memory for a given batch
+
+    Args:
+        batch_id (int): Id of the batch, counted from the start
+        buffer_batch_num (int): Number of batches that are loaded at once into memory
+
+    Returns:
+        int: Batch buffer id that needs to be in memory for a given batch
+    """
+    return batch_id // buffer_batch_num
+
+def yield_batch_buffer_span(batch_size: int, batch_buffer_len: int, num_samples: int) -> np.array:
+    """Calculates which samples should be loaded in a given batch buffer
+
+    Args:
+        batch_buffer_id (int): Id of the buffer, counting from beggining
+        batch_buffer_size (int): Size of batch buffer (in number of batches)
+        num_samples (int): Number of samples in a dataset
+
+    Returns:
+        np.array: Contignous ids that should be loaded to memory for a given buffer 
+    """
+    batch_buffer_size = batch_size * batch_buffer_len
+
+    batch_buffer_id = 0
+
+    while batch_buffer_id < (num_samples / batch_buffer_size):
+        buffer_start = batch_buffer_size * batch_buffer_id
+        buffer_end = min(num_samples, buffer_start + batch_buffer_size)
+
+        yield np.arange(buffer_start, buffer_end, 1, np.long)
+        batch_buffer_id += 1
+
+def get_ordered_dataframe_len(df: Union[pd.DataFrame, dd.DataFrame]) -> int:
+    """Gets length of a dataframe, which ids are ORDERED CONTINUOUSLY from 0 to N
+    without counting all the elements
+
+    Args:
+        df (Union[pd.DataFrame, dd.DataFrame]): Dataframe
+
+    Returns:
+        int: Length of the dataframe
+    """
+    return df.tail(1).index.values[0] + 1
+
+def get_batches(df: dd.DataFrame, batch_size: int, batch_buffer_len: int, shuffled_ids: np.array, batch_start: int = 0) -> pd.DataFrame:
+    """Generator for getting batches from large Dask dataframe with implemented buffering
+
+    Args:
+        df (dd.DataFrame): Source dask dataframe
+        batch_size (int): Desired size of a batch
+        batch_buffer_len (int): Number of batches to load to memory at once
+        shuffled_ids (np.array): Shuffled order of samples
+
+    Returns:
+        pd.DataFrame: [description]
+
+    Yields:
+        Iterator[pd.DataFrame]: [description]
+    """
+    length = get_ordered_dataframe_len(df)
+
+    batch_id = batch_start
+
+    for batch_buffer_span in yield_batch_buffer_span(batch_size, batch_buffer_len, length):
+        buffer_ids = shuffled_ids[batch_buffer_span]
+        buffer = df.loc[buffer_ids].compute()
+
+        for i in range(batch_buffer_len):
+            batch_ids = buffer_ids[range(i * batch_size, min((i+1) * batch_size, len(buffer_ids)))]
+
+            yield buffer.loc[batch_ids]
+            batch_id += 1
diff --git a/src/processing.py b/src/processing.py
index 8e3501e..04589e3 100644
--- a/src/processing.py
+++ b/src/processing.py
@@ -6,7 +6,11 @@ import numpy as np
 from transformers import PreTrainedTokenizerFast
 from collections import defaultdict
 
-ACTIONS_KEYS = ['dot', 'upper_case', 'colon', 'elipsis', 'dash', 'question_mark']
+ACTIONS_KEYS = [
+    'dot', 
+    'upper_case', 
+    'colon',  
+    'question_mark']
 
 def empty_action_vector() -> np.ndarray:
     """Returns a do-nothing actions vector
@@ -54,9 +58,9 @@ def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
     """
     # Unsuported characters
     word.replace(";", ".") 
-    word.replace('"', " ")
-    word.replace('(', " ")
-    word.replace(')', " ")
+    word.replace('"', "")
+    word.replace('(', "")
+    word.replace(')', "")
 
     while len(word) > 0 and not word[0].isalnum(): # remove proceding characters
         word = word[1:]
@@ -64,14 +68,10 @@ def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
     if len(word) == 0:
         return dict(zip(ACTIONS_KEYS, [False] * len(ACTIONS_KEYS)))
 
-    has_colon = len(word) > 3 and word[-3:] == "..."
-
     actions = {
-        'dot': word[-1] == '.' and not has_colon,
+        'dot': word[-1] == '.',
         'upper_case': word[0].isupper(),
         'colon': word[-1] == ",",
-        'elipsis': has_colon,
-        'dash': next_word is not None and next_word == "-",
         'question_mark': word[-1] == "?"
     }
 
@@ -228,12 +228,8 @@ def recover_word(word: str, action: Mapping[str, bool]) -> str:
         word_result = word_result.capitalize()
     if action['colon']:
         word_result += ","
-    if action['elipsis']:
-        word_result += "..."
-    if action['dash']:
-        word_result += " -"
     if action['question_mark']:
-        word_result += " -"
+        word_result += "?"
 
     return word_result
 
diff --git a/src/test_batch_loading.py b/src/test_batch_loading.py
new file mode 100644
index 0000000..613c1bb
--- /dev/null
+++ b/src/test_batch_loading.py
@@ -0,0 +1,47 @@
+import numpy as np
+import pandas as pd
+import dask.dataframe as dd
+from src.batch_loading import *
+
+def test_calculate_batch_buffer_id():
+    ids = [0, 1, 2, 3, 4, 5, 6]
+    assert calculate_batch_buffer_id(0, 3)  == 0
+    assert calculate_batch_buffer_id(1, 3)  == 0
+    assert calculate_batch_buffer_id(2, 3)  == 0
+    assert calculate_batch_buffer_id(3, 3)  == 1
+    assert calculate_batch_buffer_id(4, 3)  == 1
+    assert calculate_batch_buffer_id(5, 3)  == 1
+    assert calculate_batch_buffer_id(6, 3)  == 2
+
+def test_yield_batch_buffer_span():
+    ids = [0, 1, 2, 3, 4, 5, 6]
+
+    result = list(yield_batch_buffer_span(2, 2, len(ids)))
+
+    assert np.all(result[0] == [0, 1, 2, 3])
+    assert np.all(result[1] == [4, 5, 6])
+
+def test_get_ordered_dataframe_len():
+    df = pd.DataFrame({'a': [1, 2, 3, 4, 5, 6, 7]})
+
+    assert get_ordered_dataframe_len(df) == 7
+
+def test_get_batches():
+    batch_size = 2
+    batch_buffer_len = 2
+    pdf = pd.DataFrame({'a': [1,0,2,3,4,5,6]})
+    shuffled_ids = np.array([1, 0, 2, 3, 4, 5, 6])
+    df = dd.from_pandas(pdf, npartitions=2)
+
+    batches = list(get_batches(df, batch_size, batch_buffer_len, shuffled_ids))
+
+    assert np.all(batches[0]['a'].values == [0, 1])
+    assert np.all(batches[1]['a'].values == [2, 3])
+    assert np.all(batches[2]['a'].values == [4, 5])
+    assert np.all(batches[3]['a'].values == [6])
+
+    batches = list(get_batches(df, batch_size, batch_buffer_len, shuffled_ids, 1))
+
+    assert np.all(batches[1]['a'].values == [2, 3])
+    assert np.all(batches[2]['a'].values == [4, 5])
+    assert np.all(batches[3]['a'].values == [6])
-- 
GitLab


From a10f7b9c4ea98cd1987d11b35a2067fb37009533 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 6 Aug 2020 14:48:11 +0200
Subject: [PATCH 037/116] Actions reduction fix

---
 src/processing.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/src/processing.py b/src/processing.py
index 04589e3..eea9f5e 100644
--- a/src/processing.py
+++ b/src/processing.py
@@ -244,8 +244,7 @@ def is_sentence_end(actions_encoded: np.ndarray) -> bool:
     """
     actions_decoded = decode_actions(actions_encoded)
 
-    return (actions_decoded['dot']
-            or actions_decoded['elipsis'])
+    return actions_decoded['dot'] == True
 
 def nearest_sentence_l(labels: np.array, index_start: int) -> int:
     """Find nearest word that begins a sentence that has lower or equal index to index_start
-- 
GitLab


From 2d9b3d9861158ccb7da0bb512e12c27d15d39bcf Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 6 Aug 2020 13:40:23 +0000
Subject: [PATCH 038/116] Started service coding

---
 notebooks/test_actions_model.ipynb | 29 +++++++++++++++--------------
 punctuator.py                      |  6 ++++++
 2 files changed, 21 insertions(+), 14 deletions(-)
 create mode 100644 punctuator.py

diff --git a/notebooks/test_actions_model.ipynb b/notebooks/test_actions_model.ipynb
index d1bae4e..4f2484e 100644
--- a/notebooks/test_actions_model.ipynb
+++ b/notebooks/test_actions_model.ipynb
@@ -14,7 +14,7 @@
   },
   "orig_nbformat": 2,
   "kernelspec": {
-   "name": "python38264bita7d7da14168440cb9836372958035d4a",
+   "name": "python_defaultSpec_1596719587498",
    "display_name": "Python 3.8.2 64-bit"
   }
  },
@@ -23,7 +23,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 45,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -42,7 +42,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 46,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -51,7 +51,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 47,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -60,11 +60,11 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 48,
+   "execution_count": 16,
    "metadata": {},
    "outputs": [],
    "source": [
-    "expected = \"W porcie w stolicy Libanu doszło we wtorek po południu do potężnej eksplozji. Zginęło co najmniej sto osób, ok. 4 tys. zostało rannych. Z dotychczasowych informacji wynika, że przyczyną wybuchu mogły być chemikalia, w tym saletra amonowa. Trop wiedzie do statku pod mołdawską banderą.\"\n",
+    "expected = \"w niedzielę kiedy prażanie są precz i moczą nogi wieku bereuce na wzgórze zamkowe królów czeskich napierają watahy cudzoziemców podnieconych obecnością w miejscu ważnym dla historii od dziesięciu wieków i ponownie od czasu kiedy na hrad wyniesiono vaclava havla zamek chwilowo nie ma gospodarza ale wszystko jest pod parą i gwardia gotowa zawsze do stania u boku katedra świętego wita jest dla czechów tym czy dla polaków katedra na wawelu żelazny szlak turystyczny prowadzi złotą uliczką gdzie zaczyna boleć głowa od nadmiaru okazji do pamiątkowego zdjęcia w tym małym domu mieszkał wielki kafka zanim się dostaniemy na lewy brzeg wełtawy wypada spędzić dłuższą chwilę na samym moście bujnie opiewanym i jeśli nie najpiękniejszym to z pewnością jedynym w roli tak ciągle twórczej dla folkloru wielkiego miasta najpierw śmierć uderza młoteczkiem w dzwonki a potem defiluje i przedstawia się turystom dwanaście apostołów zegar orloja jest najwyższą z wiekowych atrakcji na rynku starego miasta nie budujemy drugich czech choć bliżej nam do pragi niż do tokio i już tysiąc lat temu zaczął święty wojciech ciekawe tylko czy jan hus przetrwałby na pomniku w drugich czechach i czy miałby tyle szans co święty wacław przy wiślanym szlaku od mostu poniatowskiego w dół rzeki trzyma się bar pod rurą i może jeszcze ze dwie budy gdzie po dawnemu sprzedaje się piwo marki piwo temperatura zależy od słoneczka szkło nie gra roli a niewiastom wstęp niewskazany w gazetach do niedawna pisano że piwo degraduje polaków w rzeczywistości było dokładnie na odwrót amerykański desant nie uratował zagrożonej cnoty polaków bo ciepłe piwo jest mimo wszystko mniej obrzydliwe od ciepłej coca coli miejsce właściwe obu napojom w kulturze i obyczaju wyznaczył dopiero wolny rynek \"\n",
     "text_clean = create_model_input_output(expected)[0]\n",
     "\n",
     "inputs = tokenizer(text_clean, return_tensors=\"pt\")"
@@ -72,7 +72,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 49,
+   "execution_count": 17,
    "metadata": {
     "tags": []
    },
@@ -80,7 +80,7 @@
     {
      "output_type": "stream",
      "name": "stderr",
-     "text": "Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\nSome weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+     "text": "Some weights of the model checkpoint at dkleczek/bert-base-polish-cased-v1 were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\nSome weights of BertForTokenClassification were not initialized from the model checkpoint at dkleczek/bert-base-polish-cased-v1 and are newly initialized: ['classifier.weight', 'classifier.bias']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     },
     {
      "output_type": "execute_result",
@@ -88,17 +88,18 @@
       "text/plain": "<All keys matched successfully>"
      },
      "metadata": {},
-     "execution_count": 49
+     "execution_count": 17
     }
    ],
    "source": [
-    "model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=6)\n",
-    "model.load_state_dict(torch.load(\"../checkpoints/actions/0-183000.model\", map_location={'cuda:0': 'cpu'}))\n"
+    "model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=4)\n",
+    "device = torch.device(\"cpu\")\n",
+    "model.load_state_dict(torch.load(\"../checkpoints/actions/0-2000.model\", map_location=device))\n"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 51,
+   "execution_count": 21,
    "metadata": {
     "tags": []
    },
@@ -106,7 +107,7 @@
     {
      "output_type": "stream",
      "name": "stdout",
-     "text": "Dekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik niemieckiego obozu pracy znęcał się nad obywatelami polskimi. Oskarżony Musielski. Świadek Jankowska opowiedziała, jak bił on Polaków po twarzy, kopał i groził Majdankiem. Świadek Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu, Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\n------\nDekretem polskiej władzy państwowej stworzono na wyzwolonym obszarze Rzeczypospolitej specjalny sąd karny, w którego kompetencje wchodzą sprawy o zdrady narodu polskiego. Przed sądem stanęli renegaci, którzy nie tylko wyrzekli się polskości, ale postępowaniem swym czynnie pomagali Niemcom w ich zbrodniach. Podajemy fragmenty z rozprawy przeciwko folksdojczowi Musialskiemu, który jako kierownik Niemieckiego Obozu pracy znęcał się nad obywatelami polskimi. oskarżony musielski świadek Jankowska opowiedziała, jak bił on Polaków po twarzy kopał i groził majdankiem świadek. Stankiewicz stwierdził, że Musielski przewyższył swym okrucieństwem poprzednich kierowników obozu Niemców. Prokurator doktor Sawicki zażądał dla oskarżonego kary śmierci. Po naradzie sąd skazał Musielskiego na karę śmierci przez powieszenie.\n"
+     "text": "w niedzielę kiedy prażanie są precz i moczą nogi wieku bereuce na wzgórze zamkowe królów czeskich napierają watahy cudzoziemców podnieconych obecnością w miejscu ważnym dla historii od dziesięciu wieków i ponownie od czasu kiedy na hrad wyniesiono vaclava havla zamek chwilowo nie ma gospodarza ale wszystko jest pod parą i gwardia gotowa zawsze do stania u boku katedra świętego wita jest dla czechów tym czy dla polaków katedra na wawelu żelazny szlak turystyczny prowadzi złotą uliczką gdzie zaczyna boleć głowa od nadmiaru okazji do pamiątkowego zdjęcia w tym małym domu mieszkał wielki kafka zanim się dostaniemy na lewy brzeg wełtawy wypada spędzić dłuższą chwilę na samym moście bujnie opiewanym i jeśli nie najpiękniejszym to z pewnością jedynym w roli tak ciągle twórczej dla folkloru wielkiego miasta najpierw śmierć uderza młoteczkiem w dzwonki a potem defiluje i przedstawia się turystom dwanaście apostołów zegar orloja jest najwyższą z wiekowych atrakcji na rynku starego miasta nie budujemy drugich czech choć bliżej nam do pragi niż do tokio i już tysiąc lat temu zaczął święty wojciech ciekawe tylko czy jan hus przetrwałby na pomniku w drugich czechach i czy miałby tyle szans co święty wacław przy wiślanym szlaku od mostu poniatowskiego w dół rzeki trzyma się bar pod rurą i może jeszcze ze dwie budy gdzie po dawnemu sprzedaje się piwo marki piwo temperatura zależy od słoneczka szkło nie gra roli a niewiastom wstęp niewskazany w gazetach do niedawna pisano że piwo degraduje polaków w rzeczywistości było dokładnie na odwrót amerykański desant nie uratował zagrożonej cnoty polaków bo ciepłe piwo jest mimo wszystko mniej obrzydliwe od ciepłej coca coli miejsce właściwe obu napojom w kulturze i obyczaju wyznaczył dopiero wolny rynek \n------\nW niedzielę, kiedy prażanie są precz I moczą nogi, wieku Bereuce na Wzgórze zamkowe Królów czeskich napierają watahy cudzoziemców, podnieconych obecnością W miejscu, ważnym, Dla historii Od dziesięciu wieków I ponownie, Od czasu, kiedy Na Hrad wyniesiono vaclava Havla Zamek chwilowo Nie ma gospodarza, Ale wszystko jest Pod parą I gwardia gotowa zawsze Do stania U boku. Katedra Świętego wita jest Dla Czechów tym, czy Dla Polaków Katedra Na Wawelu Żelazny Szlak turystyczny prowadzi Złotą uliczką, Gdzie zaczyna boleć głowa Od nadmiaru, okazji, Do pamiątkowego zdjęcia. W tym Małym domu mieszkał Wielki kafka Zanim się dostaniemy na lewy brzeg, wełtawy wypada spędzić dłuższą chwilę Na samym moście, bujnie opiewanym i, jeśli nie najpiękniejszym, to z pewnością, jedynym W roli tak ciągle, twórczej Dla folkloru, Wielkiego miasta. Najpierw śmierć uderza Młoteczkiem w dzwonki, A potem defiluje I przedstawia się turystom Dwanaście apostołów. Zegar Orloja jest najwyższą z wiekowych atrakcji. Na rynku starego miasta. Nie budujemy drugich czech Choć bliżej nam, do Pragi Niż do Tokio I już tysiąc lat temu, zaczął Święty wojciech Ciekawe tylko, czy Jan Hus przetrwałby Na pomniku W Drugich Czechach I Czy miałby tyle szans, co Święty Wacław Przy Wiślanym szlaku Od Mostu Poniatowskiego W dół rzeki Trzyma się bar Pod rurą I może jeszcze ze dwie budy, gdzie, po dawnemu, sprzedaje się piwo, marki piwo. Temperatura zależy od słoneczka Szkło nie gra roli, a niewiastom wstęp niewskazany. W gazetach Do niedawna pisano, że piwo degraduje Polaków W rzeczywistości, było, dokładnie Na odwrót. Amerykański desant Nie uratował zagrożonej cnoty polaków Bo ciepłe piwo, jest, Mimo wszystko, mniej obrzydliwe, Od ciepłej, coca coli miejsce właściwe Obu napojom, W kulturze, i obyczaju, Wyznaczył dopiero wolny rynek\n"
     }
    ],
    "source": [
@@ -115,7 +116,7 @@
     "y_pred = model(**inputs)[0].sigmoid()\n",
     "labels_pred = token_labels_to_word_labels(text_clean, y_pred.detach().numpy()[0, 1:-1, :], tokenizer)\n",
     "\n",
-    "actions = labels_pred > 0.5\n",
+    "actions = labels_pred > 0.8\n",
     "print(expected)\n",
     "print(\"------\")\n",
     "print(recover_text(text_clean, actions))"
diff --git a/punctuator.py b/punctuator.py
new file mode 100644
index 0000000..0141a19
--- /dev/null
+++ b/punctuator.py
@@ -0,0 +1,6 @@
+import argparse
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Adds punctuaiton in to raw text stream.")
+    parser.add_argument('-i', '--input', type=str, required=True, help="Path to input text file")
+    parser.add_argument('-o', '--output', type=str, required=True, help="Path to input text file")
\ No newline at end of file
-- 
GitLab


From acbd2b764c34fda33d0f36b44bf69fc6e4740f28 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Fri, 7 Aug 2020 14:31:11 +0200
Subject: [PATCH 039/116] Most of code now conforms to the Clarin & PEP8
 standards

---
 .gitignore                                    |   1 +
 .vscode/.ropeproject/config.py                | 123 +++++++++++++
 .vscode/.ropeproject/objectdb                 | Bin 0 -> 6 bytes
 .vscode/launch.json                           |  22 +++
 .vscode/settings.json                         |  25 +++
 Dockerfile                                    |   1 +
 docker/{ => development}/Dockerfile           |   0
 dvc.yaml                                      |  36 ++--
 generated/.gitignore                          |   6 +-
 main.py                                       |  37 ++++
 params.yaml                                   |   7 +-
 punctuator.py                                 |   6 -
 scripts/actions_based/processing.py           |  82 ---------
 scripts/actions_based/stage2_tokenization.py  |  46 -----
 scripts/actions_based/stage3_exploding.py     |  31 ----
 scripts/actions_based/stage4_reindexing.py    |  43 -----
 scripts/actions_based/stage5_stats.py         |  52 ------
 scripts/translation_based/stage3_exploding.py |  30 ----
 src/batch_loading.py                          |  28 ++-
 src/models/TransformerSeq2Seq.py              |  82 ++++++---
 {scripts => src/pipelines}/__init__.py        |   0
 .../pipelines}/actions_based/__init__.py      |   0
 src/pipelines/actions_based/processing.py     |  90 ++++++++++
 .../actions_based/stage1_extraction.py        |  34 ++--
 .../actions_based/stage2_tokenization.py      |  46 +++++
 .../actions_based/stage3_exploding.py         |  37 ++++
 .../actions_based/stage4_reindexing.py        |  37 ++++
 src/pipelines/actions_based/stage5_stats.py   |  58 ++++++
 .../pipelines}/actions_based/train.py         | 101 +++++++----
 .../pipelines}/translation_based/__init__.py  |   0
 .../translation_based/processing.py           | 166 +++++++++++-------
 .../translation_based/stage1_extraction.py    |  28 ++-
 .../stage2_create_batches.py                  |  28 ++-
 .../translation_based/stage3_exploding.py     |  37 ++++
 .../translation_based/stage4_reindexing.py    |  22 +--
 .../pipelines}/translation_based/train.py     | 119 +++++++++----
 src/processing.py                             | 159 +++++++++++------
 src/test_batch_loading.py                     |  47 -----
 src/utils.py                                  |  36 ++--
 tests/__init__.py                             |   0
 tests/pipelines/__init__.py                   |   0
 tests/pipelines/translation_based/__init__.py |   0
 .../translation_based/test_processing.py      | 110 +++++++++---
 tests/test_batch_loading.py                   |  58 ++++++
 {src => tests}/test_processing.py             | 152 ++++++++--------
 tox.ini                                       |  48 +++++
 46 files changed, 1330 insertions(+), 741 deletions(-)
 create mode 100644 .vscode/.ropeproject/config.py
 create mode 100644 .vscode/.ropeproject/objectdb
 create mode 100644 .vscode/launch.json
 create mode 100644 .vscode/settings.json
 create mode 100644 Dockerfile
 rename docker/{ => development}/Dockerfile (100%)
 create mode 100755 main.py
 delete mode 100644 punctuator.py
 delete mode 100644 scripts/actions_based/processing.py
 delete mode 100644 scripts/actions_based/stage2_tokenization.py
 delete mode 100644 scripts/actions_based/stage3_exploding.py
 delete mode 100644 scripts/actions_based/stage4_reindexing.py
 delete mode 100644 scripts/actions_based/stage5_stats.py
 delete mode 100644 scripts/translation_based/stage3_exploding.py
 rename {scripts => src/pipelines}/__init__.py (100%)
 rename {scripts => src/pipelines}/actions_based/__init__.py (100%)
 create mode 100644 src/pipelines/actions_based/processing.py
 rename {scripts => src/pipelines}/actions_based/stage1_extraction.py (59%)
 create mode 100644 src/pipelines/actions_based/stage2_tokenization.py
 create mode 100644 src/pipelines/actions_based/stage3_exploding.py
 create mode 100644 src/pipelines/actions_based/stage4_reindexing.py
 create mode 100644 src/pipelines/actions_based/stage5_stats.py
 rename {scripts => src/pipelines}/actions_based/train.py (54%)
 rename {scripts => src/pipelines}/translation_based/__init__.py (100%)
 rename {scripts => src/pipelines}/translation_based/processing.py (71%)
 rename {scripts => src/pipelines}/translation_based/stage1_extraction.py (52%)
 rename {scripts => src/pipelines}/translation_based/stage2_create_batches.py (53%)
 create mode 100644 src/pipelines/translation_based/stage3_exploding.py
 rename {scripts => src/pipelines}/translation_based/stage4_reindexing.py (56%)
 rename {scripts => src/pipelines}/translation_based/train.py (50%)
 delete mode 100644 src/test_batch_loading.py
 create mode 100644 tests/__init__.py
 create mode 100644 tests/pipelines/__init__.py
 create mode 100644 tests/pipelines/translation_based/__init__.py
 rename {scripts => tests/pipelines}/translation_based/test_processing.py (58%)
 create mode 100644 tests/test_batch_loading.py
 rename {src => tests}/test_processing.py (63%)
 create mode 100644 tox.ini

diff --git a/.gitignore b/.gitignore
index 7db4ff0..1414e6a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,3 +10,4 @@ __pycache__
 .pytest_cache
 /checkpoints
 .dvc
+.tox
\ No newline at end of file
diff --git a/.vscode/.ropeproject/config.py b/.vscode/.ropeproject/config.py
new file mode 100644
index 0000000..c339bc7
--- /dev/null
+++ b/.vscode/.ropeproject/config.py
@@ -0,0 +1,123 @@
+# The default ``config.py``
+# flake8: noqa
+
+
+def set_prefs(prefs):
+    """This function is called before opening the project"""
+
+    # Specify which files and folders to ignore in the project.
+    # Changes to ignored resources are not added to the history and
+    # VCSs.  Also they are not returned in `Project.get_files()`.
+    # Note that ``?`` and ``*`` match all characters but slashes.
+    # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc'
+    # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc'
+    # '.svn': matches 'pkg/.svn' and all of its children
+    # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o'
+    # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o'
+    prefs["ignored_resources"] = [
+        "*.pyc",
+        "*~",
+        ".ropeproject",
+        ".hg",
+        ".svn",
+        "_svn",
+        ".git",
+        ".tox",
+    ]
+
+    # Specifies which files should be considered python files.  It is
+    # useful when you have scripts inside your project.  Only files
+    # ending with ``.py`` are considered to be python files by
+    # default.
+    # prefs['python_files'] = ['*.py']
+
+    # Custom source folders:  By default rope searches the project
+    # for finding source folders (folders that should be searched
+    # for finding modules).  You can add paths to that list.  Note
+    # that rope guesses project source folders correctly most of the
+    # time; use this if you have any problems.
+    # The folders should be relative to project root and use '/' for
+    # separating folders regardless of the platform rope is running on.
+    # 'src/my_source_folder' for instance.
+    # prefs.add('source_folders', 'src')
+
+    # You can extend python path for looking up modules
+    # prefs.add('python_path', '~/python/')
+
+    # Should rope save object information or not.
+    prefs["save_objectdb"] = True
+    prefs["compress_objectdb"] = False
+
+    # If `True`, rope analyzes each module when it is being saved.
+    prefs["automatic_soa"] = True
+    # The depth of calls to follow in static object analysis
+    prefs["soa_followed_calls"] = 0
+
+    # If `False` when running modules or unit tests "dynamic object
+    # analysis" is turned off.  This makes them much faster.
+    prefs["perform_doa"] = True
+
+    # Rope can check the validity of its object DB when running.
+    prefs["validate_objectdb"] = True
+
+    # How many undos to hold?
+    prefs["max_history_items"] = 32
+
+    # Shows whether to save history across sessions.
+    prefs["save_history"] = True
+    prefs["compress_history"] = False
+
+    # Set the number spaces used for indenting.  According to
+    # :PEP:`8`, it is best to use 4 spaces.  Since most of rope's
+    # unit-tests use 4 spaces it is more reliable, too.
+    prefs["indent_size"] = 4
+
+    # Builtin and c-extension modules that are allowed to be imported
+    # and inspected by rope.
+    prefs["extension_modules"] = []
+
+    # Add all standard c-extensions to extension_modules list.
+    prefs["import_dynload_stdmods"] = True
+
+    # If `True` modules with syntax errors are considered to be empty.
+    # The default value is `False`; When `False` syntax errors raise
+    # `rope.base.exceptions.ModuleSyntaxError` exception.
+    prefs["ignore_syntax_errors"] = False
+
+    # If `True`, rope ignores unresolvable imports.  Otherwise, they
+    # appear in the importing namespace.
+    prefs["ignore_bad_imports"] = False
+
+    # If `True`, rope will insert new module imports as
+    # `from <package> import <module>` by default.
+    prefs["prefer_module_from_imports"] = False
+
+    # If `True`, rope will transform a comma list of imports into
+    # multiple separate import statements when organizing
+    # imports.
+    prefs["split_imports"] = False
+
+    # If `True`, rope will remove all top-level import statements and
+    # reinsert them at the top of the module when making changes.
+    prefs["pull_imports_to_top"] = True
+
+    # If `True`, rope will sort imports alphabetically by module name instead
+    # of alphabetically by import statement, with from imports after normal
+    # imports.
+    prefs["sort_imports_alphabetically"] = False
+
+    # Location of implementation of
+    # rope.base.oi.type_hinting.interfaces.ITypeHintingFactory In general
+    # case, you don't have to change this value, unless you're an rope expert.
+    # Change this value to inject you own implementations of interfaces
+    # listed in module rope.base.oi.type_hinting.providers.interfaces
+    # For example, you can add you own providers for Django Models, or disable
+    # the search type-hinting in a class hierarchy, etc.
+    prefs["type_hinting_factory"] = (
+        "rope.base.oi.type_hinting.factory.default_type_hinting_factory"
+    )
+
+
+def project_opened(project):
+    """This function is called after opening the project"""
+    # Do whatever you like here!
diff --git a/.vscode/.ropeproject/objectdb b/.vscode/.ropeproject/objectdb
new file mode 100644
index 0000000000000000000000000000000000000000..0a47446c0ad231c193bdd44ff327ba2ab28bf3d8
GIT binary patch
literal 6
NcmZo*sx4&D0{{kv0iOT>

literal 0
HcmV?d00001

diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 0000000..6e10b9f
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,22 @@
+{
+    // Use IntelliSense to learn about possible attributes.
+    // Hover to view descriptions of existing attributes.
+    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+    "version": "0.2.0",
+    "configurations": [
+        {
+            "name": "Python: Current File",
+            "type": "python",
+            "request": "launch",
+            "program": "${file}",
+            "console": "integratedTerminal",
+            "cwd": "${workspaceFolder}/dataset_generation"
+        },
+        {
+            "name": "Python: Attach using Process Id",
+            "type": "python",
+            "request": "attach",
+            "processId": "${command:pickProcess}"
+        }
+    ]
+}
\ No newline at end of file
diff --git a/.vscode/settings.json b/.vscode/settings.json
new file mode 100644
index 0000000..f55b7bf
--- /dev/null
+++ b/.vscode/settings.json
@@ -0,0 +1,25 @@
+{
+    "python.testing.pytestArgs": [
+        "tests",
+    ],
+    "python.testing.unittestEnabled": false,
+    "python.testing.nosetestsEnabled": false,
+    "python.testing.pytestEnabled": true,
+    "python.testing.unittestArgs": [
+        "-v",
+        "-s",
+        "./src",
+        "-p",
+        "test_*.py"
+    ],
+    "files.watcherExclude": {
+        "**/.git": true,
+        "**/.svn": true,
+        "**/.hg": true,
+        "**/CVS": true,
+        "**/.DS_Store": true,
+        "data/*": true
+    },
+    "python.testing.cwd": "${workspaceFolder}",
+    "docker.host": "ssh://mpogoda@156.17.135.51"
+}
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..ead19b1
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1 @@
+FROM clarinpl/python:3.8
\ No newline at end of file
diff --git a/docker/Dockerfile b/docker/development/Dockerfile
similarity index 100%
rename from docker/Dockerfile
rename to docker/development/Dockerfile
diff --git a/dvc.yaml b/dvc.yaml
index 7875e8d..bc29243 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -1,18 +1,18 @@
 stages:
   actions_extraction:
-    cmd: python3 -m scripts.actions_based.stage1_extraction
+    cmd: python3 -m src.pipelines.actions_based.stage1_extraction
     deps:
     - data
-    - scripts/actions_based/stage1_extraction.py
+    - src/pipelines/actions_based/stage1_extraction.py
     params:
     - actions.extraction.num_partitions
     outs:
     - generated/actions/stage1_extraction
   actions_tokenization:
-    cmd: python3 -m scripts.actions_based.stage2_tokenization
+    cmd: python3 -m src.pipelines.actions_based.stage2_tokenization
     deps:
     - generated/actions/stage1_extraction
-    - scripts/actions_based/stage2_tokenization.py
+    - src/pipelines/actions_based/stage2_tokenization.py
     params:
     - actions.tokenization.max_tokens
     - actions.tokenization.min_tokens
@@ -20,32 +20,32 @@ stages:
     outs:
     - generated/actions/stage2_tokenization
   actions_exploding:
-    cmd: python3 -m scripts.actions_based.stage3_exploding
+    cmd: python3 -m src.pipelines.actions_based.stage3_exploding
     deps:
     - generated/actions/stage2_tokenization
-    - scripts/actions_based/stage3_exploding.py
+    - src/pipelines/actions_based/stage3_exploding.py
     outs:
     - generated/actions/stage3_exploding
   actions_reindexing:
-    cmd: python3 -m scripts.actions_based.stage4_reindexing
+    cmd: python3 -m src.pipelines.actions_based.stage4_reindexing
     deps:
     - generated/actions/stage3_exploding
-    - scripts/actions_based/stage4_reindexing.py
+    - src/pipelines/actions_based/stage4_reindexing.py
     outs:
     - generated/actions/stage4_reindexing
   actions_stats:
-    cmd: python3 -m scripts.actions_based.stage5_stats
+    cmd: python3 -m src.pipelines.actions_based.stage5_stats
     deps:
     - generated/actions/stage4_reindexing
-    - scripts/actions_based/stage5_stats.py
+    - src/pipelines/actions_based/stage5_stats.py
     outs:
     - generated/actions/stage5_stats
   actions_training:
-    cmd: python3 -m scripts.actions_based.train
+    cmd: python3 -m src.pipelines.actions_based.train
     deps:
     - generated/actions/stage4_reindexing
     - generated/actions/stage5_stats
-    - scripts/actions_based/train.py
+    - src/pipelines/actions_based/train.py
     params:
     - global.base_model
     - actions.training.max_training_time
@@ -56,7 +56,7 @@ stages:
     outs:
     - checkpoints/actions
   translations_extraction:
-    cmd: python3 -m scripts.translation_based.stage1_extraction
+    cmd: python3 -m src.pipelines.translation_based.stage1_extraction
     deps:
     - data
     params:
@@ -64,7 +64,7 @@ stages:
     outs:
     - generated/translations/stage1_extraction
   translations_create_batches:
-    cmd: python3 -m scripts.translation_based.stage2_create_batches
+    cmd: python3 -m src.pipelines.translation_based.stage2_create_batches
     deps:
     - generated/translations/stage1_extraction
     params:
@@ -72,22 +72,22 @@ stages:
     outs:
     - generated/translations/stage2_create_batches
   translations_exploding:
-    cmd: python3 -m scripts.translation_based.stage3_exploding
+    cmd: python3 -m src.pipelines.translation_based.stage3_exploding
     deps:
     - generated/translations/stage2_create_batches
     outs:
     - generated/translations/stage3_exploding
   translations_reindexing:
-    cmd: python3 -m scripts.translation_based.stage4_reindexing
+    cmd: python3 -m src.pipelines.translation_based.stage4_reindexing
     deps:
     - generated/translations/stage3_exploding
     outs:
     - generated/translations/stage4_reindexing
   translations_training:
-    cmd: python3 -m scripts.translation_based.train
+    cmd: python3 -m src.pipelines.translation_based.train
     deps:
     - generated/translations/stage4_reindexing
-    - scripts/translation_based/train.py
+    - src/pipelines/translation_based/train.py
     params:
     - global.base_model
     - translations.training.max_training_time
diff --git a/generated/.gitignore b/generated/.gitignore
index 959c3a4..c96a04f 100644
--- a/generated/.gitignore
+++ b/generated/.gitignore
@@ -1,4 +1,2 @@
-/stage1_extraction
-/stage2_tokenization
-/stage3_exploding
-/stage4_reindexing
+*
+!.gitignore
\ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100755
index 0000000..e01ced9
--- /dev/null
+++ b/main.py
@@ -0,0 +1,37 @@
+#!/usr/bin/python3
+
+import argparse
+import os
+import sys
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        description="Adds punctuaiton in to raw text stream."
+    )
+    parser.add_argument(
+        "-i",
+        "--input",
+        type=str,
+        required=True,
+        help="Path to input text file",
+    )
+    parser.add_argument(
+        "-o",
+        "--output",
+        type=str,
+        required=True,
+        help="Path to input text file",
+    )
+    parser.add_argument(
+        "-m",
+        "--model",
+        type=str,
+        choices=["actions", "translation"],
+        default="actions",
+        help="Selects which model will be used. Defaults to actions",
+    )
+
+    args = parser.parse_args()
+
+    if not os.path.exists(args.input):
+        print(f"Error: File '{args.input}' does not exists", file=sys.stderr)
diff --git a/params.yaml b/params.yaml
index 9b3fb5b..93c1563 100644
--- a/params.yaml
+++ b/params.yaml
@@ -1,6 +1,7 @@
 global:
     dashboard_port: 8787
     base_model: "dkleczek/bert-base-polish-cased-v1"
+    random_seed: 44
 
 actions:
     extraction:
@@ -29,12 +30,12 @@ actions:
     training:
         learning_rate: 0.0001
         num_epochs: 5
-        batch_size: 8
+        batch_size: 2
         save_step: 1000
         max_training_time: null
         loss_averaging_span: 1000
-        fresh_start: false
-        device: "cuda:1"
+        fresh_start: true
+        device: "cuda:0"
 translations:
     extraction:
         num_partitions: 2_000
diff --git a/punctuator.py b/punctuator.py
deleted file mode 100644
index 0141a19..0000000
--- a/punctuator.py
+++ /dev/null
@@ -1,6 +0,0 @@
-import argparse
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Adds punctuaiton in to raw text stream.")
-    parser.add_argument('-i', '--input', type=str, required=True, help="Path to input text file")
-    parser.add_argument('-o', '--output', type=str, required=True, help="Path to input text file")
\ No newline at end of file
diff --git a/scripts/actions_based/processing.py b/scripts/actions_based/processing.py
deleted file mode 100644
index 8e8db74..0000000
--- a/scripts/actions_based/processing.py
+++ /dev/null
@@ -1,82 +0,0 @@
-from transformers import BertTokenizerFast
-from src.processing import tokenize_labeled_text, batchify_data
-import numpy as np
-
-def expand_dims(entry: dict):
-    inputs = entry.input.reshape(entry.input_shape)
-    outputs = entry.output.reshape(entry.output_shape)
-    masks = entry.attention_mask.reshape(entry.attention_mask_shape)
-
-    return {
-        'input': inputs,
-        'output': outputs,
-        "attention_mask": masks,
-    }
-
-EXPAND_DIMS_META = {
-    'input': object,
-    'output': object,
-    'attention_mask': object
-}
-
-def apply_tokenization(df, min_tokens: int, max_tokens: int, tokenizer: BertTokenizerFast):
-    text_clean = df.input
-    labels = df.output
-    shape = df.output_shape
-
-    tokens, token_labels = tokenize_labeled_text(
-        text_clean, labels.reshape(shape), tokenizer)
-
-    inputs, outputs, attentions = batchify_data(
-        tokens, token_labels, max_tokens, tokenizer, min_tokens)
-
-    inputs_shape = np.array(inputs.shape)
-    outputs_shape = np.array(outputs.shape)
-    attentions_shape = np.array(attentions.shape)
-
-    return {
-        'input': inputs.reshape(-1),
-        'output': outputs.reshape(-1),
-        'attention_mask': attentions.reshape(-1),
-        'input_shape': inputs_shape,
-        'output_shape': outputs_shape,
-        'attention_mask_shape': attentions_shape
-    }
-
-
-APPLY_TOKENIZATION_META = {
-    'input': object,
-    'output': object,
-    'attention_mask': object,
-    'input_shape': object,
-    'output_shape': object,
-    'attention_mask_shape': object
-}
-
-def flatten_dims(entry):
-    inputs_shape = np.array(entry.input.shape)
-    outputs_shape = np.array(entry.output.shape)
-    attentions_shape = np.array(entry.attention_mask.shape)
-
-    inputs = entry.input.reshape(-1)
-    outputs = entry.output.reshape(-1)
-    attentions = entry.attention_mask.reshape(-1)
-
-    return {
-        'input': inputs,
-        'output': outputs,
-        'attention_mask': attentions,
-        'input_shape': inputs_shape,
-        'output_shape': outputs_shape,
-        'attention_mask_shape': attentions_shape
-    }
-
-
-FLATTEN_DIMS_META = {
-    'input': object,
-    'output': object,
-    'attention_mask': object,
-    'input_shape': object,
-    'output_shape': object,
-    'attention_mask_shape': object
-}
diff --git a/scripts/actions_based/stage2_tokenization.py b/scripts/actions_based/stage2_tokenization.py
deleted file mode 100644
index fae7d86..0000000
--- a/scripts/actions_based/stage2_tokenization.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# /usr/bin/python3
-import os
-import glob
-import random
-from lxml import etree
-import uuid
-import hashlib
-import seaborn as sns
-import re
-import numpy as np
-from tqdm import tqdm
-from src.processing import tokenize_labeled_text, batchify_data
-from src.utils import remove_multiple_spaces, remove_punctuation, PROJECT_ROOT, get_config, prepare_folder
-import dask
-from dask.diagnostics import ProgressBar
-import dask.dataframe as dd
-from transformers import BertTokenizerFast
-from dask.distributed import Client
-from scripts.actions_based.processing import apply_tokenization, APPLY_TOKENIZATION_META
-
-INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
-
-if __name__ == "__main__":
-
-    config = get_config()
-    max_tokens = config['actions']['tokenization']['max_tokens']
-    min_tokens = config['actions']['tokenization']['min_tokens']
-    num_workers = config['actions']['tokenization']['num_workers']
-    memory_limit = config['actions']['tokenization']['worker_memory_limit']
-    base_model = config['global']['base_model']
-
-    prepare_folder(OUTPUT_FOLDER)
-
-    client = Client(n_workers=num_workers, memory_limit=memory_limit)
-    print(client.dashboard_link)
-
-    tokenizer = BertTokenizerFast.from_pretrained(base_model)
-
-    tokenizer = dask.delayed(tokenizer)
-
-    df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
-    df = df.apply(apply_tokenization, args=(min_tokens, max_tokens, tokenizer),
-                  result_type='expand', axis=1, meta=APPLY_TOKENIZATION_META)
-
-    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/scripts/actions_based/stage3_exploding.py b/scripts/actions_based/stage3_exploding.py
deleted file mode 100644
index 2978324..0000000
--- a/scripts/actions_based/stage3_exploding.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# /usr/bin/python3
-from src.processing import batchify_data
-from dask.diagnostics import ProgressBar
-import dask.dataframe as dd
-import numpy as np
-import dask
-from dask.distributed import Client
-import pandas as pd
-from src.utils import PROJECT_ROOT, get_config, prepare_folder
-from scripts.actions_based.processing import expand_dims, EXPAND_DIMS_META, flatten_dims, FLATTEN_DIMS_META
-
-INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
-
-if __name__ == "__main__":
-    config = get_config()
-    num_workers = config['actions']['exploding']['num_workers']
-    memory_limit = config['actions']['exploding']['worker_memory_limit']
-
-    prepare_folder(OUTPUT_FOLDER)
-
-    client = Client(n_workers=num_workers, memory_limit=memory_limit)
-    print(client.dashboard_link)
-
-    df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
-
-    df = df.apply(expand_dims, result_type='expand', axis=1, meta=EXPAND_DIMS_META)
-    df = df.map_partitions(lambda x: x.apply(lambda y: y.explode(), axis=0), meta=EXPAND_DIMS_META)
-    df = df.apply(flatten_dims, result_type='expand', axis=1, meta=FLATTEN_DIMS_META)
-    
-    df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/scripts/actions_based/stage4_reindexing.py b/scripts/actions_based/stage4_reindexing.py
deleted file mode 100644
index c50d62f..0000000
--- a/scripts/actions_based/stage4_reindexing.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# /usr/bin/python3
-from src.processing import batchify_data
-from dask.diagnostics import ProgressBar
-import dask.dataframe as dd
-from transformers import BertTokenizerFast
-import numpy as np
-import dask
-from dask.distributed import Client
-import pandas as pd
-from src.utils import PROJECT_ROOT, get_config, prepare_folder
-
-INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
-
-if __name__ == "__main__":
-    config = get_config()
-    num_workers = config['actions']['reindexing']['num_workers']
-    memory_limit = config['actions']['reindexing']['worker_memory_limit']
-
-    prepare_folder(OUTPUT_FOLDER)
-
-    client = Client(n_workers=num_workers, memory_limit=memory_limit)
-    print(client.dashboard_link)
-
-    df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
-
-    # Add ordered indexes
-    df = df.assign(ones=1)
-    df = df.reset_index(drop=True)
-    idx = (df.ones.cumsum() - 1).persist()
-    df = df.assign(ones=idx)
-
-    # Shuffle 
-    #shuffled_idx = idx.compute().values
-    #np.random.shuffle(shuffled_idx)
-    #shuffled_idx = client.scatter(shuffled_idx)    
-    #mapped_ones = df.ones.apply(lambda x, idx: idx[x], args=(shuffled_idx,), meta=('ones', 'int64'))
-    #df = df.assign(ones=mapped_ones)
-
-    #df = df.persist()
-
-    df = df.set_index('ones', shuffle='tasks')
-    df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/scripts/actions_based/stage5_stats.py b/scripts/actions_based/stage5_stats.py
deleted file mode 100644
index 7038008..0000000
--- a/scripts/actions_based/stage5_stats.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# /usr/bin/python3
-from src.processing import batchify_data, ACTIONS_KEYS
-from dask.diagnostics import ProgressBar
-import dask.dataframe as dd
-from transformers import BertTokenizerFast
-import numpy as np
-import dask
-from dask.distributed import Client
-import pandas as pd
-from src.utils import PROJECT_ROOT, get_config, prepare_folder
-import pickle
-from scripts.actions_based.processing import expand_dims, EXPAND_DIMS_META
-
-INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
-
-def reduce_fold(fold_value, new_value):
-    return {
-        'class_number': fold_value["class_number"] + np.sum(new_value, axis=0),
-        'num_examples': fold_value["num_examples"] + new_value.shape[0]
-    }
-
-def reduce_partitions(x, y):
-    return {
-        'class_number': x["class_number"] + y["class_number"],
-        'num_examples': x["num_examples"] + y["num_examples"]
-    }
-
-if __name__ == "__main__":
-    config = get_config()
-    num_workers = config['actions']['stats']['num_workers']
-    memory_limit = config['actions']['stats']['worker_memory_limit']
-
-    prepare_folder(OUTPUT_FOLDER)
-
-    client = Client(n_workers=num_workers, memory_limit=memory_limit)
-    print(client.dashboard_link)
-
-    df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
-    df = df.apply(expand_dims, result_type='expand', axis=1, meta=EXPAND_DIMS_META)
-
-    outputs_bag = df['output'].to_bag()
-
-    inital_values = {
-        "class_number": np.array([0] * len(ACTIONS_KEYS)),
-        'num_examples': 0
-    }
-
-    result = outputs_bag.fold(reduce_fold, reduce_partitions, initial=inital_values).compute()
-
-    with open(f"{OUTPUT_FOLDER}/stats.pickle", 'wb') as f:
-        pickle.dump(result, f)
\ No newline at end of file
diff --git a/scripts/translation_based/stage3_exploding.py b/scripts/translation_based/stage3_exploding.py
deleted file mode 100644
index f357dcd..0000000
--- a/scripts/translation_based/stage3_exploding.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# /usr/bin/python3
-from scripts.translation_based.processing import flatten_dims, expand_dims, FLATTEN_DIMS_META, EXPAND_DIMS_META
-from dask.diagnostics import ProgressBar
-import dask.dataframe as dd
-import numpy as np
-import dask
-from dask.distributed import Client
-import pandas as pd
-from src.utils import PROJECT_ROOT, get_config, prepare_folder
-
-INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage2_create_batches"
-OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage3_exploding"
-
-if __name__ == "__main__":
-    config = get_config()
-    num_workers = config['translations']['exploding']['num_workers']
-    memory_limit = config['translations']['exploding']['worker_memory_limit']
-
-    prepare_folder(OUTPUT_FOLDER)
-
-    client = Client(n_workers=num_workers, memory_limit=memory_limit)
-    print(client.dashboard_link)
-
-    df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
-
-    df = df.apply(expand_dims, result_type='expand', axis=1, meta=EXPAND_DIMS_META)
-    df = df.map_partitions(lambda x: x.apply(lambda y: y.explode(), axis=0), meta=EXPAND_DIMS_META)
-    df = df.apply(flatten_dims, result_type='expand', axis=1, meta=FLATTEN_DIMS_META)
-    
-    df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
diff --git a/src/batch_loading.py b/src/batch_loading.py
index 414bf93..7821cc4 100644
--- a/src/batch_loading.py
+++ b/src/batch_loading.py
@@ -3,6 +3,7 @@ import pandas as pd
 import dask.dataframe as dd
 from typing import Union
 
+
 def calculate_batch_buffer_id(batch_id: int, buffer_batch_num: int) -> int:
     """Calculate which buffer should be loaded into memory for a given batch
 
@@ -15,7 +16,10 @@ def calculate_batch_buffer_id(batch_id: int, buffer_batch_num: int) -> int:
     """
     return batch_id // buffer_batch_num
 
-def yield_batch_buffer_span(batch_size: int, batch_buffer_len: int, num_samples: int) -> np.array:
+
+def yield_batch_buffer_span(
+    batch_size: int, batch_buffer_len: int, num_samples: int
+) -> np.array:
     """Calculates which samples should be loaded in a given batch buffer
 
     Args:
@@ -24,7 +28,7 @@ def yield_batch_buffer_span(batch_size: int, batch_buffer_len: int, num_samples:
         num_samples (int): Number of samples in a dataset
 
     Returns:
-        np.array: Contignous ids that should be loaded to memory for a given buffer 
+        np.array: Contignous ids that should be loaded to memory for a given buffer
     """
     batch_buffer_size = batch_size * batch_buffer_len
 
@@ -37,6 +41,7 @@ def yield_batch_buffer_span(batch_size: int, batch_buffer_len: int, num_samples:
         yield np.arange(buffer_start, buffer_end, 1, np.long)
         batch_buffer_id += 1
 
+
 def get_ordered_dataframe_len(df: Union[pd.DataFrame, dd.DataFrame]) -> int:
     """Gets length of a dataframe, which ids are ORDERED CONTINUOUSLY from 0 to N
     without counting all the elements
@@ -49,7 +54,14 @@ def get_ordered_dataframe_len(df: Union[pd.DataFrame, dd.DataFrame]) -> int:
     """
     return df.tail(1).index.values[0] + 1
 
-def get_batches(df: dd.DataFrame, batch_size: int, batch_buffer_len: int, shuffled_ids: np.array, batch_start: int = 0) -> pd.DataFrame:
+
+def get_batches(
+    df: dd.DataFrame,
+    batch_size: int,
+    batch_buffer_len: int,
+    shuffled_ids: np.array,
+    batch_start: int = 0,
+) -> pd.DataFrame:
     """Generator for getting batches from large Dask dataframe with implemented buffering
 
     Args:
@@ -68,12 +80,18 @@ def get_batches(df: dd.DataFrame, batch_size: int, batch_buffer_len: int, shuffl
 
     batch_id = batch_start
 
-    for batch_buffer_span in yield_batch_buffer_span(batch_size, batch_buffer_len, length):
+    for batch_buffer_span in yield_batch_buffer_span(
+        batch_size, batch_buffer_len, length
+    ):
         buffer_ids = shuffled_ids[batch_buffer_span]
         buffer = df.loc[buffer_ids].compute()
 
         for i in range(batch_buffer_len):
-            batch_ids = buffer_ids[range(i * batch_size, min((i+1) * batch_size, len(buffer_ids)))]
+            batch_ids = buffer_ids[
+                range(
+                    i * batch_size, min((i + 1) * batch_size, len(buffer_ids))
+                )
+            ]
 
             yield buffer.loc[batch_ids]
             batch_id += 1
diff --git a/src/models/TransformerSeq2Seq.py b/src/models/TransformerSeq2Seq.py
index 2466348..5c5cefe 100644
--- a/src/models/TransformerSeq2Seq.py
+++ b/src/models/TransformerSeq2Seq.py
@@ -4,7 +4,7 @@ import math
 
 
 class PositionalEncoding(nn.Module):
-    """Adds sinsusoidal positional encoding (as in original AIAYN paper)
+    """Adds sinsusoidal positional encoding (as in original "Attention is all you need" paper.)
     src: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
 
     """
@@ -22,12 +22,14 @@ class PositionalEncoding(nn.Module):
 
         pe = torch.zeros(max_len, d_model)
         position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
-        div_term = torch.exp(torch.arange(
-            0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+        div_term = torch.exp(
+            torch.arange(0, d_model, 2).float()
+            * (-math.log(10000.0) / d_model)
+        )
         pe[:, 0::2] = torch.sin(position * div_term)
         pe[:, 1::2] = torch.cos(position * div_term)
         pe = pe.unsqueeze(0).transpose(0, 1)
-        self.register_buffer('pe', pe)
+        self.register_buffer("pe", pe)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """Applies positional encoding
@@ -38,55 +40,81 @@ class PositionalEncoding(nn.Module):
         Returns:
             torch.Tensor: Word embeddings with added positional encodings
         """
-        x = x + self.pe[:x.size(0), :]
+        x = x + self.pe[: x.size(0), :]
         return self.dropout(x)
 
 
 class TransformerSeq2Seq(nn.Module):
-    """Class representing a sequence to sequence transformer, based on original "Attention is all you need" paper.
-    """
-
-    def __init__(self, vocab_size: int
-        , embedding_size: int
-        , max_len: int
-        , num_heads: int = 8
-        , encoder_layers: int = 6
-        , decoder_layers: int = 6
-        , feedforward_neurons: int = 2048
-        , dropout: float = 0.1):
+    """Class representing a sequence to sequence transformer, based on original "Attention is all you need" paper."""
+
+    def __init__(
+        self,
+        vocab_size: int,
+        embedding_size: int,
+        max_len: int,
+        num_heads: int = 8,
+        encoder_layers: int = 6,
+        decoder_layers: int = 6,
+        feedforward_neurons: int = 2048,
+        dropout: float = 0.1,
+    ):
 
         super(TransformerSeq2Seq, self).__init__()
 
+        # Embedd from token to vec space
         self.word_embedding = nn.Embedding(vocab_size, embedding_size)
-        self.position_embedding = PositionalEncoding(embedding_size, max_len, dropout)
+
+        # Add positional encoding
+        self.position_embedding = PositionalEncoding(
+            embedding_size, max_len, dropout
+        )
+
+        # Combined encoder-decoder step
         self.core = nn.Transformer(
-            embedding_size, num_heads, encoder_layers, decoder_layers, feedforward_neurons, dropout)
+            embedding_size,
+            num_heads,
+            encoder_layers,
+            decoder_layers,
+            feedforward_neurons,
+            dropout,
+        )
+
+        # Map embedding to word
         self.embedding_to_words = nn.Linear(embedding_size, vocab_size)
 
-    def forward(self, source: torch.Tensor, target: torch.Tensor, source_mask: torch.Tensor) -> torch.Tensor:
+    def forward(
+        self,
+        source: torch.Tensor,
+        target: torch.Tensor,
+        source_mask: torch.Tensor,
+    ) -> torch.Tensor:
         """Full encoder-decoder pass
 
         Args:
-            source (torch.Tensor): Tensor with batch of source sentences tokens (BxL)
-            target (torch.Tensor): Tensor with batch of target sentences tokens (BxL-1)
-            source_mask (torch.Tensor): Mask applied to source (True if element is padding, False otherwise) (BxL)
+            source (torch.Tensor): Tensor with batch of source sentences tokens [BxL shape]
+            target (torch.Tensor): Tensor with batch of target sentences tokens [BxL-1 shape]
+            source_mask (torch.Tensor): Mask applied to source (True if element is padding, False otherwise) [BxL shape]
 
         Returns:
-            torch.Tensor: Tensor with predicted target sentences tokens (BxL-1xV)
+            torch.Tensor: Tensor with predicted target sentences tokens [Bx(L-1)xV]
         """
-
+        # Input to encoder
         x = source.transpose(0, 1)
         x = self.word_embedding(x)
         x = self.position_embedding(x)
 
+        # Input to decoder
         y = target.transpose(0, 1)
         y = self.word_embedding(y)
         y = self.position_embedding(y)
 
-        tgt_mask = self.core.generate_square_subsequent_mask(y.shape[0]).to(y.device)
+        tgt_mask = self.core.generate_square_subsequent_mask(y.shape[0]).to(
+            y.device
+        )
 
-        z = self.core(x, y, src_key_padding_mask=source_mask,
-                      tgt_mask=tgt_mask).transpose(1, 0)
+        z = self.core(
+            x, y, src_key_padding_mask=source_mask, tgt_mask=tgt_mask
+        ).transpose(1, 0)
         z = self.embedding_to_words(z)
 
         return z
diff --git a/scripts/__init__.py b/src/pipelines/__init__.py
similarity index 100%
rename from scripts/__init__.py
rename to src/pipelines/__init__.py
diff --git a/scripts/actions_based/__init__.py b/src/pipelines/actions_based/__init__.py
similarity index 100%
rename from scripts/actions_based/__init__.py
rename to src/pipelines/actions_based/__init__.py
diff --git a/src/pipelines/actions_based/processing.py b/src/pipelines/actions_based/processing.py
new file mode 100644
index 0000000..7d4a67a
--- /dev/null
+++ b/src/pipelines/actions_based/processing.py
@@ -0,0 +1,90 @@
+from transformers import BertTokenizerFast
+from src.processing import tokenize_labeled_text, batchify_data
+import numpy as np
+
+
+def expand_dims(entry: dict):
+    inputs = entry.input.reshape(entry.input_shape)
+    outputs = entry.output.reshape(entry.output_shape)
+    masks = entry.attention_mask.reshape(entry.attention_mask_shape)
+
+    return {
+        "input": inputs,
+        "output": outputs,
+        "attention_mask": masks,
+    }
+
+
+EXPAND_DIMS_META = {
+    "input": object,
+    "output": object,
+    "attention_mask": object,
+}
+
+
+def apply_tokenization(
+    df, min_tokens: int, max_tokens: int, tokenizer: BertTokenizerFast
+):
+    text_clean = df.input
+    labels = df.output
+    shape = df.output_shape
+
+    tokens, token_labels = tokenize_labeled_text(
+        text_clean, labels.reshape(shape), tokenizer
+    )
+
+    inputs, outputs, attentions = batchify_data(
+        tokens, token_labels, max_tokens, tokenizer, min_tokens
+    )
+
+    inputs_shape = np.array(inputs.shape)
+    outputs_shape = np.array(outputs.shape)
+    attentions_shape = np.array(attentions.shape)
+
+    return {
+        "input": inputs.reshape(-1),
+        "output": outputs.reshape(-1),
+        "attention_mask": attentions.reshape(-1),
+        "input_shape": inputs_shape,
+        "output_shape": outputs_shape,
+        "attention_mask_shape": attentions_shape,
+    }
+
+
+APPLY_TOKENIZATION_META = {
+    "input": object,
+    "output": object,
+    "attention_mask": object,
+    "input_shape": object,
+    "output_shape": object,
+    "attention_mask_shape": object,
+}
+
+
+def flatten_dims(entry):
+    inputs_shape = np.array(entry.input.shape)
+    outputs_shape = np.array(entry.output.shape)
+    attentions_shape = np.array(entry.attention_mask.shape)
+
+    inputs = entry.input.reshape(-1)
+    outputs = entry.output.reshape(-1)
+    attentions = entry.attention_mask.reshape(-1)
+
+    return {
+        "input": inputs,
+        "output": outputs,
+        "attention_mask": attentions,
+        "input_shape": inputs_shape,
+        "output_shape": outputs_shape,
+        "attention_mask_shape": attentions_shape,
+    }
+
+
+FLATTEN_DIMS_META = {
+    "input": object,
+    "output": object,
+    "attention_mask": object,
+    "input_shape": object,
+    "output_shape": object,
+    "attention_mask_shape": object,
+}
diff --git a/scripts/actions_based/stage1_extraction.py b/src/pipelines/actions_based/stage1_extraction.py
similarity index 59%
rename from scripts/actions_based/stage1_extraction.py
rename to src/pipelines/actions_based/stage1_extraction.py
index 7ccfa7b..f6735d0 100644
--- a/scripts/actions_based/stage1_extraction.py
+++ b/src/pipelines/actions_based/stage1_extraction.py
@@ -2,9 +2,7 @@
 import glob
 import numpy as np
 from src.processing import text_from_xml, create_model_input_output
-from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
-import dask
 import pandas as pd
 from dask.distributed import Client
 from src.utils import get_config, PROJECT_ROOT, prepare_folder
@@ -12,6 +10,7 @@ from src.utils import get_config, PROJECT_ROOT, prepare_folder
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
 
+
 def process_file(x):
     full_text = text_from_xml(x.file)
 
@@ -20,16 +19,21 @@ def process_file(x):
 
         output_shape = np.array(model_output.shape, dtype=np.int)
 
-        return {'input': model_input, 'output': model_output.reshape(-1), 'output_shape': output_shape}
+        return {
+            "input": model_input,
+            "output": model_output.reshape(-1),
+            "output_shape": output_shape,
+        }
     else:
-        return {'input': None, 'output': None, 'output_shape': None}
+        return {"input": None, "output": None, "output_shape": None}
+
 
 if __name__ == "__main__":
 
     config = get_config()
-    num_partitions = config['actions']['extraction']['num_partitions']
-    num_workers = config['actions']['extraction']['num_workers']
-    memory_limit = config['actions']['extraction']['worker_memory_limit']
+    num_partitions = config["actions"]["extraction"]["num_partitions"]
+    num_workers = config["actions"]["extraction"]["num_workers"]
+    memory_limit = config["actions"]["extraction"]["worker_memory_limit"]
 
     prepare_folder(OUTPUT_FOLDER)
 
@@ -43,11 +47,17 @@ if __name__ == "__main__":
     print(f"Dashboard: {client.dashboard_link}")
 
     # Processing pipeline
-    df = dd.from_pandas(pd.DataFrame({'file': files_paths}), npartitions=num_partitions)
-
-    df = df.apply(process_file, result_type='expand', axis=1, meta={
-                   'input': str, 'output': object, 'output_shape': object})
+    df = dd.from_pandas(
+        pd.DataFrame({"file": files_paths}), npartitions=num_partitions
+    )
+
+    df = df.apply(
+        process_file,
+        result_type="expand",
+        axis=1,
+        meta={"input": str, "output": object, "output_shape": object},
+    )
     df = df.dropna()
 
     # Export
-    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
\ No newline at end of file
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/src/pipelines/actions_based/stage2_tokenization.py b/src/pipelines/actions_based/stage2_tokenization.py
new file mode 100644
index 0000000..affa6e9
--- /dev/null
+++ b/src/pipelines/actions_based/stage2_tokenization.py
@@ -0,0 +1,46 @@
+# /usr/bin/python3
+from src.utils import (
+    PROJECT_ROOT,
+    get_config,
+    prepare_folder,
+)
+import dask
+import dask.dataframe as dd
+from transformers import BertTokenizerFast
+from dask.distributed import Client
+from src.pipelines.actions_based.processing import (
+    apply_tokenization,
+    APPLY_TOKENIZATION_META,
+)
+
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
+
+if __name__ == "__main__":
+
+    config = get_config()
+    max_tokens = config["actions"]["tokenization"]["max_tokens"]
+    min_tokens = config["actions"]["tokenization"]["min_tokens"]
+    num_workers = config["actions"]["tokenization"]["num_workers"]
+    memory_limit = config["actions"]["tokenization"]["worker_memory_limit"]
+    base_model = config["global"]["base_model"]
+
+    prepare_folder(OUTPUT_FOLDER)
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
+    print(client.dashboard_link)
+
+    tokenizer = BertTokenizerFast.from_pretrained(base_model)
+
+    tokenizer = dask.delayed(tokenizer)
+
+    df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
+    df = df.apply(
+        apply_tokenization,
+        args=(min_tokens, max_tokens, tokenizer),
+        result_type="expand",
+        axis=1,
+        meta=APPLY_TOKENIZATION_META,
+    )
+
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/src/pipelines/actions_based/stage3_exploding.py b/src/pipelines/actions_based/stage3_exploding.py
new file mode 100644
index 0000000..9cff62b
--- /dev/null
+++ b/src/pipelines/actions_based/stage3_exploding.py
@@ -0,0 +1,37 @@
+# /usr/bin/python3
+import dask.dataframe as dd
+from dask.distributed import Client
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
+from src.pipelines.actions_based.processing import (
+    expand_dims,
+    EXPAND_DIMS_META,
+    flatten_dims,
+    FLATTEN_DIMS_META,
+)
+
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
+
+if __name__ == "__main__":
+    config = get_config()
+    num_workers = config["actions"]["exploding"]["num_workers"]
+    memory_limit = config["actions"]["exploding"]["worker_memory_limit"]
+
+    prepare_folder(OUTPUT_FOLDER)
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
+    print(client.dashboard_link)
+
+    df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
+
+    df = df.apply(
+        expand_dims, result_type="expand", axis=1, meta=EXPAND_DIMS_META
+    )
+    df = df.map_partitions(
+        lambda x: x.apply(lambda y: y.explode(), axis=0), meta=EXPAND_DIMS_META
+    )
+    df = df.apply(
+        flatten_dims, result_type="expand", axis=1, meta=FLATTEN_DIMS_META
+    )
+
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/src/pipelines/actions_based/stage4_reindexing.py b/src/pipelines/actions_based/stage4_reindexing.py
new file mode 100644
index 0000000..3fcaa76
--- /dev/null
+++ b/src/pipelines/actions_based/stage4_reindexing.py
@@ -0,0 +1,37 @@
+# /usr/bin/python3
+import dask.dataframe as dd
+from dask.distributed import Client
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
+
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+
+if __name__ == "__main__":
+    config = get_config()
+    num_workers = config["actions"]["reindexing"]["num_workers"]
+    memory_limit = config["actions"]["reindexing"]["worker_memory_limit"]
+
+    prepare_folder(OUTPUT_FOLDER)
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
+    print(client.dashboard_link)
+
+    df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
+
+    # Add ordered indexes
+    df = df.assign(ones=1)
+    df = df.reset_index(drop=True)
+    idx = (df.ones.cumsum() - 1).persist()
+    df = df.assign(ones=idx)
+
+    # Shuffle
+    # shuffled_idx = idx.compute().values
+    # np.random.shuffle(shuffled_idx)
+    # shuffled_idx = client.scatter(shuffled_idx)
+    # mapped_ones = df.ones.apply(lambda x, idx: idx[x], args=(shuffled_idx,), meta=('ones', 'int64'))
+    # df = df.assign(ones=mapped_ones)
+
+    # df = df.persist()
+
+    df = df.set_index("ones", shuffle="tasks")
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/src/pipelines/actions_based/stage5_stats.py b/src/pipelines/actions_based/stage5_stats.py
new file mode 100644
index 0000000..5a69d01
--- /dev/null
+++ b/src/pipelines/actions_based/stage5_stats.py
@@ -0,0 +1,58 @@
+# /usr/bin/python3
+from src.processing import ACTIONS_KEYS
+import dask.dataframe as dd
+import numpy as np
+from dask.distributed import Client
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
+import pickle
+from src.pipelines.actions_based.processing import (
+    expand_dims,
+    EXPAND_DIMS_META,
+)
+
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
+
+
+def reduce_fold(fold_value, new_value):
+    return {
+        "class_number": fold_value["class_number"] + np.sum(new_value, axis=0),
+        "num_examples": fold_value["num_examples"] + new_value.shape[0],
+    }
+
+
+def reduce_partitions(x, y):
+    return {
+        "class_number": x["class_number"] + y["class_number"],
+        "num_examples": x["num_examples"] + y["num_examples"],
+    }
+
+
+if __name__ == "__main__":
+    config = get_config()
+    num_workers = config["actions"]["stats"]["num_workers"]
+    memory_limit = config["actions"]["stats"]["worker_memory_limit"]
+
+    prepare_folder(OUTPUT_FOLDER)
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
+    print(client.dashboard_link)
+
+    df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
+    df = df.apply(
+        expand_dims, result_type="expand", axis=1, meta=EXPAND_DIMS_META
+    )
+
+    outputs_bag = df["output"].to_bag()
+
+    inital_values = {
+        "class_number": np.array([0] * len(ACTIONS_KEYS)),
+        "num_examples": 0,
+    }
+
+    result = outputs_bag.fold(
+        reduce_fold, reduce_partitions, initial=inital_values
+    ).compute()
+
+    with open(f"{OUTPUT_FOLDER}/stats.pickle", "wb") as f:
+        pickle.dump(result, f)
diff --git a/scripts/actions_based/train.py b/src/pipelines/actions_based/train.py
similarity index 54%
rename from scripts/actions_based/train.py
rename to src/pipelines/actions_based/train.py
index 40b8939..266d422 100755
--- a/scripts/actions_based/train.py
+++ b/src/pipelines/actions_based/train.py
@@ -3,12 +3,15 @@
 from transformers import BertTokenizerFast, BertForTokenClassification
 import torch
 from torch.nn import BCEWithLogitsLoss
-import pandas as pd
 import numpy as np
 import dask.dataframe as dd
-import os
 import glob
-from src.utils import PROJECT_ROOT, get_config, convert_to_timedelta, prepare_folder
+from src.utils import (
+    PROJECT_ROOT,
+    get_config,
+    convert_to_timedelta,
+    prepare_folder,
+)
 from src.processing import ACTIONS_KEYS
 from datetime import datetime
 import pickle
@@ -20,15 +23,15 @@ OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions"
 
 if __name__ == "__main__":
     config = get_config()
-    learning_rate = config['actions']['training']['learning_rate']
-    num_epochs = config['actions']['training']['num_epochs']
-    batch_size = config['actions']['training']['batch_size']
-    save_step = config['actions']['training']['save_step']
-    loss_averaging_span = config['actions']['training']['loss_averaging_span']
-    fresh_start = config['actions']['training']['fresh_start']
-    device_name = config['actions']['training']['device']
-    max_train_time = config['actions']['training']['max_training_time']
-    base_model = config['global']['base_model']
+    learning_rate = config["actions"]["training"]["learning_rate"]
+    num_epochs = config["actions"]["training"]["num_epochs"]
+    batch_size = config["actions"]["training"]["batch_size"]
+    save_step = config["actions"]["training"]["save_step"]
+    loss_averaging_span = config["actions"]["training"]["loss_averaging_span"]
+    fresh_start = config["actions"]["training"]["fresh_start"]
+    device_name = config["actions"]["training"]["device"]
+    max_train_time = config["actions"]["training"]["max_training_time"]
+    base_model = config["global"]["base_model"]
 
     prepare_folder(OUTPUT_PATH)
 
@@ -39,22 +42,24 @@ if __name__ == "__main__":
     print(f"Training on {device}")
 
     # Load loss weights
-    with open(f"{INPUT_STATS_PATH}/stats.pickle", 'rb') as f:
+    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
         stats = pickle.load(f)
-        pos_examples = stats['class_number']
-        neg_examples = stats['num_examples'] - stats['class_number']
+        pos_examples = stats["class_number"]
+        neg_examples = stats["num_examples"] - stats["class_number"]
         pos_weight = torch.tensor(neg_examples / pos_examples)
 
     df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
-    model = BertForTokenClassification.from_pretrained(base_model, num_labels=len(ACTIONS_KEYS)).to(device)
+    model = BertForTokenClassification.from_pretrained(
+        base_model, num_labels=len(ACTIONS_KEYS)
+    ).to(device)
     criterion = BCEWithLogitsLoss(pos_weight=pos_weight).to(device)
     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
     epoch_start = 0
     sample_start = 0
-    if fresh_start == False:
+    if fresh_start is False:
         checkpoint_files = glob.glob(f"{OUTPUT_PATH}/*.model")
         furthest_epoch = -1
         furthest_batch_num = -1
@@ -68,8 +73,18 @@ if __name__ == "__main__":
                 furthest_batch_num = max(iteration, furthest_batch_num)
 
         if furthest_epoch > -1 and furthest_batch_num > -1:
-            model.load_state_dict(torch.load(f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.model", map_location=device))
-            #optimizer.load_state_dict(torch.load(f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.optimizer", map_location=device))
+            model.load_state_dict(
+                torch.load(
+                    f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.model",
+                    map_location=device,
+                )
+            )
+            optimizer.load_state_dict(
+                torch.load(
+                    f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.optimizer",
+                    map_location=device,
+                )
+            )
 
             epoch_start, sample_start = furthest_epoch, furthest_batch_num
             print(f"Loaded {furthest_epoch}-{furthest_batch_num}")
@@ -92,14 +107,27 @@ if __name__ == "__main__":
             break
 
         i = sample_start
-        for data_batch in get_batches(df, batch_size, 100, random_index_shuffle, i):
-            inputs = data_batch.apply(lambda x: x['input'].reshape(x['input_shape']), axis=1).values
-            outputs = data_batch.apply(lambda x: x['output'].reshape(x['output_shape']), axis=1).values
-            attentions_mask = data_batch.apply(lambda x: x['attention_mask'].reshape(x['attention_mask_shape']), axis=1).values
+        for data_batch in get_batches(
+            df, batch_size, 100, random_index_shuffle, i
+        ):
+            inputs = data_batch.apply(
+                lambda x: x["input"].reshape(x["input_shape"]), axis=1
+            ).values
+            outputs = data_batch.apply(
+                lambda x: x["output"].reshape(x["output_shape"]), axis=1
+            ).values
+            attentions_mask = data_batch.apply(
+                lambda x: x["attention_mask"].reshape(
+                    x["attention_mask_shape"]
+                ),
+                axis=1,
+            ).values
 
             inputs = torch.tensor(np.stack(inputs).squeeze()).to(device)
             outputs = torch.tensor(np.stack(outputs)).to(device)
-            attentions_mask = torch.tensor(np.stack(attentions_mask)).to(device)
+            attentions_mask = torch.tensor(np.stack(attentions_mask)).to(
+                device
+            )
 
             y_pred = model(input_ids=inputs, attention_mask=attentions_mask)[0]
 
@@ -109,19 +137,31 @@ if __name__ == "__main__":
             if len(losses) > loss_averaging_span:
                 losses = losses[-loss_averaging_span:]
 
-            print(f'epoch: {epoch} | step: {i} | loss: {np.mean(losses)}')
+            print(f"epoch: {epoch} | step: {i} | loss: {np.mean(losses)}")
 
             optimizer.zero_grad()
 
-            if i % save_step == 0 and (i != sample_start or epoch != epoch_start):
+            if i % save_step == 0 and (
+                i != sample_start or epoch != epoch_start
+            ):
                 print(f"Saving: Epoch {epoch}, step {i}")
-                torch.save(model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model")
-                torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.optimizer")
+                torch.save(
+                    model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model"
+                )
+                torch.save(
+                    optimizer.state_dict(),
+                    f"{OUTPUT_PATH}/{epoch}-{i}.optimizer",
+                )
 
             if datetime.now() > time_max:
                 print(f"Max time reached, saving: Epoch {epoch}, step {i}")
-                torch.save(model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model")
-                torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.optimizer")
+                torch.save(
+                    model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model"
+                )
+                torch.save(
+                    optimizer.state_dict(),
+                    f"{OUTPUT_PATH}/{epoch}-{i}.optimizer",
+                )
                 training_stopped = True
                 break
 
@@ -133,4 +173,3 @@ if __name__ == "__main__":
     if not training_stopped:
         torch.save(model.state_dict(), f"{OUTPUT_PATH}/final.model")
         torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/final.optimizer")
-
diff --git a/scripts/translation_based/__init__.py b/src/pipelines/translation_based/__init__.py
similarity index 100%
rename from scripts/translation_based/__init__.py
rename to src/pipelines/translation_based/__init__.py
diff --git a/scripts/translation_based/processing.py b/src/pipelines/translation_based/processing.py
similarity index 71%
rename from scripts/translation_based/processing.py
rename to src/pipelines/translation_based/processing.py
index 941d0ec..3b02a4d 100644
--- a/scripts/translation_based/processing.py
+++ b/src/pipelines/translation_based/processing.py
@@ -1,4 +1,3 @@
-import dask.dataframe as dd
 from src.processing import text_from_xml, remove_punctuation
 from transformers import BertTokenizerFast
 import numpy as np
@@ -17,16 +16,21 @@ def raw_to_dataframe(entry: dict) -> dict:
     full_text = text_from_xml(entry.file)
 
     if len(full_text) > 0:
-        return {'input': full_text}
+        return {"input": full_text}
     else:
-        return {'input': None}
+        return {"input": None}
 
 
-RAW_TO_DATAFRAME_META = {
-    'input': str
-}
+RAW_TO_DATAFRAME_META = {"input": str}
+
 
-def generate_batches(entry: dict, min_len: int, max_len: int, separating_token: int, tokenizer: BertTokenizerFast) -> dict:
+def generate_batches(
+    entry: dict,
+    min_len: int,
+    max_len: int,
+    separating_token: int,
+    tokenizer: BertTokenizerFast,
+) -> dict:
     """Converts raw text entries into list of tokens
 
     Args:
@@ -36,12 +40,14 @@ def generate_batches(entry: dict, min_len: int, max_len: int, separating_token:
     Returns:
         dict: Dask dataset entry with one column ('tokens') containing np.array list of tokens
     """
-    tokens = np.array(tokenizer(entry.input)['input_ids'][1:-1])
+    tokens = np.array(tokenizer(entry.input)["input_ids"][1:-1])
 
     tokens_ending = (tokens == separating_token).astype(np.int)
     batch_indices = get_batch_indexes(tokens_ending, min_len, max_len - 2)
-    
-    source_batch, target_batch = crete_input_output_batch(tokens, batch_indices, max_len, tokenizer)
+
+    source_batch, target_batch = crete_input_output_batch(
+        tokens, batch_indices, max_len, tokenizer
+    )
     mask_batch = (source_batch != tokenizer.pad_token_id).astype(np.int)
 
     source_batch_shape = np.array(source_batch.shape)
@@ -53,41 +59,44 @@ def generate_batches(entry: dict, min_len: int, max_len: int, separating_token:
     mask_batch = mask_batch.reshape(-1)
 
     return {
-        'source': source_batch,
-        'target': target_batch,
-        'attention_mask': mask_batch,
-        'source_shape': source_batch_shape,
-        'target_shape': target_batch_shape,
-        'attention_mask_shape': mask_batch_shape
+        "source": source_batch,
+        "target": target_batch,
+        "attention_mask": mask_batch,
+        "source_shape": source_batch_shape,
+        "target_shape": target_batch_shape,
+        "attention_mask_shape": mask_batch_shape,
     }
 
 
 GENERATE_BATCHES_META = {
-    'source': object,
-    'target': object,
-    'attention_mask': object,
-    'source_shape': object,
-    'target_shape': object,
-    'attention_mask_shape': object
+    "source": object,
+    "target": object,
+    "attention_mask": object,
+    "source_shape": object,
+    "target_shape": object,
+    "attention_mask_shape": object,
 }
 
+
 def expand_dims(entry):
     source = entry.source.reshape(entry.source_shape)
     target = entry.target.reshape(entry.target_shape)
     mask = entry.attention_mask.reshape(entry.attention_mask_shape)
 
     return {
-        'source': source,
-        'target': target,
+        "source": source,
+        "target": target,
         "attention_mask": mask,
     }
 
+
 EXPAND_DIMS_META = {
-    'source': object, 
-    'target': object,
-    'attention_mask': object
+    "source": object,
+    "target": object,
+    "attention_mask": object,
 }
 
+
 def flatten_dims(entry):
     source_shape = np.array(entry.source.shape)
     target_shape = np.array(entry.target.shape)
@@ -98,21 +107,22 @@ def flatten_dims(entry):
     mask = entry.attention_mask.reshape(-1)
 
     return {
-        'source': source,
-        'target': target,
-        'attention_mask': mask,
-        'source_shape': source_shape,
-        'target_shape': target_shape,
-        'attention_mask_shape': mask_shape
+        "source": source,
+        "target": target,
+        "attention_mask": mask,
+        "source_shape": source_shape,
+        "target_shape": target_shape,
+        "attention_mask_shape": mask_shape,
     }
 
+
 FLATTEN_DIMS_META = {
-    'source': object,
-    'target': object,
-    'attention_mask': object,
-    'source_shape': object,
-    'target_shape': object,
-    'attention_mask_shape': object
+    "source": object,
+    "target": object,
+    "attention_mask": object,
+    "source_shape": object,
+    "target_shape": object,
+    "attention_mask_shape": object,
 }
 
 
@@ -160,8 +170,10 @@ def find_new_sentence_right(seq: np.array, pos: int) -> int:
     return None
 
 
-def get_batch_indexes(seq: np.array, min_length: int, max_length: int) -> [np.array]:
-    """Turns long sequence into array of indices, composing a single batch file. 
+def get_batch_indexes(
+    seq: np.array, min_length: int, max_length: int
+) -> [np.array]:
+    """Turns long sequence into array of indices, composing a single batch file.
 
     Args:
         seq (np.array): Input sequence of 1s and 0s, where 1 means end of sequence token (dot, semicolon etc.)
@@ -200,7 +212,9 @@ def get_batch_indexes(seq: np.array, min_length: int, max_length: int) -> [np.ar
     return batch
 
 
-def add_padding(seq: np.ndarray, total_length: int, padding_symbol: any) -> np.ndarray:
+def add_padding(
+    seq: np.ndarray, total_length: int, padding_symbol: any
+) -> np.ndarray:
     """Pads a sequence with provided symbol, to get array of length total_length in the end
 
     Args:
@@ -215,12 +229,16 @@ def add_padding(seq: np.ndarray, total_length: int, padding_symbol: any) -> np.n
     assert num_padding >= 0
 
     if num_padding > 0:
-        return np.concatenate([seq, np.array([padding_symbol] * num_padding)], axis=0)
+        return np.concatenate(
+            [seq, np.array([padding_symbol] * num_padding)], axis=0
+        )
     else:
         return np.copy(seq)
 
 
-def add_begin_end_tokens(seq: np.ndarray, begin_token: any, end_token: any) -> np.ndarray:
+def add_begin_end_tokens(
+    seq: np.ndarray, begin_token: any, end_token: any
+) -> np.ndarray:
     """Adds preceding and ending special tokens to the sequence
 
     Args:
@@ -232,16 +250,16 @@ def add_begin_end_tokens(seq: np.ndarray, begin_token: any, end_token: any) -> n
         np.ndarray: Sequence of len L+2
     """
 
-    return np.concatenate(
-        [
-            [begin_token],
-            seq,
-            [end_token]
-        ]
-    )
+    return np.concatenate([[begin_token], seq, [end_token]])
 
 
-def standarize_translation_sample(seq: np.ndarray, total_length: int, padding_symbol: any, begin_token: any, end_token: any) -> np.ndarray:
+def standarize_translation_sample(
+    seq: np.ndarray,
+    total_length: int,
+    padding_symbol: any,
+    begin_token: any,
+    end_token: any,
+) -> np.ndarray:
     """Adds special tokens and padding so that every sample has identical shape
 
     Args:
@@ -254,10 +272,16 @@ def standarize_translation_sample(seq: np.ndarray, total_length: int, padding_sy
     Returns:
         np.ndarray: Output sequence of length total_length
     """
-    return add_padding(add_begin_end_tokens(seq, begin_token, end_token), total_length, padding_symbol)
+    return add_padding(
+        add_begin_end_tokens(seq, begin_token, end_token),
+        total_length,
+        padding_symbol,
+    )
 
 
-def create_input_output(tokens: np.ndarray, length: int, tokenizer: BertTokenizerFast) -> (np.ndarray, np.ndarray):
+def create_input_output(
+    tokens: np.ndarray, length: int, tokenizer: BertTokenizerFast
+) -> (np.ndarray, np.ndarray):
     """Transforms a sequence of tokens into "translation" input and output
 
     Args:
@@ -271,22 +295,38 @@ def create_input_output(tokens: np.ndarray, length: int, tokenizer: BertTokenize
     """
     decoded_str = tokenizer.decode(tokens)
     cleaned_str = remove_punctuation(decoded_str).lower()
-    source_batch_entry = tokenizer(cleaned_str)['input_ids'][1:-1]
+    source_batch_entry = tokenizer(cleaned_str)["input_ids"][1:-1]
     target_batch_entry = tokens
 
     # In rare cases (because of encoding) unpunctuated lowercase input might be longer than output and exeed limits
-    # We need to trim in such cases 
+    # We need to trim in such cases
     if len(source_batch_entry) > length - 2:
-        source_batch_entry = source_batch_entry[:(length-2)]
+        source_batch_entry = source_batch_entry[: (length - 2)]
 
     source_batch_entry = standarize_translation_sample(
-        source_batch_entry, length, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id)
+        source_batch_entry,
+        length,
+        tokenizer.pad_token_id,
+        tokenizer.cls_token_id,
+        tokenizer.sep_token_id,
+    )
     target_batch_entry = standarize_translation_sample(
-        target_batch_entry, length, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id)
+        target_batch_entry,
+        length,
+        tokenizer.pad_token_id,
+        tokenizer.cls_token_id,
+        tokenizer.sep_token_id,
+    )
 
     return source_batch_entry, target_batch_entry
 
-def crete_input_output_batch(seq: np.ndarray, batch_indexes: [np.ndarray], length: int, tokenizer: BertTokenizerFast) -> (np.ndarray, np.ndarray):
+
+def crete_input_output_batch(
+    seq: np.ndarray,
+    batch_indexes: [np.ndarray],
+    length: int,
+    tokenizer: BertTokenizerFast,
+) -> (np.ndarray, np.ndarray):
     """Transforms a sequence of tokens into "translation" input and output batch
 
     Args:
@@ -304,9 +344,11 @@ def crete_input_output_batch(seq: np.ndarray, batch_indexes: [np.ndarray], lengt
     source_batch = []
     target_batch = []
     for entry in base_batch:
-        source_entry, target_entry = create_input_output(entry, length, tokenizer)
+        source_entry, target_entry = create_input_output(
+            entry, length, tokenizer
+        )
 
         source_batch.append(source_entry)
         target_batch.append(target_entry)
 
-    return np.array(source_batch), np.array(target_batch)
\ No newline at end of file
+    return np.array(source_batch), np.array(target_batch)
diff --git a/scripts/translation_based/stage1_extraction.py b/src/pipelines/translation_based/stage1_extraction.py
similarity index 52%
rename from scripts/translation_based/stage1_extraction.py
rename to src/pipelines/translation_based/stage1_extraction.py
index 159c7c8..5f2758f 100644
--- a/scripts/translation_based/stage1_extraction.py
+++ b/src/pipelines/translation_based/stage1_extraction.py
@@ -1,5 +1,8 @@
 # /usr/bin/python3
-from scripts.translation_based.processing import raw_to_dataframe, RAW_TO_DATAFRAME_META
+from src.pipelines.translation_based.processing import (
+    raw_to_dataframe,
+    RAW_TO_DATAFRAME_META,
+)
 from src.utils import PROJECT_ROOT, prepare_folder, get_config
 from glob import glob
 import numpy as np
@@ -13,14 +16,14 @@ OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
 if __name__ == "__main__":
 
     config = get_config()
-    num_partitions = config['translations']['extraction']['num_partitions']
-    num_workers = config['translations']['extraction']['num_workers']
-    memory_limit = config['translations']['extraction']['worker_memory_limit']
+    num_partitions = config["translations"]["extraction"]["num_partitions"]
+    num_workers = config["translations"]["extraction"]["num_workers"]
+    memory_limit = config["translations"]["extraction"]["worker_memory_limit"]
 
     prepare_folder(OUTPUT_FOLDER)
 
     file_schema = f"{INPUT_FOLDER}/**/text_structure.xml"
-    files_paths = glob(file_schema, recursive=True) 
+    files_paths = glob(file_schema, recursive=True)
 
     # Make sure python memory fragmentation won't go insane
     np.random.shuffle(files_paths)
@@ -29,10 +32,17 @@ if __name__ == "__main__":
     print(f"Dashboard: {client.dashboard_link}")
 
     # Processing pipeline
-    df = dd.from_pandas(pd.DataFrame({'file': files_paths}), npartitions=num_partitions)
-
-    df = df.apply(raw_to_dataframe, result_type='expand', axis=1, meta=RAW_TO_DATAFRAME_META)
+    df = dd.from_pandas(
+        pd.DataFrame({"file": files_paths}), npartitions=num_partitions
+    )
+
+    df = df.apply(
+        raw_to_dataframe,
+        result_type="expand",
+        axis=1,
+        meta=RAW_TO_DATAFRAME_META,
+    )
     df = df.dropna()
 
     # Export
-    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
\ No newline at end of file
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/scripts/translation_based/stage2_create_batches.py b/src/pipelines/translation_based/stage2_create_batches.py
similarity index 53%
rename from scripts/translation_based/stage2_create_batches.py
rename to src/pipelines/translation_based/stage2_create_batches.py
index c7719f3..9b65e08 100644
--- a/scripts/translation_based/stage2_create_batches.py
+++ b/src/pipelines/translation_based/stage2_create_batches.py
@@ -1,7 +1,9 @@
 # /usr/bin/python3
-from scripts.translation_based.processing import generate_batches, GENERATE_BATCHES_META
+from src.pipelines.translation_based.processing import (
+    generate_batches,
+    GENERATE_BATCHES_META,
+)
 from src.utils import PROJECT_ROOT, prepare_folder, get_config
-import numpy as np
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 import dask.dataframe as dd
@@ -13,11 +15,13 @@ OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage2_create_batches"
 if __name__ == "__main__":
 
     config = get_config()
-    num_workers = config['translations']['create_batches']['num_workers']
-    memory_limit = config['translations']['create_batches']['worker_memory_limit']
-    min_tokens = config['translations']['create_batches']['min_tokens']
-    max_tokens = config['translations']['create_batches']['max_tokens']
-    base_model = config['global']['base_model']
+    num_workers = config["translations"]["create_batches"]["num_workers"]
+    memory_limit = config["translations"]["create_batches"][
+        "worker_memory_limit"
+    ]
+    min_tokens = config["translations"]["create_batches"]["min_tokens"]
+    max_tokens = config["translations"]["create_batches"]["max_tokens"]
+    base_model = config["global"]["base_model"]
 
     prepare_folder(OUTPUT_FOLDER)
 
@@ -30,8 +34,14 @@ if __name__ == "__main__":
     token_separating = tokenizer(".")["input_ids"][1]
 
     df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
-    df = df.apply(generate_batches, result_type='expand', axis=1, meta=GENERATE_BATCHES_META, args=(min_tokens, max_tokens, token_separating, tokenizer))
+    df = df.apply(
+        generate_batches,
+        result_type="expand",
+        axis=1,
+        meta=GENERATE_BATCHES_META,
+        args=(min_tokens, max_tokens, token_separating, tokenizer),
+    )
     df = df.dropna()
 
     # Export
-    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
\ No newline at end of file
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/src/pipelines/translation_based/stage3_exploding.py b/src/pipelines/translation_based/stage3_exploding.py
new file mode 100644
index 0000000..4123abd
--- /dev/null
+++ b/src/pipelines/translation_based/stage3_exploding.py
@@ -0,0 +1,37 @@
+# /usr/bin/python3
+from src.pipelines.translation_based.processing import (
+    flatten_dims,
+    expand_dims,
+    FLATTEN_DIMS_META,
+    EXPAND_DIMS_META,
+)
+import dask.dataframe as dd
+from dask.distributed import Client
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
+
+INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage2_create_batches"
+OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage3_exploding"
+
+if __name__ == "__main__":
+    config = get_config()
+    num_workers = config["translations"]["exploding"]["num_workers"]
+    memory_limit = config["translations"]["exploding"]["worker_memory_limit"]
+
+    prepare_folder(OUTPUT_FOLDER)
+
+    client = Client(n_workers=num_workers, memory_limit=memory_limit)
+    print(client.dashboard_link)
+
+    df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
+
+    df = df.apply(
+        expand_dims, result_type="expand", axis=1, meta=EXPAND_DIMS_META
+    )
+    df = df.map_partitions(
+        lambda x: x.apply(lambda y: y.explode(), axis=0), meta=EXPAND_DIMS_META
+    )
+    df = df.apply(
+        flatten_dims, result_type="expand", axis=1, meta=FLATTEN_DIMS_META
+    )
+
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/scripts/translation_based/stage4_reindexing.py b/src/pipelines/translation_based/stage4_reindexing.py
similarity index 56%
rename from scripts/translation_based/stage4_reindexing.py
rename to src/pipelines/translation_based/stage4_reindexing.py
index 1974753..ffad285 100644
--- a/scripts/translation_based/stage4_reindexing.py
+++ b/src/pipelines/translation_based/stage4_reindexing.py
@@ -1,12 +1,6 @@
 # /usr/bin/python3
-from src.processing import batchify_data
-from dask.diagnostics import ProgressBar
 import dask.dataframe as dd
-from transformers import BertTokenizerFast
-import numpy as np
-import dask
 from dask.distributed import Client
-import pandas as pd
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage3_exploding"
@@ -14,15 +8,15 @@ OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage4_reindexing"
 
 if __name__ == "__main__":
     config = get_config()
-    num_workers = config['translations']['reindexing']['num_workers']
-    memory_limit = config['translations']['reindexing']['worker_memory_limit']
+    num_workers = config["translations"]["reindexing"]["num_workers"]
+    memory_limit = config["translations"]["reindexing"]["worker_memory_limit"]
 
     prepare_folder(OUTPUT_FOLDER)
 
     client = Client(n_workers=num_workers, memory_limit=memory_limit)
     print(client.dashboard_link)
 
-    df = dd.read_parquet(INPUT_FOLDER, engine='pyarrow')
+    df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
 
     # Add ordered indexes
     df = df.assign(ones=1)
@@ -30,11 +24,13 @@ if __name__ == "__main__":
     idx = (df.ones.cumsum() - 1).persist()
     df = df.assign(ones=idx)
 
-    # Shuffle 
+    # Shuffle
     shuffled_idx = idx.compute().values
     shuffled_idx = client.scatter(shuffled_idx)
-    mapped_ones = df.ones.apply(lambda x, idx: idx[x], args=(shuffled_idx,), meta=('ones', 'int64')).persist()
+    mapped_ones = df.ones.apply(
+        lambda x, idx: idx[x], args=(shuffled_idx,), meta=("ones", "int64")
+    ).persist()
     df = df.assign(ones=mapped_ones)
 
-    df = df.set_index('ones')
-    df.to_parquet(OUTPUT_FOLDER, engine='pyarrow')
+    df = df.set_index("ones")
+    df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/scripts/translation_based/train.py b/src/pipelines/translation_based/train.py
similarity index 50%
rename from scripts/translation_based/train.py
rename to src/pipelines/translation_based/train.py
index 1a17599..fa7f838 100755
--- a/scripts/translation_based/train.py
+++ b/src/pipelines/translation_based/train.py
@@ -1,15 +1,16 @@
 #!/usr/bin/python3
 
-from transformers import BertTokenizerFast, BertForMaskedLM
 import torch
-from torch.nn import BCEWithLogitsLoss
-import pandas as pd
 import numpy as np
 import dask.dataframe as dd
-import os
 import glob
-from src.utils import PROJECT_ROOT, get_config, convert_to_timedelta, prepare_folder
-from src.processing import ACTIONS_KEYS
+from transformers import BertTokenizerFast
+from src.utils import (
+    PROJECT_ROOT,
+    get_config,
+    convert_to_timedelta,
+    prepare_folder,
+)
 from datetime import datetime
 from src.models.TransformerSeq2Seq import TransformerSeq2Seq
 
@@ -18,16 +19,18 @@ OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/translations"
 
 if __name__ == "__main__":
     config = get_config()
-    learning_rate = config['translations']['training']['learning_rate']
-    max_len = config['translations']['create_batches']['max_tokens']
-    num_epochs = config['translations']['training']['num_epochs']
-    batch_size = config['translations']['training']['batch_size']
-    save_step = config['translations']['training']['save_step']
-    loss_averaging_span = config['translations']['training']['loss_averaging_span']
-    fresh_start = config['translations']['training']['fresh_start']
-    device_name = config['translations']['training']['device']
-    max_train_time = config['translations']['training']['max_training_time']
-    base_model = config['global']['base_model']
+    learning_rate = config["translations"]["training"]["learning_rate"]
+    max_len = config["translations"]["create_batches"]["max_tokens"]
+    num_epochs = config["translations"]["training"]["num_epochs"]
+    batch_size = config["translations"]["training"]["batch_size"]
+    save_step = config["translations"]["training"]["save_step"]
+    loss_averaging_span = config["translations"]["training"][
+        "loss_averaging_span"
+    ]
+    fresh_start = config["translations"]["training"]["fresh_start"]
+    device_name = config["translations"]["training"]["device"]
+    max_train_time = config["translations"]["training"]["max_training_time"]
+    base_model = config["global"]["base_model"]
 
     prepare_folder(OUTPUT_PATH)
 
@@ -38,16 +41,18 @@ if __name__ == "__main__":
     print(f"Training on {device}")
 
     df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
-    
+
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
-    model = TransformerSeq2Seq(tokenizer.vocab_size, 256, max_len, 4, 4, 4, ).to(device)
+    model = TransformerSeq2Seq(
+        tokenizer.vocab_size, 256, max_len, 4, 4, 4,
+    ).to(device)
     criterion = torch.nn.CrossEntropyLoss(reduction="mean").to(device)
     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
     epoch_start = 0
     sample_start = 0
-    if fresh_start == False:
+    if fresh_start is False:
         checkpoint_files = glob.glob(f"{OUTPUT_PATH}/*.model")
         furthest_epoch = -1
         furthest_batch_num = -1
@@ -61,8 +66,16 @@ if __name__ == "__main__":
                 furthest_batch_num = max(iteration, furthest_batch_num)
 
         if furthest_epoch > -1 and furthest_batch_num > -1:
-            model.load_state_dict(torch.load(f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.model"))
-            optimizer.load_state_dict(torch.load(f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.optimizer"))
+            model.load_state_dict(
+                torch.load(
+                    f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.model"
+                )
+            )
+            optimizer.load_state_dict(
+                torch.load(
+                    f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.optimizer"
+                )
+            )
 
             epoch_start, sample_start = furthest_epoch, furthest_batch_num
             print(f"Loaded {furthest_epoch}-{furthest_batch_num}")
@@ -81,25 +94,42 @@ if __name__ == "__main__":
             break
 
         i = sample_start
-        
+
         while True:
-            data_batch_indexes = list(range(i*batch_size, i*batch_size + batch_size))
-            
+            data_batch_indexes = list(
+                range(i * batch_size, i * batch_size + batch_size)
+            )
+
             # Precomputing total number of samples takes very long, so lets
             # try to get next batch until fail :)
             try:
                 data_batch = df.loc[data_batch_indexes].compute()
-            except:
+            except Exception:
                 # TODO: Specify exception type
                 break
 
-            inputs = data_batch.apply(lambda x: x['source'].reshape(x['source_shape']), axis=1).values
-            outputs = data_batch.apply(lambda x: x['target'].reshape(x['target_shape']), axis=1).values
-            attentions_mask = data_batch.apply(lambda x: x['attention_mask'].reshape(x['attention_mask_shape']), axis=1).values
-
-            inputs = torch.tensor(np.stack(inputs, axis=0), dtype=torch.long).to(device)
-            attentions_mask = torch.tensor(np.stack(attentions_mask, axis=0) == 0).to(device)
-            output_indices = torch.tensor(np.stack(outputs, axis=0), dtype=torch.long).to(device)
+            inputs = data_batch.apply(
+                lambda x: x["source"].reshape(x["source_shape"]), axis=1
+            ).values
+            outputs = data_batch.apply(
+                lambda x: x["target"].reshape(x["target_shape"]), axis=1
+            ).values
+            attentions_mask = data_batch.apply(
+                lambda x: x["attention_mask"].reshape(
+                    x["attention_mask_shape"]
+                ),
+                axis=1,
+            ).values
+
+            inputs = torch.tensor(
+                np.stack(inputs, axis=0), dtype=torch.long
+            ).to(device)
+            attentions_mask = torch.tensor(
+                np.stack(attentions_mask, axis=0) == 0
+            ).to(device)
+            output_indices = torch.tensor(
+                np.stack(outputs, axis=0), dtype=torch.long
+            ).to(device)
 
             y_pred = model(inputs, output_indices[:, :-1], attentions_mask)
             y_pred = y_pred.transpose(1, 2)
@@ -110,19 +140,31 @@ if __name__ == "__main__":
             if len(losses) > loss_averaging_span:
                 losses = losses[-loss_averaging_span:]
 
-            print(f'epoch: {epoch} | step: {i} | loss: {np.mean(losses)}')
+            print(f"epoch: {epoch} | step: {i} | loss: {np.mean(losses)}")
 
             optimizer.zero_grad()
 
-            if i % save_step == 0 and (i != sample_start or epoch != epoch_start):
+            if i % save_step == 0 and (
+                i != sample_start or epoch != epoch_start
+            ):
                 print(f"Saving: Epoch {epoch}, step {i}")
-                torch.save(model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model")
-                torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.optimizer")
+                torch.save(
+                    model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model"
+                )
+                torch.save(
+                    optimizer.state_dict(),
+                    f"{OUTPUT_PATH}/{epoch}-{i}.optimizer",
+                )
 
             if datetime.now() > time_max:
                 print(f"Max time reached, saving: Epoch {epoch}, step {i}")
-                torch.save(model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model")
-                torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.optimizer")
+                torch.save(
+                    model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model"
+                )
+                torch.save(
+                    optimizer.state_dict(),
+                    f"{OUTPUT_PATH}/{epoch}-{i}.optimizer",
+                )
                 training_stopped = True
                 break
 
@@ -134,4 +176,3 @@ if __name__ == "__main__":
     if not training_stopped:
         torch.save(model.state_dict(), f"{OUTPUT_PATH}/final.model")
         torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/final.optimizer")
-
diff --git a/src/processing.py b/src/processing.py
index eea9f5e..7b14493 100644
--- a/src/processing.py
+++ b/src/processing.py
@@ -1,4 +1,3 @@
-import glob
 from xml.etree import ElementTree as ET
 from typing import Optional, Mapping
 from src.utils import remove_punctuation
@@ -6,11 +5,8 @@ import numpy as np
 from transformers import PreTrainedTokenizerFast
 from collections import defaultdict
 
-ACTIONS_KEYS = [
-    'dot', 
-    'upper_case', 
-    'colon',  
-    'question_mark']
+ACTIONS_KEYS = ["dot", "upper_case", "colon", "question_mark"]
+
 
 def empty_action_vector() -> np.ndarray:
     """Returns a do-nothing actions vector
@@ -20,6 +16,7 @@ def empty_action_vector() -> np.ndarray:
     """
     return np.zeros(len(ACTIONS_KEYS))
 
+
 def text_from_xml(path: str) -> str:
     """Extract spoken text from dataset's xml format
 
@@ -33,7 +30,7 @@ def text_from_xml(path: str) -> str:
 
     full_text = ""
 
-    for node in root.iter('*'):
+    for node in root.iter("*"):
         if len(node) == 0:
             who = node.get("who")
             text = node.text
@@ -47,36 +44,41 @@ def text_from_xml(path: str) -> str:
 
 
 def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
-    """Detect what actions should model perform on a word and returns encoded action vector
+    """Detect what actions should model perform on a word and returns encoded
+       action vector
 
     Args:
         word (str): Word on wich action is decided
-        next_word (Optional[str]): Word that follows considered word. Can be None if nothing follows a word
+        next_word (Optional[str]): Word that follows considered word. Can be
+            None if nothing follows a word
 
     Returns:
-        Mapping[str, bool]: Mapping telling if each of possible actions should be performed (True) or not (False) 
+        Mapping[str, bool]: Mapping telling if each of possible actions should be performed (True) or not (False)
     """
     # Unsuported characters
-    word.replace(";", ".") 
+    word.replace(";", ".")
     word.replace('"', "")
-    word.replace('(', "")
-    word.replace(')', "")
+    word.replace("(", "")
+    word.replace(")", "")
 
-    while len(word) > 0 and not word[0].isalnum(): # remove proceding characters
+    while (
+        len(word) > 0 and not word[0].isalnum()
+    ):  # remove proceding characters
         word = word[1:]
 
     if len(word) == 0:
         return dict(zip(ACTIONS_KEYS, [False] * len(ACTIONS_KEYS)))
 
     actions = {
-        'dot': word[-1] == '.',
-        'upper_case': word[0].isupper(),
-        'colon': word[-1] == ",",
-        'question_mark': word[-1] == "?"
+        "dot": word[-1] == ".",
+        "upper_case": word[0].isupper(),
+        "colon": word[-1] == ",",
+        "question_mark": word[-1] == "?",
     }
 
     return actions
 
+
 def encode_actions(actions: Mapping[str, bool]) -> np.ndarray:
     """Transforms actions into vector
 
@@ -88,9 +90,10 @@ def encode_actions(actions: Mapping[str, bool]) -> np.ndarray:
     """
     return np.array(list(actions.values())).astype(float)
 
+
 def decode_actions(encoded_actions: np.ndarray) -> Mapping[str, bool]:
     """Decodes actions
-    
+
     Args:
         encoded_actions (np.ndarray): 1 dimensional action vector
 
@@ -101,6 +104,7 @@ def decode_actions(encoded_actions: np.ndarray) -> Mapping[str, bool]:
 
     return dict(zip(ACTIONS_KEYS, encoded_actions.astype(np.bool).tolist()))
 
+
 def create_model_input_output(text: str) -> (str, np.ndarray):
     """Returns a pair of input and desired output of the model
 
@@ -119,7 +123,7 @@ def create_model_input_output(text: str) -> (str, np.ndarray):
     i = 0
     while i < len(words):
         word = words[i]
-        next_word = words[i+1] if len(words) > i+1 else None
+        next_word = words[i + 1] if len(words) > i + 1 else None
 
         word_sanitized = remove_punctuation(word).lower()
         if len(word_sanitized) > 0:
@@ -135,7 +139,10 @@ def create_model_input_output(text: str) -> (str, np.ndarray):
 
     return " ".join(words_output), np.array(actions_output)
 
-def token_word_mapping(text: str, tokenizer: PreTrainedTokenizerFast) -> np.ndarray:
+
+def token_word_mapping(
+    text: str, tokenizer: PreTrainedTokenizerFast
+) -> np.ndarray:
     """Returns mapping where each token is labeled with index of word it's part of
 
     Args:
@@ -146,9 +153,9 @@ def token_word_mapping(text: str, tokenizer: PreTrainedTokenizerFast) -> np.ndar
         np.ndarray: Array of length L (number of tokens) where each entry is index of word (cls and sep labels are not counted).
     """
     text_tokenized = tokenizer(text, return_offsets_mapping=True)
-    offset_mappings = text_tokenized['offset_mapping'][1:-1]
+    offset_mappings = text_tokenized["offset_mapping"][1:-1]
 
-    offset_mappings = text_tokenized['offset_mapping'][1:-1]
+    offset_mappings = text_tokenized["offset_mapping"][1:-1]
 
     # Create a map where each character is assigned index of it's word
     words_mapping = []
@@ -162,7 +169,10 @@ def token_word_mapping(text: str, tokenizer: PreTrainedTokenizerFast) -> np.ndar
 
     return np.array(token_mapping)
 
-def token_labels_to_word_labels(text: str, token_labels: np.ndarray, tokenizer: PreTrainedTokenizerFast) -> np.ndarray:
+
+def token_labels_to_word_labels(
+    text: str, token_labels: np.ndarray, tokenizer: PreTrainedTokenizerFast
+) -> np.ndarray:
     mapping = token_word_mapping(text, tokenizer)
 
     assert len(mapping) == len(token_labels)
@@ -172,16 +182,17 @@ def token_labels_to_word_labels(text: str, token_labels: np.ndarray, tokenizer:
     for i in range(len(mapping)):
         labels[mapping[i]].append(token_labels[i])
 
-    return np.array([
-        np.mean(labels[x], axis=0) for x in sorted(labels)
-    ])
+    return np.array([np.mean(labels[x], axis=0) for x in sorted(labels)])
+
 
-def tokenize_labeled_text(text: str, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast) -> (np.ndarray, np.ndarray):
+def tokenize_labeled_text(
+    text: str, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast
+) -> (np.ndarray, np.ndarray):
     """Transforms text into numerical tokens. Also expand word-level labels into token-level labels
 
     Args:
         text (str): Text that will be tokenized (TODO: Change to array)
-        labels (np.ndarray): Word-level labels for text to be tokenized. Word is defined via space spearation 
+        labels (np.ndarray): Word-level labels for text to be tokenized. Word is defined via space spearation
         tokenizer (PreTrainedTokenizerFast): Tokenizer that will be used for tokenization
 
     Returns:
@@ -190,8 +201,8 @@ def tokenize_labeled_text(text: str, labels: np.ndarray, tokenizer: PreTrainedTo
     """
     text_tokenized = tokenizer(text, return_offsets_mapping=True)
 
-    offset_mappings = text_tokenized['offset_mapping'][1:-1]
-    input_ids = text_tokenized['input_ids'][1:-1]
+    offset_mappings = text_tokenized["offset_mapping"][1:-1]
+    input_ids = text_tokenized["input_ids"][1:-1]
 
     # Create a map where each character is assigned index of it's word
     words_mapping = []
@@ -203,7 +214,7 @@ def tokenize_labeled_text(text: str, labels: np.ndarray, tokenizer: PreTrainedTo
 
     # Assign each token to a word
     token_mapping = [words_mapping[x[0]] for x in offset_mappings]
-    
+
     # Expand word-based labels to token-based labels
     labels_tokenized = [labels[i] for i in token_mapping]
 
@@ -221,18 +232,19 @@ def recover_word(word: str, action: Mapping[str, bool]) -> str:
         str: transfomed word
     """
     word_result = word
-    
-    if action['dot']:
+
+    if action["dot"]:
         word_result += "."
-    if action['upper_case']:
+    if action["upper_case"]:
         word_result = word_result.capitalize()
-    if action['colon']:
+    if action["colon"]:
         word_result += ","
-    if action['question_mark']:
+    if action["question_mark"]:
         word_result += "?"
 
     return word_result
 
+
 def is_sentence_end(actions_encoded: np.ndarray) -> bool:
     """Returns if given action would end a sentence
 
@@ -244,7 +256,8 @@ def is_sentence_end(actions_encoded: np.ndarray) -> bool:
     """
     actions_decoded = decode_actions(actions_encoded)
 
-    return actions_decoded['dot'] == True
+    return actions_decoded["dot"] is True
+
 
 def nearest_sentence_l(labels: np.array, index_start: int) -> int:
     """Find nearest word that begins a sentence that has lower or equal index to index_start
@@ -260,7 +273,7 @@ def nearest_sentence_l(labels: np.array, index_start: int) -> int:
     result_index = index_start
 
     while result_index > 0:
-        if is_sentence_end(labels[result_index, :]): 
+        if is_sentence_end(labels[result_index, :]):
             # prevent beeing in the middle of token
             result_index -= 1
         elif is_sentence_end(labels[result_index - 1, :]):
@@ -273,6 +286,7 @@ def nearest_sentence_l(labels: np.array, index_start: int) -> int:
 
     return result_index
 
+
 def nearest_sentence_r(labels: np.array, index_start: int) -> Optional[int]:
     """Find nearest word that begins a sentence that has higher or equal index to index_start
 
@@ -286,7 +300,9 @@ def nearest_sentence_r(labels: np.array, index_start: int) -> Optional[int]:
     result_index = index_start
 
     while result_index < len(labels):
-        if is_sentence_end(labels[result_index - 1]) and not is_sentence_end(labels[result_index]):
+        if is_sentence_end(labels[result_index - 1]) and not is_sentence_end(
+            labels[result_index]
+        ):
             break
         else:
             result_index += 1
@@ -296,7 +312,10 @@ def nearest_sentence_r(labels: np.array, index_start: int) -> Optional[int]:
     else:
         return result_index
 
-def batchify_labels(labels: np.ndarray, max_tokens: int, min_tokens: int = 3) -> [np.ndarray]:
+
+def batchify_labels(
+    labels: np.ndarray, max_tokens: int, min_tokens: int = 3
+) -> [np.ndarray]:
     """Splits long labels array into batches of desired size
 
     Args:
@@ -324,19 +343,28 @@ def batchify_labels(labels: np.ndarray, max_tokens: int, min_tokens: int = 3) ->
             if new_index == index:
                 new_index = nearest_sentence_r(labels, index + num_consumed)
                 if new_index is None:
-                    labels_batches.append(np.array(list(range(index, index + num_consumed))))
+                    labels_batches.append(
+                        np.array(list(range(index, index + num_consumed)))
+                    )
                     break
         else:
-            labels_batches.append(np.array(list(range(index, index + num_consumed))))
+            labels_batches.append(
+                np.array(list(range(index, index + num_consumed)))
+            )
             break
 
-        labels_batches.append(np.array(list(range(index, index + num_consumed))))
+        labels_batches.append(
+            np.array(list(range(index, index + num_consumed)))
+        )
 
         index = new_index
 
     return labels_batches
 
-def add_cls_sep(tokens: np.ndarray, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast) -> (np.ndarray, np.ndarray):
+
+def add_cls_sep(
+    tokens: np.ndarray, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast
+) -> (np.ndarray, np.ndarray):
     """Adds staring cls and ending sep token ids into tokens & labels
 
     Args:
@@ -347,19 +375,27 @@ def add_cls_sep(tokens: np.ndarray, labels: np.ndarray, tokenizer: PreTrainedTok
         np.ndarray: tokens with added cls & sep tokens ids
         np.ndarray: labels with first and last item duplicated to accomodate for cls & sep
     """
-    
-    tokens = np.concatenate([[[tokenizer.cls_token_id]], tokens, [[tokenizer.sep_token_id]]])
+
+    tokens = np.concatenate(
+        [[[tokenizer.cls_token_id]], tokens, [[tokenizer.sep_token_id]]]
+    )
     labels = np.concatenate([labels[:1, :], labels, labels[-1:, :]])
 
     return tokens, labels
 
-def add_padding(tokens: np.ndarray, labels: np.ndarray, length: int, tokenizer: PreTrainedTokenizerFast) -> (np.ndarray, np.ndarray, np.ndarray):
+
+def add_padding(
+    tokens: np.ndarray,
+    labels: np.ndarray,
+    length: int,
+    tokenizer: PreTrainedTokenizerFast,
+) -> (np.ndarray, np.ndarray, np.ndarray):
     """Appends padding to tokens and labels to match desired length
 
     Args:
         tokens (np.ndarray): Lx1 array of token ids
         labels (np.ndarray): LxA array of action vectors
-        length (int): Desired length of a vector. Must be higher than L  
+        length (int): Desired length of a vector. Must be higher than L
         tokenizer (PreTrainedTokenizerFast): Tokenizer that was used for tokenization
 
     Returns:
@@ -372,7 +408,9 @@ def add_padding(tokens: np.ndarray, labels: np.ndarray, length: int, tokenizer:
     assert pad_length >= 0
 
     if pad_length > 0:
-        tokens = np.concatenate([tokens, [[tokenizer.pad_token_id]] * pad_length])
+        tokens = np.concatenate(
+            [tokens, [[tokenizer.pad_token_id]] * pad_length]
+        )
         labels = np.concatenate([labels, [empty_action_vector()] * pad_length])
 
     mask = np.ones(len(tokens)).astype(np.int)
@@ -382,8 +420,14 @@ def add_padding(tokens: np.ndarray, labels: np.ndarray, length: int, tokenizer:
 
     return tokens, labels, mask
 
-def batchify_data(tokens: np.ndarray, labels: np.ndarray, max_tokens: int,
-                    tokenizer: PreTrainedTokenizerFast, min_tokens: int = 3) -> (np.ndarray, np.ndarray):
+
+def batchify_data(
+    tokens: np.ndarray,
+    labels: np.ndarray,
+    max_tokens: int,
+    tokenizer: PreTrainedTokenizerFast,
+    min_tokens: int = 3,
+) -> (np.ndarray, np.ndarray):
     """Transforms tokens and labels into a batch
 
     Args:
@@ -410,11 +454,15 @@ def batchify_data(tokens: np.ndarray, labels: np.ndarray, max_tokens: int,
         assert len(ids) >= min_tokens
         assert len(ids) <= max_tokens - 2
 
-        tokens_sample, labels_sample = add_cls_sep(tokens_sample, labels_sample, tokenizer)
+        tokens_sample, labels_sample = add_cls_sep(
+            tokens_sample, labels_sample, tokenizer
+        )
 
         assert len(tokens_sample) <= max_tokens
 
-        tokens_sample, labels_sample, mask = add_padding(tokens_sample, labels_sample, max_tokens, tokenizer)
+        tokens_sample, labels_sample, mask = add_padding(
+            tokens_sample, labels_sample, max_tokens, tokenizer
+        )
 
         tokens_batch.append(tokens_sample)
         labels_batch.append(labels_sample)
@@ -422,6 +470,7 @@ def batchify_data(tokens: np.ndarray, labels: np.ndarray, max_tokens: int,
 
     return np.array(tokens_batch), np.array(labels_batch), np.array(mask_batch)
 
+
 def recover_text(text: str, actions_encoded: np.ndarray):
     words = text.split(" ")
 
@@ -429,7 +478,7 @@ def recover_text(text: str, actions_encoded: np.ndarray):
 
     for word, action_encoded in zip(words, actions_encoded.tolist()):
         action_decoded = decode_actions(np.array(action_encoded))
-        
+
         word_recovered = recover_word(word, action_decoded)
         words_output.append(word_recovered)
 
diff --git a/src/test_batch_loading.py b/src/test_batch_loading.py
deleted file mode 100644
index 613c1bb..0000000
--- a/src/test_batch_loading.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import numpy as np
-import pandas as pd
-import dask.dataframe as dd
-from src.batch_loading import *
-
-def test_calculate_batch_buffer_id():
-    ids = [0, 1, 2, 3, 4, 5, 6]
-    assert calculate_batch_buffer_id(0, 3)  == 0
-    assert calculate_batch_buffer_id(1, 3)  == 0
-    assert calculate_batch_buffer_id(2, 3)  == 0
-    assert calculate_batch_buffer_id(3, 3)  == 1
-    assert calculate_batch_buffer_id(4, 3)  == 1
-    assert calculate_batch_buffer_id(5, 3)  == 1
-    assert calculate_batch_buffer_id(6, 3)  == 2
-
-def test_yield_batch_buffer_span():
-    ids = [0, 1, 2, 3, 4, 5, 6]
-
-    result = list(yield_batch_buffer_span(2, 2, len(ids)))
-
-    assert np.all(result[0] == [0, 1, 2, 3])
-    assert np.all(result[1] == [4, 5, 6])
-
-def test_get_ordered_dataframe_len():
-    df = pd.DataFrame({'a': [1, 2, 3, 4, 5, 6, 7]})
-
-    assert get_ordered_dataframe_len(df) == 7
-
-def test_get_batches():
-    batch_size = 2
-    batch_buffer_len = 2
-    pdf = pd.DataFrame({'a': [1,0,2,3,4,5,6]})
-    shuffled_ids = np.array([1, 0, 2, 3, 4, 5, 6])
-    df = dd.from_pandas(pdf, npartitions=2)
-
-    batches = list(get_batches(df, batch_size, batch_buffer_len, shuffled_ids))
-
-    assert np.all(batches[0]['a'].values == [0, 1])
-    assert np.all(batches[1]['a'].values == [2, 3])
-    assert np.all(batches[2]['a'].values == [4, 5])
-    assert np.all(batches[3]['a'].values == [6])
-
-    batches = list(get_batches(df, batch_size, batch_buffer_len, shuffled_ids, 1))
-
-    assert np.all(batches[1]['a'].values == [2, 3])
-    assert np.all(batches[2]['a'].values == [4, 5])
-    assert np.all(batches[3]['a'].values == [6])
diff --git a/src/utils.py b/src/utils.py
index bf36f03..2217a36 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -5,7 +5,10 @@ from datetime import timedelta
 from typing import Optional
 import shutil
 
-PROJECT_ROOT=os.path.dirname(os.path.realpath("/".join(__file__.split("/")) + "/.."))
+PROJECT_ROOT = os.path.dirname(
+    os.path.realpath("/".join(__file__.split("/")) + "/..")
+)
+
 
 def get_config() -> dict:
     """Returns dict with config values
@@ -14,11 +17,12 @@ def get_config() -> dict:
         dict: Dict with condig values
     """
 
-    with open(f"{PROJECT_ROOT}/params.yaml", "r")  as file:
+    with open(f"{PROJECT_ROOT}/params.yaml", "r") as file:
         config = yaml.load(file, Loader=yaml.FullLoader)
 
     return config
 
+
 def remove_multiple_spaces(text: str) -> str:
     """Replaces multiple spaces by a single one
 
@@ -30,9 +34,10 @@ def remove_multiple_spaces(text: str) -> str:
     """
     return re.sub(r"\s\s+", " ", text)
 
+
 def remove_punctuation(text: str) -> str:
-    """Removes all non-alphanumeric characters from the text.  
-    Might result in multiple spaces while chracters like `-` 
+    """Removes all non-alphanumeric characters from the text.
+    Might result in multiple spaces while chracters like `-`
     are used
 
     Args:
@@ -41,7 +46,8 @@ def remove_punctuation(text: str) -> str:
     Returns:
         str: Text with all punctuactions removed
     """
-    return ''.join(filter(lambda x: x.isalnum() or x.isspace(), text))
+    return "".join(filter(lambda x: x.isalnum() or x.isspace(), text))
+
 
 def prepare_folder(path: str, wipe: bool = False) -> None:
     """Function make sure that provided path exists. Can aditionaly
@@ -57,13 +63,13 @@ def prepare_folder(path: str, wipe: bool = False) -> None:
 
     os.makedirs(path, exist_ok=True)
 
+
 def convert_to_timedelta(time_val: str) -> Optional[timedelta]:
     """
     src: https://code.activestate.com/recipes/577894-convert-strings-like-5d-and-60s-to-timedelta-objec/
     Given a *time_val* (string) such as '5d', returns a timedelta object
-    representing the given value (e.g. timedelta(days=5)).  Accepts the
-    following '<num><char>' formats:
-    
+    representing the given value (e.g. timedelta(days=5)).
+
     =========   ======= ===================
     Character   Meaning Example
     =========   ======= ===================
@@ -72,9 +78,9 @@ def convert_to_timedelta(time_val: str) -> Optional[timedelta]:
     h           Hours   '24h' -> 24 Hours
     d           Days    '7d'  -> 7 Days
     =========   ======= ===================
-    
+
     Examples::
-    
+
         >>> convert_to_timedelta('7d')
         datetime.timedelta(7)
         >>> convert_to_timedelta('24h')
@@ -85,13 +91,13 @@ def convert_to_timedelta(time_val: str) -> Optional[timedelta]:
         datetime.timedelta(0, 120)
     """
     num = int(time_val[:-1])
-    if time_val.endswith('s'):
+    if time_val.endswith("s"):
         return timedelta(seconds=num)
-    elif time_val.endswith('m'):
+    elif time_val.endswith("m"):
         return timedelta(minutes=num)
-    elif time_val.endswith('h'):
+    elif time_val.endswith("h"):
         return timedelta(hours=num)
-    elif time_val.endswith('d'):
+    elif time_val.endswith("d"):
         return timedelta(days=num)
     else:
-        return None
\ No newline at end of file
+        return None
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/pipelines/__init__.py b/tests/pipelines/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/pipelines/translation_based/__init__.py b/tests/pipelines/translation_based/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/scripts/translation_based/test_processing.py b/tests/pipelines/translation_based/test_processing.py
similarity index 58%
rename from scripts/translation_based/test_processing.py
rename to tests/pipelines/translation_based/test_processing.py
index b340204..b83e01b 100644
--- a/scripts/translation_based/test_processing.py
+++ b/tests/pipelines/translation_based/test_processing.py
@@ -1,6 +1,14 @@
 import numpy as np
-from scripts.translation_based.processing import (
-    find_new_sentence_left, find_new_sentence_right, get_batch_indexes, add_padding, add_begin_end_tokens, standarize_translation_sample, create_input_output, crete_input_output_batch)
+from src.pipelines.translation_based.processing import (
+    find_new_sentence_left,
+    find_new_sentence_right,
+    get_batch_indexes,
+    add_padding,
+    add_begin_end_tokens,
+    standarize_translation_sample,
+    create_input_output,
+    crete_input_output_batch,
+)
 from transformers import BertTokenizerFast
 
 
@@ -29,10 +37,7 @@ def test_split_to_samples():
     min_len = 3
     max_len = 5
     test_input = np.array([0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0])
-    expeted_output = [
-        np.array([0, 1, 2, 3, 4]),
-        np.array([6, 7, 8, 9, 10])
-    ]
+    expeted_output = [np.array([0, 1, 2, 3, 4]), np.array([6, 7, 8, 9, 10])]
 
     result = get_batch_indexes(test_input, min_len, max_len)
     assert len(result) == len(expeted_output)
@@ -55,10 +60,7 @@ def test_add_padding():
     assert np.all(result == [1, 2, 3, 4, 9, 9])
 
     # multidimensional use-case
-    input_sequence = np.array([
-        [1, 2, 3],
-        [4, 5, 6]
-    ])
+    input_sequence = np.array([[1, 2, 3], [4, 5, 6]])
     padd = np.array([9, 9, 9])
     result = add_padding(input_sequence, 4, padd)
     assert len(result) == 4
@@ -85,12 +87,35 @@ def test_standarize_translation_sample():
 def test_create_input_output():
     sequence = [56500, 117, 10824, 30186, 11090, 10113, 119]
     tokenizer = BertTokenizerFast.from_pretrained(
-        "bert-base-multilingual-cased")
-
-    expected_output_sequence = [tokenizer.cls_token_id, 56500, 117, 10824, 30186, 11090,
-                               10113, 119, tokenizer.sep_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id]
-    expected_input_sequence = [tokenizer.cls_token_id, 21739, 10824, 16469, tokenizer.sep_token_id, tokenizer.pad_token_id,
-                                tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id]
+        "bert-base-multilingual-cased"
+    )
+
+    expected_output_sequence = [
+        tokenizer.cls_token_id,
+        56500,
+        117,
+        10824,
+        30186,
+        11090,
+        10113,
+        119,
+        tokenizer.sep_token_id,
+        tokenizer.pad_token_id,
+        tokenizer.pad_token_id,
+    ]
+    expected_input_sequence = [
+        tokenizer.cls_token_id,
+        21739,
+        10824,
+        16469,
+        tokenizer.sep_token_id,
+        tokenizer.pad_token_id,
+        tokenizer.pad_token_id,
+        tokenizer.pad_token_id,
+        tokenizer.pad_token_id,
+        tokenizer.pad_token_id,
+        tokenizer.pad_token_id,
+    ]
 
     result_input, result_output = create_input_output(sequence, 11, tokenizer)
 
@@ -102,26 +127,53 @@ def test_create_input_output():
 
 def test_create_input_output_batch():
     tokenizer = BertTokenizerFast.from_pretrained(
-        "bert-base-multilingual-cased")
+        "bert-base-multilingual-cased"
+    )
 
-    expected_output_1 = np.array(tokenizer("Ala, ma KoTa.")['input_ids'])[1:-1]
-    expected_output_2 = np.array(tokenizer("A kOt nie!")['input_ids'])[1:-1]
+    expected_output_1 = np.array(tokenizer("Ala, ma KoTa.")["input_ids"])[1:-1]
+    expected_output_2 = np.array(tokenizer("A kOt nie!")["input_ids"])[1:-1]
 
-    expected_input_1 = np.array(tokenizer("ala ma kota")['input_ids'])[1:-1]
-    expected_input_2 = np.array(tokenizer("a kot nie")['input_ids'])[1:-1]
+    expected_input_1 = np.array(tokenizer("ala ma kota")["input_ids"])[1:-1]
+    expected_input_2 = np.array(tokenizer("a kot nie")["input_ids"])[1:-1]
 
     input_sequence = np.concatenate([expected_output_1, expected_output_2])
     batch_ids = [
         np.array(list(range(len(expected_output_1)))),
-        np.array(list(range(len(expected_output_2)))) + len(expected_output_1)
+        np.array(list(range(len(expected_output_2)))) + len(expected_output_1),
     ]
 
-    expected_input_1 = standarize_translation_sample(expected_input_1, 20, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id)
-    expected_input_2 = standarize_translation_sample(expected_input_2, 20, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id)
-    expected_output_1 = standarize_translation_sample(expected_output_1, 20, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id)
-    expected_output_2 = standarize_translation_sample(expected_output_2, 20, tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id)
-
-    result_input, result_output = crete_input_output_batch(input_sequence, batch_ids, 20, tokenizer)
+    expected_input_1 = standarize_translation_sample(
+        expected_input_1,
+        20,
+        tokenizer.pad_token_id,
+        tokenizer.cls_token_id,
+        tokenizer.sep_token_id,
+    )
+    expected_input_2 = standarize_translation_sample(
+        expected_input_2,
+        20,
+        tokenizer.pad_token_id,
+        tokenizer.cls_token_id,
+        tokenizer.sep_token_id,
+    )
+    expected_output_1 = standarize_translation_sample(
+        expected_output_1,
+        20,
+        tokenizer.pad_token_id,
+        tokenizer.cls_token_id,
+        tokenizer.sep_token_id,
+    )
+    expected_output_2 = standarize_translation_sample(
+        expected_output_2,
+        20,
+        tokenizer.pad_token_id,
+        tokenizer.cls_token_id,
+        tokenizer.sep_token_id,
+    )
+
+    result_input, result_output = crete_input_output_batch(
+        input_sequence, batch_ids, 20, tokenizer
+    )
 
     assert result_input.shape[0] == 2
     assert result_input.shape[1] == 20
@@ -133,4 +185,4 @@ def test_create_input_output_batch():
     assert np.all(result_input[1] == expected_input_2)
 
     assert np.all(result_output[0] == expected_output_1)
-    assert np.all(result_output[1] == expected_output_2)
\ No newline at end of file
+    assert np.all(result_output[1] == expected_output_2)
diff --git a/tests/test_batch_loading.py b/tests/test_batch_loading.py
new file mode 100644
index 0000000..3ccc510
--- /dev/null
+++ b/tests/test_batch_loading.py
@@ -0,0 +1,58 @@
+import numpy as np
+import pandas as pd
+import dask.dataframe as dd
+from src.batch_loading import (
+    calculate_batch_buffer_id,
+    yield_batch_buffer_span,
+    get_ordered_dataframe_len,
+    get_batches
+)
+
+
+def test_calculate_batch_buffer_id():
+    # ids = [0, 1, 2, 3, 4, 5, 6]
+    assert calculate_batch_buffer_id(0, 3) == 0
+    assert calculate_batch_buffer_id(1, 3) == 0
+    assert calculate_batch_buffer_id(2, 3) == 0
+    assert calculate_batch_buffer_id(3, 3) == 1
+    assert calculate_batch_buffer_id(4, 3) == 1
+    assert calculate_batch_buffer_id(5, 3) == 1
+    assert calculate_batch_buffer_id(6, 3) == 2
+
+
+def test_yield_batch_buffer_span():
+    ids = [0, 1, 2, 3, 4, 5, 6]
+
+    result = list(yield_batch_buffer_span(2, 2, len(ids)))
+
+    assert np.all(result[0] == [0, 1, 2, 3])
+    assert np.all(result[1] == [4, 5, 6])
+
+
+def test_get_ordered_dataframe_len():
+    df = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6, 7]})
+
+    assert get_ordered_dataframe_len(df) == 7
+
+
+def test_get_batches():
+    batch_size = 2
+    batch_buffer_len = 2
+    pdf = pd.DataFrame({"a": [1, 0, 2, 3, 4, 5, 6]})
+    shuffled_ids = np.array([1, 0, 2, 3, 4, 5, 6])
+    df = dd.from_pandas(pdf, npartitions=2)
+
+    batches = list(get_batches(df, batch_size, batch_buffer_len, shuffled_ids))
+
+    assert np.all(batches[0]["a"].values == [0, 1])
+    assert np.all(batches[1]["a"].values == [2, 3])
+    assert np.all(batches[2]["a"].values == [4, 5])
+    assert np.all(batches[3]["a"].values == [6])
+
+    batches = list(
+        get_batches(df, batch_size, batch_buffer_len, shuffled_ids, 1)
+    )
+
+    assert np.all(batches[1]["a"].values == [2, 3])
+    assert np.all(batches[2]["a"].values == [4, 5])
+    assert np.all(batches[3]["a"].values == [6])
diff --git a/src/test_processing.py b/tests/test_processing.py
similarity index 63%
rename from src/test_processing.py
rename to tests/test_processing.py
index 64ef6d7..45ac686 100644
--- a/src/test_processing.py
+++ b/tests/test_processing.py
@@ -1,100 +1,105 @@
-from src.processing import *
-from transformers import PreTrainedTokenizerFast, BertTokenizerFast
+from src.processing import (
+    detect_actions,
+    encode_actions,
+    token_word_mapping,
+    tokenize_labeled_text,
+    token_labels_to_word_labels,
+    create_model_input_output,
+    recover_text,
+    nearest_sentence_l,
+    nearest_sentence_r,
+    batchify_labels,
+    batchify_data,
+    ACTIONS_KEYS,
+    decode_actions
+)
+from transformers import BertTokenizerFast
 import pytest
+import numpy as np
+
 
 def test_detect_actions():
-    actions = detect_actions("Janek...", None)
+    actions = detect_actions("Janek.", None)
     assert actions == {
-        'dot': False,
-        'upper_case': True,
-        'colon': False,
-        'semicolon': False,
-        'elipsis': True,
-        'dash': False
+        "dot": True,
+        "upper_case": True,
+        "colon": False,
+        "question_mark": False,
     }
 
-    actions = detect_actions("ewka.", None)
+    actions = detect_actions("ewka?", None)
     assert actions == {
-        'dot': True,
-        'upper_case': False,
-        'colon': False,
-        'semicolon': False,
-        'elipsis': False,
-        'dash': False
+        "dot": False,
+        "upper_case": False,
+        "colon": False,
+        "question_mark": True,
     }
 
-    actions = detect_actions("Test", "-")
+    actions = detect_actions("Test", None)
     assert actions == {
-        'dot': False,
-        'upper_case': True,
-        'colon': False,
-        'semicolon': False,
-        'elipsis': False,
-        'dash': True
+        "dot": False,
+        "upper_case": True,
+        "colon": False,
+        "question_mark": False,
     }
 
 
 def test_encode_actions():
     x = {
-        'dot': True,
-        'upper_case': False,
-        'colon': False,
-        'semicolon': True,
-        'elipsis': False,
-        'dash': True
+        "dot": True,
+        "upper_case": False,
+        "colon": False,
+        "question_mark": True,
     }
 
-    assert np.all(encode_actions(x) == np.array([1, 0, 0, 1, 0, 1]))
+    assert np.all(encode_actions(x) == np.array([1, 0, 0, 1]))
 
 
 def test_decode_actions():
-    x = np.array([1, 0, 0, 1, 0, 1])
+    x = np.array([1, 0, 0, 1])
 
     assert decode_actions(x) == {
-        'dot': True,
-        'upper_case': False,
-        'colon': False,
-        'semicolon': True,
-        'elipsis': False,
-        'dash': True
+        "dot": True,
+        "upper_case": False,
+        "colon": False,
+        "question_mark": True,
     }
 
+
 def test_token_word_mapping():
     text = "janek poszedł do ogrodu"
     tokenizer = BertTokenizerFast.from_pretrained(
-        'bert-base-multilingual-cased')
+        "bert-base-multilingual-cased"
+    )
 
     text_tokenized = tokenizer(text)
 
     mapping = token_word_mapping(text, tokenizer)
 
-    assert len(mapping) == (len(text_tokenized['input_ids']) - 2)
+    assert len(mapping) == (len(text_tokenized["input_ids"]) - 2)
     assert min(mapping) == 0
     assert max(mapping) == 3
 
+
 def test_token_labels_to_word_labels():
     text = "janek poszedł do ogrodu"
-    labels = np.array([
-        [0, 0, 0],
-        [1, 0, 0],
-        [0, 1, 0],
-        [0, 0, 1]
-    ])
+    labels = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
     tokenizer = BertTokenizerFast.from_pretrained(
-        'bert-base-multilingual-cased')
+        "bert-base-multilingual-cased"
+    )
 
-    tokens, token_labels = tokenize_labeled_text(text, labels, tokenizer)
+    _, token_labels = tokenize_labeled_text(text, labels, tokenizer)
 
-    mapping = token_word_mapping(text, tokenizer)
     word_labels = token_labels_to_word_labels(text, token_labels, tokenizer)
 
-    assert np.all(np.vectorize(pytest.approx)(word_labels, labels)) == True
-    
+    assert np.all(np.vectorize(pytest.approx)(word_labels, labels))
+
 
 def test_tokenize_labeled_text():
-    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię..."
+    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię?"
     tokenizer = BertTokenizerFast.from_pretrained(
-        'bert-base-multilingual-cased')
+        "bert-base-multilingual-cased"
+    )
 
     text_clean, labels = create_model_input_output(text)
     tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
@@ -109,11 +114,9 @@ def test_tokenize_labeled_text():
     assert tokens[0, 0] != tokenizer.cls_token_id
     assert tokens[-1, 0] != tokenizer.sep_token_id
 
-def test_recover_text():
-    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię..."
-    tokenizer = BertTokenizerFast.from_pretrained(
-        'bert-base-multilingual-cased')
 
+def test_recover_text():
+    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię?"
     text_clean, word_labels = create_model_input_output(text)
 
     result_text = recover_text(text_clean, word_labels)
@@ -134,14 +137,14 @@ def test_nearest_sentence_l():
 
 
 def create_dummy_action(end_sentence: bool) -> np.array:
-    return encode_actions({
-        'dot': end_sentence,
-        'upper_case': False,
-        'colon': False,
-        'semicolon': False,
-        'elipsis': False,
-        'dash': False
-    })
+    return encode_actions(
+        {
+            "dot": end_sentence,
+            "upper_case": False,
+            "colon": False,
+            "question_mark": False,
+        }
+    )
 
 
 def test_nearest_sentence_r():
@@ -170,18 +173,20 @@ def test_batchify_labels():
 
 
 def test_batchify_data():
-    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam niedzwiedzia..."
+    text = (
+        "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam"
+        " niedzwiedzia?"
+    )
     tokenizer = BertTokenizerFast.from_pretrained(
-        'bert-base-multilingual-cased')
+        "bert-base-multilingual-cased"
+    )
 
     text_clean, labels = create_model_input_output(text)
     tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
 
-    # print(tokenizer.convert_ids_to_tokens(tokens.reshape(-1).astype(int)))
-    # print(token_labels)
-
     input_batch, output_batch, mask_batch = batchify_data(
-        tokens, token_labels, 5, tokenizer)
+        tokens, token_labels, 5, tokenizer
+    )
 
     assert len(input_batch.shape) == 3
     assert len(output_batch.shape) == 3
@@ -203,13 +208,14 @@ def test_batchify_data():
     assert mask_batch.dtype == np.int
 
     # Should never be fully masked
-    assert np.all(mask_batch[:, 0] == 0) == False
+    # TODO: Make sure correct convetions is used
+    assert np.all(mask_batch[:, 0] == 1)
 
     # Should never be fully masked0
     for i in range(input_batch.shape[0]):
         # Should always start from beginning of the sentence
-        assert decode_actions(output_batch[i, 0, :])['upper_case']
-        assert decode_actions(output_batch[i, 1, :])['upper_case']
+        assert decode_actions(output_batch[i, 0, :])["upper_case"]
+        assert decode_actions(output_batch[i, 1, :])["upper_case"]
 
         # Should always end with sep and padding#
         # TODO: Test it
diff --git a/tox.ini b/tox.ini
new file mode 100644
index 0000000..43a673d
--- /dev/null
+++ b/tox.ini
@@ -0,0 +1,48 @@
+[tox]
+envlist = py38,flake8,pep8
+skipsdist = True
+
+[testenv]
+deps = 
+    pytest
+    numpy
+    pyyaml
+    pandas 
+    tqdm 
+    torch 
+    dask[complete] 
+    transformers 
+    pyarrow==0.17.1
+    lxml
+
+[testenv:py38]
+commands = pytest --ignore data --ignore generated
+
+[flake8]
+exclude =
+     .tox,
+    .git,
+    __pycache__,
+    docs/source/conf.py,
+    build,
+    dist,
+    tests/fixtures/*,
+    *.pyc,
+    *.egg-info,
+    .cache,
+    .eggs
+    data
+    generated
+max-complexity = 10
+max-line-length = 80
+select = C,E,F,W,B,B950
+ignore = E203, E501, W503
+
+
+[testenv:pep8]
+deps =
+    flake8
+basepython = python3
+commands =
+    flake8 {posargs}
+
-- 
GitLab


From 3bac4718ef083f2a0908ec13947f7d108d4021c2 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 12:16:33 +0200
Subject: [PATCH 040/116] Most of deployment code in POC state

---
 .dockerignore                                 | 15 +++
 Dockerfile                                    | 17 +++-
 config.ini                                    | 21 ++++
 notebooks/test_actions_model.ipynb            | 57 +++++++++--
 params.yaml                                   |  2 +-
 main.py => punctuate.py                       | 12 +--
 src/batch_loading.py                          |  4 +-
 src/models/TransformerSeq2Seq.py              | 16 +---
 src/pipelines/actions_based/processing.py     | 95 ++++++++++++++++++-
 .../actions_based/stage1_extraction.py        |  4 +-
 .../actions_based/stage3_exploding.py         |  8 +-
 src/pipelines/actions_based/stage5_stats.py   |  4 +-
 src/pipelines/actions_based/train.py          | 62 ++++--------
 src/pipelines/translation_based/processing.py | 20 +---
 .../translation_based/stage1_extraction.py    |  9 +-
 .../stage2_create_batches.py                  |  4 +-
 .../translation_based/stage3_exploding.py     |  8 +-
 src/pipelines/translation_based/train.py      | 94 ++++++------------
 src/processing.py                             | 30 +++---
 src/training.py                               | 58 +++++++++++
 src/utils.py                                  |  4 +-
 tests/pipelines/actions_based/__init__.py     |  0
 .../actions_based/test_processing.py          | 24 +++++
 .../translation_based/test_processing.py      |  8 +-
 tests/test_batch_loading.py                   |  6 +-
 tests/test_processing.py                      | 23 ++---
 tests/test_training.py                        | 21 ++++
 tox.ini                                       |  2 +-
 worker.py                                     | 42 ++++++++
 29 files changed, 436 insertions(+), 234 deletions(-)
 create mode 100644 .dockerignore
 create mode 100644 config.ini
 rename main.py => punctuate.py (73%)
 create mode 100644 src/training.py
 create mode 100644 tests/pipelines/actions_based/__init__.py
 create mode 100644 tests/pipelines/actions_based/test_processing.py
 create mode 100644 tests/test_training.py
 create mode 100755 worker.py

diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000..19d3acf
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,15 @@
+data
+__pycache__
+.devcontainer
+.dvc
+.idea
+.metals
+.pytest_cache
+.tox
+.vscode
+checkpoints
+dask-worker-space
+data
+generated
+notebooks
+tests
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
index ead19b1..f21915e 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1 +1,16 @@
-FROM clarinpl/python:3.8
\ No newline at end of file
+FROM clarinpl/python:3.8
+
+RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y gcc python3-dev
+RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pyarrow==0.17.1 pytest lxml
+RUN mkdir /punctuator
+WORKDIR /punctuator
+
+COPY src ./src
+COPY config.ini .
+COPY worker.py .
+
+RUN mkdir ./deploy && \
+    wget https://minio.clarin-pl.eu/public/models/punctuation/0-190000.model -O deploy/model
+    
+RUN ls ./deploy
+CMD ["./worker.py"]
\ No newline at end of file
diff --git a/config.ini b/config.ini
new file mode 100644
index 0000000..9e58547
--- /dev/null
+++ b/config.ini
@@ -0,0 +1,21 @@
+[service]
+tool = Punctuator
+
+root = /samba/requests/
+rabbit_host = addr
+rabbit_user = test
+rabbit_password = test
+
+[tool]
+workers_number = 1
+
+[logging]
+port = 9981
+local_log_level = INFO
+
+[deployment]
+device = "cpu"
+chunk_size = 500
+threshold = 0.9
+model = "deploy/model"
+base_model = "dkleczek/bert-base-polish-cased-v1"
\ No newline at end of file
diff --git a/notebooks/test_actions_model.ipynb b/notebooks/test_actions_model.ipynb
index 4f2484e..98877b8 100644
--- a/notebooks/test_actions_model.ipynb
+++ b/notebooks/test_actions_model.ipynb
@@ -14,7 +14,7 @@
   },
   "orig_nbformat": 2,
   "kernelspec": {
-   "name": "python_defaultSpec_1596719587498",
+   "name": "python38264bita7d7da14168440cb9836372958035d4a",
    "display_name": "Python 3.8.2 64-bit"
   }
  },
@@ -27,6 +27,9 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "%load_ext autoreload\n",
+    "%autoreload 2\n",
+    "\n",
     "import sys\n",
     "sys.path.append(\"../\")\n",
     "\n",
@@ -60,11 +63,11 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
-    "expected = \"w niedzielę kiedy prażanie są precz i moczą nogi wieku bereuce na wzgórze zamkowe królów czeskich napierają watahy cudzoziemców podnieconych obecnością w miejscu ważnym dla historii od dziesięciu wieków i ponownie od czasu kiedy na hrad wyniesiono vaclava havla zamek chwilowo nie ma gospodarza ale wszystko jest pod parą i gwardia gotowa zawsze do stania u boku katedra świętego wita jest dla czechów tym czy dla polaków katedra na wawelu żelazny szlak turystyczny prowadzi złotą uliczką gdzie zaczyna boleć głowa od nadmiaru okazji do pamiątkowego zdjęcia w tym małym domu mieszkał wielki kafka zanim się dostaniemy na lewy brzeg wełtawy wypada spędzić dłuższą chwilę na samym moście bujnie opiewanym i jeśli nie najpiękniejszym to z pewnością jedynym w roli tak ciągle twórczej dla folkloru wielkiego miasta najpierw śmierć uderza młoteczkiem w dzwonki a potem defiluje i przedstawia się turystom dwanaście apostołów zegar orloja jest najwyższą z wiekowych atrakcji na rynku starego miasta nie budujemy drugich czech choć bliżej nam do pragi niż do tokio i już tysiąc lat temu zaczął święty wojciech ciekawe tylko czy jan hus przetrwałby na pomniku w drugich czechach i czy miałby tyle szans co święty wacław przy wiślanym szlaku od mostu poniatowskiego w dół rzeki trzyma się bar pod rurą i może jeszcze ze dwie budy gdzie po dawnemu sprzedaje się piwo marki piwo temperatura zależy od słoneczka szkło nie gra roli a niewiastom wstęp niewskazany w gazetach do niedawna pisano że piwo degraduje polaków w rzeczywistości było dokładnie na odwrót amerykański desant nie uratował zagrożonej cnoty polaków bo ciepłe piwo jest mimo wszystko mniej obrzydliwe od ciepłej coca coli miejsce właściwe obu napojom w kulturze i obyczaju wyznaczył dopiero wolny rynek \"\n",
+    "expected = \"w niedzielę kiedy prażanie są precz i moczą nogi wieku bereuce na wzgórze zamkowe królów czeskich napierają watahy cudzoziemców podnieconych obecnością w miejscu ważnym dla historii od dziesięciu wieków i ponownie od czasu kiedy na hrad wyniesiono vaclava havla zamek chwilowo nie ma gospodarza ale wszystko jest pod parą i gwardia gotowa zawsze do stania u boku katedra świętego wita jest dla czechów tym czy dla polaków katedra na wawelu żelazny szlak turystyczny prowadzi złotą uliczką gdzie zaczyna boleć głowa od nadmiaru okazji do pamiątkowego zdjęcia w tym małym domu mieszkał wielki kafka zanim się dostaniemy na lewy brzeg wełtawy wypada spędzić dłuższą chwilę na samym moście bujnie opiewanym i jeśli nie najpiękniejszym to z pewnością jedynym w roli tak ciągle twórczej dla folkloru wielkiego miasta najpierw śmierć uderza młoteczkiem w dzwonki a potem defiluje i przedstawia się turystom dwanaście apostołów zegar orloja jest najwyższą z wiekowych atrakcji na rynku starego miasta nie budujemy drugich czech choć bliżej nam do pragi niż do tokio i już tysiąc lat temu zaczął święty wojciech ciekawe tylko czy jan hus przetrwałby na pomniku w drugich czechach i czy miałby tyle szans co święty wacław przy wiślanym szlaku od mostu poniatowskiego w dół rzeki trzyma się bar pod rurą i może jeszcze ze dwie budy gdzie po dawnemu sprzedaje się piwo marki piwo temperatura zależy od słoneczka szkło nie gra roli a niewiastom wstęp niewskazany w gazetach do niedawna pisano że piwo degraduje polaków w rzeczywistości było dokładnie na odwrót amerykański desant nie uratował zagrożonej cnoty polaków bo ciepłe piwo jest mimo wszystko mniej obrzydliwe od ciepłej coca coli miejsce właściwe obu napojom w kulturze i obyczaju wyznaczył dopiero wolny rynek\"\n",
     "text_clean = create_model_input_output(expected)[0]\n",
     "\n",
     "inputs = tokenizer(text_clean, return_tensors=\"pt\")"
@@ -72,7 +75,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 5,
    "metadata": {
     "tags": []
    },
@@ -88,18 +91,40 @@
       "text/plain": "<All keys matched successfully>"
      },
      "metadata": {},
-     "execution_count": 17
+     "execution_count": 5
     }
    ],
    "source": [
     "model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=4)\n",
     "device = torch.device(\"cpu\")\n",
-    "model.load_state_dict(torch.load(\"../checkpoints/actions/0-2000.model\", map_location=device))\n"
+    "model.load_state_dict(torch.load(\"../checkpoints/actions/0-100.model\", map_location=device))\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 57,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": "'w niedzielę, kiedy prażanie są precz i moczą nogi wieku Bereuce na wzgórze zamkowe Królów Czeskich napierają watahy cudzoziemców podnieconych obecnością w miejscu ważnym dla historii od dziesięciu wieków i ponownie od czasu, kiedy na Hrad wyniesiono Vaclava Havla Zamek chwilowo nie ma gospodarza ale wszystko jest pod parą i Gwardia gotowa zawsze do stania u boku Katedra Świętego Wita jest dla Czechów tym, czy dla polaków Katedra na Wawelu Żelazny szlak turystyczny prowadzi złotą uliczką gdzie zaczyna boleć głowa od nadmiaru okazji do Pamiątkowego zdjęcia w tym małym domu mieszkał wielki Kafka zanim się dostaniemy na lewy brzeg wełtawy wypada spędzić dłuższą chwilę na samym moście bujnie opiewanym, i jeśli nie najpiękniejszym, to z pewnością jedynym w roli tak ciągle twórczej dla folkloru wielkiego miasta najpierw śmierć uderza młoteczkiem w dzwonki, a potem defiluje i przedstawia się turystom dwanaście apostołów zegar Orloja jest najwyższą z wiekowych atrakcji na rynku starego miasta Nie budujemy drugich czech choć bliżej nam do Pragi niż do Tokio i już tysiąc lat temu zaczął święty wojciech ciekawe tylko czy Jan Hus przetrwałby na pomniku w drugich Czechach i czy miałby tyle szans, co święty Wacław przy wiślanym szlaku od mostu Poniatowskiego w dół rzeki trzyma się bar pod rurą i może jeszcze ze dwie budy, gdzie po dawnemu sprzedaje się piwo marki piwo temperatura zależy od słoneczka szkło nie gra roli, a niewiastom wstęp niewskazany w gazetach do niedawna pisano, że piwo degraduje polaków w rzeczywistości było dokładnie na odwrót Amerykański Desant Nie uratował zagrożonej cnoty Polaków, bo ciepłe piwo jest mimo wszystko mniej obrzydliwe od ciepłej coca coli miejsce właściwe obu napojom w kulturze i obyczaju wyznaczył dopiero wolny rynek'"
+     },
+     "metadata": {},
+     "execution_count": 57
+    }
+   ],
+   "source": [
+    "from src.pipelines.actions_based.processing import apply_actions_punctuation\n",
+    "\n",
+    "apply_actions_punctuation(text_clean, 10, tokenizer, model, 0.9)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 47,
    "metadata": {
     "tags": []
    },
@@ -107,16 +132,30 @@
     {
      "output_type": "stream",
      "name": "stdout",
-     "text": "w niedzielę kiedy prażanie są precz i moczą nogi wieku bereuce na wzgórze zamkowe królów czeskich napierają watahy cudzoziemców podnieconych obecnością w miejscu ważnym dla historii od dziesięciu wieków i ponownie od czasu kiedy na hrad wyniesiono vaclava havla zamek chwilowo nie ma gospodarza ale wszystko jest pod parą i gwardia gotowa zawsze do stania u boku katedra świętego wita jest dla czechów tym czy dla polaków katedra na wawelu żelazny szlak turystyczny prowadzi złotą uliczką gdzie zaczyna boleć głowa od nadmiaru okazji do pamiątkowego zdjęcia w tym małym domu mieszkał wielki kafka zanim się dostaniemy na lewy brzeg wełtawy wypada spędzić dłuższą chwilę na samym moście bujnie opiewanym i jeśli nie najpiękniejszym to z pewnością jedynym w roli tak ciągle twórczej dla folkloru wielkiego miasta najpierw śmierć uderza młoteczkiem w dzwonki a potem defiluje i przedstawia się turystom dwanaście apostołów zegar orloja jest najwyższą z wiekowych atrakcji na rynku starego miasta nie budujemy drugich czech choć bliżej nam do pragi niż do tokio i już tysiąc lat temu zaczął święty wojciech ciekawe tylko czy jan hus przetrwałby na pomniku w drugich czechach i czy miałby tyle szans co święty wacław przy wiślanym szlaku od mostu poniatowskiego w dół rzeki trzyma się bar pod rurą i może jeszcze ze dwie budy gdzie po dawnemu sprzedaje się piwo marki piwo temperatura zależy od słoneczka szkło nie gra roli a niewiastom wstęp niewskazany w gazetach do niedawna pisano że piwo degraduje polaków w rzeczywistości było dokładnie na odwrót amerykański desant nie uratował zagrożonej cnoty polaków bo ciepłe piwo jest mimo wszystko mniej obrzydliwe od ciepłej coca coli miejsce właściwe obu napojom w kulturze i obyczaju wyznaczył dopiero wolny rynek \n------\nW niedzielę, kiedy prażanie są precz I moczą nogi, wieku Bereuce na Wzgórze zamkowe Królów czeskich napierają watahy cudzoziemców, podnieconych obecnością W miejscu, ważnym, Dla historii Od dziesięciu wieków I ponownie, Od czasu, kiedy Na Hrad wyniesiono vaclava Havla Zamek chwilowo Nie ma gospodarza, Ale wszystko jest Pod parą I gwardia gotowa zawsze Do stania U boku. Katedra Świętego wita jest Dla Czechów tym, czy Dla Polaków Katedra Na Wawelu Żelazny Szlak turystyczny prowadzi Złotą uliczką, Gdzie zaczyna boleć głowa Od nadmiaru, okazji, Do pamiątkowego zdjęcia. W tym Małym domu mieszkał Wielki kafka Zanim się dostaniemy na lewy brzeg, wełtawy wypada spędzić dłuższą chwilę Na samym moście, bujnie opiewanym i, jeśli nie najpiękniejszym, to z pewnością, jedynym W roli tak ciągle, twórczej Dla folkloru, Wielkiego miasta. Najpierw śmierć uderza Młoteczkiem w dzwonki, A potem defiluje I przedstawia się turystom Dwanaście apostołów. Zegar Orloja jest najwyższą z wiekowych atrakcji. Na rynku starego miasta. Nie budujemy drugich czech Choć bliżej nam, do Pragi Niż do Tokio I już tysiąc lat temu, zaczął Święty wojciech Ciekawe tylko, czy Jan Hus przetrwałby Na pomniku W Drugich Czechach I Czy miałby tyle szans, co Święty Wacław Przy Wiślanym szlaku Od Mostu Poniatowskiego W dół rzeki Trzyma się bar Pod rurą I może jeszcze ze dwie budy, gdzie, po dawnemu, sprzedaje się piwo, marki piwo. Temperatura zależy od słoneczka Szkło nie gra roli, a niewiastom wstęp niewskazany. W gazetach Do niedawna pisano, że piwo degraduje Polaków W rzeczywistości, było, dokładnie Na odwrót. Amerykański desant Nie uratował zagrożonej cnoty polaków Bo ciepłe piwo, jest, Mimo wszystko, mniej obrzydliwe, Od ciepłej, coca coli miejsce właściwe Obu napojom, W kulturze, i obyczaju, Wyznaczył dopiero wolny rynek\n"
+     "text": "(277, 4)\nw niedzielę kiedy prażanie są precz i moczą nogi wieku bereuce na wzgórze zamkowe królów czeskich napierają watahy cudzoziemców podnieconych obecnością w miejscu ważnym dla historii od dziesięciu wieków i ponownie od czasu kiedy na hrad wyniesiono vaclava havla zamek chwilowo nie ma gospodarza ale wszystko jest pod parą i gwardia gotowa zawsze do stania u boku katedra świętego wita jest dla czechów tym czy dla polaków katedra na wawelu żelazny szlak turystyczny prowadzi złotą uliczką gdzie zaczyna boleć głowa od nadmiaru okazji do pamiątkowego zdjęcia w tym małym domu mieszkał wielki kafka zanim się dostaniemy na lewy brzeg wełtawy wypada spędzić dłuższą chwilę na samym moście bujnie opiewanym i jeśli nie najpiękniejszym to z pewnością jedynym w roli tak ciągle twórczej dla folkloru wielkiego miasta najpierw śmierć uderza młoteczkiem w dzwonki a potem defiluje i przedstawia się turystom dwanaście apostołów zegar orloja jest najwyższą z wiekowych atrakcji na rynku starego miasta nie budujemy drugich czech choć bliżej nam do pragi niż do tokio i już tysiąc lat temu zaczął święty wojciech ciekawe tylko czy jan hus przetrwałby na pomniku w drugich czechach i czy miałby tyle szans co święty wacław przy wiślanym szlaku od mostu poniatowskiego w dół rzeki trzyma się bar pod rurą i może jeszcze ze dwie budy gdzie po dawnemu sprzedaje się piwo marki piwo temperatura zależy od słoneczka szkło nie gra roli a niewiastom wstęp niewskazany w gazetach do niedawna pisano że piwo degraduje polaków w rzeczywistości było dokładnie na odwrót amerykański desant nie uratował zagrożonej cnoty polaków bo ciepłe piwo jest mimo wszystko mniej obrzydliwe od ciepłej coca coli miejsce właściwe obu napojom w kulturze i obyczaju wyznaczył dopiero wolny rynek \n------\n"
+    },
+    {
+     "output_type": "error",
+     "ename": "NameError",
+     "evalue": "name 'actions' is not defined",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-47-43658d364fa3>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpected\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"------\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecover_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext_clean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mactions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m: name 'actions' is not defined"
+     ]
     }
    ],
    "source": [
     "from src.processing import token_labels_to_word_labels, recover_text\n",
     "\n",
     "y_pred = model(**inputs)[0].sigmoid()\n",
+    "y_pred = y_pred > 0.9\n",
+    "\n",
     "labels_pred = token_labels_to_word_labels(text_clean, y_pred.detach().numpy()[0, 1:-1, :], tokenizer)\n",
     "\n",
-    "actions = labels_pred > 0.8\n",
+    "print(labels_pred.shape)\n",
+    "\n",
     "print(expected)\n",
     "print(\"------\")\n",
     "print(recover_text(text_clean, actions))"
diff --git a/params.yaml b/params.yaml
index 93c1563..3b9977d 100644
--- a/params.yaml
+++ b/params.yaml
@@ -31,7 +31,7 @@ actions:
         learning_rate: 0.0001
         num_epochs: 5
         batch_size: 2
-        save_step: 1000
+        save_step: 100
         max_training_time: null
         loss_averaging_span: 1000
         fresh_start: true
diff --git a/main.py b/punctuate.py
similarity index 73%
rename from main.py
rename to punctuate.py
index e01ced9..f46c216 100755
--- a/main.py
+++ b/punctuate.py
@@ -9,18 +9,10 @@ if __name__ == "__main__":
         description="Adds punctuaiton in to raw text stream."
     )
     parser.add_argument(
-        "-i",
-        "--input",
-        type=str,
-        required=True,
-        help="Path to input text file",
+        "-i", "--input", type=str, required=True, help="Path to input text file",
     )
     parser.add_argument(
-        "-o",
-        "--output",
-        type=str,
-        required=True,
-        help="Path to input text file",
+        "-o", "--output", type=str, required=True, help="Path to input text file",
     )
     parser.add_argument(
         "-m",
diff --git a/src/batch_loading.py b/src/batch_loading.py
index 7821cc4..cb569f1 100644
--- a/src/batch_loading.py
+++ b/src/batch_loading.py
@@ -88,9 +88,7 @@ def get_batches(
 
         for i in range(batch_buffer_len):
             batch_ids = buffer_ids[
-                range(
-                    i * batch_size, min((i + 1) * batch_size, len(buffer_ids))
-                )
+                range(i * batch_size, min((i + 1) * batch_size, len(buffer_ids)))
             ]
 
             yield buffer.loc[batch_ids]
diff --git a/src/models/TransformerSeq2Seq.py b/src/models/TransformerSeq2Seq.py
index 5c5cefe..00cedf9 100644
--- a/src/models/TransformerSeq2Seq.py
+++ b/src/models/TransformerSeq2Seq.py
@@ -23,8 +23,7 @@ class PositionalEncoding(nn.Module):
         pe = torch.zeros(max_len, d_model)
         position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
         div_term = torch.exp(
-            torch.arange(0, d_model, 2).float()
-            * (-math.log(10000.0) / d_model)
+            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
         )
         pe[:, 0::2] = torch.sin(position * div_term)
         pe[:, 1::2] = torch.cos(position * div_term)
@@ -65,9 +64,7 @@ class TransformerSeq2Seq(nn.Module):
         self.word_embedding = nn.Embedding(vocab_size, embedding_size)
 
         # Add positional encoding
-        self.position_embedding = PositionalEncoding(
-            embedding_size, max_len, dropout
-        )
+        self.position_embedding = PositionalEncoding(embedding_size, max_len, dropout)
 
         # Combined encoder-decoder step
         self.core = nn.Transformer(
@@ -83,10 +80,7 @@ class TransformerSeq2Seq(nn.Module):
         self.embedding_to_words = nn.Linear(embedding_size, vocab_size)
 
     def forward(
-        self,
-        source: torch.Tensor,
-        target: torch.Tensor,
-        source_mask: torch.Tensor,
+        self, source: torch.Tensor, target: torch.Tensor, source_mask: torch.Tensor,
     ) -> torch.Tensor:
         """Full encoder-decoder pass
 
@@ -108,9 +102,7 @@ class TransformerSeq2Seq(nn.Module):
         y = self.word_embedding(y)
         y = self.position_embedding(y)
 
-        tgt_mask = self.core.generate_square_subsequent_mask(y.shape[0]).to(
-            y.device
-        )
+        tgt_mask = self.core.generate_square_subsequent_mask(y.shape[0]).to(y.device)
 
         z = self.core(
             x, y, src_key_padding_mask=source_mask, tgt_mask=tgt_mask
diff --git a/src/pipelines/actions_based/processing.py b/src/pipelines/actions_based/processing.py
index 7d4a67a..17ef803 100644
--- a/src/pipelines/actions_based/processing.py
+++ b/src/pipelines/actions_based/processing.py
@@ -1,6 +1,10 @@
 from transformers import BertTokenizerFast
-from src.processing import tokenize_labeled_text, batchify_data
+from src.processing import tokenize_labeled_text, batchify_data, empty_action_dict, encode_actions
 import numpy as np
+from typing import Optional
+import torch.nn as nn
+import torch
+from src.processing import token_labels_to_word_labels, recover_text
 
 
 def expand_dims(entry: dict):
@@ -88,3 +92,92 @@ FLATTEN_DIMS_META = {
     "output_shape": object,
     "attention_mask_shape": object,
 }
+
+
+def action_vector(actions: [str]) -> np.ndarray:
+    """Transforms array of label names into an action vector.
+
+    Args:
+        actions ([str]): Actions that should be in action vector (eg. ["dot", "upper_case"])
+
+    Returns:
+        np.ndarray: Action vector with provided actions
+    """
+    return encode_actions({
+        "dot": "dot" in actions,
+        "upper_case": "upper_case" in actions, 
+        "colon": "colon" in actions,
+        "question_mark": "question_mark" in actions,
+    })
+
+
+def last_stop_label(labels: np.array, stop_action: np.array) -> Optional[int]:
+    """Finds the position of the last sentence ending token
+
+    Args:
+        labels (np.array): Array of token-labels in form of action vectors (LxA shape)
+        stop_token (np.array): Action vector that mark a stop token (A shape)
+
+    Returns:
+        int: Index of the last found stop token in a sentence. None if no stop token is found
+    """
+
+    assert len(labels.shape) == 2   
+    assert len(stop_action.shape) == 1 
+    stop_labels = np.argwhere(np.all(labels == stop_action, axis=1))
+
+    if len(stop_labels) == 0:
+        return None
+
+    return stop_labels[-1][0]
+
+def apply_actions_punctuation(text: str, chunk_size: int, tokenizer: BertTokenizerFast, model: nn.Module, threshold: float = 0.9) -> str:
+    """Adds punctuation to text using actions model
+
+    Args:
+        text (str): Raw, unpuctuated text
+        chunk_size (int): Maxium number of tokens to precess at once (both memory & computing scales ~O(n^2))
+        tokenizer (BertTokenizerFast): Tokenizer to use
+        model (nn.Module): Trained actions model
+        threshold (float, optional): Threshold after which action will be applied. Defaults to 0.9.
+
+    Returns:
+        str: [description]
+    """
+
+    text = text.strip()
+
+    tokens = tokenizer(text, return_tensors='pt')['input_ids']
+    output = None
+
+    index_start = 0
+    while index_start < len(tokens[0]):
+        index_end = min(index_start + chunk_size, len(tokens[0]))
+
+        tokens_chunk = tokens[:, index_start:index_end]
+
+        raw_output = model(input_ids=tokens_chunk, token_type_ids=torch.zeros_like(tokens_chunk), attention_mask=torch.ones_like(tokens_chunk))[0].sigmoid()
+        raw_output = raw_output[0].detach().numpy()
+
+        actions = raw_output > threshold
+        offset = last_stop_label(actions, action_vector("dot"))
+
+
+        # Prevent infinite loop
+        if (offset is None) or (offset == 0):
+            offset = (index_end - index_start)
+
+        if output is None:
+            output = raw_output[0:offset]
+        else:
+            output = np.concatenate([output, raw_output[0:offset]], axis=0)
+
+        index_start += offset
+
+
+    assert len(output) == len(tokens[0])
+
+    word_labels = token_labels_to_word_labels(text, output[1:-1], tokenizer)
+    actions = word_labels > threshold
+
+    return recover_text(text, actions)
diff --git a/src/pipelines/actions_based/stage1_extraction.py b/src/pipelines/actions_based/stage1_extraction.py
index f6735d0..cf31c06 100644
--- a/src/pipelines/actions_based/stage1_extraction.py
+++ b/src/pipelines/actions_based/stage1_extraction.py
@@ -47,9 +47,7 @@ if __name__ == "__main__":
     print(f"Dashboard: {client.dashboard_link}")
 
     # Processing pipeline
-    df = dd.from_pandas(
-        pd.DataFrame({"file": files_paths}), npartitions=num_partitions
-    )
+    df = dd.from_pandas(pd.DataFrame({"file": files_paths}), npartitions=num_partitions)
 
     df = df.apply(
         process_file,
diff --git a/src/pipelines/actions_based/stage3_exploding.py b/src/pipelines/actions_based/stage3_exploding.py
index 9cff62b..2c270ce 100644
--- a/src/pipelines/actions_based/stage3_exploding.py
+++ b/src/pipelines/actions_based/stage3_exploding.py
@@ -24,14 +24,10 @@ if __name__ == "__main__":
 
     df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
 
-    df = df.apply(
-        expand_dims, result_type="expand", axis=1, meta=EXPAND_DIMS_META
-    )
+    df = df.apply(expand_dims, result_type="expand", axis=1, meta=EXPAND_DIMS_META)
     df = df.map_partitions(
         lambda x: x.apply(lambda y: y.explode(), axis=0), meta=EXPAND_DIMS_META
     )
-    df = df.apply(
-        flatten_dims, result_type="expand", axis=1, meta=FLATTEN_DIMS_META
-    )
+    df = df.apply(flatten_dims, result_type="expand", axis=1, meta=FLATTEN_DIMS_META)
 
     df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/src/pipelines/actions_based/stage5_stats.py b/src/pipelines/actions_based/stage5_stats.py
index 5a69d01..3dd895b 100644
--- a/src/pipelines/actions_based/stage5_stats.py
+++ b/src/pipelines/actions_based/stage5_stats.py
@@ -39,9 +39,7 @@ if __name__ == "__main__":
     print(client.dashboard_link)
 
     df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
-    df = df.apply(
-        expand_dims, result_type="expand", axis=1, meta=EXPAND_DIMS_META
-    )
+    df = df.apply(expand_dims, result_type="expand", axis=1, meta=EXPAND_DIMS_META)
 
     outputs_bag = df["output"].to_bag()
 
diff --git a/src/pipelines/actions_based/train.py b/src/pipelines/actions_based/train.py
index 266d422..ac54cef 100755
--- a/src/pipelines/actions_based/train.py
+++ b/src/pipelines/actions_based/train.py
@@ -12,6 +12,7 @@ from src.utils import (
     convert_to_timedelta,
     prepare_folder,
 )
+from src.training import latest_model, save_training_step
 from src.processing import ACTIONS_KEYS
 from datetime import datetime
 import pickle
@@ -32,8 +33,10 @@ if __name__ == "__main__":
     device_name = config["actions"]["training"]["device"]
     max_train_time = config["actions"]["training"]["max_training_time"]
     base_model = config["global"]["base_model"]
+    seed = config["global"]["random_seed"]
 
     prepare_folder(OUTPUT_PATH)
+    np.random.seed(seed=seed)
 
     if max_train_time is not None:
         max_train_time = convert_to_timedelta(max_train_time)
@@ -61,33 +64,25 @@ if __name__ == "__main__":
     sample_start = 0
     if fresh_start is False:
         checkpoint_files = glob.glob(f"{OUTPUT_PATH}/*.model")
-        furthest_epoch = -1
-        furthest_batch_num = -1
-        for checkpoint_file in checkpoint_files:
-            filename = checkpoint_file.split("/")[-1].split(".")[0]
-            epoch, iteration = filename.split("-")
-            epoch, iteration = int(epoch), int(iteration)
-
-            if epoch >= furthest_epoch:
-                furthest_epoch = epoch
-                furthest_batch_num = max(iteration, furthest_batch_num)
-
-        if furthest_epoch > -1 and furthest_batch_num > -1:
+        latest = latest_model(checkpoint_files)
+
+        if latest is not None:
+            epoch, batch = latest
             model.load_state_dict(
                 torch.load(
-                    f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.model",
+                    f"{OUTPUT_PATH}/{epoch}-{batch}.model",
                     map_location=device,
                 )
             )
             optimizer.load_state_dict(
                 torch.load(
-                    f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.optimizer",
+                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer",
                     map_location=device,
                 )
             )
 
-            epoch_start, sample_start = furthest_epoch, furthest_batch_num
-            print(f"Loaded {furthest_epoch}-{furthest_batch_num}")
+            epoch_start, sample_start = epoch, batch
+            print(f"Loaded {epoch}-{batch}")
 
     model.train()
     model.base_model.train()
@@ -107,9 +102,7 @@ if __name__ == "__main__":
             break
 
         i = sample_start
-        for data_batch in get_batches(
-            df, batch_size, 100, random_index_shuffle, i
-        ):
+        for data_batch in get_batches(df, batch_size, 100, random_index_shuffle, i):
             inputs = data_batch.apply(
                 lambda x: x["input"].reshape(x["input_shape"]), axis=1
             ).values
@@ -117,17 +110,13 @@ if __name__ == "__main__":
                 lambda x: x["output"].reshape(x["output_shape"]), axis=1
             ).values
             attentions_mask = data_batch.apply(
-                lambda x: x["attention_mask"].reshape(
-                    x["attention_mask_shape"]
-                ),
+                lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]),
                 axis=1,
             ).values
 
             inputs = torch.tensor(np.stack(inputs).squeeze()).to(device)
             outputs = torch.tensor(np.stack(outputs)).to(device)
-            attentions_mask = torch.tensor(np.stack(attentions_mask)).to(
-                device
-            )
+            attentions_mask = torch.tensor(np.stack(attentions_mask)).to(device)
 
             y_pred = model(input_ids=inputs, attention_mask=attentions_mask)[0]
 
@@ -141,27 +130,13 @@ if __name__ == "__main__":
 
             optimizer.zero_grad()
 
-            if i % save_step == 0 and (
-                i != sample_start or epoch != epoch_start
-            ):
+            if i % save_step == 0 and (i != sample_start or epoch != epoch_start):
                 print(f"Saving: Epoch {epoch}, step {i}")
-                torch.save(
-                    model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model"
-                )
-                torch.save(
-                    optimizer.state_dict(),
-                    f"{OUTPUT_PATH}/{epoch}-{i}.optimizer",
-                )
+                save_training_step(OUTPUT_PATH, f"{epoch}-{i}", model, optimizer)
 
             if datetime.now() > time_max:
                 print(f"Max time reached, saving: Epoch {epoch}, step {i}")
-                torch.save(
-                    model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model"
-                )
-                torch.save(
-                    optimizer.state_dict(),
-                    f"{OUTPUT_PATH}/{epoch}-{i}.optimizer",
-                )
+                save_training_step(OUTPUT_PATH, f"{epoch}-{i}", model, optimizer)
                 training_stopped = True
                 break
 
@@ -171,5 +146,4 @@ if __name__ == "__main__":
             i += 1
 
     if not training_stopped:
-        torch.save(model.state_dict(), f"{OUTPUT_PATH}/final.model")
-        torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/final.optimizer")
+        save_training_step(OUTPUT_PATH, "final", model, optimizer)
diff --git a/src/pipelines/translation_based/processing.py b/src/pipelines/translation_based/processing.py
index 3b02a4d..4dfddc1 100644
--- a/src/pipelines/translation_based/processing.py
+++ b/src/pipelines/translation_based/processing.py
@@ -170,9 +170,7 @@ def find_new_sentence_right(seq: np.array, pos: int) -> int:
     return None
 
 
-def get_batch_indexes(
-    seq: np.array, min_length: int, max_length: int
-) -> [np.array]:
+def get_batch_indexes(seq: np.array, min_length: int, max_length: int) -> [np.array]:
     """Turns long sequence into array of indices, composing a single batch file.
 
     Args:
@@ -212,9 +210,7 @@ def get_batch_indexes(
     return batch
 
 
-def add_padding(
-    seq: np.ndarray, total_length: int, padding_symbol: any
-) -> np.ndarray:
+def add_padding(seq: np.ndarray, total_length: int, padding_symbol: any) -> np.ndarray:
     """Pads a sequence with provided symbol, to get array of length total_length in the end
 
     Args:
@@ -229,9 +225,7 @@ def add_padding(
     assert num_padding >= 0
 
     if num_padding > 0:
-        return np.concatenate(
-            [seq, np.array([padding_symbol] * num_padding)], axis=0
-        )
+        return np.concatenate([seq, np.array([padding_symbol] * num_padding)], axis=0)
     else:
         return np.copy(seq)
 
@@ -273,9 +267,7 @@ def standarize_translation_sample(
         np.ndarray: Output sequence of length total_length
     """
     return add_padding(
-        add_begin_end_tokens(seq, begin_token, end_token),
-        total_length,
-        padding_symbol,
+        add_begin_end_tokens(seq, begin_token, end_token), total_length, padding_symbol,
     )
 
 
@@ -344,9 +336,7 @@ def crete_input_output_batch(
     source_batch = []
     target_batch = []
     for entry in base_batch:
-        source_entry, target_entry = create_input_output(
-            entry, length, tokenizer
-        )
+        source_entry, target_entry = create_input_output(entry, length, tokenizer)
 
         source_batch.append(source_entry)
         target_batch.append(target_entry)
diff --git a/src/pipelines/translation_based/stage1_extraction.py b/src/pipelines/translation_based/stage1_extraction.py
index 5f2758f..8581407 100644
--- a/src/pipelines/translation_based/stage1_extraction.py
+++ b/src/pipelines/translation_based/stage1_extraction.py
@@ -32,15 +32,10 @@ if __name__ == "__main__":
     print(f"Dashboard: {client.dashboard_link}")
 
     # Processing pipeline
-    df = dd.from_pandas(
-        pd.DataFrame({"file": files_paths}), npartitions=num_partitions
-    )
+    df = dd.from_pandas(pd.DataFrame({"file": files_paths}), npartitions=num_partitions)
 
     df = df.apply(
-        raw_to_dataframe,
-        result_type="expand",
-        axis=1,
-        meta=RAW_TO_DATAFRAME_META,
+        raw_to_dataframe, result_type="expand", axis=1, meta=RAW_TO_DATAFRAME_META,
     )
     df = df.dropna()
 
diff --git a/src/pipelines/translation_based/stage2_create_batches.py b/src/pipelines/translation_based/stage2_create_batches.py
index 9b65e08..c1b7b0d 100644
--- a/src/pipelines/translation_based/stage2_create_batches.py
+++ b/src/pipelines/translation_based/stage2_create_batches.py
@@ -16,9 +16,7 @@ if __name__ == "__main__":
 
     config = get_config()
     num_workers = config["translations"]["create_batches"]["num_workers"]
-    memory_limit = config["translations"]["create_batches"][
-        "worker_memory_limit"
-    ]
+    memory_limit = config["translations"]["create_batches"]["worker_memory_limit"]
     min_tokens = config["translations"]["create_batches"]["min_tokens"]
     max_tokens = config["translations"]["create_batches"]["max_tokens"]
     base_model = config["global"]["base_model"]
diff --git a/src/pipelines/translation_based/stage3_exploding.py b/src/pipelines/translation_based/stage3_exploding.py
index 4123abd..74646a5 100644
--- a/src/pipelines/translation_based/stage3_exploding.py
+++ b/src/pipelines/translation_based/stage3_exploding.py
@@ -24,14 +24,10 @@ if __name__ == "__main__":
 
     df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
 
-    df = df.apply(
-        expand_dims, result_type="expand", axis=1, meta=EXPAND_DIMS_META
-    )
+    df = df.apply(expand_dims, result_type="expand", axis=1, meta=EXPAND_DIMS_META)
     df = df.map_partitions(
         lambda x: x.apply(lambda y: y.explode(), axis=0), meta=EXPAND_DIMS_META
     )
-    df = df.apply(
-        flatten_dims, result_type="expand", axis=1, meta=FLATTEN_DIMS_META
-    )
+    df = df.apply(flatten_dims, result_type="expand", axis=1, meta=FLATTEN_DIMS_META)
 
     df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/src/pipelines/translation_based/train.py b/src/pipelines/translation_based/train.py
index fa7f838..8965bda 100755
--- a/src/pipelines/translation_based/train.py
+++ b/src/pipelines/translation_based/train.py
@@ -11,8 +11,10 @@ from src.utils import (
     convert_to_timedelta,
     prepare_folder,
 )
+from src.training import latest_model, save_training_step
 from datetime import datetime
 from src.models.TransformerSeq2Seq import TransformerSeq2Seq
+from src.batch_loading import get_batches, get_ordered_dataframe_len
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/translations/stage4_reindexing"
 OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/translations"
@@ -24,15 +26,15 @@ if __name__ == "__main__":
     num_epochs = config["translations"]["training"]["num_epochs"]
     batch_size = config["translations"]["training"]["batch_size"]
     save_step = config["translations"]["training"]["save_step"]
-    loss_averaging_span = config["translations"]["training"][
-        "loss_averaging_span"
-    ]
+    loss_averaging_span = config["translations"]["training"]["loss_averaging_span"]
     fresh_start = config["translations"]["training"]["fresh_start"]
     device_name = config["translations"]["training"]["device"]
     max_train_time = config["translations"]["training"]["max_training_time"]
     base_model = config["global"]["base_model"]
+    seed = config["global"]["random_seed"]
 
     prepare_folder(OUTPUT_PATH)
+    np.random.seed(seed=seed)
 
     if max_train_time is not None:
         max_train_time = convert_to_timedelta(max_train_time)
@@ -44,9 +46,7 @@ if __name__ == "__main__":
 
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
-    model = TransformerSeq2Seq(
-        tokenizer.vocab_size, 256, max_len, 4, 4, 4,
-    ).to(device)
+    model = TransformerSeq2Seq(tokenizer.vocab_size, 256, max_len, 4, 4, 4,).to(device)
     criterion = torch.nn.CrossEntropyLoss(reduction="mean").to(device)
     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
@@ -54,35 +54,33 @@ if __name__ == "__main__":
     sample_start = 0
     if fresh_start is False:
         checkpoint_files = glob.glob(f"{OUTPUT_PATH}/*.model")
-        furthest_epoch = -1
-        furthest_batch_num = -1
-        for checkpoint_file in checkpoint_files:
-            filename = checkpoint_file.split("/")[-1].split(".")[0]
-            epoch, iteration = filename.split("-")
-            epoch, iteration = int(epoch), int(iteration)
-
-            if epoch >= furthest_epoch:
-                furthest_epoch = epoch
-                furthest_batch_num = max(iteration, furthest_batch_num)
-
-        if furthest_epoch > -1 and furthest_batch_num > -1:
+        latest = latest_model(checkpoint_files)
+
+        if latest is not None:
+            epoch, batch = latest
             model.load_state_dict(
                 torch.load(
-                    f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.model"
+                    f"{OUTPUT_PATH}/{epoch}-{batch}.model",
+                    map_location=device,
                 )
             )
             optimizer.load_state_dict(
                 torch.load(
-                    f"{OUTPUT_PATH}/{furthest_epoch}-{furthest_batch_num}.optimizer"
+                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer",
+                    map_location=device,
                 )
             )
 
-            epoch_start, sample_start = furthest_epoch, furthest_batch_num
-            print(f"Loaded {furthest_epoch}-{furthest_batch_num}")
+            epoch_start, sample_start = epoch, batch
+            print(f"Loaded {epoch}-{batch}")
 
     model.train()
+    model.base_model.train()
     losses = []
 
+    num_samples = get_ordered_dataframe_len(df)
+    random_index_shuffle = np.random.permutation(range(num_samples))
+
     training_stopped = False
 
     time_max = datetime.max
@@ -94,20 +92,7 @@ if __name__ == "__main__":
             break
 
         i = sample_start
-
-        while True:
-            data_batch_indexes = list(
-                range(i * batch_size, i * batch_size + batch_size)
-            )
-
-            # Precomputing total number of samples takes very long, so lets
-            # try to get next batch until fail :)
-            try:
-                data_batch = df.loc[data_batch_indexes].compute()
-            except Exception:
-                # TODO: Specify exception type
-                break
-
+        for data_batch in get_batches(df, batch_size, 100, random_index_shuffle, i):
             inputs = data_batch.apply(
                 lambda x: x["source"].reshape(x["source_shape"]), axis=1
             ).values
@@ -115,18 +100,14 @@ if __name__ == "__main__":
                 lambda x: x["target"].reshape(x["target_shape"]), axis=1
             ).values
             attentions_mask = data_batch.apply(
-                lambda x: x["attention_mask"].reshape(
-                    x["attention_mask_shape"]
-                ),
+                lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]),
                 axis=1,
             ).values
 
-            inputs = torch.tensor(
-                np.stack(inputs, axis=0), dtype=torch.long
-            ).to(device)
-            attentions_mask = torch.tensor(
-                np.stack(attentions_mask, axis=0) == 0
-            ).to(device)
+            inputs = torch.tensor(np.stack(inputs, axis=0), dtype=torch.long).to(device)
+            attentions_mask = torch.tensor(np.stack(attentions_mask, axis=0) == 0).to(
+                device
+            )
             output_indices = torch.tensor(
                 np.stack(outputs, axis=0), dtype=torch.long
             ).to(device)
@@ -144,27 +125,13 @@ if __name__ == "__main__":
 
             optimizer.zero_grad()
 
-            if i % save_step == 0 and (
-                i != sample_start or epoch != epoch_start
-            ):
+            if i % save_step == 0 and (i != sample_start or epoch != epoch_start):
                 print(f"Saving: Epoch {epoch}, step {i}")
-                torch.save(
-                    model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model"
-                )
-                torch.save(
-                    optimizer.state_dict(),
-                    f"{OUTPUT_PATH}/{epoch}-{i}.optimizer",
-                )
+                save_training_step(OUTPUT_PATH, f"{epoch}-{i}", model, optimizer)
 
             if datetime.now() > time_max:
                 print(f"Max time reached, saving: Epoch {epoch}, step {i}")
-                torch.save(
-                    model.state_dict(), f"{OUTPUT_PATH}/{epoch}-{i}.model"
-                )
-                torch.save(
-                    optimizer.state_dict(),
-                    f"{OUTPUT_PATH}/{epoch}-{i}.optimizer",
-                )
+                save_training_step(OUTPUT_PATH, f"{epoch}-{i}", model, optimizer)
                 training_stopped = True
                 break
 
@@ -174,5 +141,4 @@ if __name__ == "__main__":
             i += 1
 
     if not training_stopped:
-        torch.save(model.state_dict(), f"{OUTPUT_PATH}/final.model")
-        torch.save(optimizer.state_dict(), f"{OUTPUT_PATH}/final.optimizer")
+        save_training_step(OUTPUT_PATH, "final", model, optimizer)
diff --git a/src/processing.py b/src/processing.py
index 7b14493..d61be86 100644
--- a/src/processing.py
+++ b/src/processing.py
@@ -17,6 +17,16 @@ def empty_action_vector() -> np.ndarray:
     return np.zeros(len(ACTIONS_KEYS))
 
 
+def empty_action_dict() -> dict:
+    """Returns a do-noting unencoded action dict
+
+    Returns:
+        dict: Action dict with all actions set to False
+    """
+
+    return decode_actions(empty_action_vector())
+
+
 def text_from_xml(path: str) -> str:
     """Extract spoken text from dataset's xml format
 
@@ -61,9 +71,7 @@ def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
     word.replace("(", "")
     word.replace(")", "")
 
-    while (
-        len(word) > 0 and not word[0].isalnum()
-    ):  # remove proceding characters
+    while len(word) > 0 and not word[0].isalnum():  # remove proceding characters
         word = word[1:]
 
     if len(word) == 0:
@@ -140,9 +148,7 @@ def create_model_input_output(text: str) -> (str, np.ndarray):
     return " ".join(words_output), np.array(actions_output)
 
 
-def token_word_mapping(
-    text: str, tokenizer: PreTrainedTokenizerFast
-) -> np.ndarray:
+def token_word_mapping(text: str, tokenizer: PreTrainedTokenizerFast) -> np.ndarray:
     """Returns mapping where each token is labeled with index of word it's part of
 
     Args:
@@ -348,14 +354,10 @@ def batchify_labels(
                     )
                     break
         else:
-            labels_batches.append(
-                np.array(list(range(index, index + num_consumed)))
-            )
+            labels_batches.append(np.array(list(range(index, index + num_consumed))))
             break
 
-        labels_batches.append(
-            np.array(list(range(index, index + num_consumed)))
-        )
+        labels_batches.append(np.array(list(range(index, index + num_consumed))))
 
         index = new_index
 
@@ -408,9 +410,7 @@ def add_padding(
     assert pad_length >= 0
 
     if pad_length > 0:
-        tokens = np.concatenate(
-            [tokens, [[tokenizer.pad_token_id]] * pad_length]
-        )
+        tokens = np.concatenate([tokens, [[tokenizer.pad_token_id]] * pad_length])
         labels = np.concatenate([labels, [empty_action_vector()] * pad_length])
 
     mask = np.ones(len(tokens)).astype(np.int)
diff --git a/src/training.py b/src/training.py
new file mode 100644
index 0000000..0ef4b9c
--- /dev/null
+++ b/src/training.py
@@ -0,0 +1,58 @@
+from typing import Tuple, Optional
+import re
+import torch.nn as nn
+import torch.optim as optim
+import torch
+from src.utils import prepare_folder
+
+
+def latest_model(file_paths: [str]) -> Optional[Tuple[int, int]]:
+    """Finds newest model in directory
+
+    Args:
+        files ([str]): List of all file paths that will be considered. File extension is discarded
+                       File names must be in format epoch_num-batch_num.extension
+
+    Returns:
+        (int, int): Tuple of (latest_batch, latest_step) for latest model
+    """
+
+    furthest_epoch = -1
+    furthest_batch_num = -1
+    for checkpoint_file in file_paths:
+        filename = checkpoint_file.split("/")[-1].split(".")[0]
+
+        result = re.search(r"^(\d+)-(\d+)$", filename)
+        if result is not None:
+            epoch, batch = [int(x) for x in result.groups()]
+
+            if epoch > furthest_epoch:
+                furthest_epoch = epoch
+                furthest_batch_num = batch
+            elif epoch == furthest_epoch:
+                furthest_batch_num = max(batch, furthest_batch_num)
+
+    if (furthest_epoch == -1) or (furthest_batch_num == -1):
+        return None
+
+    return furthest_epoch, furthest_batch_num
+
+
+def save_training_step(dir: str, name: str, model: nn.Module, optimizer: Optional[optim.Optimizer] = None, create_dir: bool = False) -> None:
+    """Saves a trainig step to a directory
+
+    Args:
+        dir (str): Directory where step will be saved
+        name (str): Name of the step (eg. "0-1000")
+        model (nn.Module): model that will be saved
+        optimizer (optim.Optimizer): optimizer that will be saved. Might be None
+    """
+    if create_dir:
+        prepare_folder(dir, wipe=False)
+
+    torch.save(model.state_dict(), f"{dir}/{name}.model")
+
+    if optimizer is not None:
+        torch.save(
+            optimizer.state_dict(), f"{dir}/{name}.optimizer",
+        )
diff --git a/src/utils.py b/src/utils.py
index 2217a36..6ee5337 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -5,9 +5,7 @@ from datetime import timedelta
 from typing import Optional
 import shutil
 
-PROJECT_ROOT = os.path.dirname(
-    os.path.realpath("/".join(__file__.split("/")) + "/..")
-)
+PROJECT_ROOT = os.path.dirname(os.path.realpath("/".join(__file__.split("/")) + "/.."))
 
 
 def get_config() -> dict:
diff --git a/tests/pipelines/actions_based/__init__.py b/tests/pipelines/actions_based/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/pipelines/actions_based/test_processing.py b/tests/pipelines/actions_based/test_processing.py
new file mode 100644
index 0000000..c7e2767
--- /dev/null
+++ b/tests/pipelines/actions_based/test_processing.py
@@ -0,0 +1,24 @@
+from src.pipelines.actions_based.processing import last_stop_label, action_vector
+from src.processing import encode_actions, empty_action_vector
+import numpy as np
+
+
+def test_action_vector():
+    expected = encode_actions({
+        "dot": True,
+        "upper_case": True,
+        "colon": False,
+        "question_mark": False,
+    })
+
+    assert np.all(action_vector(["dot", "upper_case"]) == expected)
+
+def test_last_stop_label():
+    stop_action = action_vector(["Dot"])
+    not_stop_action = action_vector(["upper_case"])
+
+    labels = np.array([not_stop_action, not_stop_action, stop_action, not_stop_action])
+
+    res = last_stop_label(labels, stop_action)
+
+    assert last_stop_label(labels, stop_action) == 2
\ No newline at end of file
diff --git a/tests/pipelines/translation_based/test_processing.py b/tests/pipelines/translation_based/test_processing.py
index b83e01b..8ef4c69 100644
--- a/tests/pipelines/translation_based/test_processing.py
+++ b/tests/pipelines/translation_based/test_processing.py
@@ -86,9 +86,7 @@ def test_standarize_translation_sample():
 
 def test_create_input_output():
     sequence = [56500, 117, 10824, 30186, 11090, 10113, 119]
-    tokenizer = BertTokenizerFast.from_pretrained(
-        "bert-base-multilingual-cased"
-    )
+    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
 
     expected_output_sequence = [
         tokenizer.cls_token_id,
@@ -126,9 +124,7 @@ def test_create_input_output():
 
 
 def test_create_input_output_batch():
-    tokenizer = BertTokenizerFast.from_pretrained(
-        "bert-base-multilingual-cased"
-    )
+    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
 
     expected_output_1 = np.array(tokenizer("Ala, ma KoTa.")["input_ids"])[1:-1]
     expected_output_2 = np.array(tokenizer("A kOt nie!")["input_ids"])[1:-1]
diff --git a/tests/test_batch_loading.py b/tests/test_batch_loading.py
index 3ccc510..cde65ad 100644
--- a/tests/test_batch_loading.py
+++ b/tests/test_batch_loading.py
@@ -5,7 +5,7 @@ from src.batch_loading import (
     calculate_batch_buffer_id,
     yield_batch_buffer_span,
     get_ordered_dataframe_len,
-    get_batches
+    get_batches,
 )
 
 
@@ -49,9 +49,7 @@ def test_get_batches():
     assert np.all(batches[2]["a"].values == [4, 5])
     assert np.all(batches[3]["a"].values == [6])
 
-    batches = list(
-        get_batches(df, batch_size, batch_buffer_len, shuffled_ids, 1)
-    )
+    batches = list(get_batches(df, batch_size, batch_buffer_len, shuffled_ids, 1))
 
     assert np.all(batches[1]["a"].values == [2, 3])
     assert np.all(batches[2]["a"].values == [4, 5])
diff --git a/tests/test_processing.py b/tests/test_processing.py
index 45ac686..2aeed6e 100644
--- a/tests/test_processing.py
+++ b/tests/test_processing.py
@@ -11,7 +11,7 @@ from src.processing import (
     batchify_labels,
     batchify_data,
     ACTIONS_KEYS,
-    decode_actions
+    decode_actions,
 )
 from transformers import BertTokenizerFast
 import pytest
@@ -68,9 +68,7 @@ def test_decode_actions():
 
 def test_token_word_mapping():
     text = "janek poszedł do ogrodu"
-    tokenizer = BertTokenizerFast.from_pretrained(
-        "bert-base-multilingual-cased"
-    )
+    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
 
     text_tokenized = tokenizer(text)
 
@@ -84,9 +82,7 @@ def test_token_word_mapping():
 def test_token_labels_to_word_labels():
     text = "janek poszedł do ogrodu"
     labels = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
-    tokenizer = BertTokenizerFast.from_pretrained(
-        "bert-base-multilingual-cased"
-    )
+    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
 
     _, token_labels = tokenize_labeled_text(text, labels, tokenizer)
 
@@ -97,9 +93,7 @@ def test_token_labels_to_word_labels():
 
 def test_tokenize_labeled_text():
     text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię?"
-    tokenizer = BertTokenizerFast.from_pretrained(
-        "bert-base-multilingual-cased"
-    )
+    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
 
     text_clean, labels = create_model_input_output(text)
     tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
@@ -173,13 +167,8 @@ def test_batchify_labels():
 
 
 def test_batchify_data():
-    text = (
-        "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam"
-        " niedzwiedzia?"
-    )
-    tokenizer = BertTokenizerFast.from_pretrained(
-        "bert-base-multilingual-cased"
-    )
+    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam niedzwiedzia?"
+    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
 
     text_clean, labels = create_model_input_output(text)
     tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
diff --git a/tests/test_training.py b/tests/test_training.py
new file mode 100644
index 0000000..2aa5d6a
--- /dev/null
+++ b/tests/test_training.py
@@ -0,0 +1,21 @@
+from src.training import latest_model
+
+
+def test_latest_model():
+    files = []
+    assert latest_model(files) is None
+
+    files.append("/path/tam/pam/Wrongformat.b")
+    assert latest_model(files) is None
+
+    files.append("/path/tam/pam/0-2000.b")
+    assert latest_model(files) == (0, 2000)
+
+    files.append("/path/tam/pam/0-3000.c")
+    assert latest_model(files) == (0, 3000)
+
+    files.append("/path/tam/pam/1-1000.a")
+    assert latest_model(files) == (1, 1000)
+
+    files.append("/path/tam/pam/1-500.a")
+    assert latest_model(files) == (1, 1000)
diff --git a/tox.ini b/tox.ini
index 43a673d..c7d1dde 100644
--- a/tox.ini
+++ b/tox.ini
@@ -36,7 +36,7 @@ exclude =
 max-complexity = 10
 max-line-length = 80
 select = C,E,F,W,B,B950
-ignore = E203, E501, W503
+ignore = E203, E501, W503, C901
 
 
 [testenv:pep8]
diff --git a/worker.py b/worker.py
new file mode 100755
index 0000000..49fd847
--- /dev/null
+++ b/worker.py
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+import nlp_ws
+import shutil
+from src.pipelines.actions_based.processing import apply_actions_punctuation
+from transformers import BertTokenizerFast, BertForTokenClassification, PretrainedConfig
+import configparser
+from src.processing import ACTIONS_KEYS
+import torch
+
+class Worker(nlp_ws.NLPWorker):
+    """Class that implements example worker."""
+
+    def init(self):
+        self.config = configparser.ConfigParser()
+        self.config.read("config.ini")
+
+        config = PretrainedConfig.from_pretrained(self.config['deployment']['base_model'])
+        config.num_labels = len(ACTIONS_KEYS)
+
+        device = torch.device("cpu")
+        self.threshold = self.config['deployment']['threshold']
+        self.chunk_size = self.config['deployment']['chunk_size']
+        self.tokenizer = BertTokenizerFast.from_pretrained(self.config['deployment']['base_model'])
+        self.model = BertForTokenClassification(config)
+        self.model.load_state_dict(torch.load(self.config['deployment']['model'], map_location=device))
+
+    def process(self, input_file: str, task_options: dict, output_file: str) -> None:
+        """Implementation of example tasks that copies files."""
+
+        with open(input_file, 'r') as f:
+            text = f.read()
+            text_processed = apply_actions_punctuation(text, self.chunk_size, self.tokenizer, self.model, self.threshold)
+            
+        with  open(output_file, 'w') as f:
+            f.write(text_processed)
+
+        shutil.copy(input_file, output_file)
+
+
+if __name__ == '__main__':
+    nlp_ws.NLPService.main(Worker)
-- 
GitLab


From 4887892ebcda364577a9afa5360101552821d875 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 13:12:46 +0200
Subject: [PATCH 041/116] In theory, worker should now work

---
 Dockerfile                                    |  3 +-
 punctuate.py                                  | 47 +++++++++++++++++--
 src/pipelines/actions_based/processing.py     | 46 ++++++++++++------
 src/pipelines/actions_based/train.py          |  8 +---
 src/pipelines/actions_based/utils.py          | 30 ++++++++++++
 src/pipelines/translation_based/train.py      |  8 +---
 src/training.py                               |  8 +++-
 .../actions_based/test_processing.py          | 12 ++---
 worker.py                                     | 34 +++++++-------
 9 files changed, 137 insertions(+), 59 deletions(-)
 create mode 100644 src/pipelines/actions_based/utils.py

diff --git a/Dockerfile b/Dockerfile
index f21915e..b467b0f 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -12,5 +12,4 @@ COPY worker.py .
 RUN mkdir ./deploy && \
     wget https://minio.clarin-pl.eu/public/models/punctuation/0-190000.model -O deploy/model
     
-RUN ls ./deploy
-CMD ["./worker.py"]
\ No newline at end of file
+CMD [ "./worker.py" ]
\ No newline at end of file
diff --git a/punctuate.py b/punctuate.py
index f46c216..d411cd0 100755
--- a/punctuate.py
+++ b/punctuate.py
@@ -3,6 +3,8 @@
 import argparse
 import os
 import sys
+from src.pipelines.actions_based.utils import load_model
+from src.pipelines.actions_based.processing import apply_actions_punctuation
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(
@@ -17,13 +19,48 @@ if __name__ == "__main__":
     parser.add_argument(
         "-m",
         "--model",
+        required=True,
         type=str,
-        choices=["actions", "translation"],
-        default="actions",
-        help="Selects which model will be used. Defaults to actions",
+        help="Path to the pretrained model",
+    )
+    parser.add_argument(
+        "-b",
+        "--base",
+        required=True,
+        type=str,
+        help="Name of base model",
+    )
+    parser.add_argument(
+        '-c',
+        '--chunk_size',
+        default=500,
+        type=int,
+        help="Maximum chunk size"
+    )
+    parser.add_argument(
+        '-t',
+        '--threshold',
+        default=0.9,
+        type=float,
+        help="Threshold"
     )
 
     args = parser.parse_args()
 
-    if not os.path.exists(args.input):
-        print(f"Error: File '{args.input}' does not exists", file=sys.stderr)
+    #if not os.path.exists(args.input):
+    #    print(f"Error: File '{args.input}' does not exists", file=sys.stderr)
+
+    tokenizer, model = load_model(
+        args.model,
+        args.base,
+        "cpu"
+    )
+
+    with open(args.input, "r") as f:
+        text = f.read()
+        text_processed = apply_actions_punctuation(
+            text, args.chunk_size, tokenizer, model, args.threshold
+        )
+
+    with open(args.output, "w") as f:
+        f.write(text_processed)
\ No newline at end of file
diff --git a/src/pipelines/actions_based/processing.py b/src/pipelines/actions_based/processing.py
index 17ef803..8f043e6 100644
--- a/src/pipelines/actions_based/processing.py
+++ b/src/pipelines/actions_based/processing.py
@@ -1,5 +1,10 @@
 from transformers import BertTokenizerFast
-from src.processing import tokenize_labeled_text, batchify_data, empty_action_dict, encode_actions
+from src.processing import (
+    tokenize_labeled_text,
+    batchify_data,
+    empty_action_dict,
+    encode_actions,
+)
 import numpy as np
 from typing import Optional
 import torch.nn as nn
@@ -103,12 +108,14 @@ def action_vector(actions: [str]) -> np.ndarray:
     Returns:
         np.ndarray: Action vector with provided actions
     """
-    return encode_actions({
-        "dot": "dot" in actions,
-        "upper_case": "upper_case" in actions, 
-        "colon": "colon" in actions,
-        "question_mark": "question_mark" in actions,
-    })
+    return encode_actions(
+        {
+            "dot": "dot" in actions,
+            "upper_case": "upper_case" in actions,
+            "colon": "colon" in actions,
+            "question_mark": "question_mark" in actions,
+        }
+    )
 
 
 def last_stop_label(labels: np.array, stop_action: np.array) -> Optional[int]:
@@ -122,8 +129,8 @@ def last_stop_label(labels: np.array, stop_action: np.array) -> Optional[int]:
         int: Index of the last found stop token in a sentence. None if no stop token is found
     """
 
-    assert len(labels.shape) == 2   
-    assert len(stop_action.shape) == 1 
+    assert len(labels.shape) == 2
+    assert len(stop_action.shape) == 1
     stop_labels = np.argwhere(np.all(labels == stop_action, axis=1))
 
     if len(stop_labels) == 0:
@@ -131,7 +138,14 @@ def last_stop_label(labels: np.array, stop_action: np.array) -> Optional[int]:
 
     return stop_labels[-1][0]
 
-def apply_actions_punctuation(text: str, chunk_size: int, tokenizer: BertTokenizerFast, model: nn.Module, threshold: float = 0.9) -> str:
+
+def apply_actions_punctuation(
+    text: str,
+    chunk_size: int,
+    tokenizer: BertTokenizerFast,
+    model: nn.Module,
+    threshold: float = 0.9,
+) -> str:
     """Adds punctuation to text using actions model
 
     Args:
@@ -147,7 +161,7 @@ def apply_actions_punctuation(text: str, chunk_size: int, tokenizer: BertTokeniz
 
     text = text.strip()
 
-    tokens = tokenizer(text, return_tensors='pt')['input_ids']
+    tokens = tokenizer(text, return_tensors="pt")["input_ids"]
     output = None
 
     index_start = 0
@@ -156,16 +170,19 @@ def apply_actions_punctuation(text: str, chunk_size: int, tokenizer: BertTokeniz
 
         tokens_chunk = tokens[:, index_start:index_end]
 
-        raw_output = model(input_ids=tokens_chunk, token_type_ids=torch.zeros_like(tokens_chunk), attention_mask=torch.ones_like(tokens_chunk))[0].sigmoid()
+        raw_output = model(
+            input_ids=tokens_chunk,
+            token_type_ids=torch.zeros_like(tokens_chunk),
+            attention_mask=torch.ones_like(tokens_chunk),
+        )[0].sigmoid()
         raw_output = raw_output[0].detach().numpy()
 
         actions = raw_output > threshold
         offset = last_stop_label(actions, action_vector("dot"))
 
-
         # Prevent infinite loop
         if (offset is None) or (offset == 0):
-            offset = (index_end - index_start)
+            offset = index_end - index_start
 
         if output is None:
             output = raw_output[0:offset]
@@ -174,7 +191,6 @@ def apply_actions_punctuation(text: str, chunk_size: int, tokenizer: BertTokeniz
 
         index_start += offset
 
-
     assert len(output) == len(tokens[0])
 
     word_labels = token_labels_to_word_labels(text, output[1:-1], tokenizer)
diff --git a/src/pipelines/actions_based/train.py b/src/pipelines/actions_based/train.py
index ac54cef..45c300c 100755
--- a/src/pipelines/actions_based/train.py
+++ b/src/pipelines/actions_based/train.py
@@ -69,15 +69,11 @@ if __name__ == "__main__":
         if latest is not None:
             epoch, batch = latest
             model.load_state_dict(
-                torch.load(
-                    f"{OUTPUT_PATH}/{epoch}-{batch}.model",
-                    map_location=device,
-                )
+                torch.load(f"{OUTPUT_PATH}/{epoch}-{batch}.model", map_location=device,)
             )
             optimizer.load_state_dict(
                 torch.load(
-                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer",
-                    map_location=device,
+                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer", map_location=device,
                 )
             )
 
diff --git a/src/pipelines/actions_based/utils.py b/src/pipelines/actions_based/utils.py
new file mode 100644
index 0000000..d94a4e6
--- /dev/null
+++ b/src/pipelines/actions_based/utils.py
@@ -0,0 +1,30 @@
+from transformers import BertTokenizerFast, BertForTokenClassification, PretrainedConfig
+from src.processing import ACTIONS_KEYS
+import torch
+import torch.nn as nn
+from typing import Tuple
+
+
+def load_model(
+    model_path: str, base_model: str, device: str = "cpu"
+) -> Tuple[BertTokenizerFast, nn.Module]:
+    """Load pretrained model and it's tokenizer
+
+    Args:
+        model_path (str): Path to pretrained model
+        base_model (str): Name of base model
+        device (str, optional): Device on which model will be loaded. Defaults to "cpu".
+
+    Returns:
+        (BertTokenizerFast, nn.Module): Tokenizer & model
+    """
+
+    config = PretrainedConfig.from_pretrained(base_model)
+    config.num_labels = len(ACTIONS_KEYS)
+
+    device = torch.device(device)
+    tokenizer = BertTokenizerFast.from_pretrained(base_model)
+    model = BertForTokenClassification(config)
+    model.load_state_dict(torch.load(model_path, map_location=device))
+
+    return tokenizer, model
diff --git a/src/pipelines/translation_based/train.py b/src/pipelines/translation_based/train.py
index 8965bda..f6f5f4d 100755
--- a/src/pipelines/translation_based/train.py
+++ b/src/pipelines/translation_based/train.py
@@ -59,15 +59,11 @@ if __name__ == "__main__":
         if latest is not None:
             epoch, batch = latest
             model.load_state_dict(
-                torch.load(
-                    f"{OUTPUT_PATH}/{epoch}-{batch}.model",
-                    map_location=device,
-                )
+                torch.load(f"{OUTPUT_PATH}/{epoch}-{batch}.model", map_location=device,)
             )
             optimizer.load_state_dict(
                 torch.load(
-                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer",
-                    map_location=device,
+                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer", map_location=device,
                 )
             )
 
diff --git a/src/training.py b/src/training.py
index 0ef4b9c..76b21d5 100644
--- a/src/training.py
+++ b/src/training.py
@@ -38,7 +38,13 @@ def latest_model(file_paths: [str]) -> Optional[Tuple[int, int]]:
     return furthest_epoch, furthest_batch_num
 
 
-def save_training_step(dir: str, name: str, model: nn.Module, optimizer: Optional[optim.Optimizer] = None, create_dir: bool = False) -> None:
+def save_training_step(
+    dir: str,
+    name: str,
+    model: nn.Module,
+    optimizer: Optional[optim.Optimizer] = None,
+    create_dir: bool = False,
+) -> None:
     """Saves a trainig step to a directory
 
     Args:
diff --git a/tests/pipelines/actions_based/test_processing.py b/tests/pipelines/actions_based/test_processing.py
index c7e2767..1305960 100644
--- a/tests/pipelines/actions_based/test_processing.py
+++ b/tests/pipelines/actions_based/test_processing.py
@@ -4,15 +4,13 @@ import numpy as np
 
 
 def test_action_vector():
-    expected = encode_actions({
-        "dot": True,
-        "upper_case": True,
-        "colon": False,
-        "question_mark": False,
-    })
+    expected = encode_actions(
+        {"dot": True, "upper_case": True, "colon": False, "question_mark": False,}
+    )
 
     assert np.all(action_vector(["dot", "upper_case"]) == expected)
 
+
 def test_last_stop_label():
     stop_action = action_vector(["Dot"])
     not_stop_action = action_vector(["upper_case"])
@@ -21,4 +19,4 @@ def test_last_stop_label():
 
     res = last_stop_label(labels, stop_action)
 
-    assert last_stop_label(labels, stop_action) == 2
\ No newline at end of file
+    assert last_stop_label(labels, stop_action) == 2
diff --git a/worker.py b/worker.py
index 49fd847..c241998 100755
--- a/worker.py
+++ b/worker.py
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/python
 
 import nlp_ws
 import shutil
@@ -7,6 +7,8 @@ from transformers import BertTokenizerFast, BertForTokenClassification, Pretrain
 import configparser
 from src.processing import ACTIONS_KEYS
 import torch
+from src.pipelines.actions_based.utils import load_model
+
 
 class Worker(nlp_ws.NLPWorker):
     """Class that implements example worker."""
@@ -15,28 +17,26 @@ class Worker(nlp_ws.NLPWorker):
         self.config = configparser.ConfigParser()
         self.config.read("config.ini")
 
-        config = PretrainedConfig.from_pretrained(self.config['deployment']['base_model'])
-        config.num_labels = len(ACTIONS_KEYS)
-
-        device = torch.device("cpu")
-        self.threshold = self.config['deployment']['threshold']
-        self.chunk_size = self.config['deployment']['chunk_size']
-        self.tokenizer = BertTokenizerFast.from_pretrained(self.config['deployment']['base_model'])
-        self.model = BertForTokenClassification(config)
-        self.model.load_state_dict(torch.load(self.config['deployment']['model'], map_location=device))
+        self.threshold = self.config["deployment"]["threshold"]
+        self.chunk_size = self.config["deployment"]["chunk_size"]
+        self.tokenizer, self.model = load_model(
+            self.config["deployment"]["model"],
+            self.config["deployment"]["base_model"],
+            self.config["deployment"]["device"],
+        )
 
     def process(self, input_file: str, task_options: dict, output_file: str) -> None:
         """Implementation of example tasks that copies files."""
 
-        with open(input_file, 'r') as f:
+        with open(input_file, "r") as f:
             text = f.read()
-            text_processed = apply_actions_punctuation(text, self.chunk_size, self.tokenizer, self.model, self.threshold)
-            
-        with  open(output_file, 'w') as f:
-            f.write(text_processed)
+            text_processed = apply_actions_punctuation(
+                text, self.chunk_size, self.tokenizer, self.model, self.threshold
+            )
 
-        shutil.copy(input_file, output_file)
+        with open(output_file, "w") as f:
+            f.write(text_processed)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     nlp_ws.NLPService.main(Worker)
-- 
GitLab


From a74c6f289ab24bfde38f0dc6c85e257cad0150db Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 13:17:07 +0200
Subject: [PATCH 042/116] Style fixes

---
 punctuate.py                                     | 8 ++++----
 src/pipelines/actions_based/processing.py        | 1 -
 tests/pipelines/actions_based/test_processing.py | 6 ++----
 worker.py                                        | 4 ----
 4 files changed, 6 insertions(+), 13 deletions(-)

diff --git a/punctuate.py b/punctuate.py
index d411cd0..05d4251 100755
--- a/punctuate.py
+++ b/punctuate.py
@@ -2,7 +2,6 @@
 
 import argparse
 import os
-import sys
 from src.pipelines.actions_based.utils import load_model
 from src.pipelines.actions_based.processing import apply_actions_punctuation
 
@@ -47,8 +46,9 @@ if __name__ == "__main__":
 
     args = parser.parse_args()
 
-    #if not os.path.exists(args.input):
-    #    print(f"Error: File '{args.input}' does not exists", file=sys.stderr)
+    if not os.path.exists(args.input):
+        print(f"Error: File '{args.input}' does not exists")
+        exit(-1)
 
     tokenizer, model = load_model(
         args.model,
@@ -63,4 +63,4 @@ if __name__ == "__main__":
         )
 
     with open(args.output, "w") as f:
-        f.write(text_processed)
\ No newline at end of file
+        f.write(text_processed)
diff --git a/src/pipelines/actions_based/processing.py b/src/pipelines/actions_based/processing.py
index 8f043e6..84d7ed8 100644
--- a/src/pipelines/actions_based/processing.py
+++ b/src/pipelines/actions_based/processing.py
@@ -2,7 +2,6 @@ from transformers import BertTokenizerFast
 from src.processing import (
     tokenize_labeled_text,
     batchify_data,
-    empty_action_dict,
     encode_actions,
 )
 import numpy as np
diff --git a/tests/pipelines/actions_based/test_processing.py b/tests/pipelines/actions_based/test_processing.py
index 1305960..9956d22 100644
--- a/tests/pipelines/actions_based/test_processing.py
+++ b/tests/pipelines/actions_based/test_processing.py
@@ -1,11 +1,11 @@
 from src.pipelines.actions_based.processing import last_stop_label, action_vector
-from src.processing import encode_actions, empty_action_vector
+from src.processing import encode_actions
 import numpy as np
 
 
 def test_action_vector():
     expected = encode_actions(
-        {"dot": True, "upper_case": True, "colon": False, "question_mark": False,}
+        {"dot": True, "upper_case": True, "colon": False, "question_mark": False}
     )
 
     assert np.all(action_vector(["dot", "upper_case"]) == expected)
@@ -17,6 +17,4 @@ def test_last_stop_label():
 
     labels = np.array([not_stop_action, not_stop_action, stop_action, not_stop_action])
 
-    res = last_stop_label(labels, stop_action)
-
     assert last_stop_label(labels, stop_action) == 2
diff --git a/worker.py b/worker.py
index c241998..059a831 100755
--- a/worker.py
+++ b/worker.py
@@ -1,12 +1,8 @@
 #!/usr/bin/python
 
 import nlp_ws
-import shutil
 from src.pipelines.actions_based.processing import apply_actions_punctuation
-from transformers import BertTokenizerFast, BertForTokenClassification, PretrainedConfig
 import configparser
-from src.processing import ACTIONS_KEYS
-import torch
 from src.pipelines.actions_based.utils import load_model
 
 
-- 
GitLab


From 87016be32a860549772c1358d9b030546c7cddea Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 13:24:19 +0200
Subject: [PATCH 043/116] Added CI

---
 .gitlab-ci.yml | 40 ++++++++++++++++++++++++++++++++++++++++
 tox.ini        |  4 ++--
 2 files changed, 42 insertions(+), 2 deletions(-)
 create mode 100644 .gitlab-ci.yml

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
new file mode 100644
index 0000000..53720f5
--- /dev/null
+++ b/.gitlab-ci.yml
@@ -0,0 +1,40 @@
+image: clarinpl/python:3.8
+
+cache:
+  paths:
+    - .tox
+
+stages:
+  - check_style
+  - testing
+  - build
+
+before_script:
+  - pip install tox==2.9.1
+
+pep8:
+  stage: check_style
+  script:
+    - tox -v -e pep8
+
+unittest:
+  stage: testing
+  script:
+    - tox -v -e unittest
+
+build_image:
+  stage: build
+  image: docker:18.09.7
+  only:
+    - master
+  services:
+    - docker:18.09.7-dind
+  before_script:
+    - ''
+  script:
+    - docker build -t clarinpl/punctuator .
+    - echo $DOCKER_PASSWORD > pass.txt
+    - cat pass.txt | docker login --username $DOCKER_USERNAME --password-stdin
+    - rm pass.txt
+    - docker push clarinpl/punctuator
+C
\ No newline at end of file
diff --git a/tox.ini b/tox.ini
index c7d1dde..d7149d3 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
 [tox]
-envlist = py38,flake8,pep8
+envlist = unittest,pep8
 skipsdist = True
 
 [testenv]
@@ -15,7 +15,7 @@ deps =
     pyarrow==0.17.1
     lxml
 
-[testenv:py38]
+[testenv:unittest]
 commands = pytest --ignore data --ignore generated
 
 [flake8]
-- 
GitLab


From 280965adc243fac7fdd0d26e7d065088db5ded16 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 13:25:15 +0200
Subject: [PATCH 044/116] Removed notebooks

---
 notebooks/dask_dataframe_exploration.ipynb | 108 -------------
 notebooks/dask_functionality_test.ipynb    |  81 ----------
 notebooks/test_actions_model.ipynb         | 172 ---------------------
 notebooks/test_bert_dimensions.ipynb       | 134 ----------------
 notebooks/test_translations_model.ipynb    | 172 ---------------------
 notebooks/tokenizer_testing.ipynb          | 142 -----------------
 notebooks/torch_exploration.ipynb          | 118 --------------
 notebooks/torch_transformer.ipynb          | 133 ----------------
 8 files changed, 1060 deletions(-)
 delete mode 100644 notebooks/dask_dataframe_exploration.ipynb
 delete mode 100644 notebooks/dask_functionality_test.ipynb
 delete mode 100644 notebooks/test_actions_model.ipynb
 delete mode 100644 notebooks/test_bert_dimensions.ipynb
 delete mode 100644 notebooks/test_translations_model.ipynb
 delete mode 100644 notebooks/tokenizer_testing.ipynb
 delete mode 100644 notebooks/torch_exploration.ipynb
 delete mode 100644 notebooks/torch_transformer.ipynb

diff --git a/notebooks/dask_dataframe_exploration.ipynb b/notebooks/dask_dataframe_exploration.ipynb
deleted file mode 100644
index 0e48640..0000000
--- a/notebooks/dask_dataframe_exploration.ipynb
+++ /dev/null
@@ -1,108 +0,0 @@
-{
- "metadata": {
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.2-final"
-  },
-  "orig_nbformat": 2,
-  "kernelspec": {
-   "name": "python38264bita7d7da14168440cb9836372958035d4a",
-   "display_name": "Python 3.8.2 64-bit"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2,
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import dask.dataframe as dd\n",
-    "import numpy as np"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "df = dd.read_parquet(\"../generated/translations/stage2_create_batches\", engine=\"pyarrow\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [],
-   "source": [
-    "shapes = df.source_shape.compute()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "                                               source  \\\n15  [2, 27476, 7835, 4677, 2822, 11226, 781, 77, 4...   \n18  [2, 1178, 4607, 766, 7835, 752, 58008, 4419, 5...   \n21  [2, 27476, 7835, 25104, 9712, 6901, 8698, 778,...   \n25  [2, 1645, 2160, 27476, 11811, 7835, 4677, 5657...   \n29  [2, 27476, 7835, 4677, 5529, 77, 10814, 4994, ...   \n\n                                               target  \\\n15  [2, 15482, 7835, 2931, 18, 7331, 11226, 781, 1...   \n18  [2, 56453, 7835, 18, 922, 58008, 4419, 5482, 4...   \n21  [2, 15482, 7835, 18, 38648, 9712, 6901, 8698, ...   \n25  [2, 1513, 2160, 16, 27476, 11811, 7835, 2931, ...   \n29  [2, 15482, 7835, 2931, 53234, 77, 6789, 17353,...   \n\n                                                 mask source_shape  \\\n15  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...     [2, 500]   \n18  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...   [169, 500]   \n21  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...    [94, 500]   \n25  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...    [25, 500]   \n29  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...     [3, 500]   \n\n   target_shape  mask_shape  \n15     [2, 500]    [2, 500]  \n18   [169, 500]  [169, 500]  \n21    [94, 500]   [94, 500]  \n25    [25, 500]   [25, 500]  \n29     [3, 500]    [3, 500]  ",
-      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>source</th>\n      <th>target</th>\n      <th>mask</th>\n      <th>source_shape</th>\n      <th>target_shape</th>\n      <th>mask_shape</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>15</th>\n      <td>[2, 27476, 7835, 4677, 2822, 11226, 781, 77, 4...</td>\n      <td>[2, 15482, 7835, 2931, 18, 7331, 11226, 781, 1...</td>\n      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...</td>\n      <td>[2, 500]</td>\n      <td>[2, 500]</td>\n      <td>[2, 500]</td>\n    </tr>\n    <tr>\n      <th>18</th>\n      <td>[2, 1178, 4607, 766, 7835, 752, 58008, 4419, 5...</td>\n      <td>[2, 56453, 7835, 18, 922, 58008, 4419, 5482, 4...</td>\n      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...</td>\n      <td>[169, 500]</td>\n      <td>[169, 500]</td>\n      <td>[169, 500]</td>\n    </tr>\n    <tr>\n      <th>21</th>\n      <td>[2, 27476, 7835, 25104, 9712, 6901, 8698, 778,...</td>\n      <td>[2, 15482, 7835, 18, 38648, 9712, 6901, 8698, ...</td>\n      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...</td>\n      <td>[94, 500]</td>\n      <td>[94, 500]</td>\n      <td>[94, 500]</td>\n    </tr>\n    <tr>\n      <th>25</th>\n      <td>[2, 1645, 2160, 27476, 11811, 7835, 4677, 5657...</td>\n      <td>[2, 1513, 2160, 16, 27476, 11811, 7835, 2931, ...</td>\n      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...</td>\n      <td>[25, 500]</td>\n      <td>[25, 500]</td>\n      <td>[25, 500]</td>\n    </tr>\n    <tr>\n      <th>29</th>\n      <td>[2, 27476, 7835, 4677, 5529, 77, 10814, 4994, ...</td>\n      <td>[2, 15482, 7835, 2931, 53234, 77, 6789, 17353,...</td>\n      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...</td>\n      <td>[3, 500]</td>\n      <td>[3, 500]</td>\n      <td>[3, 500]</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
-     },
-     "metadata": {},
-     "execution_count": 4
-    }
-   ],
-   "source": [
-    "df.head()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 17,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "stream",
-     "name": "stdout",
-     "text": "[  13.41055592    6.22695335    9.50220389 1603.42842607  290.42591041\n  182.60031948]\n"
-    }
-   ],
-   "source": [
-    "import pickle\n",
-    "import numpy as np\n",
-    "with open(\"../generated/actions/stage5_stats/stats.pickle\", 'rb') as f:\n",
-    "    stats = pickle.load(f)\n",
-    "    pos_examples = stats['class_number']\n",
-    "    neg_examples = stats['num_examples'] - stats['class_number']\n",
-    "    ratio = neg_examples / pos_examples\n",
-    "\n",
-    "    print(ratio)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ]
-}
\ No newline at end of file
diff --git a/notebooks/dask_functionality_test.ipynb b/notebooks/dask_functionality_test.ipynb
deleted file mode 100644
index ee74d31..0000000
--- a/notebooks/dask_functionality_test.ipynb
+++ /dev/null
@@ -1,81 +0,0 @@
-{
- "metadata": {
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.2-final"
-  },
-  "orig_nbformat": 2,
-  "kernelspec": {
-   "name": "python38264bita7d7da14168440cb9836372958035d4a",
-   "display_name": "Python 3.8.2 64-bit"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2,
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 19,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import dask\n",
-    "import dask.dataframe as dd\n",
-    "import pandas as pd\n",
-    "import numpy as np"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 24,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "pdf = pd.DataFrame({'x': [1,2,3, 4, 5], 'y': ['a', 'b', 'c', 'd', 'e']})"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 25,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "   x  y  ones\n0  1  a     1\n1  2  b     2\n2  3  c     3\n3  4  d     4\n4  5  e     5",
-      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>x</th>\n      <th>y</th>\n      <th>ones</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>1</td>\n      <td>a</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>2</td>\n      <td>b</td>\n      <td>2</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>3</td>\n      <td>c</td>\n      <td>3</td>\n    </tr>\n    <tr>\n      <th>3</th>\n      <td>4</td>\n      <td>d</td>\n      <td>4</td>\n    </tr>\n    <tr>\n      <th>4</th>\n      <td>5</td>\n      <td>e</td>\n      <td>5</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
-     },
-     "metadata": {},
-     "execution_count": 25
-    }
-   ],
-   "source": [
-    "df = dd.from_pandas(pdf, npartitions=2)\n",
-    "df = df.assign(ones=1)\n",
-    "df.ones = df.ones.cumsum()\n",
-    "\n",
-    "order_indexes == df.ones.compute()\n",
-    "random_indexes = df.ones.compute()\n",
-    "np.random.shuffle(random_indexes)\n",
-    "mapping = \n",
-    "\n",
-    "df.compute()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ]
-}
\ No newline at end of file
diff --git a/notebooks/test_actions_model.ipynb b/notebooks/test_actions_model.ipynb
deleted file mode 100644
index 98877b8..0000000
--- a/notebooks/test_actions_model.ipynb
+++ /dev/null
@@ -1,172 +0,0 @@
-{
- "metadata": {
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.2-final"
-  },
-  "orig_nbformat": 2,
-  "kernelspec": {
-   "name": "python38264bita7d7da14168440cb9836372958035d4a",
-   "display_name": "Python 3.8.2 64-bit"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2,
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "%load_ext autoreload\n",
-    "%autoreload 2\n",
-    "\n",
-    "import sys\n",
-    "sys.path.append(\"../\")\n",
-    "\n",
-    "from transformers import BertTokenizerFast, BertForTokenClassification\n",
-    "import torch\n",
-    "from torch.nn import BCEWithLogitsLoss\n",
-    "import pandas as pd\n",
-    "import numpy as np\n",
-    "import dask.dataframe as dd\n",
-    "\n",
-    "from src.processing import create_model_input_output"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "MODEL_BASE = \"dkleczek/bert-base-polish-cased-v1\""
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "tokenizer = BertTokenizerFast.from_pretrained(MODEL_BASE)\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "expected = \"w niedzielę kiedy prażanie są precz i moczą nogi wieku bereuce na wzgórze zamkowe królów czeskich napierają watahy cudzoziemców podnieconych obecnością w miejscu ważnym dla historii od dziesięciu wieków i ponownie od czasu kiedy na hrad wyniesiono vaclava havla zamek chwilowo nie ma gospodarza ale wszystko jest pod parą i gwardia gotowa zawsze do stania u boku katedra świętego wita jest dla czechów tym czy dla polaków katedra na wawelu żelazny szlak turystyczny prowadzi złotą uliczką gdzie zaczyna boleć głowa od nadmiaru okazji do pamiątkowego zdjęcia w tym małym domu mieszkał wielki kafka zanim się dostaniemy na lewy brzeg wełtawy wypada spędzić dłuższą chwilę na samym moście bujnie opiewanym i jeśli nie najpiękniejszym to z pewnością jedynym w roli tak ciągle twórczej dla folkloru wielkiego miasta najpierw śmierć uderza młoteczkiem w dzwonki a potem defiluje i przedstawia się turystom dwanaście apostołów zegar orloja jest najwyższą z wiekowych atrakcji na rynku starego miasta nie budujemy drugich czech choć bliżej nam do pragi niż do tokio i już tysiąc lat temu zaczął święty wojciech ciekawe tylko czy jan hus przetrwałby na pomniku w drugich czechach i czy miałby tyle szans co święty wacław przy wiślanym szlaku od mostu poniatowskiego w dół rzeki trzyma się bar pod rurą i może jeszcze ze dwie budy gdzie po dawnemu sprzedaje się piwo marki piwo temperatura zależy od słoneczka szkło nie gra roli a niewiastom wstęp niewskazany w gazetach do niedawna pisano że piwo degraduje polaków w rzeczywistości było dokładnie na odwrót amerykański desant nie uratował zagrożonej cnoty polaków bo ciepłe piwo jest mimo wszystko mniej obrzydliwe od ciepłej coca coli miejsce właściwe obu napojom w kulturze i obyczaju wyznaczył dopiero wolny rynek\"\n",
-    "text_clean = create_model_input_output(expected)[0]\n",
-    "\n",
-    "inputs = tokenizer(text_clean, return_tensors=\"pt\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "stream",
-     "name": "stderr",
-     "text": "Some weights of the model checkpoint at dkleczek/bert-base-polish-cased-v1 were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\nSome weights of BertForTokenClassification were not initialized from the model checkpoint at dkleczek/bert-base-polish-cased-v1 and are newly initialized: ['classifier.weight', 'classifier.bias']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
-    },
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "<All keys matched successfully>"
-     },
-     "metadata": {},
-     "execution_count": 5
-    }
-   ],
-   "source": [
-    "model = BertForTokenClassification.from_pretrained(MODEL_BASE, num_labels=4)\n",
-    "device = torch.device(\"cpu\")\n",
-    "model.load_state_dict(torch.load(\"../checkpoints/actions/0-100.model\", map_location=device))\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 57,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "'w niedzielę, kiedy prażanie są precz i moczą nogi wieku Bereuce na wzgórze zamkowe Królów Czeskich napierają watahy cudzoziemców podnieconych obecnością w miejscu ważnym dla historii od dziesięciu wieków i ponownie od czasu, kiedy na Hrad wyniesiono Vaclava Havla Zamek chwilowo nie ma gospodarza ale wszystko jest pod parą i Gwardia gotowa zawsze do stania u boku Katedra Świętego Wita jest dla Czechów tym, czy dla polaków Katedra na Wawelu Żelazny szlak turystyczny prowadzi złotą uliczką gdzie zaczyna boleć głowa od nadmiaru okazji do Pamiątkowego zdjęcia w tym małym domu mieszkał wielki Kafka zanim się dostaniemy na lewy brzeg wełtawy wypada spędzić dłuższą chwilę na samym moście bujnie opiewanym, i jeśli nie najpiękniejszym, to z pewnością jedynym w roli tak ciągle twórczej dla folkloru wielkiego miasta najpierw śmierć uderza młoteczkiem w dzwonki, a potem defiluje i przedstawia się turystom dwanaście apostołów zegar Orloja jest najwyższą z wiekowych atrakcji na rynku starego miasta Nie budujemy drugich czech choć bliżej nam do Pragi niż do Tokio i już tysiąc lat temu zaczął święty wojciech ciekawe tylko czy Jan Hus przetrwałby na pomniku w drugich Czechach i czy miałby tyle szans, co święty Wacław przy wiślanym szlaku od mostu Poniatowskiego w dół rzeki trzyma się bar pod rurą i może jeszcze ze dwie budy, gdzie po dawnemu sprzedaje się piwo marki piwo temperatura zależy od słoneczka szkło nie gra roli, a niewiastom wstęp niewskazany w gazetach do niedawna pisano, że piwo degraduje polaków w rzeczywistości było dokładnie na odwrót Amerykański Desant Nie uratował zagrożonej cnoty Polaków, bo ciepłe piwo jest mimo wszystko mniej obrzydliwe od ciepłej coca coli miejsce właściwe obu napojom w kulturze i obyczaju wyznaczył dopiero wolny rynek'"
-     },
-     "metadata": {},
-     "execution_count": 57
-    }
-   ],
-   "source": [
-    "from src.pipelines.actions_based.processing import apply_actions_punctuation\n",
-    "\n",
-    "apply_actions_punctuation(text_clean, 10, tokenizer, model, 0.9)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 47,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "stream",
-     "name": "stdout",
-     "text": "(277, 4)\nw niedzielę kiedy prażanie są precz i moczą nogi wieku bereuce na wzgórze zamkowe królów czeskich napierają watahy cudzoziemców podnieconych obecnością w miejscu ważnym dla historii od dziesięciu wieków i ponownie od czasu kiedy na hrad wyniesiono vaclava havla zamek chwilowo nie ma gospodarza ale wszystko jest pod parą i gwardia gotowa zawsze do stania u boku katedra świętego wita jest dla czechów tym czy dla polaków katedra na wawelu żelazny szlak turystyczny prowadzi złotą uliczką gdzie zaczyna boleć głowa od nadmiaru okazji do pamiątkowego zdjęcia w tym małym domu mieszkał wielki kafka zanim się dostaniemy na lewy brzeg wełtawy wypada spędzić dłuższą chwilę na samym moście bujnie opiewanym i jeśli nie najpiękniejszym to z pewnością jedynym w roli tak ciągle twórczej dla folkloru wielkiego miasta najpierw śmierć uderza młoteczkiem w dzwonki a potem defiluje i przedstawia się turystom dwanaście apostołów zegar orloja jest najwyższą z wiekowych atrakcji na rynku starego miasta nie budujemy drugich czech choć bliżej nam do pragi niż do tokio i już tysiąc lat temu zaczął święty wojciech ciekawe tylko czy jan hus przetrwałby na pomniku w drugich czechach i czy miałby tyle szans co święty wacław przy wiślanym szlaku od mostu poniatowskiego w dół rzeki trzyma się bar pod rurą i może jeszcze ze dwie budy gdzie po dawnemu sprzedaje się piwo marki piwo temperatura zależy od słoneczka szkło nie gra roli a niewiastom wstęp niewskazany w gazetach do niedawna pisano że piwo degraduje polaków w rzeczywistości było dokładnie na odwrót amerykański desant nie uratował zagrożonej cnoty polaków bo ciepłe piwo jest mimo wszystko mniej obrzydliwe od ciepłej coca coli miejsce właściwe obu napojom w kulturze i obyczaju wyznaczył dopiero wolny rynek \n------\n"
-    },
-    {
-     "output_type": "error",
-     "ename": "NameError",
-     "evalue": "name 'actions' is not defined",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-47-43658d364fa3>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpected\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"------\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecover_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext_clean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mactions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;31mNameError\u001b[0m: name 'actions' is not defined"
-     ]
-    }
-   ],
-   "source": [
-    "from src.processing import token_labels_to_word_labels, recover_text\n",
-    "\n",
-    "y_pred = model(**inputs)[0].sigmoid()\n",
-    "y_pred = y_pred > 0.9\n",
-    "\n",
-    "labels_pred = token_labels_to_word_labels(text_clean, y_pred.detach().numpy()[0, 1:-1, :], tokenizer)\n",
-    "\n",
-    "print(labels_pred.shape)\n",
-    "\n",
-    "print(expected)\n",
-    "print(\"------\")\n",
-    "print(recover_text(text_clean, actions))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ]
-}
\ No newline at end of file
diff --git a/notebooks/test_bert_dimensions.ipynb b/notebooks/test_bert_dimensions.ipynb
deleted file mode 100644
index 8bc4617..0000000
--- a/notebooks/test_bert_dimensions.ipynb
+++ /dev/null
@@ -1,134 +0,0 @@
-{
- "metadata": {
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.2-final"
-  },
-  "orig_nbformat": 2,
-  "kernelspec": {
-   "name": "python38264bit17f10e31b7e440e591cfca7d4c2c2274",
-   "display_name": "Python 3.8.2 64-bit"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2,
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 41,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from transformers import BertForMaskedLM, BertTokenizerFast\n",
-    "import torch"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 42,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "stream",
-     "name": "stderr",
-     "text": "Some weights of the model checkpoint at dkleczek/bert-base-polish-cased-v1 were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
-    }
-   ],
-   "source": [
-    "model = BertForMaskedLM.from_pretrained(\"dkleczek/bert-base-polish-cased-v1\")\n",
-    "tokenizer = BertTokenizerFast.from_pretrained(\"dkleczek/bert-base-polish-cased-v1\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 85,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "text = \"Dziwny <mask> ten świat!\"\n",
-    "kwgs = tokenizer(text, return_tensors='pt', max_length=30, padding='max_length')"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 86,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "res = model(**kwgs)[0]\n",
-    "output = [x.argmax() for x in res[0]]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 87,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "'! Dziwny < br > ten świat!!! \" Dzi! i! < … > nowy świat!! świat! \" \"!! jest'"
-     },
-     "metadata": {},
-     "execution_count": 87
-    }
-   ],
-   "source": [
-    "tokenizer.decode(output)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 88,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "[tensor(5),\n tensor(4642),\n tensor(12407),\n tensor(32),\n tensor(6858),\n tensor(34),\n tensor(1216),\n tensor(1994),\n tensor(5),\n tensor(5),\n tensor(5),\n tensor(6),\n tensor(4642),\n tensor(5),\n tensor(77),\n tensor(5),\n tensor(32),\n tensor(372),\n tensor(34),\n tensor(3905),\n tensor(1994),\n tensor(5),\n tensor(5),\n tensor(1994),\n tensor(5),\n tensor(6),\n tensor(6),\n tensor(5),\n tensor(5),\n tensor(800)]"
-     },
-     "metadata": {},
-     "execution_count": 88
-    }
-   ],
-   "source": [
-    "output"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 89,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "{'input_ids': tensor([[    2,  4642, 12407,    32, 45933,    34,  1216,  1994,     5,     4,\n             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 0, 0]])}"
-     },
-     "metadata": {},
-     "execution_count": 89
-    }
-   ],
-   "source": [
-    "kwgs"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ]
-}
\ No newline at end of file
diff --git a/notebooks/test_translations_model.ipynb b/notebooks/test_translations_model.ipynb
deleted file mode 100644
index d418b29..0000000
--- a/notebooks/test_translations_model.ipynb
+++ /dev/null
@@ -1,172 +0,0 @@
-{
- "metadata": {
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.2-final"
-  },
-  "orig_nbformat": 2,
-  "kernelspec": {
-   "name": "python_defaultSpec_1596573889994",
-   "display_name": "Python 3.8.2 64-bit"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2,
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import sys\n",
-    "sys.path.append(\"../\")\n",
-    "\n",
-    "from transformers import BertTokenizerFast\n",
-    "from src.models.TransformerSeq2Seq import TransformerSeq2Seq\n",
-    "import torch\n",
-    "from torch.nn import BCEWithLogitsLoss\n",
-    "import pandas as pd\n",
-    "import numpy as np\n",
-    "import dask.dataframe as dd\n",
-    "\n",
-    "from src.processing import create_model_input_output"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "MODEL_BASE = \"dkleczek/bert-base-polish-cased-v1\""
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "tokenizer = BertTokenizerFast.from_pretrained(MODEL_BASE)\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "tokenizer.em"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 36,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "expected = \"Ogromny wybuch w Bejrucie zrównał z ziemią prawie całą dzielnicę.\"\n",
-    "text_clean = create_model_input_output(expected)[0]\n",
-    "\n",
-    "inputs = tokenizer(text_clean, return_tensors=\"pt\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 41,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "<All keys matched successfully>"
-     },
-     "metadata": {},
-     "execution_count": 41
-    }
-   ],
-   "source": [
-    "model = TransformerSeq2Seq(tokenizer.vocab_size, 200, 300, 1, 2, 2)\n",
-    "model.load_state_dict(torch.load(\"../checkpoints/translations/0-43000.model\", map_location={'cuda:2': 'cpu'}))\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 42,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "'[CLS] Zapis w Bej, zrównał z ziemią prawie całą dodatki, całą, całą, dzielnicę, teatr, dzielnicę,, rozpro, a całą dzielnicę, z Londynu, budownictwie, prawie całą dzielnicę, niezarzymy, cuk z ziemią, prawie całą dzielnicę, rozpro, bory, nieufności, daną, z daleko, autorytetem, krowybina, w Argentynie wybuch, w dzielnicę, wybuch w dzielnicę, płyt, partia, doniesie,cie, wybuch w walce, kardi, ogromny, wybuchwany, przeszło, w wybuch,'"
-     },
-     "metadata": {},
-     "execution_count": 42
-    }
-   ],
-   "source": [
-    "input_tokens = inputs['input_ids']\n",
-    "outputs = [[tokenizer.cls_token_id]]\n",
-    "\n",
-    "for j in range(100):\n",
-    "    preds = model(input_tokens, torch.tensor(outputs, dtype=torch.long), torch.zeros_like(input_tokens).bool()).softmax(-1)\n",
-    "    outputs[0].append(preds[0, -1].argmax().detach().tolist())\n",
-    "\n",
-    "tokenizer.decode(outputs[0])\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 43,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "error",
-     "ename": "TypeError",
-     "evalue": "TransformerSeq2Seq object argument after ** must be a mapping, not Tensor",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-43-ed4e46be8788>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcls_token_id\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0mtokens_predictions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;31mTypeError\u001b[0m: TransformerSeq2Seq object argument after ** must be a mapping, not Tensor"
-     ]
-    }
-   ],
-   "source": [
-    "from src.processing import token_labels_to_word_labels, recover_text\n",
-    "\n",
-    "inputs = inputs['input_ids']\n",
-    "outputs = np.array([[tokenizer.cls_token_id]])\n",
-    "\n",
-    "y_pred = model(**inputs)\n",
-    "\n",
-    "tokens_predictions = [np.argmax(x) for x in y_pred[0, :]]\n",
-    "\n",
-    "actions = labels_pred > 0.5\n",
-    "print(expected)\n",
-    "print(\"------\")\n",
-    "print(recover_text(text_clean, actions))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ]
-}
\ No newline at end of file
diff --git a/notebooks/tokenizer_testing.ipynb b/notebooks/tokenizer_testing.ipynb
deleted file mode 100644
index d6e83cc..0000000
--- a/notebooks/tokenizer_testing.ipynb
+++ /dev/null
@@ -1,142 +0,0 @@
-{
- "metadata": {
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.2-final"
-  },
-  "orig_nbformat": 2,
-  "kernelspec": {
-   "name": "python38264bita7d7da14168440cb9836372958035d4a",
-   "display_name": "Python 3.8.2 64-bit"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2,
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 14,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from transformers import BertTokenizerFast\n",
-    "import numpy as np"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "base_model = \"bert-base-multilingual-cased\"\n",
-    "tokenizer = BertTokenizerFast.from_pretrained(base_model)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 31,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "[56500, 117, 10824, 30186, 11090, 10113, 119, 138, 13400, 11058, 106]"
-     },
-     "metadata": {},
-     "execution_count": 31
-    }
-   ],
-   "source": [
-    "tokenizer(\"Ala, ma KoTa. A kot nie!\")['input_ids'][1:-1]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 26,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "'Ala, ma KoTa.'"
-     },
-     "metadata": {},
-     "execution_count": 26
-    }
-   ],
-   "source": [
-    "tokenizer.decode(np.array(tokenizer(\"Ala, ma KoTa. A kot nie!\")['input_ids'][1:-1])[[0, 1, 2, 3, 4, 5, 6]])"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 29,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "'A kot nie!'"
-     },
-     "metadata": {},
-     "execution_count": 29
-    }
-   ],
-   "source": [
-    "tokenizer.decode(np.array(tokenizer(\"Ala, ma KoTa. A kot nie!\")['input_ids'][1:-1])[[7, 8, 9, 10]])"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 20,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "'ala ma kota'"
-     },
-     "metadata": {},
-     "execution_count": 20
-    }
-   ],
-   "source": [
-    "tokenizer.decode(np.array(tokenizer(\"ala ma kota a kot nie\")['input_ids'][1:-1])[[0, 1, 2]])"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 22,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "'a kot nie'"
-     },
-     "metadata": {},
-     "execution_count": 22
-    }
-   ],
-   "source": [
-    "tokenizer.decode(np.array(tokenizer(\"ala ma kota a kot nie\")['input_ids'][1:-1])[[3, 4, 5]])"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ]
-}
\ No newline at end of file
diff --git a/notebooks/torch_exploration.ipynb b/notebooks/torch_exploration.ipynb
deleted file mode 100644
index 58f68ee..0000000
--- a/notebooks/torch_exploration.ipynb
+++ /dev/null
@@ -1,118 +0,0 @@
-{
- "metadata": {
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.2-final"
-  },
-  "orig_nbformat": 2,
-  "kernelspec": {
-   "name": "python38264bita7d7da14168440cb9836372958035d4a",
-   "display_name": "Python 3.8.2 64-bit"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2,
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import torch"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 62,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "error",
-     "ename": "RuntimeError",
-     "evalue": "Could not run 'aten::scatter_.value' with arguments from the 'SparseCPUTensorId' backend. 'aten::scatter_.value' is only available for these backends: [CPUTensorId, CUDATensorId, VariableTensorId].",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-62-9bd81c403b2e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0monehot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msparse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0monehot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscatter_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m \u001b[0monehot\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;31mRuntimeError\u001b[0m: Could not run 'aten::scatter_.value' with arguments from the 'SparseCPUTensorId' backend. 'aten::scatter_.value' is only available for these backends: [CPUTensorId, CUDATensorId, VariableTensorId]."
-     ]
-    }
-   ],
-   "source": [
-    "i = torch.LongTensor([\n",
-    "    [1, 2, 3],\n",
-    "    [4, 4, 4]\n",
-    "]).unsqueeze(-1)\n",
-    "\n",
-    "onehot = torch.sparse.FloatTensor(2, 3, 5).zero_()\n",
-    "onehot.scatter_(2, i, 1)\n",
-    "onehot"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 77,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "tensor([[1.3133, 1.3133, 1.3133],\n        [0.3133, 0.3133, 0.3133]])"
-     },
-     "metadata": {},
-     "execution_count": 77
-    }
-   ],
-   "source": [
-    "got = torch.tensor([\n",
-    "    [[0, 1], [0, 1], [1, 0]],\n",
-    "    [[0, 1], [0, 1], [1, 0]]\n",
-    "], dtype=torch.float)\n",
-    "\n",
-    "target = torch.tensor([\n",
-    "    [0, 0, 1],\n",
-    "    [1, 1, 0]\n",
-    "])\n",
-    "\n",
-    "got.transpose_(1, 2)\n",
-    "\n",
-    "loss = torch.nn.CrossEntropyLoss(reduction='none')\n",
-    "loss(got, target)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 79,
-   "metadata": {},
-   "outputs": [
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "tensor([[0.2689, 0.7311],\n        [0.5000, 0.5000]])"
-     },
-     "metadata": {},
-     "execution_count": 79
-    }
-   ],
-   "source": [
-    "torch.tensor([[1.0, 2.0], [5.0, 5.0]]).softmax(-1)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ]
-}
\ No newline at end of file
diff --git a/notebooks/torch_transformer.ipynb b/notebooks/torch_transformer.ipynb
deleted file mode 100644
index e8d69dd..0000000
--- a/notebooks/torch_transformer.ipynb
+++ /dev/null
@@ -1,133 +0,0 @@
-{
- "metadata": {
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": 3
-  },
-  "orig_nbformat": 2,
-  "kernelspec": {
-   "name": "python_defaultSpec_1596544773362",
-   "display_name": "Python 3.8.2 64-bit"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 2,
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 10,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import torch\n",
-    "import torch.nn as nn\n",
-    "import math"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 14,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "class PositionalEncoding(nn.Module):\n",
-    "    def __init__(self, d_model, dropout=0.1, max_len=5000):\n",
-    "        super(PositionalEncoding, self).__init__()\n",
-    "        self.dropout = nn.Dropout(p=dropout)\n",
-    "\n",
-    "        pe = torch.zeros(max_len, d_model)\n",
-    "        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
-    "        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n",
-    "        pe[:, 0::2] = torch.sin(position * div_term)\n",
-    "        pe[:, 1::2] = torch.cos(position * div_term)\n",
-    "        pe = pe.unsqueeze(0).transpose(0, 1)\n",
-    "        self.register_buffer('pe', pe)\n",
-    "\n",
-    "    def forward(self, x):\n",
-    "        x = x + self.pe[:x.size(0), :]\n",
-    "        return self.dropout(x)\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 63,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "class Transformer(nn.Module):\n",
-    "    def __init__(self, vocab_size: int = 5, embedding_size: int = 16, num_heads: int = 4, encoder_layers: int = 1, decoder_layers: int = 1, feedforward_neurons: int = 100, dropout: float = 0.1, max_len: int= 10):\n",
-    "        super(Transformer, self).__init__()\n",
-    "\n",
-    "        self.word_embedding = nn.Embedding(vocab_size, embedding_size)\n",
-    "        self.position_embedding = PositionalEncoding(embedding_size, dropout, max_len)\n",
-    "        self.core = nn.Transformer(embedding_size, num_heads, encoder_layers, decoder_layers, feedforward_neurons, dropout)\n",
-    "        self.embedding_to_words = nn.Linear(embedding_size, vocab_size)\n",
-    "\n",
-    "    def forward(self, source, target, source_mask):\n",
-    "        x = source.transpose(0, 1)\n",
-    "        x = self.word_embedding(x)\n",
-    "        x = self.position_embedding(x)\n",
-    "\n",
-    "        y = target.transpose(0, 1)\n",
-    "        y = self.word_embedding(y)\n",
-    "        y = self.position_embedding(y)\n",
-    "\n",
-    "        tgt_mask = self.core.generate_square_subsequent_mask(y.shape[0])\n",
-    "\n",
-    "        print(tgt_mask.shape)\n",
-    "\n",
-    "        z = self.core(x, y, src_key_padding_mask=source_mask, tgt_mask=tgt_mask).transpose(1, 0)\n",
-    "        z = self.embedding_to_words(z)\n",
-    "\n",
-    "        return z"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 66,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "output_type": "stream",
-     "name": "stdout",
-     "text": "torch.Size([3, 3])\n"
-    },
-    {
-     "output_type": "execute_result",
-     "data": {
-      "text/plain": "torch.Size([2, 3, 5])"
-     },
-     "metadata": {},
-     "execution_count": 66
-    }
-   ],
-   "source": [
-    "transformer = Transformer()\n",
-    "\n",
-    "example_batch = torch.randint(0, 5, (2, 4))\n",
-    "example_target = torch.randint(0, 5, (2, 4))\n",
-    "\n",
-    "source_mask = torch.ones_like(example_batch, dtype=torch.uint8) == 0\n",
-    "\n",
-    "transformer(example_batch, example_target[:, :-1], source_mask).shape\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ]
-}
\ No newline at end of file
-- 
GitLab


From 97a7a48850bc5c01a55eb163a5dea1696fb38787 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 13:26:36 +0200
Subject: [PATCH 045/116] Small fixups

---
 .gitignore     | 3 ++-
 .gitlab-ci.yml | 3 +--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/.gitignore b/.gitignore
index 1414e6a..419d8f2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,4 +10,5 @@ __pycache__
 .pytest_cache
 /checkpoints
 .dvc
-.tox
\ No newline at end of file
+.tox
+notebooks
\ No newline at end of file
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 53720f5..6504c49 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -36,5 +36,4 @@ build_image:
     - echo $DOCKER_PASSWORD > pass.txt
     - cat pass.txt | docker login --username $DOCKER_USERNAME --password-stdin
     - rm pass.txt
-    - docker push clarinpl/punctuator
-C
\ No newline at end of file
+    - docker push clarinpl/punctuator
\ No newline at end of file
-- 
GitLab


From baaa44fa7a8bd6daa8050c13243835a306a12006 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 13:29:47 +0200
Subject: [PATCH 046/116] More cleanup

---
 .devcontainer/devcontainer.json |  48 -------------
 .gitignore                      |   3 +-
 .vscode/.ropeproject/config.py  | 123 --------------------------------
 .vscode/.ropeproject/objectdb   | Bin 6 -> 0 bytes
 .vscode/launch.json             |  22 ------
 .vscode/settings.json           |  25 -------
 6 files changed, 2 insertions(+), 219 deletions(-)
 delete mode 100644 .devcontainer/devcontainer.json
 delete mode 100644 .vscode/.ropeproject/config.py
 delete mode 100644 .vscode/.ropeproject/objectdb
 delete mode 100644 .vscode/launch.json
 delete mode 100644 .vscode/settings.json

diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
deleted file mode 100644
index c0f5fbc..0000000
--- a/.devcontainer/devcontainer.json
+++ /dev/null
@@ -1,48 +0,0 @@
-// For format details, see https://aka.ms/vscode-remote/devcontainer.json or this file's README at:
-// https://github.com/microsoft/vscode-dev-containers/tree/v0.128.0/containers/docker-existing-dockerfile
-{
-	"name": "Development Container",
-
-	// Sets the run context to one level up instead of the .devcontainer folder.
-	"context": "../docker",
-
-	// Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename.
-	"dockerFile": "../docker/Dockerfile",
-
-	// Set *default* container specific settings.json values on container create.
-	"settings": { 
-		"terminal.integrated.shell.linux": null
-	},
-
-	// Add the IDs of extensions you want installed when the container is created.
-	"extensions": [
-		"ms-python.python"
-	],
-
-	// Use 'forwardPorts' to make a list of ports inside the container available locally.
-	"forwardPorts": [8787],
-
-	"runArgs": [
-		"--gpus", "all"
-	],
-
-	"mounts": [
-		"source=/home/mpogoda/.gitconfig,target=/root/.gitconfig,type=bind",
-		"source=/home/mpogoda/.aws,target=/root/.aws,type=bind"
-	],
-
-	"workspaceMount": "source=/home/mpogoda/mnt/interpunkcja,target=/workspace,type=bind,consistency=cached",
-	"workspaceFolder": "/workspace",
-
-	// Uncomment the next line to run commands after the container is created - for example installing curl.
-	// "postCreateCommand": "apt-get update && apt-get install -y curl",
-
-	// Uncomment when using a ptrace-based debugger like C++, Go, and Rust
-	// "runArgs": [ "--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined" ],
-
-	// Uncomment to use the Docker CLI from inside the container. See https://aka.ms/vscode-remote/samples/docker-from-docker.
-	// "mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind" ],
-
-	// Uncomment to connect as a non-root user. See https://aka.ms/vscode-remote/containers/non-root.
-	// "remoteUser": "vscode"
-}
diff --git a/.gitignore b/.gitignore
index 419d8f2..491e41d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,7 +2,8 @@ dane/**
 dataset_simple
 dataset_actions
 **/dask-worker-space
-.vscode`
+.vscode
+.devcontainer
 .idea
 .metals
 /data
diff --git a/.vscode/.ropeproject/config.py b/.vscode/.ropeproject/config.py
deleted file mode 100644
index c339bc7..0000000
--- a/.vscode/.ropeproject/config.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# The default ``config.py``
-# flake8: noqa
-
-
-def set_prefs(prefs):
-    """This function is called before opening the project"""
-
-    # Specify which files and folders to ignore in the project.
-    # Changes to ignored resources are not added to the history and
-    # VCSs.  Also they are not returned in `Project.get_files()`.
-    # Note that ``?`` and ``*`` match all characters but slashes.
-    # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc'
-    # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc'
-    # '.svn': matches 'pkg/.svn' and all of its children
-    # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o'
-    # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o'
-    prefs["ignored_resources"] = [
-        "*.pyc",
-        "*~",
-        ".ropeproject",
-        ".hg",
-        ".svn",
-        "_svn",
-        ".git",
-        ".tox",
-    ]
-
-    # Specifies which files should be considered python files.  It is
-    # useful when you have scripts inside your project.  Only files
-    # ending with ``.py`` are considered to be python files by
-    # default.
-    # prefs['python_files'] = ['*.py']
-
-    # Custom source folders:  By default rope searches the project
-    # for finding source folders (folders that should be searched
-    # for finding modules).  You can add paths to that list.  Note
-    # that rope guesses project source folders correctly most of the
-    # time; use this if you have any problems.
-    # The folders should be relative to project root and use '/' for
-    # separating folders regardless of the platform rope is running on.
-    # 'src/my_source_folder' for instance.
-    # prefs.add('source_folders', 'src')
-
-    # You can extend python path for looking up modules
-    # prefs.add('python_path', '~/python/')
-
-    # Should rope save object information or not.
-    prefs["save_objectdb"] = True
-    prefs["compress_objectdb"] = False
-
-    # If `True`, rope analyzes each module when it is being saved.
-    prefs["automatic_soa"] = True
-    # The depth of calls to follow in static object analysis
-    prefs["soa_followed_calls"] = 0
-
-    # If `False` when running modules or unit tests "dynamic object
-    # analysis" is turned off.  This makes them much faster.
-    prefs["perform_doa"] = True
-
-    # Rope can check the validity of its object DB when running.
-    prefs["validate_objectdb"] = True
-
-    # How many undos to hold?
-    prefs["max_history_items"] = 32
-
-    # Shows whether to save history across sessions.
-    prefs["save_history"] = True
-    prefs["compress_history"] = False
-
-    # Set the number spaces used for indenting.  According to
-    # :PEP:`8`, it is best to use 4 spaces.  Since most of rope's
-    # unit-tests use 4 spaces it is more reliable, too.
-    prefs["indent_size"] = 4
-
-    # Builtin and c-extension modules that are allowed to be imported
-    # and inspected by rope.
-    prefs["extension_modules"] = []
-
-    # Add all standard c-extensions to extension_modules list.
-    prefs["import_dynload_stdmods"] = True
-
-    # If `True` modules with syntax errors are considered to be empty.
-    # The default value is `False`; When `False` syntax errors raise
-    # `rope.base.exceptions.ModuleSyntaxError` exception.
-    prefs["ignore_syntax_errors"] = False
-
-    # If `True`, rope ignores unresolvable imports.  Otherwise, they
-    # appear in the importing namespace.
-    prefs["ignore_bad_imports"] = False
-
-    # If `True`, rope will insert new module imports as
-    # `from <package> import <module>` by default.
-    prefs["prefer_module_from_imports"] = False
-
-    # If `True`, rope will transform a comma list of imports into
-    # multiple separate import statements when organizing
-    # imports.
-    prefs["split_imports"] = False
-
-    # If `True`, rope will remove all top-level import statements and
-    # reinsert them at the top of the module when making changes.
-    prefs["pull_imports_to_top"] = True
-
-    # If `True`, rope will sort imports alphabetically by module name instead
-    # of alphabetically by import statement, with from imports after normal
-    # imports.
-    prefs["sort_imports_alphabetically"] = False
-
-    # Location of implementation of
-    # rope.base.oi.type_hinting.interfaces.ITypeHintingFactory In general
-    # case, you don't have to change this value, unless you're an rope expert.
-    # Change this value to inject you own implementations of interfaces
-    # listed in module rope.base.oi.type_hinting.providers.interfaces
-    # For example, you can add you own providers for Django Models, or disable
-    # the search type-hinting in a class hierarchy, etc.
-    prefs["type_hinting_factory"] = (
-        "rope.base.oi.type_hinting.factory.default_type_hinting_factory"
-    )
-
-
-def project_opened(project):
-    """This function is called after opening the project"""
-    # Do whatever you like here!
diff --git a/.vscode/.ropeproject/objectdb b/.vscode/.ropeproject/objectdb
deleted file mode 100644
index 0a47446c0ad231c193bdd44ff327ba2ab28bf3d8..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 6
NcmZo*sx4&D0{{kv0iOT>

diff --git a/.vscode/launch.json b/.vscode/launch.json
deleted file mode 100644
index 6e10b9f..0000000
--- a/.vscode/launch.json
+++ /dev/null
@@ -1,22 +0,0 @@
-{
-    // Use IntelliSense to learn about possible attributes.
-    // Hover to view descriptions of existing attributes.
-    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
-    "version": "0.2.0",
-    "configurations": [
-        {
-            "name": "Python: Current File",
-            "type": "python",
-            "request": "launch",
-            "program": "${file}",
-            "console": "integratedTerminal",
-            "cwd": "${workspaceFolder}/dataset_generation"
-        },
-        {
-            "name": "Python: Attach using Process Id",
-            "type": "python",
-            "request": "attach",
-            "processId": "${command:pickProcess}"
-        }
-    ]
-}
\ No newline at end of file
diff --git a/.vscode/settings.json b/.vscode/settings.json
deleted file mode 100644
index f55b7bf..0000000
--- a/.vscode/settings.json
+++ /dev/null
@@ -1,25 +0,0 @@
-{
-    "python.testing.pytestArgs": [
-        "tests",
-    ],
-    "python.testing.unittestEnabled": false,
-    "python.testing.nosetestsEnabled": false,
-    "python.testing.pytestEnabled": true,
-    "python.testing.unittestArgs": [
-        "-v",
-        "-s",
-        "./src",
-        "-p",
-        "test_*.py"
-    ],
-    "files.watcherExclude": {
-        "**/.git": true,
-        "**/.svn": true,
-        "**/.hg": true,
-        "**/CVS": true,
-        "**/.DS_Store": true,
-        "data/*": true
-    },
-    "python.testing.cwd": "${workspaceFolder}",
-    "docker.host": "ssh://mpogoda@156.17.135.51"
-}
\ No newline at end of file
-- 
GitLab


From a5151068248b717811fef919958456f7cf7a36d3 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 13:31:22 +0200
Subject: [PATCH 047/116] More cleanup

---
 .gitignore |   3 +-
 dvc.lock   | 162 -----------------------------------------------------
 2 files changed, 2 insertions(+), 163 deletions(-)
 delete mode 100644 dvc.lock

diff --git a/.gitignore b/.gitignore
index 491e41d..c929ac9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,4 +12,5 @@ __pycache__
 /checkpoints
 .dvc
 .tox
-notebooks
\ No newline at end of file
+notebooks
+dvc.lock
\ No newline at end of file
diff --git a/dvc.lock b/dvc.lock
deleted file mode 100644
index 8780ddc..0000000
--- a/dvc.lock
+++ /dev/null
@@ -1,162 +0,0 @@
-extraction:
-  cmd: python3 -m scripts.dataset_generation.stage1_extraction
-  deps:
-  - path: data
-    md5: 1fa175e752af1638dc896838e82a9d7d.dir
-  - path: scripts/dataset_generation/stage1_extraction.py
-    md5: b5256e47e54f55fd406f23889a9cbca9
-  params:
-    params.yaml:
-      extraction.num_partitions: 2000
-  outs:
-  - path: generated/stage1_extraction
-    md5: c33e5a857a8de3bce69bfc8636f64854.dir
-tokenization:
-  cmd: python3 -m scripts.dataset_generation.stage2_tokenization
-  deps:
-  - path: generated/stage1_extraction
-    md5: c33e5a857a8de3bce69bfc8636f64854.dir
-  - path: scripts/dataset_generation/stage2_tokenization.py
-    md5: 1afe768315b818e3a051c976cf19d2f3
-  params:
-    params.yaml:
-      tokenization.max_tokens: 500
-      tokenization.min_tokens: 10
-  outs:
-  - path: generated/stage2_tokenization
-    md5: 1b98b64ecf98ec74c446721256182539.dir
-exploding:
-  cmd: python3 -m scripts.dataset_generation.stage3_exploding
-  deps:
-  - path: generated/stage2_tokenization
-    md5: 1b98b64ecf98ec74c446721256182539.dir
-  - path: scripts/dataset_generation/stage3_exploding.py
-    md5: 490843d650534f09003480d26fde2390
-  outs:
-  - path: generated/stage3_exploding
-    md5: 688ce8926016ed49154b088850be6cff.dir
-reindexing:
-  cmd: python3 -m scripts.dataset_generation.stage4_reindexing
-  deps:
-  - path: generated/stage3_exploding
-    md5: 688ce8926016ed49154b088850be6cff.dir
-  - path: scripts/dataset_generation/stage4_reindexing.py
-    md5: 342a0fd49f45c3d9ff9b3a701f6ebb7d
-  outs:
-  - path: generated/stage4_reindexing
-    md5: 9e797430fe072a60e778e191606a9952.dir
-translations_extraction:
-  cmd: python3 -m scripts.translation_based.stage1_extraction
-  deps:
-  - path: data
-    md5: 1fa175e752af1638dc896838e82a9d7d.dir
-  params:
-    params.yaml:
-      translations.extraction.num_partitions: 2000
-  outs:
-  - path: generated/translations/stage1_extraction
-    md5: c7f5bb265082fdd21b8936ddca14a8ab.dir
-translations_tokenization:
-  cmd: python3 -m scripts.translation_based.stage2_tokenization
-  deps:
-  - path: generated/translations/stage1_extraction
-    md5: 61a1a88c672e485fd9b0dc0ef22817a9.dir
-  params:
-    params.yaml:
-      global.base_model: bert-base-multilingual-cased
-  outs:
-  - path: generated/translations/stage2_tokenization
-    md5: b4132fb48d63c09ee5fd5e017f5c279c.dir
-actions_extraction:
-  cmd: python3 -m scripts.actions_based.stage1_extraction
-  deps:
-  - path: data
-    md5: 1fa175e752af1638dc896838e82a9d7d.dir
-  - path: scripts/actions_based/stage1_extraction.py
-    md5: a01f6ee74e165e7c3d6b21c648482d45
-  params:
-    params.yaml:
-      actions.extraction.num_partitions: 2000
-  outs:
-  - path: generated/actions/stage1_extraction
-    md5: 8c9d822cc101faf137bd54932c94f922.dir
-actions_tokenization:
-  cmd: python3 -m scripts.actions_based.stage2_tokenization
-  deps:
-  - path: generated/actions/stage1_extraction
-    md5: 8c9d822cc101faf137bd54932c94f922.dir
-  - path: scripts/actions_based/stage2_tokenization.py
-    md5: 6360e7facd4af85d2deb0deabb4cc448
-  params:
-    params.yaml:
-      actions.tokenization.max_tokens: 500
-      actions.tokenization.min_tokens: 10
-      global.base_model: dkleczek/bert-base-polish-cased-v1
-  outs:
-  - path: generated/actions/stage2_tokenization
-    md5: a1a31dc4baa92b775e44c335e3f75a9c.dir
-actions_exploding:
-  cmd: python3 -m scripts.actions_based.stage3_exploding
-  deps:
-  - path: generated/actions/stage2_tokenization
-    md5: a1a31dc4baa92b775e44c335e3f75a9c.dir
-  - path: scripts/actions_based/stage3_exploding.py
-    md5: f65f552b17d012c53b5a42406cb88bcd
-  outs:
-  - path: generated/actions/stage3_exploding
-    md5: 6db856e40b88769840799232b23c2058.dir
-actions_reindexing:
-  cmd: python3 -m scripts.actions_based.stage4_reindexing
-  deps:
-  - path: generated/actions/stage3_exploding
-    md5: 6db856e40b88769840799232b23c2058.dir
-  - path: scripts/actions_based/stage4_reindexing.py
-    md5: 7841f8c3acdc12a5dc0adef12b8b8cbc
-  outs:
-  - path: generated/actions/stage4_reindexing
-    md5: 446e8e2b2011af28fcfa63557c2b5808.dir
-actions_training:
-  cmd: python3 -m scripts.actions_based.train
-  deps:
-  - path: generated/actions/stage4_reindexing
-    md5: 446e8e2b2011af28fcfa63557c2b5808.dir
-  - path: scripts/actions_based/train.py
-    md5: ef61cad42a6be6f862051530bbc6965b
-  params:
-    params.yaml:
-      actions.training.batch_size: 2
-      actions.training.learning_rate: 0.0001
-      actions.training.max_training_time: 2m
-      actions.training.num_epochs: 5
-      actions.training.save_step: 1000
-      global.base_model: dkleczek/bert-base-polish-cased-v1
-  outs:
-  - path: checkpoints/actions
-    md5: 6116b19bae31f503a350b635125a6daf.dir
-translations_create_batches:
-  cmd: python3 -m scripts.translation_based.stage2_create_batches
-  deps:
-  - path: generated/translations/stage1_extraction
-    md5: c7f5bb265082fdd21b8936ddca14a8ab.dir
-  params:
-    params.yaml:
-      global.base_model: dkleczek/bert-base-polish-cased-v1
-  outs:
-  - path: generated/translations/stage2_create_batches
-    md5: 730e90598dac106a9088eb0906caa227.dir
-translations_exploding:
-  cmd: python3 -m scripts.translation_based.stage3_exploding
-  deps:
-  - path: generated/translations/stage2_create_batches
-    md5: 730e90598dac106a9088eb0906caa227.dir
-  outs:
-  - path: generated/translations/stage3_exploding
-    md5: 918ba496477757257e702b12da0ef21e.dir
-translations_reindexing:
-  cmd: python3 -m scripts.translation_based.stage4_reindexing
-  deps:
-  - path: generated/translations/stage3_exploding
-    md5: 918ba496477757257e702b12da0ef21e.dir
-  outs:
-  - path: generated/translations/stage4_reindexing
-    md5: caa09e33b141187800d330ab131a45e0.dir
-- 
GitLab


From 587682a7de6d25d7b048eb5106a9d1e5d5716c55 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 13:33:29 +0200
Subject: [PATCH 048/116] Bumped tox version on CI

---
 .gitlab-ci.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 6504c49..c74b4c8 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -10,7 +10,7 @@ stages:
   - build
 
 before_script:
-  - pip install tox==2.9.1
+  - pip install tox==3.19.0
 
 pep8:
   stage: check_style
-- 
GitLab


From 70a3cdb76bb71d6321418e9a729c04a26e57f0a6 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 13:35:14 +0200
Subject: [PATCH 049/116] Workaround of wrong symlink in base python3.8 image

---
 tox.ini | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tox.ini b/tox.ini
index d7149d3..230f9c4 100644
--- a/tox.ini
+++ b/tox.ini
@@ -42,7 +42,7 @@ ignore = E203, E501, W503, C901
 [testenv:pep8]
 deps =
     flake8
-basepython = python3
+basepython = python
 commands =
     flake8 {posargs}
 
-- 
GitLab


From 9abc1c79fbcd1a559061e676e3471fee6ed7bd73 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 13:39:54 +0200
Subject: [PATCH 050/116] Workaround of wrong symlink in base python3.8 image

---
 .gitlab-ci.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index c74b4c8..4d60851 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -1,4 +1,4 @@
-image: clarinpl/python:3.8
+image: python:3.8.5
 
 cache:
   paths:
-- 
GitLab


From 6e778bc8ef98e74fb3643e9f8d4461518c7bd194 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 17:25:42 +0200
Subject: [PATCH 051/116] Added missing docs, style fixes

---
 .isort.cfg                                    |   3 +
 punctuate.py                                  |  45 +-
 src/batch_loading.py                          |   5 +-
 src/models/TransformerSeq2Seq.py              |   3 +-
 src/pipelines/actions_based/processing.py     | 640 +++++++++++++++---
 .../actions_based/stage1_extraction.py        |  29 +-
 .../actions_based/stage2_tokenization.py      |  14 +-
 .../actions_based/stage3_exploding.py         |   8 +-
 .../actions_based/stage4_reindexing.py        |  10 +-
 src/pipelines/actions_based/stage5_stats.py   |  11 +-
 src/pipelines/actions_based/train.py          |  25 +-
 src/pipelines/actions_based/utils.py          |  75 +-
 src/pipelines/translation_based/processing.py |  57 +-
 .../translation_based/stage1_extraction.py    |  13 +-
 .../stage2_create_batches.py                  |  12 +-
 .../translation_based/stage3_exploding.py     |  11 +-
 .../translation_based/stage4_reindexing.py    |   1 +
 src/pipelines/translation_based/train.py      |  21 +-
 src/processing.py                             | 504 ++------------
 src/training.py                               |   6 +-
 src/utils.py                                  |   7 +-
 .../actions_based/test_processing.py          | 214 +++++-
 .../translation_based/test_processing.py      |  11 +-
 tests/test_batch_loading.py                   |   7 +-
 tests/test_processing.py                      | 210 ------
 tox.ini                                       |   6 +-
 worker.py                                     |   4 +-
 27 files changed, 971 insertions(+), 981 deletions(-)
 create mode 100644 .isort.cfg
 delete mode 100644 tests/test_processing.py

diff --git a/.isort.cfg b/.isort.cfg
new file mode 100644
index 0000000..9e5a06c
--- /dev/null
+++ b/.isort.cfg
@@ -0,0 +1,3 @@
+[settings]
+profile=hug
+src_paths=src,test
diff --git a/punctuate.py b/punctuate.py
index 05d4251..05ba1ba 100755
--- a/punctuate.py
+++ b/punctuate.py
@@ -1,11 +1,14 @@
 #!/usr/bin/python3
 
 import argparse
+from argparse import Namespace
 import os
-from src.pipelines.actions_based.utils import load_model
+
 from src.pipelines.actions_based.processing import apply_actions_punctuation
+from src.pipelines.actions_based.utils import load_model
 
-if __name__ == "__main__":
+
+def get_args() -> Namespace:
     parser = argparse.ArgumentParser(
         description="Adds punctuaiton in to raw text stream."
     )
@@ -16,45 +19,27 @@ if __name__ == "__main__":
         "-o", "--output", type=str, required=True, help="Path to input text file",
     )
     parser.add_argument(
-        "-m",
-        "--model",
-        required=True,
-        type=str,
-        help="Path to the pretrained model",
+        "-m", "--model", required=True, type=str, help="Path to the pretrained model",
     )
     parser.add_argument(
-        "-b",
-        "--base",
-        required=True,
-        type=str,
-        help="Name of base model",
+        "-b", "--base", required=True, type=str, help="Name of base model",
     )
     parser.add_argument(
-        '-c',
-        '--chunk_size',
-        default=500,
-        type=int,
-        help="Maximum chunk size"
-    )
-    parser.add_argument(
-        '-t',
-        '--threshold',
-        default=0.9,
-        type=float,
-        help="Threshold"
+        "-c", "--chunk_size", default=500, type=int, help="Maximum chunk size"
     )
+    parser.add_argument("-t", "--threshold", default=0.9, type=float, help="Threshold")
+
+    return parser.parse_args()
 
-    args = parser.parse_args()
+
+if __name__ == "__main__":
+    args = get_args()
 
     if not os.path.exists(args.input):
         print(f"Error: File '{args.input}' does not exists")
         exit(-1)
 
-    tokenizer, model = load_model(
-        args.model,
-        args.base,
-        "cpu"
-    )
+    tokenizer, model = load_model(args.model, args.base, "cpu")
 
     with open(args.input, "r") as f:
         text = f.read()
diff --git a/src/batch_loading.py b/src/batch_loading.py
index cb569f1..e0527e6 100644
--- a/src/batch_loading.py
+++ b/src/batch_loading.py
@@ -1,7 +1,8 @@
+from typing import Union
+
+import dask.dataframe as dd
 import numpy as np
 import pandas as pd
-import dask.dataframe as dd
-from typing import Union
 
 
 def calculate_batch_buffer_id(batch_id: int, buffer_batch_num: int) -> int:
diff --git a/src/models/TransformerSeq2Seq.py b/src/models/TransformerSeq2Seq.py
index 00cedf9..753df1e 100644
--- a/src/models/TransformerSeq2Seq.py
+++ b/src/models/TransformerSeq2Seq.py
@@ -1,6 +1,7 @@
+import math
+
 import torch
 import torch.nn as nn
-import math
 
 
 class PositionalEncoding(nn.Module):
diff --git a/src/pipelines/actions_based/processing.py b/src/pipelines/actions_based/processing.py
index 84d7ed8..fc5ac1f 100644
--- a/src/pipelines/actions_based/processing.py
+++ b/src/pipelines/actions_based/processing.py
@@ -1,41 +1,66 @@
-from transformers import BertTokenizerFast
-from src.processing import (
-    tokenize_labeled_text,
-    batchify_data,
-    encode_actions,
-)
+from collections import defaultdict
+from typing import List, Mapping, Optional, Tuple
+from xml.etree import ElementTree as ET
+
 import numpy as np
-from typing import Optional
-import torch.nn as nn
-import torch
-from src.processing import token_labels_to_word_labels, recover_text
+from transformers import BertTokenizerFast
+from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
 
+from src.utils import remove_punctuation
 
-def expand_dims(entry: dict):
-    inputs = entry.input.reshape(entry.input_shape)
-    outputs = entry.output.reshape(entry.output_shape)
-    masks = entry.attention_mask.reshape(entry.attention_mask_shape)
+ACTIONS_KEYS = ["dot", "upper_case", "colon", "question_mark"]
 
-    return {
-        "input": inputs,
-        "output": outputs,
-        "attention_mask": masks,
-    }
 
+def apply_file_processing(x: dict) -> dict:
+    """Creates input-output pairs from xml file from dataset
 
-EXPAND_DIMS_META = {
-    "input": object,
-    "output": object,
-    "attention_mask": object,
+    Args:
+        x (dict): Dask dataframe row with columns: file
+
+    Returns:
+        dict: Dask dataframe row with columns: source, target, target_shape
+    """
+    full_text = text_from_xml(x.file)
+
+    if len(full_text) > 0:
+        model_input, model_output = create_model_input_output(full_text)
+
+        output_shape = np.array(model_output.shape, dtype=np.int)
+
+        return {
+            "source": model_input,
+            "target": model_output.reshape(-1),
+            "target_shape": output_shape,
+        }
+    else:
+        return {"source": None, "target": None, "target_shape": None}
+
+
+APPLY_FILE_PROCESSING_META = {
+    "source": object,
+    "target": object,
+    "target_shape": object,
 }
 
 
 def apply_tokenization(
-    df, min_tokens: int, max_tokens: int, tokenizer: BertTokenizerFast
-):
-    text_clean = df.input
-    labels = df.output
-    shape = df.output_shape
+    df: dict, min_tokens: int, max_tokens: int, tokenizer: BertTokenizerFast
+) -> dict:
+    """Applies tokenization and chunking
+
+    Args:
+        df (dict): Dataframe entry with columns: source, target, target_shape
+        min_tokens (int): Minimum number of tokens in a single training example
+        max_tokens (int): Maximum number of tokens in a single testing example
+        tokenizer (BertTokenizerFast): Tokenizer that will be used for tokenization
+
+    Returns:
+        dict: Dataframe entry with columns: source, target, attention_mask, source_shape
+                , target_shape, attention_mask_shape
+    """
+    text_clean = df.source
+    labels = df.target
+    shape = df.target_shape
 
     tokens, token_labels = tokenize_labeled_text(
         text_clean, labels.reshape(shape), tokenizer
@@ -50,55 +75,26 @@ def apply_tokenization(
     attentions_shape = np.array(attentions.shape)
 
     return {
-        "input": inputs.reshape(-1),
-        "output": outputs.reshape(-1),
+        "source": inputs.reshape(-1),
+        "target": outputs.reshape(-1),
         "attention_mask": attentions.reshape(-1),
-        "input_shape": inputs_shape,
-        "output_shape": outputs_shape,
+        "source_shape": inputs_shape,
+        "target_shape": outputs_shape,
         "attention_mask_shape": attentions_shape,
     }
 
 
 APPLY_TOKENIZATION_META = {
-    "input": object,
-    "output": object,
+    "source": object,
+    "target": object,
     "attention_mask": object,
-    "input_shape": object,
-    "output_shape": object,
+    "source_shape": object,
+    "target_shape": object,
     "attention_mask_shape": object,
 }
 
 
-def flatten_dims(entry):
-    inputs_shape = np.array(entry.input.shape)
-    outputs_shape = np.array(entry.output.shape)
-    attentions_shape = np.array(entry.attention_mask.shape)
-
-    inputs = entry.input.reshape(-1)
-    outputs = entry.output.reshape(-1)
-    attentions = entry.attention_mask.reshape(-1)
-
-    return {
-        "input": inputs,
-        "output": outputs,
-        "attention_mask": attentions,
-        "input_shape": inputs_shape,
-        "output_shape": outputs_shape,
-        "attention_mask_shape": attentions_shape,
-    }
-
-
-FLATTEN_DIMS_META = {
-    "input": object,
-    "output": object,
-    "attention_mask": object,
-    "input_shape": object,
-    "output_shape": object,
-    "attention_mask_shape": object,
-}
-
-
-def action_vector(actions: [str]) -> np.ndarray:
+def action_vector(actions: List[str]) -> np.ndarray:
     """Transforms array of label names into an action vector.
 
     Args:
@@ -138,61 +134,493 @@ def last_stop_label(labels: np.array, stop_action: np.array) -> Optional[int]:
     return stop_labels[-1][0]
 
 
-def apply_actions_punctuation(
-    text: str,
-    chunk_size: int,
-    tokenizer: BertTokenizerFast,
-    model: nn.Module,
-    threshold: float = 0.9,
-) -> str:
-    """Adds punctuation to text using actions model
+def empty_action_vector() -> np.ndarray:
+    """Returns a do-nothing actions vector
+
+    Returns:
+        np.ndarray: Vector with all zeroes and length of ACTION_KEYS
+    """
+    return np.zeros(len(ACTIONS_KEYS))
+
+
+def empty_action_dict() -> dict:
+    """Returns a do-noting unencoded action dict
+
+    Returns:
+        dict: Action dict with all actions set to False
+    """
+
+    return decode_actions(empty_action_vector())
+
+
+def text_from_xml(path: str) -> str:
+    """Extract spoken text from dataset's xml format
+
+    Args:
+        path (str): Path to xml
+
+    Returns:
+        str: Raw text
+    """
+    root = ET.parse(path).getroot()
+
+    full_text = ""
+
+    for node in root.iter("*"):
+        if len(node) == 0:
+            who = node.get("who")
+            text = node.text
+
+            if text is not None and who is not None and who != "#komentarz":
+                full_text = " ".join([full_text, text])
+
+    del root
+
+    return full_text
+
+
+def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
+    """Detect what actions should model perform on a word and returns encoded
+       action vector
+
+    Args:
+        word (str): Word on wich action is decided
+        next_word (Optional[str]): Word that follows considered word. Can be
+            None if nothing follows a word
+
+    Returns:
+        Mapping[str, bool]: Mapping telling if each of possible actions should be performed (True) or not (False)
+    """
+    # Unsuported characters
+    word.replace(";", ".")
+    word.replace('"', "")
+    word.replace("(", "")
+    word.replace(")", "")
+
+    while len(word) > 0 and not word[0].isalnum():  # remove proceding characters
+        word = word[1:]
+
+    if len(word) == 0:
+        return dict(zip(ACTIONS_KEYS, [False] * len(ACTIONS_KEYS)))
+
+    actions = {
+        "dot": word[-1] == ".",
+        "upper_case": word[0].isupper(),
+        "colon": word[-1] == ",",
+        "question_mark": word[-1] == "?",
+    }
+
+    return actions
+
+
+def encode_actions(actions: Mapping[str, bool]) -> np.ndarray:
+    """Transforms actions into vector
+
+    Args:
+        actions (Mapping[str, bool]): Map telling which actions should be made
+
+    Returns:
+        np.ndarray: 1 dimensional action vector
+    """
+    return np.array(list(actions.values())).astype(float)
+
+
+def decode_actions(encoded_actions: np.ndarray) -> Mapping[str, bool]:
+    """Decodes actions
+
+    Args:
+        encoded_actions (np.ndarray): 1 dimensional action vector
+
+    Returns:
+        Mapping[str, bool]: Map telling which actions should be made
+    """
+    assert encoded_actions.shape[0] == len(ACTIONS_KEYS)
+
+    return dict(zip(ACTIONS_KEYS, encoded_actions.astype(np.bool).tolist()))
+
+
+def create_model_input_output(text: str) -> Tuple[str, np.ndarray]:
+    """Returns a pair of input and desired output of the model
+
+    Args:
+        text (str): Correct text sample
+
+    Returns:
+        text_cleaned (str): Text without any interpuction and all lowercase
+        actions (np.ndarray): To dimensional array, where each row is aciton vector for each word (columns)
+    """
+    words = text.split(" ")
+
+    words_output = []
+    actions_output = []
+
+    i = 0
+    while i < len(words):
+        word = words[i]
+        next_word = words[i + 1] if len(words) > i + 1 else None
+
+        word_sanitized = remove_punctuation(word).lower()
+        if len(word_sanitized) > 0:
+            actions = detect_actions(word, next_word)
+            actions_encoded = encode_actions(actions)
+
+            words_output.append(word_sanitized)
+            actions_output.append(actions_encoded)
+
+        i += 1
+
+    assert len(words_output) == len(actions_output)
+
+    return " ".join(words_output), np.array(actions_output)
+
+
+def token_word_mapping(text: str, tokenizer: PreTrainedTokenizerFast) -> np.ndarray:
+    """Returns mapping where each token is labeled with index of word it's part of
+
+    Args:
+        text (str): Input text
+        tokenizer (PreTrainedTokenizerFast): Tokenizer used to tokenize text
+
+    Returns:
+        np.ndarray: Array of length L (number of tokens) where each entry is index of word (cls and sep labels are not counted).
+    """
+    text_tokenized = tokenizer(text, return_offsets_mapping=True)
+    offset_mappings = text_tokenized["offset_mapping"][1:-1]
+
+    offset_mappings = text_tokenized["offset_mapping"][1:-1]
+
+    # Create a map where each character is assigned index of it's word
+    words_mapping = []
+    actual_word = 0
+    for character in text:
+        words_mapping.append(actual_word)
+        if character == " ":
+            actual_word += 1
+
+    token_mapping = [words_mapping[x[0]] for x in offset_mappings]
+
+    return np.array(token_mapping)
+
+
+def token_labels_to_word_labels(
+    text: str, token_labels: np.ndarray, tokenizer: PreTrainedTokenizerFast
+) -> np.ndarray:
+    mapping = token_word_mapping(text, tokenizer)
+
+    assert len(mapping) == len(token_labels)
+
+    labels = defaultdict(list)
+
+    for i in range(len(mapping)):
+        labels[mapping[i]].append(token_labels[i])
+
+    return np.array([np.mean(labels[x], axis=0) for x in sorted(labels)])
+
+
+def tokenize_labeled_text(
+    text: str, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast
+) -> Tuple[np.ndarray, np.ndarray]:
+    """Transforms text into numerical tokens. Also expand word-level labels into token-level labels
+
+    Args:
+        text (str): Text that will be tokenized (TODO: Change to array)
+        labels (np.ndarray): Word-level labels for text to be tokenized. Word is defined via space spearation
+        tokenizer (PreTrainedTokenizerFast): Tokenizer that will be used for tokenization
+
+    Returns:
+        np.ndarray: 2-dimensional array with tokens (without cls and sep tokens!)
+        np.ndarray 2-dimensional array with token-level labels
+    """
+    text_tokenized = tokenizer(text, return_offsets_mapping=True)
+
+    offset_mappings = text_tokenized["offset_mapping"][1:-1]
+    input_ids = text_tokenized["input_ids"][1:-1]
+
+    # Create a map where each character is assigned index of it's word
+    words_mapping = []
+    actual_word = 0
+    for character in text:
+        words_mapping.append(actual_word)
+        if character == " ":
+            actual_word += 1
+
+    # Assign each token to a word
+    token_mapping = [words_mapping[x[0]] for x in offset_mappings]
+
+    # Expand word-based labels to token-based labels
+    labels_tokenized = [labels[i] for i in token_mapping]
+
+    return np.array(input_ids).reshape(-1, 1), np.array(labels_tokenized)
+
+
+def recover_word(word: str, action: Mapping[str, bool]) -> str:
+    """Applies action to a word
 
     Args:
-        text (str): Raw, unpuctuated text
-        chunk_size (int): Maxium number of tokens to precess at once (both memory & computing scales ~O(n^2))
-        tokenizer (BertTokenizerFast): Tokenizer to use
-        model (nn.Module): Trained actions model
-        threshold (float, optional): Threshold after which action will be applied. Defaults to 0.9.
+        word (str): word on which action will be applied
+        action (Mapping[str, bool]): Action to be applied
 
     Returns:
-        str: [description]
+        str: transfomed word
     """
+    word_result = word
+
+    if action["dot"]:
+        word_result += "."
+    if action["upper_case"]:
+        word_result = word_result.capitalize()
+    if action["colon"]:
+        word_result += ","
+    if action["question_mark"]:
+        word_result += "?"
 
-    text = text.strip()
+    return word_result
 
-    tokens = tokenizer(text, return_tensors="pt")["input_ids"]
-    output = None
 
-    index_start = 0
-    while index_start < len(tokens[0]):
-        index_end = min(index_start + chunk_size, len(tokens[0]))
+def is_sentence_end(actions_encoded: np.ndarray) -> bool:
+    """Returns if given action would end a sentence
 
-        tokens_chunk = tokens[:, index_start:index_end]
+    Args:
+        actions_encoded (np.ndarray): Action vector
 
-        raw_output = model(
-            input_ids=tokens_chunk,
-            token_type_ids=torch.zeros_like(tokens_chunk),
-            attention_mask=torch.ones_like(tokens_chunk),
-        )[0].sigmoid()
-        raw_output = raw_output[0].detach().numpy()
+    Returns:
+        bool: True if action would end a sentence, False otherwise
+    """
+    actions_decoded = decode_actions(actions_encoded)
 
-        actions = raw_output > threshold
-        offset = last_stop_label(actions, action_vector("dot"))
+    return actions_decoded["dot"] is True
 
-        # Prevent infinite loop
-        if (offset is None) or (offset == 0):
-            offset = index_end - index_start
 
-        if output is None:
-            output = raw_output[0:offset]
+def nearest_sentence_l(labels: np.array, index_start: int) -> int:
+    """Find nearest word that begins a sentence that has lower or equal index to index_start
+
+    Args:
+        labels (np.array): 2-dimensonal array of action-vectors
+        index_start (int): Index from which search will be started
+
+    Returns:
+        int: Index of nearest left-oriented start of the sentence. If no sentence is found, first index is assumed to
+             start a sentence
+    """
+    result_index = index_start
+
+    while result_index > 0:
+        if is_sentence_end(labels[result_index, :]):
+            # prevent beeing in the middle of token
+            result_index -= 1
+        elif is_sentence_end(labels[result_index - 1, :]):
+            break
+        elif result_index == 1:
+            result_index = 0
+            break
+        else:
+            result_index -= 1
+
+    return result_index
+
+
+def nearest_sentence_r(labels: np.array, index_start: int) -> Optional[int]:
+    """Find nearest word that begins a sentence that has higher or equal index to index_start
+
+    Args:
+        labels (np.array): 2-dimensonal array of action-vectors
+        index_start (int): Index from which search will be started
+
+    Returns:
+        int: Index of nearest right-oriented start of the sentence. None if no later sentence is found
+    """
+    result_index = index_start
+
+    while result_index < len(labels):
+        if is_sentence_end(labels[result_index - 1]) and not is_sentence_end(
+            labels[result_index]
+        ):
+            break
+        else:
+            result_index += 1
+
+    if result_index >= len(labels):
+        return None
+    else:
+        return result_index
+
+
+def batchify_labels(
+    labels: np.ndarray, max_tokens: int, min_tokens: int = 3
+) -> List[np.ndarray]:
+    """Splits long labels array into batches of desired size
+
+    Args:
+        labels (np.ndarray): 2-dimensional array of action-vectors
+        max_tokens (int): Maximum number of labels in a single batch
+        min_tokens (int, optional): Minimum number of labels in a single batch. Defaults to 3.
+
+    Returns:
+        [np.ndarray]: List of arrays with indexes composing each batch
+    """
+    assert min_tokens >= 1
+    assert max_tokens >= 1
+
+    labels_batches = []
+
+    index = 0
+    new_index = 0
+    while index < (labels.shape[0] - min_tokens):
+        num_consumed = min(max_tokens, labels.shape[0] - index)
+
+        assert num_consumed >= min_tokens
+
+        if index + num_consumed < (labels.shape[0] - min_tokens):
+            new_index = nearest_sentence_l(labels, index + num_consumed)
+            if new_index == index:
+                new_index = nearest_sentence_r(labels, index + num_consumed)
+                if new_index is None:
+                    labels_batches.append(
+                        np.array(list(range(index, index + num_consumed)))
+                    )
+                    break
         else:
-            output = np.concatenate([output, raw_output[0:offset]], axis=0)
+            labels_batches.append(np.array(list(range(index, index + num_consumed))))
+            break
+
+        labels_batches.append(np.array(list(range(index, index + num_consumed))))
+
+        index = new_index
+
+    return labels_batches
+
+
+def add_cls_sep(
+    tokens: np.ndarray, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast
+) -> Tuple[np.ndarray, np.ndarray]:
+    """Adds staring cls and ending sep token ids into tokens & labels
+
+    Args:
+        tokens (np.ndarray): 2-dimensional array (with 1 feature!) of tokens
+        labels (np.ndarray): 2-dimensional array of action vectors
+
+    Returns:
+        np.ndarray: tokens with added cls & sep tokens ids
+        np.ndarray: labels with first and last item duplicated to accomodate for cls & sep
+    """
+
+    tokens = np.concatenate(
+        [[[tokenizer.cls_token_id]], tokens, [[tokenizer.sep_token_id]]]
+    )
+    labels = np.concatenate([labels[:1, :], labels, labels[-1:, :]])
+
+    return tokens, labels
+
+
+def add_padding(
+    tokens: np.ndarray,
+    labels: np.ndarray,
+    length: int,
+    tokenizer: PreTrainedTokenizerFast,
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+    """Appends padding to tokens and labels to match desired length
+
+    Args:
+        tokens (np.ndarray): Lx1 array of token ids
+        labels (np.ndarray): LxA array of action vectors
+        length (int): Desired length of a vector. Must be higher than L
+        tokenizer (PreTrainedTokenizerFast): Tokenizer that was used for tokenization
+
+    Returns:
+        np.ndarray: (L+P)x1 array of token ids with added padding
+        np.ndarray: (L+P)xA array of action vectors with added padding
+        np.ndarray: (L+P)-length array of masks where True means token False - padding
+    """
+
+    pad_length = length - tokens.shape[0]
+    assert pad_length >= 0
+
+    if pad_length > 0:
+        tokens = np.concatenate([tokens, [[tokenizer.pad_token_id]] * pad_length])
+        labels = np.concatenate([labels, [empty_action_vector()] * pad_length])
+
+    mask = np.ones(len(tokens)).astype(np.int)
+
+    if pad_length > 0:
+        mask[-pad_length:] = False
+
+    return tokens, labels, mask
+
+
+def batchify_data(
+    tokens: np.ndarray,
+    labels: np.ndarray,
+    max_tokens: int,
+    tokenizer: PreTrainedTokenizerFast,
+    min_tokens: int = 3,
+) -> Tuple[np.ndarray, np.ndarray]:
+    """Chop long tokens-labels pair into smaller ones of equal length (with added padding)
+
+    Args:
+        tokens (np.ndarray): Tokens representing long, unpunctuated text (Shape L)
+        labels (np.ndarray): Action-labels to transform provided text into punctuated one (Shape LxA)
+        max_tokens (int): Maxium number of tokens in a single entry
+        tokenizer (PreTrainedTokenizerFast): Tokenizer used to tokenize sentence into tokens
+        min_tokens (int, optional): Minimum number of tokens in a sentence. Defaults to 3.
+
+    Returns:
+        Tuple[np.ndarray, np.ndarray]:
+            tokens_batch - Tokens array splitted into smaller chunks. (Shape (max_tokens)xL )
+            labels_batch - LAbels array splitted into smaller chunks. (Shape (max_tokens)xLxA )
+
+    """
+
+    assert max_tokens >= min_tokens + 2
+    assert min_tokens >= 1
+
+    tokens_batch = []
+    labels_batch = []
+    mask_batch = []
+
+    idxs = batchify_labels(labels, max_tokens - 2, min_tokens)
+
+    for ids in idxs:
+        tokens_sample = tokens[ids, :]
+        labels_sample = labels[ids, :]
+
+        assert len(ids) >= min_tokens
+        assert len(ids) <= max_tokens - 2
+
+        tokens_sample, labels_sample = add_cls_sep(
+            tokens_sample, labels_sample, tokenizer
+        )
+
+        assert len(tokens_sample) <= max_tokens
+
+        tokens_sample, labels_sample, mask = add_padding(
+            tokens_sample, labels_sample, max_tokens, tokenizer
+        )
+
+        tokens_batch.append(tokens_sample)
+        labels_batch.append(labels_sample)
+        mask_batch.append(mask)
+
+    return np.array(tokens_batch), np.array(labels_batch), np.array(mask_batch)
+
+
+def recover_text(text: str, actions_encoded: np.ndarray) -> str:
+    """Applies per-word actions to unpunctuated text
+
+    Args:
+        text (str): lowercase, unpunctuated text
+        actions_encoded (np.ndarray): Array of per-word action vectors (Shape LxA)
+
+    Returns:
+        str: Punctuated version of the text
+    """
+    words = text.split(" ")
 
-        index_start += offset
+    words_output = []
 
-    assert len(output) == len(tokens[0])
+    for word, action_encoded in zip(words, actions_encoded.tolist()):
+        action_decoded = decode_actions(np.array(action_encoded))
 
-    word_labels = token_labels_to_word_labels(text, output[1:-1], tokenizer)
-    actions = word_labels > threshold
+        word_recovered = recover_word(word, action_decoded)
+        words_output.append(word_recovered)
 
-    return recover_text(text, actions)
+    return " ".join(words_output)
diff --git a/src/pipelines/actions_based/stage1_extraction.py b/src/pipelines/actions_based/stage1_extraction.py
index cf31c06..5a058a9 100644
--- a/src/pipelines/actions_based/stage1_extraction.py
+++ b/src/pipelines/actions_based/stage1_extraction.py
@@ -1,33 +1,18 @@
 # /usr/bin/python3
 import glob
-import numpy as np
-from src.processing import text_from_xml, create_model_input_output
+
 import dask.dataframe as dd
+import numpy as np
 import pandas as pd
 from dask.distributed import Client
-from src.utils import get_config, PROJECT_ROOT, prepare_folder
+
+from src.pipelines.actions_based.processing import APPLY_FILE_PROCESSING_META, apply_file_processing
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
 
 
-def process_file(x):
-    full_text = text_from_xml(x.file)
-
-    if len(full_text) > 0:
-        model_input, model_output = create_model_input_output(full_text)
-
-        output_shape = np.array(model_output.shape, dtype=np.int)
-
-        return {
-            "input": model_input,
-            "output": model_output.reshape(-1),
-            "output_shape": output_shape,
-        }
-    else:
-        return {"input": None, "output": None, "output_shape": None}
-
-
 if __name__ == "__main__":
 
     config = get_config()
@@ -50,10 +35,10 @@ if __name__ == "__main__":
     df = dd.from_pandas(pd.DataFrame({"file": files_paths}), npartitions=num_partitions)
 
     df = df.apply(
-        process_file,
+        apply_file_processing,
         result_type="expand",
         axis=1,
-        meta={"input": str, "output": object, "output_shape": object},
+        meta=APPLY_FILE_PROCESSING_META,
     )
     df = df.dropna()
 
diff --git a/src/pipelines/actions_based/stage2_tokenization.py b/src/pipelines/actions_based/stage2_tokenization.py
index affa6e9..0ea3586 100644
--- a/src/pipelines/actions_based/stage2_tokenization.py
+++ b/src/pipelines/actions_based/stage2_tokenization.py
@@ -1,17 +1,11 @@
 # /usr/bin/python3
-from src.utils import (
-    PROJECT_ROOT,
-    get_config,
-    prepare_folder,
-)
 import dask
 import dask.dataframe as dd
-from transformers import BertTokenizerFast
 from dask.distributed import Client
-from src.pipelines.actions_based.processing import (
-    apply_tokenization,
-    APPLY_TOKENIZATION_META,
-)
+from transformers import BertTokenizerFast
+
+from src.pipelines.actions_based.processing import APPLY_TOKENIZATION_META, apply_tokenization
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
diff --git a/src/pipelines/actions_based/stage3_exploding.py b/src/pipelines/actions_based/stage3_exploding.py
index 2c270ce..81dc965 100644
--- a/src/pipelines/actions_based/stage3_exploding.py
+++ b/src/pipelines/actions_based/stage3_exploding.py
@@ -1,13 +1,9 @@
 # /usr/bin/python3
 import dask.dataframe as dd
 from dask.distributed import Client
+
+from src.processing import EXPAND_DIMS_META, FLATTEN_DIMS_META, expand_dims, flatten_dims
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
-from src.pipelines.actions_based.processing import (
-    expand_dims,
-    EXPAND_DIMS_META,
-    flatten_dims,
-    FLATTEN_DIMS_META,
-)
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
diff --git a/src/pipelines/actions_based/stage4_reindexing.py b/src/pipelines/actions_based/stage4_reindexing.py
index 3fcaa76..fade725 100644
--- a/src/pipelines/actions_based/stage4_reindexing.py
+++ b/src/pipelines/actions_based/stage4_reindexing.py
@@ -1,6 +1,7 @@
 # /usr/bin/python3
 import dask.dataframe as dd
 from dask.distributed import Client
+
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage3_exploding"
@@ -24,14 +25,5 @@ if __name__ == "__main__":
     idx = (df.ones.cumsum() - 1).persist()
     df = df.assign(ones=idx)
 
-    # Shuffle
-    # shuffled_idx = idx.compute().values
-    # np.random.shuffle(shuffled_idx)
-    # shuffled_idx = client.scatter(shuffled_idx)
-    # mapped_ones = df.ones.apply(lambda x, idx: idx[x], args=(shuffled_idx,), meta=('ones', 'int64'))
-    # df = df.assign(ones=mapped_ones)
-
-    # df = df.persist()
-
     df = df.set_index("ones", shuffle="tasks")
     df.to_parquet(OUTPUT_FOLDER, engine="pyarrow")
diff --git a/src/pipelines/actions_based/stage5_stats.py b/src/pipelines/actions_based/stage5_stats.py
index 3dd895b..d7f3ea2 100644
--- a/src/pipelines/actions_based/stage5_stats.py
+++ b/src/pipelines/actions_based/stage5_stats.py
@@ -1,14 +1,13 @@
 # /usr/bin/python3
-from src.processing import ACTIONS_KEYS
+import pickle
+
 import dask.dataframe as dd
 import numpy as np
 from dask.distributed import Client
+
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.processing import EXPAND_DIMS_META, expand_dims
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
-import pickle
-from src.pipelines.actions_based.processing import (
-    expand_dims,
-    EXPAND_DIMS_META,
-)
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
diff --git a/src/pipelines/actions_based/train.py b/src/pipelines/actions_based/train.py
index 45c300c..0a3ac69 100755
--- a/src/pipelines/actions_based/train.py
+++ b/src/pipelines/actions_based/train.py
@@ -1,22 +1,19 @@
 #!/usr/bin/python3
 
-from transformers import BertTokenizerFast, BertForTokenClassification
-import torch
-from torch.nn import BCEWithLogitsLoss
-import numpy as np
-import dask.dataframe as dd
 import glob
-from src.utils import (
-    PROJECT_ROOT,
-    get_config,
-    convert_to_timedelta,
-    prepare_folder,
-)
-from src.training import latest_model, save_training_step
-from src.processing import ACTIONS_KEYS
-from datetime import datetime
 import pickle
+from datetime import datetime
+
+import dask.dataframe as dd
+import numpy as np
+import torch
+from torch.nn import BCEWithLogitsLoss
+from transformers import BertForTokenClassification, BertTokenizerFast
+
 from src.batch_loading import get_batches
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.training import latest_model, save_training_step
+from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config, prepare_folder
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
 INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
diff --git a/src/pipelines/actions_based/utils.py b/src/pipelines/actions_based/utils.py
index d94a4e6..4f62b61 100644
--- a/src/pipelines/actions_based/utils.py
+++ b/src/pipelines/actions_based/utils.py
@@ -1,8 +1,17 @@
-from transformers import BertTokenizerFast, BertForTokenClassification, PretrainedConfig
-from src.processing import ACTIONS_KEYS
+from typing import Tuple
+
+import numpy as np
 import torch
 import torch.nn as nn
-from typing import Tuple
+from transformers import BertForTokenClassification, BertTokenizerFast, PretrainedConfig
+
+from src.pipelines.actions_based.processing import (
+    action_vector,
+    last_stop_label,
+    recover_text,
+    token_labels_to_word_labels,
+)
+from src.processing import ACTIONS_KEYS
 
 
 def load_model(
@@ -28,3 +37,63 @@ def load_model(
     model.load_state_dict(torch.load(model_path, map_location=device))
 
     return tokenizer, model
+
+
+def apply_actions_punctuation(
+    text: str,
+    chunk_size: int,
+    tokenizer: BertTokenizerFast,
+    model: nn.Module,
+    threshold: float = 0.9,
+) -> str:
+    """Adds punctuation to text using actions model
+
+    Args:
+        text (str): Raw, unpuctuated text
+        chunk_size (int): Maxium number of tokens to precess at once (both memory & computing scales ~O(n^2))
+        tokenizer (BertTokenizerFast): Tokenizer to use
+        model (nn.Module): Trained actions model
+        threshold (float, optional): Threshold after which action will be applied. Defaults to 0.9.
+
+    Returns:
+        str: [description]
+    """
+
+    text = text.strip()
+
+    tokens = tokenizer(text, return_tensors="pt")["input_ids"]
+    output = None
+
+    index_start = 0
+    while index_start < len(tokens[0]):
+        index_end = min(index_start + chunk_size, len(tokens[0]))
+
+        tokens_chunk = tokens[:, index_start:index_end]
+
+        raw_output = model(
+            input_ids=tokens_chunk,
+            token_type_ids=torch.zeros_like(tokens_chunk),
+            attention_mask=torch.ones_like(tokens_chunk),
+        )[0].sigmoid()
+        raw_output = raw_output[0].detach().numpy()
+
+        actions = raw_output > threshold
+        offset = last_stop_label(actions, action_vector("dot"))
+
+        # Prevent infinite loop
+        if (offset is None) or (offset == 0):
+            offset = index_end - index_start
+
+        if output is None:
+            output = raw_output[0:offset]
+        else:
+            output = np.concatenate([output, raw_output[0:offset]], axis=0)
+
+        index_start += offset
+
+    assert len(output) == len(tokens[0])
+
+    word_labels = token_labels_to_word_labels(text, output[1:-1], tokenizer)
+    actions = word_labels > threshold
+
+    return recover_text(text, actions)
diff --git a/src/pipelines/translation_based/processing.py b/src/pipelines/translation_based/processing.py
index 4dfddc1..41962da 100644
--- a/src/pipelines/translation_based/processing.py
+++ b/src/pipelines/translation_based/processing.py
@@ -1,6 +1,9 @@
-from src.processing import text_from_xml, remove_punctuation
-from transformers import BertTokenizerFast
+from typing import Tuple
+
 import numpy as np
+from transformers import BertTokenizerFast
+
+from src.pipelines.actions_based.processing import remove_punctuation, text_from_xml
 
 
 def raw_to_dataframe(entry: dict) -> dict:
@@ -78,54 +81,6 @@ GENERATE_BATCHES_META = {
 }
 
 
-def expand_dims(entry):
-    source = entry.source.reshape(entry.source_shape)
-    target = entry.target.reshape(entry.target_shape)
-    mask = entry.attention_mask.reshape(entry.attention_mask_shape)
-
-    return {
-        "source": source,
-        "target": target,
-        "attention_mask": mask,
-    }
-
-
-EXPAND_DIMS_META = {
-    "source": object,
-    "target": object,
-    "attention_mask": object,
-}
-
-
-def flatten_dims(entry):
-    source_shape = np.array(entry.source.shape)
-    target_shape = np.array(entry.target.shape)
-    mask_shape = np.array(entry.attention_mask.shape)
-
-    source = entry.source.reshape(-1)
-    target = entry.target.reshape(-1)
-    mask = entry.attention_mask.reshape(-1)
-
-    return {
-        "source": source,
-        "target": target,
-        "attention_mask": mask,
-        "source_shape": source_shape,
-        "target_shape": target_shape,
-        "attention_mask_shape": mask_shape,
-    }
-
-
-FLATTEN_DIMS_META = {
-    "source": object,
-    "target": object,
-    "attention_mask": object,
-    "source_shape": object,
-    "target_shape": object,
-    "attention_mask_shape": object,
-}
-
-
 def find_new_sentence_left(seq: np.array, pos: int) -> int:
     """Finds nerest sentence on the left of the current position (including current position)
 
@@ -273,7 +228,7 @@ def standarize_translation_sample(
 
 def create_input_output(
     tokens: np.ndarray, length: int, tokenizer: BertTokenizerFast
-) -> (np.ndarray, np.ndarray):
+) -> Tuple[np.ndarray, np.ndarray]:
     """Transforms a sequence of tokens into "translation" input and output
 
     Args:
diff --git a/src/pipelines/translation_based/stage1_extraction.py b/src/pipelines/translation_based/stage1_extraction.py
index 8581407..386211d 100644
--- a/src/pipelines/translation_based/stage1_extraction.py
+++ b/src/pipelines/translation_based/stage1_extraction.py
@@ -1,14 +1,13 @@
 # /usr/bin/python3
-from src.pipelines.translation_based.processing import (
-    raw_to_dataframe,
-    RAW_TO_DATAFRAME_META,
-)
-from src.utils import PROJECT_ROOT, prepare_folder, get_config
 from glob import glob
-import numpy as np
-from dask.distributed import Client
+
 import dask.dataframe as dd
+import numpy as np
 import pandas as pd
+from dask.distributed import Client
+
+from src.pipelines.translation_based.processing import RAW_TO_DATAFRAME_META, raw_to_dataframe
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
diff --git a/src/pipelines/translation_based/stage2_create_batches.py b/src/pipelines/translation_based/stage2_create_batches.py
index c1b7b0d..ade8bf2 100644
--- a/src/pipelines/translation_based/stage2_create_batches.py
+++ b/src/pipelines/translation_based/stage2_create_batches.py
@@ -1,13 +1,11 @@
 # /usr/bin/python3
-from src.pipelines.translation_based.processing import (
-    generate_batches,
-    GENERATE_BATCHES_META,
-)
-from src.utils import PROJECT_ROOT, prepare_folder, get_config
-from dask.distributed import Client
-from transformers import BertTokenizerFast
 import dask.dataframe as dd
 from dask import delayed
+from dask.distributed import Client
+from transformers import BertTokenizerFast
+
+from src.pipelines.translation_based.processing import GENERATE_BATCHES_META, generate_batches
+from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
 OUTPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage2_create_batches"
diff --git a/src/pipelines/translation_based/stage3_exploding.py b/src/pipelines/translation_based/stage3_exploding.py
index 74646a5..d969dd1 100644
--- a/src/pipelines/translation_based/stage3_exploding.py
+++ b/src/pipelines/translation_based/stage3_exploding.py
@@ -1,12 +1,13 @@
 # /usr/bin/python3
+import dask.dataframe as dd
+from dask.distributed import Client
+
 from src.pipelines.translation_based.processing import (
-    flatten_dims,
-    expand_dims,
-    FLATTEN_DIMS_META,
     EXPAND_DIMS_META,
+    FLATTEN_DIMS_META,
+    expand_dims,
+    flatten_dims,
 )
-import dask.dataframe as dd
-from dask.distributed import Client
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage2_create_batches"
diff --git a/src/pipelines/translation_based/stage4_reindexing.py b/src/pipelines/translation_based/stage4_reindexing.py
index ffad285..6bbb541 100644
--- a/src/pipelines/translation_based/stage4_reindexing.py
+++ b/src/pipelines/translation_based/stage4_reindexing.py
@@ -1,6 +1,7 @@
 # /usr/bin/python3
 import dask.dataframe as dd
 from dask.distributed import Client
+
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage3_exploding"
diff --git a/src/pipelines/translation_based/train.py b/src/pipelines/translation_based/train.py
index f6f5f4d..6e39ecc 100755
--- a/src/pipelines/translation_based/train.py
+++ b/src/pipelines/translation_based/train.py
@@ -1,20 +1,17 @@
 #!/usr/bin/python3
 
-import torch
-import numpy as np
-import dask.dataframe as dd
 import glob
-from transformers import BertTokenizerFast
-from src.utils import (
-    PROJECT_ROOT,
-    get_config,
-    convert_to_timedelta,
-    prepare_folder,
-)
-from src.training import latest_model, save_training_step
 from datetime import datetime
-from src.models.TransformerSeq2Seq import TransformerSeq2Seq
+
+import dask.dataframe as dd
+import numpy as np
+import torch
+from transformers import BertTokenizerFast
+
 from src.batch_loading import get_batches, get_ordered_dataframe_len
+from src.models.TransformerSeq2Seq import TransformerSeq2Seq
+from src.training import latest_model, save_training_step
+from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config, prepare_folder
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/translations/stage4_reindexing"
 OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/translations"
diff --git a/src/processing.py b/src/processing.py
index d61be86..2777416 100644
--- a/src/processing.py
+++ b/src/processing.py
@@ -1,485 +1,65 @@
-from xml.etree import ElementTree as ET
-from typing import Optional, Mapping
-from src.utils import remove_punctuation
 import numpy as np
-from transformers import PreTrainedTokenizerFast
-from collections import defaultdict
 
-ACTIONS_KEYS = ["dot", "upper_case", "colon", "question_mark"]
 
-
-def empty_action_vector() -> np.ndarray:
-    """Returns a do-nothing actions vector
-
-    Returns:
-        np.ndarray: Vector with all zeroes and length of ACTION_KEYS
-    """
-    return np.zeros(len(ACTIONS_KEYS))
-
-
-def empty_action_dict() -> dict:
-    """Returns a do-noting unencoded action dict
-
-    Returns:
-        dict: Action dict with all actions set to False
-    """
-
-    return decode_actions(empty_action_vector())
-
-
-def text_from_xml(path: str) -> str:
-    """Extract spoken text from dataset's xml format
-
-    Args:
-        path (str): Path to xml
-
-    Returns:
-        str: Raw text
-    """
-    root = ET.parse(path).getroot()
-
-    full_text = ""
-
-    for node in root.iter("*"):
-        if len(node) == 0:
-            who = node.get("who")
-            text = node.text
-
-            if text is not None and who is not None and who != "#komentarz":
-                full_text = " ".join([full_text, text])
-
-    del root
-
-    return full_text
-
-
-def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
-    """Detect what actions should model perform on a word and returns encoded
-       action vector
+def expand_dims(entry) -> dict:
+    """Reshapes flat source, target, mask arrays into corresponding shapes
 
     Args:
-        word (str): Word on wich action is decided
-        next_word (Optional[str]): Word that follows considered word. Can be
-            None if nothing follows a word
+        entry (dict): Dask dataframe row with columns: source, target, attention_mask, source_shape, target_shape, attention_mask_shape
 
     Returns:
-        Mapping[str, bool]: Mapping telling if each of possible actions should be performed (True) or not (False)
+        dict: Dask dataframe row with columns: source, target, attention_mask
     """
-    # Unsuported characters
-    word.replace(";", ".")
-    word.replace('"', "")
-    word.replace("(", "")
-    word.replace(")", "")
+    source = entry.source.reshape(entry.source_shape)
+    target = entry.target.reshape(entry.target_shape)
+    mask = entry.attention_mask.reshape(entry.attention_mask_shape)
 
-    while len(word) > 0 and not word[0].isalnum():  # remove proceding characters
-        word = word[1:]
-
-    if len(word) == 0:
-        return dict(zip(ACTIONS_KEYS, [False] * len(ACTIONS_KEYS)))
-
-    actions = {
-        "dot": word[-1] == ".",
-        "upper_case": word[0].isupper(),
-        "colon": word[-1] == ",",
-        "question_mark": word[-1] == "?",
+    return {
+        "source": source,
+        "target": target,
+        "attention_mask": mask,
     }
 
-    return actions
-
-
-def encode_actions(actions: Mapping[str, bool]) -> np.ndarray:
-    """Transforms actions into vector
-
-    Args:
-        actions (Mapping[str, bool]): Map telling which actions should be made
-
-    Returns:
-        np.ndarray: 1 dimensional action vector
-    """
-    return np.array(list(actions.values())).astype(float)
-
-
-def decode_actions(encoded_actions: np.ndarray) -> Mapping[str, bool]:
-    """Decodes actions
-
-    Args:
-        encoded_actions (np.ndarray): 1 dimensional action vector
-
-    Returns:
-        Mapping[str, bool]: Map telling which actions should be made
-    """
-    assert encoded_actions.shape[0] == len(ACTIONS_KEYS)
-
-    return dict(zip(ACTIONS_KEYS, encoded_actions.astype(np.bool).tolist()))
-
-
-def create_model_input_output(text: str) -> (str, np.ndarray):
-    """Returns a pair of input and desired output of the model
-
-    Args:
-        text (str): Correct text sample
-
-    Returns:
-        text_cleaned (str): Text without any interpuction and all lowercase
-        actions (np.ndarray): To dimensional array, where each row is aciton vector for each word (columns)
-    """
-    words = text.split(" ")
-
-    words_output = []
-    actions_output = []
-
-    i = 0
-    while i < len(words):
-        word = words[i]
-        next_word = words[i + 1] if len(words) > i + 1 else None
-
-        word_sanitized = remove_punctuation(word).lower()
-        if len(word_sanitized) > 0:
-            actions = detect_actions(word, next_word)
-            actions_encoded = encode_actions(actions)
-
-            words_output.append(word_sanitized)
-            actions_output.append(actions_encoded)
-
-        i += 1
-
-    assert len(words_output) == len(actions_output)
-
-    return " ".join(words_output), np.array(actions_output)
-
-
-def token_word_mapping(text: str, tokenizer: PreTrainedTokenizerFast) -> np.ndarray:
-    """Returns mapping where each token is labeled with index of word it's part of
-
-    Args:
-        text (str): Input text
-        tokenizer (PreTrainedTokenizerFast): Tokenizer used to tokenize text
-
-    Returns:
-        np.ndarray: Array of length L (number of tokens) where each entry is index of word (cls and sep labels are not counted).
-    """
-    text_tokenized = tokenizer(text, return_offsets_mapping=True)
-    offset_mappings = text_tokenized["offset_mapping"][1:-1]
-
-    offset_mappings = text_tokenized["offset_mapping"][1:-1]
-
-    # Create a map where each character is assigned index of it's word
-    words_mapping = []
-    actual_word = 0
-    for character in text:
-        words_mapping.append(actual_word)
-        if character == " ":
-            actual_word += 1
-
-    token_mapping = [words_mapping[x[0]] for x in offset_mappings]
-
-    return np.array(token_mapping)
-
-
-def token_labels_to_word_labels(
-    text: str, token_labels: np.ndarray, tokenizer: PreTrainedTokenizerFast
-) -> np.ndarray:
-    mapping = token_word_mapping(text, tokenizer)
-
-    assert len(mapping) == len(token_labels)
-
-    labels = defaultdict(list)
-
-    for i in range(len(mapping)):
-        labels[mapping[i]].append(token_labels[i])
-
-    return np.array([np.mean(labels[x], axis=0) for x in sorted(labels)])
-
-
-def tokenize_labeled_text(
-    text: str, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast
-) -> (np.ndarray, np.ndarray):
-    """Transforms text into numerical tokens. Also expand word-level labels into token-level labels
-
-    Args:
-        text (str): Text that will be tokenized (TODO: Change to array)
-        labels (np.ndarray): Word-level labels for text to be tokenized. Word is defined via space spearation
-        tokenizer (PreTrainedTokenizerFast): Tokenizer that will be used for tokenization
-
-    Returns:
-        np.ndarray: 2-dimensional array with tokens (without cls and sep tokens!)
-        np.ndarray 2-dimensional array with token-level labels
-    """
-    text_tokenized = tokenizer(text, return_offsets_mapping=True)
-
-    offset_mappings = text_tokenized["offset_mapping"][1:-1]
-    input_ids = text_tokenized["input_ids"][1:-1]
-
-    # Create a map where each character is assigned index of it's word
-    words_mapping = []
-    actual_word = 0
-    for character in text:
-        words_mapping.append(actual_word)
-        if character == " ":
-            actual_word += 1
-
-    # Assign each token to a word
-    token_mapping = [words_mapping[x[0]] for x in offset_mappings]
-
-    # Expand word-based labels to token-based labels
-    labels_tokenized = [labels[i] for i in token_mapping]
 
-    return np.array(input_ids).reshape(-1, 1), np.array(labels_tokenized)
+EXPAND_DIMS_META = {
+    "source": object,
+    "target": object,
+    "attention_mask": object,
+}
 
 
-def recover_word(word: str, action: Mapping[str, bool]) -> str:
-    """Applies action to a word
+def flatten_dims(entry: dict) -> dict:
+    """Flattens arrays in dataframe rows into 1D and saves shapes into separate columns
 
     Args:
-        word (str): word on which action will be applied
-        action (Mapping[str, bool]): Action to be applied
+        entry (dict): Dask dataframe row with columns: source, target, attention_mask
 
     Returns:
-        str: transfomed word
+        dict: Dask dataframe row with columns: source, target, attention_mask, source_shape, target_shape, attention_mask_shape
     """
-    word_result = word
+    source_shape = np.array(entry.source.shape)
+    target_shape = np.array(entry.target.shape)
+    mask_shape = np.array(entry.attention_mask.shape)
 
-    if action["dot"]:
-        word_result += "."
-    if action["upper_case"]:
-        word_result = word_result.capitalize()
-    if action["colon"]:
-        word_result += ","
-    if action["question_mark"]:
-        word_result += "?"
+    source = entry.source.reshape(-1)
+    target = entry.target.reshape(-1)
+    mask = entry.attention_mask.reshape(-1)
 
-    return word_result
-
-
-def is_sentence_end(actions_encoded: np.ndarray) -> bool:
-    """Returns if given action would end a sentence
-
-    Args:
-        actions_encoded (np.ndarray): Action vector
-
-    Returns:
-        bool: True if action would end a sentence, False otherwise
-    """
-    actions_decoded = decode_actions(actions_encoded)
-
-    return actions_decoded["dot"] is True
-
-
-def nearest_sentence_l(labels: np.array, index_start: int) -> int:
-    """Find nearest word that begins a sentence that has lower or equal index to index_start
-
-    Args:
-        labels (np.array): 2-dimensonal array of action-vectors
-        index_start (int): Index from which search will be started
-
-    Returns:
-        int: Index of nearest left-oriented start of the sentence. If no sentence is found, first index is assumed to
-             start a sentence
-    """
-    result_index = index_start
-
-    while result_index > 0:
-        if is_sentence_end(labels[result_index, :]):
-            # prevent beeing in the middle of token
-            result_index -= 1
-        elif is_sentence_end(labels[result_index - 1, :]):
-            break
-        elif result_index == 1:
-            result_index = 0
-            break
-        else:
-            result_index -= 1
-
-    return result_index
-
-
-def nearest_sentence_r(labels: np.array, index_start: int) -> Optional[int]:
-    """Find nearest word that begins a sentence that has higher or equal index to index_start
-
-    Args:
-        labels (np.array): 2-dimensonal array of action-vectors
-        index_start (int): Index from which search will be started
-
-    Returns:
-        int: Index of nearest right-oriented start of the sentence. None if no later sentence is found
-    """
-    result_index = index_start
-
-    while result_index < len(labels):
-        if is_sentence_end(labels[result_index - 1]) and not is_sentence_end(
-            labels[result_index]
-        ):
-            break
-        else:
-            result_index += 1
-
-    if result_index >= len(labels):
-        return None
-    else:
-        return result_index
-
-
-def batchify_labels(
-    labels: np.ndarray, max_tokens: int, min_tokens: int = 3
-) -> [np.ndarray]:
-    """Splits long labels array into batches of desired size
-
-    Args:
-        labels (np.ndarray): 2-dimensional array of action-vectors
-        max_tokens (int): Maximum number of labels in a single batch
-        min_tokens (int, optional): Minimum number of labels in a single batch. Defaults to 3.
-
-    Returns:
-        [np.ndarray]: List of arrays with indexes composing each batch
-    """
-    assert min_tokens >= 1
-    assert max_tokens >= 1
-
-    labels_batches = []
-
-    index = 0
-    new_index = 0
-    while index < (labels.shape[0] - min_tokens):
-        num_consumed = min(max_tokens, labels.shape[0] - index)
-
-        assert num_consumed >= min_tokens
-
-        if index + num_consumed < (labels.shape[0] - min_tokens):
-            new_index = nearest_sentence_l(labels, index + num_consumed)
-            if new_index == index:
-                new_index = nearest_sentence_r(labels, index + num_consumed)
-                if new_index is None:
-                    labels_batches.append(
-                        np.array(list(range(index, index + num_consumed)))
-                    )
-                    break
-        else:
-            labels_batches.append(np.array(list(range(index, index + num_consumed))))
-            break
-
-        labels_batches.append(np.array(list(range(index, index + num_consumed))))
-
-        index = new_index
-
-    return labels_batches
-
-
-def add_cls_sep(
-    tokens: np.ndarray, labels: np.ndarray, tokenizer: PreTrainedTokenizerFast
-) -> (np.ndarray, np.ndarray):
-    """Adds staring cls and ending sep token ids into tokens & labels
-
-    Args:
-        tokens (np.ndarray): 2-dimensional array (with 1 feature!) of tokens
-        labels (np.ndarray): 2-dimensional array of action vectors
-
-    Returns:
-        np.ndarray: tokens with added cls & sep tokens ids
-        np.ndarray: labels with first and last item duplicated to accomodate for cls & sep
-    """
-
-    tokens = np.concatenate(
-        [[[tokenizer.cls_token_id]], tokens, [[tokenizer.sep_token_id]]]
-    )
-    labels = np.concatenate([labels[:1, :], labels, labels[-1:, :]])
-
-    return tokens, labels
-
-
-def add_padding(
-    tokens: np.ndarray,
-    labels: np.ndarray,
-    length: int,
-    tokenizer: PreTrainedTokenizerFast,
-) -> (np.ndarray, np.ndarray, np.ndarray):
-    """Appends padding to tokens and labels to match desired length
-
-    Args:
-        tokens (np.ndarray): Lx1 array of token ids
-        labels (np.ndarray): LxA array of action vectors
-        length (int): Desired length of a vector. Must be higher than L
-        tokenizer (PreTrainedTokenizerFast): Tokenizer that was used for tokenization
-
-    Returns:
-        np.ndarray: (L+P)x1 array of token ids with added padding
-        np.ndarray: (L+P)xA array of action vectors with added padding
-        np.ndarray: (L+P)-length array of masks where True means token False - padding
-    """
-
-    pad_length = length - tokens.shape[0]
-    assert pad_length >= 0
-
-    if pad_length > 0:
-        tokens = np.concatenate([tokens, [[tokenizer.pad_token_id]] * pad_length])
-        labels = np.concatenate([labels, [empty_action_vector()] * pad_length])
-
-    mask = np.ones(len(tokens)).astype(np.int)
-
-    if pad_length > 0:
-        mask[-pad_length:] = False
-
-    return tokens, labels, mask
-
-
-def batchify_data(
-    tokens: np.ndarray,
-    labels: np.ndarray,
-    max_tokens: int,
-    tokenizer: PreTrainedTokenizerFast,
-    min_tokens: int = 3,
-) -> (np.ndarray, np.ndarray):
-    """Transforms tokens and labels into a batch
-
-    Args:
-        np ([type]): [description]
-        tokens (np.ndarray, labels, optional): [description]. Defaults to 3)->(np.ndarray.
-
-    Returns:
-        [type]: [description]
-    """
-
-    assert max_tokens >= min_tokens + 2
-    assert min_tokens >= 1
-
-    tokens_batch = []
-    labels_batch = []
-    mask_batch = []
-
-    idxs = batchify_labels(labels, max_tokens - 2, min_tokens)
-
-    for ids in idxs:
-        tokens_sample = tokens[ids, :]
-        labels_sample = labels[ids, :]
-
-        assert len(ids) >= min_tokens
-        assert len(ids) <= max_tokens - 2
-
-        tokens_sample, labels_sample = add_cls_sep(
-            tokens_sample, labels_sample, tokenizer
-        )
-
-        assert len(tokens_sample) <= max_tokens
-
-        tokens_sample, labels_sample, mask = add_padding(
-            tokens_sample, labels_sample, max_tokens, tokenizer
-        )
-
-        tokens_batch.append(tokens_sample)
-        labels_batch.append(labels_sample)
-        mask_batch.append(mask)
-
-    return np.array(tokens_batch), np.array(labels_batch), np.array(mask_batch)
-
-
-def recover_text(text: str, actions_encoded: np.ndarray):
-    words = text.split(" ")
-
-    words_output = []
-
-    for word, action_encoded in zip(words, actions_encoded.tolist()):
-        action_decoded = decode_actions(np.array(action_encoded))
+    return {
+        "source": source,
+        "target": target,
+        "attention_mask": mask,
+        "source_shape": source_shape,
+        "target_shape": target_shape,
+        "attention_mask_shape": mask_shape,
+    }
 
-        word_recovered = recover_word(word, action_decoded)
-        words_output.append(word_recovered)
 
-    return " ".join(words_output)
+FLATTEN_DIMS_META = {
+    "source": object,
+    "target": object,
+    "attention_mask": object,
+    "source_shape": object,
+    "target_shape": object,
+    "attention_mask_shape": object,
+}
diff --git a/src/training.py b/src/training.py
index 76b21d5..f9ffd92 100644
--- a/src/training.py
+++ b/src/training.py
@@ -1,8 +1,10 @@
-from typing import Tuple, Optional
 import re
+from typing import Optional, Tuple
+
+import torch
 import torch.nn as nn
 import torch.optim as optim
-import torch
+
 from src.utils import prepare_folder
 
 
diff --git a/src/utils.py b/src/utils.py
index 6ee5337..d03d216 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -1,9 +1,10 @@
-import yaml
-import re
 import os
+import re
+import shutil
 from datetime import timedelta
 from typing import Optional
-import shutil
+
+import yaml
 
 PROJECT_ROOT = os.path.dirname(os.path.realpath("/".join(__file__.split("/")) + "/.."))
 
diff --git a/tests/pipelines/actions_based/test_processing.py b/tests/pipelines/actions_based/test_processing.py
index 9956d22..c626ff2 100644
--- a/tests/pipelines/actions_based/test_processing.py
+++ b/tests/pipelines/actions_based/test_processing.py
@@ -1,6 +1,216 @@
-from src.pipelines.actions_based.processing import last_stop_label, action_vector
-from src.processing import encode_actions
 import numpy as np
+import pytest
+from transformers import BertTokenizerFast
+
+from src.pipelines.actions_based.processing import (
+    ACTIONS_KEYS,
+    action_vector,
+    batchify_data,
+    batchify_labels,
+    create_model_input_output,
+    decode_actions,
+    detect_actions,
+    encode_actions,
+    last_stop_label,
+    nearest_sentence_l,
+    nearest_sentence_r,
+    recover_text,
+    token_labels_to_word_labels,
+    token_word_mapping,
+    tokenize_labeled_text,
+)
+
+
+def test_detect_actions():
+    actions = detect_actions("Janek.", None)
+    assert actions == {
+        "dot": True,
+        "upper_case": True,
+        "colon": False,
+        "question_mark": False,
+    }
+
+    actions = detect_actions("ewka?", None)
+    assert actions == {
+        "dot": False,
+        "upper_case": False,
+        "colon": False,
+        "question_mark": True,
+    }
+
+    actions = detect_actions("Test", None)
+    assert actions == {
+        "dot": False,
+        "upper_case": True,
+        "colon": False,
+        "question_mark": False,
+    }
+
+
+def test_encode_actions():
+    x = {
+        "dot": True,
+        "upper_case": False,
+        "colon": False,
+        "question_mark": True,
+    }
+
+    assert np.all(encode_actions(x) == np.array([1, 0, 0, 1]))
+
+
+def test_decode_actions():
+    x = np.array([1, 0, 0, 1])
+
+    assert decode_actions(x) == {
+        "dot": True,
+        "upper_case": False,
+        "colon": False,
+        "question_mark": True,
+    }
+
+
+def test_token_word_mapping():
+    text = "janek poszedł do ogrodu"
+    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
+
+    text_tokenized = tokenizer(text)
+
+    mapping = token_word_mapping(text, tokenizer)
+
+    assert len(mapping) == (len(text_tokenized["input_ids"]) - 2)
+    assert min(mapping) == 0
+    assert max(mapping) == 3
+
+
+def test_token_labels_to_word_labels():
+    text = "janek poszedł do ogrodu"
+    labels = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
+    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
+
+    _, token_labels = tokenize_labeled_text(text, labels, tokenizer)
+
+    word_labels = token_labels_to_word_labels(text, token_labels, tokenizer)
+
+    assert np.all(np.vectorize(pytest.approx)(word_labels, labels))
+
+
+def test_tokenize_labeled_text():
+    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię?"
+    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
+
+    text_clean, labels = create_model_input_output(text)
+    tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
+
+    assert len(tokens.shape) == 2
+    assert len(token_labels.shape) == 2
+
+    assert tokens.shape[1] == 1
+    assert token_labels.shape[1] == len(ACTIONS_KEYS)
+
+    assert len(tokens) == len(token_labels)
+    assert tokens[0, 0] != tokenizer.cls_token_id
+    assert tokens[-1, 0] != tokenizer.sep_token_id
+
+
+def test_recover_text():
+    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię?"
+    text_clean, word_labels = create_model_input_output(text)
+
+    result_text = recover_text(text_clean, word_labels)
+
+    assert result_text == text
+
+
+def test_nearest_sentence_l():
+    end = create_dummy_action(True)
+    word = create_dummy_action(False)
+
+    entry = np.array([word, word, word, end, end, word, word, end])
+
+    assert nearest_sentence_l(entry, 3) == 0
+    assert nearest_sentence_l(entry, 4) == 0
+    assert nearest_sentence_l(entry, 5) == 5
+    assert nearest_sentence_l(entry, 7) == 5
+
+
+def create_dummy_action(end_sentence: bool) -> np.array:
+    return encode_actions(
+        {
+            "dot": end_sentence,
+            "upper_case": False,
+            "colon": False,
+            "question_mark": False,
+        }
+    )
+
+
+def test_nearest_sentence_r():
+    end = create_dummy_action(True)
+    word = create_dummy_action(False)
+
+    entry = np.array([word, word, word, end, end, word, word, end])
+
+    assert nearest_sentence_r(entry, 0) == 0
+    assert nearest_sentence_r(entry, 4) == 5
+    assert nearest_sentence_r(entry, 5) == 5
+    assert nearest_sentence_r(entry, 6) is None
+    assert nearest_sentence_r(entry, 7) is None
+
+
+def test_batchify_labels():
+    end = create_dummy_action(True)
+    word = create_dummy_action(False)
+    entry = np.array([word, word, word, end, end, word, word, end])
+
+    batches = batchify_labels(entry, 3, 1)
+
+    assert len(batches) == 2
+    assert np.all(batches[0] == range(0, 3))
+    assert np.all(batches[1] == range(5, 8))
+
+
+def test_batchify_data():
+    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam niedzwiedzia?"
+    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
+
+    text_clean, labels = create_model_input_output(text)
+    tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
+
+    input_batch, output_batch, mask_batch = batchify_data(
+        tokens, token_labels, 5, tokenizer
+    )
+
+    assert len(input_batch.shape) == 3
+    assert len(output_batch.shape) == 3
+    assert len(mask_batch.shape) == 2
+
+    assert input_batch.shape[0] == mask_batch.shape[0]
+    assert input_batch.shape[0] > 1
+
+    # Second dimension should be sequence length
+    assert input_batch.shape[1] == 5
+    assert output_batch.shape[1] == 5
+    assert mask_batch.shape[1] == 5
+
+    # Third dimension should be feature size
+    assert input_batch.shape[2] == 1
+    assert output_batch.shape[2] == len(ACTIONS_KEYS)
+
+    # Mask should be integer (1 - leave, 0 - mask out)
+    assert mask_batch.dtype == np.int
+
+    # Should never be fully masked
+    # TODO: Make sure correct convetions is used
+    assert np.all(mask_batch[:, 0] == 1)
+
+    # Should never be fully masked0
+    for i in range(input_batch.shape[0]):
+        # Should always start from beginning of the sentence
+        assert decode_actions(output_batch[i, 0, :])["upper_case"]
+        assert decode_actions(output_batch[i, 1, :])["upper_case"]
+
+        # Should always end with sep and padding#
+        # TODO: Test it
 
 
 def test_action_vector():
diff --git a/tests/pipelines/translation_based/test_processing.py b/tests/pipelines/translation_based/test_processing.py
index 8ef4c69..f7806a1 100644
--- a/tests/pipelines/translation_based/test_processing.py
+++ b/tests/pipelines/translation_based/test_processing.py
@@ -1,15 +1,16 @@
 import numpy as np
+from transformers import BertTokenizerFast
+
 from src.pipelines.translation_based.processing import (
+    add_begin_end_tokens,
+    add_padding,
+    create_input_output,
+    crete_input_output_batch,
     find_new_sentence_left,
     find_new_sentence_right,
     get_batch_indexes,
-    add_padding,
-    add_begin_end_tokens,
     standarize_translation_sample,
-    create_input_output,
-    crete_input_output_batch,
 )
-from transformers import BertTokenizerFast
 
 
 def test_find_new_sentence_left():
diff --git a/tests/test_batch_loading.py b/tests/test_batch_loading.py
index cde65ad..641aaa7 100644
--- a/tests/test_batch_loading.py
+++ b/tests/test_batch_loading.py
@@ -1,11 +1,12 @@
+import dask.dataframe as dd
 import numpy as np
 import pandas as pd
-import dask.dataframe as dd
+
 from src.batch_loading import (
     calculate_batch_buffer_id,
-    yield_batch_buffer_span,
-    get_ordered_dataframe_len,
     get_batches,
+    get_ordered_dataframe_len,
+    yield_batch_buffer_span,
 )
 
 
diff --git a/tests/test_processing.py b/tests/test_processing.py
deleted file mode 100644
index 2aeed6e..0000000
--- a/tests/test_processing.py
+++ /dev/null
@@ -1,210 +0,0 @@
-from src.processing import (
-    detect_actions,
-    encode_actions,
-    token_word_mapping,
-    tokenize_labeled_text,
-    token_labels_to_word_labels,
-    create_model_input_output,
-    recover_text,
-    nearest_sentence_l,
-    nearest_sentence_r,
-    batchify_labels,
-    batchify_data,
-    ACTIONS_KEYS,
-    decode_actions,
-)
-from transformers import BertTokenizerFast
-import pytest
-import numpy as np
-
-
-def test_detect_actions():
-    actions = detect_actions("Janek.", None)
-    assert actions == {
-        "dot": True,
-        "upper_case": True,
-        "colon": False,
-        "question_mark": False,
-    }
-
-    actions = detect_actions("ewka?", None)
-    assert actions == {
-        "dot": False,
-        "upper_case": False,
-        "colon": False,
-        "question_mark": True,
-    }
-
-    actions = detect_actions("Test", None)
-    assert actions == {
-        "dot": False,
-        "upper_case": True,
-        "colon": False,
-        "question_mark": False,
-    }
-
-
-def test_encode_actions():
-    x = {
-        "dot": True,
-        "upper_case": False,
-        "colon": False,
-        "question_mark": True,
-    }
-
-    assert np.all(encode_actions(x) == np.array([1, 0, 0, 1]))
-
-
-def test_decode_actions():
-    x = np.array([1, 0, 0, 1])
-
-    assert decode_actions(x) == {
-        "dot": True,
-        "upper_case": False,
-        "colon": False,
-        "question_mark": True,
-    }
-
-
-def test_token_word_mapping():
-    text = "janek poszedł do ogrodu"
-    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
-
-    text_tokenized = tokenizer(text)
-
-    mapping = token_word_mapping(text, tokenizer)
-
-    assert len(mapping) == (len(text_tokenized["input_ids"]) - 2)
-    assert min(mapping) == 0
-    assert max(mapping) == 3
-
-
-def test_token_labels_to_word_labels():
-    text = "janek poszedł do ogrodu"
-    labels = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
-    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
-
-    _, token_labels = tokenize_labeled_text(text, labels, tokenizer)
-
-    word_labels = token_labels_to_word_labels(text, token_labels, tokenizer)
-
-    assert np.all(np.vectorize(pytest.approx)(word_labels, labels))
-
-
-def test_tokenize_labeled_text():
-    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię?"
-    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
-
-    text_clean, labels = create_model_input_output(text)
-    tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
-
-    assert len(tokens.shape) == 2
-    assert len(token_labels.shape) == 2
-
-    assert tokens.shape[1] == 1
-    assert token_labels.shape[1] == len(ACTIONS_KEYS)
-
-    assert len(tokens) == len(token_labels)
-    assert tokens[0, 0] != tokenizer.cls_token_id
-    assert tokens[-1, 0] != tokenizer.sep_token_id
-
-
-def test_recover_text():
-    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam Zosię?"
-    text_clean, word_labels = create_model_input_output(text)
-
-    result_text = recover_text(text_clean, word_labels)
-
-    assert result_text == text
-
-
-def test_nearest_sentence_l():
-    end = create_dummy_action(True)
-    word = create_dummy_action(False)
-
-    entry = np.array([word, word, word, end, end, word, word, end])
-
-    assert nearest_sentence_l(entry, 3) == 0
-    assert nearest_sentence_l(entry, 4) == 0
-    assert nearest_sentence_l(entry, 5) == 5
-    assert nearest_sentence_l(entry, 7) == 5
-
-
-def create_dummy_action(end_sentence: bool) -> np.array:
-    return encode_actions(
-        {
-            "dot": end_sentence,
-            "upper_case": False,
-            "colon": False,
-            "question_mark": False,
-        }
-    )
-
-
-def test_nearest_sentence_r():
-    end = create_dummy_action(True)
-    word = create_dummy_action(False)
-
-    entry = np.array([word, word, word, end, end, word, word, end])
-
-    assert nearest_sentence_r(entry, 0) == 0
-    assert nearest_sentence_r(entry, 4) == 5
-    assert nearest_sentence_r(entry, 5) == 5
-    assert nearest_sentence_r(entry, 6) is None
-    assert nearest_sentence_r(entry, 7) is None
-
-
-def test_batchify_labels():
-    end = create_dummy_action(True)
-    word = create_dummy_action(False)
-    entry = np.array([word, word, word, end, end, word, word, end])
-
-    batches = batchify_labels(entry, 3, 1)
-
-    assert len(batches) == 2
-    assert np.all(batches[0] == range(0, 3))
-    assert np.all(batches[1] == range(5, 8))
-
-
-def test_batchify_data():
-    text = "Janek poszedł do ogrodu. Ogród był zwierzęcy. Spotkał tam niedzwiedzia?"
-    tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
-
-    text_clean, labels = create_model_input_output(text)
-    tokens, token_labels = tokenize_labeled_text(text_clean, labels, tokenizer)
-
-    input_batch, output_batch, mask_batch = batchify_data(
-        tokens, token_labels, 5, tokenizer
-    )
-
-    assert len(input_batch.shape) == 3
-    assert len(output_batch.shape) == 3
-    assert len(mask_batch.shape) == 2
-
-    assert input_batch.shape[0] == mask_batch.shape[0]
-    assert input_batch.shape[0] > 1
-
-    # Second dimension should be sequence length
-    assert input_batch.shape[1] == 5
-    assert output_batch.shape[1] == 5
-    assert mask_batch.shape[1] == 5
-
-    # Third dimension should be feature size
-    assert input_batch.shape[2] == 1
-    assert output_batch.shape[2] == len(ACTIONS_KEYS)
-
-    # Mask should be integer (1 - leave, 0 - mask out)
-    assert mask_batch.dtype == np.int
-
-    # Should never be fully masked
-    # TODO: Make sure correct convetions is used
-    assert np.all(mask_batch[:, 0] == 1)
-
-    # Should never be fully masked0
-    for i in range(input_batch.shape[0]):
-        # Should always start from beginning of the sentence
-        assert decode_actions(output_batch[i, 0, :])["upper_case"]
-        assert decode_actions(output_batch[i, 1, :])["upper_case"]
-
-        # Should always end with sep and padding#
-        # TODO: Test it
diff --git a/tox.ini b/tox.ini
index 230f9c4..4326963 100644
--- a/tox.ini
+++ b/tox.ini
@@ -35,13 +35,15 @@ exclude =
     generated
 max-complexity = 10
 max-line-length = 80
-select = C,E,F,W,B,B950
-ignore = E203, E501, W503, C901
+select = I,C,E,F,W,B,B950,TYP,T
+ignore = E501, C901, I201
 
 
 [testenv:pep8]
 deps =
     flake8
+    flake8-type-annotations
+    flake8-typing-imports
 basepython = python
 commands =
     flake8 {posargs}
diff --git a/worker.py b/worker.py
index 059a831..15bce3d 100755
--- a/worker.py
+++ b/worker.py
@@ -1,8 +1,10 @@
 #!/usr/bin/python
 
+import configparser
+
 import nlp_ws
+
 from src.pipelines.actions_based.processing import apply_actions_punctuation
-import configparser
 from src.pipelines.actions_based.utils import load_model
 
 
-- 
GitLab


From 9d17a6a48325bcb461d37217af2dbc74245d41f4 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 17:32:19 +0200
Subject: [PATCH 052/116] Moved model from docker image into script

---
 Dockerfile    | 2 +-
 entrypoint.sh | 8 ++++++++
 2 files changed, 9 insertions(+), 1 deletion(-)
 create mode 100755 entrypoint.sh

diff --git a/Dockerfile b/Dockerfile
index b467b0f..2948379 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -12,4 +12,4 @@ COPY worker.py .
 RUN mkdir ./deploy && \
     wget https://minio.clarin-pl.eu/public/models/punctuation/0-190000.model -O deploy/model
     
-CMD [ "./worker.py" ]
\ No newline at end of file
+ENTRYPOINT [ "./worker.py" ]
\ No newline at end of file
diff --git a/entrypoint.sh b/entrypoint.sh
new file mode 100755
index 0000000..e548dca
--- /dev/null
+++ b/entrypoint.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+
+if test -f "./deploy/model"; then
+    mkdir -p ./deploy
+    wget https://minio.clarin-pl.eu/public/models/punctuation/0-190000.model -O deploy/model
+fi
+
+python3 worker.py
\ No newline at end of file
-- 
GitLab


From c39bee2b6f2854400336cde3d2c5ec4ace69bc42 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 17:57:21 +0200
Subject: [PATCH 053/116] Added readme

---
 README.md | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/README.md b/README.md
index 7a65c4a..0b99108 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,10 @@
-# punctuator
+# Punctuator
+A service that automatically adds punctuation to raw word-stream (eg. from speech2text).  
+
+## Approaches
+1. Token classification (actions): Each token is classified with 4 labels: Uppercase, dot, colon, question mark. The model is based on the stacked encoder part of transformer architecture (Bert), followed by FC-layer that transforms the output into per-token multilabel binary classifications. For now, there is no restriction for taking dot, question_mark and colon labels simultaneously, so that's the are of improvement  (hierarchical, multilabel classification)
+
+2. Sequence-to-Sequence (translations): Full encoder-decoder stack that takes input (unpunctuated text) and the output produced so far to predict the next token. In theory, this model should be able to represent many more cases (eg. all upper, some upper, dashes, ellipsis etc...) without explicit defines. However, the lack of constraints makes it much harder to train. 
+
+## Mountpoints
+Directory where model will be downloaded (~500Mb) needs to be mounted at /punctuator/deploy
-- 
GitLab


From b9d01fc8d6b7a5741e85ad293034c89e7de10e0c Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 18:14:34 +0200
Subject: [PATCH 054/116] Updated dockerfile

---
 Dockerfile       |  9 +++----
 requirements.txt | 61 ++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 66 insertions(+), 4 deletions(-)
 create mode 100644 requirements.txt

diff --git a/Dockerfile b/Dockerfile
index 2948379..2b3b6a3 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,15 +1,16 @@
 FROM clarinpl/python:3.8
 
 RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y gcc python3-dev
-RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pyarrow==0.17.1 pytest lxml
 RUN mkdir /punctuator
 WORKDIR /punctuator
 
+COPY requirements.txt requirements.txt
+RUN pip3 install -r requirements.txt && rm requirements.txt
+
 COPY src ./src
 COPY config.ini .
 COPY worker.py .
+  
+RUN pip3 freeze
 
-RUN mkdir ./deploy && \
-    wget https://minio.clarin-pl.eu/public/models/punctuation/0-190000.model -O deploy/model
-    
 ENTRYPOINT [ "./worker.py" ]
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..17154b9
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,61 @@
+attrs==19.3.0
+bokeh==2.1.1
+certifi==2020.6.20
+chardet==3.0.4
+click==7.1.2
+cloudpickle==1.5.0
+cycler==0.10.0
+dask==2.22.0
+distributed==2.22.0
+filelock==3.0.12
+fsspec==0.8.0
+future==0.18.2
+HeapDict==1.0.1
+idna==2.10
+iniconfig==1.0.1
+Jinja2==2.11.2
+joblib==0.16.0
+kiwisolver==1.2.0
+locket==0.2.0
+lxml==4.5.2
+MarkupSafe==1.1.1
+matplotlib==3.3.0
+more-itertools==8.4.0
+msgpack==1.0.0
+numpy==1.19.1
+packaging==20.4
+pandas==1.1.0
+partd==1.1.0
+Pillow==7.2.0
+pluggy==0.13.1
+psutil==5.7.2
+py==1.9.0
+pyarrow==0.17.1
+pycurl==7.43.0
+pygobject==3.20.0
+pyparsing==2.4.7
+pytest==6.0.1
+python-apt==1.1.0b1+ubuntu0.16.4.9
+python-dateutil==2.8.1
+pytz==2020.1
+PyYAML==5.3.1
+regex==2020.7.14
+requests==2.24.0
+sacremoses==0.0.43
+scipy==1.5.2
+seaborn==0.10.1
+sentencepiece==0.1.91
+six==1.15.0
+sortedcontainers==2.2.2
+tblib==1.7.0
+tokenizers==0.8.1rc1
+toml==0.10.1
+toolz==0.10.0
+torch==1.6.0
+tornado==6.0.4
+tqdm==4.48.2
+transformers==3.0.2
+typing-extensions==3.7.4.2
+unattended-upgrades==0.1
+urllib3==1.25.10
+zict==2.0.0
\ No newline at end of file
-- 
GitLab


From 841d0546761c92df3facff303a90ea2bb28360df Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 10 Aug 2020 18:26:07 +0200
Subject: [PATCH 055/116] Refactoring fixups

---
 src/pipelines/actions_based/stage5_stats.py | 2 +-
 src/pipelines/actions_based/train.py        | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/pipelines/actions_based/stage5_stats.py b/src/pipelines/actions_based/stage5_stats.py
index d7f3ea2..a91ae2f 100644
--- a/src/pipelines/actions_based/stage5_stats.py
+++ b/src/pipelines/actions_based/stage5_stats.py
@@ -40,7 +40,7 @@ if __name__ == "__main__":
     df = dd.read_parquet(INPUT_FOLDER, engine="pyarrow")
     df = df.apply(expand_dims, result_type="expand", axis=1, meta=EXPAND_DIMS_META)
 
-    outputs_bag = df["output"].to_bag()
+    outputs_bag = df["target"].to_bag()
 
     inital_values = {
         "class_number": np.array([0] * len(ACTIONS_KEYS)),
diff --git a/src/pipelines/actions_based/train.py b/src/pipelines/actions_based/train.py
index 0a3ac69..e6ed38e 100755
--- a/src/pipelines/actions_based/train.py
+++ b/src/pipelines/actions_based/train.py
@@ -97,10 +97,10 @@ if __name__ == "__main__":
         i = sample_start
         for data_batch in get_batches(df, batch_size, 100, random_index_shuffle, i):
             inputs = data_batch.apply(
-                lambda x: x["input"].reshape(x["input_shape"]), axis=1
+                lambda x: x["source"].reshape(x["source_shape"]), axis=1
             ).values
             outputs = data_batch.apply(
-                lambda x: x["output"].reshape(x["output_shape"]), axis=1
+                lambda x: x["target"].reshape(x["target_shape"]), axis=1
             ).values
             attentions_mask = data_batch.apply(
                 lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]),
-- 
GitLab


From 29efda6defc1ebeaf56f9e5328fffca7291f3303 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 11 Aug 2020 10:36:25 +0200
Subject: [PATCH 056/116] More refactoring, added less error-prone
 preprocessing

---
 punctuate.py                                  |  5 ++-
 .../actions_based/stage1_extraction.py        |  5 ++-
 .../actions_based/stage2_tokenization.py      |  5 ++-
 .../actions_based/stage3_exploding.py         |  7 +++-
 .../translation_based/stage1_extraction.py    |  5 ++-
 .../stage2_create_batches.py                  |  5 ++-
 src/utils.py                                  | 22 ++++++++++
 tests/test_utils.py                           | 40 +++++++++++++++++++
 worker.py                                     |  3 +-
 9 files changed, 89 insertions(+), 8 deletions(-)
 create mode 100644 tests/test_utils.py

diff --git a/punctuate.py b/punctuate.py
index 05ba1ba..510eaad 100755
--- a/punctuate.py
+++ b/punctuate.py
@@ -1,11 +1,12 @@
 #!/usr/bin/python3
 
 import argparse
-from argparse import Namespace
 import os
+from argparse import Namespace
 
 from src.pipelines.actions_based.processing import apply_actions_punctuation
 from src.pipelines.actions_based.utils import load_model
+from src.utils import preprocess
 
 
 def get_args() -> Namespace:
@@ -42,7 +43,7 @@ if __name__ == "__main__":
     tokenizer, model = load_model(args.model, args.base, "cpu")
 
     with open(args.input, "r") as f:
-        text = f.read()
+        text = preprocess(f.read())
         text_processed = apply_actions_punctuation(
             text, args.chunk_size, tokenizer, model, args.threshold
         )
diff --git a/src/pipelines/actions_based/stage1_extraction.py b/src/pipelines/actions_based/stage1_extraction.py
index 5a058a9..94dc26c 100644
--- a/src/pipelines/actions_based/stage1_extraction.py
+++ b/src/pipelines/actions_based/stage1_extraction.py
@@ -6,7 +6,10 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.actions_based.processing import APPLY_FILE_PROCESSING_META, apply_file_processing
+from src.pipelines.actions_based.processing import (
+    APPLY_FILE_PROCESSING_META,
+    apply_file_processing,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/actions_based/stage2_tokenization.py b/src/pipelines/actions_based/stage2_tokenization.py
index 0ea3586..b30445f 100644
--- a/src/pipelines/actions_based/stage2_tokenization.py
+++ b/src/pipelines/actions_based/stage2_tokenization.py
@@ -4,7 +4,10 @@ import dask.dataframe as dd
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.actions_based.processing import APPLY_TOKENIZATION_META, apply_tokenization
+from src.pipelines.actions_based.processing import (
+    APPLY_TOKENIZATION_META,
+    apply_tokenization,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
diff --git a/src/pipelines/actions_based/stage3_exploding.py b/src/pipelines/actions_based/stage3_exploding.py
index 81dc965..72ec128 100644
--- a/src/pipelines/actions_based/stage3_exploding.py
+++ b/src/pipelines/actions_based/stage3_exploding.py
@@ -2,7 +2,12 @@
 import dask.dataframe as dd
 from dask.distributed import Client
 
-from src.processing import EXPAND_DIMS_META, FLATTEN_DIMS_META, expand_dims, flatten_dims
+from src.processing import (
+    EXPAND_DIMS_META,
+    FLATTEN_DIMS_META,
+    expand_dims,
+    flatten_dims,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
diff --git a/src/pipelines/translation_based/stage1_extraction.py b/src/pipelines/translation_based/stage1_extraction.py
index 386211d..6ffdbf7 100644
--- a/src/pipelines/translation_based/stage1_extraction.py
+++ b/src/pipelines/translation_based/stage1_extraction.py
@@ -6,7 +6,10 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.translation_based.processing import RAW_TO_DATAFRAME_META, raw_to_dataframe
+from src.pipelines.translation_based.processing import (
+    RAW_TO_DATAFRAME_META,
+    raw_to_dataframe,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/translation_based/stage2_create_batches.py b/src/pipelines/translation_based/stage2_create_batches.py
index ade8bf2..83a2edc 100644
--- a/src/pipelines/translation_based/stage2_create_batches.py
+++ b/src/pipelines/translation_based/stage2_create_batches.py
@@ -4,7 +4,10 @@ from dask import delayed
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.translation_based.processing import GENERATE_BATCHES_META, generate_batches
+from src.pipelines.translation_based.processing import (
+    GENERATE_BATCHES_META,
+    generate_batches,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
diff --git a/src/utils.py b/src/utils.py
index d03d216..906de65 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -45,9 +45,31 @@ def remove_punctuation(text: str) -> str:
     Returns:
         str: Text with all punctuactions removed
     """
+
+    # Separating characters
+    text = text.replace("-", " ").replace("/", " ").replace("+", " ")
+
     return "".join(filter(lambda x: x.isalnum() or x.isspace(), text))
 
 
+def preprocess(text: str) -> str:
+    """Makes sure that input is in the same format as training data (no non-alphanum chars, no double spaces,
+        all lowercase etc.)
+
+    Args:
+        text (str): Text to be processed
+
+    Returns:
+        str: Text in training-data format
+    """
+    text = remove_punctuation(text)
+    text = remove_multiple_spaces(text)
+    text = text.lower()
+    text = text.strip()
+
+    return text
+
+
 def prepare_folder(path: str, wipe: bool = False) -> None:
     """Function make sure that provided path exists. Can aditionaly
     remove all files from the path.
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000..9887354
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,40 @@
+from src.utils import convert_to_timedelta, preprocess, remove_multiple_spaces, remove_punctuation
+
+
+def test_remove_multiple_spaces():
+    provided = "Ala   ma Kota.      Kot ma Ale "
+    expected = "Ala ma Kota. Kot ma Ale "
+
+    assert remove_multiple_spaces(provided) == expected
+
+
+def test_remove_punctuation():
+    provided = "Ala..  ma-Kota!?.@@$ Kot ma Ale ()*"
+    expected = "Ala  ma Kota Kot ma Ale "
+
+    assert remove_punctuation(provided) == expected
+
+
+def test_preprocess():
+    provided = "Ala  ma-Kota!?.@@$ Kot ma Ale ()*"
+    expected = "ala ma kota kot ma ale"
+
+    assert preprocess(provided) == expected
+
+
+def test_convert_to_timedelta():
+    assert convert_to_timedelta("5d").days == 5
+    assert convert_to_timedelta("5d").seconds == 0
+    assert convert_to_timedelta("5d").microseconds == 0
+
+    assert convert_to_timedelta("4h").days == 0
+    assert convert_to_timedelta("4h").seconds == 4 * 60 * 60
+    assert convert_to_timedelta("4h").microseconds == 0
+
+    assert convert_to_timedelta("3m").days == 0
+    assert convert_to_timedelta("3m").seconds == 3 * 60
+    assert convert_to_timedelta("3m").microseconds == 0
+
+    assert convert_to_timedelta("2s").days == 0
+    assert convert_to_timedelta("2s").seconds == 2
+    assert convert_to_timedelta("2s").microseconds == 0
diff --git a/worker.py b/worker.py
index 15bce3d..5bf6e0c 100755
--- a/worker.py
+++ b/worker.py
@@ -6,6 +6,7 @@ import nlp_ws
 
 from src.pipelines.actions_based.processing import apply_actions_punctuation
 from src.pipelines.actions_based.utils import load_model
+from src.utils import preprocess
 
 
 class Worker(nlp_ws.NLPWorker):
@@ -27,7 +28,7 @@ class Worker(nlp_ws.NLPWorker):
         """Implementation of example tasks that copies files."""
 
         with open(input_file, "r") as f:
-            text = f.read()
+            text = preprocess(f.read())
             text_processed = apply_actions_punctuation(
                 text, self.chunk_size, self.tokenizer, self.model, self.threshold
             )
-- 
GitLab


From 89c32149e537337ca9bde6dc69477a32f987cf5f Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 13 Aug 2020 14:27:18 +0200
Subject: [PATCH 057/116] Added model with restricted, probabilistic loss

---
 .gitignore                                    |   3 +-
 src/models/actions_model_base.py              |  57 +++++++
 src/models/actions_model_restricted.py        | 137 +++++++++++++++++
 .../actions_based/train restricted.py         | 142 ++++++++++++++++++
 src/pipelines/actions_based/train.py          |  10 +-
 tests/models/__init__.py                      |   0
 tests/models/test_actions_model_base.py       |  40 +++++
 tests/models/test_actions_model_restricted.py |  50 ++++++
 8 files changed, 432 insertions(+), 7 deletions(-)
 create mode 100644 src/models/actions_model_base.py
 create mode 100644 src/models/actions_model_restricted.py
 create mode 100755 src/pipelines/actions_based/train restricted.py
 create mode 100644 tests/models/__init__.py
 create mode 100644 tests/models/test_actions_model_base.py
 create mode 100644 tests/models/test_actions_model_restricted.py

diff --git a/.gitignore b/.gitignore
index c929ac9..11789fe 100644
--- a/.gitignore
+++ b/.gitignore
@@ -13,4 +13,5 @@ __pycache__
 .dvc
 .tox
 notebooks
-dvc.lock
\ No newline at end of file
+dvc.lock
+dask-worker-space
\ No newline at end of file
diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
new file mode 100644
index 0000000..4fd334b
--- /dev/null
+++ b/src/models/actions_model_base.py
@@ -0,0 +1,57 @@
+from typing import Callable, Tuple
+import torch.nn as nn
+from torch.nn.modules.loss import BCEWithLogitsLoss
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_bert import BertForTokenClassification
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+import torch
+
+
+class ActionsModelBase(nn.Module):
+    """Model based on simple multilabel per-token classifiaction. Each token is binarly classified in n-dimensions"""
+
+    def __init__(self, base_model: str, num_labels: int = len(ACTIONS_KEYS)) -> None:
+        """Initializes actions model
+
+        Args:
+            base_model (str): Name of base model
+            num_labels (int): Length of action vector
+        """
+        super(ActionsModelBase, self).__init__()
+
+        config = PretrainedConfig.from_pretrained(base_model)
+        config.num_labels = num_labels
+
+        self.criterion = None
+        self.core = BertForTokenClassification(config)
+
+    def forward(
+        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Computes logits for uppercasing and adding punctuation to a word
+
+        Args:
+            input_ids (torch.Tensor): Array of ids of tokens. Shape BxL
+            attention_mask (torch.Tensor): Mask telling if a token should be masked out (ie. Padding). Shape BxL
+
+        Returns:
+            torch.Tensor: Predicted actions vector
+        """
+        y_pred = self.core(input_ids=input_ids, attention_mask=attention_mask)[0]
+
+        return y_pred
+
+
+class ActionsModelBaseLoss(nn.Module):
+    def __init__(self, prior_odds: torch.Tensor) -> None:
+        super(ActionsModelBaseLoss, self).__init__()
+
+        self.core = BCEWithLogitsLoss(prior_odds)
+
+    def forward(
+        self,
+        true_action_vector: torch.Tensor,
+        predicted_action_vector_logits: torch.Tensor,
+    ) -> torch.Tensor:
+
+        return self.core(predicted_action_vector_logits, true_action_vector)
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
new file mode 100644
index 0000000..1876422
--- /dev/null
+++ b/src/models/actions_model_restricted.py
@@ -0,0 +1,137 @@
+from typing import Tuple
+import torch.nn as nn
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_bert import BertForTokenClassification
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+import torch
+
+
+class ActionsModelRestricted(nn.Module):
+    """Similar to ActionsModelBase, however no-punctuation class is added
+    internally, and punctuation-related entries are treaded as proper binomial
+    probability distribution
+    """
+
+    def __init__(self, base_model: str, extended_action_vector_size: int) -> None:
+        """Initializes actions model
+
+        Args:
+            base_model (str): Name of base model
+        """
+        super(ActionsModelRestricted, self).__init__()
+
+        config = PretrainedConfig.from_pretrained(base_model)
+
+        config.num_labels = extended_action_vector_size
+
+        self.core = BertForTokenClassification(config)
+
+    def forward(
+        self, input_ids: torch.Tensor, attentions_mask: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Computes logits for uppercasing and adding punctuation to a word
+
+        Args:
+            input_ids (torch.Tensor): Array of ids of tokens. Shape BxL
+            attentions_mask (torch.Tensor): Mask telling if a token should be masked out (ie. Padding). Shape BxL
+
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor]: Logit for makeing each word uppercase and for adding a punctuation mark to each word
+        """
+        y_pred = self.core(input_ids=input_ids, attention_mask=attentions_mask)[0]
+
+        # Force punctuations to be proper categorical logodds-distribution
+        y_pred[:, :, 1:] = self._logit(torch.softmax(y_pred[:, :, 1:], -1))
+
+        return y_pred
+
+    def loss_func(
+        true_extended_action_vector: torch.Tensor,
+        predicted_action_vector_logits: torch.Tensor,
+        prior_log_probs: torch.Tensor,
+    ) -> torch.Tensor:
+        """Loss function for actions model
+
+        Args:
+            true_uppercase (torch.Tensor): Tensor teling if token should be uppercased (1.0) or not (0.0) (BxL)
+            true_punctuation_type (torch.Tensor): If punctuation sign should be added after word (BxLxP)
+            uppercase_logit (torch.Tensor): Logit of uppercasing probability predited by model
+            punctuation_type_logit (torch.Tensor): Logits of punctuation probability predicted by model
+
+        Returns:
+            torch.Tensor: Loss value
+        """
+        uppercase_cond_log_prob = torch.distributions.Bernoulli(
+            logits=predicted_action_vector_logits[:, :, 0]
+        ).log_prob(true_extended_action_vector[:, :, 0])
+
+        punctuation_cond_log_prob = torch.distributions.OneHotCategorical(
+            logits=predicted_action_vector_logits[:, :, 1:]
+        ).log_prob(true_extended_action_vector[0, 0, 1:])
+
+        model_uppercase_cond_prob = uppercase_cond_log_prob - prior_log_probs[0]
+
+        prior_index = true_extended_action_vector[:, :, 1:].argmax(-1)
+
+        model_punctuation_cond_prob = (
+            punctuation_cond_log_prob - prior_log_probs[1 + prior_index]
+        )
+
+        # An assumption that Uppercase ⊥ Punctuation, however it's not terrible as we only
+        # assume that uppercasing of current word is independent of current-word punctuation.
+        # In most cases it should be rather true. Also use mean instead of sum to get batch
+        # size indepnedence.
+        return -torch.mean(model_uppercase_cond_prob + model_punctuation_cond_prob)
+
+    @staticmethod
+    def _logit(x: torch.Tensor):
+        EPS = 1e-5
+
+        z = torch.clamp(x, EPS, 1.0 - EPS)
+
+        return torch.log(z / (1 - z))
+
+
+class ActionsModelRestrictedLoss(nn.Module):
+    def __init__(self, prior_log_probs: torch.Tensor) -> None:
+        super(ActionsModelRestrictedLoss, self).__init__()
+
+        self.prior_log_probs = prior_log_probs
+
+    def forward(
+        self,
+        true_extended_action_vector: torch.Tensor,
+        predicted_action_vector_logits: torch.Tensor,
+    ) -> torch.Tensor:
+        """Loss function for actions model
+
+        Args:
+            true_uppercase (torch.Tensor): Tensor teling if token should be uppercased (1.0) or not (0.0) (BxL)
+            true_punctuation_type (torch.Tensor): If punctuation sign should be added after word (BxLxP)
+            uppercase_logit (torch.Tensor): Logit of uppercasing probability predited by model
+            punctuation_type_logit (torch.Tensor): Logits of punctuation probability predicted by model
+
+        Returns:
+            torch.Tensor: Loss value
+        """
+        uppercase_cond_log_prob = torch.distributions.Bernoulli(
+            logits=predicted_action_vector_logits[:, :, 0]
+        ).log_prob(true_extended_action_vector[:, :, 0])
+
+        punctuation_cond_log_prob = torch.distributions.OneHotCategorical(
+            logits=predicted_action_vector_logits[:, :, 1:]
+        ).log_prob(true_extended_action_vector[0, 0, 1:])
+
+        model_uppercase_cond_prob = uppercase_cond_log_prob - self.prior_log_probs[0]
+
+        prior_index = true_extended_action_vector[:, :, 1:].argmax(-1)
+
+        model_punctuation_cond_prob = (
+            punctuation_cond_log_prob - self.prior_log_probs[1 + prior_index]
+        )
+
+        # An assumption that Uppercase ⊥ Punctuation, however it's not terrible as we only
+        # assume that uppercasing of current word is independent of current-word punctuation.
+        # In most cases it should be rather true. Also use mean instead of sum to get batch
+        # size indepnedence.
+        return -torch.mean(model_uppercase_cond_prob + model_punctuation_cond_prob)
diff --git a/src/pipelines/actions_based/train restricted.py b/src/pipelines/actions_based/train restricted.py
new file mode 100755
index 0000000..e6ed38e
--- /dev/null
+++ b/src/pipelines/actions_based/train restricted.py	
@@ -0,0 +1,142 @@
+#!/usr/bin/python3
+
+import glob
+import pickle
+from datetime import datetime
+
+import dask.dataframe as dd
+import numpy as np
+import torch
+from torch.nn import BCEWithLogitsLoss
+from transformers import BertForTokenClassification, BertTokenizerFast
+
+from src.batch_loading import get_batches
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.training import latest_model, save_training_step
+from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config, prepare_folder
+
+INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
+OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions"
+
+if __name__ == "__main__":
+    config = get_config()
+    learning_rate = config["actions"]["training"]["learning_rate"]
+    num_epochs = config["actions"]["training"]["num_epochs"]
+    batch_size = config["actions"]["training"]["batch_size"]
+    save_step = config["actions"]["training"]["save_step"]
+    loss_averaging_span = config["actions"]["training"]["loss_averaging_span"]
+    fresh_start = config["actions"]["training"]["fresh_start"]
+    device_name = config["actions"]["training"]["device"]
+    max_train_time = config["actions"]["training"]["max_training_time"]
+    base_model = config["global"]["base_model"]
+    seed = config["global"]["random_seed"]
+
+    prepare_folder(OUTPUT_PATH)
+    np.random.seed(seed=seed)
+
+    if max_train_time is not None:
+        max_train_time = convert_to_timedelta(max_train_time)
+
+    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
+    print(f"Training on {device}")
+
+    # Load loss weights
+    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
+        stats = pickle.load(f)
+        pos_examples = stats["class_number"]
+        neg_examples = stats["num_examples"] - stats["class_number"]
+        pos_weight = torch.tensor(neg_examples / pos_examples)
+
+    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
+    tokenizer = BertTokenizerFast.from_pretrained(base_model)
+
+    model = BertForTokenClassification.from_pretrained(
+        base_model, num_labels=len(ACTIONS_KEYS)
+    ).to(device)
+    criterion = BCEWithLogitsLoss(pos_weight=pos_weight).to(device)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+
+    epoch_start = 0
+    sample_start = 0
+    if fresh_start is False:
+        checkpoint_files = glob.glob(f"{OUTPUT_PATH}/*.model")
+        latest = latest_model(checkpoint_files)
+
+        if latest is not None:
+            epoch, batch = latest
+            model.load_state_dict(
+                torch.load(f"{OUTPUT_PATH}/{epoch}-{batch}.model", map_location=device,)
+            )
+            optimizer.load_state_dict(
+                torch.load(
+                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer", map_location=device,
+                )
+            )
+
+            epoch_start, sample_start = epoch, batch
+            print(f"Loaded {epoch}-{batch}")
+
+    model.train()
+    model.base_model.train()
+    losses = []
+
+    num_samples = df.tail(1).index.values[0] + 1
+    random_index_shuffle = np.random.permutation(range(num_samples))
+
+    training_stopped = False
+
+    time_max = datetime.max
+    if max_train_time is not None:
+        time_max = datetime.now() + max_train_time
+
+    for epoch in range(epoch_start, num_epochs):
+        if training_stopped:
+            break
+
+        i = sample_start
+        for data_batch in get_batches(df, batch_size, 100, random_index_shuffle, i):
+            inputs = data_batch.apply(
+                lambda x: x["source"].reshape(x["source_shape"]), axis=1
+            ).values
+            outputs = data_batch.apply(
+                lambda x: x["target"].reshape(x["target_shape"]), axis=1
+            ).values
+            attentions_mask = data_batch.apply(
+                lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]),
+                axis=1,
+            ).values
+
+            inputs = torch.tensor(np.stack(inputs).squeeze()).to(device)
+            outputs = torch.tensor(np.stack(outputs)).to(device)
+            attentions_mask = torch.tensor(np.stack(attentions_mask)).to(device)
+
+            y_pred = model(input_ids=inputs, attention_mask=attentions_mask)[0]
+
+            loss = criterion(y_pred, outputs)
+
+            losses.append(loss.item())
+            if len(losses) > loss_averaging_span:
+                losses = losses[-loss_averaging_span:]
+
+            print(f"epoch: {epoch} | step: {i} | loss: {np.mean(losses)}")
+
+            optimizer.zero_grad()
+
+            if i % save_step == 0 and (i != sample_start or epoch != epoch_start):
+                print(f"Saving: Epoch {epoch}, step {i}")
+                save_training_step(OUTPUT_PATH, f"{epoch}-{i}", model, optimizer)
+
+            if datetime.now() > time_max:
+                print(f"Max time reached, saving: Epoch {epoch}, step {i}")
+                save_training_step(OUTPUT_PATH, f"{epoch}-{i}", model, optimizer)
+                training_stopped = True
+                break
+
+            loss.backward()
+            optimizer.step()
+
+            i += 1
+
+    if not training_stopped:
+        save_training_step(OUTPUT_PATH, "final", model, optimizer)
diff --git a/src/pipelines/actions_based/train.py b/src/pipelines/actions_based/train.py
index e6ed38e..6973c53 100755
--- a/src/pipelines/actions_based/train.py
+++ b/src/pipelines/actions_based/train.py
@@ -3,6 +3,7 @@
 import glob
 import pickle
 from datetime import datetime
+from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
 
 import dask.dataframe as dd
 import numpy as np
@@ -51,10 +52,8 @@ if __name__ == "__main__":
     df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
-    model = BertForTokenClassification.from_pretrained(
-        base_model, num_labels=len(ACTIONS_KEYS)
-    ).to(device)
-    criterion = BCEWithLogitsLoss(pos_weight=pos_weight).to(device)
+    model = ActionsModelBase(base_model, len(ACTIONS_KEYS)).to(device)
+    criterion = ActionsModelBaseLoss(pos_weight).to(device)
     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
     epoch_start = 0
@@ -78,7 +77,6 @@ if __name__ == "__main__":
             print(f"Loaded {epoch}-{batch}")
 
     model.train()
-    model.base_model.train()
     losses = []
 
     num_samples = df.tail(1).index.values[0] + 1
@@ -111,7 +109,7 @@ if __name__ == "__main__":
             outputs = torch.tensor(np.stack(outputs)).to(device)
             attentions_mask = torch.tensor(np.stack(attentions_mask)).to(device)
 
-            y_pred = model(input_ids=inputs, attention_mask=attentions_mask)[0]
+            y_pred = model(input_ids=inputs, attention_mask=attentions_mask)
 
             loss = criterion(y_pred, outputs)
 
diff --git a/tests/models/__init__.py b/tests/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/models/test_actions_model_base.py b/tests/models/test_actions_model_base.py
new file mode 100644
index 0000000..ee3459b
--- /dev/null
+++ b/tests/models/test_actions_model_base.py
@@ -0,0 +1,40 @@
+import torch
+import torch.distributions as dist
+from transformers.tokenization_bert import BertTokenizerFast
+from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
+
+
+def test_dimensions():
+    base_model = "dkleczek/bert-base-polish-cased-v1"
+    action_vector_size = 5
+
+    tokens = BertTokenizerFast.from_pretrained(base_model)(
+        "Ala ma kota", return_tensors="pt"
+    )
+    model = ActionsModelBase(base_model, action_vector_size)
+
+    result = model(tokens["input_ids"], tokens["attention_mask"])
+
+    assert len(result.shape) == 3
+
+    assert result.shape[0] == tokens["input_ids"].shape[0]
+    assert result.shape[1] == tokens["input_ids"].shape[1]
+    assert result.shape[2] == action_vector_size
+
+
+def test_loss_dimensions():
+    batch_size = 5
+    sequence_len = 10
+    actions_size = 3
+    weights = torch.zeros(actions_size) + 0.3
+    actions_vector_true = torch.zeros((batch_size, sequence_len, actions_size))
+    actions_vector_bad = torch.ones((batch_size, sequence_len, actions_size))
+    loss = ActionsModelBaseLoss(weights)
+
+    result = loss(actions_vector_true, actions_vector_bad)
+    assert len(result.shape) == 0
+
+    result_perfect = loss(actions_vector_true, actions_vector_true)
+    result_bad = loss(actions_vector_true, actions_vector_bad)
+
+    assert result_perfect < result_bad
\ No newline at end of file
diff --git a/tests/models/test_actions_model_restricted.py b/tests/models/test_actions_model_restricted.py
new file mode 100644
index 0000000..679ab58
--- /dev/null
+++ b/tests/models/test_actions_model_restricted.py
@@ -0,0 +1,50 @@
+from src.models.actions_model_restricted import (
+    ActionsModelRestricted,
+    ActionsModelRestrictedLoss,
+)
+import torch
+import torch.distributions as dist
+from transformers.tokenization_bert import BertTokenizerFast
+
+
+def test_dimensions():
+    base_model = "dkleczek/bert-base-polish-cased-v1"
+    action_vector_size = 5
+
+    tokens = BertTokenizerFast.from_pretrained(base_model)(
+        "Ala ma kota", return_tensors="pt"
+    )
+    model = ActionsModelRestricted(base_model, action_vector_size)
+
+    result = model(tokens["input_ids"], tokens["attention_mask"])
+
+    assert len(result.shape) == 3
+
+    assert result.shape[0] == tokens["input_ids"].shape[0]
+    assert result.shape[1] == tokens["input_ids"].shape[1]
+    assert result.shape[2] == action_vector_size
+
+
+def test_loss_dimensions():
+    batch_size = 5
+    sequence_len = 10
+    action_vector_size = 4
+    prior_probs = torch.tensor([0.3, 0.2, 0.3, 0.5]).log()
+    loss = ActionsModelRestrictedLoss(prior_probs)
+
+    actions_vector_true = torch.zeros((batch_size, sequence_len, action_vector_size))
+    actions_vector_true[:, :, -1] = 1.0
+
+    actions_vector_bad = torch.ones((batch_size, sequence_len, action_vector_size))
+    actions_vector_bad[:, :, -1] = 0.0
+
+    result = loss(actions_vector_true, actions_vector_bad)
+    assert len(result.shape) == 0
+
+    result_perfect = loss(actions_vector_true, actions_vector_true)
+    result_bad = loss(actions_vector_true, actions_vector_bad)
+
+    print(result_perfect)
+    print(result_bad)
+
+    assert result_perfect < result_bad
\ No newline at end of file
-- 
GitLab


From 18646d815971546ce8634277c13d1f25697b3c79 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 13 Aug 2020 14:48:09 +0200
Subject: [PATCH 058/116] Pipeline for restricted verison of actions model

---
 src/models/actions_model_restricted.py        | 12 ++++----
 ...rain restricted.py => train_restricted.py} | 29 ++++++++++++++-----
 2 files changed, 28 insertions(+), 13 deletions(-)
 rename src/pipelines/actions_based/{train restricted.py => train_restricted.py} (87%)

diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index 1876422..1172ae9 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -27,23 +27,25 @@ class ActionsModelRestricted(nn.Module):
         self.core = BertForTokenClassification(config)
 
     def forward(
-        self, input_ids: torch.Tensor, attentions_mask: torch.Tensor
+        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Computes logits for uppercasing and adding punctuation to a word
 
         Args:
             input_ids (torch.Tensor): Array of ids of tokens. Shape BxL
-            attentions_mask (torch.Tensor): Mask telling if a token should be masked out (ie. Padding). Shape BxL
+            attention_mask (torch.Tensor): Mask telling if a token should be masked out (ie. Padding). Shape BxL
 
         Returns:
             Tuple[torch.Tensor, torch.Tensor]: Logit for makeing each word uppercase and for adding a punctuation mark to each word
         """
-        y_pred = self.core(input_ids=input_ids, attention_mask=attentions_mask)[0]
+        y_pred = self.core(input_ids=input_ids, attention_mask=attention_mask)[0]
+
+        pred_uppercase = y_pred[:, :, :1]
 
         # Force punctuations to be proper categorical logodds-distribution
-        y_pred[:, :, 1:] = self._logit(torch.softmax(y_pred[:, :, 1:], -1))
+        pred_punctuation = self._logit(torch.softmax(y_pred[:, :, 1:], -1))
 
-        return y_pred
+        return torch.cat([pred_uppercase, pred_punctuation], -1)
 
     def loss_func(
         true_extended_action_vector: torch.Tensor,
diff --git a/src/pipelines/actions_based/train restricted.py b/src/pipelines/actions_based/train_restricted.py
similarity index 87%
rename from src/pipelines/actions_based/train restricted.py
rename to src/pipelines/actions_based/train_restricted.py
index e6ed38e..8d862b0 100755
--- a/src/pipelines/actions_based/train restricted.py	
+++ b/src/pipelines/actions_based/train_restricted.py
@@ -4,6 +4,12 @@ import glob
 import pickle
 from datetime import datetime
 
+from torch import dtype
+from src.models.actions_model_restricted import (
+    ActionsModelRestricted,
+    ActionsModelRestrictedLoss,
+)
+
 import dask.dataframe as dd
 import numpy as np
 import torch
@@ -45,16 +51,20 @@ if __name__ == "__main__":
     with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
         stats = pickle.load(f)
         pos_examples = stats["class_number"]
-        neg_examples = stats["num_examples"] - stats["class_number"]
-        pos_weight = torch.tensor(neg_examples / pos_examples)
+
+        probs = pos_examples / stats["num_examples"]
+
+        # Add no-punctuation prob
+        probs = np.concatenate([probs, np.array([1.0 - np.sum(probs[1:])])], -1)
+        assert probs[-1] >= 0.0 and probs[-1] <= 1.0
+
+        probs = torch.tensor(probs).to(device)
 
     df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
-    model = BertForTokenClassification.from_pretrained(
-        base_model, num_labels=len(ACTIONS_KEYS)
-    ).to(device)
-    criterion = BCEWithLogitsLoss(pos_weight=pos_weight).to(device)
+    model = ActionsModelRestricted(base_model, len(ACTIONS_KEYS) + 1).to(device)
+    criterion = ActionsModelRestrictedLoss(probs).to(device)
     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
     epoch_start = 0
@@ -78,7 +88,6 @@ if __name__ == "__main__":
             print(f"Loaded {epoch}-{batch}")
 
     model.train()
-    model.base_model.train()
     losses = []
 
     num_samples = df.tail(1).index.values[0] + 1
@@ -111,7 +120,11 @@ if __name__ == "__main__":
             outputs = torch.tensor(np.stack(outputs)).to(device)
             attentions_mask = torch.tensor(np.stack(attentions_mask)).to(device)
 
-            y_pred = model(input_ids=inputs, attention_mask=attentions_mask)[0]
+            y_pred = model(input_ids=inputs, attention_mask=attentions_mask)
+
+            outputs = torch.cat(
+                [outputs, (1.0 - outputs[:, :, 1:].max(-1)[0]).unsqueeze(-1)], axis=-1
+            )
 
             loss = criterion(y_pred, outputs)
 
-- 
GitLab


From b5d380848922d3a6a3bffd71b5cc422aff2999e1 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 17 Aug 2020 20:25:49 +0200
Subject: [PATCH 059/116] Restricted actions model based on reverese
 probability

---
 src/models/actions_model_restricted.py | 82 +++++++++++---------------
 1 file changed, 34 insertions(+), 48 deletions(-)

diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index 1172ae9..48a2c8e 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -6,6 +6,22 @@ from src.pipelines.actions_based.processing import ACTIONS_KEYS
 import torch
 
 
+def logit(x: torch.Tensor) -> torch.Tensor:
+    EPS = torch.tensor(1e-5)
+
+    z = torch.clamp(x, EPS, 1.0 - EPS)
+
+    return torch.log(z / (1 - z))
+
+
+class LogProb2Logit(nn.Module):
+    def __init__(self) -> None:
+        super(LogProb2Logit, self).__init__()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return logit(x.exp())
+
+
 class ActionsModelRestricted(nn.Module):
     """Similar to ActionsModelBase, however no-punctuation class is added
     internally, and punctuation-related entries are treaded as proper binomial
@@ -47,44 +63,6 @@ class ActionsModelRestricted(nn.Module):
 
         return torch.cat([pred_uppercase, pred_punctuation], -1)
 
-    def loss_func(
-        true_extended_action_vector: torch.Tensor,
-        predicted_action_vector_logits: torch.Tensor,
-        prior_log_probs: torch.Tensor,
-    ) -> torch.Tensor:
-        """Loss function for actions model
-
-        Args:
-            true_uppercase (torch.Tensor): Tensor teling if token should be uppercased (1.0) or not (0.0) (BxL)
-            true_punctuation_type (torch.Tensor): If punctuation sign should be added after word (BxLxP)
-            uppercase_logit (torch.Tensor): Logit of uppercasing probability predited by model
-            punctuation_type_logit (torch.Tensor): Logits of punctuation probability predicted by model
-
-        Returns:
-            torch.Tensor: Loss value
-        """
-        uppercase_cond_log_prob = torch.distributions.Bernoulli(
-            logits=predicted_action_vector_logits[:, :, 0]
-        ).log_prob(true_extended_action_vector[:, :, 0])
-
-        punctuation_cond_log_prob = torch.distributions.OneHotCategorical(
-            logits=predicted_action_vector_logits[:, :, 1:]
-        ).log_prob(true_extended_action_vector[0, 0, 1:])
-
-        model_uppercase_cond_prob = uppercase_cond_log_prob - prior_log_probs[0]
-
-        prior_index = true_extended_action_vector[:, :, 1:].argmax(-1)
-
-        model_punctuation_cond_prob = (
-            punctuation_cond_log_prob - prior_log_probs[1 + prior_index]
-        )
-
-        # An assumption that Uppercase ⊥ Punctuation, however it's not terrible as we only
-        # assume that uppercasing of current word is independent of current-word punctuation.
-        # In most cases it should be rather true. Also use mean instead of sum to get batch
-        # size indepnedence.
-        return -torch.mean(model_uppercase_cond_prob + model_punctuation_cond_prob)
-
     @staticmethod
     def _logit(x: torch.Tensor):
         EPS = 1e-5
@@ -99,6 +77,7 @@ class ActionsModelRestrictedLoss(nn.Module):
         super(ActionsModelRestrictedLoss, self).__init__()
 
         self.prior_log_probs = prior_log_probs
+        self.logprob2logit = LogProb2Logit()
 
     def forward(
         self,
@@ -116,24 +95,31 @@ class ActionsModelRestrictedLoss(nn.Module):
         Returns:
             torch.Tensor: Loss value
         """
+
+        # P(Uppercase, Words) = P(Words | Uppercase) * P(Uppercase)
         uppercase_cond_log_prob = torch.distributions.Bernoulli(
-            logits=predicted_action_vector_logits[:, :, 0]
-        ).log_prob(true_extended_action_vector[:, :, 0])
+            logits=predicted_action_vector_logits[:, :, 0].reshape(-1)
+        ).log_prob(true_extended_action_vector[:, :, 0].reshape(-1))
+
+        uppercase_prior_log_prob = torch.distributions.Bernoulli(
+            logits=self.logprob2logit(self.prior_log_probs[0])
+        ).log_prob(true_extended_action_vector[:, :, 0].reshape(-1))
 
+        uppercase_log_prob = uppercase_cond_log_prob + uppercase_prior_log_prob
+
+        # P(Punctuation, Words) = P(Words | Punctuation) * P(Punctuation)
         punctuation_cond_log_prob = torch.distributions.OneHotCategorical(
             logits=predicted_action_vector_logits[:, :, 1:]
-        ).log_prob(true_extended_action_vector[0, 0, 1:])
-
-        model_uppercase_cond_prob = uppercase_cond_log_prob - self.prior_log_probs[0]
+        ).log_prob(true_extended_action_vector[:, :, 1:])
 
-        prior_index = true_extended_action_vector[:, :, 1:].argmax(-1)
+        punctuation_prior_log_prob = torch.distributions.OneHotCategorical(
+            logits=self.logprob2logit(self.prior_log_probs[1:])
+        ).log_prob(true_extended_action_vector[:, :, 1:])
 
-        model_punctuation_cond_prob = (
-            punctuation_cond_log_prob - self.prior_log_probs[1 + prior_index]
-        )
+        punctuation_log_prob = punctuation_cond_log_prob + punctuation_prior_log_prob
 
         # An assumption that Uppercase ⊥ Punctuation, however it's not terrible as we only
         # assume that uppercasing of current word is independent of current-word punctuation.
         # In most cases it should be rather true. Also use mean instead of sum to get batch
         # size indepnedence.
-        return -torch.mean(model_uppercase_cond_prob + model_punctuation_cond_prob)
+        return -torch.mean(uppercase_log_prob) - torch.mean(punctuation_log_prob)
-- 
GitLab


From 5f42c111ce790ecd3f91ea78a5c3e380b2f11572 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 18 Aug 2020 08:33:07 +0200
Subject: [PATCH 060/116] Added mixed model and train script generalization

---
 dvc.yaml                                      |  47 +++-
 params.yaml                                   |  32 ++-
 src/models/TransformerSeq2Seq.py              | 110 +--------
 src/models/actions_model_mixed.py             | 143 ++++++++++++
 src/models/actions_model_restricted.py        |  69 +-----
 src/models/common.py                          | 129 +++++++++++
 src/pipelines/actions_based/processing.py     |  12 +-
 src/pipelines/actions_based/train.py          | 140 ------------
 src/pipelines/actions_based/train_base.py     |  90 ++++++++
 src/pipelines/actions_based/train_mixed.py    | 116 ++++++++++
 .../actions_based/train_restricted.py         | 215 ++++++++----------
 src/pipelines/train.py                        | 176 ++++++++++++++
 src/pipelines/translation_based/processing.py |   5 +-
 src/utils.py                                  |  29 ++-
 .../actions_based/test_processing.py          |  16 +-
 tests/test_utils.py                           |  28 ++-
 16 files changed, 895 insertions(+), 462 deletions(-)
 create mode 100644 src/models/actions_model_mixed.py
 create mode 100644 src/models/common.py
 delete mode 100755 src/pipelines/actions_based/train.py
 create mode 100755 src/pipelines/actions_based/train_base.py
 create mode 100755 src/pipelines/actions_based/train_mixed.py
 create mode 100644 src/pipelines/train.py

diff --git a/dvc.yaml b/dvc.yaml
index bc29243..a73590c 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -40,21 +40,61 @@ stages:
     - src/pipelines/actions_based/stage5_stats.py
     outs:
     - generated/actions/stage5_stats
-  actions_training:
-    cmd: python3 -m src.pipelines.actions_based.train
+  actions_base_training:
+    cmd: python3 -m src.pipelines.actions_based.train_base
     deps:
     - generated/actions/stage4_reindexing
     - generated/actions/stage5_stats
     - src/pipelines/actions_based/train.py
     params:
     - global.base_model
+    - global.random_seed
     - actions.training.max_training_time
     - actions.training.learning_rate
     - actions.training.num_epochs
     - actions.training.batch_size
     - actions.training.save_step
     outs:
-    - checkpoints/actions
+    - checkpoints/actions_base
+
+  actions_restricted_training:
+    cmd: python3 -m src.pipelines.actions_based.train_restricted
+    deps:
+    - generated/actions/stage4_reindexing
+    - generated/actions/stage5_stats
+    - src/pipelines/actions_based/train.py
+    params:
+    - global.base_model
+    - global.random_seed
+    - actions.training_restricted.max_training_time
+    - actions.training_restricted.learning_rate
+    - actions.training_restricted.num_epochs
+    - actions.training_restricted.batch_size
+    - actions.training_restricted.save_step
+    outs:
+    - checkpoints/actions_restricted
+
+  actions_mixed_training:
+    cmd: python3 -m src.pipelines.actions_based.train_mixed
+    deps:
+    - generated/actions/stage4_reindexing
+    - generated/actions/stage5_stats
+    - src/pipelines/actions_based/train.py
+    params:
+    - global.base_model
+    - global.random_seed
+    - actions.training_mixed.embedding_size
+    - actions.training_mixed.num_heads
+    - actions.training_mixed.num_layers
+    - actions.training_mixed.dropout
+    - actions.training_mixed.feedforward_neurons
+    - actions.training_mixed.max_training_time
+    - actions.training_mixed.learning_rate
+    - actions.training_mixed.num_epochs
+    - actions.training_mixed.batch_size
+    - actions.training_mixed.save_step
+    outs:
+    - checkpoints/actions_mixed
   translations_extraction:
     cmd: python3 -m src.pipelines.translation_based.stage1_extraction
     deps:
@@ -90,6 +130,7 @@ stages:
     - src/pipelines/translation_based/train.py
     params:
     - global.base_model
+    - global.random_seed
     - translations.training.max_training_time
     - translations.training.learning_rate
     - translations.training.num_epochs
diff --git a/params.yaml b/params.yaml
index 3b9977d..39c7acf 100644
--- a/params.yaml
+++ b/params.yaml
@@ -27,11 +27,39 @@ actions:
         num_workers: 24
         worker_memory_limit: "2GB"
 
-    training:
+    training_base:
+        learning_rate: 0.0001
+        num_epochs: 5
+        batch_size: 2
+        batch_buffer_size: 100
+        save_step: 1000
+        max_training_time: null
+        loss_averaging_span: 1000
+        fresh_start: true
+        device: "cuda:0"
+
+    training_restricted:
         learning_rate: 0.0001
         num_epochs: 5
         batch_size: 2
-        save_step: 100
+        batch_buffer_size: 100
+        save_step: 1000
+        max_training_time: null
+        loss_averaging_span: 1000
+        fresh_start: true
+        device: "cuda:0"
+
+    training_mixed:
+        embedding_size: 200
+        num_heads: 4
+        num_layers: 2
+        dropout: 0.1
+        feedforward_neurons: 500
+        learning_rate: 0.0001
+        num_epochs: 5
+        batch_size: 2
+        batch_buffer_size: 1000
+        save_step: 1000
         max_training_time: null
         loss_averaging_span: 1000
         fresh_start: true
diff --git a/src/models/TransformerSeq2Seq.py b/src/models/TransformerSeq2Seq.py
index 753df1e..057a124 100644
--- a/src/models/TransformerSeq2Seq.py
+++ b/src/models/TransformerSeq2Seq.py
@@ -3,111 +3,5 @@ import math
 import torch
 import torch.nn as nn
 
-
-class PositionalEncoding(nn.Module):
-    """Adds sinsusoidal positional encoding (as in original "Attention is all you need" paper.)
-    src: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
-
-    """
-
-    def __init__(self, d_model: int, max_len: int, dropout=0.1):
-        """Sinusidal positional encodings
-
-        Args:
-            d_model (int): Embedding dimension
-            max_len (int): Maximum length of sequence
-            dropout (float, optional): Dropout ratio. Defaults to 0.1.
-        """
-        super(PositionalEncoding, self).__init__()
-        self.dropout = nn.Dropout(p=dropout)
-
-        pe = torch.zeros(max_len, d_model)
-        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
-        div_term = torch.exp(
-            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
-        )
-        pe[:, 0::2] = torch.sin(position * div_term)
-        pe[:, 1::2] = torch.cos(position * div_term)
-        pe = pe.unsqueeze(0).transpose(0, 1)
-        self.register_buffer("pe", pe)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        """Applies positional encoding
-
-        Args:
-            x (torch.Tensor): Word embeddings tensor
-
-        Returns:
-            torch.Tensor: Word embeddings with added positional encodings
-        """
-        x = x + self.pe[: x.size(0), :]
-        return self.dropout(x)
-
-
-class TransformerSeq2Seq(nn.Module):
-    """Class representing a sequence to sequence transformer, based on original "Attention is all you need" paper."""
-
-    def __init__(
-        self,
-        vocab_size: int,
-        embedding_size: int,
-        max_len: int,
-        num_heads: int = 8,
-        encoder_layers: int = 6,
-        decoder_layers: int = 6,
-        feedforward_neurons: int = 2048,
-        dropout: float = 0.1,
-    ):
-
-        super(TransformerSeq2Seq, self).__init__()
-
-        # Embedd from token to vec space
-        self.word_embedding = nn.Embedding(vocab_size, embedding_size)
-
-        # Add positional encoding
-        self.position_embedding = PositionalEncoding(embedding_size, max_len, dropout)
-
-        # Combined encoder-decoder step
-        self.core = nn.Transformer(
-            embedding_size,
-            num_heads,
-            encoder_layers,
-            decoder_layers,
-            feedforward_neurons,
-            dropout,
-        )
-
-        # Map embedding to word
-        self.embedding_to_words = nn.Linear(embedding_size, vocab_size)
-
-    def forward(
-        self, source: torch.Tensor, target: torch.Tensor, source_mask: torch.Tensor,
-    ) -> torch.Tensor:
-        """Full encoder-decoder pass
-
-        Args:
-            source (torch.Tensor): Tensor with batch of source sentences tokens [BxL shape]
-            target (torch.Tensor): Tensor with batch of target sentences tokens [BxL-1 shape]
-            source_mask (torch.Tensor): Mask applied to source (True if element is padding, False otherwise) [BxL shape]
-
-        Returns:
-            torch.Tensor: Tensor with predicted target sentences tokens [Bx(L-1)xV]
-        """
-        # Input to encoder
-        x = source.transpose(0, 1)
-        x = self.word_embedding(x)
-        x = self.position_embedding(x)
-
-        # Input to decoder
-        y = target.transpose(0, 1)
-        y = self.word_embedding(y)
-        y = self.position_embedding(y)
-
-        tgt_mask = self.core.generate_square_subsequent_mask(y.shape[0]).to(y.device)
-
-        z = self.core(
-            x, y, src_key_padding_mask=source_mask, tgt_mask=tgt_mask
-        ).transpose(1, 0)
-        z = self.embedding_to_words(z)
-
-        return z
+from src.models.common import PositionalEncoding
+from src.models.common import TransformerSeq2Seq
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
new file mode 100644
index 0000000..acd7f47
--- /dev/null
+++ b/src/models/actions_model_mixed.py
@@ -0,0 +1,143 @@
+from typing import Callable, Tuple
+import torch.nn as nn
+from torch.nn.modules.loss import BCEWithLogitsLoss
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_bert import BertForTokenClassification
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+import torch
+from src.models.common import PositionalEncoding, generate_square_subsequent_mask
+
+
+class ActionsModelMixed(nn.Module):
+    """Encoder-decoder based model with unpunctuated token sequence as input and array of action-vectors as output"""
+
+    def __init__(
+        self,
+        vocab_size: int,
+        embedding_size: int = 200,
+        num_heads: int = 4,
+        num_layers: int = 2,
+        feedforward_neurons: int = 200,
+        num_labels: int = len(ACTIONS_KEYS),
+        max_len: int = 500,
+        dropout: float = 0.1,
+    ) -> None:
+        """Initializes mixed model
+
+        Args:
+            vocab_size (int): Number of tokens in tokenizer dictionary
+            embedding_size (int, optional): Shape of word and punctuation embeddings. Defaults to 200.
+            num_heads (int, optional): Number of heads in multiheaded attention. Defaults to 4.
+            num_layers (int, optional): Number of both decoded and encoder layers. Defaults to 2.
+            feedforward_neurons (int, optional): Size of feed-forward neural network at the end of encoder/decoder. Defaults to 200.
+            num_labels (int, optional): Action-vector size. Defaults to len(ACTIONS_KEYS).
+            max_len (int, optional): Maxium length of sequence. Defaults to 500.
+            dropout (float, optional): Dropout ratio. Defaults to 0.1.
+        """
+        super(ActionsModelMixed, self).__init__()
+
+        # Word embedder
+        self.word_embedding = nn.Embedding(vocab_size, embedding_size)
+        self.punctuation_embedding = nn.Linear(num_labels, embedding_size)
+
+        # Add positional encoding
+        self.words_position_embedding = PositionalEncoding(
+            embedding_size, max_len, dropout
+        )
+        self.punctuation_position_embedding = PositionalEncoding(
+            embedding_size, max_len, dropout
+        )
+
+        # Sentence encoder
+        sentence_encoder_layer = nn.TransformerEncoderLayer(
+            embedding_size, num_heads, feedforward_neurons, dropout
+        )
+        self.sentence_encoder = nn.TransformerEncoder(
+            sentence_encoder_layer, num_layers=num_layers
+        )
+
+        # Punctuation decoder
+        punctuation_decoder_layer = nn.TransformerDecoderLayer(
+            embedding_size, num_heads, feedforward_neurons, dropout
+        )
+        self.punctuation_decoder = nn.TransformerDecoder(
+            punctuation_decoder_layer, num_layers=num_layers
+        )
+
+        self.to_labels = nn.Linear(embedding_size, num_labels)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        actions: torch.Tensor,
+        attention_mask: torch.Tensor,
+    ) -> torch.Tensor:
+        """Computes action vectors array from array of tokens
+
+        Args:
+            input_ids (torch.Tensor): Tokens representing unpuctuated text. Shape BxL
+            actions (torch.Tensor): Actions vector predicted up-till now. Shape BxL-1xA
+            attention_mask (torch.Tensor): Mask representing if token is padding (True) or Not. Shape BxL
+
+        Returns:
+            torch.Tensor: Predicted actions shifted one to the left. Shape BxL-1xA
+        """
+
+        # Input to encoder
+        x = input_ids.transpose(0, 1)
+        x = self.word_embedding(x)
+        x = self.words_position_embedding(x)
+
+        # Input to decoder
+        y = actions.transpose(0, 1)
+        y = self.punctuation_embedding(y)
+        y = self.punctuation_position_embedding(y)
+
+        tgt_mask = generate_square_subsequent_mask(y.shape[0]).to(y.device)
+
+        sentence_encoded = self.sentence_encoder(x, src_key_padding_mask=attention_mask)
+
+        actions_decoded = self.punctuation_decoder(
+            y, sentence_encoded, tgt_mask=tgt_mask
+        )
+
+        z = actions_decoded.transpose(1, 0)
+
+        return self.to_labels(z)
+
+
+class ActionsModelMixedLoss(nn.Module):
+    """Class representing proposed loss for training mixed actions model"""
+
+    def __init__(self, prior_odds: torch.Tensor) -> None:
+        """Initializes ActionsModelMixedLoss
+
+        Args:
+            prior_odds (torch.Tensor): Odds representing ratio of positive to negative examples for each label in action vector. Shape A
+        """
+        super(ActionsModelMixedLoss, self).__init__()
+
+        self.core = BCEWithLogitsLoss(prior_odds)
+
+    def forward(
+        self,
+        true_action_vector: torch.Tensor,
+        predicted_action_vector_logits: torch.Tensor,
+    ) -> torch.Tensor:
+        """Computes loss for training mixed actions model
+
+        Args:
+            true_action_vector (torch.Tensor): Action vector that should be
+                predicted by ActionsModelMixed (shifted by 1 to the left in
+                regards to inputs). Shape BxL-1xA
+
+            predicted_action_vector_logits (torch.Tensor): Action vector that
+                was acttualy predicted by ActionsModelMixed (shifted by 1 to
+                the left in regards to inputs). Shape BxL-1xA
+
+
+        Returns:
+            torch.Tensor: Loss of predition in relation to ground truth
+        """
+
+        return self.core(predicted_action_vector_logits, true_action_vector)
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index 48a2c8e..e73df2f 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -6,33 +6,17 @@ from src.pipelines.actions_based.processing import ACTIONS_KEYS
 import torch
 
 
-def logit(x: torch.Tensor) -> torch.Tensor:
-    EPS = torch.tensor(1e-5)
-
-    z = torch.clamp(x, EPS, 1.0 - EPS)
-
-    return torch.log(z / (1 - z))
-
-
-class LogProb2Logit(nn.Module):
-    def __init__(self) -> None:
-        super(LogProb2Logit, self).__init__()
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        return logit(x.exp())
-
-
 class ActionsModelRestricted(nn.Module):
     """Similar to ActionsModelBase, however no-punctuation class is added
-    internally, and punctuation-related entries are treaded as proper binomial
-    probability distribution
+    and punctuation-related entries are treaded as proper categorical distribution
     """
 
     def __init__(self, base_model: str, extended_action_vector_size: int) -> None:
-        """Initializes actions model
+        """Initializes restricted actions model
 
         Args:
             base_model (str): Name of base model
+            extended_action_vector_size (int): Action-vector size including additional no-punctuation logit
         """
         super(ActionsModelRestricted, self).__init__()
 
@@ -44,7 +28,7 @@ class ActionsModelRestricted(nn.Module):
 
     def forward(
         self, input_ids: torch.Tensor, attention_mask: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    ) -> torch.Tensor:
         """Computes logits for uppercasing and adding punctuation to a word
 
         Args:
@@ -52,13 +36,13 @@ class ActionsModelRestricted(nn.Module):
             attention_mask (torch.Tensor): Mask telling if a token should be masked out (ie. Padding). Shape BxL
 
         Returns:
-            Tuple[torch.Tensor, torch.Tensor]: Logit for makeing each word uppercase and for adding a punctuation mark to each word
+            torch.Tensor: Logit for making each word uppercase and for adding a punctuation mark to each word. Shape BxL
         """
         y_pred = self.core(input_ids=input_ids, attention_mask=attention_mask)[0]
 
         pred_uppercase = y_pred[:, :, :1]
 
-        # Force punctuations to be proper categorical logodds-distribution
+        # Force punctuations to be proper categorical distribution logits
         pred_punctuation = self._logit(torch.softmax(y_pred[:, :, 1:], -1))
 
         return torch.cat([pred_uppercase, pred_punctuation], -1)
@@ -73,53 +57,24 @@ class ActionsModelRestricted(nn.Module):
 
 
 class ActionsModelRestrictedLoss(nn.Module):
-    def __init__(self, prior_log_probs: torch.Tensor) -> None:
+    def __init__(self, prior_odds: torch.Tensor) -> None:
         super(ActionsModelRestrictedLoss, self).__init__()
 
-        self.prior_log_probs = prior_log_probs
-        self.logprob2logit = LogProb2Logit()
+        self.core = nn.BCEWithLogitsLoss(prior_odds)
 
     def forward(
         self,
         true_extended_action_vector: torch.Tensor,
         predicted_action_vector_logits: torch.Tensor,
     ) -> torch.Tensor:
-        """Loss function for actions model
+        """Loss for ActionsModelRestricted model
 
         Args:
-            true_uppercase (torch.Tensor): Tensor teling if token should be uppercased (1.0) or not (0.0) (BxL)
-            true_punctuation_type (torch.Tensor): If punctuation sign should be added after word (BxLxP)
-            uppercase_logit (torch.Tensor): Logit of uppercasing probability predited by model
-            punctuation_type_logit (torch.Tensor): Logits of punctuation probability predicted by model
+            true_extended_action_vector (torch.Tensor): Ground-truth action vectors. Shape BxLxA
+            predicted_action_vector_logits (torch.Tensor): Action vector-s logits predicted by ActionsModelRestricted model. Shape BxLxA
 
         Returns:
             torch.Tensor: Loss value
         """
 
-        # P(Uppercase, Words) = P(Words | Uppercase) * P(Uppercase)
-        uppercase_cond_log_prob = torch.distributions.Bernoulli(
-            logits=predicted_action_vector_logits[:, :, 0].reshape(-1)
-        ).log_prob(true_extended_action_vector[:, :, 0].reshape(-1))
-
-        uppercase_prior_log_prob = torch.distributions.Bernoulli(
-            logits=self.logprob2logit(self.prior_log_probs[0])
-        ).log_prob(true_extended_action_vector[:, :, 0].reshape(-1))
-
-        uppercase_log_prob = uppercase_cond_log_prob + uppercase_prior_log_prob
-
-        # P(Punctuation, Words) = P(Words | Punctuation) * P(Punctuation)
-        punctuation_cond_log_prob = torch.distributions.OneHotCategorical(
-            logits=predicted_action_vector_logits[:, :, 1:]
-        ).log_prob(true_extended_action_vector[:, :, 1:])
-
-        punctuation_prior_log_prob = torch.distributions.OneHotCategorical(
-            logits=self.logprob2logit(self.prior_log_probs[1:])
-        ).log_prob(true_extended_action_vector[:, :, 1:])
-
-        punctuation_log_prob = punctuation_cond_log_prob + punctuation_prior_log_prob
-
-        # An assumption that Uppercase ⊥ Punctuation, however it's not terrible as we only
-        # assume that uppercasing of current word is independent of current-word punctuation.
-        # In most cases it should be rather true. Also use mean instead of sum to get batch
-        # size indepnedence.
-        return -torch.mean(uppercase_log_prob) - torch.mean(punctuation_log_prob)
+        return self.core(predicted_action_vector_logits, true_extended_action_vector)
diff --git a/src/models/common.py b/src/models/common.py
new file mode 100644
index 0000000..bcf4bb2
--- /dev/null
+++ b/src/models/common.py
@@ -0,0 +1,129 @@
+import math
+
+import torch
+import torch.nn as nn
+
+
+def generate_square_subsequent_mask(sz):
+    r"""
+    Generate a square mask for the sequence. The masked positions are filled with float('-inf').
+    Unmasked positions are filled with float(0.0).
+
+    Source: torch Transformer class
+    """
+    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
+    mask = (
+        mask.float()
+        .masked_fill(mask == 0, float("-inf"))
+        .masked_fill(mask == 1, float(0.0))
+    )
+    return mask
+
+
+class PositionalEncoding(nn.Module):
+    """Adds sinsusoidal positional encoding (as in original "Attention is all you need" paper.)
+    src: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
+
+    """
+
+    def __init__(self, d_model: int, max_len: int, dropout=0.1):
+        """Sinusidal positional encodings
+
+        Args:
+            d_model (int): Embedding dimension
+            max_len (int): Maximum length of sequence
+            dropout (float, optional): Dropout ratio. Defaults to 0.1.
+        """
+        super(PositionalEncoding, self).__init__()
+        self.dropout = nn.Dropout(p=dropout)
+
+        pe = torch.zeros(max_len, d_model)
+        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
+        )
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0).transpose(0, 1)
+        self.register_buffer("pe", pe)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Applies positional encoding
+
+        Args:
+            x (torch.Tensor): Word embeddings tensor
+
+        Returns:
+            torch.Tensor: Word embeddings with added positional encodings
+        """
+        x = x + self.pe[: x.size(0), :]
+        return self.dropout(x)
+
+
+class TransformerSeq2Seq(nn.Module):
+    """Class representing a sequence to sequence transformer, based on original "Attention is all you need" paper."""
+
+    def __init__(
+        self,
+        vocab_size: int,
+        embedding_size: int,
+        max_len: int,
+        num_heads: int = 8,
+        encoder_layers: int = 6,
+        decoder_layers: int = 6,
+        feedforward_neurons: int = 2048,
+        dropout: float = 0.1,
+    ):
+
+        super(TransformerSeq2Seq, self).__init__()
+
+        # Embedd from token to vec space
+        self.word_embedding = nn.Embedding(vocab_size, embedding_size)
+
+        # Add positional encoding
+        self.position_embedding = PositionalEncoding(embedding_size, max_len, dropout)
+
+        # Combined encoder-decoder step
+        self.core = nn.Transformer(
+            embedding_size,
+            num_heads,
+            encoder_layers,
+            decoder_layers,
+            feedforward_neurons,
+            dropout,
+        )
+
+        # Map embedding to word
+        self.embedding_to_words = nn.Linear(embedding_size, vocab_size)
+
+    def forward(
+        self, source: torch.Tensor, target: torch.Tensor, source_mask: torch.Tensor,
+    ) -> torch.Tensor:
+        """Full encoder-decoder pass
+
+        Args:
+            source (torch.Tensor): Tensor with batch of source sentences tokens [BxL shape]
+            target (torch.Tensor): Tensor with batch of target sentences tokens [BxL-1 shape]
+            source_mask (torch.Tensor): Mask applied to source (True if element is padding, False otherwise) [BxL shape]
+
+        Returns:
+            torch.Tensor: Tensor with predicted target sentences tokens [Bx(L-1)xV]
+        """
+        # Input to encoder
+        x = source.transpose(0, 1)
+        x = self.word_embedding(x)
+        x = self.position_embedding(x)
+
+        # Input to decoder
+        y = target.transpose(0, 1)
+        y = self.word_embedding(y)
+        y = self.position_embedding(y)
+
+        tgt_mask = self.core.generate_square_subsequent_mask(y.shape[0]).to(y.device)
+
+        z = self.core(
+            x, y, src_key_padding_mask=source_mask, tgt_mask=tgt_mask
+        ).transpose(1, 0)
+        z = self.embedding_to_words(z)
+
+        return z
diff --git a/src/pipelines/actions_based/processing.py b/src/pipelines/actions_based/processing.py
index fc5ac1f..7057c0b 100644
--- a/src/pipelines/actions_based/processing.py
+++ b/src/pipelines/actions_based/processing.py
@@ -6,9 +6,9 @@ import numpy as np
 from transformers import BertTokenizerFast
 from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
 
-from src.utils import remove_punctuation
+from src.utils import input_preprocess, output_preprocess
 
-ACTIONS_KEYS = ["dot", "upper_case", "colon", "question_mark"]
+ACTIONS_KEYS = ["upper_case", "dot", "colon", "question_mark"]
 
 
 def apply_file_processing(x: dict) -> dict:
@@ -105,8 +105,8 @@ def action_vector(actions: List[str]) -> np.ndarray:
     """
     return encode_actions(
         {
-            "dot": "dot" in actions,
             "upper_case": "upper_case" in actions,
+            "dot": "dot" in actions,
             "colon": "colon" in actions,
             "question_mark": "question_mark" in actions,
         }
@@ -204,8 +204,8 @@ def detect_actions(word: str, next_word: Optional[str]) -> Mapping[str, bool]:
         return dict(zip(ACTIONS_KEYS, [False] * len(ACTIONS_KEYS)))
 
     actions = {
-        "dot": word[-1] == ".",
         "upper_case": word[0].isupper(),
+        "dot": word[-1] == ".",
         "colon": word[-1] == ",",
         "question_mark": word[-1] == "?",
     }
@@ -249,7 +249,7 @@ def create_model_input_output(text: str) -> Tuple[str, np.ndarray]:
         text_cleaned (str): Text without any interpuction and all lowercase
         actions (np.ndarray): To dimensional array, where each row is aciton vector for each word (columns)
     """
-    words = text.split(" ")
+    words = output_preprocess(text).split(" ")
 
     words_output = []
     actions_output = []
@@ -259,7 +259,7 @@ def create_model_input_output(text: str) -> Tuple[str, np.ndarray]:
         word = words[i]
         next_word = words[i + 1] if len(words) > i + 1 else None
 
-        word_sanitized = remove_punctuation(word).lower()
+        word_sanitized = input_preprocess(word)
         if len(word_sanitized) > 0:
             actions = detect_actions(word, next_word)
             actions_encoded = encode_actions(actions)
diff --git a/src/pipelines/actions_based/train.py b/src/pipelines/actions_based/train.py
deleted file mode 100755
index 6973c53..0000000
--- a/src/pipelines/actions_based/train.py
+++ /dev/null
@@ -1,140 +0,0 @@
-#!/usr/bin/python3
-
-import glob
-import pickle
-from datetime import datetime
-from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
-
-import dask.dataframe as dd
-import numpy as np
-import torch
-from torch.nn import BCEWithLogitsLoss
-from transformers import BertForTokenClassification, BertTokenizerFast
-
-from src.batch_loading import get_batches
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.training import latest_model, save_training_step
-from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config, prepare_folder
-
-INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
-INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
-OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions"
-
-if __name__ == "__main__":
-    config = get_config()
-    learning_rate = config["actions"]["training"]["learning_rate"]
-    num_epochs = config["actions"]["training"]["num_epochs"]
-    batch_size = config["actions"]["training"]["batch_size"]
-    save_step = config["actions"]["training"]["save_step"]
-    loss_averaging_span = config["actions"]["training"]["loss_averaging_span"]
-    fresh_start = config["actions"]["training"]["fresh_start"]
-    device_name = config["actions"]["training"]["device"]
-    max_train_time = config["actions"]["training"]["max_training_time"]
-    base_model = config["global"]["base_model"]
-    seed = config["global"]["random_seed"]
-
-    prepare_folder(OUTPUT_PATH)
-    np.random.seed(seed=seed)
-
-    if max_train_time is not None:
-        max_train_time = convert_to_timedelta(max_train_time)
-
-    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
-    print(f"Training on {device}")
-
-    # Load loss weights
-    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
-        stats = pickle.load(f)
-        pos_examples = stats["class_number"]
-        neg_examples = stats["num_examples"] - stats["class_number"]
-        pos_weight = torch.tensor(neg_examples / pos_examples)
-
-    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
-    tokenizer = BertTokenizerFast.from_pretrained(base_model)
-
-    model = ActionsModelBase(base_model, len(ACTIONS_KEYS)).to(device)
-    criterion = ActionsModelBaseLoss(pos_weight).to(device)
-    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
-
-    epoch_start = 0
-    sample_start = 0
-    if fresh_start is False:
-        checkpoint_files = glob.glob(f"{OUTPUT_PATH}/*.model")
-        latest = latest_model(checkpoint_files)
-
-        if latest is not None:
-            epoch, batch = latest
-            model.load_state_dict(
-                torch.load(f"{OUTPUT_PATH}/{epoch}-{batch}.model", map_location=device,)
-            )
-            optimizer.load_state_dict(
-                torch.load(
-                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer", map_location=device,
-                )
-            )
-
-            epoch_start, sample_start = epoch, batch
-            print(f"Loaded {epoch}-{batch}")
-
-    model.train()
-    losses = []
-
-    num_samples = df.tail(1).index.values[0] + 1
-    random_index_shuffle = np.random.permutation(range(num_samples))
-
-    training_stopped = False
-
-    time_max = datetime.max
-    if max_train_time is not None:
-        time_max = datetime.now() + max_train_time
-
-    for epoch in range(epoch_start, num_epochs):
-        if training_stopped:
-            break
-
-        i = sample_start
-        for data_batch in get_batches(df, batch_size, 100, random_index_shuffle, i):
-            inputs = data_batch.apply(
-                lambda x: x["source"].reshape(x["source_shape"]), axis=1
-            ).values
-            outputs = data_batch.apply(
-                lambda x: x["target"].reshape(x["target_shape"]), axis=1
-            ).values
-            attentions_mask = data_batch.apply(
-                lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]),
-                axis=1,
-            ).values
-
-            inputs = torch.tensor(np.stack(inputs).squeeze()).to(device)
-            outputs = torch.tensor(np.stack(outputs)).to(device)
-            attentions_mask = torch.tensor(np.stack(attentions_mask)).to(device)
-
-            y_pred = model(input_ids=inputs, attention_mask=attentions_mask)
-
-            loss = criterion(y_pred, outputs)
-
-            losses.append(loss.item())
-            if len(losses) > loss_averaging_span:
-                losses = losses[-loss_averaging_span:]
-
-            print(f"epoch: {epoch} | step: {i} | loss: {np.mean(losses)}")
-
-            optimizer.zero_grad()
-
-            if i % save_step == 0 and (i != sample_start or epoch != epoch_start):
-                print(f"Saving: Epoch {epoch}, step {i}")
-                save_training_step(OUTPUT_PATH, f"{epoch}-{i}", model, optimizer)
-
-            if datetime.now() > time_max:
-                print(f"Max time reached, saving: Epoch {epoch}, step {i}")
-                save_training_step(OUTPUT_PATH, f"{epoch}-{i}", model, optimizer)
-                training_stopped = True
-                break
-
-            loss.backward()
-            optimizer.step()
-
-            i += 1
-
-    if not training_stopped:
-        save_training_step(OUTPUT_PATH, "final", model, optimizer)
diff --git a/src/pipelines/actions_based/train_base.py b/src/pipelines/actions_based/train_base.py
new file mode 100755
index 0000000..1e4c7f9
--- /dev/null
+++ b/src/pipelines/actions_based/train_base.py
@@ -0,0 +1,90 @@
+#!/usr/bin/python3
+
+import pickle
+from src.pipelines.train import TrainerBase
+from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
+
+import numpy as np
+import torch
+import pandas as pd
+from transformers import BertTokenizerFast
+
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config
+
+INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
+OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions_base"
+
+
+class TrainerActions(TrainerBase):
+    def __init__(self) -> None:
+
+        config = get_config()
+        learning_rate = config["actions"]["training_base"]["learning_rate"]
+        num_epochs = config["actions"]["training_base"]["num_epochs"]
+        batch_size = config["actions"]["training_base"]["batch_size"]
+        save_step = config["actions"]["training_base"]["save_step"]
+        batch_buffer_size = config["actions"]["training_base"]["batch_buffer_size"]
+        loss_averaging_span = config["actions"]["training_base"]["loss_averaging_span"]
+        fresh_start = config["actions"]["training_base"]["fresh_start"]
+        device_name = config["actions"]["training_base"]["device"]
+        max_train_time = config["actions"]["training_base"]["max_training_time"]
+        base_model = config["global"]["base_model"]
+        seed = config["global"]["random_seed"]
+
+        if max_train_time is not None:
+            max_train_time = convert_to_timedelta(max_train_time)
+
+        # Load loss weights
+        with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
+            stats = pickle.load(f)
+            pos_examples = stats["class_number"]
+            neg_examples = stats["num_examples"] - stats["class_number"]
+            pos_weight = torch.tensor(neg_examples / pos_examples)
+
+        np.random.seed(seed=seed)
+
+        device = torch.device(device_name if torch.cuda.is_available() else "cpu")
+
+        model = ActionsModelBase(base_model, len(ACTIONS_KEYS)).to(device)
+        self.criterion = ActionsModelBaseLoss(pos_weight).to(device)
+        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+
+        super(TrainerActions, self).__init__(
+            model,
+            device,
+            optimizer,
+            INPUT_PATH,
+            OUTPUT_PATH,
+            max_train_time,
+            fresh_start,
+            num_epochs,
+            batch_size,
+            batch_buffer_size,
+            loss_averaging_span,
+            save_step,
+        )
+
+    def calc_loss(self, data_batch: pd.DataFrame,) -> torch.Tensor:
+        inputs = data_batch.apply(
+            lambda x: x["source"].reshape(x["source_shape"]), axis=1
+        ).values
+        outputs = data_batch.apply(
+            lambda x: x["target"].reshape(x["target_shape"]), axis=1
+        ).values
+        attentions_mask = data_batch.apply(
+            lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]), axis=1,
+        ).values
+
+        inputs = torch.tensor(np.stack(inputs).squeeze()).to(self.device)
+        outputs = torch.tensor(np.stack(outputs)).to(self.device)
+        attentions_mask = torch.tensor(np.stack(attentions_mask)).to(self.device)
+
+        y_pred = self.model(input_ids=inputs, attention_mask=attentions_mask)
+
+        return self.criterion(y_pred, outputs)
+
+
+if __name__ == "__main__":
+    TrainerActions().train()
\ No newline at end of file
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
new file mode 100755
index 0000000..cafb383
--- /dev/null
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -0,0 +1,116 @@
+#!/usr/bin/python3
+
+import pickle
+from src.models.actions_model_mixed import ActionsModelMixed, ActionsModelMixedLoss
+from src.pipelines.train import TrainerBase
+
+import numpy as np
+import torch
+import pandas as pd
+from transformers import BertTokenizerFast
+
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config
+
+INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
+OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions_mixed"
+
+
+class TrainerActions(TrainerBase):
+    def __init__(self) -> None:
+
+        config = get_config()
+        embedding_size = config["actions"]["training_mixed"]["embedding_size"]
+        num_heads = config["actions"]["training_mixed"]["num_heads"]
+        num_layers = config["actions"]["training_mixed"]["num_layers"]
+        dropout = config["actions"]["training_mixed"]["dropout"]
+        feedforward_neurons = config["actions"]["training_mixed"]["feedforward_neurons"]
+        learning_rate = config["actions"]["training_mixed"]["learning_rate"]
+        num_epochs = config["actions"]["training_mixed"]["num_epochs"]
+        batch_size = config["actions"]["training_mixed"]["batch_size"]
+        save_step = config["actions"]["training_mixed"]["save_step"]
+        batch_buffer_size = config["actions"]["training_mixed"]["batch_buffer_size"]
+        loss_averaging_span = config["actions"]["training_mixed"]["loss_averaging_span"]
+        fresh_start = config["actions"]["training_mixed"]["fresh_start"]
+        device_name = config["actions"]["training_mixed"]["device"]
+        max_train_time = config["actions"]["training_mixed"]["max_training_time"]
+        base_model = config["global"]["base_model"]
+        seed = config["global"]["random_seed"]
+
+        if max_train_time is not None:
+            max_train_time = convert_to_timedelta(max_train_time)
+
+        # Load loss weights
+        with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
+            stats = pickle.load(f)
+            pos_examples = stats["class_number"]
+            neg_examples = stats["num_examples"] - stats["class_number"]
+            pos_weight = torch.tensor(neg_examples / pos_examples)
+
+        np.random.seed(seed=seed)
+
+        device = torch.device(device_name if torch.cuda.is_available() else "cpu")
+
+        tokenizer = BertTokenizerFast.from_pretrained(base_model)
+        model = ActionsModelMixed(
+            tokenizer.vocab_size,
+            embedding_size,
+            num_heads,
+            num_layers,
+            feedforward_neurons,
+            len(ACTIONS_KEYS),
+            500,
+            dropout,
+        ).to(device)
+        self.criterion = ActionsModelMixedLoss(pos_weight).to(device)
+        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+
+        super(TrainerActions, self).__init__(
+            model,
+            device,
+            optimizer,
+            INPUT_PATH,
+            OUTPUT_PATH,
+            max_train_time,
+            fresh_start,
+            num_epochs,
+            batch_size,
+            batch_buffer_size,
+            loss_averaging_span,
+            save_step,
+        )
+
+    def calc_loss(self, data_batch: pd.DataFrame,) -> torch.Tensor:
+        inputs = data_batch.apply(
+            lambda x: x["source"].reshape(x["source_shape"]), axis=1
+        ).values
+        outputs = data_batch.apply(
+            lambda x: x["target"].reshape(x["target_shape"]), axis=1
+        ).values
+        attentions_mask = data_batch.apply(
+            lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]), axis=1,
+        ).values
+
+        inputs = torch.tensor(np.stack(inputs).squeeze(), dtype=torch.long).to(
+            self.device
+        )
+        outputs = torch.tensor(np.stack(outputs), dtype=torch.float).to(self.device)
+
+        # Convert to boolean
+        attentions_mask = torch.tensor(np.stack(attentions_mask))
+        attentions_mask = torch.tensor(np.stack(attentions_mask, axis=0) == 0).to(
+            self.device
+        )
+
+        y_pred = self.model(
+            input_ids=inputs,
+            actions=outputs[:, :-1, :],
+            attention_mask=attentions_mask,
+        )
+
+        return self.criterion(outputs[:, 1:, :], y_pred)
+
+
+if __name__ == "__main__":
+    TrainerActions().train()
\ No newline at end of file
diff --git a/src/pipelines/actions_based/train_restricted.py b/src/pipelines/actions_based/train_restricted.py
index 8d862b0..048901e 100755
--- a/src/pipelines/actions_based/train_restricted.py
+++ b/src/pipelines/actions_based/train_restricted.py
@@ -1,155 +1,118 @@
 #!/usr/bin/python3
 
-import glob
 import pickle
-from datetime import datetime
-
-from torch import dtype
 from src.models.actions_model_restricted import (
     ActionsModelRestricted,
     ActionsModelRestrictedLoss,
 )
+from src.pipelines.train import TrainerBase
 
-import dask.dataframe as dd
 import numpy as np
 import torch
-from torch.nn import BCEWithLogitsLoss
-from transformers import BertForTokenClassification, BertTokenizerFast
+import pandas as pd
 
-from src.batch_loading import get_batches
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.training import latest_model, save_training_step
-from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config, prepare_folder
+from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
 INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
-OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions"
-
-if __name__ == "__main__":
-    config = get_config()
-    learning_rate = config["actions"]["training"]["learning_rate"]
-    num_epochs = config["actions"]["training"]["num_epochs"]
-    batch_size = config["actions"]["training"]["batch_size"]
-    save_step = config["actions"]["training"]["save_step"]
-    loss_averaging_span = config["actions"]["training"]["loss_averaging_span"]
-    fresh_start = config["actions"]["training"]["fresh_start"]
-    device_name = config["actions"]["training"]["device"]
-    max_train_time = config["actions"]["training"]["max_training_time"]
-    base_model = config["global"]["base_model"]
-    seed = config["global"]["random_seed"]
-
-    prepare_folder(OUTPUT_PATH)
-    np.random.seed(seed=seed)
-
-    if max_train_time is not None:
-        max_train_time = convert_to_timedelta(max_train_time)
-
-    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
-    print(f"Training on {device}")
-
-    # Load loss weights
-    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
-        stats = pickle.load(f)
-        pos_examples = stats["class_number"]
-
-        probs = pos_examples / stats["num_examples"]
-
-        # Add no-punctuation prob
-        probs = np.concatenate([probs, np.array([1.0 - np.sum(probs[1:])])], -1)
-        assert probs[-1] >= 0.0 and probs[-1] <= 1.0
-
-        probs = torch.tensor(probs).to(device)
-
-    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
-    tokenizer = BertTokenizerFast.from_pretrained(base_model)
-
-    model = ActionsModelRestricted(base_model, len(ACTIONS_KEYS) + 1).to(device)
-    criterion = ActionsModelRestrictedLoss(probs).to(device)
-    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
-
-    epoch_start = 0
-    sample_start = 0
-    if fresh_start is False:
-        checkpoint_files = glob.glob(f"{OUTPUT_PATH}/*.model")
-        latest = latest_model(checkpoint_files)
-
-        if latest is not None:
-            epoch, batch = latest
-            model.load_state_dict(
-                torch.load(f"{OUTPUT_PATH}/{epoch}-{batch}.model", map_location=device,)
+OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions_restricted"
+
+
+class TrainerActions(TrainerBase):
+    def __init__(self) -> None:
+
+        config = get_config()
+        learning_rate = config["actions"]["training_restricted"]["learning_rate"]
+        num_epochs = config["actions"]["training_restricted"]["num_epochs"]
+        batch_size = config["actions"]["training_restricted"]["batch_size"]
+        save_step = config["actions"]["training_restricted"]["save_step"]
+        batch_buffer_size = config["actions"]["training_restricted"][
+            "batch_buffer_size"
+        ]
+        loss_averaging_span = config["actions"]["training_restricted"][
+            "loss_averaging_span"
+        ]
+        fresh_start = config["actions"]["training_restricted"]["fresh_start"]
+        device_name = config["actions"]["training_restricted"]["device"]
+        max_train_time = config["actions"]["training_restricted"]["max_training_time"]
+        base_model = config["global"]["base_model"]
+        seed = config["global"]["random_seed"]
+
+        if max_train_time is not None:
+            max_train_time = convert_to_timedelta(max_train_time)
+
+        # Load loss weights
+        with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
+            stats = pickle.load(f)
+            pos_examples = stats["class_number"]
+            neg_examples = stats["num_examples"] - stats["class_number"]
+            pos_weight = torch.tensor(neg_examples / pos_examples)
+
+            # Load loss weights
+        with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
+            stats = pickle.load(f)
+            pos_examples = stats["class_number"]
+            neg_examples = stats["num_examples"] - stats["class_number"]
+
+            no_punctuation_pos_examples = np.sum(neg_examples[1:])
+            no_punctuation_neg_examples = np.sum(pos_examples[1:])
+
+            pos_examples = np.concatenate(
+                [pos_examples, no_punctuation_pos_examples.reshape(1)], -1
             )
-            optimizer.load_state_dict(
-                torch.load(
-                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer", map_location=device,
-                )
+            neg_examples = np.concatenate(
+                [neg_examples, no_punctuation_neg_examples.reshape(1)], -1
             )
 
-            epoch_start, sample_start = epoch, batch
-            print(f"Loaded {epoch}-{batch}")
-
-    model.train()
-    losses = []
-
-    num_samples = df.tail(1).index.values[0] + 1
-    random_index_shuffle = np.random.permutation(range(num_samples))
-
-    training_stopped = False
+            pos_weight = torch.tensor(neg_examples / pos_examples)
 
-    time_max = datetime.max
-    if max_train_time is not None:
-        time_max = datetime.now() + max_train_time
+        np.random.seed(seed=seed)
 
-    for epoch in range(epoch_start, num_epochs):
-        if training_stopped:
-            break
+        device = torch.device(device_name if torch.cuda.is_available() else "cpu")
 
-        i = sample_start
-        for data_batch in get_batches(df, batch_size, 100, random_index_shuffle, i):
-            inputs = data_batch.apply(
-                lambda x: x["source"].reshape(x["source_shape"]), axis=1
-            ).values
-            outputs = data_batch.apply(
-                lambda x: x["target"].reshape(x["target_shape"]), axis=1
-            ).values
-            attentions_mask = data_batch.apply(
-                lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]),
-                axis=1,
-            ).values
+        model = ActionsModelRestricted(base_model, len(ACTIONS_KEYS) + 1).to(device)
+        self.criterion = ActionsModelRestrictedLoss(pos_weight).to(device)
+        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
-            inputs = torch.tensor(np.stack(inputs).squeeze()).to(device)
-            outputs = torch.tensor(np.stack(outputs)).to(device)
-            attentions_mask = torch.tensor(np.stack(attentions_mask)).to(device)
+        super(TrainerActions, self).__init__(
+            model,
+            device,
+            optimizer,
+            INPUT_PATH,
+            OUTPUT_PATH,
+            max_train_time,
+            fresh_start,
+            num_epochs,
+            batch_size,
+            batch_buffer_size,
+            loss_averaging_span,
+            save_step,
+        )
 
-            y_pred = model(input_ids=inputs, attention_mask=attentions_mask)
+    def calc_loss(self, data_batch: pd.DataFrame,) -> torch.Tensor:
+        inputs = data_batch.apply(
+            lambda x: x["source"].reshape(x["source_shape"]), axis=1
+        ).values
+        outputs = data_batch.apply(
+            lambda x: x["target"].reshape(x["target_shape"]), axis=1
+        ).values
+        attentions_mask = data_batch.apply(
+            lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]), axis=1,
+        ).values
 
-            outputs = torch.cat(
-                [outputs, (1.0 - outputs[:, :, 1:].max(-1)[0]).unsqueeze(-1)], axis=-1
-            )
-
-            loss = criterion(y_pred, outputs)
-
-            losses.append(loss.item())
-            if len(losses) > loss_averaging_span:
-                losses = losses[-loss_averaging_span:]
+        inputs = torch.tensor(np.stack(inputs).squeeze()).to(self.device)
+        outputs = torch.tensor(np.stack(outputs)).to(self.device)
+        attentions_mask = torch.tensor(np.stack(attentions_mask)).to(self.device)
 
-            print(f"epoch: {epoch} | step: {i} | loss: {np.mean(losses)}")
+        y_pred = self.model(input_ids=inputs, attention_mask=attentions_mask)
 
-            optimizer.zero_grad()
+        outputs = torch.cat(
+            [outputs, (1.0 - outputs[:, :, 1:].max(-1)[0]).unsqueeze(-1)], axis=-1
+        )
 
-            if i % save_step == 0 and (i != sample_start or epoch != epoch_start):
-                print(f"Saving: Epoch {epoch}, step {i}")
-                save_training_step(OUTPUT_PATH, f"{epoch}-{i}", model, optimizer)
+        return self.criterion(outputs, y_pred)
 
-            if datetime.now() > time_max:
-                print(f"Max time reached, saving: Epoch {epoch}, step {i}")
-                save_training_step(OUTPUT_PATH, f"{epoch}-{i}", model, optimizer)
-                training_stopped = True
-                break
 
-            loss.backward()
-            optimizer.step()
-
-            i += 1
-
-    if not training_stopped:
-        save_training_step(OUTPUT_PATH, "final", model, optimizer)
+if __name__ == "__main__":
+    TrainerActions().train()
\ No newline at end of file
diff --git a/src/pipelines/train.py b/src/pipelines/train.py
new file mode 100644
index 0000000..8e988cc
--- /dev/null
+++ b/src/pipelines/train.py
@@ -0,0 +1,176 @@
+#!/usr/bin/python3
+
+import glob
+from datetime import datetime, timedelta
+from typing import Optional
+
+from torch.optim.optimizer import Optimizer
+
+import dask.dataframe as dd
+import numpy as np
+import torch
+import torch.nn as nn
+import pandas as pd
+
+from src.batch_loading import get_batches
+from src.training import latest_model, save_training_step
+from src.utils import convert_to_timedelta, prepare_folder
+from abc import ABC, abstractmethod
+
+
+class TrainerBase(ABC):
+    """[summary]
+
+    Args:
+        ABC ([type]): [description]
+    """
+
+    def __init__(
+        self,
+        model: nn.Module,
+        device: torch.device,
+        optimizer: Optimizer,
+        input_path: str,
+        output_path: str,
+        max_train_time: Optional[timedelta],
+        fresh_start: bool,
+        num_epochs: int,
+        batch_size: int,
+        batch_buffer_size: int = 100,
+        loss_averaging_span: int = 1000,
+        save_step: int = 1000,
+    ) -> None:
+        """Initializes base trainer
+
+        Args:
+            model (nn.Module): Model that will be trained
+            device (torch.device): Device on which model will be loaded & trained
+            optimizer (Optimizer): Optimizer used for gradient descent
+            input_path (str): Path to parquet folder with input dataset
+            output_path (str): Path where model checkpoints will be stored
+            max_train_time (Optional[timedelta]): Maximum training time
+            fresh_start (bool): If set to true, last checkpoint will not be loaded and training will start from scratch
+            num_epochs (int): Number of epochs to train
+            batch_size (int): Batch size to use
+            batch_buffer_size (int, optional): How many batches to load to ram at once. Defaults to 100.
+            loss_averaging_span (int, optional): How many losses to average ovet in logs. Defaults to 1000.
+            save_step (int, optional): Step at which model checkpoints will be saved. Defaults to 1000.
+        """
+
+        self.model = model
+        self.device = device
+        self.optimizer = optimizer
+
+        self.input_path = input_path
+        self.output_path = output_path
+        self.fresh_start = fresh_start
+        self.num_epochs = num_epochs
+        self.batch_size = batch_size
+        self.batch_buffer_size = batch_buffer_size
+        self.loss_averaging_span = loss_averaging_span
+        self.save_step = save_step
+
+        self.max_train_time = max_train_time
+        if self.max_train_time is not None:
+            self.max_train_time = convert_to_timedelta(self.max_train_time)
+
+    @abstractmethod
+    def calc_loss(self, data_batch: pd.DataFrame) -> torch.Tensor:
+        """User-provided function that will return loss on which backprob and
+        optimization will be made
+
+        Args:
+            data_batch (pd.DataFrame): Pandas dataframe with a single batch of data
+
+        Returns:
+            torch.Tensor: Loss tensor
+        """
+        pass
+
+    def _load_model(self):
+        checkpoint_files = glob.glob(f"{self.output_path}/*.model")
+        latest = latest_model(checkpoint_files)
+
+        if latest is not None:
+            epoch, batch = latest
+            self.model.load_state_dict(
+                torch.load(
+                    f"{self.output_path}/{epoch}-{batch}.model",
+                    map_location=self.device,
+                )
+            )
+            self.optimizer.load_state_dict(
+                torch.load(
+                    f"{self.output_path}/{epoch}-{batch}.optimizer",
+                    map_location=self.device,
+                )
+            )
+
+            return epoch, batch
+
+    def train(self):
+        """Preforms full training of the model"""
+        prepare_folder(self.output_path)
+        print(f"Training on {self.device}")
+
+        df = dd.read_parquet(self.input_path, engine="pyarrow")
+
+        epoch_start = 0
+        sample_start = 0
+        if self.fresh_start is False:
+            epoch_start, sample_start = self._load_model()
+            print(f"Loaded {epoch_start}-{sample_start}")
+
+        self.model.train()
+        losses = []
+
+        num_samples = df.tail(1).index.values[0] + 1
+        random_index_shuffle = np.random.permutation(range(num_samples))
+
+        training_stopped = False
+
+        time_max = datetime.max
+        if self.max_train_time is not None:
+            time_max = datetime.now() + self.max_train_time
+
+        for epoch in range(epoch_start, self.num_epochs):
+            if training_stopped:
+                break
+
+            i = sample_start
+            for data_batch in get_batches(
+                df, self.batch_size, self.batch_buffer_size, random_index_shuffle, i
+            ):
+                loss = self.calc_loss(data_batch)
+
+                losses.append(loss.item())
+                if len(losses) > self.loss_averaging_span:
+                    losses = losses[-self.loss_averaging_span :]
+
+                print(f"epoch: {epoch} | step: {i} | loss: {np.mean(losses)}")
+
+                self.optimizer.zero_grad()
+
+                if i % self.save_step == 0 and (
+                    i != sample_start or epoch != epoch_start
+                ):
+                    print(f"Saving: Epoch {epoch}, step {i}")
+                    save_training_step(
+                        self.output_path, f"{epoch}-{i}", self.model, self.optimizer
+                    )
+
+                if datetime.now() > time_max:
+                    print(f"Max time reached, saving: Epoch {epoch}, step {i}")
+                    save_training_step(
+                        self.output_path, f"{epoch}-{i}", self.model, self.optimizer
+                    )
+                    training_stopped = True
+                    break
+
+                loss.backward()
+                self.optimizer.step()
+
+                i += 1
+
+        if not training_stopped:
+            save_training_step(self.output_path, "final", self.model, self.optimizer)
diff --git a/src/pipelines/translation_based/processing.py b/src/pipelines/translation_based/processing.py
index 41962da..b139bce 100644
--- a/src/pipelines/translation_based/processing.py
+++ b/src/pipelines/translation_based/processing.py
@@ -1,9 +1,10 @@
+from src.utils import input_preprocess
 from typing import Tuple
 
 import numpy as np
 from transformers import BertTokenizerFast
 
-from src.pipelines.actions_based.processing import remove_punctuation, text_from_xml
+from src.pipelines.actions_based.processing import text_from_xml
 
 
 def raw_to_dataframe(entry: dict) -> dict:
@@ -241,7 +242,7 @@ def create_input_output(
         np.ndarray: Single sample that will serve as expected output from the model
     """
     decoded_str = tokenizer.decode(tokens)
-    cleaned_str = remove_punctuation(decoded_str).lower()
+    cleaned_str = input_preprocess(decoded_str).lower()
     source_batch_entry = tokenizer(cleaned_str)["input_ids"][1:-1]
     target_batch_entry = tokens
 
diff --git a/src/utils.py b/src/utils.py
index 906de65..06bbaf4 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -2,7 +2,7 @@ import os
 import re
 import shutil
 from datetime import timedelta
-from typing import Optional
+from typing import List, Optional
 
 import yaml
 
@@ -34,7 +34,7 @@ def remove_multiple_spaces(text: str) -> str:
     return re.sub(r"\s\s+", " ", text)
 
 
-def remove_punctuation(text: str) -> str:
+def remove_punctuation(text: str, whitelist: List[str] = []) -> str:
     """Removes all non-alphanumeric characters from the text.
     Might result in multiple spaces while chracters like `-`
     are used
@@ -46,13 +46,32 @@ def remove_punctuation(text: str) -> str:
         str: Text with all punctuactions removed
     """
 
-    # Separating characters
+    return "".join(filter(lambda x: x.isalnum() or x.isspace() or x in whitelist, text))
+
+
+def output_preprocess(text: str) -> str:
+    """Cleans the text out of bad formating and removes or replaces symbols that will not be predicted by a model
+
+    Args:
+        text (str): Arbitrary text
+
+    Returns:
+        str: Text that could be a direct output of punctuation prediction algorithm
+    """
+    # Whitespace-like characters
     text = text.replace("-", " ").replace("/", " ").replace("+", " ")
 
-    return "".join(filter(lambda x: x.isalnum() or x.isspace(), text))
+    # Punctuation-like characters
+    text = text.replace(";", ".").replace("!", ".")
+
+    text = remove_punctuation(text, [".", ",", "?"])
+    text = remove_multiple_spaces(text)
+    text = text.strip()
+
+    return text
 
 
-def preprocess(text: str) -> str:
+def input_preprocess(text: str) -> str:
     """Makes sure that input is in the same format as training data (no non-alphanum chars, no double spaces,
         all lowercase etc.)
 
diff --git a/tests/pipelines/actions_based/test_processing.py b/tests/pipelines/actions_based/test_processing.py
index c626ff2..8e4caaa 100644
--- a/tests/pipelines/actions_based/test_processing.py
+++ b/tests/pipelines/actions_based/test_processing.py
@@ -24,24 +24,24 @@ from src.pipelines.actions_based.processing import (
 def test_detect_actions():
     actions = detect_actions("Janek.", None)
     assert actions == {
-        "dot": True,
         "upper_case": True,
+        "dot": True,
         "colon": False,
         "question_mark": False,
     }
 
     actions = detect_actions("ewka?", None)
     assert actions == {
-        "dot": False,
         "upper_case": False,
+        "dot": False,
         "colon": False,
         "question_mark": True,
     }
 
     actions = detect_actions("Test", None)
     assert actions == {
-        "dot": False,
         "upper_case": True,
+        "dot": False,
         "colon": False,
         "question_mark": False,
     }
@@ -49,21 +49,21 @@ def test_detect_actions():
 
 def test_encode_actions():
     x = {
-        "dot": True,
         "upper_case": False,
+        "dot": True,
         "colon": False,
         "question_mark": True,
     }
 
-    assert np.all(encode_actions(x) == np.array([1, 0, 0, 1]))
+    assert np.all(encode_actions(x) == np.array([0, 1, 0, 1]))
 
 
 def test_decode_actions():
-    x = np.array([1, 0, 0, 1])
+    x = np.array([0, 1, 0, 1])
 
     assert decode_actions(x) == {
-        "dot": True,
         "upper_case": False,
+        "dot": True,
         "colon": False,
         "question_mark": True,
     }
@@ -136,8 +136,8 @@ def test_nearest_sentence_l():
 def create_dummy_action(end_sentence: bool) -> np.array:
     return encode_actions(
         {
-            "dot": end_sentence,
             "upper_case": False,
+            "dot": end_sentence,
             "colon": False,
             "question_mark": False,
         }
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 9887354..0702891 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,4 +1,10 @@
-from src.utils import convert_to_timedelta, preprocess, remove_multiple_spaces, remove_punctuation
+from src.utils import (
+    convert_to_timedelta,
+    input_preprocess,
+    output_preprocess,
+    remove_multiple_spaces,
+    remove_punctuation,
+)
 
 
 def test_remove_multiple_spaces():
@@ -10,16 +16,28 @@ def test_remove_multiple_spaces():
 
 def test_remove_punctuation():
     provided = "Ala..  ma-Kota!?.@@$ Kot ma Ale ()*"
-    expected = "Ala  ma Kota Kot ma Ale "
+    expected = "Ala  maKota Kot ma Ale "
 
     assert remove_punctuation(provided) == expected
 
+    whitelist = [".", "?"]
+    expected_whitelist = "Ala..  maKota?. Kot ma Ale "
 
-def test_preprocess():
+    assert remove_punctuation(provided, whitelist) == expected_whitelist
+
+
+def test_input_preprocess():
+    provided = "Ala  ma-Kota!?.@@$ Kot ma Ale ()*"
+    expected = "ala makota kot ma ale"
+
+    assert input_preprocess(provided) == expected
+
+
+def test_output_preprocess():
     provided = "Ala  ma-Kota!?.@@$ Kot ma Ale ()*"
-    expected = "ala ma kota kot ma ale"
+    expected = "Ala ma Kota.?. Kot ma Ale"
 
-    assert preprocess(provided) == expected
+    assert output_preprocess(provided) == expected
 
 
 def test_convert_to_timedelta():
-- 
GitLab


From e2e5586d9353f7612a99fe4bfac27991f46e52f2 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 18 Aug 2020 09:19:36 +0200
Subject: [PATCH 061/116] Hotfixes

---
 Dockerfile                           | 5 ++---
 config.ini                           | 6 +++---
 entrypoint.sh                        | 4 ++--
 requirements.txt                     | 3 ++-
 src/pipelines/actions_based/utils.py | 2 +-
 worker.py                            | 3 +--
 6 files changed, 11 insertions(+), 12 deletions(-)

diff --git a/Dockerfile b/Dockerfile
index 2b3b6a3..ce5211b 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -10,7 +10,6 @@ RUN pip3 install -r requirements.txt && rm requirements.txt
 COPY src ./src
 COPY config.ini .
 COPY worker.py .
-  
-RUN pip3 freeze
+COPY entrypoint.sh .
 
-ENTRYPOINT [ "./worker.py" ]
\ No newline at end of file
+ENTRYPOINT [ "./entrypoint.sh" ]
\ No newline at end of file
diff --git a/config.ini b/config.ini
index 9e58547..93819e5 100644
--- a/config.ini
+++ b/config.ini
@@ -14,8 +14,8 @@ port = 9981
 local_log_level = INFO
 
 [deployment]
-device = "cpu"
+device = cpu
 chunk_size = 500
 threshold = 0.9
-model = "deploy/model"
-base_model = "dkleczek/bert-base-polish-cased-v1"
\ No newline at end of file
+model = deploy/model
+base_model = dkleczek/bert-base-polish-cased-v1
\ No newline at end of file
diff --git a/entrypoint.sh b/entrypoint.sh
index e548dca..a6e06ed 100755
--- a/entrypoint.sh
+++ b/entrypoint.sh
@@ -1,8 +1,8 @@
 #!/bin/bash
 
-if test -f "./deploy/model"; then
+if ! test -f "./deploy/model"; then
     mkdir -p ./deploy
     wget https://minio.clarin-pl.eu/public/models/punctuation/0-190000.model -O deploy/model
 fi
 
-python3 worker.py
\ No newline at end of file
+python worker.py
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 17154b9..4e2d19a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -58,4 +58,5 @@ transformers==3.0.2
 typing-extensions==3.7.4.2
 unattended-upgrades==0.1
 urllib3==1.25.10
-zict==2.0.0
\ No newline at end of file
+zict==2.0.0
+git+https://gitlab.clarin-pl.eu/nlpworkers/nlp_ws.git@fa5f09a2f1447cac2c411c9d9e3d927ecd815ddc#egg=nlp_ws
\ No newline at end of file
diff --git a/src/pipelines/actions_based/utils.py b/src/pipelines/actions_based/utils.py
index 4f62b61..a8728e1 100644
--- a/src/pipelines/actions_based/utils.py
+++ b/src/pipelines/actions_based/utils.py
@@ -6,12 +6,12 @@ import torch.nn as nn
 from transformers import BertForTokenClassification, BertTokenizerFast, PretrainedConfig
 
 from src.pipelines.actions_based.processing import (
+    ACTIONS_KEYS,
     action_vector,
     last_stop_label,
     recover_text,
     token_labels_to_word_labels,
 )
-from src.processing import ACTIONS_KEYS
 
 
 def load_model(
diff --git a/worker.py b/worker.py
index 5bf6e0c..2d91eb5 100755
--- a/worker.py
+++ b/worker.py
@@ -4,8 +4,7 @@ import configparser
 
 import nlp_ws
 
-from src.pipelines.actions_based.processing import apply_actions_punctuation
-from src.pipelines.actions_based.utils import load_model
+from src.pipelines.actions_based.utils import apply_actions_punctuation, load_model
 from src.utils import preprocess
 
 
-- 
GitLab


From c6bf26dc880b48c181bb904c8c04d8273603a7ed Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 18 Aug 2020 09:37:32 +0200
Subject: [PATCH 062/116] Rollback tox - newest one is unsupported in CI

---
 .gitlab-ci.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 4d60851..a858b46 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -10,7 +10,7 @@ stages:
   - build
 
 before_script:
-  - pip install tox==3.19.0
+  - pip install tox==3.18.1
 
 pep8:
   stage: check_style
-- 
GitLab


From 2680ca5805572531814a46e1b1031aa72fa6113b Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 18 Aug 2020 10:58:05 +0200
Subject: [PATCH 063/116] Added training dockerfile and script

---
 docker/training/Dockerfile       | 8 ++++++++
 docker/training/requirements.txt | 1 +
 train.sh                         | 6 ++++++
 3 files changed, 15 insertions(+)
 create mode 100644 docker/training/Dockerfile
 create mode 120000 docker/training/requirements.txt
 create mode 100755 train.sh

diff --git a/docker/training/Dockerfile b/docker/training/Dockerfile
new file mode 100644
index 0000000..de4f169
--- /dev/null
+++ b/docker/training/Dockerfile
@@ -0,0 +1,8 @@
+FROM clarinpl/cuda-python:3.7
+
+RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y gcc python3-dev
+RUN mkdir /punctuator
+WORKDIR /punctuator
+
+COPY requirements.txt requirements.txt
+RUN pip3 install -r requirements.txt && rm requirements.txt
\ No newline at end of file
diff --git a/docker/training/requirements.txt b/docker/training/requirements.txt
new file mode 120000
index 0000000..fd1efae
--- /dev/null
+++ b/docker/training/requirements.txt
@@ -0,0 +1 @@
+../../requirements.txt
\ No newline at end of file
diff --git a/train.sh b/train.sh
new file mode 100755
index 0000000..a217f60
--- /dev/null
+++ b/train.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+
+DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
+
+docker build . -f ./docker/training/Dockerfile -t clarinpl/punctuator_training
+docker run -v $DIR:/punctuator --gpus all -it -d --entrypoint python clarinpl/punctuator_training -m $1
\ No newline at end of file
-- 
GitLab


From f3dbe1960d6703551f748714a69866388a0a7263 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 18 Aug 2020 11:14:23 +0200
Subject: [PATCH 064/116] Added descrption to train script

---
 train.sh | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/train.sh b/train.sh
index a217f60..a3c6a13 100755
--- a/train.sh
+++ b/train.sh
@@ -1,6 +1,9 @@
 #!/bin/bash
 
+# Usage: ./train.sh [module_to_run] [container_name]
+# Eg.: ./train.sh src.pipelines.actions_based.train_base base_training
+
 DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
 
 docker build . -f ./docker/training/Dockerfile -t clarinpl/punctuator_training
-docker run -v $DIR:/punctuator --gpus all -it -d --entrypoint python clarinpl/punctuator_training -m $1
\ No newline at end of file
+docker run -v $DIR:/punctuator --name $2 --gpus all -it -d --entrypoint python clarinpl/punctuator_training -m $1
\ No newline at end of file
-- 
GitLab


From 0017f5348ecc624832201fab62ccc2d351e802d5 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 18 Aug 2020 14:08:37 +0200
Subject: [PATCH 065/116] Import order refactoring

---
 config.ini                                    |  4 ++--
 docker/development/Dockerfile                 | 23 ++++++++++++++++++-
 src/models/TransformerSeq2Seq.py              |  3 +--
 src/models/actions_model_base.py              |  4 +++-
 src/models/actions_model_mixed.py             |  6 +++--
 src/models/actions_model_restricted.py        |  4 +++-
 .../actions_based/stage1_extraction.py        |  5 +---
 .../actions_based/stage2_tokenization.py      |  5 +---
 .../actions_based/stage3_exploding.py         |  7 +-----
 src/pipelines/actions_based/train_base.py     |  8 +++----
 src/pipelines/actions_based/train_mixed.py    |  8 +++----
 .../actions_based/train_restricted.py         | 11 ++++-----
 src/pipelines/train.py                        |  7 +++---
 src/pipelines/translation_based/processing.py |  2 +-
 .../translation_based/stage1_extraction.py    |  5 +---
 .../stage2_create_batches.py                  |  5 +---
 tests/models/test_actions_model_base.py       |  3 ++-
 tests/models/test_actions_model_restricted.py |  8 +++----
 worker.py                                     |  8 +++----
 19 files changed, 65 insertions(+), 61 deletions(-)

diff --git a/config.ini b/config.ini
index 93819e5..9e02bf5 100644
--- a/config.ini
+++ b/config.ini
@@ -1,8 +1,8 @@
 [service]
-tool = Punctuator
+tool = punctuator_test
 
 root = /samba/requests/
-rabbit_host = addr
+rabbit_host = test
 rabbit_user = test
 rabbit_password = test
 
diff --git a/docker/development/Dockerfile b/docker/development/Dockerfile
index 1535758..570c2be 100644
--- a/docker/development/Dockerfile
+++ b/docker/development/Dockerfile
@@ -38,4 +38,25 @@ ENV NVIDIA_REQUIRE_CUDA "cuda>=10.2 brand=tesla,driver>=384,driver<385 brand=tes
 ### END CUDA Installation
 
 RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pyarrow==0.17.1 pytest lxml
-RUN ln -s /usr/bin/pip3 /usr/bin/pip
\ No newline at end of file
+RUN ln -s /usr/bin/pip3 /usr/bin/pip
+
+ARG USERNAME=mpogoda
+ARG USER_UID=1000
+ARG USER_GID=$USER_UID
+
+# Create the user
+RUN groupadd --gid $USER_GID $USERNAME \
+    && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
+    #
+    # [Optional] Add sudo support. Omit if you don't need to install software after connecting.
+    && apt-get update \
+    && apt-get install -y sudo \
+    && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
+    && chmod 0440 /etc/sudoers.d/$USERNAME
+
+# ********************************************************
+# * Anything else you want to do like clean up goes here *
+# ********************************************************
+
+# [Optional] Set the default user. Omit if you want to keep the default as root.
+USER $USERNAME
\ No newline at end of file
diff --git a/src/models/TransformerSeq2Seq.py b/src/models/TransformerSeq2Seq.py
index 057a124..696c78a 100644
--- a/src/models/TransformerSeq2Seq.py
+++ b/src/models/TransformerSeq2Seq.py
@@ -3,5 +3,4 @@ import math
 import torch
 import torch.nn as nn
 
-from src.models.common import PositionalEncoding
-from src.models.common import TransformerSeq2Seq
+from src.models.common import PositionalEncoding, TransformerSeq2Seq
diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index 4fd334b..eb24a22 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -1,10 +1,12 @@
 from typing import Callable, Tuple
+
+import torch
 import torch.nn as nn
 from torch.nn.modules.loss import BCEWithLogitsLoss
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
+
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
-import torch
 
 
 class ActionsModelBase(nn.Module):
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index acd7f47..c9df6f7 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -1,11 +1,13 @@
 from typing import Callable, Tuple
+
+import torch
 import torch.nn as nn
 from torch.nn.modules.loss import BCEWithLogitsLoss
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
-import torch
+
 from src.models.common import PositionalEncoding, generate_square_subsequent_mask
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
 
 
 class ActionsModelMixed(nn.Module):
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index e73df2f..32d156e 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -1,9 +1,11 @@
 from typing import Tuple
+
+import torch
 import torch.nn as nn
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
+
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
-import torch
 
 
 class ActionsModelRestricted(nn.Module):
diff --git a/src/pipelines/actions_based/stage1_extraction.py b/src/pipelines/actions_based/stage1_extraction.py
index 94dc26c..5a058a9 100644
--- a/src/pipelines/actions_based/stage1_extraction.py
+++ b/src/pipelines/actions_based/stage1_extraction.py
@@ -6,10 +6,7 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.actions_based.processing import (
-    APPLY_FILE_PROCESSING_META,
-    apply_file_processing,
-)
+from src.pipelines.actions_based.processing import APPLY_FILE_PROCESSING_META, apply_file_processing
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/actions_based/stage2_tokenization.py b/src/pipelines/actions_based/stage2_tokenization.py
index b30445f..0ea3586 100644
--- a/src/pipelines/actions_based/stage2_tokenization.py
+++ b/src/pipelines/actions_based/stage2_tokenization.py
@@ -4,10 +4,7 @@ import dask.dataframe as dd
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.actions_based.processing import (
-    APPLY_TOKENIZATION_META,
-    apply_tokenization,
-)
+from src.pipelines.actions_based.processing import APPLY_TOKENIZATION_META, apply_tokenization
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
diff --git a/src/pipelines/actions_based/stage3_exploding.py b/src/pipelines/actions_based/stage3_exploding.py
index 72ec128..81dc965 100644
--- a/src/pipelines/actions_based/stage3_exploding.py
+++ b/src/pipelines/actions_based/stage3_exploding.py
@@ -2,12 +2,7 @@
 import dask.dataframe as dd
 from dask.distributed import Client
 
-from src.processing import (
-    EXPAND_DIMS_META,
-    FLATTEN_DIMS_META,
-    expand_dims,
-    flatten_dims,
-)
+from src.processing import EXPAND_DIMS_META, FLATTEN_DIMS_META, expand_dims, flatten_dims
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
diff --git a/src/pipelines/actions_based/train_base.py b/src/pipelines/actions_based/train_base.py
index 1e4c7f9..03f9899 100755
--- a/src/pipelines/actions_based/train_base.py
+++ b/src/pipelines/actions_based/train_base.py
@@ -1,15 +1,15 @@
 #!/usr/bin/python3
 
 import pickle
-from src.pipelines.train import TrainerBase
-from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
 
 import numpy as np
-import torch
 import pandas as pd
+import torch
 from transformers import BertTokenizerFast
 
+from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.pipelines.train import TrainerBase
 from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
@@ -87,4 +87,4 @@ class TrainerActions(TrainerBase):
 
 
 if __name__ == "__main__":
-    TrainerActions().train()
\ No newline at end of file
+    TrainerActions().train()
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index cafb383..20f87ca 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -1,15 +1,15 @@
 #!/usr/bin/python3
 
 import pickle
-from src.models.actions_model_mixed import ActionsModelMixed, ActionsModelMixedLoss
-from src.pipelines.train import TrainerBase
 
 import numpy as np
-import torch
 import pandas as pd
+import torch
 from transformers import BertTokenizerFast
 
+from src.models.actions_model_mixed import ActionsModelMixed, ActionsModelMixedLoss
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.pipelines.train import TrainerBase
 from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
@@ -113,4 +113,4 @@ class TrainerActions(TrainerBase):
 
 
 if __name__ == "__main__":
-    TrainerActions().train()
\ No newline at end of file
+    TrainerActions().train()
diff --git a/src/pipelines/actions_based/train_restricted.py b/src/pipelines/actions_based/train_restricted.py
index 048901e..193c75e 100755
--- a/src/pipelines/actions_based/train_restricted.py
+++ b/src/pipelines/actions_based/train_restricted.py
@@ -1,17 +1,14 @@
 #!/usr/bin/python3
 
 import pickle
-from src.models.actions_model_restricted import (
-    ActionsModelRestricted,
-    ActionsModelRestrictedLoss,
-)
-from src.pipelines.train import TrainerBase
 
 import numpy as np
-import torch
 import pandas as pd
+import torch
 
+from src.models.actions_model_restricted import ActionsModelRestricted, ActionsModelRestrictedLoss
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.pipelines.train import TrainerBase
 from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
@@ -115,4 +112,4 @@ class TrainerActions(TrainerBase):
 
 
 if __name__ == "__main__":
-    TrainerActions().train()
\ No newline at end of file
+    TrainerActions().train()
diff --git a/src/pipelines/train.py b/src/pipelines/train.py
index 8e988cc..087c506 100644
--- a/src/pipelines/train.py
+++ b/src/pipelines/train.py
@@ -1,21 +1,20 @@
 #!/usr/bin/python3
 
 import glob
+from abc import ABC, abstractmethod
 from datetime import datetime, timedelta
 from typing import Optional
 
-from torch.optim.optimizer import Optimizer
-
 import dask.dataframe as dd
 import numpy as np
+import pandas as pd
 import torch
 import torch.nn as nn
-import pandas as pd
+from torch.optim.optimizer import Optimizer
 
 from src.batch_loading import get_batches
 from src.training import latest_model, save_training_step
 from src.utils import convert_to_timedelta, prepare_folder
-from abc import ABC, abstractmethod
 
 
 class TrainerBase(ABC):
diff --git a/src/pipelines/translation_based/processing.py b/src/pipelines/translation_based/processing.py
index b9901e0..608cf43 100644
--- a/src/pipelines/translation_based/processing.py
+++ b/src/pipelines/translation_based/processing.py
@@ -1,10 +1,10 @@
-from src.utils import input_preprocess
 from typing import List, Tuple
 
 import numpy as np
 from transformers import BertTokenizerFast
 
 from src.pipelines.actions_based.processing import text_from_xml
+from src.utils import input_preprocess
 
 
 def raw_to_dataframe(entry: dict) -> dict:
diff --git a/src/pipelines/translation_based/stage1_extraction.py b/src/pipelines/translation_based/stage1_extraction.py
index 6ffdbf7..386211d 100644
--- a/src/pipelines/translation_based/stage1_extraction.py
+++ b/src/pipelines/translation_based/stage1_extraction.py
@@ -6,10 +6,7 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.translation_based.processing import (
-    RAW_TO_DATAFRAME_META,
-    raw_to_dataframe,
-)
+from src.pipelines.translation_based.processing import RAW_TO_DATAFRAME_META, raw_to_dataframe
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/translation_based/stage2_create_batches.py b/src/pipelines/translation_based/stage2_create_batches.py
index 83a2edc..ade8bf2 100644
--- a/src/pipelines/translation_based/stage2_create_batches.py
+++ b/src/pipelines/translation_based/stage2_create_batches.py
@@ -4,10 +4,7 @@ from dask import delayed
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.translation_based.processing import (
-    GENERATE_BATCHES_META,
-    generate_batches,
-)
+from src.pipelines.translation_based.processing import GENERATE_BATCHES_META, generate_batches
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
diff --git a/tests/models/test_actions_model_base.py b/tests/models/test_actions_model_base.py
index ee3459b..29a7ae0 100644
--- a/tests/models/test_actions_model_base.py
+++ b/tests/models/test_actions_model_base.py
@@ -1,6 +1,7 @@
 import torch
 import torch.distributions as dist
 from transformers.tokenization_bert import BertTokenizerFast
+
 from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
 
 
@@ -37,4 +38,4 @@ def test_loss_dimensions():
     result_perfect = loss(actions_vector_true, actions_vector_true)
     result_bad = loss(actions_vector_true, actions_vector_bad)
 
-    assert result_perfect < result_bad
\ No newline at end of file
+    assert result_perfect < result_bad
diff --git a/tests/models/test_actions_model_restricted.py b/tests/models/test_actions_model_restricted.py
index 679ab58..9633876 100644
--- a/tests/models/test_actions_model_restricted.py
+++ b/tests/models/test_actions_model_restricted.py
@@ -1,11 +1,9 @@
-from src.models.actions_model_restricted import (
-    ActionsModelRestricted,
-    ActionsModelRestrictedLoss,
-)
 import torch
 import torch.distributions as dist
 from transformers.tokenization_bert import BertTokenizerFast
 
+from src.models.actions_model_restricted import ActionsModelRestricted, ActionsModelRestrictedLoss
+
 
 def test_dimensions():
     base_model = "dkleczek/bert-base-polish-cased-v1"
@@ -47,4 +45,4 @@ def test_loss_dimensions():
     print(result_perfect)
     print(result_bad)
 
-    assert result_perfect < result_bad
\ No newline at end of file
+    assert result_perfect < result_bad
diff --git a/worker.py b/worker.py
index 2d91eb5..cc5790a 100755
--- a/worker.py
+++ b/worker.py
@@ -5,7 +5,7 @@ import configparser
 import nlp_ws
 
 from src.pipelines.actions_based.utils import apply_actions_punctuation, load_model
-from src.utils import preprocess
+from src.utils import input_preprocess
 
 
 class Worker(nlp_ws.NLPWorker):
@@ -15,8 +15,8 @@ class Worker(nlp_ws.NLPWorker):
         self.config = configparser.ConfigParser()
         self.config.read("config.ini")
 
-        self.threshold = self.config["deployment"]["threshold"]
-        self.chunk_size = self.config["deployment"]["chunk_size"]
+        self.threshold = float(self.config["deployment"]["threshold"])
+        self.chunk_size = float(self.config["deployment"]["chunk_size"])
         self.tokenizer, self.model = load_model(
             self.config["deployment"]["model"],
             self.config["deployment"]["base_model"],
@@ -27,7 +27,7 @@ class Worker(nlp_ws.NLPWorker):
         """Implementation of example tasks that copies files."""
 
         with open(input_file, "r") as f:
-            text = preprocess(f.read())
+            text = input_preprocess(f.read())
             text_processed = apply_actions_punctuation(
                 text, self.chunk_size, self.tokenizer, self.model, self.threshold
             )
-- 
GitLab


From a6773978f6a3a60524a393c9cbb916025c84fc08 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 19 Aug 2020 12:39:44 +0200
Subject: [PATCH 066/116] Training updates

---
 docker/development/Dockerfile              | 16 ++++++----
 dvc.yaml                                   | 14 ++++-----
 params.yaml                                | 12 ++++----
 src/models/actions_model_mixed.py          | 36 +++++++++++++++++++++-
 src/pipelines/actions_based/train_mixed.py |  2 ++
 train.sh                                   |  2 +-
 6 files changed, 61 insertions(+), 21 deletions(-)

diff --git a/docker/development/Dockerfile b/docker/development/Dockerfile
index 570c2be..29a9224 100644
--- a/docker/development/Dockerfile
+++ b/docker/development/Dockerfile
@@ -41,12 +41,12 @@ RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pya
 RUN ln -s /usr/bin/pip3 /usr/bin/pip
 
 ARG USERNAME=mpogoda
-ARG USER_UID=1000
-ARG USER_GID=$USER_UID
+ARG USER_UID=1030
+ARG USER_GID=1032
 
 # Create the user
-RUN groupadd --gid $USER_GID $USERNAME \
-    && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
+RUN groupadd --gid 1032 $USERNAME \
+    && useradd --uid 1030 --gid 1032 -m $USERNAME \
     #
     # [Optional] Add sudo support. Omit if you don't need to install software after connecting.
     && apt-get update \
@@ -54,9 +54,13 @@ RUN groupadd --gid $USER_GID $USERNAME \
     && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
     && chmod 0440 /etc/sudoers.d/$USERNAME
 
+
+RUN groupmod --gid $USER_GID $USERNAME \
+    && usermod --uid $USER_UID --gid $USER_GID $USERNAME \
+    && chown -R $USER_UID:$USER_GID /home/$USERNAME
+
 # ********************************************************
 # * Anything else you want to do like clean up goes here *
 # ********************************************************
 
-# [Optional] Set the default user. Omit if you want to keep the default as root.
-USER $USERNAME
\ No newline at end of file
+USER ${USERNAME}
\ No newline at end of file
diff --git a/dvc.yaml b/dvc.yaml
index a73590c..e4b7cba 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -12,7 +12,7 @@ stages:
     cmd: python3 -m src.pipelines.actions_based.stage2_tokenization
     deps:
     - generated/actions/stage1_extraction
-    - src/pipelines/actions_based/stage2_tokenization.py
+    - src
     params:
     - actions.tokenization.max_tokens
     - actions.tokenization.min_tokens
@@ -23,21 +23,21 @@ stages:
     cmd: python3 -m src.pipelines.actions_based.stage3_exploding
     deps:
     - generated/actions/stage2_tokenization
-    - src/pipelines/actions_based/stage3_exploding.py
+    - src
     outs:
     - generated/actions/stage3_exploding
   actions_reindexing:
     cmd: python3 -m src.pipelines.actions_based.stage4_reindexing
     deps:
     - generated/actions/stage3_exploding
-    - src/pipelines/actions_based/stage4_reindexing.py
+    - src
     outs:
     - generated/actions/stage4_reindexing
   actions_stats:
     cmd: python3 -m src.pipelines.actions_based.stage5_stats
     deps:
     - generated/actions/stage4_reindexing
-    - src/pipelines/actions_based/stage5_stats.py
+    - src
     outs:
     - generated/actions/stage5_stats
   actions_base_training:
@@ -45,7 +45,7 @@ stages:
     deps:
     - generated/actions/stage4_reindexing
     - generated/actions/stage5_stats
-    - src/pipelines/actions_based/train.py
+    - src
     params:
     - global.base_model
     - global.random_seed
@@ -62,7 +62,7 @@ stages:
     deps:
     - generated/actions/stage4_reindexing
     - generated/actions/stage5_stats
-    - src/pipelines/actions_based/train.py
+    - src
     params:
     - global.base_model
     - global.random_seed
@@ -79,7 +79,7 @@ stages:
     deps:
     - generated/actions/stage4_reindexing
     - generated/actions/stage5_stats
-    - src/pipelines/actions_based/train.py
+    - src
     params:
     - global.base_model
     - global.random_seed
diff --git a/params.yaml b/params.yaml
index 39c7acf..c965809 100644
--- a/params.yaml
+++ b/params.yaml
@@ -50,16 +50,16 @@ actions:
         device: "cuda:0"
 
     training_mixed:
-        embedding_size: 200
-        num_heads: 4
-        num_layers: 2
+        embedding_size: 768
+        num_heads: 12
+        num_layers: 6
         dropout: 0.1
-        feedforward_neurons: 500
+        feedforward_neurons: 1000
         learning_rate: 0.0001
         num_epochs: 5
         batch_size: 2
         batch_buffer_size: 1000
-        save_step: 1000
+        save_step: 10000
         max_training_time: null
         loss_averaging_span: 1000
         fresh_start: true
@@ -92,4 +92,4 @@ translations:
         max_training_time: "4h"
         loss_averaging_span: 1000
         fresh_start: false
-        device: "cuda:1"
\ No newline at end of file
+        device: "cuda:1"
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index c9df6f7..f524c23 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -1,13 +1,21 @@
 from typing import Callable, Tuple
 
 import torch
+from torch import device
 import torch.nn as nn
 from torch.nn.modules.loss import BCEWithLogitsLoss
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
+from transformers.tokenization_bert import BertTokenizerFast
+import numpy as np
 
 from src.models.common import PositionalEncoding, generate_square_subsequent_mask
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.pipelines.actions_based.processing import (
+    ACTIONS_KEYS,
+    action_vector,
+    empty_action_vector,
+    recover_text,
+)
 
 
 class ActionsModelMixed(nn.Module):
@@ -38,6 +46,9 @@ class ActionsModelMixed(nn.Module):
         """
         super(ActionsModelMixed, self).__init__()
 
+        self.num_labels = num_labels
+        self.device = device
+
         # Word embedder
         self.word_embedding = nn.Embedding(vocab_size, embedding_size)
         self.punctuation_embedding = nn.Linear(num_labels, embedding_size)
@@ -107,6 +118,29 @@ class ActionsModelMixed(nn.Module):
 
         return self.to_labels(z)
 
+    def predict(
+        self, text: str, tokenizer: BertTokenizerFast, threshold: float = 0.9
+    ) -> str:
+        inputs = [action_vector(["upper_case"])]
+
+        text_tokenized = tokenizer(text, return_tensors="pt")
+
+        for _ in range(text_tokenized["input_ids"].shape[1] - 1):
+            prediction_raw = self.forward(
+                text_tokenized["input_ids"],
+                torch.tensor(inputs, dtype=torch.float).reshape(1, -1, self.num_labels),
+                text_tokenized["attention_mask"] == 0,
+            ).sigmoid()
+
+            inputs.append(
+                (prediction_raw.detach().numpy()[0, -1, :] > threshold).astype(np.float)
+            )
+
+        inputs = np.array(inputs)[1:]
+        prediction_binary = inputs.astype(np.int)
+
+        return recover_text(text, prediction_binary)
+
 
 class ActionsModelMixedLoss(nn.Module):
     """Class representing proposed loss for training mixed actions model"""
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index 20f87ca..0c6c0b4 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -38,6 +38,8 @@ class TrainerActions(TrainerBase):
         base_model = config["global"]["base_model"]
         seed = config["global"]["random_seed"]
 
+        print(f"Layers: {num_layers}")
+
         if max_train_time is not None:
             max_train_time = convert_to_timedelta(max_train_time)
 
diff --git a/train.sh b/train.sh
index a3c6a13..e0fb415 100755
--- a/train.sh
+++ b/train.sh
@@ -6,4 +6,4 @@
 DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
 
 docker build . -f ./docker/training/Dockerfile -t clarinpl/punctuator_training
-docker run -v $DIR:/punctuator --name $2 --gpus all -it -d --entrypoint python clarinpl/punctuator_training -m $1
\ No newline at end of file
+docker run -v $DIR:/punctuator --name $2 --gpus all -it --entrypoint python clarinpl/punctuator_training -m $1
-- 
GitLab


From 92e32d06f8ff7f33450ee7bd01c923d4917d6c02 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 20 Aug 2020 09:03:42 +0000
Subject: [PATCH 067/116] Fixed batch dim squeeze when batch_size=1

---
 src/pipelines/actions_based/train_mixed.py      | 5 +++--
 src/pipelines/actions_based/train_restricted.py | 4 ++--
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index 0c6c0b4..8a6daba 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -94,9 +94,10 @@ class TrainerActions(TrainerBase):
             lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]), axis=1,
         ).values
 
-        inputs = torch.tensor(np.stack(inputs).squeeze(), dtype=torch.long).to(
+        inputs = torch.tensor(np.stack(inputs), dtype=torch.long).to(
             self.device
-        )
+        ).squeeze(dim=2)
+
         outputs = torch.tensor(np.stack(outputs), dtype=torch.float).to(self.device)
 
         # Convert to boolean
diff --git a/src/pipelines/actions_based/train_restricted.py b/src/pipelines/actions_based/train_restricted.py
index 193c75e..06d4932 100755
--- a/src/pipelines/actions_based/train_restricted.py
+++ b/src/pipelines/actions_based/train_restricted.py
@@ -98,7 +98,7 @@ class TrainerActions(TrainerBase):
             lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]), axis=1,
         ).values
 
-        inputs = torch.tensor(np.stack(inputs).squeeze()).to(self.device)
+        inputs = torch.tensor(np.stack(inputs)).squeeze(dim=2).to(self.device)
         outputs = torch.tensor(np.stack(outputs)).to(self.device)
         attentions_mask = torch.tensor(np.stack(attentions_mask)).to(self.device)
 
@@ -107,7 +107,7 @@ class TrainerActions(TrainerBase):
         outputs = torch.cat(
             [outputs, (1.0 - outputs[:, :, 1:].max(-1)[0]).unsqueeze(-1)], axis=-1
         )
-
+        
         return self.criterion(outputs, y_pred)
 
 
-- 
GitLab


From ea95ed9a92a31a2e90b2cf7b7f5e0825ada401e8 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 20 Aug 2020 14:51:22 +0200
Subject: [PATCH 068/116] Fixed loss calucation & unbalanced weighting

---
 src/models/actions_model_mixed.py             |  2 +-
 src/models/actions_model_restricted.py        | 24 +++++++++-------
 .../actions_based/train_restricted.py         | 28 ++++++++-----------
 3 files changed, 27 insertions(+), 27 deletions(-)

diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index f524c23..eae23e8 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -153,7 +153,7 @@ class ActionsModelMixedLoss(nn.Module):
         """
         super(ActionsModelMixedLoss, self).__init__()
 
-        self.core = BCEWithLogitsLoss(prior_odds)
+        self.core = BCEWithLogitsLoss(pos_weight=prior_odds)
 
     def forward(
         self,
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index 32d156e..f453e94 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -42,12 +42,7 @@ class ActionsModelRestricted(nn.Module):
         """
         y_pred = self.core(input_ids=input_ids, attention_mask=attention_mask)[0]
 
-        pred_uppercase = y_pred[:, :, :1]
-
-        # Force punctuations to be proper categorical distribution logits
-        pred_punctuation = self._logit(torch.softmax(y_pred[:, :, 1:], -1))
-
-        return torch.cat([pred_uppercase, pred_punctuation], -1)
+        return y_pred
 
     @staticmethod
     def _logit(x: torch.Tensor):
@@ -59,15 +54,16 @@ class ActionsModelRestricted(nn.Module):
 
 
 class ActionsModelRestrictedLoss(nn.Module):
-    def __init__(self, prior_odds: torch.Tensor) -> None:
+    def __init__(self, prior_uppercase_odds: torch.Tensor, punctuation_weights: torch.Tensor) -> None:
         super(ActionsModelRestrictedLoss, self).__init__()
 
-        self.core = nn.BCEWithLogitsLoss(prior_odds)
+        self.binary_ce = nn.BCEWithLogitsLoss(pos_weight=prior_uppercase_odds)
+        self.cat_ce = nn.CrossEntropyLoss(punctuation_weights)
 
     def forward(
         self,
-        true_extended_action_vector: torch.Tensor,
         predicted_action_vector_logits: torch.Tensor,
+        true_extended_action_vector: torch.Tensor,
     ) -> torch.Tensor:
         """Loss for ActionsModelRestricted model
 
@@ -78,5 +74,13 @@ class ActionsModelRestrictedLoss(nn.Module):
         Returns:
             torch.Tensor: Loss value
         """
+        
+        predicted_punc = predicted_action_vector_logits[:, :, 1:].transpose(1, 2)
+        target_punc_index = torch.argmax(true_extended_action_vector[:, :, 1:], dim=-1)
+        punc_loss = self.cat_ce(predicted_punc, target_punc_index)
+
+        predicted_uppercase = predicted_action_vector_logits[:, :, 0]
+        target_uppercase = true_extended_action_vector[:, :, 0]
+        uppercase_loss = self.binary_ce(predicted_uppercase, target_uppercase)
 
-        return self.core(predicted_action_vector_logits, true_extended_action_vector)
+        return punc_loss + uppercase_loss
diff --git a/src/pipelines/actions_based/train_restricted.py b/src/pipelines/actions_based/train_restricted.py
index 06d4932..a9da8fe 100755
--- a/src/pipelines/actions_based/train_restricted.py
+++ b/src/pipelines/actions_based/train_restricted.py
@@ -5,6 +5,7 @@ import pickle
 import numpy as np
 import pandas as pd
 import torch
+from torch._C import dtype
 
 from src.models.actions_model_restricted import ActionsModelRestricted, ActionsModelRestrictedLoss
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
@@ -44,32 +45,27 @@ class TrainerActions(TrainerBase):
             stats = pickle.load(f)
             pos_examples = stats["class_number"]
             neg_examples = stats["num_examples"] - stats["class_number"]
-            pos_weight = torch.tensor(neg_examples / pos_examples)
 
-            # Load loss weights
-        with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
-            stats = pickle.load(f)
-            pos_examples = stats["class_number"]
-            neg_examples = stats["num_examples"] - stats["class_number"]
+            uppercase_pos_examples = pos_examples[0]
+            uppercase_neg_examples = neg_examples[0]
+            uppercase_pos_odds = torch.tensor(uppercase_pos_examples / uppercase_neg_examples, dtype=torch.float)
 
-            no_punctuation_pos_examples = np.sum(neg_examples[1:])
-            no_punctuation_neg_examples = np.sum(pos_examples[1:])
+            has_punctuation_neg_examples = neg_examples[1:]
+            has_no_punctuation_neg_examples = np.sum(pos_examples[1:])
 
-            pos_examples = np.concatenate(
-                [pos_examples, no_punctuation_pos_examples.reshape(1)], -1
-            )
-            neg_examples = np.concatenate(
-                [neg_examples, no_punctuation_neg_examples.reshape(1)], -1
+            punctuation_neg_examples = np.concatenate(
+                [has_punctuation_neg_examples, 
+                has_no_punctuation_neg_examples.reshape(1)], -1
             )
 
-            pos_weight = torch.tensor(neg_examples / pos_examples)
+            punctuation_class_weights = torch.tensor((punctuation_neg_examples) / np.sum(punctuation_neg_examples), dtype=torch.float)
 
         np.random.seed(seed=seed)
 
         device = torch.device(device_name if torch.cuda.is_available() else "cpu")
 
         model = ActionsModelRestricted(base_model, len(ACTIONS_KEYS) + 1).to(device)
-        self.criterion = ActionsModelRestrictedLoss(pos_weight).to(device)
+        self.criterion = ActionsModelRestrictedLoss(uppercase_pos_odds, punctuation_class_weights).to(device)
         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
         super(TrainerActions, self).__init__(
@@ -108,7 +104,7 @@ class TrainerActions(TrainerBase):
             [outputs, (1.0 - outputs[:, :, 1:].max(-1)[0]).unsqueeze(-1)], axis=-1
         )
         
-        return self.criterion(outputs, y_pred)
+        return self.criterion(y_pred, outputs)
 
 
 if __name__ == "__main__":
-- 
GitLab


From 1a813f22447b6212ae46985408acee2d2f95d8a6 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 20 Aug 2020 16:08:30 +0200
Subject: [PATCH 069/116] Made all unit test pass

---
 src/models/TransformerSeq2Seq.py              | 73 ++++++++++++++++++-
 src/models/actions_model_base.py              | 19 ++++-
 src/models/actions_model_mixed.py             |  9 +--
 src/models/actions_model_restricted.py        | 25 ++++---
 src/models/common.py                          | 69 ------------------
 src/pipelines/actions_based/train_base.py     |  1 -
 src/pipelines/actions_based/train_mixed.py    |  8 +-
 .../actions_based/train_restricted.py         | 28 +++++--
 src/pipelines/train.py                        |  3 +-
 tests/models/test_actions_model_base.py       |  1 -
 tests/models/test_actions_model_restricted.py | 22 ++++--
 11 files changed, 148 insertions(+), 110 deletions(-)

diff --git a/src/models/TransformerSeq2Seq.py b/src/models/TransformerSeq2Seq.py
index 696c78a..3009fae 100644
--- a/src/models/TransformerSeq2Seq.py
+++ b/src/models/TransformerSeq2Seq.py
@@ -1,6 +1,73 @@
-import math
-
 import torch
 import torch.nn as nn
 
-from src.models.common import PositionalEncoding, TransformerSeq2Seq
+from src.models.common import PositionalEncoding
+
+
+class TransformerSeq2Seq(nn.Module):
+    """Class representing a sequence to sequence transformer, based on original "Attention is all you need" paper."""
+
+    def __init__(
+        self,
+        vocab_size: int,
+        embedding_size: int,
+        max_len: int,
+        num_heads: int = 8,
+        encoder_layers: int = 6,
+        decoder_layers: int = 6,
+        feedforward_neurons: int = 2048,
+        dropout: float = 0.1,
+    ):
+
+        super(TransformerSeq2Seq, self).__init__()
+
+        # Embedd from token to vec space
+        self.word_embedding = nn.Embedding(vocab_size, embedding_size)
+
+        # Add positional encoding
+        self.position_embedding = PositionalEncoding(embedding_size, max_len, dropout)
+
+        # Combined encoder-decoder step
+        self.core = nn.Transformer(
+            embedding_size,
+            num_heads,
+            encoder_layers,
+            decoder_layers,
+            feedforward_neurons,
+            dropout,
+        )
+
+        # Map embedding to word
+        self.embedding_to_words = nn.Linear(embedding_size, vocab_size)
+
+    def forward(
+        self, source: torch.Tensor, target: torch.Tensor, source_mask: torch.Tensor,
+    ) -> torch.Tensor:
+        """Full encoder-decoder pass
+
+        Args:
+            source (torch.Tensor): Tensor with batch of source sentences tokens [BxL shape]
+            target (torch.Tensor): Tensor with batch of target sentences tokens [BxL-1 shape]
+            source_mask (torch.Tensor): Mask applied to source (True if element is padding, False otherwise) [BxL shape]
+
+        Returns:
+            torch.Tensor: Tensor with predicted target sentences tokens [Bx(L-1)xV]
+        """
+        # Input to encoder
+        x = source.transpose(0, 1)
+        x = self.word_embedding(x)
+        x = self.position_embedding(x)
+
+        # Input to decoder
+        y = target.transpose(0, 1)
+        y = self.word_embedding(y)
+        y = self.position_embedding(y)
+
+        tgt_mask = self.core.generate_square_subsequent_mask(y.shape[0]).to(y.device)
+
+        z = self.core(
+            x, y, src_key_padding_mask=source_mask, tgt_mask=tgt_mask
+        ).transpose(1, 0)
+        z = self.embedding_to_words(z)
+
+        return z
diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index eb24a22..064cc91 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -1,5 +1,3 @@
-from typing import Callable, Tuple
-
 import torch
 import torch.nn as nn
 from torch.nn.modules.loss import BCEWithLogitsLoss
@@ -45,7 +43,15 @@ class ActionsModelBase(nn.Module):
 
 
 class ActionsModelBaseLoss(nn.Module):
+    """Proposed loss for ActionsModelBase model"""
+
     def __init__(self, prior_odds: torch.Tensor) -> None:
+        """Initializes ActionsModelBaseLoss
+
+        Args:
+            prior_odds (torch.Tensor): Positive to negative ratio of each action vector
+                entry in dataset. Shape A
+        """
         super(ActionsModelBaseLoss, self).__init__()
 
         self.core = BCEWithLogitsLoss(prior_odds)
@@ -55,5 +61,14 @@ class ActionsModelBaseLoss(nn.Module):
         true_action_vector: torch.Tensor,
         predicted_action_vector_logits: torch.Tensor,
     ) -> torch.Tensor:
+        """Computes ActionsModelBase loss
+
+        Args:
+            true_action_vector (torch.Tensor): Logits predicted by ActionsModelBase model. Shape BxLxA
+            predicted_action_vector_logits (torch.Tensor): Target labels. Shape BxLxA
+
+        Returns:
+            torch.Tensor: Computed loss.
+        """
 
         return self.core(predicted_action_vector_logits, true_action_vector)
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index eae23e8..8eee0e7 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -1,19 +1,14 @@
-from typing import Callable, Tuple
-
+import numpy as np
 import torch
-from torch import device
 import torch.nn as nn
+from torch import device
 from torch.nn.modules.loss import BCEWithLogitsLoss
-from transformers.configuration_utils import PretrainedConfig
-from transformers.modeling_bert import BertForTokenClassification
 from transformers.tokenization_bert import BertTokenizerFast
-import numpy as np
 
 from src.models.common import PositionalEncoding, generate_square_subsequent_mask
 from src.pipelines.actions_based.processing import (
     ACTIONS_KEYS,
     action_vector,
-    empty_action_vector,
     recover_text,
 )
 
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index f453e94..43dae15 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -1,12 +1,8 @@
-from typing import Tuple
-
 import torch
 import torch.nn as nn
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
 
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
-
 
 class ActionsModelRestricted(nn.Module):
     """Similar to ActionsModelBase, however no-punctuation class is added
@@ -54,10 +50,21 @@ class ActionsModelRestricted(nn.Module):
 
 
 class ActionsModelRestrictedLoss(nn.Module):
-    def __init__(self, prior_uppercase_odds: torch.Tensor, punctuation_weights: torch.Tensor) -> None:
+    def __init__(
+        self, prior_uppercase_odds: torch.Tensor, punctuation_weights: torch.Tensor
+    ) -> None:
+        """Initializes ActionsModelRestrictedLoss
+
+        Args:
+            prior_uppercase_odds (torch.Tensor): Odds od positive to negative cases of uppercase in dataset
+            punctuation_weights (torch.Tensor): Weights for each class in loss function. Should be inversly proportional to number of
+                their occurances in dataset (Shape A+1)
+        """
         super(ActionsModelRestrictedLoss, self).__init__()
 
-        self.binary_ce = nn.BCEWithLogitsLoss(pos_weight=prior_uppercase_odds)
+        self.binary_ce = nn.BCEWithLogitsLoss(
+            pos_weight=prior_uppercase_odds.reshape(1)
+        )
         self.cat_ce = nn.CrossEntropyLoss(punctuation_weights)
 
     def forward(
@@ -68,13 +75,13 @@ class ActionsModelRestrictedLoss(nn.Module):
         """Loss for ActionsModelRestricted model
 
         Args:
-            true_extended_action_vector (torch.Tensor): Ground-truth action vectors. Shape BxLxA
-            predicted_action_vector_logits (torch.Tensor): Action vector-s logits predicted by ActionsModelRestricted model. Shape BxLxA
+            true_extended_action_vector (torch.Tensor): Ground-truth action vectors. Shape BxLxA+1
+            predicted_action_vector_logits (torch.Tensor): Action vector-s logits predicted by ActionsModelRestricted model. Shape BxLxA+1
 
         Returns:
             torch.Tensor: Loss value
         """
-        
+
         predicted_punc = predicted_action_vector_logits[:, :, 1:].transpose(1, 2)
         target_punc_index = torch.argmax(true_extended_action_vector[:, :, 1:], dim=-1)
         punc_loss = self.cat_ce(predicted_punc, target_punc_index)
diff --git a/src/models/common.py b/src/models/common.py
index bcf4bb2..012999f 100644
--- a/src/models/common.py
+++ b/src/models/common.py
@@ -58,72 +58,3 @@ class PositionalEncoding(nn.Module):
         """
         x = x + self.pe[: x.size(0), :]
         return self.dropout(x)
-
-
-class TransformerSeq2Seq(nn.Module):
-    """Class representing a sequence to sequence transformer, based on original "Attention is all you need" paper."""
-
-    def __init__(
-        self,
-        vocab_size: int,
-        embedding_size: int,
-        max_len: int,
-        num_heads: int = 8,
-        encoder_layers: int = 6,
-        decoder_layers: int = 6,
-        feedforward_neurons: int = 2048,
-        dropout: float = 0.1,
-    ):
-
-        super(TransformerSeq2Seq, self).__init__()
-
-        # Embedd from token to vec space
-        self.word_embedding = nn.Embedding(vocab_size, embedding_size)
-
-        # Add positional encoding
-        self.position_embedding = PositionalEncoding(embedding_size, max_len, dropout)
-
-        # Combined encoder-decoder step
-        self.core = nn.Transformer(
-            embedding_size,
-            num_heads,
-            encoder_layers,
-            decoder_layers,
-            feedforward_neurons,
-            dropout,
-        )
-
-        # Map embedding to word
-        self.embedding_to_words = nn.Linear(embedding_size, vocab_size)
-
-    def forward(
-        self, source: torch.Tensor, target: torch.Tensor, source_mask: torch.Tensor,
-    ) -> torch.Tensor:
-        """Full encoder-decoder pass
-
-        Args:
-            source (torch.Tensor): Tensor with batch of source sentences tokens [BxL shape]
-            target (torch.Tensor): Tensor with batch of target sentences tokens [BxL-1 shape]
-            source_mask (torch.Tensor): Mask applied to source (True if element is padding, False otherwise) [BxL shape]
-
-        Returns:
-            torch.Tensor: Tensor with predicted target sentences tokens [Bx(L-1)xV]
-        """
-        # Input to encoder
-        x = source.transpose(0, 1)
-        x = self.word_embedding(x)
-        x = self.position_embedding(x)
-
-        # Input to decoder
-        y = target.transpose(0, 1)
-        y = self.word_embedding(y)
-        y = self.position_embedding(y)
-
-        tgt_mask = self.core.generate_square_subsequent_mask(y.shape[0]).to(y.device)
-
-        z = self.core(
-            x, y, src_key_padding_mask=source_mask, tgt_mask=tgt_mask
-        ).transpose(1, 0)
-        z = self.embedding_to_words(z)
-
-        return z
diff --git a/src/pipelines/actions_based/train_base.py b/src/pipelines/actions_based/train_base.py
index 03f9899..93500f9 100755
--- a/src/pipelines/actions_based/train_base.py
+++ b/src/pipelines/actions_based/train_base.py
@@ -5,7 +5,6 @@ import pickle
 import numpy as np
 import pandas as pd
 import torch
-from transformers import BertTokenizerFast
 
 from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index 8a6daba..be9b636 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -94,9 +94,11 @@ class TrainerActions(TrainerBase):
             lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]), axis=1,
         ).values
 
-        inputs = torch.tensor(np.stack(inputs), dtype=torch.long).to(
-            self.device
-        ).squeeze(dim=2)
+        inputs = (
+            torch.tensor(np.stack(inputs), dtype=torch.long)
+            .to(self.device)
+            .squeeze(dim=2)
+        )
 
         outputs = torch.tensor(np.stack(outputs), dtype=torch.float).to(self.device)
 
diff --git a/src/pipelines/actions_based/train_restricted.py b/src/pipelines/actions_based/train_restricted.py
index a9da8fe..5e7d8f3 100755
--- a/src/pipelines/actions_based/train_restricted.py
+++ b/src/pipelines/actions_based/train_restricted.py
@@ -5,9 +5,11 @@ import pickle
 import numpy as np
 import pandas as pd
 import torch
-from torch._C import dtype
 
-from src.models.actions_model_restricted import ActionsModelRestricted, ActionsModelRestrictedLoss
+from src.models.actions_model_restricted import (
+    ActionsModelRestricted,
+    ActionsModelRestrictedLoss,
+)
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
 from src.pipelines.train import TrainerBase
 from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config
@@ -48,24 +50,34 @@ class TrainerActions(TrainerBase):
 
             uppercase_pos_examples = pos_examples[0]
             uppercase_neg_examples = neg_examples[0]
-            uppercase_pos_odds = torch.tensor(uppercase_pos_examples / uppercase_neg_examples, dtype=torch.float)
+            uppercase_pos_odds = torch.tensor(
+                uppercase_pos_examples / uppercase_neg_examples, dtype=torch.float
+            )
 
             has_punctuation_neg_examples = neg_examples[1:]
             has_no_punctuation_neg_examples = np.sum(pos_examples[1:])
 
             punctuation_neg_examples = np.concatenate(
-                [has_punctuation_neg_examples, 
-                has_no_punctuation_neg_examples.reshape(1)], -1
+                [
+                    has_punctuation_neg_examples,
+                    has_no_punctuation_neg_examples.reshape(1),
+                ],
+                -1,
             )
 
-            punctuation_class_weights = torch.tensor((punctuation_neg_examples) / np.sum(punctuation_neg_examples), dtype=torch.float)
+            punctuation_class_weights = torch.tensor(
+                (punctuation_neg_examples) / np.sum(punctuation_neg_examples),
+                dtype=torch.float,
+            )
 
         np.random.seed(seed=seed)
 
         device = torch.device(device_name if torch.cuda.is_available() else "cpu")
 
         model = ActionsModelRestricted(base_model, len(ACTIONS_KEYS) + 1).to(device)
-        self.criterion = ActionsModelRestrictedLoss(uppercase_pos_odds, punctuation_class_weights).to(device)
+        self.criterion = ActionsModelRestrictedLoss(
+            uppercase_pos_odds, punctuation_class_weights
+        ).to(device)
         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
         super(TrainerActions, self).__init__(
@@ -103,7 +115,7 @@ class TrainerActions(TrainerBase):
         outputs = torch.cat(
             [outputs, (1.0 - outputs[:, :, 1:].max(-1)[0]).unsqueeze(-1)], axis=-1
         )
-        
+
         return self.criterion(y_pred, outputs)
 
 
diff --git a/src/pipelines/train.py b/src/pipelines/train.py
index 087c506..f282a8e 100644
--- a/src/pipelines/train.py
+++ b/src/pipelines/train.py
@@ -144,7 +144,8 @@ class TrainerBase(ABC):
 
                 losses.append(loss.item())
                 if len(losses) > self.loss_averaging_span:
-                    losses = losses[-self.loss_averaging_span :]
+                    fist_to_keep = -self.loss_averaging_span
+                    losses = losses[fist_to_keep:]
 
                 print(f"epoch: {epoch} | step: {i} | loss: {np.mean(losses)}")
 
diff --git a/tests/models/test_actions_model_base.py b/tests/models/test_actions_model_base.py
index 29a7ae0..ecca181 100644
--- a/tests/models/test_actions_model_base.py
+++ b/tests/models/test_actions_model_base.py
@@ -1,5 +1,4 @@
 import torch
-import torch.distributions as dist
 from transformers.tokenization_bert import BertTokenizerFast
 
 from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
diff --git a/tests/models/test_actions_model_restricted.py b/tests/models/test_actions_model_restricted.py
index 9633876..7446675 100644
--- a/tests/models/test_actions_model_restricted.py
+++ b/tests/models/test_actions_model_restricted.py
@@ -1,8 +1,10 @@
 import torch
-import torch.distributions as dist
 from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.actions_model_restricted import ActionsModelRestricted, ActionsModelRestrictedLoss
+from src.models.actions_model_restricted import (
+    ActionsModelRestricted,
+    ActionsModelRestrictedLoss,
+)
 
 
 def test_dimensions():
@@ -27,13 +29,19 @@ def test_loss_dimensions():
     batch_size = 5
     sequence_len = 10
     action_vector_size = 4
-    prior_probs = torch.tensor([0.3, 0.2, 0.3, 0.5]).log()
-    loss = ActionsModelRestrictedLoss(prior_probs)
+    uppercase_odds = torch.tensor(0.3, dtype=torch.float)
+    punctuation_weights = torch.tensor([0.3, 0.3, 0.1], dtype=torch.float)
+    loss = ActionsModelRestrictedLoss(uppercase_odds, punctuation_weights)
 
-    actions_vector_true = torch.zeros((batch_size, sequence_len, action_vector_size))
+    actions_vector_true = torch.zeros(
+        (batch_size, sequence_len, action_vector_size), dtype=torch.float
+    )
     actions_vector_true[:, :, -1] = 1.0
 
-    actions_vector_bad = torch.ones((batch_size, sequence_len, action_vector_size))
+    actions_vector_bad = torch.zeros(
+        (batch_size, sequence_len, action_vector_size), dtype=torch.float
+    )
+    actions_vector_bad[:, :, :2] = 1.0
     actions_vector_bad[:, :, -1] = 0.0
 
     result = loss(actions_vector_true, actions_vector_bad)
@@ -46,3 +54,5 @@ def test_loss_dimensions():
     print(result_bad)
 
     assert result_perfect < result_bad
+    assert result_perfect > 0
+    assert result_bad > 0
-- 
GitLab


From f433eaa2a40c46268418fec2fbfe7fd28e950a1f Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 20 Aug 2020 16:17:42 +0200
Subject: [PATCH 070/116] Added unit test for mixed model dimensions

---
 tests/models/test_actions_model_mixed.py      | 61 +++++++++++++++++++
 tests/models/test_actions_model_restricted.py |  5 +-
 2 files changed, 62 insertions(+), 4 deletions(-)
 create mode 100644 tests/models/test_actions_model_mixed.py

diff --git a/tests/models/test_actions_model_mixed.py b/tests/models/test_actions_model_mixed.py
new file mode 100644
index 0000000..3e81134
--- /dev/null
+++ b/tests/models/test_actions_model_mixed.py
@@ -0,0 +1,61 @@
+import torch
+from transformers.tokenization_bert import BertTokenizerFast
+
+from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
+from src.models.actions_model_mixed import ActionsModelMixed, ActionsModelMixedLoss
+
+
+def test_dimensions():
+    base_model = "dkleczek/bert-base-polish-cased-v1"
+    action_vector_size = 5
+
+    tokenizer = BertTokenizerFast.from_pretrained(base_model)
+    tokens = tokenizer("Ala ma kota", return_tensors="pt")
+
+    embedding_size = 200
+    num_heads = 4
+    num_layers = 2
+    feedforward_neurons = 100
+    max_len = 500
+    dropout = 0.1
+    model = ActionsModelMixed(
+        tokenizer.vocab_size,
+        embedding_size,
+        num_heads,
+        num_layers,
+        feedforward_neurons,
+        action_vector_size,
+        max_len,
+        dropout,
+    )
+
+    actions_len = 3
+    actions = torch.distributions.Multinomial(
+        1, torch.tensor([0.5] * action_vector_size)
+    ).sample((tokens["input_ids"].shape[0], actions_len))
+
+    result = model(tokens["input_ids"], actions, tokens["attention_mask"])
+
+    assert len(result.shape) == 3
+
+    assert result.shape[0] == tokens["input_ids"].shape[0]
+    assert result.shape[1] == actions_len
+    assert result.shape[2] == action_vector_size
+
+
+def test_loss_dimensions():
+    batch_size = 5
+    sequence_len = 10
+    actions_size = 3
+    prior_odds = torch.zeros(actions_size) + 0.3
+    actions_vector_true = torch.zeros((batch_size, sequence_len, actions_size))
+    actions_vector_bad = torch.ones((batch_size, sequence_len, actions_size))
+    loss = ActionsModelMixedLoss(prior_odds)
+
+    result = loss(actions_vector_true, actions_vector_bad)
+    assert len(result.shape) == 0
+
+    result_perfect = loss(actions_vector_true, actions_vector_true)
+    result_bad = loss(actions_vector_true, actions_vector_bad)
+
+    assert result_perfect < result_bad
diff --git a/tests/models/test_actions_model_restricted.py b/tests/models/test_actions_model_restricted.py
index 7446675..eeae5ec 100644
--- a/tests/models/test_actions_model_restricted.py
+++ b/tests/models/test_actions_model_restricted.py
@@ -1,10 +1,7 @@
 import torch
 from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.actions_model_restricted import (
-    ActionsModelRestricted,
-    ActionsModelRestrictedLoss,
-)
+from src.models.actions_model_restricted import ActionsModelRestricted, ActionsModelRestrictedLoss
 
 
 def test_dimensions():
-- 
GitLab


From 2fa77e42d5cc24f8096f38e1fd5331c42ba10083 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 20 Aug 2020 16:25:47 +0200
Subject: [PATCH 071/116] removed unused imports

---
 tests/models/test_actions_model_mixed.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/tests/models/test_actions_model_mixed.py b/tests/models/test_actions_model_mixed.py
index 3e81134..591c738 100644
--- a/tests/models/test_actions_model_mixed.py
+++ b/tests/models/test_actions_model_mixed.py
@@ -1,7 +1,6 @@
 import torch
 from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
 from src.models.actions_model_mixed import ActionsModelMixed, ActionsModelMixedLoss
 
 
-- 
GitLab


From 91f70867c5479e53ffc7edfa0a982840659f1f4c Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 20 Aug 2020 16:39:53 +0200
Subject: [PATCH 072/116] Might fix python crash on CI

---
 src/models/actions_model_mixed.py        | 2 --
 tests/models/test_actions_model_mixed.py | 6 +++---
 2 files changed, 3 insertions(+), 5 deletions(-)

diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index 8eee0e7..00da61d 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -1,7 +1,6 @@
 import numpy as np
 import torch
 import torch.nn as nn
-from torch import device
 from torch.nn.modules.loss import BCEWithLogitsLoss
 from transformers.tokenization_bert import BertTokenizerFast
 
@@ -42,7 +41,6 @@ class ActionsModelMixed(nn.Module):
         super(ActionsModelMixed, self).__init__()
 
         self.num_labels = num_labels
-        self.device = device
 
         # Word embedder
         self.word_embedding = nn.Embedding(vocab_size, embedding_size)
diff --git a/tests/models/test_actions_model_mixed.py b/tests/models/test_actions_model_mixed.py
index 591c738..6e05bdd 100644
--- a/tests/models/test_actions_model_mixed.py
+++ b/tests/models/test_actions_model_mixed.py
@@ -11,10 +11,10 @@ def test_dimensions():
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
     tokens = tokenizer("Ala ma kota", return_tensors="pt")
 
-    embedding_size = 200
-    num_heads = 4
+    embedding_size = 20
+    num_heads = 2
     num_layers = 2
-    feedforward_neurons = 100
+    feedforward_neurons = 10
     max_len = 500
     dropout = 0.1
     model = ActionsModelMixed(
-- 
GitLab


From bf7ea33e2ec19c49f2516b3bca2f6c7dc1b4ae1d Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 20 Aug 2020 17:00:10 +0200
Subject: [PATCH 073/116] Made requirements more strict in tox enviroment

---
 requirements.txt |  3 ---
 tox.ini          | 12 +-----------
 2 files changed, 1 insertion(+), 14 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index 4e2d19a..645f1a5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -32,10 +32,8 @@ psutil==5.7.2
 py==1.9.0
 pyarrow==0.17.1
 pycurl==7.43.0
-pygobject==3.20.0
 pyparsing==2.4.7
 pytest==6.0.1
-python-apt==1.1.0b1+ubuntu0.16.4.9
 python-dateutil==2.8.1
 pytz==2020.1
 PyYAML==5.3.1
@@ -56,7 +54,6 @@ tornado==6.0.4
 tqdm==4.48.2
 transformers==3.0.2
 typing-extensions==3.7.4.2
-unattended-upgrades==0.1
 urllib3==1.25.10
 zict==2.0.0
 git+https://gitlab.clarin-pl.eu/nlpworkers/nlp_ws.git@fa5f09a2f1447cac2c411c9d9e3d927ecd815ddc#egg=nlp_ws
\ No newline at end of file
diff --git a/tox.ini b/tox.ini
index 4326963..cd8520e 100644
--- a/tox.ini
+++ b/tox.ini
@@ -3,17 +3,7 @@ envlist = unittest,pep8
 skipsdist = True
 
 [testenv]
-deps = 
-    pytest
-    numpy
-    pyyaml
-    pandas 
-    tqdm 
-    torch 
-    dask[complete] 
-    transformers 
-    pyarrow==0.17.1
-    lxml
+deps = -rrequirements.txt
 
 [testenv:unittest]
 commands = pytest --ignore data --ignore generated
-- 
GitLab


From 3486197f2f8895b620a6f025e24dfff7da74cbb7 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 20 Aug 2020 17:11:13 +0200
Subject: [PATCH 074/116] Use older torch version

---
 requirements.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/requirements.txt b/requirements.txt
index 645f1a5..820114d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -49,7 +49,7 @@ tblib==1.7.0
 tokenizers==0.8.1rc1
 toml==0.10.1
 toolz==0.10.0
-torch==1.6.0
+torch==1.5.1
 tornado==6.0.4
 tqdm==4.48.2
 transformers==3.0.2
-- 
GitLab


From 16d77f32ba11e6eec451b260633693d80c699810 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 24 Aug 2020 11:44:44 +0200
Subject: [PATCH 075/116] Metrics calculation, small fixes to base

---
 src/models/actions_model_base.py              |  5 +-
 src/models/actions_model_mixed.py             | 52 +++++++++++--
 src/pipelines/actions_based/processing.py     |  2 -
 src/pipelines/actions_based/scoring.py        | 45 +++++++++++
 tests/pipelines/actions_based/test_scoring.py | 76 +++++++++++++++++++
 worker.py                                     |  3 +
 6 files changed, 172 insertions(+), 11 deletions(-)
 create mode 100644 src/pipelines/actions_based/scoring.py
 create mode 100644 tests/pipelines/actions_based/test_scoring.py

diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index 064cc91..f79c59c 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -37,7 +37,8 @@ class ActionsModelBase(nn.Module):
         Returns:
             torch.Tensor: Predicted actions vector
         """
-        y_pred = self.core(input_ids=input_ids, attention_mask=attention_mask)[0]
+        y_pred = self.core(input_ids=input_ids,
+                           attention_mask=attention_mask)[0]
 
         return y_pred
 
@@ -54,7 +55,7 @@ class ActionsModelBaseLoss(nn.Module):
         """
         super(ActionsModelBaseLoss, self).__init__()
 
-        self.core = BCEWithLogitsLoss(prior_odds)
+        self.core = BCEWithLogitsLoss(pos_weight=prior_odds)
 
     def forward(
         self,
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index 00da61d..d2f51cc 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -9,6 +9,7 @@ from src.pipelines.actions_based.processing import (
     ACTIONS_KEYS,
     action_vector,
     recover_text,
+    token_labels_to_word_labels,
 )
 
 
@@ -41,6 +42,7 @@ class ActionsModelMixed(nn.Module):
         super(ActionsModelMixed, self).__init__()
 
         self.num_labels = num_labels
+        self.device = "cpu"
 
         # Word embedder
         self.word_embedding = nn.Embedding(vocab_size, embedding_size)
@@ -111,6 +113,35 @@ class ActionsModelMixed(nn.Module):
 
         return self.to_labels(z)
 
+    def to(self, device):
+        self.device = device
+
+        super(ActionsModelMixed, self).to(device)
+
+    def predict_raw(
+        self,
+        input_ids: torch.Tensor,
+        attention_mask: torch.Tensor,
+        threshold: float = 0.9,
+    ):
+        target_device = self.device
+
+        input_ids = input_ids.to(target_device)
+        attention_mask = (attention_mask == 0).to(target_device)
+        inputs = torch.tensor([action_vector(["upper_case"])], dtype=torch.float).to(
+            target_device
+        )
+
+        for _ in range(input_ids.shape[1] - 2):
+            prediction_raw = self.forward(
+                input_ids, inputs.reshape(1, -1, self.num_labels), attention_mask
+            ).sigmoid()
+
+            new_output = (prediction_raw[0, -1:, :] > threshold).astype(torch.float)
+            inputs = torch.stack([inputs, new_output], dim=1)
+
+        return inputs
+
     def predict(
         self, text: str, tokenizer: BertTokenizerFast, threshold: float = 0.9
     ) -> str:
@@ -118,19 +149,26 @@ class ActionsModelMixed(nn.Module):
 
         text_tokenized = tokenizer(text, return_tensors="pt")
 
-        for _ in range(text_tokenized["input_ids"].shape[1] - 1):
+        target_device = self.device
+
+        for _ in range(text_tokenized["input_ids"].shape[1] - 2):
             prediction_raw = self.forward(
-                text_tokenized["input_ids"],
-                torch.tensor(inputs, dtype=torch.float).reshape(1, -1, self.num_labels),
-                text_tokenized["attention_mask"] == 0,
+                text_tokenized["input_ids"].to(target_device),
+                torch.tensor(inputs, dtype=torch.float)
+                .reshape(1, -1, self.num_labels)
+                .to(target_device),
+                (text_tokenized["attention_mask"] == 0).to(target_device),
             ).sigmoid()
 
             inputs.append(
-                (prediction_raw.detach().numpy()[0, -1, :] > threshold).astype(np.float)
+                (prediction_raw.detach().cpu().numpy()[0, -1, :] > threshold).astype(
+                    np.float
+                )
             )
 
-        inputs = np.array(inputs)[1:]
-        prediction_binary = inputs.astype(np.int)
+        word_labels = token_labels_to_word_labels(text, inputs[1:], tokenizer)
+
+        prediction_binary = word_labels.astype(np.int)
 
         return recover_text(text, prediction_binary)
 
diff --git a/src/pipelines/actions_based/processing.py b/src/pipelines/actions_based/processing.py
index 7057c0b..f65b9b4 100644
--- a/src/pipelines/actions_based/processing.py
+++ b/src/pipelines/actions_based/processing.py
@@ -287,8 +287,6 @@ def token_word_mapping(text: str, tokenizer: PreTrainedTokenizerFast) -> np.ndar
     text_tokenized = tokenizer(text, return_offsets_mapping=True)
     offset_mappings = text_tokenized["offset_mapping"][1:-1]
 
-    offset_mappings = text_tokenized["offset_mapping"][1:-1]
-
     # Create a map where each character is assigned index of it's word
     words_mapping = []
     actual_word = 0
diff --git a/src/pipelines/actions_based/scoring.py b/src/pipelines/actions_based/scoring.py
new file mode 100644
index 0000000..9e543c1
--- /dev/null
+++ b/src/pipelines/actions_based/scoring.py
@@ -0,0 +1,45 @@
+from typing import List
+import numpy as np
+from sklearn.metrics import roc_curve, auc
+
+
+def predictions_threshold(
+    predictions: np.ndarray, threshold: float = 0.9
+) -> np.ndarray:
+    return (predictions > threshold).astype(np.float)
+
+
+def compute_accuracy(predictions: np.ndarray, targets: np.ndarray) -> np.ndarray:
+    return (
+        np.sum((predictions == targets).astype(
+            np.int), axis=0) / predictions.shape[0]
+    )
+
+
+def multiclass_roc_curve(target: np.ndarray, predictions: np.ndarray) -> List[np.ndarray]:
+    class_fprs = []
+    class_tprs = []
+    class_thresholds = []
+
+    for index in range(predictions.shape[-1]):
+        fpr, tpr, thresholds = roc_curve(
+            target[:, index], predictions[:, index])
+
+        class_fprs.append(fpr)
+        class_tprs.append(tpr)
+        class_thresholds.append(thresholds)
+
+    return class_fprs, class_tprs, class_thresholds
+
+
+def multiclass_auc(false_positive_rate: List[np.ndarray], true_positive_rate: List[np.ndarray]) -> np.ndarray:
+
+    assert len(false_positive_rate) == len(true_positive_rate)
+
+    num_classes = len(false_positive_rate)
+    auc_list = np.zeros(num_classes)
+
+    for i in range(num_classes):
+        auc_list[i] = auc(false_positive_rate[i], true_positive_rate[i])
+
+    return auc_list
diff --git a/tests/pipelines/actions_based/test_scoring.py b/tests/pipelines/actions_based/test_scoring.py
new file mode 100644
index 0000000..f394e58
--- /dev/null
+++ b/tests/pipelines/actions_based/test_scoring.py
@@ -0,0 +1,76 @@
+import numpy as np
+from numpy.testing import assert_allclose, assert_array_equal
+from sklearn.metrics import accuracy_score, auc
+
+from src.pipelines.actions_based.scoring import multiclass_auc, multiclass_roc_curve, predictions_threshold
+
+
+def test_predictions_threshold():
+    threshold = 0.5
+    predictions = np.array(
+        [[[0.3, 0.6, 0.1, 0.2, 0.9], [0.3, 0.6, 0.1, 0.2, 0.9]]])
+    expected = np.array(
+        [[[0.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]]])
+
+    got = predictions_threshold(predictions, threshold)
+
+    assert np.all(got == expected)
+
+
+def test_compute_accuracy():
+    predictions = np.array(
+        [[0.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]])
+    ideal = np.array([[0.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]])
+    half = np.array([[1.0, 0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 0.0, 0.0, 1.0]])
+    zero = np.array([[1.0, 0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 1.0, 1.0, 0.0]])
+
+    assert_allclose(accuracy_score(predictions, ideal),
+                    [1.0, 1.0, 1.0, 1.0, 1.0])
+    assert_allclose(accuracy_score(predictions, half),
+                    [0.5, 0.5, 0.5, 0.5, 0.5])
+    assert_allclose(accuracy_score(predictions, zero),
+                    [0.0, 0.0, 0.0, 0.0, 0.0])
+
+
+def test_multiclass_roc_curve():
+    predictions = np.array(
+        [[0.3, 0.2, 0.1, 0.3, 0.1], [0.7, 0.5, 0.1, 0.2, 0.9]])
+    expected = np.array([[0.0, 1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 1.0, 1.0, 0.0]])
+
+    fpr, tpr, thresholds = multiclass_roc_curve(expected, predictions)
+
+    assert len(thresholds) == expected.shape[1]
+
+    # Thresholds
+    assert_allclose(thresholds[0], [1.7, 0.7, 0.3])
+    assert_allclose(thresholds[1], [1.5, 0.5, 0.2])
+    assert_allclose(thresholds[2], [1.1, 0.1])
+    assert_allclose(thresholds[3], [1.3, 0.3, 0.2])
+    assert_allclose(thresholds[4], [1.9, 0.9, 0.1])
+
+    # False positive rate
+    assert_array_equal(fpr[0], [0.0, 0.0, 1.0])
+    assert_array_equal(fpr[1], [0.0, 1.0, 1.0])
+    assert_array_equal(fpr[2], [0.0, 1.0])
+    assert_array_equal(fpr[3], [0.0, 1.0, 1.0])
+    assert_array_equal(fpr[4], [0.0, 1.0, 1.0])
+
+    # True positive rate
+    assert_array_equal(tpr[0], [0.0, 1.0, 1.0])
+    assert_array_equal(tpr[1], [0.0, 0.0, 1.0])
+    assert_array_equal(tpr[2], [0.0, 1.0])
+    assert_array_equal(tpr[3], [0.0, 0.0, 1.0])
+    assert_array_equal(tpr[4], [0.0, 0.0, 1.0])
+
+
+def test_multiclass_auc():
+    predictions = np.array(
+        [[0.3, 0.2, 0.1, 0.3, 0.1], [0.7, 0.5, 0.1, 0.2, 0.9]])
+    expected = np.array([[0.0, 1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 1.0, 1.0, 0.0]])
+
+    fpr, tpr, _ = multiclass_roc_curve(expected, predictions)
+    result = multiclass_auc(np.array(fpr), np.array(tpr))
+
+    assert len(result) == 5
+    assert np.all(result >= 0)
+    assert np.all(result <= 1)
diff --git a/worker.py b/worker.py
index cc5790a..0b4beb5 100755
--- a/worker.py
+++ b/worker.py
@@ -26,6 +26,9 @@ class Worker(nlp_ws.NLPWorker):
     def process(self, input_file: str, task_options: dict, output_file: str) -> None:
         """Implementation of example tasks that copies files."""
 
+        model_type = 'action_base'
+        model_version = '1-0-80000'
+
         with open(input_file, "r") as f:
             text = input_preprocess(f.read())
             text_processed = apply_actions_punctuation(
-- 
GitLab


From d6c8f6f38bb3311d413f5459ea935b3e5a0e7288 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 24 Aug 2020 11:46:54 +0200
Subject: [PATCH 076/116] Training stability fixes

---
 src/pipelines/train.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/src/pipelines/train.py b/src/pipelines/train.py
index 087c506..4d97671 100644
--- a/src/pipelines/train.py
+++ b/src/pipelines/train.py
@@ -140,6 +140,9 @@ class TrainerBase(ABC):
             for data_batch in get_batches(
                 df, self.batch_size, self.batch_buffer_size, random_index_shuffle, i
             ):
+                if len(data_batch) == 0:
+                    continue
+
                 loss = self.calc_loss(data_batch)
 
                 losses.append(loss.item())
@@ -171,5 +174,7 @@ class TrainerBase(ABC):
 
                 i += 1
 
+            sample_start = 0
+
         if not training_stopped:
             save_training_step(self.output_path, "final", self.model, self.optimizer)
-- 
GitLab


From e4716d12400fa2864ddd115a7863f36de5a8af77 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 24 Aug 2020 12:06:57 +0200
Subject: [PATCH 077/116] Fixed to(device) resulting in none

---
 src/pipelines/actions_based/train_mixed.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index be9b636..1266133 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -52,7 +52,8 @@ class TrainerActions(TrainerBase):
 
         np.random.seed(seed=seed)
 
-        device = torch.device(device_name if torch.cuda.is_available() else "cpu")
+        device = torch.device(
+            device_name if torch.cuda.is_available() else "cpu")
 
         tokenizer = BertTokenizerFast.from_pretrained(base_model)
         model = ActionsModelMixed(
@@ -64,7 +65,8 @@ class TrainerActions(TrainerBase):
             len(ACTIONS_KEYS),
             500,
             dropout,
-        ).to(device)
+        )
+        model.to(device)
         self.criterion = ActionsModelMixedLoss(pos_weight).to(device)
         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
@@ -100,7 +102,8 @@ class TrainerActions(TrainerBase):
             .squeeze(dim=2)
         )
 
-        outputs = torch.tensor(np.stack(outputs), dtype=torch.float).to(self.device)
+        outputs = torch.tensor(
+            np.stack(outputs), dtype=torch.float).to(self.device)
 
         # Convert to boolean
         attentions_mask = torch.tensor(np.stack(attentions_mask))
-- 
GitLab


From 9870ebcce3bbc0279d4229a4b4da826e69a17590 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 24 Aug 2020 13:45:03 +0200
Subject: [PATCH 078/116] Base training fix

---
 .dockerignore                                 |  3 +-
 requirements.txt                              |  1 +
 src/models/actions_model_base.py              |  5 ++-
 src/models/actions_model_mixed.py             | 31 ++++++++++++++---
 src/pipelines/actions_based/scoring.py        | 17 ++++++----
 .../actions_based/stage1_extraction.py        |  5 ++-
 .../actions_based/stage2_tokenization.py      |  5 ++-
 .../actions_based/stage3_exploding.py         |  7 +++-
 src/pipelines/actions_based/train_mixed.py    |  6 ++--
 .../translation_based/stage1_extraction.py    |  5 ++-
 .../stage2_create_batches.py                  |  5 ++-
 tests/models/test_actions_model_base.py       |  4 +--
 tests/models/test_actions_model_restricted.py |  5 ++-
 tests/pipelines/actions_based/test_scoring.py | 34 ++++++++-----------
 worker.py                                     |  3 --
 15 files changed, 86 insertions(+), 50 deletions(-)

diff --git a/.dockerignore b/.dockerignore
index 19d3acf..06d8c56 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -12,4 +12,5 @@ dask-worker-space
 data
 generated
 notebooks
-tests
\ No newline at end of file
+tests
+deploy
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 820114d..700a3a8 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -56,4 +56,5 @@ transformers==3.0.2
 typing-extensions==3.7.4.2
 urllib3==1.25.10
 zict==2.0.0
+scikit-learn==0.23.2
 git+https://gitlab.clarin-pl.eu/nlpworkers/nlp_ws.git@fa5f09a2f1447cac2c411c9d9e3d927ecd815ddc#egg=nlp_ws
\ No newline at end of file
diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index f79c59c..81cabe1 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -37,8 +37,7 @@ class ActionsModelBase(nn.Module):
         Returns:
             torch.Tensor: Predicted actions vector
         """
-        y_pred = self.core(input_ids=input_ids,
-                           attention_mask=attention_mask)[0]
+        y_pred = self.core(input_ids=input_ids, attention_mask=attention_mask)[0]
 
         return y_pred
 
@@ -59,8 +58,8 @@ class ActionsModelBaseLoss(nn.Module):
 
     def forward(
         self,
-        true_action_vector: torch.Tensor,
         predicted_action_vector_logits: torch.Tensor,
+        true_action_vector: torch.Tensor,
     ) -> torch.Tensor:
         """Computes ActionsModelBase loss
 
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index d2f51cc..8ac456e 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
 import numpy as np
 import torch
 import torch.nn as nn
@@ -123,6 +125,7 @@ class ActionsModelMixed(nn.Module):
         input_ids: torch.Tensor,
         attention_mask: torch.Tensor,
         threshold: float = 0.9,
+        max_cond_len: Optional[int] = None,
     ):
         target_device = self.device
 
@@ -132,9 +135,16 @@ class ActionsModelMixed(nn.Module):
             target_device
         )
 
+        if max_cond_len is None:
+            max_cond_len = np.iinfo(np.int).max
+
         for _ in range(input_ids.shape[1] - 2):
+            input_start = max(0, len(inputs) - max_cond_len)
+
             prediction_raw = self.forward(
-                input_ids, inputs.reshape(1, -1, self.num_labels), attention_mask
+                input_ids[:, input_start:],
+                inputs[:, input_start:].reshape(1, -1, self.num_labels),
+                attention_mask,
             ).sigmoid()
 
             new_output = (prediction_raw[0, -1:, :] > threshold).astype(torch.float)
@@ -143,7 +153,11 @@ class ActionsModelMixed(nn.Module):
         return inputs
 
     def predict(
-        self, text: str, tokenizer: BertTokenizerFast, threshold: float = 0.9
+        self,
+        text: str,
+        tokenizer: BertTokenizerFast,
+        threshold: float = 0.9,
+        max_cond_len: Optional[int] = None,
     ) -> str:
         inputs = [action_vector(["upper_case"])]
 
@@ -151,13 +165,20 @@ class ActionsModelMixed(nn.Module):
 
         target_device = self.device
 
+        if max_cond_len is None:
+            max_cond_len = np.iinfo(np.int).max
+
         for _ in range(text_tokenized["input_ids"].shape[1] - 2):
+            input_start = max(0, len(inputs) - max_cond_len)
+
             prediction_raw = self.forward(
-                text_tokenized["input_ids"].to(target_device),
-                torch.tensor(inputs, dtype=torch.float)
+                text_tokenized["input_ids"][:, input_start:].to(target_device),
+                torch.tensor(inputs[input_start:], dtype=torch.float)
                 .reshape(1, -1, self.num_labels)
                 .to(target_device),
-                (text_tokenized["attention_mask"] == 0).to(target_device),
+                (text_tokenized["attention_mask"][:, input_start:] == 0).to(
+                    target_device
+                ),
             ).sigmoid()
 
             inputs.append(
diff --git a/src/pipelines/actions_based/scoring.py b/src/pipelines/actions_based/scoring.py
index 9e543c1..254d557 100644
--- a/src/pipelines/actions_based/scoring.py
+++ b/src/pipelines/actions_based/scoring.py
@@ -1,6 +1,7 @@
 from typing import List
+
 import numpy as np
-from sklearn.metrics import roc_curve, auc
+from sklearn.metrics import auc, roc_curve
 
 
 def predictions_threshold(
@@ -11,19 +12,19 @@ def predictions_threshold(
 
 def compute_accuracy(predictions: np.ndarray, targets: np.ndarray) -> np.ndarray:
     return (
-        np.sum((predictions == targets).astype(
-            np.int), axis=0) / predictions.shape[0]
+        np.sum((predictions == targets).astype(np.int), axis=0) / predictions.shape[0]
     )
 
 
-def multiclass_roc_curve(target: np.ndarray, predictions: np.ndarray) -> List[np.ndarray]:
+def multiclass_roc_curve(
+    target: np.ndarray, predictions: np.ndarray
+) -> List[np.ndarray]:
     class_fprs = []
     class_tprs = []
     class_thresholds = []
 
     for index in range(predictions.shape[-1]):
-        fpr, tpr, thresholds = roc_curve(
-            target[:, index], predictions[:, index])
+        fpr, tpr, thresholds = roc_curve(target[:, index], predictions[:, index])
 
         class_fprs.append(fpr)
         class_tprs.append(tpr)
@@ -32,7 +33,9 @@ def multiclass_roc_curve(target: np.ndarray, predictions: np.ndarray) -> List[np
     return class_fprs, class_tprs, class_thresholds
 
 
-def multiclass_auc(false_positive_rate: List[np.ndarray], true_positive_rate: List[np.ndarray]) -> np.ndarray:
+def multiclass_auc(
+    false_positive_rate: List[np.ndarray], true_positive_rate: List[np.ndarray]
+) -> np.ndarray:
 
     assert len(false_positive_rate) == len(true_positive_rate)
 
diff --git a/src/pipelines/actions_based/stage1_extraction.py b/src/pipelines/actions_based/stage1_extraction.py
index 5a058a9..94dc26c 100644
--- a/src/pipelines/actions_based/stage1_extraction.py
+++ b/src/pipelines/actions_based/stage1_extraction.py
@@ -6,7 +6,10 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.actions_based.processing import APPLY_FILE_PROCESSING_META, apply_file_processing
+from src.pipelines.actions_based.processing import (
+    APPLY_FILE_PROCESSING_META,
+    apply_file_processing,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/actions_based/stage2_tokenization.py b/src/pipelines/actions_based/stage2_tokenization.py
index 0ea3586..b30445f 100644
--- a/src/pipelines/actions_based/stage2_tokenization.py
+++ b/src/pipelines/actions_based/stage2_tokenization.py
@@ -4,7 +4,10 @@ import dask.dataframe as dd
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.actions_based.processing import APPLY_TOKENIZATION_META, apply_tokenization
+from src.pipelines.actions_based.processing import (
+    APPLY_TOKENIZATION_META,
+    apply_tokenization,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
diff --git a/src/pipelines/actions_based/stage3_exploding.py b/src/pipelines/actions_based/stage3_exploding.py
index 81dc965..72ec128 100644
--- a/src/pipelines/actions_based/stage3_exploding.py
+++ b/src/pipelines/actions_based/stage3_exploding.py
@@ -2,7 +2,12 @@
 import dask.dataframe as dd
 from dask.distributed import Client
 
-from src.processing import EXPAND_DIMS_META, FLATTEN_DIMS_META, expand_dims, flatten_dims
+from src.processing import (
+    EXPAND_DIMS_META,
+    FLATTEN_DIMS_META,
+    expand_dims,
+    flatten_dims,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index 1266133..666fa51 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -52,8 +52,7 @@ class TrainerActions(TrainerBase):
 
         np.random.seed(seed=seed)
 
-        device = torch.device(
-            device_name if torch.cuda.is_available() else "cpu")
+        device = torch.device(device_name if torch.cuda.is_available() else "cpu")
 
         tokenizer = BertTokenizerFast.from_pretrained(base_model)
         model = ActionsModelMixed(
@@ -102,8 +101,7 @@ class TrainerActions(TrainerBase):
             .squeeze(dim=2)
         )
 
-        outputs = torch.tensor(
-            np.stack(outputs), dtype=torch.float).to(self.device)
+        outputs = torch.tensor(np.stack(outputs), dtype=torch.float).to(self.device)
 
         # Convert to boolean
         attentions_mask = torch.tensor(np.stack(attentions_mask))
diff --git a/src/pipelines/translation_based/stage1_extraction.py b/src/pipelines/translation_based/stage1_extraction.py
index 386211d..6ffdbf7 100644
--- a/src/pipelines/translation_based/stage1_extraction.py
+++ b/src/pipelines/translation_based/stage1_extraction.py
@@ -6,7 +6,10 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.translation_based.processing import RAW_TO_DATAFRAME_META, raw_to_dataframe
+from src.pipelines.translation_based.processing import (
+    RAW_TO_DATAFRAME_META,
+    raw_to_dataframe,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/translation_based/stage2_create_batches.py b/src/pipelines/translation_based/stage2_create_batches.py
index ade8bf2..83a2edc 100644
--- a/src/pipelines/translation_based/stage2_create_batches.py
+++ b/src/pipelines/translation_based/stage2_create_batches.py
@@ -4,7 +4,10 @@ from dask import delayed
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.translation_based.processing import GENERATE_BATCHES_META, generate_batches
+from src.pipelines.translation_based.processing import (
+    GENERATE_BATCHES_META,
+    generate_batches,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
diff --git a/tests/models/test_actions_model_base.py b/tests/models/test_actions_model_base.py
index ecca181..5fd17b3 100644
--- a/tests/models/test_actions_model_base.py
+++ b/tests/models/test_actions_model_base.py
@@ -31,10 +31,10 @@ def test_loss_dimensions():
     actions_vector_bad = torch.ones((batch_size, sequence_len, actions_size))
     loss = ActionsModelBaseLoss(weights)
 
-    result = loss(actions_vector_true, actions_vector_bad)
+    result = loss(actions_vector_bad, actions_vector_true)
     assert len(result.shape) == 0
 
     result_perfect = loss(actions_vector_true, actions_vector_true)
-    result_bad = loss(actions_vector_true, actions_vector_bad)
+    result_bad = loss(actions_vector_bad, actions_vector_true)
 
     assert result_perfect < result_bad
diff --git a/tests/models/test_actions_model_restricted.py b/tests/models/test_actions_model_restricted.py
index eeae5ec..7446675 100644
--- a/tests/models/test_actions_model_restricted.py
+++ b/tests/models/test_actions_model_restricted.py
@@ -1,7 +1,10 @@
 import torch
 from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.actions_model_restricted import ActionsModelRestricted, ActionsModelRestrictedLoss
+from src.models.actions_model_restricted import (
+    ActionsModelRestricted,
+    ActionsModelRestrictedLoss,
+)
 
 
 def test_dimensions():
diff --git a/tests/pipelines/actions_based/test_scoring.py b/tests/pipelines/actions_based/test_scoring.py
index f394e58..8d48f5c 100644
--- a/tests/pipelines/actions_based/test_scoring.py
+++ b/tests/pipelines/actions_based/test_scoring.py
@@ -1,16 +1,18 @@
 import numpy as np
 from numpy.testing import assert_allclose, assert_array_equal
-from sklearn.metrics import accuracy_score, auc
+from sklearn.metrics import accuracy_score
 
-from src.pipelines.actions_based.scoring import multiclass_auc, multiclass_roc_curve, predictions_threshold
+from src.pipelines.actions_based.scoring import (
+    multiclass_auc,
+    multiclass_roc_curve,
+    predictions_threshold,
+)
 
 
 def test_predictions_threshold():
     threshold = 0.5
-    predictions = np.array(
-        [[[0.3, 0.6, 0.1, 0.2, 0.9], [0.3, 0.6, 0.1, 0.2, 0.9]]])
-    expected = np.array(
-        [[[0.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]]])
+    predictions = np.array([[[0.3, 0.6, 0.1, 0.2, 0.9], [0.3, 0.6, 0.1, 0.2, 0.9]]])
+    expected = np.array([[[0.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]]])
 
     got = predictions_threshold(predictions, threshold)
 
@@ -18,23 +20,18 @@ def test_predictions_threshold():
 
 
 def test_compute_accuracy():
-    predictions = np.array(
-        [[0.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]])
+    predictions = np.array([[0.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]])
     ideal = np.array([[0.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]])
     half = np.array([[1.0, 0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 0.0, 0.0, 1.0]])
     zero = np.array([[1.0, 0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 1.0, 1.0, 0.0]])
 
-    assert_allclose(accuracy_score(predictions, ideal),
-                    [1.0, 1.0, 1.0, 1.0, 1.0])
-    assert_allclose(accuracy_score(predictions, half),
-                    [0.5, 0.5, 0.5, 0.5, 0.5])
-    assert_allclose(accuracy_score(predictions, zero),
-                    [0.0, 0.0, 0.0, 0.0, 0.0])
+    assert_allclose(accuracy_score(predictions, ideal), [1.0, 1.0, 1.0, 1.0, 1.0])
+    assert_allclose(accuracy_score(predictions, half), [0.5, 0.5, 0.5, 0.5, 0.5])
+    assert_allclose(accuracy_score(predictions, zero), [0.0, 0.0, 0.0, 0.0, 0.0])
 
 
 def test_multiclass_roc_curve():
-    predictions = np.array(
-        [[0.3, 0.2, 0.1, 0.3, 0.1], [0.7, 0.5, 0.1, 0.2, 0.9]])
+    predictions = np.array([[0.3, 0.2, 0.1, 0.3, 0.1], [0.7, 0.5, 0.1, 0.2, 0.9]])
     expected = np.array([[0.0, 1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 1.0, 1.0, 0.0]])
 
     fpr, tpr, thresholds = multiclass_roc_curve(expected, predictions)
@@ -64,12 +61,11 @@ def test_multiclass_roc_curve():
 
 
 def test_multiclass_auc():
-    predictions = np.array(
-        [[0.3, 0.2, 0.1, 0.3, 0.1], [0.7, 0.5, 0.1, 0.2, 0.9]])
+    predictions = np.array([[0.3, 0.2, 0.1, 0.3, 0.1], [0.7, 0.5, 0.1, 0.2, 0.9]])
     expected = np.array([[0.0, 1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 1.0, 1.0, 0.0]])
 
     fpr, tpr, _ = multiclass_roc_curve(expected, predictions)
-    result = multiclass_auc(np.array(fpr), np.array(tpr))
+    result = multiclass_auc(fpr, tpr)
 
     assert len(result) == 5
     assert np.all(result >= 0)
diff --git a/worker.py b/worker.py
index 0b4beb5..cc5790a 100755
--- a/worker.py
+++ b/worker.py
@@ -26,9 +26,6 @@ class Worker(nlp_ws.NLPWorker):
     def process(self, input_file: str, task_options: dict, output_file: str) -> None:
         """Implementation of example tasks that copies files."""
 
-        model_type = 'action_base'
-        model_version = '1-0-80000'
-
         with open(input_file, "r") as f:
             text = input_preprocess(f.read())
             text_processed = apply_actions_punctuation(
-- 
GitLab


From d40c920e81a97463c5eb9fafd793fb4ebefaf5b3 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 25 Aug 2020 13:17:10 +0200
Subject: [PATCH 079/116] Added utils

---
 src/utils.py                                  | 22 +++++++++++++++++++
 tests/pipelines/actions_based/test_scoring.py | 12 ----------
 2 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index 06bbaf4..d8acdff 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -3,6 +3,8 @@ import re
 import shutil
 from datetime import timedelta
 from typing import List, Optional
+import pandas as pd
+import numpy as np
 
 import yaml
 
@@ -104,6 +106,26 @@ def prepare_folder(path: str, wipe: bool = False) -> None:
     os.makedirs(path, exist_ok=True)
 
 
+def unflattened_column(df: pd.DataFrame, name: str) -> np.ndarray:
+    """Get column from the dataframe that was flattened. Dataframe must have columns
+    "name" and "name_shape", where name is 1D numpy array and name_shape is target 
+    shape of this numpy array.
+
+    Args:
+        df (pd.DataFrame): Dataframe from which to extract array
+        name (str): Name of the column
+
+    Returns:
+        np.ndarray: Unflattened mutlidiamenional column of shape Lx*(name_shape)
+    """
+
+    values = df.apply(
+        lambda x: x[name].reshape(x[f"{name}_shape"]), axis=1
+    ).values
+
+    return np.stack(values)
+
+
 def convert_to_timedelta(time_val: str) -> Optional[timedelta]:
     """
     src: https://code.activestate.com/recipes/577894-convert-strings-like-5d-and-60s-to-timedelta-objec/
diff --git a/tests/pipelines/actions_based/test_scoring.py b/tests/pipelines/actions_based/test_scoring.py
index 8d48f5c..62c524c 100644
--- a/tests/pipelines/actions_based/test_scoring.py
+++ b/tests/pipelines/actions_based/test_scoring.py
@@ -1,6 +1,5 @@
 import numpy as np
 from numpy.testing import assert_allclose, assert_array_equal
-from sklearn.metrics import accuracy_score
 
 from src.pipelines.actions_based.scoring import (
     multiclass_auc,
@@ -19,17 +18,6 @@ def test_predictions_threshold():
     assert np.all(got == expected)
 
 
-def test_compute_accuracy():
-    predictions = np.array([[0.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]])
-    ideal = np.array([[0.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]])
-    half = np.array([[1.0, 0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 0.0, 0.0, 1.0]])
-    zero = np.array([[1.0, 0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 1.0, 1.0, 0.0]])
-
-    assert_allclose(accuracy_score(predictions, ideal), [1.0, 1.0, 1.0, 1.0, 1.0])
-    assert_allclose(accuracy_score(predictions, half), [0.5, 0.5, 0.5, 0.5, 0.5])
-    assert_allclose(accuracy_score(predictions, zero), [0.0, 0.0, 0.0, 0.0, 0.0])
-
-
 def test_multiclass_roc_curve():
     predictions = np.array([[0.3, 0.2, 0.1, 0.3, 0.1], [0.7, 0.5, 0.1, 0.2, 0.9]])
     expected = np.array([[0.0, 1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 1.0, 1.0, 0.0]])
-- 
GitLab


From 20b29ee650c03adafefb7352312ab3f79731fad2 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 26 Aug 2020 03:28:57 +0200
Subject: [PATCH 080/116] work on fully-contained checkpoint files

---
 src/checkpoints.py                         | 116 +++++++++++++++
 src/models/actions_model_mixed.py          |  90 ++++++++----
 src/models/interfaces.py                   |  17 +++
 src/pipelines/actions_based/train_mixed.py | 161 +++++++++++----------
 src/training.py                            |   4 +-
 src/utils.py                               | 109 +++++++++++++-
 6 files changed, 386 insertions(+), 111 deletions(-)
 create mode 100644 src/checkpoints.py
 create mode 100644 src/models/interfaces.py

diff --git a/src/checkpoints.py b/src/checkpoints.py
new file mode 100644
index 0000000..4d97dac
--- /dev/null
+++ b/src/checkpoints.py
@@ -0,0 +1,116 @@
+from __future__ import annotations
+from datetime import date, datetime, timedelta
+from glob import glob
+from src.utils import moving_average, prepare_folder
+from src.training import latest_model
+from typing import Optional, Tuple, Type
+import torch
+import torch.nn as nn
+from torch.optim.optimizer import Optimizer
+from src.models.interfaces import PunctuationModel
+
+
+class Saver():
+    def __init__(self, save_dir: str, model: PunctuationModel, optimizer: Optimizer) -> None:
+        self.save_dir = save_dir
+        self.model = model
+        self.optimizer = optimizer
+
+        prepare_folder(self.save_dir)
+    
+    def save(self, name):
+        self.model.save(self.save_dir, name)
+        torch.save(self.optimizer.state_dict(), f"{self.save_dir}/{name}.optimizer")
+
+
+class Loader:
+    def __init__(self, save_dir: str, model_type: Type[PunctuationModel], optimizer_type: Type[Optimizer], device: str) -> None:
+        self.save_dir = save_dir
+        self.device = device
+
+        self.model_type = model_type
+        self.optimizer_type = optimizer_type
+
+    def has_checkpoints(self) -> bool:
+        files = glob(f"{self.save_dir}/*.model")
+
+        return (latest_model(files) is not None)
+
+
+    def load(self, name) -> Tuple[PunctuationModel, Optimizer]:
+        model = self.model_type.load(self.save_dir, name)
+
+        optimizer = self.optimizer_type(model.parameters())
+        optimizer.load_state_dict(
+                torch.load(
+                    f"{self.save_dir}/{name}.optimizer",
+                    map_location=self.device
+                )
+            )
+
+        return model, optimizer
+
+    def load_latest(self) -> Tuple[PunctuationModel, Optimizer, int, int]:
+        files = glob(f"{self.save_dir}/*.model")
+
+        model_id = latest_model(files)    
+        if model_id is None:
+            return None
+
+        epoch, step = model_id
+        return self.load(f"{epoch}-{step}"), epoch, step
+
+
+class Checkpoint:
+    def __init__(self, save_step, saver: Saver, start_step: int, start_epoch: int) -> None:
+        self.start_step = start_step
+        self.start_epoch = start_epoch
+        self.save_step = save_step
+
+        self.saver = saver
+
+    def step(self, epoch, step) -> None:
+        if step % self.save_step == 0 and (
+            step != self.start_step or epoch != self.start_epoch
+            ):
+            print(f"Saving: Epoch {epoch}, step {step}")
+            self.saver.save(f"{epoch}-{step}")
+
+
+class Timeout:
+    def __init__(self, duration: timedelta, saver: Optional[Saver]) -> None:
+        self.saver = saver
+        self.duration = duration
+        self.time_max = None
+
+    def start(self, time_now: datetime = datetime.now()):
+        self.time_max = datetime.max
+        if self.duration is not None:
+            self.time_max = datetime.now() + self.max_train_time
+
+    def step(self, epoch, step, time = None) -> bool:
+        assert self.time_max is not None
+
+        if time is None:
+            time = datetime.now()
+
+        if time > self.time_max:
+            if self.checkpoint is not None:
+                print(f"Max time reached, saving: Epoch {epoch}, step {step}")
+                self.saver.save(f"{epoch}-{step}")
+
+            return True
+
+        return False
+
+class ProgressTracker:
+    def __init__(self, device: str, loss_averaging_span) -> None:
+        print(f"Training on {device}")
+        self.loss_averaging_span = loss_averaging_span
+        self.losses = []
+
+    def step(self, epoch, step, loss) -> None:
+        self.losses.append(loss.item())
+        loss_mean, self.losses = moving_average(self.losses, self.loss_averaging_span)
+
+        print(f"epoch: {epoch} | step: {step} | loss: {loss_mean}")
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index 8ac456e..12ef3ed 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -1,3 +1,5 @@
+from src.utils import pickle_read, pickle_save, prepare_folder
+from src.models.interfaces import PunctuationModel
 from typing import Optional
 
 import numpy as np
@@ -15,66 +17,82 @@ from src.pipelines.actions_based.processing import (
 )
 
 
-class ActionsModelMixed(nn.Module):
+from dataclasses import dataclass
+
+@dataclass
+class ActionsModelMixedParams:
+    """
+    Parameters for initializing ActionsModelMixed
+
+    Params:
+        vocab_size (int): Number of tokens in tokenizer dictionary
+        embedding_size (int, optional): Shape of word and punctuation embeddings. Defaults to 200.
+        num_heads (int, optional): Number of heads in multiheaded attention. Defaults to 4.
+        num_layers (int, optional): Number of both decoded and encoder layers. Defaults to 2.
+        feedforward_neurons (int, optional): Size of feed-forward neural network at the end of encoder/decoder. Defaults to 200.
+        num_labels (int, optional): Action-vector size. Defaults to len(ACTIONS_KEYS).
+        max_len (int, optional): Maxium length of sequence. Defaults to 500.
+        dropout (float, optional): Dropout ratio. Defaults to 0.1.
+    """
+    
+    vocab_size: int
+    embedding_size: int = 200
+    num_heads: int = 4
+    num_layers: int = 2
+    feedforward_neurons: int = 200
+    num_labels: int = len(ACTIONS_KEYS)
+    max_len: int = 500
+    dropout: float = 0.1
+
+
+class ActionsModelMixed(PunctuationModel):
     """Encoder-decoder based model with unpunctuated token sequence as input and array of action-vectors as output"""
 
     def __init__(
         self,
-        vocab_size: int,
-        embedding_size: int = 200,
-        num_heads: int = 4,
-        num_layers: int = 2,
-        feedforward_neurons: int = 200,
-        num_labels: int = len(ACTIONS_KEYS),
-        max_len: int = 500,
-        dropout: float = 0.1,
+        params: ActionsModelMixedParams
     ) -> None:
         """Initializes mixed model
 
         Args:
-            vocab_size (int): Number of tokens in tokenizer dictionary
-            embedding_size (int, optional): Shape of word and punctuation embeddings. Defaults to 200.
-            num_heads (int, optional): Number of heads in multiheaded attention. Defaults to 4.
-            num_layers (int, optional): Number of both decoded and encoder layers. Defaults to 2.
-            feedforward_neurons (int, optional): Size of feed-forward neural network at the end of encoder/decoder. Defaults to 200.
-            num_labels (int, optional): Action-vector size. Defaults to len(ACTIONS_KEYS).
-            max_len (int, optional): Maxium length of sequence. Defaults to 500.
-            dropout (float, optional): Dropout ratio. Defaults to 0.1.
+            params (ActionsModelMixedParams): Parameters for model
         """
         super(ActionsModelMixed, self).__init__()
 
-        self.num_labels = num_labels
+        self.params = params
+
+        self.num_labels = params.num_labels
         self.device = "cpu"
 
         # Word embedder
-        self.word_embedding = nn.Embedding(vocab_size, embedding_size)
-        self.punctuation_embedding = nn.Linear(num_labels, embedding_size)
+        self.word_embedding = nn.Embedding(params.vocab_size, params.embedding_size)
+        self.punctuation_embedding = nn.Linear(params.num_labels, params.embedding_size)
 
         # Add positional encoding
         self.words_position_embedding = PositionalEncoding(
-            embedding_size, max_len, dropout
+            params.embedding_size, params.max_len, params.dropout
         )
         self.punctuation_position_embedding = PositionalEncoding(
-            embedding_size, max_len, dropout
+            params.embedding_size, params.max_len, params.dropout
         )
 
         # Sentence encoder
         sentence_encoder_layer = nn.TransformerEncoderLayer(
-            embedding_size, num_heads, feedforward_neurons, dropout
+            params.embedding_size, params.num_heads, params.feedforward_neurons, params.dropout
         )
         self.sentence_encoder = nn.TransformerEncoder(
-            sentence_encoder_layer, num_layers=num_layers
+            sentence_encoder_layer, num_layers=params.num_layers
         )
 
         # Punctuation decoder
         punctuation_decoder_layer = nn.TransformerDecoderLayer(
-            embedding_size, num_heads, feedforward_neurons, dropout
+            params.embedding_size, params.num_heads, params.feedforward_neurons, params.dropout
         )
         self.punctuation_decoder = nn.TransformerDecoder(
-            punctuation_decoder_layer, num_layers=num_layers
+            punctuation_decoder_layer, num_layers=params.num_layers
         )
 
-        self.to_labels = nn.Linear(embedding_size, num_labels)
+        self.to_labels = nn.Linear(params.embedding_size, params.num_labels)
 
     def forward(
         self,
@@ -193,6 +211,24 @@ class ActionsModelMixed(nn.Module):
 
         return recover_text(text, prediction_binary)
 
+    def save(self, dir: str, name: str) -> None:
+        prepare_folder(dir)
+        torch.save(self.state_dict(), f"{dir}/{name}.model")
+        pickle_save(self.params, f"{dir}/{name}.config")
+
+    @staticmethod
+    def load(dir: str, name: str, device: str) -> PunctuationModel:
+        params = pickle_read(f"{dir}/{name}.config")
+        model = ActionsModelMixed(params)
+
+        model.load_state_dict(
+                torch.load(
+                    f"{dir}/{name}.model",
+                    map_location=device,
+                )
+            )
+
+        return model
 
 class ActionsModelMixedLoss(nn.Module):
     """Class representing proposed loss for training mixed actions model"""
diff --git a/src/models/interfaces.py b/src/models/interfaces.py
new file mode 100644
index 0000000..5ec408f
--- /dev/null
+++ b/src/models/interfaces.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+import torch.nn as nn
+
+from abc import ABC, abstractmethod
+
+class PunctuationModel(nn.Module, ABC):
+    def __init__(self) -> None:
+        super().__init__()
+    
+    @abstractmethod
+    def save(self, dir: str, name: str) -> None:
+        pass
+
+    @staticmethod
+    @abstractmethod
+    def load(dir: str, name: str, device: str) -> PunctuationModel:
+        pass
\ No newline at end of file
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index 666fa51..b33bc0f 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -1,61 +1,54 @@
 #!/usr/bin/python3
 
 import pickle
+from src.checkpoints import Checkpoint, Loader, ProgressTracker, Saver, Timeout
 
 import numpy as np
 import pandas as pd
 import torch
 from transformers import BertTokenizerFast
+import dask.dataframe as dd
 
-from src.models.actions_model_mixed import ActionsModelMixed, ActionsModelMixedLoss
+from src.models.actions_model_mixed import ActionsModelMixed, ActionsModelMixedLoss, ActionsModelMixedParams
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.pipelines.train import TrainerBase
-from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config
+from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config, random_indexes, training_loop, unflattened_column
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
 INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
 OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions_mixed"
 
 
-class TrainerActions(TrainerBase):
-    def __init__(self) -> None:
-
-        config = get_config()
-        embedding_size = config["actions"]["training_mixed"]["embedding_size"]
-        num_heads = config["actions"]["training_mixed"]["num_heads"]
-        num_layers = config["actions"]["training_mixed"]["num_layers"]
-        dropout = config["actions"]["training_mixed"]["dropout"]
-        feedforward_neurons = config["actions"]["training_mixed"]["feedforward_neurons"]
-        learning_rate = config["actions"]["training_mixed"]["learning_rate"]
-        num_epochs = config["actions"]["training_mixed"]["num_epochs"]
-        batch_size = config["actions"]["training_mixed"]["batch_size"]
-        save_step = config["actions"]["training_mixed"]["save_step"]
-        batch_buffer_size = config["actions"]["training_mixed"]["batch_buffer_size"]
-        loss_averaging_span = config["actions"]["training_mixed"]["loss_averaging_span"]
-        fresh_start = config["actions"]["training_mixed"]["fresh_start"]
-        device_name = config["actions"]["training_mixed"]["device"]
-        max_train_time = config["actions"]["training_mixed"]["max_training_time"]
-        base_model = config["global"]["base_model"]
-        seed = config["global"]["random_seed"]
-
-        print(f"Layers: {num_layers}")
-
-        if max_train_time is not None:
-            max_train_time = convert_to_timedelta(max_train_time)
-
-        # Load loss weights
-        with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
-            stats = pickle.load(f)
-            pos_examples = stats["class_number"]
-            neg_examples = stats["num_examples"] - stats["class_number"]
-            pos_weight = torch.tensor(neg_examples / pos_examples)
-
-        np.random.seed(seed=seed)
-
-        device = torch.device(device_name if torch.cuda.is_available() else "cpu")
-
-        tokenizer = BertTokenizerFast.from_pretrained(base_model)
-        model = ActionsModelMixed(
+if __name__ == "__main__":
+    config = get_config()
+    embedding_size = config["actions"]["training_mixed"]["embedding_size"]
+    num_heads = config["actions"]["training_mixed"]["num_heads"]
+    num_layers = config["actions"]["training_mixed"]["num_layers"]
+    dropout = config["actions"]["training_mixed"]["dropout"]
+    feedforward_neurons = config["actions"]["training_mixed"]["feedforward_neurons"]
+    learning_rate = config["actions"]["training_mixed"]["learning_rate"]
+    num_epochs = config["actions"]["training_mixed"]["num_epochs"]
+    batch_size = config["actions"]["training_mixed"]["batch_size"]
+    save_step = config["actions"]["training_mixed"]["save_step"]
+    batch_buffer_size = config["actions"]["training_mixed"]["batch_buffer_size"]
+    loss_averaging_span = config["actions"]["training_mixed"]["loss_averaging_span"]
+    fresh_start = config["actions"]["training_mixed"]["fresh_start"]
+    device_name = config["actions"]["training_mixed"]["device"]
+    max_train_time = convert_to_timedelta(config["actions"]["training_mixed"]["max_training_time"])
+    base_model = config["global"]["base_model"]
+    seed = config["global"]["random_seed"]
+
+    np.random.seed(seed=seed)
+    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
+    
+    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
+    tokenizer = BertTokenizerFast.from_pretrained(base_model)
+
+    loader = Loader(OUTPUT_PATH, ActionsModelMixed, torch.optim.AdamW, device)
+
+    if loader.has_checkpoints() and not fresh_start:
+        model, optimizer, epoch_start, sample_start = loader.load_latest()
+    else:
+        params = ActionsModelMixedParams(
             tokenizer.vocab_size,
             embedding_size,
             num_heads,
@@ -63,60 +56,68 @@ class TrainerActions(TrainerBase):
             feedforward_neurons,
             len(ACTIONS_KEYS),
             500,
-            dropout,
-        )
-        model.to(device)
-        self.criterion = ActionsModelMixedLoss(pos_weight).to(device)
+            dropout)
+        model = ActionsModelMixed(params)
+
         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+        epoch_start, sample_start = (0, 0)
 
-        super(TrainerActions, self).__init__(
-            model,
-            device,
-            optimizer,
-            INPUT_PATH,
-            OUTPUT_PATH,
-            max_train_time,
-            fresh_start,
-            num_epochs,
-            batch_size,
-            batch_buffer_size,
-            loss_averaging_span,
-            save_step,
-        )
+    model.train()
+    model.to(device)
+
+    # Load loss weights
+    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
+        stats = pickle.load(f)
+        pos_examples = stats["class_number"]
+        neg_examples = stats["num_examples"] - stats["class_number"]
+        pos_weight = torch.tensor(neg_examples / pos_examples)
 
-    def calc_loss(self, data_batch: pd.DataFrame,) -> torch.Tensor:
-        inputs = data_batch.apply(
-            lambda x: x["source"].reshape(x["source_shape"]), axis=1
-        ).values
-        outputs = data_batch.apply(
-            lambda x: x["target"].reshape(x["target_shape"]), axis=1
-        ).values
-        attentions_mask = data_batch.apply(
-            lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]), axis=1,
-        ).values
+    criterion = ActionsModelMixedLoss(pos_weight).to(device)
+
+    random_index_shuffle = random_indexes(df)
+    training_stopped = False
+
+    saver = Saver(OUTPUT_PATH, model, optimizer)
+    checkpoint = Checkpoint(save_step, saver, epoch_start, sample_start)
+    timer = Timeout(max_train_time, saver)
+    tracker = ProgressTracker(device, loss_averaging_span)
+
+    timer.start()
+    for data_batch, epoch, i in training_loop(epoch_start, sample_start, num_epochs, df, batch_size, batch_buffer_size, random_index_shuffle):
+        inputs = unflattened_column(data_batch, 'source')
+        outputs = unflattened_column(data_batch, 'target')
+        attentions_mask = unflattened_column(data_batch, 'attention_mask')
 
         inputs = (
-            torch.tensor(np.stack(inputs), dtype=torch.long)
-            .to(self.device)
+            torch.tensor(inputs, dtype=torch.long)
+            .to(device)
             .squeeze(dim=2)
         )
 
-        outputs = torch.tensor(np.stack(outputs), dtype=torch.float).to(self.device)
+        outputs = torch.tensor(outputs, dtype=torch.float).to(device)
 
         # Convert to boolean
-        attentions_mask = torch.tensor(np.stack(attentions_mask))
-        attentions_mask = torch.tensor(np.stack(attentions_mask, axis=0) == 0).to(
-            self.device
+        attentions_mask = torch.tensor(attentions_mask == 0).to(
+            device
         )
 
-        y_pred = self.model(
+        y_pred = model(
             input_ids=inputs,
             actions=outputs[:, :-1, :],
             attention_mask=attentions_mask,
         )
 
-        return self.criterion(outputs[:, 1:, :], y_pred)
+        loss = criterion(outputs[:, 1:, :], y_pred)
+        optimizer.zero_grad()
 
+        tracker.step(epoch, i, loss)
+        checkpoint.step(epoch, i)
+        if timer.step(epoch, i):
+            training_stopped = True
+            break
 
-if __name__ == "__main__":
-    TrainerActions().train()
+        loss.backward()
+        optimizer.step()
+
+    if not training_stopped:
+        saver.save("final")
\ No newline at end of file
diff --git a/src/training.py b/src/training.py
index f9ffd92..d74d322 100644
--- a/src/training.py
+++ b/src/training.py
@@ -1,5 +1,5 @@
 import re
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
 
 import torch
 import torch.nn as nn
@@ -8,7 +8,7 @@ import torch.optim as optim
 from src.utils import prepare_folder
 
 
-def latest_model(file_paths: [str]) -> Optional[Tuple[int, int]]:
+def latest_model(file_paths: List[str]) -> Optional[Tuple[int, int]]:
     """Finds newest model in directory
 
     Args:
diff --git a/src/utils.py b/src/utils.py
index d8acdff..4585f9b 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -1,10 +1,16 @@
 import os
+import pickle
 import re
 import shutil
 from datetime import timedelta
-from typing import List, Optional
+
+from typing import Generator
+from src.batch_loading import get_batches, get_ordered_dataframe_len
+from typing import List, Optional, Tuple
 import pandas as pd
 import numpy as np
+import dask.dataframe as dd
+import torch
 
 import yaml
 
@@ -126,7 +132,103 @@ def unflattened_column(df: pd.DataFrame, name: str) -> np.ndarray:
     return np.stack(values)
 
 
-def convert_to_timedelta(time_val: str) -> Optional[timedelta]:
+def moving_average(values: List[np.ndarray], average_span: int) -> Tuple[float, np.ndarray]:
+    """Computes moving average and keeps only latests records
+
+    Args:
+        values (List[np.ndarray]): Table containing values over which to compute moving averag
+        average_span (int): Maximum span over which to average
+
+    Returns:
+        Tuple[float, np.ndarray]: computetd average, values array trimed to last "average_span" entries
+    """
+
+    if len(values) > average_span:
+        values = values[-average_span:]
+    
+    return np.mean(values), values
+
+
+def optimizer_step(loss: torch.Tensor, optimizer: torch.optim.Optimizer) -> None:
+    """Computes and applies a single step of optimization
+
+    Args:
+        loss (torch.Tensor): Loss that is optimized
+        optimizer (torch.optim.optimizer.Optimizer): Optimizer used to optimize loss
+    """
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+
+
+def training_loop(epoch_start: int, sample_start: int, num_epochs: int , df: dd.DataFrame, batch_size: int, batch_buffer_size: int, random_index_shuffle: np.ndarray) -> Generator[pd.DataFrame, int, int]:
+    """Generator providing all data necessary to perform a training steps. This function handels epochs/steps management
+
+    Args:
+        epoch_start (int): Epoch from which to start training
+        sample_start (int): Batch in epoch from which to start training
+        num_epochs (int): Number of epochs to train
+        df (dd.DataFrame): Dask dataframe with training dataset. Indexes must be continous from 0 to len
+        batch_size (int): Batch size
+        batch_buffer_size (int): Number of batches to load at once to memory
+        random_index_shuffle (np.ndarray): Shuffled indices of dataset
+
+    Yields:
+        Generator: batch, epoch_num, step_num
+    """
+    for epoch in range(epoch_start, num_epochs):
+            i = sample_start
+            for data_batch in get_batches(
+                df, batch_size, batch_buffer_size, random_index_shuffle, i
+                ):
+                if len(data_batch) == 0:
+                    continue
+
+                yield data_batch, epoch, i
+
+                i += 1
+
+            sample_start = 0
+
+
+def random_indexes(df: dd.DataFrame) -> np.ndarray:
+    """Provides array of randomly shuffled indices for dataset
+
+    Args:
+        df (dd.DataFrame): Dask dataframe with training dataset. Indexes must be continous from 0 to len
+
+    Returns:
+        np.ndarray: Shuffled indices
+    """
+    num_samples = get_ordered_dataframe_len(df)
+    return np.random.permutation(range(num_samples))
+
+
+def pickle_save(obj: any, path: str) -> None:
+    """Pickles and saves object to a file
+
+    Args:
+        obj (any): Object to pickle
+        path (str): Path to output file
+    """
+    with open(path, 'wb') as f:
+        pickle.dump(obj, f)
+
+
+def pickle_read(path: str) -> any:
+    """Reads pickled objet from a file
+
+    Args:
+        path (str): Path to input file
+
+    Returns:
+        any: Unpickled object
+    """
+    with open(path, 'rb') as f:
+        return pickle.load(f)
+
+
+def convert_to_timedelta(time_val: Optional[str]) -> Optional[timedelta]:
     """
     src: https://code.activestate.com/recipes/577894-convert-strings-like-5d-and-60s-to-timedelta-objec/
     Given a *time_val* (string) such as '5d', returns a timedelta object
@@ -152,6 +254,9 @@ def convert_to_timedelta(time_val: str) -> Optional[timedelta]:
         >>> convert_to_timedelta('120s')
         datetime.timedelta(0, 120)
     """
+    if time_val is None:
+        return None
+
     num = int(time_val[:-1])
     if time_val.endswith("s"):
         return timedelta(seconds=num)
-- 
GitLab


From c84402eb5d5caa17cb4fef72bd441d9e28e01e73 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 26 Aug 2020 12:17:05 +0200
Subject: [PATCH 081/116] All tests pass on new api

---
 src/checkpoints.py                            |  47 ++--
 src/models/actions_model_base.py              |  44 +++-
 src/models/actions_model_mixed.py             |  32 ++-
 src/models/actions_model_restricted.py        |  60 +++++-
 src/models/interfaces.py                      |   8 +-
 .../actions_based/stage1_extraction.py        |   5 +-
 .../actions_based/stage2_tokenization.py      |   5 +-
 .../actions_based/stage3_exploding.py         |   7 +-
 src/pipelines/actions_based/train.py          | 142 ------------
 src/pipelines/actions_based/train_base.py     | 175 ++++++++-------
 src/pipelines/actions_based/train_mixed.py    |  59 +++--
 .../actions_based/train_restricted.py         | 204 ++++++++++--------
 src/pipelines/train.py                        | 181 ----------------
 .../translation_based/stage1_extraction.py    |   5 +-
 .../stage2_create_batches.py                  |   5 +-
 src/utils.py                                  |  55 ++---
 tests/models/test_actions_model_base.py       |  10 +-
 tests/models/test_actions_model_mixed.py      |  10 +-
 tests/models/test_actions_model_restricted.py |   5 +-
 tox.ini                                       |   1 +
 20 files changed, 451 insertions(+), 609 deletions(-)
 delete mode 100755 src/pipelines/actions_based/train.py
 delete mode 100644 src/pipelines/train.py

diff --git a/src/checkpoints.py b/src/checkpoints.py
index 4d97dac..a6394ad 100644
--- a/src/checkpoints.py
+++ b/src/checkpoints.py
@@ -1,30 +1,40 @@
 from __future__ import annotations
-from datetime import date, datetime, timedelta
+
+from datetime import datetime, timedelta
 from glob import glob
-from src.utils import moving_average, prepare_folder
-from src.training import latest_model
 from typing import Optional, Tuple, Type
+
 import torch
-import torch.nn as nn
 from torch.optim.optimizer import Optimizer
+
 from src.models.interfaces import PunctuationModel
+from src.training import latest_model
+from src.utils import moving_average, prepare_folder
 
 
-class Saver():
-    def __init__(self, save_dir: str, model: PunctuationModel, optimizer: Optimizer) -> None:
+class Saver:
+    def __init__(
+        self, save_dir: str, model: PunctuationModel, optimizer: Optimizer
+    ) -> None:
         self.save_dir = save_dir
         self.model = model
         self.optimizer = optimizer
 
         prepare_folder(self.save_dir)
-    
+
     def save(self, name):
         self.model.save(self.save_dir, name)
         torch.save(self.optimizer.state_dict(), f"{self.save_dir}/{name}.optimizer")
 
 
 class Loader:
-    def __init__(self, save_dir: str, model_type: Type[PunctuationModel], optimizer_type: Type[Optimizer], device: str) -> None:
+    def __init__(
+        self,
+        save_dir: str,
+        model_type: Type[PunctuationModel],
+        optimizer_type: Type[Optimizer],
+        device: str,
+    ) -> None:
         self.save_dir = save_dir
         self.device = device
 
@@ -34,26 +44,22 @@ class Loader:
     def has_checkpoints(self) -> bool:
         files = glob(f"{self.save_dir}/*.model")
 
-        return (latest_model(files) is not None)
-
+        return latest_model(files) is not None
 
     def load(self, name) -> Tuple[PunctuationModel, Optimizer]:
         model = self.model_type.load(self.save_dir, name)
 
         optimizer = self.optimizer_type(model.parameters())
         optimizer.load_state_dict(
-                torch.load(
-                    f"{self.save_dir}/{name}.optimizer",
-                    map_location=self.device
-                )
-            )
+            torch.load(f"{self.save_dir}/{name}.optimizer", map_location=self.device)
+        )
 
         return model, optimizer
 
     def load_latest(self) -> Tuple[PunctuationModel, Optimizer, int, int]:
         files = glob(f"{self.save_dir}/*.model")
 
-        model_id = latest_model(files)    
+        model_id = latest_model(files)
         if model_id is None:
             return None
 
@@ -62,7 +68,9 @@ class Loader:
 
 
 class Checkpoint:
-    def __init__(self, save_step, saver: Saver, start_step: int, start_epoch: int) -> None:
+    def __init__(
+        self, save_step, saver: Saver, start_step: int, start_epoch: int
+    ) -> None:
         self.start_step = start_step
         self.start_epoch = start_epoch
         self.save_step = save_step
@@ -72,7 +80,7 @@ class Checkpoint:
     def step(self, epoch, step) -> None:
         if step % self.save_step == 0 and (
             step != self.start_step or epoch != self.start_epoch
-            ):
+        ):
             print(f"Saving: Epoch {epoch}, step {step}")
             self.saver.save(f"{epoch}-{step}")
 
@@ -88,7 +96,7 @@ class Timeout:
         if self.duration is not None:
             self.time_max = datetime.now() + self.max_train_time
 
-    def step(self, epoch, step, time = None) -> bool:
+    def step(self, epoch, step, time=None) -> bool:
         assert self.time_max is not None
 
         if time is None:
@@ -103,6 +111,7 @@ class Timeout:
 
         return False
 
+
 class ProgressTracker:
     def __init__(self, device: str, loss_averaging_span) -> None:
         print(f"Training on {device}")
diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index 81cabe1..d0c4a6f 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -1,16 +1,37 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+
 import torch
 import torch.nn as nn
 from torch.nn.modules.loss import BCEWithLogitsLoss
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
 
+from src.models.interfaces import PunctuationModel
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.utils import pickle_read, pickle_save, prepare_folder
+
+
+@dataclass
+class ActionsModelBaseParams:
+    """
+    Parameters for ActionsModelBase initialization
+
+    Args:
+        base_model (str): Name of base model
+        num_labels (int): Length of action vector
+
+    """
 
+    base_model: str
+    num_labels: int = len(ACTIONS_KEYS)
 
-class ActionsModelBase(nn.Module):
+
+class ActionsModelBase(PunctuationModel):
     """Model based on simple multilabel per-token classifiaction. Each token is binarly classified in n-dimensions"""
 
-    def __init__(self, base_model: str, num_labels: int = len(ACTIONS_KEYS)) -> None:
+    def __init__(self, params: ActionsModelBaseParams) -> None:
         """Initializes actions model
 
         Args:
@@ -18,9 +39,10 @@ class ActionsModelBase(nn.Module):
             num_labels (int): Length of action vector
         """
         super(ActionsModelBase, self).__init__()
+        self.params = params
 
-        config = PretrainedConfig.from_pretrained(base_model)
-        config.num_labels = num_labels
+        config = PretrainedConfig.from_pretrained(params.base_model)
+        config.num_labels = params.num_labels
 
         self.criterion = None
         self.core = BertForTokenClassification(config)
@@ -41,6 +63,20 @@ class ActionsModelBase(nn.Module):
 
         return y_pred
 
+    def save(self, dir: str, name: str) -> None:
+        prepare_folder(dir)
+        torch.save(self.state_dict(), f"{dir}/{name}.model")
+        pickle_save(self.params, f"{dir}/{name}.config")
+
+    @staticmethod
+    def load(dir: str, name: str, device: str) -> ActionsModelBase:
+        params = pickle_read(f"{dir}/{name}.config")
+        model = ActionsModelBase(params)
+
+        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
+
+        return model
+
 
 class ActionsModelBaseLoss(nn.Module):
     """Proposed loss for ActionsModelBase model"""
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index 12ef3ed..ced75a7 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -1,5 +1,4 @@
-from src.utils import pickle_read, pickle_save, prepare_folder
-from src.models.interfaces import PunctuationModel
+from dataclasses import dataclass
 from typing import Optional
 
 import numpy as np
@@ -9,16 +8,16 @@ from torch.nn.modules.loss import BCEWithLogitsLoss
 from transformers.tokenization_bert import BertTokenizerFast
 
 from src.models.common import PositionalEncoding, generate_square_subsequent_mask
+from src.models.interfaces import PunctuationModel
 from src.pipelines.actions_based.processing import (
     ACTIONS_KEYS,
     action_vector,
     recover_text,
     token_labels_to_word_labels,
 )
+from src.utils import pickle_read, pickle_save, prepare_folder
 
 
-from dataclasses import dataclass
-
 @dataclass
 class ActionsModelMixedParams:
     """
@@ -34,7 +33,7 @@ class ActionsModelMixedParams:
         max_len (int, optional): Maxium length of sequence. Defaults to 500.
         dropout (float, optional): Dropout ratio. Defaults to 0.1.
     """
-    
+
     vocab_size: int
     embedding_size: int = 200
     num_heads: int = 4
@@ -48,10 +47,7 @@ class ActionsModelMixedParams:
 class ActionsModelMixed(PunctuationModel):
     """Encoder-decoder based model with unpunctuated token sequence as input and array of action-vectors as output"""
 
-    def __init__(
-        self,
-        params: ActionsModelMixedParams
-    ) -> None:
+    def __init__(self, params: ActionsModelMixedParams) -> None:
         """Initializes mixed model
 
         Args:
@@ -78,7 +74,10 @@ class ActionsModelMixed(PunctuationModel):
 
         # Sentence encoder
         sentence_encoder_layer = nn.TransformerEncoderLayer(
-            params.embedding_size, params.num_heads, params.feedforward_neurons, params.dropout
+            params.embedding_size,
+            params.num_heads,
+            params.feedforward_neurons,
+            params.dropout,
         )
         self.sentence_encoder = nn.TransformerEncoder(
             sentence_encoder_layer, num_layers=params.num_layers
@@ -86,7 +85,10 @@ class ActionsModelMixed(PunctuationModel):
 
         # Punctuation decoder
         punctuation_decoder_layer = nn.TransformerDecoderLayer(
-            params.embedding_size, params.num_heads, params.feedforward_neurons, params.dropout
+            params.embedding_size,
+            params.num_heads,
+            params.feedforward_neurons,
+            params.dropout,
         )
         self.punctuation_decoder = nn.TransformerDecoder(
             punctuation_decoder_layer, num_layers=params.num_layers
@@ -221,15 +223,11 @@ class ActionsModelMixed(PunctuationModel):
         params = pickle_read(f"{dir}/{name}.config")
         model = ActionsModelMixed(params)
 
-        model.load_state_dict(
-                torch.load(
-                    f"{dir}/{name}.model",
-                    map_location=device,
-                )
-            )
+        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
 
         return model
 
+
 class ActionsModelMixedLoss(nn.Module):
     """Class representing proposed loss for training mixed actions model"""
 
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index 43dae15..de8d338 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -1,15 +1,37 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+
 import torch
 import torch.nn as nn
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
 
+from src.models.actions_model_mixed import ActionsModelMixed
+from src.models.interfaces import PunctuationModel
+from src.utils import pickle_read, pickle_save, prepare_folder
+
+
+@dataclass
+class ActionsModelRestrictedParams:
+    """
+    Parameters for ActionsModelRestricted
+
+    Params:
+        base_model (str): Name of base model
+        extended_action_vector_size (int): Action-vector size including additional no-punctuation logit
+    """
+
+    base_model: str
+    extended_action_vector_size: int
 
-class ActionsModelRestricted(nn.Module):
+
+class ActionsModelRestricted(PunctuationModel):
     """Similar to ActionsModelBase, however no-punctuation class is added
     and punctuation-related entries are treaded as proper categorical distribution
     """
 
-    def __init__(self, base_model: str, extended_action_vector_size: int) -> None:
+    def __init__(self, params: ActionsModelRestrictedParams) -> None:
         """Initializes restricted actions model
 
         Args:
@@ -18,9 +40,11 @@ class ActionsModelRestricted(nn.Module):
         """
         super(ActionsModelRestricted, self).__init__()
 
-        config = PretrainedConfig.from_pretrained(base_model)
+        self.params = params
+
+        config = PretrainedConfig.from_pretrained(params.base_model)
 
-        config.num_labels = extended_action_vector_size
+        config.num_labels = params.extended_action_vector_size
 
         self.core = BertForTokenClassification(config)
 
@@ -48,6 +72,20 @@ class ActionsModelRestricted(nn.Module):
 
         return torch.log(z / (1 - z))
 
+    def save(self, dir: str, name: str) -> None:
+        prepare_folder(dir)
+        torch.save(self.state_dict(), f"{dir}/{name}.model")
+        pickle_save(self.params, f"{dir}/{name}.config")
+
+    @staticmethod
+    def load(dir: str, name: str, device: str) -> ActionsModelRestricted:
+        params = pickle_read(f"{dir}/{name}.config")
+        model = ActionsModelRestricted(params)
+
+        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
+
+        return model
+
 
 class ActionsModelRestrictedLoss(nn.Module):
     def __init__(
@@ -91,3 +129,17 @@ class ActionsModelRestrictedLoss(nn.Module):
         uppercase_loss = self.binary_ce(predicted_uppercase, target_uppercase)
 
         return punc_loss + uppercase_loss
+
+    def save(self, dir: str, name: str) -> None:
+        prepare_folder(dir)
+        torch.save(self.state_dict(), f"{dir}/{name}.model")
+        pickle_save(self.params, f"{dir}/{name}.config")
+
+    @staticmethod
+    def load(dir: str, name: str, device: str) -> PunctuationModel:
+        params = pickle_read(f"{dir}/{name}.config")
+        model = ActionsModelMixed(params)
+
+        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
+
+        return model
diff --git a/src/models/interfaces.py b/src/models/interfaces.py
index 5ec408f..f54df51 100644
--- a/src/models/interfaces.py
+++ b/src/models/interfaces.py
@@ -1,12 +1,14 @@
 from __future__ import annotations
-import torch.nn as nn
 
 from abc import ABC, abstractmethod
 
+import torch.nn as nn
+
+
 class PunctuationModel(nn.Module, ABC):
     def __init__(self) -> None:
         super().__init__()
-    
+
     @abstractmethod
     def save(self, dir: str, name: str) -> None:
         pass
@@ -14,4 +16,4 @@ class PunctuationModel(nn.Module, ABC):
     @staticmethod
     @abstractmethod
     def load(dir: str, name: str, device: str) -> PunctuationModel:
-        pass
\ No newline at end of file
+        pass
diff --git a/src/pipelines/actions_based/stage1_extraction.py b/src/pipelines/actions_based/stage1_extraction.py
index 94dc26c..5a058a9 100644
--- a/src/pipelines/actions_based/stage1_extraction.py
+++ b/src/pipelines/actions_based/stage1_extraction.py
@@ -6,10 +6,7 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.actions_based.processing import (
-    APPLY_FILE_PROCESSING_META,
-    apply_file_processing,
-)
+from src.pipelines.actions_based.processing import APPLY_FILE_PROCESSING_META, apply_file_processing
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/actions_based/stage2_tokenization.py b/src/pipelines/actions_based/stage2_tokenization.py
index b30445f..0ea3586 100644
--- a/src/pipelines/actions_based/stage2_tokenization.py
+++ b/src/pipelines/actions_based/stage2_tokenization.py
@@ -4,10 +4,7 @@ import dask.dataframe as dd
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.actions_based.processing import (
-    APPLY_TOKENIZATION_META,
-    apply_tokenization,
-)
+from src.pipelines.actions_based.processing import APPLY_TOKENIZATION_META, apply_tokenization
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
diff --git a/src/pipelines/actions_based/stage3_exploding.py b/src/pipelines/actions_based/stage3_exploding.py
index 72ec128..81dc965 100644
--- a/src/pipelines/actions_based/stage3_exploding.py
+++ b/src/pipelines/actions_based/stage3_exploding.py
@@ -2,12 +2,7 @@
 import dask.dataframe as dd
 from dask.distributed import Client
 
-from src.processing import (
-    EXPAND_DIMS_META,
-    FLATTEN_DIMS_META,
-    expand_dims,
-    flatten_dims,
-)
+from src.processing import EXPAND_DIMS_META, FLATTEN_DIMS_META, expand_dims, flatten_dims
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
diff --git a/src/pipelines/actions_based/train.py b/src/pipelines/actions_based/train.py
deleted file mode 100755
index e6ed38e..0000000
--- a/src/pipelines/actions_based/train.py
+++ /dev/null
@@ -1,142 +0,0 @@
-#!/usr/bin/python3
-
-import glob
-import pickle
-from datetime import datetime
-
-import dask.dataframe as dd
-import numpy as np
-import torch
-from torch.nn import BCEWithLogitsLoss
-from transformers import BertForTokenClassification, BertTokenizerFast
-
-from src.batch_loading import get_batches
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.training import latest_model, save_training_step
-from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config, prepare_folder
-
-INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
-INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
-OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions"
-
-if __name__ == "__main__":
-    config = get_config()
-    learning_rate = config["actions"]["training"]["learning_rate"]
-    num_epochs = config["actions"]["training"]["num_epochs"]
-    batch_size = config["actions"]["training"]["batch_size"]
-    save_step = config["actions"]["training"]["save_step"]
-    loss_averaging_span = config["actions"]["training"]["loss_averaging_span"]
-    fresh_start = config["actions"]["training"]["fresh_start"]
-    device_name = config["actions"]["training"]["device"]
-    max_train_time = config["actions"]["training"]["max_training_time"]
-    base_model = config["global"]["base_model"]
-    seed = config["global"]["random_seed"]
-
-    prepare_folder(OUTPUT_PATH)
-    np.random.seed(seed=seed)
-
-    if max_train_time is not None:
-        max_train_time = convert_to_timedelta(max_train_time)
-
-    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
-    print(f"Training on {device}")
-
-    # Load loss weights
-    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
-        stats = pickle.load(f)
-        pos_examples = stats["class_number"]
-        neg_examples = stats["num_examples"] - stats["class_number"]
-        pos_weight = torch.tensor(neg_examples / pos_examples)
-
-    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
-    tokenizer = BertTokenizerFast.from_pretrained(base_model)
-
-    model = BertForTokenClassification.from_pretrained(
-        base_model, num_labels=len(ACTIONS_KEYS)
-    ).to(device)
-    criterion = BCEWithLogitsLoss(pos_weight=pos_weight).to(device)
-    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
-
-    epoch_start = 0
-    sample_start = 0
-    if fresh_start is False:
-        checkpoint_files = glob.glob(f"{OUTPUT_PATH}/*.model")
-        latest = latest_model(checkpoint_files)
-
-        if latest is not None:
-            epoch, batch = latest
-            model.load_state_dict(
-                torch.load(f"{OUTPUT_PATH}/{epoch}-{batch}.model", map_location=device,)
-            )
-            optimizer.load_state_dict(
-                torch.load(
-                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer", map_location=device,
-                )
-            )
-
-            epoch_start, sample_start = epoch, batch
-            print(f"Loaded {epoch}-{batch}")
-
-    model.train()
-    model.base_model.train()
-    losses = []
-
-    num_samples = df.tail(1).index.values[0] + 1
-    random_index_shuffle = np.random.permutation(range(num_samples))
-
-    training_stopped = False
-
-    time_max = datetime.max
-    if max_train_time is not None:
-        time_max = datetime.now() + max_train_time
-
-    for epoch in range(epoch_start, num_epochs):
-        if training_stopped:
-            break
-
-        i = sample_start
-        for data_batch in get_batches(df, batch_size, 100, random_index_shuffle, i):
-            inputs = data_batch.apply(
-                lambda x: x["source"].reshape(x["source_shape"]), axis=1
-            ).values
-            outputs = data_batch.apply(
-                lambda x: x["target"].reshape(x["target_shape"]), axis=1
-            ).values
-            attentions_mask = data_batch.apply(
-                lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]),
-                axis=1,
-            ).values
-
-            inputs = torch.tensor(np.stack(inputs).squeeze()).to(device)
-            outputs = torch.tensor(np.stack(outputs)).to(device)
-            attentions_mask = torch.tensor(np.stack(attentions_mask)).to(device)
-
-            y_pred = model(input_ids=inputs, attention_mask=attentions_mask)[0]
-
-            loss = criterion(y_pred, outputs)
-
-            losses.append(loss.item())
-            if len(losses) > loss_averaging_span:
-                losses = losses[-loss_averaging_span:]
-
-            print(f"epoch: {epoch} | step: {i} | loss: {np.mean(losses)}")
-
-            optimizer.zero_grad()
-
-            if i % save_step == 0 and (i != sample_start or epoch != epoch_start):
-                print(f"Saving: Epoch {epoch}, step {i}")
-                save_training_step(OUTPUT_PATH, f"{epoch}-{i}", model, optimizer)
-
-            if datetime.now() > time_max:
-                print(f"Max time reached, saving: Epoch {epoch}, step {i}")
-                save_training_step(OUTPUT_PATH, f"{epoch}-{i}", model, optimizer)
-                training_stopped = True
-                break
-
-            loss.backward()
-            optimizer.step()
-
-            i += 1
-
-    if not training_stopped:
-        save_training_step(OUTPUT_PATH, "final", model, optimizer)
diff --git a/src/pipelines/actions_based/train_base.py b/src/pipelines/actions_based/train_base.py
index 93500f9..15da074 100755
--- a/src/pipelines/actions_based/train_base.py
+++ b/src/pipelines/actions_based/train_base.py
@@ -2,88 +2,115 @@
 
 import pickle
 
+import dask.dataframe as dd
 import numpy as np
-import pandas as pd
 import torch
-
-from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
+from transformers import BertTokenizerFast
+
+from src.checkpoints import Checkpoint, Loader, ProgressTracker, Saver, Timeout
+from src.models.actions_model_base import (
+    ActionsModelBase,
+    ActionsModelBaseLoss,
+    ActionsModelBaseParams,
+)
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.pipelines.train import TrainerBase
-from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config
+from src.utils import (
+    PROJECT_ROOT,
+    convert_to_timedelta,
+    get_config,
+    random_indexes,
+    training_loop,
+    unflattened_column,
+)
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
 INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
 OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions_base"
 
 
-class TrainerActions(TrainerBase):
-    def __init__(self) -> None:
-
-        config = get_config()
-        learning_rate = config["actions"]["training_base"]["learning_rate"]
-        num_epochs = config["actions"]["training_base"]["num_epochs"]
-        batch_size = config["actions"]["training_base"]["batch_size"]
-        save_step = config["actions"]["training_base"]["save_step"]
-        batch_buffer_size = config["actions"]["training_base"]["batch_buffer_size"]
-        loss_averaging_span = config["actions"]["training_base"]["loss_averaging_span"]
-        fresh_start = config["actions"]["training_base"]["fresh_start"]
-        device_name = config["actions"]["training_base"]["device"]
-        max_train_time = config["actions"]["training_base"]["max_training_time"]
-        base_model = config["global"]["base_model"]
-        seed = config["global"]["random_seed"]
-
-        if max_train_time is not None:
-            max_train_time = convert_to_timedelta(max_train_time)
-
-        # Load loss weights
-        with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
-            stats = pickle.load(f)
-            pos_examples = stats["class_number"]
-            neg_examples = stats["num_examples"] - stats["class_number"]
-            pos_weight = torch.tensor(neg_examples / pos_examples)
-
-        np.random.seed(seed=seed)
-
-        device = torch.device(device_name if torch.cuda.is_available() else "cpu")
+if __name__ == "__main__":
+    config = get_config()
+    learning_rate = config["actions"]["training_base"]["learning_rate"]
+    num_epochs = config["actions"]["training_base"]["num_epochs"]
+    batch_size = config["actions"]["training_base"]["batch_size"]
+    save_step = config["actions"]["training_base"]["save_step"]
+    batch_buffer_size = config["actions"]["training_base"]["batch_buffer_size"]
+    loss_averaging_span = config["actions"]["training_base"]["loss_averaging_span"]
+    fresh_start = config["actions"]["training_base"]["fresh_start"]
+    device_name = config["actions"]["training_base"]["device"]
+    max_train_time = convert_to_timedelta(
+        config["actions"]["training_base"]["max_training_time"]
+    )
+    base_model = config["global"]["base_model"]
+    seed = config["global"]["random_seed"]
+
+    np.random.seed(seed=seed)
+    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
+
+    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
+    tokenizer = BertTokenizerFast.from_pretrained(base_model)
+
+    loader = Loader(OUTPUT_PATH, ActionsModelBase, torch.optim.AdamW, device)
+    if loader.has_checkpoints() and not fresh_start:
+        model, optimizer, epoch_start, sample_start = loader.load_latest()
+    else:
+        params = ActionsModelBaseParams(base_model, len(ACTIONS_KEYS))
+        model = ActionsModelBase(params)
 
-        model = ActionsModelBase(base_model, len(ACTIONS_KEYS)).to(device)
-        self.criterion = ActionsModelBaseLoss(pos_weight).to(device)
         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
-
-        super(TrainerActions, self).__init__(
-            model,
-            device,
-            optimizer,
-            INPUT_PATH,
-            OUTPUT_PATH,
-            max_train_time,
-            fresh_start,
-            num_epochs,
-            batch_size,
-            batch_buffer_size,
-            loss_averaging_span,
-            save_step,
-        )
-
-    def calc_loss(self, data_batch: pd.DataFrame,) -> torch.Tensor:
-        inputs = data_batch.apply(
-            lambda x: x["source"].reshape(x["source_shape"]), axis=1
-        ).values
-        outputs = data_batch.apply(
-            lambda x: x["target"].reshape(x["target_shape"]), axis=1
-        ).values
-        attentions_mask = data_batch.apply(
-            lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]), axis=1,
-        ).values
-
-        inputs = torch.tensor(np.stack(inputs).squeeze()).to(self.device)
-        outputs = torch.tensor(np.stack(outputs)).to(self.device)
-        attentions_mask = torch.tensor(np.stack(attentions_mask)).to(self.device)
-
-        y_pred = self.model(input_ids=inputs, attention_mask=attentions_mask)
-
-        return self.criterion(y_pred, outputs)
-
-
-if __name__ == "__main__":
-    TrainerActions().train()
+        epoch_start, sample_start = (0, 0)
+
+    model.train()
+    model.to(device)
+
+    # Load loss weights
+    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
+        stats = pickle.load(f)
+        pos_examples = stats["class_number"]
+        neg_examples = stats["num_examples"] - stats["class_number"]
+        pos_weight = torch.tensor(neg_examples / pos_examples)
+
+    criterion = ActionsModelBaseLoss(pos_weight).to(device)
+
+    random_index_shuffle = random_indexes(df)
+    training_stopped = False
+
+    saver = Saver(OUTPUT_PATH, model, optimizer)
+    checkpoint = Checkpoint(save_step, saver, epoch_start, sample_start)
+    timer = Timeout(max_train_time, saver)
+    tracker = ProgressTracker(device, loss_averaging_span)
+
+    timer.start()
+    for data_batch, epoch, i in training_loop(
+        epoch_start,
+        sample_start,
+        num_epochs,
+        df,
+        batch_size,
+        batch_buffer_size,
+        random_index_shuffle,
+    ):
+        inputs = unflattened_column(data_batch, "source")
+        outputs = unflattened_column(data_batch, "target")
+        attentions_mask = unflattened_column(data_batch, "attention_mask")
+
+        inputs = torch.tensor(inputs, dtype=torch.long).squeeze(dim=-1).to(device)
+        outputs = torch.tensor(outputs, dtype=torch.float).to(device)
+        attentions_mask = torch.tensor(attentions_mask).to(device)
+
+        y_pred = model(input_ids=inputs, attention_mask=attentions_mask)
+
+        loss = criterion(y_pred, outputs)
+        optimizer.zero_grad()
+
+        tracker.step(epoch, i, loss)
+        checkpoint.step(epoch, i)
+        if timer.step(epoch, i):
+            training_stopped = True
+            break
+
+        loss.backward()
+        optimizer.step()
+
+    if not training_stopped:
+        saver.save("final")
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index b33bc0f..806cdfd 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -1,17 +1,27 @@
 #!/usr/bin/python3
 
 import pickle
-from src.checkpoints import Checkpoint, Loader, ProgressTracker, Saver, Timeout
 
+import dask.dataframe as dd
 import numpy as np
-import pandas as pd
 import torch
 from transformers import BertTokenizerFast
-import dask.dataframe as dd
 
-from src.models.actions_model_mixed import ActionsModelMixed, ActionsModelMixedLoss, ActionsModelMixedParams
+from src.checkpoints import Checkpoint, Loader, ProgressTracker, Saver, Timeout
+from src.models.actions_model_mixed import (
+    ActionsModelMixed,
+    ActionsModelMixedLoss,
+    ActionsModelMixedParams,
+)
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config, random_indexes, training_loop, unflattened_column
+from src.utils import (
+    PROJECT_ROOT,
+    convert_to_timedelta,
+    get_config,
+    random_indexes,
+    training_loop,
+    unflattened_column,
+)
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
 INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
@@ -33,13 +43,15 @@ if __name__ == "__main__":
     loss_averaging_span = config["actions"]["training_mixed"]["loss_averaging_span"]
     fresh_start = config["actions"]["training_mixed"]["fresh_start"]
     device_name = config["actions"]["training_mixed"]["device"]
-    max_train_time = convert_to_timedelta(config["actions"]["training_mixed"]["max_training_time"])
+    max_train_time = convert_to_timedelta(
+        config["actions"]["training_mixed"]["max_training_time"]
+    )
     base_model = config["global"]["base_model"]
     seed = config["global"]["random_seed"]
 
     np.random.seed(seed=seed)
     df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
-    
+
     device = torch.device(device_name if torch.cuda.is_available() else "cpu")
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
@@ -56,7 +68,8 @@ if __name__ == "__main__":
             feedforward_neurons,
             len(ACTIONS_KEYS),
             500,
-            dropout)
+            dropout,
+        )
         model = ActionsModelMixed(params)
 
         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
@@ -83,23 +96,25 @@ if __name__ == "__main__":
     tracker = ProgressTracker(device, loss_averaging_span)
 
     timer.start()
-    for data_batch, epoch, i in training_loop(epoch_start, sample_start, num_epochs, df, batch_size, batch_buffer_size, random_index_shuffle):
-        inputs = unflattened_column(data_batch, 'source')
-        outputs = unflattened_column(data_batch, 'target')
-        attentions_mask = unflattened_column(data_batch, 'attention_mask')
-
-        inputs = (
-            torch.tensor(inputs, dtype=torch.long)
-            .to(device)
-            .squeeze(dim=2)
-        )
+    for data_batch, epoch, i in training_loop(
+        epoch_start,
+        sample_start,
+        num_epochs,
+        df,
+        batch_size,
+        batch_buffer_size,
+        random_index_shuffle,
+    ):
+        inputs = unflattened_column(data_batch, "source")
+        outputs = unflattened_column(data_batch, "target")
+        attentions_mask = unflattened_column(data_batch, "attention_mask")
+
+        inputs = torch.tensor(inputs, dtype=torch.long).to(device).squeeze(dim=2)
 
         outputs = torch.tensor(outputs, dtype=torch.float).to(device)
 
         # Convert to boolean
-        attentions_mask = torch.tensor(attentions_mask == 0).to(
-            device
-        )
+        attentions_mask = torch.tensor(attentions_mask == 0).to(device)
 
         y_pred = model(
             input_ids=inputs,
@@ -120,4 +135,4 @@ if __name__ == "__main__":
         optimizer.step()
 
     if not training_stopped:
-        saver.save("final")
\ No newline at end of file
+        saver.save("final")
diff --git a/src/pipelines/actions_based/train_restricted.py b/src/pipelines/actions_based/train_restricted.py
index 5e7d8f3..7856695 100755
--- a/src/pipelines/actions_based/train_restricted.py
+++ b/src/pipelines/actions_based/train_restricted.py
@@ -2,122 +2,142 @@
 
 import pickle
 
+import dask.dataframe as dd
 import numpy as np
-import pandas as pd
 import torch
+from transformers import BertTokenizerFast
 
+from src.checkpoints import Checkpoint, Loader, ProgressTracker, Saver, Timeout
 from src.models.actions_model_restricted import (
     ActionsModelRestricted,
     ActionsModelRestrictedLoss,
+    ActionsModelRestrictedParams,
 )
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.pipelines.train import TrainerBase
-from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config
+from src.utils import (
+    PROJECT_ROOT,
+    convert_to_timedelta,
+    get_config,
+    random_indexes,
+    training_loop,
+    unflattened_column,
+)
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
 INPUT_STATS_PATH = f"{PROJECT_ROOT}/generated/actions/stage5_stats"
 OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions_restricted"
 
 
-class TrainerActions(TrainerBase):
-    def __init__(self) -> None:
-
-        config = get_config()
-        learning_rate = config["actions"]["training_restricted"]["learning_rate"]
-        num_epochs = config["actions"]["training_restricted"]["num_epochs"]
-        batch_size = config["actions"]["training_restricted"]["batch_size"]
-        save_step = config["actions"]["training_restricted"]["save_step"]
-        batch_buffer_size = config["actions"]["training_restricted"][
-            "batch_buffer_size"
-        ]
-        loss_averaging_span = config["actions"]["training_restricted"][
-            "loss_averaging_span"
-        ]
-        fresh_start = config["actions"]["training_restricted"]["fresh_start"]
-        device_name = config["actions"]["training_restricted"]["device"]
-        max_train_time = config["actions"]["training_restricted"]["max_training_time"]
-        base_model = config["global"]["base_model"]
-        seed = config["global"]["random_seed"]
-
-        if max_train_time is not None:
-            max_train_time = convert_to_timedelta(max_train_time)
-
-        # Load loss weights
-        with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
-            stats = pickle.load(f)
-            pos_examples = stats["class_number"]
-            neg_examples = stats["num_examples"] - stats["class_number"]
-
-            uppercase_pos_examples = pos_examples[0]
-            uppercase_neg_examples = neg_examples[0]
-            uppercase_pos_odds = torch.tensor(
-                uppercase_pos_examples / uppercase_neg_examples, dtype=torch.float
-            )
-
-            has_punctuation_neg_examples = neg_examples[1:]
-            has_no_punctuation_neg_examples = np.sum(pos_examples[1:])
-
-            punctuation_neg_examples = np.concatenate(
-                [
-                    has_punctuation_neg_examples,
-                    has_no_punctuation_neg_examples.reshape(1),
-                ],
-                -1,
-            )
-
-            punctuation_class_weights = torch.tensor(
-                (punctuation_neg_examples) / np.sum(punctuation_neg_examples),
-                dtype=torch.float,
-            )
-
-        np.random.seed(seed=seed)
-
-        device = torch.device(device_name if torch.cuda.is_available() else "cpu")
-
-        model = ActionsModelRestricted(base_model, len(ACTIONS_KEYS) + 1).to(device)
-        self.criterion = ActionsModelRestrictedLoss(
-            uppercase_pos_odds, punctuation_class_weights
-        ).to(device)
+if __name__ == "__main__":
+
+    config = get_config()
+    learning_rate = config["actions"]["training_restricted"]["learning_rate"]
+    num_epochs = config["actions"]["training_restricted"]["num_epochs"]
+    batch_size = config["actions"]["training_restricted"]["batch_size"]
+    save_step = config["actions"]["training_restricted"]["save_step"]
+    batch_buffer_size = config["actions"]["training_restricted"]["batch_buffer_size"]
+    loss_averaging_span = config["actions"]["training_restricted"][
+        "loss_averaging_span"
+    ]
+    fresh_start = config["actions"]["training_restricted"]["fresh_start"]
+    device_name = config["actions"]["training_restricted"]["device"]
+    max_train_time = convert_to_timedelta(
+        config["actions"]["training_restricted"]["max_training_time"]
+    )
+    base_model = config["global"]["base_model"]
+    seed = config["global"]["random_seed"]
+
+    np.random.seed(seed=seed)
+    df = dd.read_parquet(INPUT_PATH, engine="pyarrow")
+
+    device = torch.device(device_name if torch.cuda.is_available() else "cpu")
+    tokenizer = BertTokenizerFast.from_pretrained(base_model)
+
+    loader = Loader(OUTPUT_PATH, ActionsModelRestricted, torch.optim.AdamW, device)
+    if loader.has_checkpoints() and not fresh_start:
+        model, optimizer, epoch_start, sample_start = loader.load_latest()
+    else:
+        params = ActionsModelRestrictedParams(base_model, len(ACTIONS_KEYS) + 1)
+        model = ActionsModelRestricted(params)
+
         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+        epoch_start, sample_start = (0, 0)
+
+    model.train()
+    model.to(device)
+
+    # Load loss weights
+    with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
+        stats = pickle.load(f)
+        pos_examples = stats["class_number"]
+        neg_examples = stats["num_examples"] - stats["class_number"]
 
-        super(TrainerActions, self).__init__(
-            model,
-            device,
-            optimizer,
-            INPUT_PATH,
-            OUTPUT_PATH,
-            max_train_time,
-            fresh_start,
-            num_epochs,
-            batch_size,
-            batch_buffer_size,
-            loss_averaging_span,
-            save_step,
+        uppercase_pos_examples = pos_examples[0]
+        uppercase_neg_examples = neg_examples[0]
+        uppercase_pos_odds = torch.tensor(
+            uppercase_pos_examples / uppercase_neg_examples, dtype=torch.float
         )
 
-    def calc_loss(self, data_batch: pd.DataFrame,) -> torch.Tensor:
-        inputs = data_batch.apply(
-            lambda x: x["source"].reshape(x["source_shape"]), axis=1
-        ).values
-        outputs = data_batch.apply(
-            lambda x: x["target"].reshape(x["target_shape"]), axis=1
-        ).values
-        attentions_mask = data_batch.apply(
-            lambda x: x["attention_mask"].reshape(x["attention_mask_shape"]), axis=1,
-        ).values
+        has_punctuation_neg_examples = neg_examples[1:]
+        has_no_punctuation_neg_examples = np.sum(pos_examples[1:])
 
-        inputs = torch.tensor(np.stack(inputs)).squeeze(dim=2).to(self.device)
-        outputs = torch.tensor(np.stack(outputs)).to(self.device)
-        attentions_mask = torch.tensor(np.stack(attentions_mask)).to(self.device)
+        punctuation_neg_examples = np.concatenate(
+            [has_punctuation_neg_examples, has_no_punctuation_neg_examples.reshape(1)],
+            -1,
+        )
+
+        punctuation_class_weights = torch.tensor(
+            (punctuation_neg_examples) / np.sum(punctuation_neg_examples),
+            dtype=torch.float,
+        )
 
-        y_pred = self.model(input_ids=inputs, attention_mask=attentions_mask)
+    criterion = ActionsModelRestrictedLoss(
+        uppercase_pos_odds, punctuation_class_weights
+    ).to(device)
+
+    random_index_shuffle = random_indexes(df)
+    training_stopped = False
+
+    saver = Saver(OUTPUT_PATH, model, optimizer)
+    checkpoint = Checkpoint(save_step, saver, epoch_start, sample_start)
+    timer = Timeout(max_train_time, saver)
+    tracker = ProgressTracker(device, loss_averaging_span)
+
+    timer.start()
+    for data_batch, epoch, i in training_loop(
+        epoch_start,
+        sample_start,
+        num_epochs,
+        df,
+        batch_size,
+        batch_buffer_size,
+        random_index_shuffle,
+    ):
+        inputs = unflattened_column(data_batch, "source")
+        outputs = unflattened_column(data_batch, "target")
+        attentions_mask = unflattened_column(data_batch, "attention_mask")
+
+        inputs = torch.tensor(inputs, dtype=torch.long).squeeze(dim=-1).to(device)
+        outputs = torch.tensor(outputs, dtype=torch.float).to(device)
+        attentions_mask = torch.tensor(attentions_mask).to(device)
+
+        y_pred = model(input_ids=inputs, attention_mask=attentions_mask)
 
         outputs = torch.cat(
             [outputs, (1.0 - outputs[:, :, 1:].max(-1)[0]).unsqueeze(-1)], axis=-1
         )
 
-        return self.criterion(y_pred, outputs)
+        loss = criterion(y_pred, outputs)
+        optimizer.zero_grad()
 
+        tracker.step(epoch, i, loss)
+        checkpoint.step(epoch, i)
+        if timer.step(epoch, i):
+            training_stopped = True
+            break
 
-if __name__ == "__main__":
-    TrainerActions().train()
+        loss.backward()
+        optimizer.step()
+
+    if not training_stopped:
+        saver.save("final")
diff --git a/src/pipelines/train.py b/src/pipelines/train.py
deleted file mode 100644
index 3c224db..0000000
--- a/src/pipelines/train.py
+++ /dev/null
@@ -1,181 +0,0 @@
-#!/usr/bin/python3
-
-import glob
-from abc import ABC, abstractmethod
-from datetime import datetime, timedelta
-from typing import Optional
-
-import dask.dataframe as dd
-import numpy as np
-import pandas as pd
-import torch
-import torch.nn as nn
-from torch.optim.optimizer import Optimizer
-
-from src.batch_loading import get_batches
-from src.training import latest_model, save_training_step
-from src.utils import convert_to_timedelta, prepare_folder
-
-
-class TrainerBase(ABC):
-    """[summary]
-
-    Args:
-        ABC ([type]): [description]
-    """
-
-    def __init__(
-        self,
-        model: nn.Module,
-        device: torch.device,
-        optimizer: Optimizer,
-        input_path: str,
-        output_path: str,
-        max_train_time: Optional[timedelta],
-        fresh_start: bool,
-        num_epochs: int,
-        batch_size: int,
-        batch_buffer_size: int = 100,
-        loss_averaging_span: int = 1000,
-        save_step: int = 1000,
-    ) -> None:
-        """Initializes base trainer
-
-        Args:
-            model (nn.Module): Model that will be trained
-            device (torch.device): Device on which model will be loaded & trained
-            optimizer (Optimizer): Optimizer used for gradient descent
-            input_path (str): Path to parquet folder with input dataset
-            output_path (str): Path where model checkpoints will be stored
-            max_train_time (Optional[timedelta]): Maximum training time
-            fresh_start (bool): If set to true, last checkpoint will not be loaded and training will start from scratch
-            num_epochs (int): Number of epochs to train
-            batch_size (int): Batch size to use
-            batch_buffer_size (int, optional): How many batches to load to ram at once. Defaults to 100.
-            loss_averaging_span (int, optional): How many losses to average ovet in logs. Defaults to 1000.
-            save_step (int, optional): Step at which model checkpoints will be saved. Defaults to 1000.
-        """
-
-        self.model = model
-        self.device = device
-        self.optimizer = optimizer
-
-        self.input_path = input_path
-        self.output_path = output_path
-        self.fresh_start = fresh_start
-        self.num_epochs = num_epochs
-        self.batch_size = batch_size
-        self.batch_buffer_size = batch_buffer_size
-        self.loss_averaging_span = loss_averaging_span
-        self.save_step = save_step
-
-        self.max_train_time = max_train_time
-        if self.max_train_time is not None:
-            self.max_train_time = convert_to_timedelta(self.max_train_time)
-
-    @abstractmethod
-    def calc_loss(self, data_batch: pd.DataFrame) -> torch.Tensor:
-        """User-provided function that will return loss on which backprob and
-        optimization will be made
-
-        Args:
-            data_batch (pd.DataFrame): Pandas dataframe with a single batch of data
-
-        Returns:
-            torch.Tensor: Loss tensor
-        """
-        pass
-
-    def _load_model(self):
-        checkpoint_files = glob.glob(f"{self.output_path}/*.model")
-        latest = latest_model(checkpoint_files)
-
-        if latest is not None:
-            epoch, batch = latest
-            self.model.load_state_dict(
-                torch.load(
-                    f"{self.output_path}/{epoch}-{batch}.model",
-                    map_location=self.device,
-                )
-            )
-            self.optimizer.load_state_dict(
-                torch.load(
-                    f"{self.output_path}/{epoch}-{batch}.optimizer",
-                    map_location=self.device,
-                )
-            )
-
-            return epoch, batch
-
-    def train(self):
-        """Preforms full training of the model"""
-        prepare_folder(self.output_path)
-        print(f"Training on {self.device}")
-
-        df = dd.read_parquet(self.input_path, engine="pyarrow")
-
-        epoch_start = 0
-        sample_start = 0
-        if self.fresh_start is False:
-            epoch_start, sample_start = self._load_model()
-            print(f"Loaded {epoch_start}-{sample_start}")
-
-        self.model.train()
-        losses = []
-
-        num_samples = df.tail(1).index.values[0] + 1
-        random_index_shuffle = np.random.permutation(range(num_samples))
-
-        training_stopped = False
-
-        time_max = datetime.max
-        if self.max_train_time is not None:
-            time_max = datetime.now() + self.max_train_time
-
-        for epoch in range(epoch_start, self.num_epochs):
-            if training_stopped:
-                break
-
-            i = sample_start
-            for data_batch in get_batches(
-                df, self.batch_size, self.batch_buffer_size, random_index_shuffle, i
-            ):
-                if len(data_batch) == 0:
-                    continue
-
-                loss = self.calc_loss(data_batch)
-
-                losses.append(loss.item())
-                if len(losses) > self.loss_averaging_span:
-                    fist_to_keep = -self.loss_averaging_span
-                    losses = losses[fist_to_keep:]
-
-                print(f"epoch: {epoch} | step: {i} | loss: {np.mean(losses)}")
-
-                self.optimizer.zero_grad()
-
-                if i % self.save_step == 0 and (
-                    i != sample_start or epoch != epoch_start
-                ):
-                    print(f"Saving: Epoch {epoch}, step {i}")
-                    save_training_step(
-                        self.output_path, f"{epoch}-{i}", self.model, self.optimizer
-                    )
-
-                if datetime.now() > time_max:
-                    print(f"Max time reached, saving: Epoch {epoch}, step {i}")
-                    save_training_step(
-                        self.output_path, f"{epoch}-{i}", self.model, self.optimizer
-                    )
-                    training_stopped = True
-                    break
-
-                loss.backward()
-                self.optimizer.step()
-
-                i += 1
-
-            sample_start = 0
-
-        if not training_stopped:
-            save_training_step(self.output_path, "final", self.model, self.optimizer)
diff --git a/src/pipelines/translation_based/stage1_extraction.py b/src/pipelines/translation_based/stage1_extraction.py
index 6ffdbf7..386211d 100644
--- a/src/pipelines/translation_based/stage1_extraction.py
+++ b/src/pipelines/translation_based/stage1_extraction.py
@@ -6,10 +6,7 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.translation_based.processing import (
-    RAW_TO_DATAFRAME_META,
-    raw_to_dataframe,
-)
+from src.pipelines.translation_based.processing import RAW_TO_DATAFRAME_META, raw_to_dataframe
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/translation_based/stage2_create_batches.py b/src/pipelines/translation_based/stage2_create_batches.py
index 83a2edc..ade8bf2 100644
--- a/src/pipelines/translation_based/stage2_create_batches.py
+++ b/src/pipelines/translation_based/stage2_create_batches.py
@@ -4,10 +4,7 @@ from dask import delayed
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.translation_based.processing import (
-    GENERATE_BATCHES_META,
-    generate_batches,
-)
+from src.pipelines.translation_based.processing import GENERATE_BATCHES_META, generate_batches
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
diff --git a/src/utils.py b/src/utils.py
index 4585f9b..8a9a09e 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -3,17 +3,16 @@ import pickle
 import re
 import shutil
 from datetime import timedelta
+from typing import Generator, List, Optional, Tuple
 
-from typing import Generator
-from src.batch_loading import get_batches, get_ordered_dataframe_len
-from typing import List, Optional, Tuple
-import pandas as pd
-import numpy as np
 import dask.dataframe as dd
+import numpy as np
+import pandas as pd
 import torch
-
 import yaml
 
+from src.batch_loading import get_batches, get_ordered_dataframe_len
+
 PROJECT_ROOT = os.path.dirname(os.path.realpath("/".join(__file__.split("/")) + "/.."))
 
 
@@ -114,7 +113,7 @@ def prepare_folder(path: str, wipe: bool = False) -> None:
 
 def unflattened_column(df: pd.DataFrame, name: str) -> np.ndarray:
     """Get column from the dataframe that was flattened. Dataframe must have columns
-    "name" and "name_shape", where name is 1D numpy array and name_shape is target 
+    "name" and "name_shape", where name is 1D numpy array and name_shape is target
     shape of this numpy array.
 
     Args:
@@ -125,14 +124,14 @@ def unflattened_column(df: pd.DataFrame, name: str) -> np.ndarray:
         np.ndarray: Unflattened mutlidiamenional column of shape Lx*(name_shape)
     """
 
-    values = df.apply(
-        lambda x: x[name].reshape(x[f"{name}_shape"]), axis=1
-    ).values
+    values = df.apply(lambda x: x[name].reshape(x[f"{name}_shape"]), axis=1).values
 
     return np.stack(values)
 
 
-def moving_average(values: List[np.ndarray], average_span: int) -> Tuple[float, np.ndarray]:
+def moving_average(
+    values: List[np.ndarray], average_span: int
+) -> Tuple[float, np.ndarray]:
     """Computes moving average and keeps only latests records
 
     Args:
@@ -145,7 +144,7 @@ def moving_average(values: List[np.ndarray], average_span: int) -> Tuple[float,
 
     if len(values) > average_span:
         values = values[-average_span:]
-    
+
     return np.mean(values), values
 
 
@@ -161,7 +160,15 @@ def optimizer_step(loss: torch.Tensor, optimizer: torch.optim.Optimizer) -> None
     optimizer.step()
 
 
-def training_loop(epoch_start: int, sample_start: int, num_epochs: int , df: dd.DataFrame, batch_size: int, batch_buffer_size: int, random_index_shuffle: np.ndarray) -> Generator[pd.DataFrame, int, int]:
+def training_loop(
+    epoch_start: int,
+    sample_start: int,
+    num_epochs: int,
+    df: dd.DataFrame,
+    batch_size: int,
+    batch_buffer_size: int,
+    random_index_shuffle: np.ndarray,
+) -> Generator[pd.DataFrame, int, int]:
     """Generator providing all data necessary to perform a training steps. This function handels epochs/steps management
 
     Args:
@@ -177,18 +184,18 @@ def training_loop(epoch_start: int, sample_start: int, num_epochs: int , df: dd.
         Generator: batch, epoch_num, step_num
     """
     for epoch in range(epoch_start, num_epochs):
-            i = sample_start
-            for data_batch in get_batches(
-                df, batch_size, batch_buffer_size, random_index_shuffle, i
-                ):
-                if len(data_batch) == 0:
-                    continue
+        i = sample_start
+        for data_batch in get_batches(
+            df, batch_size, batch_buffer_size, random_index_shuffle, i
+        ):
+            if len(data_batch) == 0:
+                continue
 
-                yield data_batch, epoch, i
+            yield data_batch, epoch, i
 
-                i += 1
+            i += 1
 
-            sample_start = 0
+        sample_start = 0
 
 
 def random_indexes(df: dd.DataFrame) -> np.ndarray:
@@ -211,7 +218,7 @@ def pickle_save(obj: any, path: str) -> None:
         obj (any): Object to pickle
         path (str): Path to output file
     """
-    with open(path, 'wb') as f:
+    with open(path, "wb") as f:
         pickle.dump(obj, f)
 
 
@@ -224,7 +231,7 @@ def pickle_read(path: str) -> any:
     Returns:
         any: Unpickled object
     """
-    with open(path, 'rb') as f:
+    with open(path, "rb") as f:
         return pickle.load(f)
 
 
diff --git a/tests/models/test_actions_model_base.py b/tests/models/test_actions_model_base.py
index 5fd17b3..cee3952 100644
--- a/tests/models/test_actions_model_base.py
+++ b/tests/models/test_actions_model_base.py
@@ -1,7 +1,11 @@
 import torch
 from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.actions_model_base import ActionsModelBase, ActionsModelBaseLoss
+from src.models.actions_model_base import (
+    ActionsModelBase,
+    ActionsModelBaseLoss,
+    ActionsModelBaseParams,
+)
 
 
 def test_dimensions():
@@ -11,7 +15,9 @@ def test_dimensions():
     tokens = BertTokenizerFast.from_pretrained(base_model)(
         "Ala ma kota", return_tensors="pt"
     )
-    model = ActionsModelBase(base_model, action_vector_size)
+
+    params = ActionsModelBaseParams(base_model, action_vector_size)
+    model = ActionsModelBase(params)
 
     result = model(tokens["input_ids"], tokens["attention_mask"])
 
diff --git a/tests/models/test_actions_model_mixed.py b/tests/models/test_actions_model_mixed.py
index 6e05bdd..c2ba214 100644
--- a/tests/models/test_actions_model_mixed.py
+++ b/tests/models/test_actions_model_mixed.py
@@ -1,7 +1,11 @@
 import torch
 from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.actions_model_mixed import ActionsModelMixed, ActionsModelMixedLoss
+from src.models.actions_model_mixed import (
+    ActionsModelMixed,
+    ActionsModelMixedLoss,
+    ActionsModelMixedParams,
+)
 
 
 def test_dimensions():
@@ -17,7 +21,8 @@ def test_dimensions():
     feedforward_neurons = 10
     max_len = 500
     dropout = 0.1
-    model = ActionsModelMixed(
+
+    params = ActionsModelMixedParams(
         tokenizer.vocab_size,
         embedding_size,
         num_heads,
@@ -27,6 +32,7 @@ def test_dimensions():
         max_len,
         dropout,
     )
+    model = ActionsModelMixed(params)
 
     actions_len = 3
     actions = torch.distributions.Multinomial(
diff --git a/tests/models/test_actions_model_restricted.py b/tests/models/test_actions_model_restricted.py
index 7446675..2ea18ef 100644
--- a/tests/models/test_actions_model_restricted.py
+++ b/tests/models/test_actions_model_restricted.py
@@ -4,6 +4,7 @@ from transformers.tokenization_bert import BertTokenizerFast
 from src.models.actions_model_restricted import (
     ActionsModelRestricted,
     ActionsModelRestrictedLoss,
+    ActionsModelRestrictedParams,
 )
 
 
@@ -14,7 +15,9 @@ def test_dimensions():
     tokens = BertTokenizerFast.from_pretrained(base_model)(
         "Ala ma kota", return_tensors="pt"
     )
-    model = ActionsModelRestricted(base_model, action_vector_size)
+
+    params = ActionsModelRestrictedParams(base_model, action_vector_size)
+    model = ActionsModelRestricted(params)
 
     result = model(tokens["input_ids"], tokens["attention_mask"])
 
diff --git a/tox.ini b/tox.ini
index cd8520e..eab80b1 100644
--- a/tox.ini
+++ b/tox.ini
@@ -24,6 +24,7 @@ exclude =
     data
     generated
 max-complexity = 10
+min_python_version = 3.8
 max-line-length = 80
 select = I,C,E,F,W,B,B950,TYP,T
 ignore = E501, C901, I201
-- 
GitLab


From 8ba4d132df9085dea89b63fdbe0364f34841ce47 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 26 Aug 2020 12:27:58 +0200
Subject: [PATCH 082/116] Moved checkpoints to utility

---
 src/checkpoints.py                            | 125 ------------
 src/pipelines/actions_based/train_base.py     |   6 +-
 src/pipelines/actions_based/train_mixed.py    |   6 +-
 .../actions_based/train_restricted.py         |   6 +-
 src/pipelines/translation_based/train.py      |  10 +-
 src/training.py                               |  66 -------
 src/utils.py                                  | 181 +++++++++++++++++-
 tests/test_training.py                        |  21 --
 tests/test_utils.py                           |  21 ++
 9 files changed, 223 insertions(+), 219 deletions(-)
 delete mode 100644 src/checkpoints.py
 delete mode 100644 src/training.py
 delete mode 100644 tests/test_training.py

diff --git a/src/checkpoints.py b/src/checkpoints.py
deleted file mode 100644
index a6394ad..0000000
--- a/src/checkpoints.py
+++ /dev/null
@@ -1,125 +0,0 @@
-from __future__ import annotations
-
-from datetime import datetime, timedelta
-from glob import glob
-from typing import Optional, Tuple, Type
-
-import torch
-from torch.optim.optimizer import Optimizer
-
-from src.models.interfaces import PunctuationModel
-from src.training import latest_model
-from src.utils import moving_average, prepare_folder
-
-
-class Saver:
-    def __init__(
-        self, save_dir: str, model: PunctuationModel, optimizer: Optimizer
-    ) -> None:
-        self.save_dir = save_dir
-        self.model = model
-        self.optimizer = optimizer
-
-        prepare_folder(self.save_dir)
-
-    def save(self, name):
-        self.model.save(self.save_dir, name)
-        torch.save(self.optimizer.state_dict(), f"{self.save_dir}/{name}.optimizer")
-
-
-class Loader:
-    def __init__(
-        self,
-        save_dir: str,
-        model_type: Type[PunctuationModel],
-        optimizer_type: Type[Optimizer],
-        device: str,
-    ) -> None:
-        self.save_dir = save_dir
-        self.device = device
-
-        self.model_type = model_type
-        self.optimizer_type = optimizer_type
-
-    def has_checkpoints(self) -> bool:
-        files = glob(f"{self.save_dir}/*.model")
-
-        return latest_model(files) is not None
-
-    def load(self, name) -> Tuple[PunctuationModel, Optimizer]:
-        model = self.model_type.load(self.save_dir, name)
-
-        optimizer = self.optimizer_type(model.parameters())
-        optimizer.load_state_dict(
-            torch.load(f"{self.save_dir}/{name}.optimizer", map_location=self.device)
-        )
-
-        return model, optimizer
-
-    def load_latest(self) -> Tuple[PunctuationModel, Optimizer, int, int]:
-        files = glob(f"{self.save_dir}/*.model")
-
-        model_id = latest_model(files)
-        if model_id is None:
-            return None
-
-        epoch, step = model_id
-        return self.load(f"{epoch}-{step}"), epoch, step
-
-
-class Checkpoint:
-    def __init__(
-        self, save_step, saver: Saver, start_step: int, start_epoch: int
-    ) -> None:
-        self.start_step = start_step
-        self.start_epoch = start_epoch
-        self.save_step = save_step
-
-        self.saver = saver
-
-    def step(self, epoch, step) -> None:
-        if step % self.save_step == 0 and (
-            step != self.start_step or epoch != self.start_epoch
-        ):
-            print(f"Saving: Epoch {epoch}, step {step}")
-            self.saver.save(f"{epoch}-{step}")
-
-
-class Timeout:
-    def __init__(self, duration: timedelta, saver: Optional[Saver]) -> None:
-        self.saver = saver
-        self.duration = duration
-        self.time_max = None
-
-    def start(self, time_now: datetime = datetime.now()):
-        self.time_max = datetime.max
-        if self.duration is not None:
-            self.time_max = datetime.now() + self.max_train_time
-
-    def step(self, epoch, step, time=None) -> bool:
-        assert self.time_max is not None
-
-        if time is None:
-            time = datetime.now()
-
-        if time > self.time_max:
-            if self.checkpoint is not None:
-                print(f"Max time reached, saving: Epoch {epoch}, step {step}")
-                self.saver.save(f"{epoch}-{step}")
-
-            return True
-
-        return False
-
-
-class ProgressTracker:
-    def __init__(self, device: str, loss_averaging_span) -> None:
-        print(f"Training on {device}")
-        self.loss_averaging_span = loss_averaging_span
-        self.losses = []
-
-    def step(self, epoch, step, loss) -> None:
-        self.losses.append(loss.item())
-        loss_mean, self.losses = moving_average(self.losses, self.loss_averaging_span)
-
-        print(f"epoch: {epoch} | step: {step} | loss: {loss_mean}")
diff --git a/src/pipelines/actions_based/train_base.py b/src/pipelines/actions_based/train_base.py
index 15da074..0b45bac 100755
--- a/src/pipelines/actions_based/train_base.py
+++ b/src/pipelines/actions_based/train_base.py
@@ -7,7 +7,6 @@ import numpy as np
 import torch
 from transformers import BertTokenizerFast
 
-from src.checkpoints import Checkpoint, Loader, ProgressTracker, Saver, Timeout
 from src.models.actions_model_base import (
     ActionsModelBase,
     ActionsModelBaseLoss,
@@ -16,6 +15,11 @@ from src.models.actions_model_base import (
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
 from src.utils import (
     PROJECT_ROOT,
+    Checkpoint,
+    Loader,
+    ProgressTracker,
+    Saver,
+    Timeout,
     convert_to_timedelta,
     get_config,
     random_indexes,
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index 806cdfd..1405d64 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -7,7 +7,6 @@ import numpy as np
 import torch
 from transformers import BertTokenizerFast
 
-from src.checkpoints import Checkpoint, Loader, ProgressTracker, Saver, Timeout
 from src.models.actions_model_mixed import (
     ActionsModelMixed,
     ActionsModelMixedLoss,
@@ -16,6 +15,11 @@ from src.models.actions_model_mixed import (
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
 from src.utils import (
     PROJECT_ROOT,
+    Checkpoint,
+    Loader,
+    ProgressTracker,
+    Saver,
+    Timeout,
     convert_to_timedelta,
     get_config,
     random_indexes,
diff --git a/src/pipelines/actions_based/train_restricted.py b/src/pipelines/actions_based/train_restricted.py
index 7856695..36edb69 100755
--- a/src/pipelines/actions_based/train_restricted.py
+++ b/src/pipelines/actions_based/train_restricted.py
@@ -7,7 +7,6 @@ import numpy as np
 import torch
 from transformers import BertTokenizerFast
 
-from src.checkpoints import Checkpoint, Loader, ProgressTracker, Saver, Timeout
 from src.models.actions_model_restricted import (
     ActionsModelRestricted,
     ActionsModelRestrictedLoss,
@@ -16,6 +15,11 @@ from src.models.actions_model_restricted import (
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
 from src.utils import (
     PROJECT_ROOT,
+    Checkpoint,
+    Loader,
+    ProgressTracker,
+    Saver,
+    Timeout,
     convert_to_timedelta,
     get_config,
     random_indexes,
diff --git a/src/pipelines/translation_based/train.py b/src/pipelines/translation_based/train.py
index 6e39ecc..67fcf7a 100755
--- a/src/pipelines/translation_based/train.py
+++ b/src/pipelines/translation_based/train.py
@@ -10,8 +10,14 @@ from transformers import BertTokenizerFast
 
 from src.batch_loading import get_batches, get_ordered_dataframe_len
 from src.models.TransformerSeq2Seq import TransformerSeq2Seq
-from src.training import latest_model, save_training_step
-from src.utils import PROJECT_ROOT, convert_to_timedelta, get_config, prepare_folder
+from src.utils import (
+    PROJECT_ROOT,
+    convert_to_timedelta,
+    get_config,
+    latest_model,
+    prepare_folder,
+    save_training_step,
+)
 
 INPUT_PATH = f"{PROJECT_ROOT}/generated/translations/stage4_reindexing"
 OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/translations"
diff --git a/src/training.py b/src/training.py
deleted file mode 100644
index d74d322..0000000
--- a/src/training.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import re
-from typing import List, Optional, Tuple
-
-import torch
-import torch.nn as nn
-import torch.optim as optim
-
-from src.utils import prepare_folder
-
-
-def latest_model(file_paths: List[str]) -> Optional[Tuple[int, int]]:
-    """Finds newest model in directory
-
-    Args:
-        files ([str]): List of all file paths that will be considered. File extension is discarded
-                       File names must be in format epoch_num-batch_num.extension
-
-    Returns:
-        (int, int): Tuple of (latest_batch, latest_step) for latest model
-    """
-
-    furthest_epoch = -1
-    furthest_batch_num = -1
-    for checkpoint_file in file_paths:
-        filename = checkpoint_file.split("/")[-1].split(".")[0]
-
-        result = re.search(r"^(\d+)-(\d+)$", filename)
-        if result is not None:
-            epoch, batch = [int(x) for x in result.groups()]
-
-            if epoch > furthest_epoch:
-                furthest_epoch = epoch
-                furthest_batch_num = batch
-            elif epoch == furthest_epoch:
-                furthest_batch_num = max(batch, furthest_batch_num)
-
-    if (furthest_epoch == -1) or (furthest_batch_num == -1):
-        return None
-
-    return furthest_epoch, furthest_batch_num
-
-
-def save_training_step(
-    dir: str,
-    name: str,
-    model: nn.Module,
-    optimizer: Optional[optim.Optimizer] = None,
-    create_dir: bool = False,
-) -> None:
-    """Saves a trainig step to a directory
-
-    Args:
-        dir (str): Directory where step will be saved
-        name (str): Name of the step (eg. "0-1000")
-        model (nn.Module): model that will be saved
-        optimizer (optim.Optimizer): optimizer that will be saved. Might be None
-    """
-    if create_dir:
-        prepare_folder(dir, wipe=False)
-
-    torch.save(model.state_dict(), f"{dir}/{name}.model")
-
-    if optimizer is not None:
-        torch.save(
-            optimizer.state_dict(), f"{dir}/{name}.optimizer",
-        )
diff --git a/src/utils.py b/src/utils.py
index 8a9a09e..27e1367 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -1,21 +1,140 @@
+from __future__ import annotations
+
 import os
 import pickle
 import re
 import shutil
-from datetime import timedelta
-from typing import Generator, List, Optional, Tuple
+from datetime import datetime, timedelta
+from glob import glob
+from typing import Generator, List, Optional, Tuple, Type
 
 import dask.dataframe as dd
 import numpy as np
 import pandas as pd
 import torch
+import torch.nn as nn
 import yaml
+from torch.optim import Optimizer
 
 from src.batch_loading import get_batches, get_ordered_dataframe_len
+from src.models.interfaces import PunctuationModel
 
 PROJECT_ROOT = os.path.dirname(os.path.realpath("/".join(__file__.split("/")) + "/.."))
 
 
+class Saver:
+    def __init__(
+        self, save_dir: str, model: PunctuationModel, optimizer: Optimizer
+    ) -> None:
+        self.save_dir = save_dir
+        self.model = model
+        self.optimizer = optimizer
+
+        prepare_folder(self.save_dir)
+
+    def save(self, name):
+        self.model.save(self.save_dir, name)
+        torch.save(self.optimizer.state_dict(), f"{self.save_dir}/{name}.optimizer")
+
+
+class Loader:
+    def __init__(
+        self,
+        save_dir: str,
+        model_type: Type[PunctuationModel],
+        optimizer_type: Type[Optimizer],
+        device: str,
+    ) -> None:
+        self.save_dir = save_dir
+        self.device = device
+
+        self.model_type = model_type
+        self.optimizer_type = optimizer_type
+
+    def has_checkpoints(self) -> bool:
+        files = glob(f"{self.save_dir}/*.model")
+
+        return latest_model(files) is not None
+
+    def load(self, name) -> Tuple[PunctuationModel, Optimizer]:
+        model = self.model_type.load(self.save_dir, name)
+
+        optimizer = self.optimizer_type(model.parameters())
+        optimizer.load_state_dict(
+            torch.load(f"{self.save_dir}/{name}.optimizer", map_location=self.device)
+        )
+
+        return model, optimizer
+
+    def load_latest(self) -> Tuple[PunctuationModel, Optimizer, int, int]:
+        files = glob(f"{self.save_dir}/*.model")
+
+        model_id = latest_model(files)
+        if model_id is None:
+            return None
+
+        epoch, step = model_id
+        return self.load(f"{epoch}-{step}"), epoch, step
+
+
+class Checkpoint:
+    def __init__(
+        self, save_step, saver: Saver, start_step: int, start_epoch: int
+    ) -> None:
+        self.start_step = start_step
+        self.start_epoch = start_epoch
+        self.save_step = save_step
+
+        self.saver = saver
+
+    def step(self, epoch, step) -> None:
+        if step % self.save_step == 0 and (
+            step != self.start_step or epoch != self.start_epoch
+        ):
+            print(f"Saving: Epoch {epoch}, step {step}")
+            self.saver.save(f"{epoch}-{step}")
+
+
+class Timeout:
+    def __init__(self, duration: timedelta, saver: Optional[Saver]) -> None:
+        self.saver = saver
+        self.duration = duration
+        self.time_max = None
+
+    def start(self, time_now: datetime = datetime.now()):
+        self.time_max = datetime.max
+        if self.duration is not None:
+            self.time_max = datetime.now() + self.max_train_time
+
+    def step(self, epoch, step, time=None) -> bool:
+        assert self.time_max is not None
+
+        if time is None:
+            time = datetime.now()
+
+        if time > self.time_max:
+            if self.checkpoint is not None:
+                print(f"Max time reached, saving: Epoch {epoch}, step {step}")
+                self.saver.save(f"{epoch}-{step}")
+
+            return True
+
+        return False
+
+
+class ProgressTracker:
+    def __init__(self, device: str, loss_averaging_span) -> None:
+        print(f"Training on {device}")
+        self.loss_averaging_span = loss_averaging_span
+        self.losses = []
+
+    def step(self, epoch, step, loss) -> None:
+        self.losses.append(loss.item())
+        loss_mean, self.losses = moving_average(self.losses, self.loss_averaging_span)
+
+        print(f"epoch: {epoch} | step: {step} | loss: {loss_mean}")
+
+
 def get_config() -> dict:
     """Returns dict with config values
 
@@ -275,3 +394,61 @@ def convert_to_timedelta(time_val: Optional[str]) -> Optional[timedelta]:
         return timedelta(days=num)
     else:
         return None
+
+
+def latest_model(file_paths: List[str]) -> Optional[Tuple[int, int]]:
+    """Finds newest model in directory
+
+    Args:
+        files ([str]): List of all file paths that will be considered. File extension is discarded
+                       File names must be in format epoch_num-batch_num.extension
+
+    Returns:
+        (int, int): Tuple of (latest_batch, latest_step) for latest model
+    """
+
+    furthest_epoch = -1
+    furthest_batch_num = -1
+    for checkpoint_file in file_paths:
+        filename = checkpoint_file.split("/")[-1].split(".")[0]
+
+        result = re.search(r"^(\d+)-(\d+)$", filename)
+        if result is not None:
+            epoch, batch = [int(x) for x in result.groups()]
+
+            if epoch > furthest_epoch:
+                furthest_epoch = epoch
+                furthest_batch_num = batch
+            elif epoch == furthest_epoch:
+                furthest_batch_num = max(batch, furthest_batch_num)
+
+    if (furthest_epoch == -1) or (furthest_batch_num == -1):
+        return None
+
+    return furthest_epoch, furthest_batch_num
+
+
+def save_training_step(
+    dir: str,
+    name: str,
+    model: nn.Module,
+    optimizer: Optional[Optimizer] = None,
+    create_dir: bool = False,
+) -> None:
+    """Saves a trainig step to a directory
+
+    Args:
+        dir (str): Directory where step will be saved
+        name (str): Name of the step (eg. "0-1000")
+        model (nn.Module): model that will be saved
+        optimizer (optim.Optimizer): optimizer that will be saved. Might be None
+    """
+    if create_dir:
+        prepare_folder(dir, wipe=False)
+
+    torch.save(model.state_dict(), f"{dir}/{name}.model")
+
+    if optimizer is not None:
+        torch.save(
+            optimizer.state_dict(), f"{dir}/{name}.optimizer",
+        )
diff --git a/tests/test_training.py b/tests/test_training.py
deleted file mode 100644
index 2aa5d6a..0000000
--- a/tests/test_training.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from src.training import latest_model
-
-
-def test_latest_model():
-    files = []
-    assert latest_model(files) is None
-
-    files.append("/path/tam/pam/Wrongformat.b")
-    assert latest_model(files) is None
-
-    files.append("/path/tam/pam/0-2000.b")
-    assert latest_model(files) == (0, 2000)
-
-    files.append("/path/tam/pam/0-3000.c")
-    assert latest_model(files) == (0, 3000)
-
-    files.append("/path/tam/pam/1-1000.a")
-    assert latest_model(files) == (1, 1000)
-
-    files.append("/path/tam/pam/1-500.a")
-    assert latest_model(files) == (1, 1000)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 0702891..49fbc4a 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,6 +1,7 @@
 from src.utils import (
     convert_to_timedelta,
     input_preprocess,
+    latest_model,
     output_preprocess,
     remove_multiple_spaces,
     remove_punctuation,
@@ -56,3 +57,23 @@ def test_convert_to_timedelta():
     assert convert_to_timedelta("2s").days == 0
     assert convert_to_timedelta("2s").seconds == 2
     assert convert_to_timedelta("2s").microseconds == 0
+
+
+def test_latest_model():
+    files = []
+    assert latest_model(files) is None
+
+    files.append("/path/tam/pam/Wrongformat.b")
+    assert latest_model(files) is None
+
+    files.append("/path/tam/pam/0-2000.b")
+    assert latest_model(files) == (0, 2000)
+
+    files.append("/path/tam/pam/0-3000.c")
+    assert latest_model(files) == (0, 3000)
+
+    files.append("/path/tam/pam/1-1000.a")
+    assert latest_model(files) == (1, 1000)
+
+    files.append("/path/tam/pam/1-500.a")
+    assert latest_model(files) == (1, 1000)
-- 
GitLab


From 04b65b601adc7fb5fb5b4a4953da828cf5c9f2a2 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 26 Aug 2020 13:14:01 +0200
Subject: [PATCH 083/116] Fixed model loading

---
 params.yaml                                     | 4 ++--
 src/models/actions_model_base.py                | 5 ++---
 src/models/actions_model_mixed.py               | 5 ++---
 src/models/actions_model_restricted.py          | 3 +--
 src/models/interfaces.py                        | 2 +-
 src/pipelines/actions_based/train_base.py       | 5 ++---
 src/pipelines/actions_based/train_mixed.py      | 6 ++----
 src/pipelines/actions_based/train_restricted.py | 6 ++----
 src/utils.py                                    | 8 +++++---
 9 files changed, 19 insertions(+), 25 deletions(-)

diff --git a/params.yaml b/params.yaml
index c965809..500eeff 100644
--- a/params.yaml
+++ b/params.yaml
@@ -32,10 +32,10 @@ actions:
         num_epochs: 5
         batch_size: 2
         batch_buffer_size: 100
-        save_step: 1000
+        save_step: 50
         max_training_time: null
         loss_averaging_span: 1000
-        fresh_start: true
+        fresh_start: false
         device: "cuda:0"
 
     training_restricted:
diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index d0c4a6f..4c45125 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -71,9 +71,8 @@ class ActionsModelBase(PunctuationModel):
     @staticmethod
     def load(dir: str, name: str, device: str) -> ActionsModelBase:
         params = pickle_read(f"{dir}/{name}.config")
-        model = ActionsModelBase(params)
-
-        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
+        model = ActionsModelBase(params).to(device)
+        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device))
 
         return model
 
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index ced75a7..b489b7f 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -221,9 +221,8 @@ class ActionsModelMixed(PunctuationModel):
     @staticmethod
     def load(dir: str, name: str, device: str) -> PunctuationModel:
         params = pickle_read(f"{dir}/{name}.config")
-        model = ActionsModelMixed(params)
-
-        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
+        model = ActionsModelMixed(params).to(device)
+        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device))
 
         return model
 
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index de8d338..4b1097f 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -80,8 +80,7 @@ class ActionsModelRestricted(PunctuationModel):
     @staticmethod
     def load(dir: str, name: str, device: str) -> ActionsModelRestricted:
         params = pickle_read(f"{dir}/{name}.config")
-        model = ActionsModelRestricted(params)
-
+        model = ActionsModelRestricted(params).to(device)
         model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
 
         return model
diff --git a/src/models/interfaces.py b/src/models/interfaces.py
index f54df51..09e26d6 100644
--- a/src/models/interfaces.py
+++ b/src/models/interfaces.py
@@ -7,7 +7,7 @@ import torch.nn as nn
 
 class PunctuationModel(nn.Module, ABC):
     def __init__(self) -> None:
-        super().__init__()
+        super(PunctuationModel, self).__init__()
 
     @abstractmethod
     def save(self, dir: str, name: str) -> None:
diff --git a/src/pipelines/actions_based/train_base.py b/src/pipelines/actions_based/train_base.py
index 0b45bac..b8c8bce 100755
--- a/src/pipelines/actions_based/train_base.py
+++ b/src/pipelines/actions_based/train_base.py
@@ -59,13 +59,12 @@ if __name__ == "__main__":
         model, optimizer, epoch_start, sample_start = loader.load_latest()
     else:
         params = ActionsModelBaseParams(base_model, len(ACTIONS_KEYS))
-        model = ActionsModelBase(params)
+        model = ActionsModelBase(params).to(device)
 
-        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate).to(device)
         epoch_start, sample_start = (0, 0)
 
     model.train()
-    model.to(device)
 
     # Load loss weights
     with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index 1405d64..64a82d8 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -74,13 +74,11 @@ if __name__ == "__main__":
             500,
             dropout,
         )
-        model = ActionsModelMixed(params)
-
-        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+        model = ActionsModelMixed(params).to(device)
+        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate).to(device)
         epoch_start, sample_start = (0, 0)
 
     model.train()
-    model.to(device)
 
     # Load loss weights
     with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
diff --git a/src/pipelines/actions_based/train_restricted.py b/src/pipelines/actions_based/train_restricted.py
index 36edb69..29118dc 100755
--- a/src/pipelines/actions_based/train_restricted.py
+++ b/src/pipelines/actions_based/train_restricted.py
@@ -62,13 +62,11 @@ if __name__ == "__main__":
         model, optimizer, epoch_start, sample_start = loader.load_latest()
     else:
         params = ActionsModelRestrictedParams(base_model, len(ACTIONS_KEYS) + 1)
-        model = ActionsModelRestricted(params)
-
-        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+        model = ActionsModelRestricted(params).to(device)
+        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate).to(device)
         epoch_start, sample_start = (0, 0)
 
     model.train()
-    model.to(device)
 
     # Load loss weights
     with open(f"{INPUT_STATS_PATH}/stats.pickle", "rb") as f:
diff --git a/src/utils.py b/src/utils.py
index 27e1367..50d7007 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -57,13 +57,15 @@ class Loader:
         return latest_model(files) is not None
 
     def load(self, name) -> Tuple[PunctuationModel, Optimizer]:
-        model = self.model_type.load(self.save_dir, name)
+        model = self.model_type.load(self.save_dir, name, self.device)
 
         optimizer = self.optimizer_type(model.parameters())
         optimizer.load_state_dict(
             torch.load(f"{self.save_dir}/{name}.optimizer", map_location=self.device)
         )
 
+        print(f"Loaded model {name}")
+
         return model, optimizer
 
     def load_latest(self) -> Tuple[PunctuationModel, Optimizer, int, int]:
@@ -74,12 +76,12 @@ class Loader:
             return None
 
         epoch, step = model_id
-        return self.load(f"{epoch}-{step}"), epoch, step
+        return *(self.load(f"{epoch}-{step}")), epoch, step
 
 
 class Checkpoint:
     def __init__(
-        self, save_step, saver: Saver, start_step: int, start_epoch: int
+        self, save_step, saver: Saver, start_epoch: int, start_step: int
     ) -> None:
         self.start_step = start_step
         self.start_epoch = start_epoch
-- 
GitLab


From 162085b4fe8b22765ff62ee597ae9d2b264ddb0c Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 26 Aug 2020 15:57:10 +0200
Subject: [PATCH 084/116] Testing now works for all actions models

---
 src/models/actions_model_base.py         | 21 +++++-
 src/models/actions_model_mixed.py        | 62 ++++++++---------
 src/models/actions_model_restricted.py   | 23 ++++++-
 src/models/interfaces.py                 | 23 ++++++-
 src/pipelines/actions_based/test.py      | 88 ++++++++++++++++++++++++
 tests/models/test_actions_model_mixed.py |  2 +
 6 files changed, 182 insertions(+), 37 deletions(-)
 create mode 100644 src/pipelines/actions_based/test.py

diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index 4c45125..f23ecc0 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -7,8 +7,9 @@ import torch.nn as nn
 from torch.nn.modules.loss import BCEWithLogitsLoss
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
+from transformers.tokenization_bert import BertTokenizerFast
 
-from src.models.interfaces import PunctuationModel
+from src.models.interfaces import ActionsModel
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
 from src.utils import pickle_read, pickle_save, prepare_folder
 
@@ -28,7 +29,7 @@ class ActionsModelBaseParams:
     num_labels: int = len(ACTIONS_KEYS)
 
 
-class ActionsModelBase(PunctuationModel):
+class ActionsModelBase(ActionsModel):
     """Model based on simple multilabel per-token classifiaction. Each token is binarly classified in n-dimensions"""
 
     def __init__(self, params: ActionsModelBaseParams) -> None:
@@ -41,6 +42,7 @@ class ActionsModelBase(PunctuationModel):
         super(ActionsModelBase, self).__init__()
         self.params = params
 
+        self.tokenizer = BertTokenizerFast.from_pretrained(params.base_model)
         config = PretrainedConfig.from_pretrained(params.base_model)
         config.num_labels = params.num_labels
 
@@ -63,6 +65,21 @@ class ActionsModelBase(PunctuationModel):
 
         return y_pred
 
+    def predict_raw(
+        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Function that maps input_ids tensors into per-token labels
+
+        Args:
+            input_ids (torch.Tensor): Token ids of input. Shape BxL
+            attention_mask (torch.Tensor): Attention mask of tokens. Shape BxL
+
+        Returns:
+            torch.Tensor: Per-token action-vector labels. Shape BxLxA
+        """
+
+        return self.forward(input_ids, attention_mask=attention_mask).sigmoid()
+
     def save(self, dir: str, name: str) -> None:
         prepare_folder(dir)
         torch.save(self.state_dict(), f"{dir}/{name}.model")
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index b489b7f..5e1fc5f 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -35,6 +35,7 @@ class ActionsModelMixedParams:
     """
 
     vocab_size: int
+    threshold: float = 0.9
     embedding_size: int = 200
     num_heads: int = 4
     num_layers: int = 2
@@ -140,38 +141,6 @@ class ActionsModelMixed(PunctuationModel):
 
         super(ActionsModelMixed, self).to(device)
 
-    def predict_raw(
-        self,
-        input_ids: torch.Tensor,
-        attention_mask: torch.Tensor,
-        threshold: float = 0.9,
-        max_cond_len: Optional[int] = None,
-    ):
-        target_device = self.device
-
-        input_ids = input_ids.to(target_device)
-        attention_mask = (attention_mask == 0).to(target_device)
-        inputs = torch.tensor([action_vector(["upper_case"])], dtype=torch.float).to(
-            target_device
-        )
-
-        if max_cond_len is None:
-            max_cond_len = np.iinfo(np.int).max
-
-        for _ in range(input_ids.shape[1] - 2):
-            input_start = max(0, len(inputs) - max_cond_len)
-
-            prediction_raw = self.forward(
-                input_ids[:, input_start:],
-                inputs[:, input_start:].reshape(1, -1, self.num_labels),
-                attention_mask,
-            ).sigmoid()
-
-            new_output = (prediction_raw[0, -1:, :] > threshold).astype(torch.float)
-            inputs = torch.stack([inputs, new_output], dim=1)
-
-        return inputs
-
     def predict(
         self,
         text: str,
@@ -213,6 +182,35 @@ class ActionsModelMixed(PunctuationModel):
 
         return recover_text(text, prediction_binary)
 
+    def predict_raw(
+        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Function that maps input_ids tensors into per-token labels
+
+        Args:
+            input_ids (torch.Tensor): Token ids of input. Shape BxL
+            attention_mask (torch.Tensor): Attention mask of tokens. Shape BxL
+
+        Returns:
+            torch.Tensor: Per-token action-vector labels. Shape BxLxA
+        """
+        outputs = torch.tensor(action_vector(["upper_case"]), dtype=torch.float).to(
+            input_ids.device
+        )
+        outputs = outputs.unsqueeze(0).unsqueeze(0).repeat(input_ids.shape[0], 1, 1)
+
+        for _ in range(input_ids.shape[1] - 1):
+            prediction_raw = self.forward(
+                input_ids, outputs, (attention_mask == 0)
+            ).sigmoid()
+
+            prediction_raw = (prediction_raw[:, -1:, :] > self.params.threshold).type(
+                torch.float
+            )
+            outputs = torch.cat([outputs, prediction_raw], dim=1)
+
+        return outputs
+
     def save(self, dir: str, name: str) -> None:
         prepare_folder(dir)
         torch.save(self.state_dict(), f"{dir}/{name}.model")
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index 4b1097f..12b7f16 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -8,7 +8,7 @@ from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
 
 from src.models.actions_model_mixed import ActionsModelMixed
-from src.models.interfaces import PunctuationModel
+from src.models.interfaces import ActionsModel, PunctuationModel
 from src.utils import pickle_read, pickle_save, prepare_folder
 
 
@@ -26,7 +26,7 @@ class ActionsModelRestrictedParams:
     extended_action_vector_size: int
 
 
-class ActionsModelRestricted(PunctuationModel):
+class ActionsModelRestricted(ActionsModel):
     """Similar to ActionsModelBase, however no-punctuation class is added
     and punctuation-related entries are treaded as proper categorical distribution
     """
@@ -64,6 +64,25 @@ class ActionsModelRestricted(PunctuationModel):
 
         return y_pred
 
+    def predict_raw(
+        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Function that maps input_ids tensors into per-token labels
+
+        Args:
+            input_ids (torch.Tensor): Token ids of input. Shape BxL
+            attention_mask (torch.Tensor): Attention mask of tokens. Shape BxL
+
+        Returns:
+            torch.Tensor: Per-token action-vector labels. Shape BxLxA
+        """
+
+        logits = self.forward(input_ids, attention_mask=attention_mask)
+        prob_uppercase = logits[:, :, :1].sigmoid()
+        prob_punctuation = logits[:, :, 1:].softmax(dim=-1)[:, :, :-1]
+
+        return torch.cat([prob_uppercase, prob_punctuation], dim=-1)
+
     @staticmethod
     def _logit(x: torch.Tensor):
         EPS = 1e-5
diff --git a/src/models/interfaces.py b/src/models/interfaces.py
index 09e26d6..9fb15a7 100644
--- a/src/models/interfaces.py
+++ b/src/models/interfaces.py
@@ -2,12 +2,13 @@ from __future__ import annotations
 
 from abc import ABC, abstractmethod
 
+import torch
 import torch.nn as nn
 
 
 class PunctuationModel(nn.Module, ABC):
     def __init__(self) -> None:
-        super(PunctuationModel, self).__init__()
+        super().__init__()
 
     @abstractmethod
     def save(self, dir: str, name: str) -> None:
@@ -17,3 +18,23 @@ class PunctuationModel(nn.Module, ABC):
     @abstractmethod
     def load(dir: str, name: str, device: str) -> PunctuationModel:
         pass
+
+
+class ActionsModel(PunctuationModel):
+    def __init__(self) -> None:
+        super().__init__()
+
+    @abstractmethod
+    def predict_raw(
+        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
+    ) -> torch.Tensor:
+        """Function that maps input_ids tensors into per-token labels
+
+        Args:
+            input_ids (torch.Tensor): Token ids of input. Shape BxL
+            attention_mask (torch.Tensor): Attention mask of tokens. Shape BxL
+
+        Returns:
+            torch.Tensor: Per-token action-vector labels. Shape BxLxA
+        """
+        pass
diff --git a/src/pipelines/actions_based/test.py b/src/pipelines/actions_based/test.py
new file mode 100644
index 0000000..d8250b6
--- /dev/null
+++ b/src/pipelines/actions_based/test.py
@@ -0,0 +1,88 @@
+import dask.dataframe as dd
+import numpy as np
+import torch
+from sklearn.metrics import f1_score
+from tqdm import trange
+from transformers.tokenization_bert import BertTokenizerFast
+
+from src.batch_loading import get_ordered_dataframe_len
+from src.models.actions_model_mixed import ActionsModelMixed, ActionsModelMixedParams
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.pipelines.actions_based.scoring import predictions_threshold
+from src.utils import PROJECT_ROOT, unflattened_column
+
+NUM_CHECK = 1
+BATCH_SIZE = 1
+TEST_DATASET = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
+
+
+class Metrics:
+    def __init__(self, name: str) -> None:
+        self.name = name
+
+    def compute_metrics(self, predictions: np.ndarray, targets: np.ndarray):
+        f1_scores = self._f1_scores(predictions, targets)
+
+        print(f"Model {self.name} | F1 scores")
+        print("----------------------")
+        print(dict(zip(ACTIONS_KEYS, f1_scores)))
+        print("----------------------")
+
+    def _f1_scores(self, predictions: np.ndarray, targets: np.ndarray) -> dict:
+        predictions = predictions_threshold(predictions, 0.0)
+        return f1_score(predictions, targets, average=None)
+
+
+if __name__ == "__main__":
+    print("Getting dataset info...")
+    df = dd.read_parquet(TEST_DATASET, engine="pyarrow")
+
+    print("Loading dataset to memory...")
+    df_len = get_ordered_dataframe_len(df)
+
+    data_start = df_len - NUM_CHECK
+    data_end = df_len
+    pdf = df.loc[data_start:data_end].compute().reset_index()
+
+    device = torch.device("cpu")
+
+    tokenizer = BertTokenizerFast.from_pretrained("dkleczek/bert-base-polish-cased-v1")
+
+    params = ActionsModelMixedParams(tokenizer.vocab_size)
+    model = ActionsModelMixed(params)
+
+    true_batches = []
+    prediction_batches = []
+
+    print("Computing...")
+    num_batches = len(pdf) // BATCH_SIZE
+    for batch in trange(num_batches):
+
+        batch_start = batch * BATCH_SIZE
+        batch_end = (batch + 1) * BATCH_SIZE
+        batch_pdf = pdf.iloc[batch_start:batch_end]
+
+        inputs = unflattened_column(batch_pdf, "source")
+        outputs = unflattened_column(batch_pdf, "target")
+        attentions_mask = unflattened_column(batch_pdf, "attention_mask")
+
+        inputs = torch.tensor(inputs, dtype=torch.long).squeeze(dim=-1).to(device)
+        outputs = torch.tensor(outputs, dtype=torch.float).to(device)
+        attentions_mask = torch.tensor(attentions_mask).to(device)
+
+        prediction_batch = (
+            model.predict_raw(inputs, attentions_mask).detach().cpu().numpy()
+        )
+        prediction_batches.append(prediction_batch)
+
+        true_batches.append(outputs.cpu().numpy())
+
+    predictions = np.concatenate(prediction_batches, axis=0).reshape(
+        -1, prediction_batches[0].shape[-1]
+    )
+    trues = np.concatenate(true_batches, axis=0).reshape(-1, true_batches[0].shape[-1])
+
+    metrics = Metrics("actions-base")
+
+    print("Calculating metrics...")
+    metrics.compute_metrics(predictions, trues)
diff --git a/tests/models/test_actions_model_mixed.py b/tests/models/test_actions_model_mixed.py
index c2ba214..79d17e2 100644
--- a/tests/models/test_actions_model_mixed.py
+++ b/tests/models/test_actions_model_mixed.py
@@ -16,6 +16,7 @@ def test_dimensions():
     tokens = tokenizer("Ala ma kota", return_tensors="pt")
 
     embedding_size = 20
+    threshold = 0.9
     num_heads = 2
     num_layers = 2
     feedforward_neurons = 10
@@ -24,6 +25,7 @@ def test_dimensions():
 
     params = ActionsModelMixedParams(
         tokenizer.vocab_size,
+        threshold,
         embedding_size,
         num_heads,
         num_layers,
-- 
GitLab


From eb1a42d316efd123ca8d85bb1ed88bffb5fadc46 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 26 Aug 2020 16:27:43 +0200
Subject: [PATCH 085/116] Testing script is now parametrized

---
 src/pipelines/actions_based/scoring.py | 21 +++++-
 src/pipelines/actions_based/test.py    | 94 ++++++++++++++++----------
 2 files changed, 80 insertions(+), 35 deletions(-)

diff --git a/src/pipelines/actions_based/scoring.py b/src/pipelines/actions_based/scoring.py
index 254d557..1187052 100644
--- a/src/pipelines/actions_based/scoring.py
+++ b/src/pipelines/actions_based/scoring.py
@@ -1,7 +1,26 @@
 from typing import List
 
 import numpy as np
-from sklearn.metrics import auc, roc_curve
+from sklearn.metrics import auc, f1_score, roc_curve
+
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
+
+
+class Metrics:
+    def __init__(self, name: str) -> None:
+        self.name = name
+
+    def compute_metrics(self, predictions: np.ndarray, targets: np.ndarray):
+        f1_scores = self._f1_scores(predictions, targets)
+
+        print(f"Model {self.name} | F1 scores")
+        print("----------------------")
+        print(dict(zip(ACTIONS_KEYS, f1_scores)))
+        print("----------------------")
+
+    def _f1_scores(self, predictions: np.ndarray, targets: np.ndarray) -> dict:
+        predictions = predictions_threshold(predictions, 0.0)
+        return f1_score(predictions, targets, average=None)
 
 
 def predictions_threshold(
diff --git a/src/pipelines/actions_based/test.py b/src/pipelines/actions_based/test.py
index d8250b6..c982aa0 100644
--- a/src/pipelines/actions_based/test.py
+++ b/src/pipelines/actions_based/test.py
@@ -1,65 +1,91 @@
+import argparse
+
 import dask.dataframe as dd
 import numpy as np
 import torch
-from sklearn.metrics import f1_score
 from tqdm import trange
-from transformers.tokenization_bert import BertTokenizerFast
 
 from src.batch_loading import get_ordered_dataframe_len
-from src.models.actions_model_mixed import ActionsModelMixed, ActionsModelMixedParams
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.pipelines.actions_based.scoring import predictions_threshold
+from src.models.actions_model_base import ActionsModelBase
+from src.models.actions_model_mixed import ActionsModelMixed
+from src.models.actions_model_restricted import ActionsModelRestricted
+from src.pipelines.actions_based.scoring import Metrics
 from src.utils import PROJECT_ROOT, unflattened_column
 
-NUM_CHECK = 1
-BATCH_SIZE = 1
-TEST_DATASET = f"{PROJECT_ROOT}/generated/actions/stage4_reindexing"
-
-
-class Metrics:
-    def __init__(self, name: str) -> None:
-        self.name = name
-
-    def compute_metrics(self, predictions: np.ndarray, targets: np.ndarray):
-        f1_scores = self._f1_scores(predictions, targets)
+SUPPORTED_MODELS = {
+    "base": ActionsModelBase,
+    "restricted": ActionsModelRestricted,
+    "mixed": ActionsModelMixed,
+}
 
-        print(f"Model {self.name} | F1 scores")
-        print("----------------------")
-        print(dict(zip(ACTIONS_KEYS, f1_scores)))
-        print("----------------------")
-
-    def _f1_scores(self, predictions: np.ndarray, targets: np.ndarray) -> dict:
-        predictions = predictions_threshold(predictions, 0.0)
-        return f1_score(predictions, targets, average=None)
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Evaluate actions model")
+    parser.add_argument(
+        "-a",
+        "--architecture",
+        required=True,
+        choices=SUPPORTED_MODELS.keys(),
+        help="Model architecture",
+    )
+    parser.add_argument(
+        "-d",
+        "--directory",
+        required=True,
+        help="Directory where trained model is located, relative to project root",
+    )
+    parser.add_argument("-m", "--model", default="final", help="Pretrained model name")
+    parser.add_argument("-b", "--batch", type=int, default=1, help="Batch size")
+    parser.add_argument(
+        "-l",
+        "--limit",
+        type=int,
+        default=0,
+        help="Limits how much samples from test set will be evaluated",
+    )
+    parser.add_argument(
+        "-s",
+        "--dataset",
+        type=str,
+        required=True,
+        help="Directory where test dataset is located, relative to project root",
+    )
+    parser.add_argument(
+        "-e",
+        "--device",
+        type=str,
+        default="cpu",
+        help="Device on which inference will be done",
+    )
+    args = parser.parse_args()
 
+    test_dataset = f"{PROJECT_ROOT}/{args.dataset}"
 
-if __name__ == "__main__":
     print("Getting dataset info...")
-    df = dd.read_parquet(TEST_DATASET, engine="pyarrow")
+    df = dd.read_parquet(test_dataset, engine="pyarrow")
 
     print("Loading dataset to memory...")
     df_len = get_ordered_dataframe_len(df)
 
-    data_start = df_len - NUM_CHECK
+    data_start = max(df_len - args.limit, 0)
     data_end = df_len
     pdf = df.loc[data_start:data_end].compute().reset_index()
 
     device = torch.device("cpu")
 
-    tokenizer = BertTokenizerFast.from_pretrained("dkleczek/bert-base-polish-cased-v1")
-
-    params = ActionsModelMixedParams(tokenizer.vocab_size)
-    model = ActionsModelMixed(params)
+    print(f"Loading model {args.model}")
+    model_location = f"{PROJECT_ROOT}/{args.directory}"
+    model_type = SUPPORTED_MODELS[args.architecture]
+    model = model_type.load(model_location, args.model, args.device)
 
     true_batches = []
     prediction_batches = []
 
     print("Computing...")
-    num_batches = len(pdf) // BATCH_SIZE
+    num_batches = len(pdf) // args.batch
     for batch in trange(num_batches):
 
-        batch_start = batch * BATCH_SIZE
-        batch_end = (batch + 1) * BATCH_SIZE
+        batch_start = batch * args.batch
+        batch_end = (batch + 1) * args.batch
         batch_pdf = pdf.iloc[batch_start:batch_end]
 
         inputs = unflattened_column(batch_pdf, "source")
-- 
GitLab


From 4a9582f5e1944676f3f7048785195668f27b1edc Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 26 Aug 2020 16:38:46 +0200
Subject: [PATCH 086/116] Test results can be now saved to file

---
 src/pipelines/actions_based/scoring.py | 33 ++++++++++++++++++++------
 src/pipelines/actions_based/test.py    |  9 ++++++-
 2 files changed, 34 insertions(+), 8 deletions(-)

diff --git a/src/pipelines/actions_based/scoring.py b/src/pipelines/actions_based/scoring.py
index 1187052..12c448b 100644
--- a/src/pipelines/actions_based/scoring.py
+++ b/src/pipelines/actions_based/scoring.py
@@ -1,26 +1,45 @@
-from typing import List
+from typing import List, Optional
 
 import numpy as np
 from sklearn.metrics import auc, f1_score, roc_curve
 
 from src.pipelines.actions_based.processing import ACTIONS_KEYS
+from src.utils import prepare_folder
 
 
 class Metrics:
-    def __init__(self, name: str) -> None:
+    def __init__(self, name: str, output_dir: Optional[str]) -> None:
         self.name = name
+        self.message = ""
+        self.output_dir = output_dir
 
     def compute_metrics(self, predictions: np.ndarray, targets: np.ndarray):
         f1_scores = self._f1_scores(predictions, targets)
 
-        print(f"Model {self.name} | F1 scores")
-        print("----------------------")
-        print(dict(zip(ACTIONS_KEYS, f1_scores)))
-        print("----------------------")
+        self._log_text(f"Model {self.name} | F1 scores")
+        self._log_text("----------------------")
+        self._log_text(f1_scores)
+        self._log_text("----------------------")
+
+        self._output_message()
 
     def _f1_scores(self, predictions: np.ndarray, targets: np.ndarray) -> dict:
         predictions = predictions_threshold(predictions, 0.0)
-        return f1_score(predictions, targets, average=None)
+        f1_scores = f1_score(predictions, targets, average=None)
+
+        return dict(zip(ACTIONS_KEYS, f1_scores))
+
+    def _output_message(self):
+        print(self.message)
+
+        if self.output_dir is not None:
+            prepare_folder(self.output_dir)
+
+            with open(f"{self.output_dir}/{self.name}.txt", "w") as f:
+                f.write(self.message)
+
+    def _log_text(self, text: str):
+        self.message += f"{text}\n"
 
 
 def predictions_threshold(
diff --git a/src/pipelines/actions_based/test.py b/src/pipelines/actions_based/test.py
index c982aa0..3dbc61f 100644
--- a/src/pipelines/actions_based/test.py
+++ b/src/pipelines/actions_based/test.py
@@ -56,6 +56,13 @@ if __name__ == "__main__":
         default="cpu",
         help="Device on which inference will be done",
     )
+    parser.add_argument(
+        "-o",
+        "--output",
+        type=str,
+        required=True,
+        help="Directory where output will be stored",
+    )
     args = parser.parse_args()
 
     test_dataset = f"{PROJECT_ROOT}/{args.dataset}"
@@ -108,7 +115,7 @@ if __name__ == "__main__":
     )
     trues = np.concatenate(true_batches, axis=0).reshape(-1, true_batches[0].shape[-1])
 
-    metrics = Metrics("actions-base")
+    metrics = Metrics("actions-base", args.output)
 
     print("Calculating metrics...")
     metrics.compute_metrics(predictions, trues)
-- 
GitLab


From d7da711259e560075f7211dd08900c70b42484ce Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 26 Aug 2020 17:07:05 +0200
Subject: [PATCH 087/116] Some testing parameters are now moved to params.yaml

---
 params.yaml                         |  5 ++++
 src/pipelines/actions_based/test.py | 39 ++++++++++++-----------------
 2 files changed, 21 insertions(+), 23 deletions(-)

diff --git a/params.yaml b/params.yaml
index 500eeff..3c8305f 100644
--- a/params.yaml
+++ b/params.yaml
@@ -38,6 +38,11 @@ actions:
         fresh_start: false
         device: "cuda:0"
 
+    test_base:
+        limit: 1
+        batch_size: 1
+        device: "cuda:0"
+
     training_restricted:
         learning_rate: 0.0001
         num_epochs: 5
diff --git a/src/pipelines/actions_based/test.py b/src/pipelines/actions_based/test.py
index 3dbc61f..8cfcda9 100644
--- a/src/pipelines/actions_based/test.py
+++ b/src/pipelines/actions_based/test.py
@@ -10,7 +10,7 @@ from src.models.actions_model_base import ActionsModelBase
 from src.models.actions_model_mixed import ActionsModelMixed
 from src.models.actions_model_restricted import ActionsModelRestricted
 from src.pipelines.actions_based.scoring import Metrics
-from src.utils import PROJECT_ROOT, unflattened_column
+from src.utils import PROJECT_ROOT, get_config, unflattened_column
 
 SUPPORTED_MODELS = {
     "base": ActionsModelBase,
@@ -34,28 +34,13 @@ if __name__ == "__main__":
         help="Directory where trained model is located, relative to project root",
     )
     parser.add_argument("-m", "--model", default="final", help="Pretrained model name")
-    parser.add_argument("-b", "--batch", type=int, default=1, help="Batch size")
     parser.add_argument(
-        "-l",
-        "--limit",
-        type=int,
-        default=0,
-        help="Limits how much samples from test set will be evaluated",
-    )
-    parser.add_argument(
-        "-s",
+        "-ds",
         "--dataset",
         type=str,
         required=True,
         help="Directory where test dataset is located, relative to project root",
     )
-    parser.add_argument(
-        "-e",
-        "--device",
-        type=str,
-        default="cpu",
-        help="Device on which inference will be done",
-    )
     parser.add_argument(
         "-o",
         "--output",
@@ -63,8 +48,16 @@ if __name__ == "__main__":
         required=True,
         help="Directory where output will be stored",
     )
+    parser.add_argument(
+        "-s", "--stage", type=str, required=True, help="Stage name in params.yaml"
+    )
     args = parser.parse_args()
 
+    config = get_config()
+    limit = config["actions"][args.stage]["limit"]
+    batch_size = config["actions"][args.stage]["batch_size"]
+    device_name = config["actions"][args.stage]["device"]
+
     test_dataset = f"{PROJECT_ROOT}/{args.dataset}"
 
     print("Getting dataset info...")
@@ -73,26 +66,26 @@ if __name__ == "__main__":
     print("Loading dataset to memory...")
     df_len = get_ordered_dataframe_len(df)
 
-    data_start = max(df_len - args.limit, 0)
+    data_start = max(df_len - limit, 0)
     data_end = df_len
     pdf = df.loc[data_start:data_end].compute().reset_index()
 
-    device = torch.device("cpu")
+    device = torch.device(device_name)
 
     print(f"Loading model {args.model}")
     model_location = f"{PROJECT_ROOT}/{args.directory}"
     model_type = SUPPORTED_MODELS[args.architecture]
-    model = model_type.load(model_location, args.model, args.device)
+    model = model_type.load(model_location, args.model, device)
 
     true_batches = []
     prediction_batches = []
 
     print("Computing...")
-    num_batches = len(pdf) // args.batch
+    num_batches = len(pdf) // batch_size
     for batch in trange(num_batches):
 
-        batch_start = batch * args.batch
-        batch_end = (batch + 1) * args.batch
+        batch_start = batch * batch_size
+        batch_end = (batch + 1) * batch_size
         batch_pdf = pdf.iloc[batch_start:batch_end]
 
         inputs = unflattened_column(batch_pdf, "source")
-- 
GitLab


From 1beaacac28bd5269440da9d09ee2a3972e9defe3 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 26 Aug 2020 17:16:18 +0200
Subject: [PATCH 088/116] Added testing to DAG

---
 dvc.yaml    | 55 ++++++++++++++++++++++++++++++++++++++++++++++++-----
 params.yaml | 14 ++++++++++++--
 2 files changed, 62 insertions(+), 7 deletions(-)

diff --git a/dvc.yaml b/dvc.yaml
index e4b7cba..4970d16 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -1,4 +1,7 @@
 stages:
+  ######################
+  #       Action       #
+  ######################
   actions_extraction:
     cmd: python3 -m src.pipelines.actions_based.stage1_extraction
     deps:
@@ -40,6 +43,8 @@ stages:
     - src
     outs:
     - generated/actions/stage5_stats
+
+  # Base
   actions_base_training:
     cmd: python3 -m src.pipelines.actions_based.train_base
     deps:
@@ -49,14 +54,26 @@ stages:
     params:
     - global.base_model
     - global.random_seed
-    - actions.training.max_training_time
-    - actions.training.learning_rate
-    - actions.training.num_epochs
-    - actions.training.batch_size
-    - actions.training.save_step
+    - actions.training_base.max_training_time
+    - actions.training_base.learning_rate
+    - actions.training_base.num_epochs
+    - actions.training_base.batch_size
+    - actions.training_base.save_step
     outs:
     - checkpoints/actions_base
 
+  actions_base_testing:
+    cmd: python3 -m src.pipelines.actions_based.test -a base -d checkpoints/actions_base/ -m "final" -ds generated/actions/stage4_reindexing/ -o generated/actions/test_results_base -s testing_base
+    deps:
+    - checkpoints/actions_base
+    - generated/actions/stage4_reindexing
+    - src
+    params:
+    - actions.testing_base.limit
+    outs:
+    - generated/actions/test_results_base
+
+  # Restricted
   actions_restricted_training:
     cmd: python3 -m src.pipelines.actions_based.train_restricted
     deps:
@@ -74,6 +91,18 @@ stages:
     outs:
     - checkpoints/actions_restricted
 
+  actions_restricted_testing:
+    cmd: python3 -m src.pipelines.actions_based.test -a restricted -d checkpoints/actions_restricted/ -m "final" -ds generated/actions/stage4_reindexing/ -o generated/actions/test_results_restricted -s testing_restricted
+    deps:
+    - checkpoints/actions_restricted
+    - generated/actions/stage4_reindexing
+    - src
+    params:
+    - actions.testing_restricted.limit
+    outs:
+    - generated/actions/test_results_restricted
+
+  # Mixed
   actions_mixed_training:
     cmd: python3 -m src.pipelines.actions_based.train_mixed
     deps:
@@ -95,6 +124,21 @@ stages:
     - actions.training_mixed.save_step
     outs:
     - checkpoints/actions_mixed
+
+  actions_mixed_testing:
+    cmd: python3 -m src.pipelines.actions_based.test -a mixed -d checkpoints/actions_mixed/ -m "final" -ds generated/actions/stage4_reindexing/ -o generated/actions/test_results_mixed -s testing_mixed
+    deps:
+    - checkpoints/actions_mixed
+    - generated/actions/stage4_reindexing
+    - src
+    params:
+    - actions.testing_mixed.limit
+    outs:
+    - generated/actions/test_results_mixed
+
+  ######################
+  #    Translation     #
+  ######################
   translations_extraction:
     cmd: python3 -m src.pipelines.translation_based.stage1_extraction
     deps:
@@ -103,6 +147,7 @@ stages:
     - translations.extraction.num_partitions
     outs:
     - generated/translations/stage1_extraction
+
   translations_create_batches:
     cmd: python3 -m src.pipelines.translation_based.stage2_create_batches
     deps:
diff --git a/params.yaml b/params.yaml
index 3c8305f..fd62fe1 100644
--- a/params.yaml
+++ b/params.yaml
@@ -38,8 +38,8 @@ actions:
         fresh_start: false
         device: "cuda:0"
 
-    test_base:
-        limit: 1
+    testing_base:
+        limit: None
         batch_size: 1
         device: "cuda:0"
 
@@ -54,6 +54,11 @@ actions:
         fresh_start: true
         device: "cuda:0"
 
+    test_restricted:
+        limit: None
+        batch_size: 1
+        device: "cuda:0"
+
     training_mixed:
         embedding_size: 768
         num_heads: 12
@@ -69,6 +74,11 @@ actions:
         loss_averaging_span: 1000
         fresh_start: true
         device: "cuda:0"
+
+    test_mixed:
+        limit: None
+        batch_size: 1
+        device: "cuda:0"
 translations:
     extraction:
         num_partitions: 2_000
-- 
GitLab


From 71140e4f1ed1bd2bd2835f04bd0719bb18b2b632 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 26 Aug 2020 17:52:43 +0200
Subject: [PATCH 089/116] Added comments

---
 src/pipelines/actions_based/scoring.py        | 50 ++++++++++++++++---
 .../actions_based/stage1_extraction.py        |  5 +-
 .../actions_based/stage2_tokenization.py      |  5 +-
 .../actions_based/stage3_exploding.py         |  7 ++-
 .../translation_based/stage1_extraction.py    |  5 +-
 .../stage2_create_batches.py                  |  5 +-
 6 files changed, 65 insertions(+), 12 deletions(-)

diff --git a/src/pipelines/actions_based/scoring.py b/src/pipelines/actions_based/scoring.py
index 12c448b..66a267f 100644
--- a/src/pipelines/actions_based/scoring.py
+++ b/src/pipelines/actions_based/scoring.py
@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import List, Optional, Tuple
 
 import numpy as np
 from sklearn.metrics import auc, f1_score, roc_curve
@@ -8,12 +8,26 @@ from src.utils import prepare_folder
 
 
 class Metrics:
+    """Class for model metrics calcuation and presenting"""
+
     def __init__(self, name: str, output_dir: Optional[str]) -> None:
+        """Initializes Metrics
+
+        Args:
+            name (str): Name of the model that is measured
+            output_dir (Optional[str]): Directory where measurments will be saved. Can be None if saving is not required
+        """
         self.name = name
         self.message = ""
         self.output_dir = output_dir
 
     def compute_metrics(self, predictions: np.ndarray, targets: np.ndarray):
+        """Performs metrics calculation on model predictions relative to ground truth
+
+        Args:
+            predictions (np.ndarray): Predicted, non-thresholded values
+            targets (np.ndarray): Ground truth values
+        """
         f1_scores = self._f1_scores(predictions, targets)
 
         self._log_text(f"Model {self.name} | F1 scores")
@@ -45,18 +59,31 @@ class Metrics:
 def predictions_threshold(
     predictions: np.ndarray, threshold: float = 0.9
 ) -> np.ndarray:
-    return (predictions > threshold).astype(np.float)
+    """Applies thresholding above which all values will be assigned 1.0, otherwsie 0.0
 
+    Args:
+        predictions (np.ndarray): Unthresholded predictions
+        threshold (float, optional): Threshold. Defaults to 0.9.
 
-def compute_accuracy(predictions: np.ndarray, targets: np.ndarray) -> np.ndarray:
-    return (
-        np.sum((predictions == targets).astype(np.int), axis=0) / predictions.shape[0]
-    )
+    Returns:
+        np.ndarray: Binarized predictions
+    """
+    return (predictions > threshold).astype(np.float)
 
 
 def multiclass_roc_curve(
     target: np.ndarray, predictions: np.ndarray
-) -> List[np.ndarray]:
+) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
+    """Computes ROC-curve points for multiclass/mutlilabel case
+
+    Args:
+        target (np.ndarray): Ground-truth values
+        predictions (np.ndarray): Unthresholded predictions
+
+    Returns:
+        Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: Falsoe positive rates, True-positives rates, thresholds. All
+        values are returned as a lists, where each entry in the list coresponds to value at single class
+    """
     class_fprs = []
     class_tprs = []
     class_thresholds = []
@@ -74,6 +101,15 @@ def multiclass_roc_curve(
 def multiclass_auc(
     false_positive_rate: List[np.ndarray], true_positive_rate: List[np.ndarray]
 ) -> np.ndarray:
+    """Computes area under curve for each class in multilabel/multiclass case
+
+    Args:
+        false_positive_rate (List[np.ndarray]): False positive rates, where each entry in the list coresponds to value at single class
+        true_positive_rate (List[np.ndarray]): True positive rates, where each entry in the list coresponds to value at single class
+
+    Returns:
+        np.ndarray: List of auc values for each class
+    """
 
     assert len(false_positive_rate) == len(true_positive_rate)
 
diff --git a/src/pipelines/actions_based/stage1_extraction.py b/src/pipelines/actions_based/stage1_extraction.py
index 5a058a9..94dc26c 100644
--- a/src/pipelines/actions_based/stage1_extraction.py
+++ b/src/pipelines/actions_based/stage1_extraction.py
@@ -6,7 +6,10 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.actions_based.processing import APPLY_FILE_PROCESSING_META, apply_file_processing
+from src.pipelines.actions_based.processing import (
+    APPLY_FILE_PROCESSING_META,
+    apply_file_processing,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/actions_based/stage2_tokenization.py b/src/pipelines/actions_based/stage2_tokenization.py
index 0ea3586..b30445f 100644
--- a/src/pipelines/actions_based/stage2_tokenization.py
+++ b/src/pipelines/actions_based/stage2_tokenization.py
@@ -4,7 +4,10 @@ import dask.dataframe as dd
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.actions_based.processing import APPLY_TOKENIZATION_META, apply_tokenization
+from src.pipelines.actions_based.processing import (
+    APPLY_TOKENIZATION_META,
+    apply_tokenization,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
diff --git a/src/pipelines/actions_based/stage3_exploding.py b/src/pipelines/actions_based/stage3_exploding.py
index 81dc965..72ec128 100644
--- a/src/pipelines/actions_based/stage3_exploding.py
+++ b/src/pipelines/actions_based/stage3_exploding.py
@@ -2,7 +2,12 @@
 import dask.dataframe as dd
 from dask.distributed import Client
 
-from src.processing import EXPAND_DIMS_META, FLATTEN_DIMS_META, expand_dims, flatten_dims
+from src.processing import (
+    EXPAND_DIMS_META,
+    FLATTEN_DIMS_META,
+    expand_dims,
+    flatten_dims,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
diff --git a/src/pipelines/translation_based/stage1_extraction.py b/src/pipelines/translation_based/stage1_extraction.py
index 386211d..6ffdbf7 100644
--- a/src/pipelines/translation_based/stage1_extraction.py
+++ b/src/pipelines/translation_based/stage1_extraction.py
@@ -6,7 +6,10 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.translation_based.processing import RAW_TO_DATAFRAME_META, raw_to_dataframe
+from src.pipelines.translation_based.processing import (
+    RAW_TO_DATAFRAME_META,
+    raw_to_dataframe,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/translation_based/stage2_create_batches.py b/src/pipelines/translation_based/stage2_create_batches.py
index ade8bf2..83a2edc 100644
--- a/src/pipelines/translation_based/stage2_create_batches.py
+++ b/src/pipelines/translation_based/stage2_create_batches.py
@@ -4,7 +4,10 @@ from dask import delayed
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.translation_based.processing import GENERATE_BATCHES_META, generate_batches
+from src.pipelines.translation_based.processing import (
+    GENERATE_BATCHES_META,
+    generate_batches,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
-- 
GitLab


From e0102e718acfa326b122826285d0ca5ba108e549 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 26 Aug 2020 18:10:35 +0200
Subject: [PATCH 090/116] Added comments to utility classes

---
 .../actions_based/stage1_extraction.py        |   5 +-
 .../actions_based/stage2_tokenization.py      |   5 +-
 .../actions_based/stage3_exploding.py         |   7 +-
 .../translation_based/stage1_extraction.py    |   5 +-
 .../stage2_create_batches.py                  |   5 +-
 src/utils.py                                  | 122 ++++++++++++++++--
 6 files changed, 118 insertions(+), 31 deletions(-)

diff --git a/src/pipelines/actions_based/stage1_extraction.py b/src/pipelines/actions_based/stage1_extraction.py
index 94dc26c..5a058a9 100644
--- a/src/pipelines/actions_based/stage1_extraction.py
+++ b/src/pipelines/actions_based/stage1_extraction.py
@@ -6,10 +6,7 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.actions_based.processing import (
-    APPLY_FILE_PROCESSING_META,
-    apply_file_processing,
-)
+from src.pipelines.actions_based.processing import APPLY_FILE_PROCESSING_META, apply_file_processing
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/actions_based/stage2_tokenization.py b/src/pipelines/actions_based/stage2_tokenization.py
index b30445f..0ea3586 100644
--- a/src/pipelines/actions_based/stage2_tokenization.py
+++ b/src/pipelines/actions_based/stage2_tokenization.py
@@ -4,10 +4,7 @@ import dask.dataframe as dd
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.actions_based.processing import (
-    APPLY_TOKENIZATION_META,
-    apply_tokenization,
-)
+from src.pipelines.actions_based.processing import APPLY_TOKENIZATION_META, apply_tokenization
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
diff --git a/src/pipelines/actions_based/stage3_exploding.py b/src/pipelines/actions_based/stage3_exploding.py
index 72ec128..81dc965 100644
--- a/src/pipelines/actions_based/stage3_exploding.py
+++ b/src/pipelines/actions_based/stage3_exploding.py
@@ -2,12 +2,7 @@
 import dask.dataframe as dd
 from dask.distributed import Client
 
-from src.processing import (
-    EXPAND_DIMS_META,
-    FLATTEN_DIMS_META,
-    expand_dims,
-    flatten_dims,
-)
+from src.processing import EXPAND_DIMS_META, FLATTEN_DIMS_META, expand_dims, flatten_dims
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
diff --git a/src/pipelines/translation_based/stage1_extraction.py b/src/pipelines/translation_based/stage1_extraction.py
index 6ffdbf7..386211d 100644
--- a/src/pipelines/translation_based/stage1_extraction.py
+++ b/src/pipelines/translation_based/stage1_extraction.py
@@ -6,10 +6,7 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.translation_based.processing import (
-    RAW_TO_DATAFRAME_META,
-    raw_to_dataframe,
-)
+from src.pipelines.translation_based.processing import RAW_TO_DATAFRAME_META, raw_to_dataframe
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/translation_based/stage2_create_batches.py b/src/pipelines/translation_based/stage2_create_batches.py
index 83a2edc..ade8bf2 100644
--- a/src/pipelines/translation_based/stage2_create_batches.py
+++ b/src/pipelines/translation_based/stage2_create_batches.py
@@ -4,10 +4,7 @@ from dask import delayed
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.translation_based.processing import (
-    GENERATE_BATCHES_META,
-    generate_batches,
-)
+from src.pipelines.translation_based.processing import GENERATE_BATCHES_META, generate_batches
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
diff --git a/src/utils.py b/src/utils.py
index 50d7007..27fdfb1 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -23,21 +23,37 @@ PROJECT_ROOT = os.path.dirname(os.path.realpath("/".join(__file__.split("/")) +
 
 
 class Saver:
+    """Class that allows saving and loading mode-optimizer pairs"""
+
     def __init__(
         self, save_dir: str, model: PunctuationModel, optimizer: Optimizer
     ) -> None:
+        """Initializes Saver
+
+        Args:
+            save_dir (str): Directory where model and optimizer will be saved
+            model (PunctuationModel): Model to save
+            optimizer (Optimizer): Optimizer to save
+        """
         self.save_dir = save_dir
         self.model = model
         self.optimizer = optimizer
 
         prepare_folder(self.save_dir)
 
-    def save(self, name):
+    def save(self, name: str):
+        """Saves model and optimizer
+
+        Args:
+            name (str): Name under which modell will be saved
+        """
         self.model.save(self.save_dir, name)
         torch.save(self.optimizer.state_dict(), f"{self.save_dir}/{name}.optimizer")
 
 
 class Loader:
+    """Class for loading model and it's optimizer from checkpoint"""
+
     def __init__(
         self,
         save_dir: str,
@@ -45,6 +61,14 @@ class Loader:
         optimizer_type: Type[Optimizer],
         device: str,
     ) -> None:
+        """Initializes Loader
+
+        Args:
+            save_dir (str): Directory where to search for models
+            model_type (Type[PunctuationModel]): Model class that should be loaded
+            optimizer_type (Type[Optimizer]): Optimizer class that should be loaded
+            device (str): Device on which loaded model/optimizer will exists
+        """
         self.save_dir = save_dir
         self.device = device
 
@@ -52,11 +76,24 @@ class Loader:
         self.optimizer_type = optimizer_type
 
     def has_checkpoints(self) -> bool:
+        """Checks if there are any saved checkpoints in model's directory
+
+        Returns:
+            bool: True if checkpoints where found, False otherwise
+        """
         files = glob(f"{self.save_dir}/*.model")
 
         return latest_model(files) is not None
 
-    def load(self, name) -> Tuple[PunctuationModel, Optimizer]:
+    def load(self, name: str) -> Tuple[PunctuationModel, Optimizer]:
+        """Loads a model and optimizer from file
+
+        Args:
+            name (str): Name of the model that will be loaded
+
+        Returns:
+            Tuple[PunctuationModel, Optimizer]: Model and optimizer
+        """
         model = self.model_type.load(self.save_dir, name, self.device)
 
         optimizer = self.optimizer_type(model.parameters())
@@ -69,6 +106,12 @@ class Loader:
         return model, optimizer
 
     def load_latest(self) -> Tuple[PunctuationModel, Optimizer, int, int]:
+        """Loads latest checkpoint in directory
+
+        Returns:
+            Tuple[PunctuationModel, Optimizer, int, int]: Model, Optimizer, Epoch at
+            which checkpoint was made, step at which checkpoint was made
+        """
         files = glob(f"{self.save_dir}/*.model")
 
         model_id = latest_model(files)
@@ -80,16 +123,34 @@ class Loader:
 
 
 class Checkpoint:
+    """Utility class to make checkpoints every constant ammount of steps"""
+
     def __init__(
-        self, save_step, saver: Saver, start_epoch: int, start_step: int
+        self, save_step: int, saver: Saver, start_epoch: int, start_step: int
     ) -> None:
+        """Initializes Checkpoint.
+        Starting epoch and step are provided, so that checkpoint will not be made right after
+        loading model.
+
+        Args:
+            save_step (int): Number of steps after which checkpoints will be saved
+            saver (Saver): Saver used to save model/optimizer state
+            start_epoch (int): Epoch at which training was started
+            start_step (int): Step at which training was started
+        """
         self.start_step = start_step
         self.start_epoch = start_epoch
         self.save_step = save_step
 
         self.saver = saver
 
-    def step(self, epoch, step) -> None:
+    def step(self, epoch: int, step: int) -> None:
+        """Check if checkpoint should be made, and save it if necessary
+
+        Args:
+            epoch (int): Epoch num
+            step (int): Step num
+        """
         if step % self.save_step == 0 and (
             step != self.start_step or epoch != self.start_epoch
         ):
@@ -98,17 +159,45 @@ class Checkpoint:
 
 
 class Timeout:
+    """Utility class that prevent training from surpassing maximum ammount of time"""
+
     def __init__(self, duration: timedelta, saver: Optional[Saver]) -> None:
+        """Initializes Timeout
+
+        Args:
+            duration (timedelta): Maxium duration of training
+            saver (Optional[Saver]): Saver used to save checkpoint if traing time is
+            exceeded
+        """
         self.saver = saver
         self.duration = duration
         self.time_max = None
 
-    def start(self, time_now: datetime = datetime.now()):
+    def start(self, time_now: Optional[datetime] = None):
+        """Starts counting time from the start of training
+
+        Args:
+            time_now (Optional[datetime], optional): Point from which time will be measured.
+            Use current time if None. Defaults to None.
+        """
+        if time_now is None:
+            time_now = datetime.now()
+
         self.time_max = datetime.max
         if self.duration is not None:
-            self.time_max = datetime.now() + self.max_train_time
+            self.time_max = time_now + self.max_train_time
+
+    def step(self, epoch: int, step: int, time: Optional[datetime] = None) -> bool:
+        """Check if timeout was not exceeded. Saved checkpoint if time is exceeded
 
-    def step(self, epoch, step, time=None) -> bool:
+        Args:
+            epoch (int): Epoch number
+            step (int): Step number
+            time (Optional[datetime], optional): Current time. Use current time if None. Defaults to None.
+
+        Returns:
+            bool: True if time was exceeded, False otherwise
+        """
         assert self.time_max is not None
 
         if time is None:
@@ -125,12 +214,27 @@ class Timeout:
 
 
 class ProgressTracker:
-    def __init__(self, device: str, loss_averaging_span) -> None:
+    """Utility class used to tracking loss and displaying it to user"""
+
+    def __init__(self, device: str, loss_averaging_span: int) -> None:
+        """Initializes ProgressTracker
+
+        Args:
+            device (str): Device on which training is performed
+            loss_averaging_span (int): Number of latest samples used to calculate average loss
+        """
         print(f"Training on {device}")
         self.loss_averaging_span = loss_averaging_span
         self.losses = []
 
-    def step(self, epoch, step, loss) -> None:
+    def step(self, epoch: int, step: int, loss: float) -> None:
+        """New loss was calculated. Informs user about it
+
+        Args:
+            epoch (int): Epoch number
+            step (int): Step number
+            loss (float): Loss value at provided epoch and step
+        """
         self.losses.append(loss.item())
         loss_mean, self.losses = moving_average(self.losses, self.loss_averaging_span)
 
-- 
GitLab


From dab8a69e0ec6ad0ce934c46104bb789f5d8ca149 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 10:22:23 +0200
Subject: [PATCH 091/116] Removed to() beeing called on optimizer object

---
 src/models/actions_model_base.py                | 7 +++----
 src/pipelines/actions_based/train_base.py       | 6 +++---
 src/pipelines/actions_based/train_mixed.py      | 2 +-
 src/pipelines/actions_based/train_restricted.py | 2 +-
 4 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index f23ecc0..67b58e7 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -46,7 +46,6 @@ class ActionsModelBase(ActionsModel):
         config = PretrainedConfig.from_pretrained(params.base_model)
         config.num_labels = params.num_labels
 
-        self.criterion = None
         self.core = BertForTokenClassification(config)
 
     def forward(
@@ -97,16 +96,16 @@ class ActionsModelBase(ActionsModel):
 class ActionsModelBaseLoss(nn.Module):
     """Proposed loss for ActionsModelBase model"""
 
-    def __init__(self, prior_odds: torch.Tensor) -> None:
+    def __init__(self, prior_inverse_odds: torch.Tensor) -> None:
         """Initializes ActionsModelBaseLoss
 
         Args:
-            prior_odds (torch.Tensor): Positive to negative ratio of each action vector
+            prior_odds (torch.Tensor): Negative to positive ratio of each action vector
                 entry in dataset. Shape A
         """
         super(ActionsModelBaseLoss, self).__init__()
 
-        self.core = BCEWithLogitsLoss(pos_weight=prior_odds)
+        self.core = BCEWithLogitsLoss(pos_weight=prior_inverse_odds)
 
     def forward(
         self,
diff --git a/src/pipelines/actions_based/train_base.py b/src/pipelines/actions_based/train_base.py
index b8c8bce..1b0bd13 100755
--- a/src/pipelines/actions_based/train_base.py
+++ b/src/pipelines/actions_based/train_base.py
@@ -61,7 +61,7 @@ if __name__ == "__main__":
         params = ActionsModelBaseParams(base_model, len(ACTIONS_KEYS))
         model = ActionsModelBase(params).to(device)
 
-        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate).to(device)
+        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
         epoch_start, sample_start = (0, 0)
 
     model.train()
@@ -99,12 +99,12 @@ if __name__ == "__main__":
 
         inputs = torch.tensor(inputs, dtype=torch.long).squeeze(dim=-1).to(device)
         outputs = torch.tensor(outputs, dtype=torch.float).to(device)
-        attentions_mask = torch.tensor(attentions_mask).to(device)
+        attentions_mask = torch.tensor(attentions_mask).type(torch.long).to(device)
 
         y_pred = model(input_ids=inputs, attention_mask=attentions_mask)
 
-        loss = criterion(y_pred, outputs)
         optimizer.zero_grad()
+        loss = criterion(y_pred, outputs)
 
         tracker.step(epoch, i, loss)
         checkpoint.step(epoch, i)
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index 64a82d8..6186c10 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -75,7 +75,7 @@ if __name__ == "__main__":
             dropout,
         )
         model = ActionsModelMixed(params).to(device)
-        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate).to(device)
+        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
         epoch_start, sample_start = (0, 0)
 
     model.train()
diff --git a/src/pipelines/actions_based/train_restricted.py b/src/pipelines/actions_based/train_restricted.py
index 29118dc..06e5c20 100755
--- a/src/pipelines/actions_based/train_restricted.py
+++ b/src/pipelines/actions_based/train_restricted.py
@@ -63,7 +63,7 @@ if __name__ == "__main__":
     else:
         params = ActionsModelRestrictedParams(base_model, len(ACTIONS_KEYS) + 1)
         model = ActionsModelRestricted(params).to(device)
-        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate).to(device)
+        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
         epoch_start, sample_start = (0, 0)
 
     model.train()
-- 
GitLab


From ff7779fb1b85d46a67237b9070d001166af85805 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 10:23:24 +0200
Subject: [PATCH 092/116] Added more stuff to gitignore

---
 .gitignore | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/.gitignore b/.gitignore
index 1e28a8e..31cb712 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,3 +15,7 @@ __pycache__
 notebooks
 dvc.lock
 dask-worker-space
+test_data
+.env
+deploy
+service.log
\ No newline at end of file
-- 
GitLab


From 09c20ba4c4c491bb9637bae11ee5fe408beb1981 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 10:31:51 +0200
Subject: [PATCH 093/116] User inside container can be now tweeked

---
 Dockerfile                    | 20 ++++++++++++++++++++
 docker/development/Dockerfile | 19 +++++++------------
 src/utils.py                  |  4 ++--
 3 files changed, 29 insertions(+), 14 deletions(-)

diff --git a/Dockerfile b/Dockerfile
index ce5211b..f32bcb4 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -12,4 +12,24 @@ COPY config.ini .
 COPY worker.py .
 COPY entrypoint.sh .
 
+ARG USERNAME=clarin
+ARG USER_UID=1000
+ARG USER_GID=1000
+
+# Create the user
+RUN groupadd --gid $USER_GID $USERNAME \
+    && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
+    #
+    # [Optional] Add sudo support. Omit if you don't need to install software after connecting.
+    && apt-get update \
+    && apt-get install -y sudo \
+    && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
+    && chmod 0440 /etc/sudoers.d/$USERNAME
+
+# ********************************************************
+# * Anything else you want to do like clean up goes here *
+# ********************************************************
+
+USER ${USERNAME}
+
 ENTRYPOINT [ "./entrypoint.sh" ]
\ No newline at end of file
diff --git a/docker/development/Dockerfile b/docker/development/Dockerfile
index 29a9224..cba06d6 100644
--- a/docker/development/Dockerfile
+++ b/docker/development/Dockerfile
@@ -6,11 +6,11 @@ RUN pip3 install ipywidgets
 
 #### CUDA Installation
 RUN apt-get update && apt-get install -y --no-install-recommends \
-gnupg2 curl ca-certificates && \
+    gnupg2 curl ca-certificates && \
     curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
     echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
     echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \
-rm -rf /var/lib/apt/lists/*
+    rm -rf /var/lib/apt/lists/*
 
 ENV CUDA_VERSION 10.2.89
 
@@ -18,9 +18,9 @@ ENV CUDA_PKG_VERSION 10-2=$CUDA_VERSION-1
 
 # For libraries in the cuda-compat-* package: https://docs.nvidia.com/cuda/eula/index.html#attachment-a
 RUN apt-get update && apt-get install -y --no-install-recommends \
-        cuda-cudart-$CUDA_PKG_VERSION \
-cuda-compat-10-2 && \
-ln -s cuda-10.2 /usr/local/cuda && \
+    cuda-cudart-$CUDA_PKG_VERSION \
+    cuda-compat-10-2 && \
+    ln -s cuda-10.2 /usr/local/cuda && \
     rm -rf /var/lib/apt/lists/*
 
 # Required for nvidia-docker v1
@@ -45,8 +45,8 @@ ARG USER_UID=1030
 ARG USER_GID=1032
 
 # Create the user
-RUN groupadd --gid 1032 $USERNAME \
-    && useradd --uid 1030 --gid 1032 -m $USERNAME \
+RUN groupadd --gid $USER_GID $USERNAME \
+    && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
     #
     # [Optional] Add sudo support. Omit if you don't need to install software after connecting.
     && apt-get update \
@@ -54,11 +54,6 @@ RUN groupadd --gid 1032 $USERNAME \
     && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
     && chmod 0440 /etc/sudoers.d/$USERNAME
 
-
-RUN groupmod --gid $USER_GID $USERNAME \
-    && usermod --uid $USER_UID --gid $USER_GID $USERNAME \
-    && chown -R $USER_UID:$USER_GID /home/$USERNAME
-
 # ********************************************************
 # * Anything else you want to do like clean up goes here *
 # ********************************************************
diff --git a/src/utils.py b/src/utils.py
index 27fdfb1..ee201d7 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -408,8 +408,8 @@ def training_loop(
     Yields:
         Generator: batch, epoch_num, step_num
     """
+    i = sample_start
     for epoch in range(epoch_start, num_epochs):
-        i = sample_start
         for data_batch in get_batches(
             df, batch_size, batch_buffer_size, random_index_shuffle, i
         ):
@@ -420,7 +420,7 @@ def training_loop(
 
             i += 1
 
-        sample_start = 0
+        i = 0
 
 
 def random_indexes(df: dd.DataFrame) -> np.ndarray:
-- 
GitLab


From a4c3f73bcf5959c6f99d03d8afd8b0dcd783096e Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 10:40:44 +0200
Subject: [PATCH 094/116] Python 3.7 support in container

---
 src/utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/utils.py b/src/utils.py
index ee201d7..b14a173 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -119,7 +119,7 @@ class Loader:
             return None
 
         epoch, step = model_id
-        return *(self.load(f"{epoch}-{step}")), epoch, step
+        return self.load(f"{epoch}-{step}"), epoch, step
 
 
 class Checkpoint:
-- 
GitLab


From def645dcb767ea5825f785b7ce7bab5f9289301b Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 10:41:34 +0200
Subject: [PATCH 095/116] Train script now automaticly assings host user inside
 docker

---
 train.sh | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/train.sh b/train.sh
index e0fb415..34d160f 100755
--- a/train.sh
+++ b/train.sh
@@ -5,5 +5,5 @@
 
 DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
 
-docker build . -f ./docker/training/Dockerfile -t clarinpl/punctuator_training
+docker build . -f ./docker/training/Dockerfile -t clarinpl/punctuator_training --build-arg USERNAME=$(whoami) --build-arg USER_UID=$(id -u) --build-arg USER_GID=$(id -u)
 docker run -v $DIR:/punctuator --name $2 --gpus all -it --entrypoint python clarinpl/punctuator_training -m $1
-- 
GitLab


From e9e9575627b568ff36a58b4cff5ac8c6ba867695 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 10:45:27 +0200
Subject: [PATCH 096/116] User config was set in wrong docker

---
 Dockerfile                 | 16 ----------------
 docker/training/Dockerfile | 16 ++++++++++++++++
 2 files changed, 16 insertions(+), 16 deletions(-)

diff --git a/Dockerfile b/Dockerfile
index f32bcb4..a7acc60 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -16,20 +16,4 @@ ARG USERNAME=clarin
 ARG USER_UID=1000
 ARG USER_GID=1000
 
-# Create the user
-RUN groupadd --gid $USER_GID $USERNAME \
-    && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
-    #
-    # [Optional] Add sudo support. Omit if you don't need to install software after connecting.
-    && apt-get update \
-    && apt-get install -y sudo \
-    && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
-    && chmod 0440 /etc/sudoers.d/$USERNAME
-
-# ********************************************************
-# * Anything else you want to do like clean up goes here *
-# ********************************************************
-
-USER ${USERNAME}
-
 ENTRYPOINT [ "./entrypoint.sh" ]
\ No newline at end of file
diff --git a/docker/training/Dockerfile b/docker/training/Dockerfile
index de4f169..cf834c4 100644
--- a/docker/training/Dockerfile
+++ b/docker/training/Dockerfile
@@ -1,5 +1,21 @@
 FROM clarinpl/cuda-python:3.7
 
+# Create the user
+RUN groupadd --gid $USER_GID $USERNAME \
+    && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
+    #
+    # [Optional] Add sudo support. Omit if you don't need to install software after connecting.
+    && apt-get update \
+    && apt-get install -y sudo \
+    && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
+    && chmod 0440 /etc/sudoers.d/$USERNAME
+
+# ********************************************************
+# * Anything else you want to do like clean up goes here *
+# ********************************************************
+
+USER ${USERNAME}
+
 RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y gcc python3-dev
 RUN mkdir /punctuator
 WORKDIR /punctuator
-- 
GitLab


From 922992841d1b68d4630ef1ec5ef901c1f99fb98d Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 10:47:26 +0200
Subject: [PATCH 097/116] Dockerfiles fixup

---
 Dockerfile                 |  4 ----
 docker/training/Dockerfile | 10 ++++------
 2 files changed, 4 insertions(+), 10 deletions(-)

diff --git a/Dockerfile b/Dockerfile
index a7acc60..ce5211b 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -12,8 +12,4 @@ COPY config.ini .
 COPY worker.py .
 COPY entrypoint.sh .
 
-ARG USERNAME=clarin
-ARG USER_UID=1000
-ARG USER_GID=1000
-
 ENTRYPOINT [ "./entrypoint.sh" ]
\ No newline at end of file
diff --git a/docker/training/Dockerfile b/docker/training/Dockerfile
index cf834c4..a1d264b 100644
--- a/docker/training/Dockerfile
+++ b/docker/training/Dockerfile
@@ -1,19 +1,17 @@
 FROM clarinpl/cuda-python:3.7
 
+ARG USERNAME=clarin
+ARG USER_UID=1000
+ARG USER_GID=1000
+
 # Create the user
 RUN groupadd --gid $USER_GID $USERNAME \
     && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
-    #
-    # [Optional] Add sudo support. Omit if you don't need to install software after connecting.
     && apt-get update \
     && apt-get install -y sudo \
     && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
     && chmod 0440 /etc/sudoers.d/$USERNAME
 
-# ********************************************************
-# * Anything else you want to do like clean up goes here *
-# ********************************************************
-
 USER ${USERNAME}
 
 RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y gcc python3-dev
-- 
GitLab


From f47306e722058c19d926ec5815cb7e30ea846452 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 10:51:40 +0200
Subject: [PATCH 098/116] Fixup in training dockerfile

---
 docker/training/Dockerfile | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/docker/training/Dockerfile b/docker/training/Dockerfile
index a1d264b..d3c2355 100644
--- a/docker/training/Dockerfile
+++ b/docker/training/Dockerfile
@@ -1,5 +1,9 @@
 FROM clarinpl/cuda-python:3.7
 
+RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y gcc python3-dev
+RUN mkdir /punctuator
+WORKDIR /punctuator
+
 ARG USERNAME=clarin
 ARG USER_UID=1000
 ARG USER_GID=1000
@@ -14,9 +18,5 @@ RUN groupadd --gid $USER_GID $USERNAME \
 
 USER ${USERNAME}
 
-RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y gcc python3-dev
-RUN mkdir /punctuator
-WORKDIR /punctuator
-
 COPY requirements.txt requirements.txt
 RUN pip3 install -r requirements.txt && rm requirements.txt
\ No newline at end of file
-- 
GitLab


From 29ad181c6cb2e808a8716f56cac9423c16b1103e Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 10:56:08 +0200
Subject: [PATCH 099/116] Local bin is added to path in training dockerfile

---
 docker/training/Dockerfile | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/docker/training/Dockerfile b/docker/training/Dockerfile
index d3c2355..bad0d2e 100644
--- a/docker/training/Dockerfile
+++ b/docker/training/Dockerfile
@@ -14,7 +14,8 @@ RUN groupadd --gid $USER_GID $USERNAME \
     && apt-get update \
     && apt-get install -y sudo \
     && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
-    && chmod 0440 /etc/sudoers.d/$USERNAME
+    && chmod 0440 /etc/sudoers.d/$USERNAME \
+    && export PATH=PATH:/home/${USERNAME}/.local/bin
 
 USER ${USERNAME}
 
-- 
GitLab


From c1085123fd6d6a6b5868edadde32bc7ee13bb5a8 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 10:57:04 +0200
Subject: [PATCH 100/116] (fixup) Local bin is added to path in training
 dockerfile

---
 docker/training/Dockerfile | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/docker/training/Dockerfile b/docker/training/Dockerfile
index bad0d2e..9047f9a 100644
--- a/docker/training/Dockerfile
+++ b/docker/training/Dockerfile
@@ -15,7 +15,7 @@ RUN groupadd --gid $USER_GID $USERNAME \
     && apt-get install -y sudo \
     && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
     && chmod 0440 /etc/sudoers.d/$USERNAME \
-    && export PATH=PATH:/home/${USERNAME}/.local/bin
+    && export PATH=$PATH:/home/${USERNAME}/.local/bin
 
 USER ${USERNAME}
 
-- 
GitLab


From 3f1ca956efc7e6cc6a7f692e0adf81d5057120be Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 11:04:15 +0200
Subject: [PATCH 101/116] (fixup 2) Local bin is added to path in training
 dockerfile

---
 docker/training/Dockerfile | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/docker/training/Dockerfile b/docker/training/Dockerfile
index 9047f9a..b653588 100644
--- a/docker/training/Dockerfile
+++ b/docker/training/Dockerfile
@@ -14,8 +14,9 @@ RUN groupadd --gid $USER_GID $USERNAME \
     && apt-get update \
     && apt-get install -y sudo \
     && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
-    && chmod 0440 /etc/sudoers.d/$USERNAME \
-    && export PATH=$PATH:/home/${USERNAME}/.local/bin
+    && chmod 0440 /etc/sudoers.d/$USERNAME
+
+ENV PATH="/home/${USERNAME}/.local/bin:${PATH}"
 
 USER ${USERNAME}
 
-- 
GitLab


From b9e62bfa52b90d2b9d3576a8679154fa14820b84 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 11:38:01 +0000
Subject: [PATCH 102/116] Fixed missing threshold param

---
 src/pipelines/actions_based/train_mixed.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index 6186c10..69873ef 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -34,6 +34,7 @@ OUTPUT_PATH = f"{PROJECT_ROOT}/checkpoints/actions_mixed"
 
 if __name__ == "__main__":
     config = get_config()
+    threshold = config["actions"]["training_mixed"]["threshold"]
     embedding_size = config["actions"]["training_mixed"]["embedding_size"]
     num_heads = config["actions"]["training_mixed"]["num_heads"]
     num_layers = config["actions"]["training_mixed"]["num_layers"]
@@ -66,6 +67,7 @@ if __name__ == "__main__":
     else:
         params = ActionsModelMixedParams(
             tokenizer.vocab_size,
+            threshold,
             embedding_size,
             num_heads,
             num_layers,
-- 
GitLab


From 5aaa7ac11d40f106012bc9dce9e3ac35bb03e403 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 13:37:31 +0200
Subject: [PATCH 103/116] Made dockers build faster and be more versetile

---
 docker/development/Dockerfile | 12 ++++--------
 docker/training/Dockerfile    |  8 ++++----
 2 files changed, 8 insertions(+), 12 deletions(-)

diff --git a/docker/development/Dockerfile b/docker/development/Dockerfile
index cba06d6..d6356a7 100644
--- a/docker/development/Dockerfile
+++ b/docker/development/Dockerfile
@@ -40,22 +40,18 @@ ENV NVIDIA_REQUIRE_CUDA "cuda>=10.2 brand=tesla,driver>=384,driver<385 brand=tes
 RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pyarrow==0.17.1 pytest lxml
 RUN ln -s /usr/bin/pip3 /usr/bin/pip
 
-ARG USERNAME=mpogoda
-ARG USER_UID=1030
-ARG USER_GID=1032
+ARG USERNAME=clarin
+ARG USER_UID=1000
+ARG USER_GID=1000
 
 # Create the user
 RUN groupadd --gid $USER_GID $USERNAME \
     && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
-    #
-    # [Optional] Add sudo support. Omit if you don't need to install software after connecting.
     && apt-get update \
     && apt-get install -y sudo \
     && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
     && chmod 0440 /etc/sudoers.d/$USERNAME
 
-# ********************************************************
-# * Anything else you want to do like clean up goes here *
-# ********************************************************
+ENV PATH="/home/${USERNAME}/.local/bin:${PATH}"
 
 USER ${USERNAME}
\ No newline at end of file
diff --git a/docker/training/Dockerfile b/docker/training/Dockerfile
index b653588..b3a7fba 100644
--- a/docker/training/Dockerfile
+++ b/docker/training/Dockerfile
@@ -4,6 +4,9 @@ RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y gcc pyth
 RUN mkdir /punctuator
 WORKDIR /punctuator
 
+COPY requirements.txt requirements.txt
+RUN pip3 install -r requirements.txt && rm requirements.txt
+
 ARG USERNAME=clarin
 ARG USER_UID=1000
 ARG USER_GID=1000
@@ -18,7 +21,4 @@ RUN groupadd --gid $USER_GID $USERNAME \
 
 ENV PATH="/home/${USERNAME}/.local/bin:${PATH}"
 
-USER ${USERNAME}
-
-COPY requirements.txt requirements.txt
-RUN pip3 install -r requirements.txt && rm requirements.txt
\ No newline at end of file
+USER ${USERNAME}
\ No newline at end of file
-- 
GitLab


From 6788089155f2c6f3aaf71d5bb320c443ab14721d Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 11:40:19 +0000
Subject: [PATCH 104/116] Training script now depends on image build success

---
 train.sh | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/train.sh b/train.sh
index 34d160f..3d7da2d 100755
--- a/train.sh
+++ b/train.sh
@@ -5,5 +5,5 @@
 
 DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
 
-docker build . -f ./docker/training/Dockerfile -t clarinpl/punctuator_training --build-arg USERNAME=$(whoami) --build-arg USER_UID=$(id -u) --build-arg USER_GID=$(id -u)
+docker build . -f ./docker/training/Dockerfile -t clarinpl/punctuator_training --build-arg USERNAME=$(whoami) --build-arg USER_UID=$(id -u) --build-arg USER_GID=$(id -u) && \
 docker run -v $DIR:/punctuator --name $2 --gpus all -it --entrypoint python clarinpl/punctuator_training -m $1
-- 
GitLab


From f7655ba8858aea4dc14c468f16543de15def9180 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 27 Aug 2020 12:06:08 +0000
Subject: [PATCH 105/116] Compatibility fixes

---
 src/models/actions_model_mixed.py               | 5 ++++-
 src/pipelines/actions_based/train_base.py       | 3 ++-
 src/pipelines/actions_based/train_mixed.py      | 3 ++-
 src/pipelines/actions_based/train_restricted.py | 3 ++-
 src/utils.py                                    | 4 +++-
 5 files changed, 13 insertions(+), 5 deletions(-)

diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index 5e1fc5f..212477f 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -219,7 +219,10 @@ class ActionsModelMixed(PunctuationModel):
     @staticmethod
     def load(dir: str, name: str, device: str) -> PunctuationModel:
         params = pickle_read(f"{dir}/{name}.config")
-        model = ActionsModelMixed(params).to(device)
+        
+        model = ActionsModelMixed(params)
+        model.to(device)
+
         model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device))
 
         return model
diff --git a/src/pipelines/actions_based/train_base.py b/src/pipelines/actions_based/train_base.py
index 1b0bd13..c8597f0 100755
--- a/src/pipelines/actions_based/train_base.py
+++ b/src/pipelines/actions_based/train_base.py
@@ -59,7 +59,8 @@ if __name__ == "__main__":
         model, optimizer, epoch_start, sample_start = loader.load_latest()
     else:
         params = ActionsModelBaseParams(base_model, len(ACTIONS_KEYS))
-        model = ActionsModelBase(params).to(device)
+        model = ActionsModelBase(params)
+        model.to(device)
 
         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
         epoch_start, sample_start = (0, 0)
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index 69873ef..3f2e661 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -76,7 +76,8 @@ if __name__ == "__main__":
             500,
             dropout,
         )
-        model = ActionsModelMixed(params).to(device)
+        model = ActionsModelMixed(params)
+        model.to(device)
         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
         epoch_start, sample_start = (0, 0)
 
diff --git a/src/pipelines/actions_based/train_restricted.py b/src/pipelines/actions_based/train_restricted.py
index 06e5c20..ed43789 100755
--- a/src/pipelines/actions_based/train_restricted.py
+++ b/src/pipelines/actions_based/train_restricted.py
@@ -62,7 +62,8 @@ if __name__ == "__main__":
         model, optimizer, epoch_start, sample_start = loader.load_latest()
     else:
         params = ActionsModelRestrictedParams(base_model, len(ACTIONS_KEYS) + 1)
-        model = ActionsModelRestricted(params).to(device)
+        model = ActionsModelRestricted(params)
+        model.to(device)
         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
         epoch_start, sample_start = (0, 0)
 
diff --git a/src/utils.py b/src/utils.py
index b14a173..a7cb1c4 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -119,7 +119,9 @@ class Loader:
             return None
 
         epoch, step = model_id
-        return self.load(f"{epoch}-{step}"), epoch, step
+        model, optimizer = self.load(f"{epoch}-{step}")
+
+        return model, optimizer, epoch, step
 
 
 class Checkpoint:
-- 
GitLab


From b39161b9c1f70a5ec30f33f1b3ba1a07aad91257 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 31 Aug 2020 09:08:43 +0000
Subject: [PATCH 106/116] Created generic script for applying model from cmd

---
 punctuate.py                                  | 131 +++++++++++++-----
 src/models/TransformerSeq2Seq.py              |   5 +-
 src/models/actions_model_base.py              |   5 +-
 src/models/actions_model_mixed.py             |  12 +-
 src/models/actions_model_restricted.py        |  20 ++-
 src/models/interfaces.py                      |   5 +
 src/pipelines/actions_based/processing.py     |   2 +
 .../actions_based/stage1_extraction.py        |   5 +-
 .../actions_based/stage2_tokenization.py      |   5 +-
 .../actions_based/stage3_exploding.py         |   7 +-
 src/pipelines/actions_based/train_mixed.py    |   1 +
 src/pipelines/actions_based/utils.py          |  32 +++++
 src/pipelines/translation_based/processing.py |   4 +-
 .../translation_based/stage1_extraction.py    |  10 +-
 .../stage2_create_batches.py                  |   5 +-
 src/pipelines/translation_based/train.py      |  17 ++-
 src/utils.py                                  |   3 +-
 tests/models/test_actions_model_mixed.py      |   1 +
 18 files changed, 221 insertions(+), 49 deletions(-)

diff --git a/punctuate.py b/punctuate.py
index 510eaad..3cf0fe6 100755
--- a/punctuate.py
+++ b/punctuate.py
@@ -1,52 +1,117 @@
-#!/usr/bin/python3
-
 import argparse
-import os
-from argparse import Namespace
+from src.pipelines.actions_based.utils import max_suppression
+from src.pipelines.actions_based.processing import (
+    ACTIONS_KEYS,
+    action_vector,
+    recover_text,
+    token_labels_to_word_labels,
+)
+from src.models.interfaces import ActionsModel
+from typing import Dict
 
-from src.pipelines.actions_based.processing import apply_actions_punctuation
-from src.pipelines.actions_based.utils import load_model
-from src.utils import preprocess
+import dask.dataframe as dd
+import numpy as np
+import torch
+from tqdm import trange
 
+from src.batch_loading import get_ordered_dataframe_len
+from src.models.actions_model_base import ActionsModelBase
+from src.models.actions_model_mixed import ActionsModelMixed
+from src.models.actions_model_restricted import ActionsModelRestricted
+from src.pipelines.actions_based.scoring import Metrics
+from src.utils import (
+    PROJECT_ROOT,
+    get_config,
+    input_preprocess,
+    output_preprocess,
+    unflattened_column,
+)
+from transformers import BertTokenizerFast
+import colored
 
-def get_args() -> Namespace:
-    parser = argparse.ArgumentParser(
-        description="Adds punctuaiton in to raw text stream."
-    )
+SUPPORTED_MODELS: Dict[str, ActionsModel] = {
+    "base": ActionsModelBase,
+    "restricted": ActionsModelRestricted,
+    "mixed": ActionsModelMixed,
+}
+
+
+def print_highlighted(text: str, word_labels: np.ndarray, action_name: str) -> None:
+    label_id = np.argwhere(np.array(ACTIONS_KEYS) == action_name)
+
+    text = text.split(" ")
+    for label, word in zip(word_labels, text):
+        SPAN = 255 - 232
+
+        bg_color = int(label[label_id] * (SPAN - 1) + 232)
+        print(colored.bg(bg_color) + colored.fg(2) + word, end=" ")
+    print("")
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Evaluate actions model")
     parser.add_argument(
-        "-i", "--input", type=str, required=True, help="Path to input text file",
+        "-a",
+        "--architecture",
+        required=True,
+        choices=SUPPORTED_MODELS.keys(),
+        help="Model architecture",
     )
     parser.add_argument(
-        "-o", "--output", type=str, required=True, help="Path to input text file",
+        "-d",
+        "--directory",
+        required=True,
+        help="Directory where trained model is located, relative to project root",
     )
     parser.add_argument(
-        "-m", "--model", required=True, type=str, help="Path to the pretrained model",
+        "-i", "--input", required=True, type=str, help="Input text file"
     )
+    parser.add_argument("-m", "--model", default="final", help="Pretrained model name")
     parser.add_argument(
-        "-b", "--base", required=True, type=str, help="Name of base model",
+        "-l",
+        "--highlight",
+        type=str,
+        required=False,
+        choices=ACTIONS_KEYS + ["none"],
+        default="none",
+        help="Highlight prediction confidence of selected action per-word",
     )
     parser.add_argument(
-        "-c", "--chunk_size", default=500, type=int, help="Maximum chunk size"
+        "-dv",
+        "--device",
+        type=str,
+        required=False,
+        default="cpu",
+        help="Device on which inference will be made",
     )
-    parser.add_argument("-t", "--threshold", default=0.9, type=float, help="Threshold")
-
-    return parser.parse_args()
+    args = parser.parse_args()
 
+    print(f"Loading model {args.model}...")
+    device = torch.device(args.device)
+    model_location = f"{PROJECT_ROOT}/{args.directory}"
+    model_type = SUPPORTED_MODELS[args.architecture]
+    model = model_type.load(model_location, args.model, device)
 
-if __name__ == "__main__":
-    args = get_args()
-
-    if not os.path.exists(args.input):
-        print(f"Error: File '{args.input}' does not exists")
-        exit(-1)
+    print("Loading text...")
+    with open(args.input, "r") as f:
+        text = f.read()
 
-    tokenizer, model = load_model(args.model, args.base, "cpu")
+    print(f"Inferencing...")
+    tokenizer = model.tokenizer()
+    data = input_preprocess(output_preprocess(text))
+    data_tokenized = tokenizer(data, return_tensors="pt")
 
-    with open(args.input, "r") as f:
-        text = preprocess(f.read())
-        text_processed = apply_actions_punctuation(
-            text, args.chunk_size, tokenizer, model, args.threshold
-        )
+    predictions = (
+        model.predict_raw(data_tokenized["input_ids"], data_tokenized["attention_mask"])
+        .detach()
+        .cpu()
+        .numpy()
+    )
+    word_labels = token_labels_to_word_labels(data, predictions[0, 1:-1], tokenizer)
+    word_labels_suppresed = max_suppression(np.expand_dims(word_labels, axis=0), 0.9)[0]
+    text_recovered = recover_text(data, word_labels_suppresed)
 
-    with open(args.output, "w") as f:
-        f.write(text_processed)
+    if args.highlight != "none":
+        print_highlighted(text_recovered, word_labels, args.highlight)
+    else:
+        print(text_recovered)
diff --git a/src/models/TransformerSeq2Seq.py b/src/models/TransformerSeq2Seq.py
index 3009fae..2e4b5ef 100644
--- a/src/models/TransformerSeq2Seq.py
+++ b/src/models/TransformerSeq2Seq.py
@@ -41,7 +41,10 @@ class TransformerSeq2Seq(nn.Module):
         self.embedding_to_words = nn.Linear(embedding_size, vocab_size)
 
     def forward(
-        self, source: torch.Tensor, target: torch.Tensor, source_mask: torch.Tensor,
+        self,
+        source: torch.Tensor,
+        target: torch.Tensor,
+        source_mask: torch.Tensor,
     ) -> torch.Tensor:
         """Full encoder-decoder pass
 
diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index 67b58e7..e6df927 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -42,7 +42,7 @@ class ActionsModelBase(ActionsModel):
         super(ActionsModelBase, self).__init__()
         self.params = params
 
-        self.tokenizer = BertTokenizerFast.from_pretrained(params.base_model)
+        self._tokenizer = BertTokenizerFast.from_pretrained(params.base_model)
         config = PretrainedConfig.from_pretrained(params.base_model)
         config.num_labels = params.num_labels
 
@@ -79,6 +79,9 @@ class ActionsModelBase(ActionsModel):
 
         return self.forward(input_ids, attention_mask=attention_mask).sigmoid()
 
+    def tokenizer(self) -> BertTokenizerFast:
+        return self._tokenizer
+
     def save(self, dir: str, name: str) -> None:
         prepare_folder(dir)
         torch.save(self.state_dict(), f"{dir}/{name}.model")
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index 212477f..0216410 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -24,6 +24,7 @@ class ActionsModelMixedParams:
     Parameters for initializing ActionsModelMixed
 
     Params:
+        base_tokenizer (str): Name of pretrained tokenizer
         vocab_size (int): Number of tokens in tokenizer dictionary
         embedding_size (int, optional): Shape of word and punctuation embeddings. Defaults to 200.
         num_heads (int, optional): Number of heads in multiheaded attention. Defaults to 4.
@@ -34,6 +35,7 @@ class ActionsModelMixedParams:
         dropout (float, optional): Dropout ratio. Defaults to 0.1.
     """
 
+    base_tokenizer: str
     vocab_size: int
     threshold: float = 0.9
     embedding_size: int = 200
@@ -57,6 +59,7 @@ class ActionsModelMixed(PunctuationModel):
         super(ActionsModelMixed, self).__init__()
 
         self.params = params
+        self._tokenizer = None
 
         self.num_labels = params.num_labels
         self.device = "cpu"
@@ -141,6 +144,13 @@ class ActionsModelMixed(PunctuationModel):
 
         super(ActionsModelMixed, self).to(device)
 
+    def tokenizer(self) -> BertTokenizerFast:
+        if self._tokenizer is None:
+            self._tokenizer = BertTokenizerFast.from_pretrained(
+                self.config.base_tokenizer
+            )
+        return self._tokenizer
+
     def predict(
         self,
         text: str,
@@ -219,7 +229,7 @@ class ActionsModelMixed(PunctuationModel):
     @staticmethod
     def load(dir: str, name: str, device: str) -> PunctuationModel:
         params = pickle_read(f"{dir}/{name}.config")
-        
+
         model = ActionsModelMixed(params)
         model.to(device)
 
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index 12b7f16..b2ce596 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -6,6 +6,7 @@ import torch
 import torch.nn as nn
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
+from transformers.tokenization_bert import BertTokenizerFast
 
 from src.models.actions_model_mixed import ActionsModelMixed
 from src.models.interfaces import ActionsModel, PunctuationModel
@@ -91,6 +92,11 @@ class ActionsModelRestricted(ActionsModel):
 
         return torch.log(z / (1 - z))
 
+    def tokenizer(self) -> BertTokenizerFast:
+        if self._tokenizer is None:
+            self._tokenizer = BertTokenizerFast.from_pretrained(self.config.base_model)
+        return self._tokenizer
+
     def save(self, dir: str, name: str) -> None:
         prepare_folder(dir)
         torch.save(self.state_dict(), f"{dir}/{name}.model")
@@ -100,7 +106,12 @@ class ActionsModelRestricted(ActionsModel):
     def load(dir: str, name: str, device: str) -> ActionsModelRestricted:
         params = pickle_read(f"{dir}/{name}.config")
         model = ActionsModelRestricted(params).to(device)
-        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
+        model.load_state_dict(
+            torch.load(
+                f"{dir}/{name}.model",
+                map_location=device,
+            )
+        )
 
         return model
 
@@ -158,6 +169,11 @@ class ActionsModelRestrictedLoss(nn.Module):
         params = pickle_read(f"{dir}/{name}.config")
         model = ActionsModelMixed(params)
 
-        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
+        model.load_state_dict(
+            torch.load(
+                f"{dir}/{name}.model",
+                map_location=device,
+            )
+        )
 
         return model
diff --git a/src/models/interfaces.py b/src/models/interfaces.py
index 9fb15a7..ed74090 100644
--- a/src/models/interfaces.py
+++ b/src/models/interfaces.py
@@ -4,12 +4,17 @@ from abc import ABC, abstractmethod
 
 import torch
 import torch.nn as nn
+from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
 
 
 class PunctuationModel(nn.Module, ABC):
     def __init__(self) -> None:
         super().__init__()
 
+    @abstractmethod
+    def tokenizer(self) -> PreTrainedTokenizerFast:
+        pass
+
     @abstractmethod
     def save(self, dir: str, name: str) -> None:
         pass
diff --git a/src/pipelines/actions_based/processing.py b/src/pipelines/actions_based/processing.py
index f65b9b4..aaf3b7e 100644
--- a/src/pipelines/actions_based/processing.py
+++ b/src/pipelines/actions_based/processing.py
@@ -9,6 +9,8 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
 from src.utils import input_preprocess, output_preprocess
 
 ACTIONS_KEYS = ["upper_case", "dot", "colon", "question_mark"]
+UPPERCASE_INDEX = 0
+PUNCTUATION_INDEXES = [1, 2, 3]
 
 
 def apply_file_processing(x: dict) -> dict:
diff --git a/src/pipelines/actions_based/stage1_extraction.py b/src/pipelines/actions_based/stage1_extraction.py
index 5a058a9..94dc26c 100644
--- a/src/pipelines/actions_based/stage1_extraction.py
+++ b/src/pipelines/actions_based/stage1_extraction.py
@@ -6,7 +6,10 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.actions_based.processing import APPLY_FILE_PROCESSING_META, apply_file_processing
+from src.pipelines.actions_based.processing import (
+    APPLY_FILE_PROCESSING_META,
+    apply_file_processing,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
diff --git a/src/pipelines/actions_based/stage2_tokenization.py b/src/pipelines/actions_based/stage2_tokenization.py
index 0ea3586..b30445f 100644
--- a/src/pipelines/actions_based/stage2_tokenization.py
+++ b/src/pipelines/actions_based/stage2_tokenization.py
@@ -4,7 +4,10 @@ import dask.dataframe as dd
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.actions_based.processing import APPLY_TOKENIZATION_META, apply_tokenization
+from src.pipelines.actions_based.processing import (
+    APPLY_TOKENIZATION_META,
+    apply_tokenization,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage1_extraction"
diff --git a/src/pipelines/actions_based/stage3_exploding.py b/src/pipelines/actions_based/stage3_exploding.py
index 81dc965..72ec128 100644
--- a/src/pipelines/actions_based/stage3_exploding.py
+++ b/src/pipelines/actions_based/stage3_exploding.py
@@ -2,7 +2,12 @@
 import dask.dataframe as dd
 from dask.distributed import Client
 
-from src.processing import EXPAND_DIMS_META, FLATTEN_DIMS_META, expand_dims, flatten_dims
+from src.processing import (
+    EXPAND_DIMS_META,
+    FLATTEN_DIMS_META,
+    expand_dims,
+    flatten_dims,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/actions/stage2_tokenization"
diff --git a/src/pipelines/actions_based/train_mixed.py b/src/pipelines/actions_based/train_mixed.py
index 3f2e661..fd44e27 100755
--- a/src/pipelines/actions_based/train_mixed.py
+++ b/src/pipelines/actions_based/train_mixed.py
@@ -66,6 +66,7 @@ if __name__ == "__main__":
         model, optimizer, epoch_start, sample_start = loader.load_latest()
     else:
         params = ActionsModelMixedParams(
+            base_model,
             tokenizer.vocab_size,
             threshold,
             embedding_size,
diff --git a/src/pipelines/actions_based/utils.py b/src/pipelines/actions_based/utils.py
index a8728e1..bfe7cfc 100644
--- a/src/pipelines/actions_based/utils.py
+++ b/src/pipelines/actions_based/utils.py
@@ -7,6 +7,8 @@ from transformers import BertForTokenClassification, BertTokenizerFast, Pretrain
 
 from src.pipelines.actions_based.processing import (
     ACTIONS_KEYS,
+    PUNCTUATION_INDEXES,
+    UPPERCASE_INDEX,
     action_vector,
     last_stop_label,
     recover_text,
@@ -39,6 +41,36 @@ def load_model(
     return tokenizer, model
 
 
+def max_suppression(predictions: np.ndarray, threshold: float) -> np.ndarray:
+    """Converts raw prediction into action-vector with punctuation
+    limited to one sign.
+
+    Args:
+        predictions (np.ndarray): Raw predictions from the model
+        threshold (float): Thresholding value
+
+    Returns:
+        np.ndarray: Suppressed, thresholded action-vectr
+    """
+    output = np.zeros_like(predictions)
+
+    output[:, :, 0] = (predictions[:, :, UPPERCASE_INDEX] >= threshold).astype(np.int)
+
+    def assign_most_probable(x):
+        res = np.zeros_like(x)
+
+        if x.max() > threshold:
+            res[x.argmax()] = 1
+
+        return res
+
+    output[:, :, PUNCTUATION_INDEXES] = np.apply_along_axis(
+        assign_most_probable, -1, predictions[:, :, PUNCTUATION_INDEXES]
+    )
+
+    return output
+
+
 def apply_actions_punctuation(
     text: str,
     chunk_size: int,
diff --git a/src/pipelines/translation_based/processing.py b/src/pipelines/translation_based/processing.py
index 608cf43..d0d360d 100644
--- a/src/pipelines/translation_based/processing.py
+++ b/src/pipelines/translation_based/processing.py
@@ -225,7 +225,9 @@ def standarize_translation_sample(
         np.ndarray: Output sequence of length total_length
     """
     return add_padding(
-        add_begin_end_tokens(seq, begin_token, end_token), total_length, padding_symbol,
+        add_begin_end_tokens(seq, begin_token, end_token),
+        total_length,
+        padding_symbol,
     )
 
 
diff --git a/src/pipelines/translation_based/stage1_extraction.py b/src/pipelines/translation_based/stage1_extraction.py
index 386211d..8bc91a5 100644
--- a/src/pipelines/translation_based/stage1_extraction.py
+++ b/src/pipelines/translation_based/stage1_extraction.py
@@ -6,7 +6,10 @@ import numpy as np
 import pandas as pd
 from dask.distributed import Client
 
-from src.pipelines.translation_based.processing import RAW_TO_DATAFRAME_META, raw_to_dataframe
+from src.pipelines.translation_based.processing import (
+    RAW_TO_DATAFRAME_META,
+    raw_to_dataframe,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/data"
@@ -34,7 +37,10 @@ if __name__ == "__main__":
     df = dd.from_pandas(pd.DataFrame({"file": files_paths}), npartitions=num_partitions)
 
     df = df.apply(
-        raw_to_dataframe, result_type="expand", axis=1, meta=RAW_TO_DATAFRAME_META,
+        raw_to_dataframe,
+        result_type="expand",
+        axis=1,
+        meta=RAW_TO_DATAFRAME_META,
     )
     df = df.dropna()
 
diff --git a/src/pipelines/translation_based/stage2_create_batches.py b/src/pipelines/translation_based/stage2_create_batches.py
index ade8bf2..83a2edc 100644
--- a/src/pipelines/translation_based/stage2_create_batches.py
+++ b/src/pipelines/translation_based/stage2_create_batches.py
@@ -4,7 +4,10 @@ from dask import delayed
 from dask.distributed import Client
 from transformers import BertTokenizerFast
 
-from src.pipelines.translation_based.processing import GENERATE_BATCHES_META, generate_batches
+from src.pipelines.translation_based.processing import (
+    GENERATE_BATCHES_META,
+    generate_batches,
+)
 from src.utils import PROJECT_ROOT, get_config, prepare_folder
 
 INPUT_FOLDER = f"{PROJECT_ROOT}/generated/translations/stage1_extraction"
diff --git a/src/pipelines/translation_based/train.py b/src/pipelines/translation_based/train.py
index 67fcf7a..6ba2d54 100755
--- a/src/pipelines/translation_based/train.py
+++ b/src/pipelines/translation_based/train.py
@@ -49,7 +49,14 @@ if __name__ == "__main__":
 
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
-    model = TransformerSeq2Seq(tokenizer.vocab_size, 256, max_len, 4, 4, 4,).to(device)
+    model = TransformerSeq2Seq(
+        tokenizer.vocab_size,
+        256,
+        max_len,
+        4,
+        4,
+        4,
+    ).to(device)
     criterion = torch.nn.CrossEntropyLoss(reduction="mean").to(device)
     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
@@ -62,11 +69,15 @@ if __name__ == "__main__":
         if latest is not None:
             epoch, batch = latest
             model.load_state_dict(
-                torch.load(f"{OUTPUT_PATH}/{epoch}-{batch}.model", map_location=device,)
+                torch.load(
+                    f"{OUTPUT_PATH}/{epoch}-{batch}.model",
+                    map_location=device,
+                )
             )
             optimizer.load_state_dict(
                 torch.load(
-                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer", map_location=device,
+                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer",
+                    map_location=device,
                 )
             )
 
diff --git a/src/utils.py b/src/utils.py
index a7cb1c4..54cea84 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -558,5 +558,6 @@ def save_training_step(
 
     if optimizer is not None:
         torch.save(
-            optimizer.state_dict(), f"{dir}/{name}.optimizer",
+            optimizer.state_dict(),
+            f"{dir}/{name}.optimizer",
         )
diff --git a/tests/models/test_actions_model_mixed.py b/tests/models/test_actions_model_mixed.py
index 79d17e2..7bd8549 100644
--- a/tests/models/test_actions_model_mixed.py
+++ b/tests/models/test_actions_model_mixed.py
@@ -24,6 +24,7 @@ def test_dimensions():
     dropout = 0.1
 
     params = ActionsModelMixedParams(
+        base_model,
         tokenizer.vocab_size,
         threshold,
         embedding_size,
-- 
GitLab


From dbf4c13c4ef603f41e54e0d6794b101087f48946 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 31 Aug 2020 11:45:40 +0200
Subject: [PATCH 107/116] Style fixes, getting nlp_ws from clarin pypi

---
 punctuate.py     | 10 +---------
 requirements.txt |  4 +++-
 tox.ini          |  2 +-
 3 files changed, 5 insertions(+), 11 deletions(-)

diff --git a/punctuate.py b/punctuate.py
index 3cf0fe6..bc6a023 100755
--- a/punctuate.py
+++ b/punctuate.py
@@ -2,31 +2,23 @@ import argparse
 from src.pipelines.actions_based.utils import max_suppression
 from src.pipelines.actions_based.processing import (
     ACTIONS_KEYS,
-    action_vector,
     recover_text,
     token_labels_to_word_labels,
 )
 from src.models.interfaces import ActionsModel
 from typing import Dict
 
-import dask.dataframe as dd
 import numpy as np
 import torch
-from tqdm import trange
 
-from src.batch_loading import get_ordered_dataframe_len
 from src.models.actions_model_base import ActionsModelBase
 from src.models.actions_model_mixed import ActionsModelMixed
 from src.models.actions_model_restricted import ActionsModelRestricted
-from src.pipelines.actions_based.scoring import Metrics
 from src.utils import (
     PROJECT_ROOT,
-    get_config,
     input_preprocess,
     output_preprocess,
-    unflattened_column,
 )
-from transformers import BertTokenizerFast
 import colored
 
 SUPPORTED_MODELS: Dict[str, ActionsModel] = {
@@ -96,7 +88,7 @@ if __name__ == "__main__":
     with open(args.input, "r") as f:
         text = f.read()
 
-    print(f"Inferencing...")
+    print("Inferencing...")
     tokenizer = model.tokenizer()
     data = input_preprocess(output_preprocess(text))
     data_tokenized = tokenizer(data, return_tensors="pt")
diff --git a/requirements.txt b/requirements.txt
index 700a3a8..e4d8dd7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
+--index-url https://pypi.clarin-pl.eu/simple/
 attrs==19.3.0
 bokeh==2.1.1
 certifi==2020.6.20
@@ -57,4 +58,5 @@ typing-extensions==3.7.4.2
 urllib3==1.25.10
 zict==2.0.0
 scikit-learn==0.23.2
-git+https://gitlab.clarin-pl.eu/nlpworkers/nlp_ws.git@fa5f09a2f1447cac2c411c9d9e3d927ecd815ddc#egg=nlp_ws
\ No newline at end of file
+nlp_ws==0.6
+colored==1.4.2
\ No newline at end of file
diff --git a/tox.ini b/tox.ini
index eab80b1..7f8448b 100644
--- a/tox.ini
+++ b/tox.ini
@@ -3,7 +3,7 @@ envlist = unittest,pep8
 skipsdist = True
 
 [testenv]
-deps = -rrequirements.txt
+deps =  -rrequirements.txt
 
 [testenv:unittest]
 commands = pytest --ignore data --ignore generated
-- 
GitLab


From 0a9f7adab75415f542885ce6abb67d41725b1474 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 31 Aug 2020 12:03:52 +0000
Subject: [PATCH 108/116] Hotfixes

---
 punctuate.py                      | 1 +
 src/models/actions_model_mixed.py | 2 +-
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/punctuate.py b/punctuate.py
index bc6a023..29e79bf 100755
--- a/punctuate.py
+++ b/punctuate.py
@@ -83,6 +83,7 @@ if __name__ == "__main__":
     model_location = f"{PROJECT_ROOT}/{args.directory}"
     model_type = SUPPORTED_MODELS[args.architecture]
     model = model_type.load(model_location, args.model, device)
+    model.train(False)
 
     print("Loading text...")
     with open(args.input, "r") as f:
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index 0216410..e8bd584 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -147,7 +147,7 @@ class ActionsModelMixed(PunctuationModel):
     def tokenizer(self) -> BertTokenizerFast:
         if self._tokenizer is None:
             self._tokenizer = BertTokenizerFast.from_pretrained(
-                self.config.base_tokenizer
+                self.params.base_tokenizer
             )
         return self._tokenizer
 
-- 
GitLab


From 8e62064d16a5ba34c1e1eb398b2acfa7fad436ac Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 31 Aug 2020 12:09:11 +0000
Subject: [PATCH 109/116] Typing - changed device from str to torch.Device

---
 src/models/actions_model_base.py       | 2 +-
 src/models/actions_model_mixed.py      | 2 +-
 src/models/actions_model_restricted.py | 4 ++--
 src/models/interfaces.py               | 2 +-
 src/utils.py                           | 6 +++---
 5 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index e6df927..834bf26 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -88,7 +88,7 @@ class ActionsModelBase(ActionsModel):
         pickle_save(self.params, f"{dir}/{name}.config")
 
     @staticmethod
-    def load(dir: str, name: str, device: str) -> ActionsModelBase:
+    def load(dir: str, name: str, device: torch.Device) -> ActionsModelBase:
         params = pickle_read(f"{dir}/{name}.config")
         model = ActionsModelBase(params).to(device)
         model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device))
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index e8bd584..66b42ac 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -227,7 +227,7 @@ class ActionsModelMixed(PunctuationModel):
         pickle_save(self.params, f"{dir}/{name}.config")
 
     @staticmethod
-    def load(dir: str, name: str, device: str) -> PunctuationModel:
+    def load(dir: str, name: str, device: torch.Device) -> PunctuationModel:
         params = pickle_read(f"{dir}/{name}.config")
 
         model = ActionsModelMixed(params)
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index b2ce596..b78ab74 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -103,7 +103,7 @@ class ActionsModelRestricted(ActionsModel):
         pickle_save(self.params, f"{dir}/{name}.config")
 
     @staticmethod
-    def load(dir: str, name: str, device: str) -> ActionsModelRestricted:
+    def load(dir: str, name: str, device: torch.Device) -> ActionsModelRestricted:
         params = pickle_read(f"{dir}/{name}.config")
         model = ActionsModelRestricted(params).to(device)
         model.load_state_dict(
@@ -165,7 +165,7 @@ class ActionsModelRestrictedLoss(nn.Module):
         pickle_save(self.params, f"{dir}/{name}.config")
 
     @staticmethod
-    def load(dir: str, name: str, device: str) -> PunctuationModel:
+    def load(dir: str, name: str, device: torch.Device) -> PunctuationModel:
         params = pickle_read(f"{dir}/{name}.config")
         model = ActionsModelMixed(params)
 
diff --git a/src/models/interfaces.py b/src/models/interfaces.py
index ed74090..0f41381 100644
--- a/src/models/interfaces.py
+++ b/src/models/interfaces.py
@@ -21,7 +21,7 @@ class PunctuationModel(nn.Module, ABC):
 
     @staticmethod
     @abstractmethod
-    def load(dir: str, name: str, device: str) -> PunctuationModel:
+    def load(dir: str, name: str, device: torch.Device) -> PunctuationModel:
         pass
 
 
diff --git a/src/utils.py b/src/utils.py
index 54cea84..3ec73a3 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -59,7 +59,7 @@ class Loader:
         save_dir: str,
         model_type: Type[PunctuationModel],
         optimizer_type: Type[Optimizer],
-        device: str,
+        device: torch.Device,
     ) -> None:
         """Initializes Loader
 
@@ -67,7 +67,7 @@ class Loader:
             save_dir (str): Directory where to search for models
             model_type (Type[PunctuationModel]): Model class that should be loaded
             optimizer_type (Type[Optimizer]): Optimizer class that should be loaded
-            device (str): Device on which loaded model/optimizer will exists
+            device (torch.Device): Device on which loaded model/optimizer will exists
         """
         self.save_dir = save_dir
         self.device = device
@@ -218,7 +218,7 @@ class Timeout:
 class ProgressTracker:
     """Utility class used to tracking loss and displaying it to user"""
 
-    def __init__(self, device: str, loss_averaging_span: int) -> None:
+    def __init__(self, device: torch.Device, loss_averaging_span: int) -> None:
         """Initializes ProgressTracker
 
         Args:
-- 
GitLab


From 8f50b3b11ee4d0323ba07a9e48264538dcd3d40b Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Mon, 31 Aug 2020 14:13:15 +0200
Subject: [PATCH 110/116] Fixed torch.Device -> torch.device

---
 src/models/actions_model_base.py       | 2 +-
 src/models/actions_model_mixed.py      | 2 +-
 src/models/actions_model_restricted.py | 4 ++--
 src/models/interfaces.py               | 2 +-
 src/utils.py                           | 8 ++++----
 5 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index 834bf26..70da0c4 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -88,7 +88,7 @@ class ActionsModelBase(ActionsModel):
         pickle_save(self.params, f"{dir}/{name}.config")
 
     @staticmethod
-    def load(dir: str, name: str, device: torch.Device) -> ActionsModelBase:
+    def load(dir: str, name: str, device: torch.device) -> ActionsModelBase:
         params = pickle_read(f"{dir}/{name}.config")
         model = ActionsModelBase(params).to(device)
         model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device))
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index 66b42ac..ef003cc 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -227,7 +227,7 @@ class ActionsModelMixed(PunctuationModel):
         pickle_save(self.params, f"{dir}/{name}.config")
 
     @staticmethod
-    def load(dir: str, name: str, device: torch.Device) -> PunctuationModel:
+    def load(dir: str, name: str, device: torch.device) -> PunctuationModel:
         params = pickle_read(f"{dir}/{name}.config")
 
         model = ActionsModelMixed(params)
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index b78ab74..963e355 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -103,7 +103,7 @@ class ActionsModelRestricted(ActionsModel):
         pickle_save(self.params, f"{dir}/{name}.config")
 
     @staticmethod
-    def load(dir: str, name: str, device: torch.Device) -> ActionsModelRestricted:
+    def load(dir: str, name: str, device: torch.device) -> ActionsModelRestricted:
         params = pickle_read(f"{dir}/{name}.config")
         model = ActionsModelRestricted(params).to(device)
         model.load_state_dict(
@@ -165,7 +165,7 @@ class ActionsModelRestrictedLoss(nn.Module):
         pickle_save(self.params, f"{dir}/{name}.config")
 
     @staticmethod
-    def load(dir: str, name: str, device: torch.Device) -> PunctuationModel:
+    def load(dir: str, name: str, device: torch.device) -> PunctuationModel:
         params = pickle_read(f"{dir}/{name}.config")
         model = ActionsModelMixed(params)
 
diff --git a/src/models/interfaces.py b/src/models/interfaces.py
index 0f41381..7abab33 100644
--- a/src/models/interfaces.py
+++ b/src/models/interfaces.py
@@ -21,7 +21,7 @@ class PunctuationModel(nn.Module, ABC):
 
     @staticmethod
     @abstractmethod
-    def load(dir: str, name: str, device: torch.Device) -> PunctuationModel:
+    def load(dir: str, name: str, device: torch.device) -> PunctuationModel:
         pass
 
 
diff --git a/src/utils.py b/src/utils.py
index 3ec73a3..f804a56 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -59,7 +59,7 @@ class Loader:
         save_dir: str,
         model_type: Type[PunctuationModel],
         optimizer_type: Type[Optimizer],
-        device: torch.Device,
+        device: torch.device,
     ) -> None:
         """Initializes Loader
 
@@ -67,7 +67,7 @@ class Loader:
             save_dir (str): Directory where to search for models
             model_type (Type[PunctuationModel]): Model class that should be loaded
             optimizer_type (Type[Optimizer]): Optimizer class that should be loaded
-            device (torch.Device): Device on which loaded model/optimizer will exists
+            device (torch.device): Device on which loaded model/optimizer will exists
         """
         self.save_dir = save_dir
         self.device = device
@@ -218,11 +218,11 @@ class Timeout:
 class ProgressTracker:
     """Utility class used to tracking loss and displaying it to user"""
 
-    def __init__(self, device: torch.Device, loss_averaging_span: int) -> None:
+    def __init__(self, device: torch.device, loss_averaging_span: int) -> None:
         """Initializes ProgressTracker
 
         Args:
-            device (str): Device on which training is performed
+            device (torch.device): Device on which training is performed
             loss_averaging_span (int): Number of latest samples used to calculate average loss
         """
         print(f"Training on {device}")
-- 
GitLab


From 470164e142aeb43b526bd1c13f2fd39059f3fb9b Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 1 Sep 2020 11:15:38 +0200
Subject: [PATCH 111/116] All models now follow single inference API

Also alowed for simple yaml-based runtime model configuration files (thresholds etc)
---
 src/models/TransformerSeq2Seq.py              |   5 +-
 src/models/actions_model_base.py              |  97 ++++++++++++++--
 src/models/actions_model_mixed.py             |  56 ++++++---
 src/models/actions_model_restricted.py        | 107 +++++++++++++++---
 src/models/interfaces.py                      |   6 +-
 src/pipelines/actions_based/processing.py     |   2 +
 src/pipelines/translation_based/processing.py |   4 +-
 .../translation_based/stage1_extraction.py    |   5 +-
 src/pipelines/translation_based/train.py      |  17 +--
 src/utils.py                                  |  25 +++-
 tests/models/test_actions_model_base.py       |  13 +++
 tests/models/test_actions_model_mixed.py      |  25 ++++
 tests/models/test_actions_model_restricted.py |  13 +++
 tests/test_utils.py                           |  26 +++++
 14 files changed, 336 insertions(+), 65 deletions(-)

diff --git a/src/models/TransformerSeq2Seq.py b/src/models/TransformerSeq2Seq.py
index 2e4b5ef..3009fae 100644
--- a/src/models/TransformerSeq2Seq.py
+++ b/src/models/TransformerSeq2Seq.py
@@ -41,10 +41,7 @@ class TransformerSeq2Seq(nn.Module):
         self.embedding_to_words = nn.Linear(embedding_size, vocab_size)
 
     def forward(
-        self,
-        source: torch.Tensor,
-        target: torch.Tensor,
-        source_mask: torch.Tensor,
+        self, source: torch.Tensor, target: torch.Tensor, source_mask: torch.Tensor,
     ) -> torch.Tensor:
         """Full encoder-decoder pass
 
diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index 70da0c4..b06eb6a 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -1,17 +1,27 @@
 from __future__ import annotations
 
+import os
 from dataclasses import dataclass
 
+import numpy as np
 import torch
 import torch.nn as nn
+from torch import threshold
 from torch.nn.modules.loss import BCEWithLogitsLoss
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
 from transformers.tokenization_bert import BertTokenizerFast
 
 from src.models.interfaces import ActionsModel
-from src.pipelines.actions_based.processing import ACTIONS_KEYS
-from src.utils import pickle_read, pickle_save, prepare_folder
+from src.pipelines.actions_based.processing import (
+    ACTIONS_KEYS,
+    action_vector,
+    last_stop_label,
+    recover_text,
+    token_labels_to_word_labels,
+)
+from src.pipelines.actions_based.utils import max_suppression
+from src.utils import pickle_read, pickle_save, prepare_folder, yaml_serializable
 
 
 @dataclass
@@ -29,18 +39,38 @@ class ActionsModelBaseParams:
     num_labels: int = len(ACTIONS_KEYS)
 
 
+@yaml_serializable
+@dataclass
+class ActionsModelBaseRuntimeParams:
+    """
+    Parameters for ActionsModelBase during runtime interference
+
+    Args:
+        threshold (float): minimum confidence for applying action
+        chunksize (int): Maximum number of chunks to perform inference on
+    """
+
+    threshold: float = 0.9
+    chunksize: int = 500
+
+
 class ActionsModelBase(ActionsModel):
     """Model based on simple multilabel per-token classifiaction. Each token is binarly classified in n-dimensions"""
 
-    def __init__(self, params: ActionsModelBaseParams) -> None:
+    def __init__(
+        self,
+        params: ActionsModelBaseParams,
+        runtime: ActionsModelBaseRuntimeParams = ActionsModelBaseRuntimeParams(),
+    ) -> None:
         """Initializes actions model
 
         Args:
-            base_model (str): Name of base model
-            num_labels (int): Length of action vector
+            params (ActionsModelBaseParams): Params defining model's structure
+            runtime (ActionsModelBaseRuntimeParams): Params defining model's runtime inference
         """
         super(ActionsModelBase, self).__init__()
         self.params = params
+        self.runtime = runtime
 
         self._tokenizer = BertTokenizerFast.from_pretrained(params.base_model)
         config = PretrainedConfig.from_pretrained(params.base_model)
@@ -79,18 +109,71 @@ class ActionsModelBase(ActionsModel):
 
         return self.forward(input_ids, attention_mask=attention_mask).sigmoid()
 
+    def predict(self, text: str) -> str:
+        text = text.strip()
+
+        tokenizer = self.tokenizer()
+        tokens = tokenizer(text, return_tensors="pt")["input_ids"]
+        output = None
+
+        index_start = 0
+        while index_start < len(tokens[0]):
+            index_end = min(index_start + self.runtime.chunksize, len(tokens[0]))
+
+            tokens_chunk = tokens[:, index_start:index_end]
+
+            actions = (
+                self.predict_raw(tokens_chunk, torch.ones_like(tokens_chunk))
+                .detach()
+                .cpu()
+                .numpy()
+            )
+            actions_suppresed = max_suppression(actions, self.runtime.threshold)[0]
+
+            offset = last_stop_label(actions_suppresed, action_vector(["dot"]))
+
+            # Prevent infinite loop
+            if (offset is None) or (offset == 0):
+                offset = index_end - index_start
+
+            if output is None:
+                output = actions[0, 0:offset]
+            else:
+                output = np.concatenate([output, actions[0, 0:offset]], axis=0)
+
+            index_start += offset
+
+        assert len(output) == len(tokens[0])
+
+        word_labels = token_labels_to_word_labels(text, output[1:-1], tokenizer)
+        actions = max_suppression(
+            np.expand_dims(word_labels, 0), self.runtime.threshold
+        )[0]
+
+        return recover_text(text, actions)
+
     def tokenizer(self) -> BertTokenizerFast:
         return self._tokenizer
 
-    def save(self, dir: str, name: str) -> None:
+    def save(self, dir: str, name: str, runtime: bool = True) -> None:
         prepare_folder(dir)
         torch.save(self.state_dict(), f"{dir}/{name}.model")
         pickle_save(self.params, f"{dir}/{name}.config")
 
+        if runtime:
+            self.runtime.save_yaml(f"{dir}/{name}.runtime.yaml")
+
     @staticmethod
     def load(dir: str, name: str, device: torch.device) -> ActionsModelBase:
         params = pickle_read(f"{dir}/{name}.config")
-        model = ActionsModelBase(params).to(device)
+        if os.path.exists(f"{dir}/{name}.runtime.yaml"):
+            runtime = ActionsModelBaseRuntimeParams.load_yaml(
+                f"{dir}/{name}.runtime.yaml"
+            )
+        else:
+            runtime = ActionsModelBaseRuntimeParams()
+
+        model = ActionsModelBase(params, runtime).to(device)
         model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device))
 
         return model
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index ef003cc..ac0162a 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -1,3 +1,4 @@
+import os
 from dataclasses import dataclass
 from typing import Optional
 
@@ -12,10 +13,11 @@ from src.models.interfaces import PunctuationModel
 from src.pipelines.actions_based.processing import (
     ACTIONS_KEYS,
     action_vector,
+    last_stop_label,
     recover_text,
     token_labels_to_word_labels,
 )
-from src.utils import pickle_read, pickle_save, prepare_folder
+from src.utils import pickle_read, pickle_save, prepare_folder, yaml_serializable
 
 
 @dataclass
@@ -47,10 +49,29 @@ class ActionsModelMixedParams:
     dropout: float = 0.1
 
 
+@yaml_serializable
+@dataclass
+class ActionsModelMixedRuntimeParams:
+    """
+    Parameters for ActionsModelMixed during runtime interference
+
+    Args:
+        threshold (float): minimum confidence for applying action
+        chunksize (int): Maximum number of chunks to perform inference on
+    """
+
+    threshold: float = 0.9
+    max_cond_len: Optional[int] = 500
+
+
 class ActionsModelMixed(PunctuationModel):
     """Encoder-decoder based model with unpunctuated token sequence as input and array of action-vectors as output"""
 
-    def __init__(self, params: ActionsModelMixedParams) -> None:
+    def __init__(
+        self,
+        params: ActionsModelMixedParams,
+        runtime: ActionsModelMixedRuntimeParams = ActionsModelMixedRuntimeParams(),
+    ) -> None:
         """Initializes mixed model
 
         Args:
@@ -59,6 +80,7 @@ class ActionsModelMixed(PunctuationModel):
         super(ActionsModelMixed, self).__init__()
 
         self.params = params
+        self.runtime = runtime
         self._tokenizer = None
 
         self.num_labels = params.num_labels
@@ -151,19 +173,15 @@ class ActionsModelMixed(PunctuationModel):
             )
         return self._tokenizer
 
-    def predict(
-        self,
-        text: str,
-        tokenizer: BertTokenizerFast,
-        threshold: float = 0.9,
-        max_cond_len: Optional[int] = None,
-    ) -> str:
+    def predict(self, text: str) -> str:
         inputs = [action_vector(["upper_case"])]
 
+        tokenizer = self.tokenizer()
         text_tokenized = tokenizer(text, return_tensors="pt")
 
         target_device = self.device
 
+        max_cond_len = self.runtime.max_cond_len
         if max_cond_len is None:
             max_cond_len = np.iinfo(np.int).max
 
@@ -181,9 +199,10 @@ class ActionsModelMixed(PunctuationModel):
             ).sigmoid()
 
             inputs.append(
-                (prediction_raw.detach().cpu().numpy()[0, -1, :] > threshold).astype(
-                    np.float
-                )
+                (
+                    prediction_raw.detach().cpu().numpy()[0, -1, :]
+                    > self.runtime.threshold
+                ).astype(np.float)
             )
 
         word_labels = token_labels_to_word_labels(text, inputs[1:], tokenizer)
@@ -221,16 +240,25 @@ class ActionsModelMixed(PunctuationModel):
 
         return outputs
 
-    def save(self, dir: str, name: str) -> None:
+    def save(self, dir: str, name: str, runtime: bool = True) -> None:
         prepare_folder(dir)
         torch.save(self.state_dict(), f"{dir}/{name}.model")
         pickle_save(self.params, f"{dir}/{name}.config")
 
+        if runtime:
+            self.runtime.save_yaml(f"{dir}/{name}.runtime.yaml")
+
     @staticmethod
     def load(dir: str, name: str, device: torch.device) -> PunctuationModel:
         params = pickle_read(f"{dir}/{name}.config")
+        if os.path.exists(f"{dir}/{name}.runtime.yaml"):
+            runtime = ActionsModelMixedRuntimeParams.load_yaml(
+                f"{dir}/{name}.runtime.yaml"
+            )
+        else:
+            runtime = ActionsModelMixedRuntimeParams()
 
-        model = ActionsModelMixed(params)
+        model = ActionsModelMixed(params, runtime)
         model.to(device)
 
         model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device))
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index 963e355..0ec7735 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -1,7 +1,9 @@
 from __future__ import annotations
 
+import os
 from dataclasses import dataclass
 
+import numpy as np
 import torch
 import torch.nn as nn
 from transformers.configuration_utils import PretrainedConfig
@@ -10,7 +12,14 @@ from transformers.tokenization_bert import BertTokenizerFast
 
 from src.models.actions_model_mixed import ActionsModelMixed
 from src.models.interfaces import ActionsModel, PunctuationModel
-from src.utils import pickle_read, pickle_save, prepare_folder
+from src.pipelines.actions_based.processing import (
+    action_vector,
+    last_stop_label,
+    recover_text,
+    token_labels_to_word_labels,
+)
+from src.pipelines.actions_based.utils import max_suppression
+from src.utils import pickle_read, pickle_save, prepare_folder, yaml_serializable
 
 
 @dataclass
@@ -27,12 +36,31 @@ class ActionsModelRestrictedParams:
     extended_action_vector_size: int
 
 
+@yaml_serializable
+@dataclass
+class ActionsModelRestrictedRuntimeParams:
+    """
+    Parameters for ActionsModelBase during runtime interference
+
+    Args:
+        threshold (float): minimum confidence for applying action
+        chunksize (int): Maximum number of chunks to perform inference on
+    """
+
+    threshold: float = 0.9
+    chunksize: int = 500
+
+
 class ActionsModelRestricted(ActionsModel):
     """Similar to ActionsModelBase, however no-punctuation class is added
     and punctuation-related entries are treaded as proper categorical distribution
     """
 
-    def __init__(self, params: ActionsModelRestrictedParams) -> None:
+    def __init__(
+        self,
+        params: ActionsModelRestrictedParams,
+        runtime: ActionsModelRestrictedRuntimeParams = ActionsModelRestrictedRuntimeParams(),
+    ) -> None:
         """Initializes restricted actions model
 
         Args:
@@ -42,6 +70,8 @@ class ActionsModelRestricted(ActionsModel):
         super(ActionsModelRestricted, self).__init__()
 
         self.params = params
+        self.runtime = runtime
+        self._tokenizer = None
 
         config = PretrainedConfig.from_pretrained(params.base_model)
 
@@ -84,6 +114,51 @@ class ActionsModelRestricted(ActionsModel):
 
         return torch.cat([prob_uppercase, prob_punctuation], dim=-1)
 
+    def predict(self, text: str) -> str:
+        # TODO: Generalize
+        chunk_size = 500
+        threshold = 0.9
+
+        text = text.strip()
+
+        tokenizer = self.tokenizer()
+        tokens = tokenizer(text, return_tensors="pt")["input_ids"]
+        output = None
+
+        index_start = 0
+        while index_start < len(tokens[0]):
+            index_end = min(index_start + chunk_size, len(tokens[0]))
+
+            tokens_chunk = tokens[:, index_start:index_end]
+
+            actions = (
+                self.predict_raw(tokens_chunk, torch.ones_like(tokens_chunk))
+                .detach()
+                .cpu()
+                .numpy()
+            )
+            actions_suppresed = max_suppression(actions, threshold)[0]
+
+            offset = last_stop_label(actions_suppresed, action_vector(["dot"]))
+
+            # Prevent infinite loop
+            if (offset is None) or (offset == 0):
+                offset = index_end - index_start
+
+            if output is None:
+                output = actions[0, 0:offset]
+            else:
+                output = np.concatenate([output, actions[0, 0:offset]], axis=0)
+
+            index_start += offset
+
+        assert len(output) == len(tokens[0])
+
+        word_labels = token_labels_to_word_labels(text, output[1:-1], tokenizer)
+        actions = max_suppression(np.expand_dims(word_labels, 0), threshold)[0]
+
+        return recover_text(text, actions)
+
     @staticmethod
     def _logit(x: torch.Tensor):
         EPS = 1e-5
@@ -94,24 +169,29 @@ class ActionsModelRestricted(ActionsModel):
 
     def tokenizer(self) -> BertTokenizerFast:
         if self._tokenizer is None:
-            self._tokenizer = BertTokenizerFast.from_pretrained(self.config.base_model)
+            self._tokenizer = BertTokenizerFast.from_pretrained(self.params.base_model)
         return self._tokenizer
 
-    def save(self, dir: str, name: str) -> None:
+    def save(self, dir: str, name: str, runtime: bool = True) -> None:
         prepare_folder(dir)
         torch.save(self.state_dict(), f"{dir}/{name}.model")
         pickle_save(self.params, f"{dir}/{name}.config")
 
+        if runtime:
+            self.runtime.save_yaml(f"{dir}/{name}.runtime.yaml")
+
     @staticmethod
     def load(dir: str, name: str, device: torch.device) -> ActionsModelRestricted:
         params = pickle_read(f"{dir}/{name}.config")
-        model = ActionsModelRestricted(params).to(device)
-        model.load_state_dict(
-            torch.load(
-                f"{dir}/{name}.model",
-                map_location=device,
+        if os.path.exists(f"{dir}/{name}.runtime.yaml"):
+            runtime = ActionsModelRestrictedRuntimeParams.load_yaml(
+                f"{dir}/{name}.runtime.yaml"
             )
-        )
+        else:
+            runtime = ActionsModelRestrictedRuntimeParams()
+
+        model = ActionsModelRestricted(params, runtime).to(device)
+        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
 
         return model
 
@@ -169,11 +249,6 @@ class ActionsModelRestrictedLoss(nn.Module):
         params = pickle_read(f"{dir}/{name}.config")
         model = ActionsModelMixed(params)
 
-        model.load_state_dict(
-            torch.load(
-                f"{dir}/{name}.model",
-                map_location=device,
-            )
-        )
+        model.load_state_dict(torch.load(f"{dir}/{name}.model", map_location=device,))
 
         return model
diff --git a/src/models/interfaces.py b/src/models/interfaces.py
index 7abab33..4627145 100644
--- a/src/models/interfaces.py
+++ b/src/models/interfaces.py
@@ -16,7 +16,7 @@ class PunctuationModel(nn.Module, ABC):
         pass
 
     @abstractmethod
-    def save(self, dir: str, name: str) -> None:
+    def save(self, dir: str, name: str, runtime: bool = False) -> None:
         pass
 
     @staticmethod
@@ -43,3 +43,7 @@ class ActionsModel(PunctuationModel):
             torch.Tensor: Per-token action-vector labels. Shape BxLxA
         """
         pass
+
+    @abstractmethod
+    def predict(self, text: str) -> str:
+        pass
diff --git a/src/pipelines/actions_based/processing.py b/src/pipelines/actions_based/processing.py
index aaf3b7e..4f19a1a 100644
--- a/src/pipelines/actions_based/processing.py
+++ b/src/pipelines/actions_based/processing.py
@@ -128,6 +128,8 @@ def last_stop_label(labels: np.array, stop_action: np.array) -> Optional[int]:
 
     assert len(labels.shape) == 2
     assert len(stop_action.shape) == 1
+    assert stop_action.shape[0] == labels.shape[-1]
+
     stop_labels = np.argwhere(np.all(labels == stop_action, axis=1))
 
     if len(stop_labels) == 0:
diff --git a/src/pipelines/translation_based/processing.py b/src/pipelines/translation_based/processing.py
index d0d360d..608cf43 100644
--- a/src/pipelines/translation_based/processing.py
+++ b/src/pipelines/translation_based/processing.py
@@ -225,9 +225,7 @@ def standarize_translation_sample(
         np.ndarray: Output sequence of length total_length
     """
     return add_padding(
-        add_begin_end_tokens(seq, begin_token, end_token),
-        total_length,
-        padding_symbol,
+        add_begin_end_tokens(seq, begin_token, end_token), total_length, padding_symbol,
     )
 
 
diff --git a/src/pipelines/translation_based/stage1_extraction.py b/src/pipelines/translation_based/stage1_extraction.py
index 8bc91a5..6ffdbf7 100644
--- a/src/pipelines/translation_based/stage1_extraction.py
+++ b/src/pipelines/translation_based/stage1_extraction.py
@@ -37,10 +37,7 @@ if __name__ == "__main__":
     df = dd.from_pandas(pd.DataFrame({"file": files_paths}), npartitions=num_partitions)
 
     df = df.apply(
-        raw_to_dataframe,
-        result_type="expand",
-        axis=1,
-        meta=RAW_TO_DATAFRAME_META,
+        raw_to_dataframe, result_type="expand", axis=1, meta=RAW_TO_DATAFRAME_META,
     )
     df = df.dropna()
 
diff --git a/src/pipelines/translation_based/train.py b/src/pipelines/translation_based/train.py
index 6ba2d54..67fcf7a 100755
--- a/src/pipelines/translation_based/train.py
+++ b/src/pipelines/translation_based/train.py
@@ -49,14 +49,7 @@ if __name__ == "__main__":
 
     tokenizer = BertTokenizerFast.from_pretrained(base_model)
 
-    model = TransformerSeq2Seq(
-        tokenizer.vocab_size,
-        256,
-        max_len,
-        4,
-        4,
-        4,
-    ).to(device)
+    model = TransformerSeq2Seq(tokenizer.vocab_size, 256, max_len, 4, 4, 4,).to(device)
     criterion = torch.nn.CrossEntropyLoss(reduction="mean").to(device)
     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
@@ -69,15 +62,11 @@ if __name__ == "__main__":
         if latest is not None:
             epoch, batch = latest
             model.load_state_dict(
-                torch.load(
-                    f"{OUTPUT_PATH}/{epoch}-{batch}.model",
-                    map_location=device,
-                )
+                torch.load(f"{OUTPUT_PATH}/{epoch}-{batch}.model", map_location=device,)
             )
             optimizer.load_state_dict(
                 torch.load(
-                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer",
-                    map_location=device,
+                    f"{OUTPUT_PATH}/{epoch}-{batch}.optimizer", map_location=device,
                 )
             )
 
diff --git a/src/utils.py b/src/utils.py
index f804a56..07e9090 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -558,6 +558,27 @@ def save_training_step(
 
     if optimizer is not None:
         torch.save(
-            optimizer.state_dict(),
-            f"{dir}/{name}.optimizer",
+            optimizer.state_dict(), f"{dir}/{name}.optimizer",
         )
+
+
+def yaml_serializable(cls):
+    def save_yaml(self, path: str) -> None:
+        yml = yaml.dump(self.__dict__)
+        with open(path, "w") as f:
+            f.write(yml)
+
+    @staticmethod
+    def load_yaml(path: str) -> cls:
+        with open(path, "r") as f:
+            yml = f.read()
+
+        obj = cls()
+        obj.__dict__ = yaml.load(yml, Loader=yaml.FullLoader)
+
+        return obj
+
+    setattr(cls, "save_yaml", save_yaml)
+    setattr(cls, "load_yaml", load_yaml)
+
+    return cls
diff --git a/tests/models/test_actions_model_base.py b/tests/models/test_actions_model_base.py
index cee3952..900cf89 100644
--- a/tests/models/test_actions_model_base.py
+++ b/tests/models/test_actions_model_base.py
@@ -6,6 +6,7 @@ from src.models.actions_model_base import (
     ActionsModelBaseLoss,
     ActionsModelBaseParams,
 )
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
 
 
 def test_dimensions():
@@ -44,3 +45,15 @@ def test_loss_dimensions():
     result_bad = loss(actions_vector_bad, actions_vector_true)
 
     assert result_perfect < result_bad
+
+
+def test_predict():
+    params = ActionsModelBaseParams(
+        "dkleczek/bert-base-polish-cased-v1", len(ACTIONS_KEYS)
+    )
+    model = ActionsModelBase(params)
+
+    input_str = "testowy ciag znakow"
+    result = model.predict(input_str)
+
+    assert len(result) >= len(input_str)
diff --git a/tests/models/test_actions_model_mixed.py b/tests/models/test_actions_model_mixed.py
index 7bd8549..786136d 100644
--- a/tests/models/test_actions_model_mixed.py
+++ b/tests/models/test_actions_model_mixed.py
@@ -5,7 +5,9 @@ from src.models.actions_model_mixed import (
     ActionsModelMixed,
     ActionsModelMixedLoss,
     ActionsModelMixedParams,
+    ActionsModelMixedRuntimeParams,
 )
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
 
 
 def test_dimensions():
@@ -67,3 +69,26 @@ def test_loss_dimensions():
     result_bad = loss(actions_vector_true, actions_vector_bad)
 
     assert result_perfect < result_bad
+
+
+def test_predict():
+    tokenizer = BertTokenizerFast.from_pretrained("dkleczek/bert-base-polish-cased-v1")
+    params = ActionsModelMixedParams(
+        "dkleczek/bert-base-polish-cased-v1",
+        tokenizer.vocab_size,
+        0.9,
+        10,
+        2,
+        1,
+        10,
+        len(ACTIONS_KEYS),
+        500,
+        0.1,
+    )
+    runtime = ActionsModelMixedRuntimeParams(0.9, 100)
+    model = ActionsModelMixed(params, runtime)
+
+    input_str = "testowy ciag znakow"
+    result = model.predict(input_str)
+
+    assert len(result) >= len(input_str)
diff --git a/tests/models/test_actions_model_restricted.py b/tests/models/test_actions_model_restricted.py
index 2ea18ef..b659b27 100644
--- a/tests/models/test_actions_model_restricted.py
+++ b/tests/models/test_actions_model_restricted.py
@@ -6,6 +6,7 @@ from src.models.actions_model_restricted import (
     ActionsModelRestrictedLoss,
     ActionsModelRestrictedParams,
 )
+from src.pipelines.actions_based.processing import ACTIONS_KEYS
 
 
 def test_dimensions():
@@ -59,3 +60,15 @@ def test_loss_dimensions():
     assert result_perfect < result_bad
     assert result_perfect > 0
     assert result_bad > 0
+
+
+def test_predict():
+    params = ActionsModelRestrictedParams(
+        "dkleczek/bert-base-polish-cased-v1", len(ACTIONS_KEYS) + 1
+    )
+    model = ActionsModelRestricted(params)
+
+    input_str = "testowy ciag znakow"
+    result = model.predict(input_str)
+
+    assert len(result) >= len(input_str)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 49fbc4a..c4d35b1 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,3 +1,6 @@
+import os
+from dataclasses import dataclass
+
 from src.utils import (
     convert_to_timedelta,
     input_preprocess,
@@ -5,6 +8,7 @@ from src.utils import (
     output_preprocess,
     remove_multiple_spaces,
     remove_punctuation,
+    yaml_serializable,
 )
 
 
@@ -77,3 +81,25 @@ def test_latest_model():
 
     files.append("/path/tam/pam/1-500.a")
     assert latest_model(files) == (1, 1000)
+
+
+def test_yaml_serializable(fs):
+    fs.create_dir("/var")
+
+    @yaml_serializable
+    @dataclass
+    class Test:
+        x: int = 3
+        y: str = "test1"
+
+    x = Test()
+    x.x = -1
+    x.y = "test2"
+    x.save_yaml("/var/test.yaml")
+
+    assert os.path.exists("/var/test.yaml")
+
+    y = Test.load_yaml("/var/test.yaml")
+
+    assert y.x == -1
+    assert y.y == "test2"
-- 
GitLab


From 59bb49b2d51c02454fae30989db4fae64dbbc437 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 1 Sep 2020 18:15:08 +0200
Subject: [PATCH 112/116] All actions model are ready

---
 config.ini                        |  7 ++----
 punctuate.py                      |  2 +-
 requirements.txt                  |  2 +-
 src/models/actions_model_base.py  |  1 -
 src/models/actions_model_mixed.py |  7 +++---
 src/models/model_factory.py       | 10 ++++++++
 src/utils.py                      | 22 +++++++++++++++++
 tests/test_utils.py               | 10 ++++----
 worker.py                         | 39 ++++++++++++++++++++-----------
 9 files changed, 69 insertions(+), 31 deletions(-)
 create mode 100644 src/models/model_factory.py

diff --git a/config.ini b/config.ini
index 9e02bf5..07e1cde 100644
--- a/config.ini
+++ b/config.ini
@@ -1,6 +1,5 @@
 [service]
 tool = punctuator_test
-
 root = /samba/requests/
 rabbit_host = test
 rabbit_user = test
@@ -15,7 +14,5 @@ local_log_level = INFO
 
 [deployment]
 device = cpu
-chunk_size = 500
-threshold = 0.9
-model = deploy/model
-base_model = dkleczek/bert-base-polish-cased-v1
\ No newline at end of file
+models_dir = deploy
+models_enabled = actions_base,actions_mixed,actions_restricted
\ No newline at end of file
diff --git a/punctuate.py b/punctuate.py
index 29e79bf..e8eac14 100755
--- a/punctuate.py
+++ b/punctuate.py
@@ -95,7 +95,7 @@ if __name__ == "__main__":
     data_tokenized = tokenizer(data, return_tensors="pt")
 
     predictions = (
-        model.predict_raw(data_tokenized["input_ids"], data_tokenized["attention_mask"])
+        model.predict_raw(data_tokenized["input_ids"].to(device), data_tokenized["attention_mask"].to(device))
         .detach()
         .cpu()
         .numpy()
diff --git a/requirements.txt b/requirements.txt
index e4d8dd7..4a63c40 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -59,4 +59,4 @@ urllib3==1.25.10
 zict==2.0.0
 scikit-learn==0.23.2
 nlp_ws==0.6
-colored==1.4.2
\ No newline at end of file
+colored==1.4.2
diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py
index b06eb6a..d503f08 100644
--- a/src/models/actions_model_base.py
+++ b/src/models/actions_model_base.py
@@ -6,7 +6,6 @@ from dataclasses import dataclass
 import numpy as np
 import torch
 import torch.nn as nn
-from torch import threshold
 from torch.nn.modules.loss import BCEWithLogitsLoss
 from transformers.configuration_utils import PretrainedConfig
 from transformers.modeling_bert import BertForTokenClassification
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index ac0162a..6982efa 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -13,7 +13,6 @@ from src.models.interfaces import PunctuationModel
 from src.pipelines.actions_based.processing import (
     ACTIONS_KEYS,
     action_vector,
-    last_stop_label,
     recover_text,
     token_labels_to_word_labels,
 )
@@ -200,8 +199,8 @@ class ActionsModelMixed(PunctuationModel):
 
             inputs.append(
                 (
-                    prediction_raw.detach().cpu().numpy()[0, -1, :]
-                    > self.runtime.threshold
+                    prediction_raw.detach().cpu()
+                    .numpy()[0, -1, :] > self.runtime.threshold
                 ).astype(np.float)
             )
 
@@ -233,7 +232,7 @@ class ActionsModelMixed(PunctuationModel):
                 input_ids, outputs, (attention_mask == 0)
             ).sigmoid()
 
-            prediction_raw = (prediction_raw[:, -1:, :] > self.params.threshold).type(
+            prediction_raw = (prediction_raw[:, -1:, :] > self.runtime.threshold).type(
                 torch.float
             )
             outputs = torch.cat([outputs, prediction_raw], dim=1)
diff --git a/src/models/model_factory.py b/src/models/model_factory.py
new file mode 100644
index 0000000..cfa51d2
--- /dev/null
+++ b/src/models/model_factory.py
@@ -0,0 +1,10 @@
+from src.models.actions_model_mixed import ActionsModelMixed
+from src.models.actions_model_restricted import ActionsModelRestricted
+from src.models.actions_model_base import ActionsModelBase
+
+
+MODELS_MAP = {
+    "actions_base": ActionsModelBase,
+    "actions_restricted": ActionsModelRestricted,
+    "actions_mixed": ActionsModelMixed
+}
diff --git a/src/utils.py b/src/utils.py
index 07e9090..90a69f5 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -283,6 +283,26 @@ def remove_punctuation(text: str, whitelist: List[str] = []) -> str:
     return "".join(filter(lambda x: x.isalnum() or x.isspace() or x in whitelist, text))
 
 
+def unify_whitespaces(text: str) -> str:
+    """Maps all whitespace characters into a simple ' '
+
+    Args:
+        text (str): Text containing multiple forms of whitespace
+
+    Returns:
+        str: Text with a single form of whitespace
+    """
+    result = ""
+
+    for c in text:
+        if c.isspace():
+            result += " "
+        else:
+            result += c
+
+    return result
+
+
 def output_preprocess(text: str) -> str:
     """Cleans the text out of bad formating and removes or replaces symbols that will not be predicted by a model
 
@@ -299,6 +319,7 @@ def output_preprocess(text: str) -> str:
     text = text.replace(";", ".").replace("!", ".")
 
     text = remove_punctuation(text, [".", ",", "?"])
+    text = unify_whitespaces(text)
     text = remove_multiple_spaces(text)
     text = text.strip()
 
@@ -316,6 +337,7 @@ def input_preprocess(text: str) -> str:
         str: Text in training-data format
     """
     text = remove_punctuation(text)
+    text = unify_whitespaces(text)
     text = remove_multiple_spaces(text)
     text = text.lower()
     text = text.strip()
diff --git a/tests/test_utils.py b/tests/test_utils.py
index c4d35b1..7e52db2 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -83,8 +83,8 @@ def test_latest_model():
     assert latest_model(files) == (1, 1000)
 
 
-def test_yaml_serializable(fs):
-    fs.create_dir("/var")
+def test_yaml_serializable(tmp_path):
+    PATH = tmp_path / "test.yaml"
 
     @yaml_serializable
     @dataclass
@@ -95,11 +95,11 @@ def test_yaml_serializable(fs):
     x = Test()
     x.x = -1
     x.y = "test2"
-    x.save_yaml("/var/test.yaml")
+    x.save_yaml(PATH)
 
-    assert os.path.exists("/var/test.yaml")
+    assert os.path.exists(PATH)
 
-    y = Test.load_yaml("/var/test.yaml")
+    y = Test.load_yaml(PATH)
 
     assert y.x == -1
     assert y.y == "test2"
diff --git a/worker.py b/worker.py
index cc5790a..f3f9967 100755
--- a/worker.py
+++ b/worker.py
@@ -1,11 +1,12 @@
 #!/usr/bin/python
 
 import configparser
+from src.models.model_factory import MODELS_MAP
+from typing import List
 
 import nlp_ws
 
-from src.pipelines.actions_based.utils import apply_actions_punctuation, load_model
-from src.utils import input_preprocess
+from src.utils import input_preprocess, output_preprocess
 
 
 class Worker(nlp_ws.NLPWorker):
@@ -15,25 +16,35 @@ class Worker(nlp_ws.NLPWorker):
         self.config = configparser.ConfigParser()
         self.config.read("config.ini")
 
-        self.threshold = float(self.config["deployment"]["threshold"])
-        self.chunk_size = float(self.config["deployment"]["chunk_size"])
-        self.tokenizer, self.model = load_model(
-            self.config["deployment"]["model"],
-            self.config["deployment"]["base_model"],
-            self.config["deployment"]["device"],
-        )
+        self.device = self.config["deployment"]["device"]
+        self.models_dir = self.config["deployment"]["models_dir"]
+        self.models = {}
+
+        models_enabled = self.config["deployment"]["models_enabled"]
+        models_enabled = models_enabled.split(",")
+
+        self._load_models(models_enabled)
+
+    def _load_models(self, models_list: List[str]):
+        for model_type in models_list:
+            self.models[model_type] = MODELS_MAP[model_type].load(f"{self.models_dir}/{model_type}", "production", self.device)
+            self.models[model_type].train(False)
 
     def process(self, input_file: str, task_options: dict, output_file: str) -> None:
         """Implementation of example tasks that copies files."""
 
+        if "model" in task_options.keys() and task_options['model'] in MODELS_MAP.keys():
+            model_type = task_options['model']
+        else:
+            model_type = "actions_base"
+
         with open(input_file, "r") as f:
-            text = input_preprocess(f.read())
-            text_processed = apply_actions_punctuation(
-                text, self.chunk_size, self.tokenizer, self.model, self.threshold
-            )
+            text = input_preprocess(output_preprocess(f.read()))
+
+        result = self.models[model_type].predict(text)
 
         with open(output_file, "w") as f:
-            f.write(text_processed)
+            f.write(result)
 
 
 if __name__ == "__main__":
-- 
GitLab


From a318b44df87a87a0e1ca710fb38014fadba26a38 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 1 Sep 2020 19:00:30 +0200
Subject: [PATCH 113/116] Updated models

---
 entrypoint.sh | 21 +++++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)

diff --git a/entrypoint.sh b/entrypoint.sh
index a6e06ed..6de80f1 100755
--- a/entrypoint.sh
+++ b/entrypoint.sh
@@ -1,8 +1,21 @@
 #!/bin/bash
 
-if ! test -f "./deploy/model"; then
-    mkdir -p ./deploy
-    wget https://minio.clarin-pl.eu/public/models/punctuation/0-190000.model -O deploy/model
+if ! test -d "./deploy/actions_base"; then
+    mkdir -p ./deploy/actions_base
+    wget https://minio.clarin-pl.eu/public/models/punctuation/actions_base/production.model -O deploy/actions_base/production.model
+    wget https://minio.clarin-pl.eu/public/models/punctuation/actions_base/production.config -O deploy/actions_base/production.config
 fi
 
-python worker.py
\ No newline at end of file
+if ! test -d "./deploy/actions_mixed"; then
+    mkdir -p ./deploy/actions_mixed
+    wget https://minio.clarin-pl.eu/public/models/punctuation/actions_mixed/production.model -O deploy/actions_mixed/production.model
+    wget https://minio.clarin-pl.eu/public/models/punctuation/actions_mixed/production.config -O deploy/actions_mixed/production.config
+fi
+
+if ! test -d "./deploy/actions_restricted"; then
+    mkdir -p ./deploy/actions_restricted
+    wget https://minio.clarin-pl.eu/public/models/punctuation/actions_restricted/production.model -O deploy/actions_restricted/production.model
+    wget https://minio.clarin-pl.eu/public/models/punctuation/actions_restricted/production.config -O deploy/actions_restricted/production.config
+fi
+
+#python worker.py
\ No newline at end of file
-- 
GitLab


From c2355eb3b1c808c34d1203476fc87bda0b4c3f2e Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Tue, 1 Sep 2020 19:52:37 +0200
Subject: [PATCH 114/116] Removed development docker file

---
 docker/development/Dockerfile | 57 -----------------------------------
 1 file changed, 57 deletions(-)
 delete mode 100644 docker/development/Dockerfile

diff --git a/docker/development/Dockerfile b/docker/development/Dockerfile
deleted file mode 100644
index d6356a7..0000000
--- a/docker/development/Dockerfile
+++ /dev/null
@@ -1,57 +0,0 @@
-from ubuntu:20.04
-
-RUN apt update && apt install -y python3 python3-pip
-RUN apt update && apt install -y git
-RUN pip3 install ipywidgets
-
-#### CUDA Installation
-RUN apt-get update && apt-get install -y --no-install-recommends \
-    gnupg2 curl ca-certificates && \
-    curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
-    echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
-    echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \
-    rm -rf /var/lib/apt/lists/*
-
-ENV CUDA_VERSION 10.2.89
-
-ENV CUDA_PKG_VERSION 10-2=$CUDA_VERSION-1
-
-# For libraries in the cuda-compat-* package: https://docs.nvidia.com/cuda/eula/index.html#attachment-a
-RUN apt-get update && apt-get install -y --no-install-recommends \
-    cuda-cudart-$CUDA_PKG_VERSION \
-    cuda-compat-10-2 && \
-    ln -s cuda-10.2 /usr/local/cuda && \
-    rm -rf /var/lib/apt/lists/*
-
-# Required for nvidia-docker v1
-RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \
-    echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf
-
-ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}
-ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64
-
-# nvidia-container-runtime
-ENV NVIDIA_VISIBLE_DEVICES all
-ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
-ENV NVIDIA_REQUIRE_CUDA "cuda>=10.2 brand=tesla,driver>=384,driver<385 brand=tesla,driver>=396,driver<397 brand=tesla,driver>=410,driver<411 brand=tesla,driver>=418,driver<419"
-
-### END CUDA Installation
-
-RUN pip3 install numpy pandas tqdm seaborn torch dask[complete] transformers pyarrow==0.17.1 pytest lxml
-RUN ln -s /usr/bin/pip3 /usr/bin/pip
-
-ARG USERNAME=clarin
-ARG USER_UID=1000
-ARG USER_GID=1000
-
-# Create the user
-RUN groupadd --gid $USER_GID $USERNAME \
-    && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
-    && apt-get update \
-    && apt-get install -y sudo \
-    && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
-    && chmod 0440 /etc/sudoers.d/$USERNAME
-
-ENV PATH="/home/${USERNAME}/.local/bin:${PATH}"
-
-USER ${USERNAME}
\ No newline at end of file
-- 
GitLab


From 617f22430bacf08ae038a2278751708eb64cf906 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 2 Sep 2020 12:19:58 +0200
Subject: [PATCH 115/116] Restricted model fixes

---
 README.md                              | 50 +++++++++++++++++++++++---
 entrypoint.sh                          |  5 ++-
 src/models/actions_model_restricted.py | 12 ++++---
 3 files changed, 57 insertions(+), 10 deletions(-)

diff --git a/README.md b/README.md
index 0b99108..e8f9b06 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,50 @@
 # Punctuator
-A service that automatically adds punctuation to raw word-stream (eg. from speech2text).  
+A service that automatically adds punctuation to raw word-stream (eg. from speech2text) for polish language. 
 
-## Approaches
-1. Token classification (actions): Each token is classified with 4 labels: Uppercase, dot, colon, question mark. The model is based on the stacked encoder part of transformer architecture (Bert), followed by FC-layer that transforms the output into per-token multilabel binary classifications. For now, there is no restriction for taking dot, question_mark and colon labels simultaneously, so that's the are of improvement  (hierarchical, multilabel classification)
+**Example input**:
+> według webometrycznego rankingu uniwersytetów świata ze stycznia 2019 pokazującego zaangażowanie instytucji akademickich w internecie uczelnia zajmuje 5 miejsce w polsce wśród uczelni technicznych a na świecie 964 wśród wszystkich typów uczelni w rankingu szkół wyższych perspektyw politechnika wrocławska zajęła w 2019 roku 3 miejsce wśród uczelni technicznych oraz 6 miejsce spośród wszystkich uczelni akademickich w polsce
 
-2. Sequence-to-Sequence (translations): Full encoder-decoder stack that takes input (unpunctuated text) and the output produced so far to predict the next token. In theory, this model should be able to represent many more cases (eg. all upper, some upper, dashes, ellipsis etc...) without explicit defines. However, the lack of constraints makes it much harder to train. 
+**Output**:
+> Według webometrycznego rankingu uniwersytetów świata ze stycznia 2019, pokazującego zaangażowanie instytucji akademickich w Internecie, uczelnia zajmuje 5. miejsce w Polsce wśród uczelni technicznych, a na świecie 964. Wśród wszystkich typów uczelni w rankingu szkół wyższych perspektyw Politechnika Wrocławska zajęła w 2019 roku 3. miejsce wśród uczelni technicznych oraz 6. miejsce spośród wszystkich uczelni akademickich w Polsce
+
+## Models
+### Action-Based
+1. actions_base: A simple model, architecturally based on BERT. It's learned on a task to predict an "Action" for each token in the sentence. Action is described as either uppercasing of the token or adding a punctuation sign at the end of the token.
+
+2. actions_restricted: The model nearly identical with actions_base, however it predicts punctuation as a categorical distribution (so that punctuation is mutually exclusive in training time). The idea is to better differentiate between each punctuation.
+
+3. actions_mixed: A model based on the full transformer (encoder + decoder) architecture. It's much less performant, as it only predicts actions for one word at the time. However, it can model action probabilities conditioned on both the input and output predicted so far. Because of that, it's much less prone to not uppercasing letters in a new sentence or placing multiple punctuation signs in close proximity.
+
+### Translation
+2. translation (Deprecated): Full encoder-decoder stack that takes input (unpunctuated text) and the output produced so far to predict the next token. The main difference from the actions model is that it's a full text2text model without restriction on tokens. Because of that, in theory, it can represent more cases (eg. all upper, some upper, dashes, ellipsis, etc...), as opposed to only a few explicitly defined actions. However, the lack of constraints makes it much harder to train (both in performance and data size).
+
+## Usage
+To test the model localy you can use `punctuate.py` script.
+```bash
+punctuate.py [-h] -a {base,restricted,mixed} -d DIRECTORY -i INPUT [-m MODEL] [-l {upper_case,dot,colon,question_mark,none}] [-dv DEVICE]
+
+Evaluate actions model
+
+optional arguments:
+  -h, --help            show this help message and exit
+  -a {base,restricted,mixed}, --architecture {base,restricted,mixed}
+                        Model architecture
+  -d DIRECTORY, --directory DIRECTORY
+                        Directory where trained model is located, relative to project root
+  -i INPUT, --input INPUT
+                        Input text file
+  -m MODEL, --model MODEL
+                        Pretrained model name
+  -l {upper_case,dot,colon,question_mark,none}, --highlight {upper_case,dot,colon,question_mark,none}
+                        Highlight prediction confidence of selected action per-word
+  -dv DEVICE, --device DEVICE
+                        Device on which inference will be made
+```
+Eg. if you place your model named "production" at `punctuator/checkpoints/actions_base/` and example unpunctuated at `punctuator/test_data/test.txt` you can call 
+
+```bash
+python3 punctuate.py -a mixed -d /deploy/actions_mixed -i test_data/text.txt -m production -dv cuda:0
+```
 
 ## Mountpoints
-Directory where model will be downloaded (~500Mb) needs to be mounted at /punctuator/deploy
+Directory where the model will be downloaded (~500Mb) needs to be mounted at /punctuator/deploy
diff --git a/entrypoint.sh b/entrypoint.sh
index 1bf1aeb..5608c38 100755
--- a/entrypoint.sh
+++ b/entrypoint.sh
@@ -4,18 +4,21 @@ if ! test -d "./deploy/actions_base"; then
     mkdir -p ./deploy/actions_base
     wget https://minio.clarin-pl.eu/public/models/punctuation/actions_base/production.model -O deploy/actions_base/production.model
     wget https://minio.clarin-pl.eu/public/models/punctuation/actions_base/production.config -O deploy/actions_base/production.config
+    wget https://minio.clarin-pl.eu/public/models/punctuation/actions_base/production.runtime.yaml -O deploy/actions_base/production.runtime.yaml
 fi
 
 if ! test -d "./deploy/actions_mixed"; then
     mkdir -p ./deploy/actions_mixed
     wget https://minio.clarin-pl.eu/public/models/punctuation/actions_mixed/production.model -O deploy/actions_mixed/production.model
     wget https://minio.clarin-pl.eu/public/models/punctuation/actions_mixed/production.config -O deploy/actions_mixed/production.config
+    wget https://minio.clarin-pl.eu/public/models/punctuation/actions_mixed/production.runtime.yaml -O deploy/actions_mixed/production.runtime.yaml
 fi
 
 if ! test -d "./deploy/actions_restricted"; then
     mkdir -p ./deploy/actions_restricted
     wget https://minio.clarin-pl.eu/public/models/punctuation/actions_restricted/production.model -O deploy/actions_restricted/production.model
     wget https://minio.clarin-pl.eu/public/models/punctuation/actions_restricted/production.config -O deploy/actions_restricted/production.config
+    wget https://minio.clarin-pl.eu/public/models/punctuation/actions_restricted/production.runtime.yaml -O deploy/actions_restricted/production.runtime.yaml
 fi
 
-#python worker.py
+python worker.py
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index 0ec7735..6f628b8 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -110,14 +110,18 @@ class ActionsModelRestricted(ActionsModel):
 
         logits = self.forward(input_ids, attention_mask=attention_mask)
         prob_uppercase = logits[:, :, :1].sigmoid()
-        prob_punctuation = logits[:, :, 1:].softmax(dim=-1)[:, :, :-1]
+        prob_punctuation = logits[:, :, 1:].softmax(dim=-1)
+
+        no_punctuation = prob_punctuation.argmax(-1) == (self.params.extended_action_vector_size-2)
+        no_punctuation = no_punctuation.type(torch.float).unsqueeze(-1).repeat(1, 1, prob_punctuation.shape[-1]-1)
+
+        prob_punctuation = prob_punctuation[:, :, :-1].softmax(-1) * (1 - no_punctuation)
 
         return torch.cat([prob_uppercase, prob_punctuation], dim=-1)
 
     def predict(self, text: str) -> str:
-        # TODO: Generalize
-        chunk_size = 500
-        threshold = 0.9
+        chunk_size = self.runtime.chunksize
+        threshold = self.runtime.threshold
 
         text = text.strip()
 
-- 
GitLab


From 81f40372dd5d2f354db320644fc99e51bb167bc5 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Wed, 2 Sep 2020 12:30:37 +0200
Subject: [PATCH 116/116] Style fixes

---
 punctuate.py                           |  5 ++++-
 src/models/actions_model_mixed.py      |  4 ++--
 src/models/actions_model_restricted.py | 14 +++++++++++---
 src/models/model_factory.py            |  2 +-
 tox.ini                                |  2 +-
 worker.py                              | 11 ++++++++---
 6 files changed, 27 insertions(+), 11 deletions(-)

diff --git a/punctuate.py b/punctuate.py
index e8eac14..8eb4bdc 100755
--- a/punctuate.py
+++ b/punctuate.py
@@ -95,7 +95,10 @@ if __name__ == "__main__":
     data_tokenized = tokenizer(data, return_tensors="pt")
 
     predictions = (
-        model.predict_raw(data_tokenized["input_ids"].to(device), data_tokenized["attention_mask"].to(device))
+        model.predict_raw(
+            data_tokenized["input_ids"].to(device),
+            data_tokenized["attention_mask"].to(device),
+        )
         .detach()
         .cpu()
         .numpy()
diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py
index 6982efa..e8f9a50 100644
--- a/src/models/actions_model_mixed.py
+++ b/src/models/actions_model_mixed.py
@@ -199,8 +199,8 @@ class ActionsModelMixed(PunctuationModel):
 
             inputs.append(
                 (
-                    prediction_raw.detach().cpu()
-                    .numpy()[0, -1, :] > self.runtime.threshold
+                    prediction_raw.detach().cpu().numpy()[0, -1, :]
+                    > self.runtime.threshold
                 ).astype(np.float)
             )
 
diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py
index 6f628b8..9239e66 100644
--- a/src/models/actions_model_restricted.py
+++ b/src/models/actions_model_restricted.py
@@ -112,10 +112,18 @@ class ActionsModelRestricted(ActionsModel):
         prob_uppercase = logits[:, :, :1].sigmoid()
         prob_punctuation = logits[:, :, 1:].softmax(dim=-1)
 
-        no_punctuation = prob_punctuation.argmax(-1) == (self.params.extended_action_vector_size-2)
-        no_punctuation = no_punctuation.type(torch.float).unsqueeze(-1).repeat(1, 1, prob_punctuation.shape[-1]-1)
+        no_punctuation = prob_punctuation.argmax(-1) == (
+            self.params.extended_action_vector_size - 2
+        )
+        no_punctuation = (
+            no_punctuation.type(torch.float)
+            .unsqueeze(-1)
+            .repeat(1, 1, prob_punctuation.shape[-1] - 1)
+        )
 
-        prob_punctuation = prob_punctuation[:, :, :-1].softmax(-1) * (1 - no_punctuation)
+        prob_punctuation = prob_punctuation[:, :, :-1].softmax(-1) * (
+            1 - no_punctuation
+        )
 
         return torch.cat([prob_uppercase, prob_punctuation], dim=-1)
 
diff --git a/src/models/model_factory.py b/src/models/model_factory.py
index cfa51d2..3d4abcc 100644
--- a/src/models/model_factory.py
+++ b/src/models/model_factory.py
@@ -6,5 +6,5 @@ from src.models.actions_model_base import ActionsModelBase
 MODELS_MAP = {
     "actions_base": ActionsModelBase,
     "actions_restricted": ActionsModelRestricted,
-    "actions_mixed": ActionsModelMixed
+    "actions_mixed": ActionsModelMixed,
 }
diff --git a/tox.ini b/tox.ini
index 7f8448b..02ec5a0 100644
--- a/tox.ini
+++ b/tox.ini
@@ -27,7 +27,7 @@ max-complexity = 10
 min_python_version = 3.8
 max-line-length = 80
 select = I,C,E,F,W,B,B950,TYP,T
-ignore = E501, C901, I201
+ignore = E501, C901, I201, W503
 
 
 [testenv:pep8]
diff --git a/worker.py b/worker.py
index 91a5d31..7c8011b 100755
--- a/worker.py
+++ b/worker.py
@@ -27,7 +27,9 @@ class Worker(nlp_ws.NLPWorker):
 
     def _load_models(self, models_list: List[str]):
         for model_type in models_list:
-            self.models[model_type] = MODELS_MAP[model_type].load(f"{self.models_dir}/{model_type}", "production", self.device)
+            self.models[model_type] = MODELS_MAP[model_type].load(
+                f"{self.models_dir}/{model_type}", "production", self.device
+            )
             self.models[model_type].train(False)
 
         self.model.train(False)
@@ -35,8 +37,11 @@ class Worker(nlp_ws.NLPWorker):
     def process(self, input_file: str, task_options: dict, output_file: str) -> None:
         """Implementation of example tasks that copies files."""
 
-        if "model" in task_options.keys() and task_options['model'] in MODELS_MAP.keys():
-            model_type = task_options['model']
+        if (
+            "model" in task_options.keys()
+            and task_options["model"] in MODELS_MAP.keys()
+        ):
+            model_type = task_options["model"]
         else:
             model_type = "actions_base"
 
-- 
GitLab