From 901d94ed8c328aa87e4dc3c3d6f9e65201bdb727 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Sat, 11 Nov 2023 18:38:49 +1100
Subject: [PATCH] Add dataset reader serialization

---
 combo/modules/archival.py         |  14 +++-
 combo/polish_model_training.ipynb | 117 ++++++++++++++----------------
 combo/predict.py                  |   2 +-
 3 files changed, 68 insertions(+), 65 deletions(-)

diff --git a/combo/modules/archival.py b/combo/modules/archival.py
index ecf22f7..aae88dd 100644
--- a/combo/modules/archival.py
+++ b/combo/modules/archival.py
@@ -12,6 +12,7 @@ from tempfile import TemporaryDirectory
 
 from combo.config import resolve
 from combo.data.dataset_loaders import DataLoader
+from combo.data.dataset_readers import DatasetReader
 from combo.modules.model import Model
 
 
@@ -24,6 +25,7 @@ class Archive(NamedTuple):
     config: Optional[Dict[str, Any]]
     data_loader: Optional[DataLoader]
     validation_data_loader: Optional[DataLoader]
+    dataset_reader: Optional[DatasetReader]
 
 
 def add_to_tar(tar_file: tarfile.TarFile, out_stream: BytesIO, data: bytes, name: str):
@@ -37,7 +39,8 @@ def add_to_tar(tar_file: tarfile.TarFile, out_stream: BytesIO, data: bytes, name
 def archive(model: Model,
             serialization_dir: Union[PathLike, str],
             data_loader: Optional[DataLoader] = None,
-            validation_data_loader: Optional[DataLoader] = None) -> str:
+            validation_data_loader: Optional[DataLoader] = None,
+            dataset_reader: Optional[DatasetReader] = None) -> str:
     parameters = {'vocabulary': {
         'type': 'from_files_vocabulary',
         'parameters': {
@@ -51,6 +54,8 @@ def archive(model: Model,
         parameters['data_loader'] = data_loader.serialize()
     if validation_data_loader:
         parameters['validation_data_loader'] = validation_data_loader.serialize()
+    if dataset_reader:
+        parameters['dataset_reader'] = dataset_reader.serialize()
 
     parameters['training'] = {}
 
@@ -87,14 +92,17 @@ def load_archive(url_or_filename: Union[PathLike, str],
     with open(os.path.join(archive_file, 'config.json'), 'r') as f:
         config = json.load(f)
 
-    data_loader, validation_data_loader = None, None
+    data_loader, validation_data_loader, dataset_reader = None, None, None
 
     if 'data_loader' in config:
         data_loader = resolve(config['data_loader'])
     if 'validation_data_loader' in config:
         validation_data_loader = resolve(config['validation_data_loader'])
+    if 'dataset_reader' in config:
+        dataset_reader = resolve(config['dataset_reader'])
     
     return Archive(model=model,
                    config=config,
                    data_loader=data_loader,
-                   validation_data_loader=validation_data_loader)
+                   validation_data_loader=validation_data_loader,
+                   dataset_reader=dataset_reader)
diff --git a/combo/polish_model_training.ipynb b/combo/polish_model_training.ipynb
index 78ef15d..f9787ee 100644
--- a/combo/polish_model_training.ipynb
+++ b/combo/polish_model_training.ipynb
@@ -14,15 +14,15 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:03:10.139381Z",
-     "start_time": "2023-11-04T18:03:09.293292Z"
+     "end_time": "2023-11-11T07:28:53.129601Z",
+     "start_time": "2023-11-11T07:28:52.947282Z"
     }
    },
    "id": "b28c7d8bacb08d02"
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 4,
    "outputs": [],
    "source": [
     "from combo.predict import COMBO\n",
@@ -43,7 +43,6 @@
     "from combo.modules.parser import DependencyRelationModel, HeadPredictionModel\n",
     "from combo.modules.lemma import LemmatizerModel\n",
     "from combo.modules.morpho import MorphologicalFeatures\n",
-    "from combo.nn.regularizers import Regularizer\n",
     "from combo.nn.regularizers.regularizers import L2Regularizer\n",
     "import pytorch_lightning as pl\n",
     "from combo.training.trainable_combo import TrainableCombo\n",
@@ -52,15 +51,15 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:03:15.961674Z",
-     "start_time": "2023-11-04T18:03:09.430012Z"
+     "end_time": "2023-11-11T07:29:25.986145Z",
+     "start_time": "2023-11-11T07:29:25.671527Z"
     }
    },
    "id": "initial_id"
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 5,
    "outputs": [
     {
      "name": "stdout",
@@ -78,7 +77,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "89c66e96f406495298ca072425a787e3"
+       "model_id": "0df1c8c352a14e9993691edf5626968f"
       }
      },
      "metadata": {},
@@ -90,7 +89,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "13dc9a1d9d224454b102db9dc40da776"
+       "model_id": "8ac307b142074e38afe3d55e00b3c203"
       }
      },
      "metadata": {},
@@ -102,7 +101,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "ae2380f39ca54f07b095288e67aa0fcf"
+       "model_id": "f60a4d838fc849c3aff01901f531d91c"
       }
      },
      "metadata": {},
@@ -170,15 +169,15 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:03:36.466620Z",
-     "start_time": "2023-11-04T18:03:15.931675Z"
+     "end_time": "2023-11-11T07:29:54.208157Z",
+     "start_time": "2023-11-11T07:29:25.685934Z"
     }
    },
    "id": "d74957f422f0b05b"
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 6,
    "outputs": [],
    "source": [
     "seq_encoder = ComboEncoder(layer_dropout_probability=0.33,\n",
@@ -193,15 +192,15 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:03:37.132462Z",
-     "start_time": "2023-11-04T18:03:36.471496Z"
+     "end_time": "2023-11-11T07:29:55.768681Z",
+     "start_time": "2023-11-11T07:29:54.231728Z"
     }
    },
    "id": "fa724d362fd6bd23"
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 7,
    "outputs": [
     {
      "name": "stdout",
@@ -212,9 +211,9 @@
     },
     {
      "data": {
-      "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fd086156e40>"
+      "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7faf9b7f6820>"
      },
-     "execution_count": 5,
+     "execution_count": 7,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -240,15 +239,15 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:03:37.199203Z",
-     "start_time": "2023-11-04T18:03:37.132934Z"
+     "end_time": "2023-11-11T07:29:55.840836Z",
+     "start_time": "2023-11-11T07:29:55.773085Z"
     }
    },
    "id": "f8a10f9892005fca"
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 8,
    "outputs": [
     {
      "name": "stderr",
@@ -264,21 +263,21 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:03:37.200310Z",
-     "start_time": "2023-11-04T18:03:37.172593Z"
+     "end_time": "2023-11-11T07:29:55.859643Z",
+     "start_time": "2023-11-11T07:29:55.837113Z"
     }
    },
    "id": "14413692656b68ac"
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 9,
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.sso.sso_relationship.bias', 'cls.sso.sso_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']\n",
+      "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.decoder.weight', 'cls.sso.sso_relationship.weight']\n",
       "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
       "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
      ]
@@ -412,8 +411,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:10:46.777158Z",
-     "start_time": "2023-11-04T18:10:37.675411Z"
+     "end_time": "2023-11-11T07:30:01.875464Z",
+     "start_time": "2023-11-11T07:29:55.851182Z"
     }
    },
    "id": "437d12054baaffa1"
@@ -431,8 +430,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:13:21.449314Z",
-     "start_time": "2023-11-04T18:13:21.175555Z"
+     "end_time": "2023-11-11T07:30:46.834983Z",
+     "start_time": "2023-11-11T07:30:01.904158Z"
     }
    },
    "id": "e131e0ec75dc6927"
@@ -447,8 +446,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:13:25.997985Z",
-     "start_time": "2023-11-04T18:13:21.300949Z"
+     "end_time": "2023-11-11T07:30:51.486798Z",
+     "start_time": "2023-11-11T07:30:46.839866Z"
     }
    },
    "id": "195c71fcf8170ff"
@@ -482,8 +481,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:13:26.531259Z",
-     "start_time": "2023-11-04T18:13:25.980771Z"
+     "end_time": "2023-11-11T07:30:51.879326Z",
+     "start_time": "2023-11-11T07:30:51.500543Z"
     }
    },
    "id": "cefc5173154d1605"
@@ -513,7 +512,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "42c0f91fed83494bb0459b62e7a1fa7e"
+       "model_id": "7b4d5f7d5cec41b98aaf4c5fa3fff0d8"
       }
      },
      "metadata": {},
@@ -535,7 +534,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "007aa4f3ab2944e79ba186b382721d3c"
+       "model_id": "71a7264ff03d4f17828db6f8a7893e39"
       }
      },
      "metadata": {},
@@ -547,7 +546,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "bb7220511bb5416bb00bd634b6b27ca8"
+       "model_id": "b93d9100d1bc4a04ab706a9a3840644c"
       }
      },
      "metadata": {},
@@ -567,8 +566,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:14:18.581168Z",
-     "start_time": "2023-11-04T18:13:26.516910Z"
+     "end_time": "2023-11-11T07:32:36.809443Z",
+     "start_time": "2023-11-11T07:30:51.816554Z"
     }
    },
    "id": "e5af131bae4b1a33"
@@ -583,8 +582,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:14:18.600390Z",
-     "start_time": "2023-11-04T18:14:18.592964Z"
+     "end_time": "2023-11-11T07:32:37.095367Z",
+     "start_time": "2023-11-11T07:32:32.550627Z"
     }
    },
    "id": "3e23413c86063183"
@@ -599,8 +598,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:14:18.974570Z",
-     "start_time": "2023-11-04T18:14:18.594341Z"
+     "end_time": "2023-11-11T07:32:37.333871Z",
+     "start_time": "2023-11-11T07:32:32.625348Z"
     }
    },
    "id": "d555d7f0223a624b"
@@ -614,9 +613,9 @@
      "output_type": "stream",
      "text": [
       "TOKEN           LEMMA           UPOS       HEAD       DEPREL    \n",
-      "Cześć,          ?????a          NOUN                0 root      \n",
-      "jestem          ?????a          NOUN                1 punct     \n",
-      "psem.           ?????           NOUN                2 punct     \n"
+      "Cześć,          ?????           NOUN                2 punct     \n",
+      "jestem          ?????           NOUN                0 root      \n",
+      "psem.           ????            NOUN                2 punct     \n"
      ]
     }
    ],
@@ -628,8 +627,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:14:18.994941Z",
-     "start_time": "2023-11-04T18:14:18.958029Z"
+     "end_time": "2023-11-11T07:32:38.854144Z",
+     "start_time": "2023-11-11T07:32:35.324424Z"
     }
    },
    "id": "a68cd3861e1ceb67"
@@ -644,8 +643,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:14:18.996620Z",
-     "start_time": "2023-11-04T18:14:18.970791Z"
+     "end_time": "2023-11-11T07:32:41.112617Z",
+     "start_time": "2023-11-11T07:32:35.502093Z"
     }
    },
    "id": "d0f43f4493218b5"
@@ -664,37 +663,33 @@
     }
    ],
    "source": [
-    "archive(model, '/Users/majajablonska/Documents/combo', data_loader, val_data_loader)"
+    "archive(model, '/Users/majajablonska/Documents/combo', data_loader, val_data_loader, dataset_reader)"
    ],
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-04T18:15:54.256906Z",
-     "start_time": "2023-11-04T18:14:18.986766Z"
+     "end_time": "2023-11-11T07:34:18.278208Z",
+     "start_time": "2023-11-11T07:32:35.783931Z"
     }
    },
    "id": "ec92aa5bb5bb3605"
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": null,
    "outputs": [],
    "source": [],
    "metadata": {
-    "collapsed": false,
-    "ExecuteTime": {
-     "end_time": "2023-11-04T18:15:54.339234Z",
-     "start_time": "2023-11-04T18:15:54.217490Z"
-    }
+    "collapsed": false
    },
-   "id": "953bd53cccd5f890"
+   "id": "5ad8a827586f65e3"
   }
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3",
+   "name": "python3",
    "language": "python",
-   "name": "python3"
+   "display_name": "Python 3 (ipykernel)"
   },
   "language_info": {
    "codemirror_mode": {
diff --git a/combo/predict.py b/combo/predict.py
index 42c55a2..8d6b51f 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -265,5 +265,5 @@ class COMBO(PredictorModule):
 
         archive = load_archive(model_path, cuda_device=cuda_device)
         model = archive.model
-        dataset_reader = default_ud_dataset_reader()
+        dataset_reader = archive.dataset_reader or default_ud_dataset_reader()
         return cls(model, dataset_reader, tokenizer, batch_size)
-- 
GitLab