diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index 0b3e0c6acb3549a53390aee1dc1b672f9c0b7ed7..73cb5411def457728b939e3c9b99331bd14a709d 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -1,4 +1,5 @@ import codecs +import json import os import re import glob @@ -288,6 +289,8 @@ class Vocabulary(FromParameters): ) for namespace_filename in os.listdir(directory): + if namespace_filename == "slices.json": + continue if namespace_filename == NAMESPACE_PADDING_FILE: continue if namespace_filename.startswith("."): @@ -300,6 +303,11 @@ class Vocabulary(FromParameters): filename = os.path.join(directory, namespace_filename) vocab.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token) + with codecs.open( + os.path.join(directory, "slices.json"), "r", "utf-8" + ) as slices_file: + vocab.slices = json.load(slices_file) + vocab.constructed_from = 'from_files' return vocab @@ -532,6 +540,11 @@ class Vocabulary(FromParameters): for i in range(start_index, num_tokens): print(mapping[i].replace("\n", "@@NEWLINE@@"), file=token_file) + with codecs.open( + os.path.join(directory, "slices.json"), "w", "utf-8" + ) as slices_file: + json.dump(get_slices_if_not_provided(self), slices_file) + def is_padded(self, namespace: str) -> bool: namespace_itos = self._vocab[namespace].get_itos() return len(namespace_itos) > 0 and namespace_itos[0] == self._padding_token diff --git a/combo/main.py b/combo/main.py index aeb1dc8ab7c5a834690cb579b0b8f163ed4b1e30..0426f4830f1c4beb3560edb60932633b5d598606 100755 --- a/combo/main.py +++ b/combo/main.py @@ -142,6 +142,11 @@ def run(_): checks.file_exists(FLAGS.finetuning_validation_data_path) validation_data_loader = default_data_loader(default_ud_dataset_reader(), FLAGS.finetuning_validation_data_path) + print("Indexing train loader") + train_data_loader.index_with(model.vocab) + print("Indexing validation loader") + validation_data_loader.index_with(model.vocab) + print("Indexed") nlp = TrainableCombo(model, torch.optim.Adam, optimizer_kwargs={'betas': [0.9, 0.9], 'lr': 0.002}, diff --git a/combo/modules/archival.py b/combo/modules/archival.py index 85586b3af4fbd4fd82dfa2ee70b1695657e843c6..97d0281c0de33793bbf581441595d4923bf13505 100644 --- a/combo/modules/archival.py +++ b/combo/modules/archival.py @@ -44,7 +44,7 @@ def archive(model: Model, 'padding_token': model.vocab._padding_token, 'oov_token': model.vocab._oov_token } - }, 'model': model.serialize()} + }, 'model': model.serialize(pass_down_parameter_names=['vocabulary'])} if data_loader: parameters['data_loader'] = data_loader.serialize() diff --git a/combo/modules/model.py b/combo/modules/model.py index c90c05aabe1f7685c0ae18ddcd398e9aa7c57f8b..d9ad468dc16cafd03c024b2964f9429c41a5c5f3 100644 --- a/combo/modules/model.py +++ b/combo/modules/model.py @@ -341,9 +341,10 @@ class Model(Module, FromParameters): vocab_params = config.get("vocabulary") if vocab_params['type'] == 'from_files_vocabulary': vocab_params['parameters']['directory'] = vocab_dir - vocab = resolve(vocab_params) model_params = config.get("model") + model_params['parameters']['vocabulary'] = vocab_params + print(vocab_params) # The experiment config tells us how to _train_ a model, including where to get pre-trained # embeddings/weights from. We're now _loading_ the model, so those weights will already be diff --git a/combo/polish_model_training.ipynb b/combo/polish_model_training.ipynb index 6672180436d81f0d3d0d3e491b9b8ee02df3f80e..5a12c9933db0bb556334d501fe13283106500ace 100644 --- a/combo/polish_model_training.ipynb +++ b/combo/polish_model_training.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "outputs": [], "source": [ "# The path where the training and validation datasets are stored\n", @@ -14,15 +14,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:35:09.621320Z", - "start_time": "2023-10-15T12:35:09.407839Z" + "end_time": "2023-10-16T10:53:31.473874Z", + "start_time": "2023-10-16T10:53:31.267229Z" } }, "id": "b28c7d8bacb08d02" }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "outputs": [], "source": [ "import os\n", @@ -52,15 +52,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:35:17.384934Z", - "start_time": "2023-10-15T12:35:09.418819Z" + "end_time": "2023-10-16T10:53:39.605036Z", + "start_time": "2023-10-16T10:53:31.279613Z" } }, "id": "initial_id" }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "outputs": [ { "name": "stdout", @@ -78,7 +78,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "22fcc24a17304631b9ce8b5738210612" + "model_id": "36082f25618145b8802a25ba18d61bd6" } }, "metadata": {}, @@ -90,19 +90,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "a14b090b880f4d6a8314acedeadd8c1f" - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": "building vocabulary: 0it [00:00, ?it/s]", - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "4c47cc124c6f42db9fb7b3b9a07709de" + "model_id": "ee2ebebdbb164726ad182b555c2ff9f8" } }, "metadata": {}, @@ -114,7 +102,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "10cc5ff44d5f48b98c9c35de7eec0b1e" + "model_id": "f1821195b4554b529dde49452eab2382" } }, "metadata": {}, @@ -192,15 +180,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:35:40.187804Z", - "start_time": "2023-10-15T12:35:17.370019Z" + "end_time": "2023-10-16T10:54:05.865349Z", + "start_time": "2023-10-16T10:53:39.584680Z" } }, "id": "d74957f422f0b05b" }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "outputs": [], "source": [ "seq_encoder = ComboEncoder(layer_dropout_probability=0.33,\n", @@ -215,15 +203,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:35:40.901949Z", - "start_time": "2023-10-15T12:35:40.192629Z" + "end_time": "2023-10-16T10:54:07.772833Z", + "start_time": "2023-10-16T10:54:05.877807Z" } }, "id": "fa724d362fd6bd23" }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "outputs": [ { "name": "stdout", @@ -234,9 +222,9 @@ }, { "data": { - "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fec3dd685f0>" + "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fd5e206a970>" }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -262,21 +250,21 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:35:40.965741Z", - "start_time": "2023-10-15T12:35:40.904199Z" + "end_time": "2023-10-16T10:54:07.898735Z", + "start_time": "2023-10-16T10:54:07.789190Z" } }, "id": "f8a10f9892005fca" }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.sso.sso_relationship.bias', 'cls.predictions.decoder.weight', 'cls.sso.sso_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']\n", + "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.sso.sso_relationship.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] @@ -408,15 +396,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:35:47.716817Z", - "start_time": "2023-10-15T12:35:40.960064Z" + "end_time": "2023-10-16T10:54:13.223026Z", + "start_time": "2023-10-16T10:54:07.901139Z" } }, "id": "437d12054baaffa1" }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "outputs": [], "source": [ "data_loader.index_with(vocabulary)\n", @@ -427,15 +415,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:36:27.968154Z", - "start_time": "2023-10-15T12:35:47.794229Z" + "end_time": "2023-10-16T10:55:14.429100Z", + "start_time": "2023-10-16T10:54:13.264834Z" } }, "id": "e131e0ec75dc6927" }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "outputs": [], "source": [ "val_data_loader.index_with(vocabulary)" @@ -443,15 +431,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:36:31.913752Z", - "start_time": "2023-10-15T12:36:27.951466Z" + "end_time": "2023-10-16T10:55:19.201389Z", + "start_time": "2023-10-16T10:55:14.418315Z" } }, "id": "195c71fcf8170ff" }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "outputs": [ { "name": "stderr", @@ -478,15 +466,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:36:32.265366Z", - "start_time": "2023-10-15T12:36:31.923212Z" + "end_time": "2023-10-16T10:55:19.731575Z", + "start_time": "2023-10-16T10:55:19.211212Z" } }, "id": "cefc5173154d1605" }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "outputs": [ { "name": "stderr", @@ -497,10 +485,10 @@ "-------------------------------------\n", "0 | model | ComboModel | 136 M \n", "-------------------------------------\n", - "12.2 M Trainable params\n", + "12.1 M Trainable params\n", "124 M Non-trainable params\n", "136 M Total params\n", - "546.647 Total estimated model params size (MB)\n" + "546.115 Total estimated model params size (MB)\n" ] }, { @@ -509,7 +497,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "c0fdd5ff1efa43f7b2317abcd0c00fb9" + "model_id": "3920f342a18943d2949c7a67bab4306c" } }, "metadata": {}, @@ -531,7 +519,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "5c0b5cee0be2439588dc665a4bb4ba21" + "model_id": "7c2366213de4496e92dd8e39314cfff6" } }, "metadata": {}, @@ -543,7 +531,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "32bc016d2f99452db3e6dbc9575e4cc8" + "model_id": "1e61130ff9d848c89af356efbaa4ad10" } }, "metadata": {}, @@ -563,15 +551,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:38:24.497905Z", - "start_time": "2023-10-15T12:36:32.262713Z" + "end_time": "2023-10-16T10:57:16.128968Z", + "start_time": "2023-10-16T10:55:19.739514Z" } }, "id": "e5af131bae4b1a33" }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "outputs": [], "source": [ "predictor = COMBO(model, dataset_reader)" @@ -579,15 +567,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:38:32.892083Z", - "start_time": "2023-10-15T12:38:23.228734Z" + "end_time": "2023-10-16T10:57:16.312731Z", + "start_time": "2023-10-16T10:57:15.230642Z" } }, "id": "3e23413c86063183" }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "outputs": [], "source": [ "a = predictor(\"Cześć, jestem psem.\")" @@ -595,24 +583,24 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:38:32.986864Z", - "start_time": "2023-10-15T12:38:25.706104Z" + "end_time": "2023-10-16T10:57:16.316587Z", + "start_time": "2023-10-16T10:57:15.255556Z" } }, "id": "d555d7f0223a624b" }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TOKEN LEMMA UPOS HEAD DEPREL \n", - "Cześć, ????? NOUN 0 root \n", - "jestem ????? NOUN 1 punct \n", - "psem. ???? NOUN 2 punct \n" + "Cześć, ?????? NOUN 0 root \n", + "jestem ?????? NOUN 1 punct \n", + "psem. ????? NOUN 1 punct \n" ] } ], @@ -624,15 +612,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:38:33.022415Z", - "start_time": "2023-10-15T12:38:28.029459Z" + "end_time": "2023-10-16T10:57:16.391654Z", + "start_time": "2023-10-16T10:57:15.559808Z" } }, "id": "a68cd3861e1ceb67" }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "outputs": [], "source": [ "from modules.archival import archive" @@ -640,21 +628,21 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T12:38:33.024270Z", - "start_time": "2023-10-15T12:38:28.177734Z" + "end_time": "2023-10-16T10:57:16.392494Z", + "start_time": "2023-10-16T10:57:15.584799Z" } }, "id": "d0f43f4493218b5" }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 15, "outputs": [ { "data": { "text/plain": "'/Users/majajablonska/Documents/combo'" }, - "execution_count": 20, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -665,21 +653,22 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-15T13:46:22.801925Z", - "start_time": "2023-10-15T13:44:31.564518Z" + "end_time": "2023-10-16T10:58:59.073597Z", + "start_time": "2023-10-16T10:57:15.594684Z" } }, "id": "ec92aa5bb5bb3605" }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "outputs": [], "source": [], "metadata": { "collapsed": false, "ExecuteTime": { - "start_time": "2023-10-15T12:38:32.141265Z" + "end_time": "2023-10-16T10:58:59.087847Z", + "start_time": "2023-10-16T10:58:59.050723Z" } }, "id": "953bd53cccd5f890"