From 28f79649b1df5022a57f9e4cb4137bddb1c87ca2 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Mon, 13 Nov 2023 19:26:44 +1100
Subject: [PATCH] Remove unnecessary dilated_cnn copy

---
 combo/models/dilated_cnn.py       | 14 +++--
 combo/polish_model_training.ipynb | 92 ++++++++++++++++---------------
 2 files changed, 55 insertions(+), 51 deletions(-)

diff --git a/combo/models/dilated_cnn.py b/combo/models/dilated_cnn.py
index 71b7df2..e694a10 100644
--- a/combo/models/dilated_cnn.py
+++ b/combo/models/dilated_cnn.py
@@ -6,13 +6,15 @@ Author: Mateusz Klimaszewski
 from typing import List
 
 import torch
-import torch.nn as nn
 
-from combo.nn import Activation
+from combo.config import FromParameters, Registry
+from combo.config.from_parameters import register_arguments
+from combo.nn.activations import Activation
 
 
-class DilatedCnnEncoder(nn.Module):
-
+@Registry.register('dilated_cnn')
+class DilatedCnnEncoder(torch.nn.Module, FromParameters):
+    @register_arguments
     def __init__(self,
                  input_dim: int,
                  filters: List[int],
@@ -26,14 +28,14 @@ class DilatedCnnEncoder(nn.Module):
         input_dims = [input_dim] + filters[:-1]
         output_dims = filters
         for idx in range(len(activations)):
-            conv1d_layers.append(nn.Conv1d(
+            conv1d_layers.append(torch.nn.Conv1d(
                 in_channels=input_dims[idx],
                 out_channels=output_dims[idx],
                 kernel_size=(kernel_size[idx],),
                 stride=(stride[idx],),
                 padding=padding[idx],
                 dilation=(dilation[idx],)))
-        self.conv1d_layers = nn.ModuleList(conv1d_layers)
+        self.conv1d_layers = torch.nn.ModuleList(conv1d_layers)
         self.activations = activations
         assert len(self.activations) == len(self.conv1d_layers)
 
diff --git a/combo/polish_model_training.ipynb b/combo/polish_model_training.ipynb
index 1d75abd..bd2796f 100644
--- a/combo/polish_model_training.ipynb
+++ b/combo/polish_model_training.ipynb
@@ -14,8 +14,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:47:15.954139Z",
-     "start_time": "2023-11-13T07:47:15.711912Z"
+     "end_time": "2023-11-13T08:18:36.376046Z",
+     "start_time": "2023-11-13T08:18:36.189836Z"
     }
    },
    "id": "b28c7d8bacb08d02"
@@ -51,8 +51,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:47:22.233317Z",
-     "start_time": "2023-11-13T07:47:15.766709Z"
+     "end_time": "2023-11-13T08:18:41.300316Z",
+     "start_time": "2023-11-13T08:18:36.197537Z"
     }
    },
    "id": "initial_id"
@@ -77,7 +77,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "d318d4f50da14b76a14eb20cb877ee67"
+       "model_id": "e45b57a7047043a48ccfeacfb49312b5"
       }
      },
      "metadata": {},
@@ -89,7 +89,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "dbab946d82ab4d0ead64fc02796c2a9f"
+       "model_id": "a886ae4451474459b088659ebac076ae"
       }
      },
      "metadata": {},
@@ -101,7 +101,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "2f3d9306cb2b463eb080c922fe775b02"
+       "model_id": "20ff98564c7c43a9971c25f82ceda997"
       }
      },
      "metadata": {},
@@ -169,8 +169,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:47:42.601537Z",
-     "start_time": "2023-11-13T07:47:22.243325Z"
+     "end_time": "2023-11-13T08:19:01.785477Z",
+     "start_time": "2023-11-13T08:18:41.119674Z"
     }
    },
    "id": "d74957f422f0b05b"
@@ -192,8 +192,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:47:44.068445Z",
-     "start_time": "2023-11-13T07:47:42.595098Z"
+     "end_time": "2023-11-13T08:19:03.816723Z",
+     "start_time": "2023-11-13T08:19:01.774666Z"
     }
    },
    "id": "fa724d362fd6bd23"
@@ -211,7 +211,7 @@
     },
     {
      "data": {
-      "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fdd1e0a0c80>"
+      "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fb2d3cdfc80>"
      },
      "execution_count": 5,
      "metadata": {},
@@ -239,8 +239,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:47:44.196484Z",
-     "start_time": "2023-11-13T07:47:44.034821Z"
+     "end_time": "2023-11-13T08:19:03.877868Z",
+     "start_time": "2023-11-13T08:19:03.826289Z"
     }
    },
    "id": "f8a10f9892005fca"
@@ -263,8 +263,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:47:44.197075Z",
-     "start_time": "2023-11-13T07:47:44.055240Z"
+     "end_time": "2023-11-13T08:19:03.887795Z",
+     "start_time": "2023-11-13T08:19:03.870640Z"
     }
    },
    "id": "14413692656b68ac"
@@ -277,7 +277,7 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.sso.sso_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']\n",
+      "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight']\n",
       "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
       "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
      ]
@@ -411,8 +411,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:47:48.599708Z",
-     "start_time": "2023-11-13T07:47:44.063606Z"
+     "end_time": "2023-11-13T08:19:10.006549Z",
+     "start_time": "2023-11-13T08:19:03.885912Z"
     }
    },
    "id": "437d12054baaffa1"
@@ -430,8 +430,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:48:26.090634Z",
-     "start_time": "2023-11-13T07:47:48.622684Z"
+     "end_time": "2023-11-13T08:19:47.953809Z",
+     "start_time": "2023-11-13T08:19:09.989582Z"
     }
    },
    "id": "e131e0ec75dc6927"
@@ -446,8 +446,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:48:32.052740Z",
-     "start_time": "2023-11-13T07:48:26.077694Z"
+     "end_time": "2023-11-13T08:19:53.878201Z",
+     "start_time": "2023-11-13T08:19:47.940147Z"
     }
    },
    "id": "195c71fcf8170ff"
@@ -481,8 +481,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:48:32.321842Z",
-     "start_time": "2023-11-13T07:48:32.056903Z"
+     "end_time": "2023-11-13T08:19:54.150333Z",
+     "start_time": "2023-11-13T08:19:53.874397Z"
     }
    },
    "id": "cefc5173154d1605"
@@ -512,7 +512,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "027b704c9899478bb71021e074ad29bf"
+       "model_id": "9f6e199b0fd546f5833fdda238964165"
       }
      },
      "metadata": {},
@@ -534,7 +534,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "4af02d76668645ae9213db79ae97d36f"
+       "model_id": "594c065e5d2441f48ba2b87c7a3f528f"
       }
      },
      "metadata": {},
@@ -546,7 +546,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "55eda5299f554849aba6bd2781608ed2"
+       "model_id": "bd8bf330a90e4eddb52a3c87af6d2869"
       }
      },
      "metadata": {},
@@ -566,8 +566,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:49:35.721377Z",
-     "start_time": "2023-11-13T07:48:32.278875Z"
+     "end_time": "2023-11-13T08:20:39.048426Z",
+     "start_time": "2023-11-13T08:19:54.147748Z"
     }
    },
    "id": "e5af131bae4b1a33"
@@ -582,8 +582,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:49:35.728679Z",
-     "start_time": "2023-11-13T07:49:35.696749Z"
+     "end_time": "2023-11-13T08:20:39.152284Z",
+     "start_time": "2023-11-13T08:20:39.042845Z"
     }
    },
    "id": "3e23413c86063183"
@@ -598,8 +598,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:49:35.972167Z",
-     "start_time": "2023-11-13T07:49:35.711714Z"
+     "end_time": "2023-11-13T08:20:39.228735Z",
+     "start_time": "2023-11-13T08:20:39.052747Z"
     }
    },
    "id": "d555d7f0223a624b"
@@ -613,8 +613,8 @@
      "output_type": "stream",
      "text": [
       "TOKEN           LEMMA           UPOS       HEAD       DEPREL    \n",
-      "Cześć,          ??????          NOUN                0 root      \n",
-      "jestem          ??????          NOUN                1 punct     \n",
+      "Cześć,          ?????           NOUN                0 root      \n",
+      "jestem          ?????           NOUN                1 punct     \n",
       "psem.           ?????           NOUN                1 punct     \n"
      ]
     }
@@ -627,8 +627,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:49:35.973153Z",
-     "start_time": "2023-11-13T07:49:35.929034Z"
+     "end_time": "2023-11-13T08:20:39.237630Z",
+     "start_time": "2023-11-13T08:20:39.227051Z"
     }
    },
    "id": "a68cd3861e1ceb67"
@@ -643,8 +643,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:49:35.973436Z",
-     "start_time": "2023-11-13T07:49:35.931941Z"
+     "end_time": "2023-11-13T08:20:39.248539Z",
+     "start_time": "2023-11-13T08:20:39.233003Z"
     }
    },
    "id": "d0f43f4493218b5"
@@ -668,8 +668,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:51:13.077831Z",
-     "start_time": "2023-11-13T07:49:35.950950Z"
+     "end_time": "2023-11-13T08:24:03.513738Z",
+     "start_time": "2023-11-13T08:20:39.250115Z"
     }
    },
    "id": "ec92aa5bb5bb3605"
@@ -678,12 +678,14 @@
    "cell_type": "code",
    "execution_count": 16,
    "outputs": [],
-   "source": [],
+   "source": [
+    "\n"
+   ],
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T07:51:13.123575Z",
-     "start_time": "2023-11-13T07:51:13.067631Z"
+     "end_time": "2023-11-13T08:24:03.668958Z",
+     "start_time": "2023-11-13T08:24:02.256799Z"
     }
    },
    "id": "5ad8a827586f65e3"
-- 
GitLab