From 28f79649b1df5022a57f9e4cb4137bddb1c87ca2 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Mon, 13 Nov 2023 19:26:44 +1100 Subject: [PATCH] Remove unnecessary dilated_cnn copy --- combo/models/dilated_cnn.py | 14 +++-- combo/polish_model_training.ipynb | 92 ++++++++++++++++--------------- 2 files changed, 55 insertions(+), 51 deletions(-) diff --git a/combo/models/dilated_cnn.py b/combo/models/dilated_cnn.py index 71b7df2..e694a10 100644 --- a/combo/models/dilated_cnn.py +++ b/combo/models/dilated_cnn.py @@ -6,13 +6,15 @@ Author: Mateusz Klimaszewski from typing import List import torch -import torch.nn as nn -from combo.nn import Activation +from combo.config import FromParameters, Registry +from combo.config.from_parameters import register_arguments +from combo.nn.activations import Activation -class DilatedCnnEncoder(nn.Module): - +@Registry.register('dilated_cnn') +class DilatedCnnEncoder(torch.nn.Module, FromParameters): + @register_arguments def __init__(self, input_dim: int, filters: List[int], @@ -26,14 +28,14 @@ class DilatedCnnEncoder(nn.Module): input_dims = [input_dim] + filters[:-1] output_dims = filters for idx in range(len(activations)): - conv1d_layers.append(nn.Conv1d( + conv1d_layers.append(torch.nn.Conv1d( in_channels=input_dims[idx], out_channels=output_dims[idx], kernel_size=(kernel_size[idx],), stride=(stride[idx],), padding=padding[idx], dilation=(dilation[idx],))) - self.conv1d_layers = nn.ModuleList(conv1d_layers) + self.conv1d_layers = torch.nn.ModuleList(conv1d_layers) self.activations = activations assert len(self.activations) == len(self.conv1d_layers) diff --git a/combo/polish_model_training.ipynb b/combo/polish_model_training.ipynb index 1d75abd..bd2796f 100644 --- a/combo/polish_model_training.ipynb +++ b/combo/polish_model_training.ipynb @@ -14,8 +14,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:47:15.954139Z", - "start_time": "2023-11-13T07:47:15.711912Z" + "end_time": "2023-11-13T08:18:36.376046Z", + "start_time": "2023-11-13T08:18:36.189836Z" } }, "id": "b28c7d8bacb08d02" @@ -51,8 +51,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:47:22.233317Z", - "start_time": "2023-11-13T07:47:15.766709Z" + "end_time": "2023-11-13T08:18:41.300316Z", + "start_time": "2023-11-13T08:18:36.197537Z" } }, "id": "initial_id" @@ -77,7 +77,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "d318d4f50da14b76a14eb20cb877ee67" + "model_id": "e45b57a7047043a48ccfeacfb49312b5" } }, "metadata": {}, @@ -89,7 +89,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "dbab946d82ab4d0ead64fc02796c2a9f" + "model_id": "a886ae4451474459b088659ebac076ae" } }, "metadata": {}, @@ -101,7 +101,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "2f3d9306cb2b463eb080c922fe775b02" + "model_id": "20ff98564c7c43a9971c25f82ceda997" } }, "metadata": {}, @@ -169,8 +169,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:47:42.601537Z", - "start_time": "2023-11-13T07:47:22.243325Z" + "end_time": "2023-11-13T08:19:01.785477Z", + "start_time": "2023-11-13T08:18:41.119674Z" } }, "id": "d74957f422f0b05b" @@ -192,8 +192,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:47:44.068445Z", - "start_time": "2023-11-13T07:47:42.595098Z" + "end_time": "2023-11-13T08:19:03.816723Z", + "start_time": "2023-11-13T08:19:01.774666Z" } }, "id": "fa724d362fd6bd23" @@ -211,7 +211,7 @@ }, { "data": { - "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fdd1e0a0c80>" + "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fb2d3cdfc80>" }, "execution_count": 5, "metadata": {}, @@ -239,8 +239,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:47:44.196484Z", - "start_time": "2023-11-13T07:47:44.034821Z" + "end_time": "2023-11-13T08:19:03.877868Z", + "start_time": "2023-11-13T08:19:03.826289Z" } }, "id": "f8a10f9892005fca" @@ -263,8 +263,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:47:44.197075Z", - "start_time": "2023-11-13T07:47:44.055240Z" + "end_time": "2023-11-13T08:19:03.887795Z", + "start_time": "2023-11-13T08:19:03.870640Z" } }, "id": "14413692656b68ac" @@ -277,7 +277,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.sso.sso_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']\n", + "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] @@ -411,8 +411,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:47:48.599708Z", - "start_time": "2023-11-13T07:47:44.063606Z" + "end_time": "2023-11-13T08:19:10.006549Z", + "start_time": "2023-11-13T08:19:03.885912Z" } }, "id": "437d12054baaffa1" @@ -430,8 +430,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:48:26.090634Z", - "start_time": "2023-11-13T07:47:48.622684Z" + "end_time": "2023-11-13T08:19:47.953809Z", + "start_time": "2023-11-13T08:19:09.989582Z" } }, "id": "e131e0ec75dc6927" @@ -446,8 +446,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:48:32.052740Z", - "start_time": "2023-11-13T07:48:26.077694Z" + "end_time": "2023-11-13T08:19:53.878201Z", + "start_time": "2023-11-13T08:19:47.940147Z" } }, "id": "195c71fcf8170ff" @@ -481,8 +481,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:48:32.321842Z", - "start_time": "2023-11-13T07:48:32.056903Z" + "end_time": "2023-11-13T08:19:54.150333Z", + "start_time": "2023-11-13T08:19:53.874397Z" } }, "id": "cefc5173154d1605" @@ -512,7 +512,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "027b704c9899478bb71021e074ad29bf" + "model_id": "9f6e199b0fd546f5833fdda238964165" } }, "metadata": {}, @@ -534,7 +534,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "4af02d76668645ae9213db79ae97d36f" + "model_id": "594c065e5d2441f48ba2b87c7a3f528f" } }, "metadata": {}, @@ -546,7 +546,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "55eda5299f554849aba6bd2781608ed2" + "model_id": "bd8bf330a90e4eddb52a3c87af6d2869" } }, "metadata": {}, @@ -566,8 +566,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:49:35.721377Z", - "start_time": "2023-11-13T07:48:32.278875Z" + "end_time": "2023-11-13T08:20:39.048426Z", + "start_time": "2023-11-13T08:19:54.147748Z" } }, "id": "e5af131bae4b1a33" @@ -582,8 +582,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:49:35.728679Z", - "start_time": "2023-11-13T07:49:35.696749Z" + "end_time": "2023-11-13T08:20:39.152284Z", + "start_time": "2023-11-13T08:20:39.042845Z" } }, "id": "3e23413c86063183" @@ -598,8 +598,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:49:35.972167Z", - "start_time": "2023-11-13T07:49:35.711714Z" + "end_time": "2023-11-13T08:20:39.228735Z", + "start_time": "2023-11-13T08:20:39.052747Z" } }, "id": "d555d7f0223a624b" @@ -613,8 +613,8 @@ "output_type": "stream", "text": [ "TOKEN LEMMA UPOS HEAD DEPREL \n", - "Cześć, ?????? NOUN 0 root \n", - "jestem ?????? NOUN 1 punct \n", + "Cześć, ????? NOUN 0 root \n", + "jestem ????? NOUN 1 punct \n", "psem. ????? NOUN 1 punct \n" ] } @@ -627,8 +627,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:49:35.973153Z", - "start_time": "2023-11-13T07:49:35.929034Z" + "end_time": "2023-11-13T08:20:39.237630Z", + "start_time": "2023-11-13T08:20:39.227051Z" } }, "id": "a68cd3861e1ceb67" @@ -643,8 +643,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:49:35.973436Z", - "start_time": "2023-11-13T07:49:35.931941Z" + "end_time": "2023-11-13T08:20:39.248539Z", + "start_time": "2023-11-13T08:20:39.233003Z" } }, "id": "d0f43f4493218b5" @@ -668,8 +668,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:51:13.077831Z", - "start_time": "2023-11-13T07:49:35.950950Z" + "end_time": "2023-11-13T08:24:03.513738Z", + "start_time": "2023-11-13T08:20:39.250115Z" } }, "id": "ec92aa5bb5bb3605" @@ -678,12 +678,14 @@ "cell_type": "code", "execution_count": 16, "outputs": [], - "source": [], + "source": [ + "\n" + ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T07:51:13.123575Z", - "start_time": "2023-11-13T07:51:13.067631Z" + "end_time": "2023-11-13T08:24:03.668958Z", + "start_time": "2023-11-13T08:24:02.256799Z" } }, "id": "5ad8a827586f65e3" -- GitLab