From 901d94ed8c328aa87e4dc3c3d6f9e65201bdb727 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Sat, 11 Nov 2023 18:38:49 +1100 Subject: [PATCH] Add dataset reader serialization --- combo/modules/archival.py | 14 +++- combo/polish_model_training.ipynb | 117 ++++++++++++++---------------- combo/predict.py | 2 +- 3 files changed, 68 insertions(+), 65 deletions(-) diff --git a/combo/modules/archival.py b/combo/modules/archival.py index ecf22f7..aae88dd 100644 --- a/combo/modules/archival.py +++ b/combo/modules/archival.py @@ -12,6 +12,7 @@ from tempfile import TemporaryDirectory from combo.config import resolve from combo.data.dataset_loaders import DataLoader +from combo.data.dataset_readers import DatasetReader from combo.modules.model import Model @@ -24,6 +25,7 @@ class Archive(NamedTuple): config: Optional[Dict[str, Any]] data_loader: Optional[DataLoader] validation_data_loader: Optional[DataLoader] + dataset_reader: Optional[DatasetReader] def add_to_tar(tar_file: tarfile.TarFile, out_stream: BytesIO, data: bytes, name: str): @@ -37,7 +39,8 @@ def add_to_tar(tar_file: tarfile.TarFile, out_stream: BytesIO, data: bytes, name def archive(model: Model, serialization_dir: Union[PathLike, str], data_loader: Optional[DataLoader] = None, - validation_data_loader: Optional[DataLoader] = None) -> str: + validation_data_loader: Optional[DataLoader] = None, + dataset_reader: Optional[DatasetReader] = None) -> str: parameters = {'vocabulary': { 'type': 'from_files_vocabulary', 'parameters': { @@ -51,6 +54,8 @@ def archive(model: Model, parameters['data_loader'] = data_loader.serialize() if validation_data_loader: parameters['validation_data_loader'] = validation_data_loader.serialize() + if dataset_reader: + parameters['dataset_reader'] = dataset_reader.serialize() parameters['training'] = {} @@ -87,14 +92,17 @@ def load_archive(url_or_filename: Union[PathLike, str], with open(os.path.join(archive_file, 'config.json'), 'r') as f: config = json.load(f) - data_loader, validation_data_loader = None, None + data_loader, validation_data_loader, dataset_reader = None, None, None if 'data_loader' in config: data_loader = resolve(config['data_loader']) if 'validation_data_loader' in config: validation_data_loader = resolve(config['validation_data_loader']) + if 'dataset_reader' in config: + dataset_reader = resolve(config['dataset_reader']) return Archive(model=model, config=config, data_loader=data_loader, - validation_data_loader=validation_data_loader) + validation_data_loader=validation_data_loader, + dataset_reader=dataset_reader) diff --git a/combo/polish_model_training.ipynb b/combo/polish_model_training.ipynb index 78ef15d..f9787ee 100644 --- a/combo/polish_model_training.ipynb +++ b/combo/polish_model_training.ipynb @@ -14,15 +14,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:03:10.139381Z", - "start_time": "2023-11-04T18:03:09.293292Z" + "end_time": "2023-11-11T07:28:53.129601Z", + "start_time": "2023-11-11T07:28:52.947282Z" } }, "id": "b28c7d8bacb08d02" }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "outputs": [], "source": [ "from combo.predict import COMBO\n", @@ -43,7 +43,6 @@ "from combo.modules.parser import DependencyRelationModel, HeadPredictionModel\n", "from combo.modules.lemma import LemmatizerModel\n", "from combo.modules.morpho import MorphologicalFeatures\n", - "from combo.nn.regularizers import Regularizer\n", "from combo.nn.regularizers.regularizers import L2Regularizer\n", "import pytorch_lightning as pl\n", "from combo.training.trainable_combo import TrainableCombo\n", @@ -52,15 +51,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:03:15.961674Z", - "start_time": "2023-11-04T18:03:09.430012Z" + "end_time": "2023-11-11T07:29:25.986145Z", + "start_time": "2023-11-11T07:29:25.671527Z" } }, "id": "initial_id" }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "outputs": [ { "name": "stdout", @@ -78,7 +77,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "89c66e96f406495298ca072425a787e3" + "model_id": "0df1c8c352a14e9993691edf5626968f" } }, "metadata": {}, @@ -90,7 +89,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "13dc9a1d9d224454b102db9dc40da776" + "model_id": "8ac307b142074e38afe3d55e00b3c203" } }, "metadata": {}, @@ -102,7 +101,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "ae2380f39ca54f07b095288e67aa0fcf" + "model_id": "f60a4d838fc849c3aff01901f531d91c" } }, "metadata": {}, @@ -170,15 +169,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:03:36.466620Z", - "start_time": "2023-11-04T18:03:15.931675Z" + "end_time": "2023-11-11T07:29:54.208157Z", + "start_time": "2023-11-11T07:29:25.685934Z" } }, "id": "d74957f422f0b05b" }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "outputs": [], "source": [ "seq_encoder = ComboEncoder(layer_dropout_probability=0.33,\n", @@ -193,15 +192,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:03:37.132462Z", - "start_time": "2023-11-04T18:03:36.471496Z" + "end_time": "2023-11-11T07:29:55.768681Z", + "start_time": "2023-11-11T07:29:54.231728Z" } }, "id": "fa724d362fd6bd23" }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "outputs": [ { "name": "stdout", @@ -212,9 +211,9 @@ }, { "data": { - "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fd086156e40>" + "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7faf9b7f6820>" }, - "execution_count": 5, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -240,15 +239,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:03:37.199203Z", - "start_time": "2023-11-04T18:03:37.132934Z" + "end_time": "2023-11-11T07:29:55.840836Z", + "start_time": "2023-11-11T07:29:55.773085Z" } }, "id": "f8a10f9892005fca" }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "outputs": [ { "name": "stderr", @@ -264,21 +263,21 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:03:37.200310Z", - "start_time": "2023-11-04T18:03:37.172593Z" + "end_time": "2023-11-11T07:29:55.859643Z", + "start_time": "2023-11-11T07:29:55.837113Z" } }, "id": "14413692656b68ac" }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.sso.sso_relationship.bias', 'cls.sso.sso_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']\n", + "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.decoder.weight', 'cls.sso.sso_relationship.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] @@ -412,8 +411,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:10:46.777158Z", - "start_time": "2023-11-04T18:10:37.675411Z" + "end_time": "2023-11-11T07:30:01.875464Z", + "start_time": "2023-11-11T07:29:55.851182Z" } }, "id": "437d12054baaffa1" @@ -431,8 +430,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:13:21.449314Z", - "start_time": "2023-11-04T18:13:21.175555Z" + "end_time": "2023-11-11T07:30:46.834983Z", + "start_time": "2023-11-11T07:30:01.904158Z" } }, "id": "e131e0ec75dc6927" @@ -447,8 +446,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:13:25.997985Z", - "start_time": "2023-11-04T18:13:21.300949Z" + "end_time": "2023-11-11T07:30:51.486798Z", + "start_time": "2023-11-11T07:30:46.839866Z" } }, "id": "195c71fcf8170ff" @@ -482,8 +481,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:13:26.531259Z", - "start_time": "2023-11-04T18:13:25.980771Z" + "end_time": "2023-11-11T07:30:51.879326Z", + "start_time": "2023-11-11T07:30:51.500543Z" } }, "id": "cefc5173154d1605" @@ -513,7 +512,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "42c0f91fed83494bb0459b62e7a1fa7e" + "model_id": "7b4d5f7d5cec41b98aaf4c5fa3fff0d8" } }, "metadata": {}, @@ -535,7 +534,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "007aa4f3ab2944e79ba186b382721d3c" + "model_id": "71a7264ff03d4f17828db6f8a7893e39" } }, "metadata": {}, @@ -547,7 +546,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "bb7220511bb5416bb00bd634b6b27ca8" + "model_id": "b93d9100d1bc4a04ab706a9a3840644c" } }, "metadata": {}, @@ -567,8 +566,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:14:18.581168Z", - "start_time": "2023-11-04T18:13:26.516910Z" + "end_time": "2023-11-11T07:32:36.809443Z", + "start_time": "2023-11-11T07:30:51.816554Z" } }, "id": "e5af131bae4b1a33" @@ -583,8 +582,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:14:18.600390Z", - "start_time": "2023-11-04T18:14:18.592964Z" + "end_time": "2023-11-11T07:32:37.095367Z", + "start_time": "2023-11-11T07:32:32.550627Z" } }, "id": "3e23413c86063183" @@ -599,8 +598,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:14:18.974570Z", - "start_time": "2023-11-04T18:14:18.594341Z" + "end_time": "2023-11-11T07:32:37.333871Z", + "start_time": "2023-11-11T07:32:32.625348Z" } }, "id": "d555d7f0223a624b" @@ -614,9 +613,9 @@ "output_type": "stream", "text": [ "TOKEN LEMMA UPOS HEAD DEPREL \n", - "Cześć, ?????a NOUN 0 root \n", - "jestem ?????a NOUN 1 punct \n", - "psem. ????? NOUN 2 punct \n" + "Cześć, ????? NOUN 2 punct \n", + "jestem ????? NOUN 0 root \n", + "psem. ???? NOUN 2 punct \n" ] } ], @@ -628,8 +627,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:14:18.994941Z", - "start_time": "2023-11-04T18:14:18.958029Z" + "end_time": "2023-11-11T07:32:38.854144Z", + "start_time": "2023-11-11T07:32:35.324424Z" } }, "id": "a68cd3861e1ceb67" @@ -644,8 +643,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:14:18.996620Z", - "start_time": "2023-11-04T18:14:18.970791Z" + "end_time": "2023-11-11T07:32:41.112617Z", + "start_time": "2023-11-11T07:32:35.502093Z" } }, "id": "d0f43f4493218b5" @@ -664,37 +663,33 @@ } ], "source": [ - "archive(model, '/Users/majajablonska/Documents/combo', data_loader, val_data_loader)" + "archive(model, '/Users/majajablonska/Documents/combo', data_loader, val_data_loader, dataset_reader)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-04T18:15:54.256906Z", - "start_time": "2023-11-04T18:14:18.986766Z" + "end_time": "2023-11-11T07:34:18.278208Z", + "start_time": "2023-11-11T07:32:35.783931Z" } }, "id": "ec92aa5bb5bb3605" }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "outputs": [], "source": [], "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-11-04T18:15:54.339234Z", - "start_time": "2023-11-04T18:15:54.217490Z" - } + "collapsed": false }, - "id": "953bd53cccd5f890" + "id": "5ad8a827586f65e3" } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "name": "python3", "language": "python", - "name": "python3" + "display_name": "Python 3 (ipykernel)" }, "language_info": { "codemirror_mode": { diff --git a/combo/predict.py b/combo/predict.py index 42c55a2..8d6b51f 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -265,5 +265,5 @@ class COMBO(PredictorModule): archive = load_archive(model_path, cuda_device=cuda_device) model = archive.model - dataset_reader = default_ud_dataset_reader() + dataset_reader = archive.dataset_reader or default_ud_dataset_reader() return cls(model, dataset_reader, tokenizer, batch_size) -- GitLab