diff --git a/combo/data/dataset.py b/combo/data/dataset.py index dc7b9020851d10cf87ebe16cfc015069e651de46..3993a55c84b9d9486826f756144ef846bedde481 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -98,7 +98,7 @@ class UniversalDependenciesDatasetReader(DatasetReader): for conllu_file in file_path: file = pathlib.Path(conllu_file) - assert conllu_file and file.exists(), f"File with path '{conllu_file}' does not exists!" + assert conllu_file and file.exists(), f"File with path '{conllu_file}' does not exist!" with file.open("r", encoding="utf-8") as f: for annotation in conllu.parse_incr(f, fields=self.fields, field_parsers=self.field_parsers): yield self.text_to_instance(annotation) diff --git a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py index 14866e7e8a0654f08b8b4328cc0681034fbad8ee..8117c45bb509c991699a9b11e5f188e29e2e08fc 100644 --- a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py +++ b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py @@ -125,7 +125,7 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC): for conllu_file in file_path: file = pathlib.Path(conllu_file) - assert conllu_file and file.exists(), f"File with path '{conllu_file}' does not exists!" + assert conllu_file and file.exists(), f"File with path '{conllu_file}' does not exist!" with file.open("r", encoding="utf-8") as f: for annotation in conllu.parse_incr(f, fields=self.fields, field_parsers=self.field_parsers): yield self.text_to_instance(annotation) diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index 0b3e0c6acb3549a53390aee1dc1b672f9c0b7ed7..3cf5fed37405635090686edbed331ae7e1ee554e 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -300,6 +300,7 @@ class Vocabulary(FromParameters): filename = os.path.join(directory, namespace_filename) vocab.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token) + get_slices_if_not_provided(vocab) vocab.constructed_from = 'from_files' return vocab diff --git a/combo/main.py b/combo/main.py index 6611a2d50df39f95c78d3c3ada7981382a7b8602..a426b10e88f1f801a64f7e1d83b192a834c832b5 100755 --- a/combo/main.py +++ b/combo/main.py @@ -1,6 +1,7 @@ import json import logging import pathlib +import tempfile from typing import Dict import torch @@ -15,6 +16,7 @@ from combo.utils import checks from config import resolve from default_model import default_ud_dataset_reader, default_data_loader from models import ComboModel +from modules.archival import load_archive, archive from predict import COMBO logger = logging.getLogger(__name__) @@ -83,30 +85,13 @@ flags.DEFINE_integer(name="batch_size", default=1, help="Prediction batch size.") flags.DEFINE_boolean(name="silent", default=True, help="Silent prediction to file (without printing to console).") +flags.DEFINE_boolean(name="finetuning", default=False, + help="Finetuning mode for training.") flags.DEFINE_enum(name="predictor_name", default="combo-lambo", enum_values=["combo", "combo-spacy", "combo-lambo"], help="Use predictor with whitespace, spacy or lambo (recommended) tokenizer.") -def get_model_for_training(config_path) -> TrainableCombo: - with open(config_path, 'rb') as f: - config_json = json.load(f) - train_data_loader = resolve(config_json['train_data_loader']) - val_data_loader = resolve(config_json['val_data_loader']) - - model = ComboModel.from_parameters(config_json['model']['parameters']) - train_data_loader.index_with(model.vocab) - val_data_loader.index_with(model.vocab) - nlp = TrainableCombo(model, torch.optim.Adam, - optimizer_kwargs={'betas': [0.9, 0.9], 'lr': 0.002}, - validation_metrics=['EM']) - trainer = pl.Trainer(max_epochs=FLAGS.num_epochs, - default_root_dir=FLAGS.serialization_dir, - gradient_clip_val=5) - trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=val_data_loader) - return model - - def get_saved_model(parameters) -> ComboModel: return ComboModel.load(os.path.join(FLAGS.model_path), config=parameters, @@ -130,8 +115,59 @@ def get_predictor() -> COMBO: def run(_): if FLAGS.mode == 'train': - trained_nlp = get_model_for_training(FLAGS.config_path) - trained_nlp.save(FLAGS.serialization_dir) + if not FLAGS.finetuning: + checks.file_exists(FLAGS.config_path) + with open(FLAGS.config_path, 'r') as f: + params = json.load(f) + params = {**params, **_get_ext_vars()} + + serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir) + + model = resolve(params['model']) + + if 'data_loader' in params: + train_data_loader = resolve(params['data_loader']) + else: + checks.file_exists(FLAGS.training_data_path) + train_data_loader = default_data_loader(default_ud_dataset_reader(), + FLAGS.training_data_path) + train_data_loader.index_with(model.vocab) + + validation_data_loader = None + if 'validation_data_loader' in params: + validation_data_loader = resolve(params['validation_data_loader']) + validation_data_loader.index_with(model.vocab) + elif FLAGS.validation_data_path: + checks.file_exists(FLAGS.validation_data_path) + validation_data_loader = default_data_loader(default_ud_dataset_reader(), + FLAGS.validation_data_path) + validation_data_loader.index_with(model.vocab) + + else: + model, train_data_loader, validation_data_loader = load_archive(FLAGS.model_path) + + serialization_dir = tempfile.mkdtemp(prefix='combo', suffix='-finetuning', dir=FLAGS.serialization_dir) + + if not train_data_loader: + checks.file_exists(FLAGS.finetuning_training_data_path) + train_data_loader = default_data_loader(default_ud_dataset_reader(), + FLAGS.finetuning_training_data_path) + if not validation_data_loader and FLAGS.finetuning_validation_data_path: + checks.file_exists(FLAGS.finetuning_validation_data_path) + validation_data_loader = default_data_loader(default_ud_dataset_reader(), + FLAGS.finetuning_validation_data_path) + + nlp = TrainableCombo(model, torch.optim.Adam, + optimizer_kwargs={'betas': [0.9, 0.9], 'lr': 0.002}, + validation_metrics=['EM']) + trainer = pl.Trainer(max_epochs=FLAGS.num_epochs, + default_root_dir=serialization_dir, + gradient_clip_val=5) + trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader) + + archive(model, serialization_dir) + logger.info(f"Training model stored in: {serialization_dir}") + elif FLAGS.mode == 'predict': predictor = get_predictor() sentence = input("Sentence:") @@ -140,29 +176,6 @@ def run(_): for token in prediction.tokens: print("{:15} {:15} {:10} {:10} {:10}".format(token.text, token.lemma, token.upostag, token.head, token.deprel)) - elif FLAGS.mode == 'finetune': - checks.file_exists(FLAGS.model_path) - with open(os.path.join(FLAGS.model_path, 'params.json'), 'r') as f: - serialized = json.load(f) - if 'dataset_reader' in serialized: - dataset_reader = resolve(serialized['dataset_reader']) - else: - dataset_reader = default_ud_dataset_reader() - model = get_saved_model(serialized) - nlp = TrainableCombo(model, torch.optim.Adam, - optimizer_kwargs={'betas': [0.9, 0.9], 'lr': 0.002}, - validation_metrics=['EM']) - trainer = pl.Trainer(max_epochs=FLAGS.num_epochs, - default_root_dir=FLAGS.serialization_dir, - gradient_clip_val=5) - train_data_loader = default_data_loader(dataset_reader, - FLAGS.finetuning_training_data_path) - val_data_loader = default_data_loader(dataset_reader, - FLAGS.finetuning_validation_data_path) - train_data_loader.index_with(model.vocab) - val_data_loader.index_with(model.vocab) - trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=val_data_loader) - model.save(FLAGS.serialization_dir) def _get_ext_vars(finetuning: bool = False) -> Dict: @@ -178,11 +191,11 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: "features": " ".join(FLAGS.features), "targets": " ".join(FLAGS.targets), "type": "finetuning" if finetuning else "default", - "embedding_dim": str(FLAGS.embedding_dim), - "cuda_device": str(FLAGS.cuda_device), - "num_epochs": str(FLAGS.num_epochs), - "word_batch_size": str(FLAGS.word_batch_size), - "use_tensorboard": str(FLAGS.tensorboard), + "embedding_dim": int(FLAGS.embedding_dim), + "cuda_device": int(FLAGS.cuda_device), + "num_epochs": int(FLAGS.num_epochs), + "word_batch_size": int(FLAGS.word_batch_size), + "use_tensorboard": int(FLAGS.tensorboard), } diff --git a/combo/models/__init__.py b/combo/models/__init__.py index 94664c634c1899b8682e1b32242eee10c26fdf4f..0d3fb493f34e3636fff102ec54790f0a856e4ba5 100644 --- a/combo/models/__init__.py +++ b/combo/models/__init__.py @@ -1,2 +1,3 @@ from .encoder import ComboStackedBidirectionalLSTM, ComboEncoder from .combo_model import ComboModel +from .model import Model diff --git a/combo/modules/archival.py b/combo/modules/archival.py index c32d6e8a9e59abcf0359c2c20d0b005526ee213c..85586b3af4fbd4fd82dfa2ee70b1695657e843c6 100644 --- a/combo/modules/archival.py +++ b/combo/modules/archival.py @@ -11,7 +11,7 @@ from io import BytesIO from tempfile import TemporaryDirectory from config import resolve -from data import DatasetReader +from data.dataset_loaders import DataLoader from modules.model import Model @@ -21,8 +21,8 @@ CACHE_DIRECTORY = str(CACHE_ROOT / "cache") class Archive(NamedTuple): model: Model - dataset_reader: Optional[DatasetReader] - validation_dataset_reader: Optional[DatasetReader] + data_loader: Optional[DataLoader] + validation_data_loader: Optional[DataLoader] def add_to_tar(tar_file: tarfile.TarFile, out_stream: BytesIO, data: bytes, name: str): @@ -34,7 +34,9 @@ def add_to_tar(tar_file: tarfile.TarFile, out_stream: BytesIO, data: bytes, name def archive(model: Model, - serialization_dir: Union[PathLike, str]) -> str: + serialization_dir: Union[PathLike, str], + data_loader: Optional[DataLoader] = None, + validation_data_loader: Optional[DataLoader] = None) -> str: parameters = {'vocabulary': { 'type': 'from_files_vocabulary', 'parameters': { @@ -44,6 +46,11 @@ def archive(model: Model, } }, 'model': model.serialize()} + if data_loader: + parameters['data_loader'] = data_loader.serialize() + if validation_data_loader: + parameters['validation_data_loader'] = validation_data_loader.serialize() + with (TemporaryDirectory(os.path.join('tmp')) as t, BytesIO() as out_stream, tarfile.open(os.path.join(serialization_dir, 'model.tar.gz'), 'w|gz') as tar_file): @@ -71,13 +78,13 @@ def load_archive(url_or_filename: Union[PathLike, str], with open(os.path.join(archive_file, 'config.json'), 'r') as f: config = json.load(f) - dataset_reader, validation_dataset_reader = None, None + data_loader, validation_data_loader = None, None - if 'dataset_reader' in config: - dataset_reader = resolve(config['dataset_reader']) - if 'validation_dataset_reader' in config: - validation_dataset_reader = resolve(config['validation_dataset_reader']) + if 'data_loader' in config: + data_loader = resolve(config['data_loader']) + if 'validation_data_loader' in config: + validation_data_loader = resolve(config['validation_data_loader']) return Archive(model=model, - dataset_reader=dataset_reader, - validation_dataset_reader=validation_dataset_reader) + data_loader=data_loader, + validation_data_loader=validation_data_loader) diff --git a/combo/nn/regularizers/__init__.py b/combo/nn/regularizers/__init__.py index 27f101d7bd31e0502eee7722c973b0d2629f62fe..bc98f43f1cf7b169c544223bf7a358a6d7dbb9fb 100644 --- a/combo/nn/regularizers/__init__.py +++ b/combo/nn/regularizers/__init__.py @@ -1,2 +1,3 @@ from .regularizers import * -from .regularizer import * \ No newline at end of file +from .regularizer import * +from .regularizer_applicator import RegularizerApplicator diff --git a/combo/polish_model_training.ipynb b/combo/polish_model_training.ipynb index 7efc01e13779e3690e2efd7a473c6785cd162af8..6672180436d81f0d3d0d3e491b9b8ee02df3f80e 100644 --- a/combo/polish_model_training.ipynb +++ b/combo/polish_model_training.ipynb @@ -2,27 +2,27 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "outputs": [], "source": [ "# The path where the training and validation datasets are stored\n", - "TRAINING_DATA_PATH: str = '/Users/majajablonska/Downloads/PDBUD-master-85167180bcbe0565a09269257456961365cf6ff3/PDB-UD/PDB-UD/PDBUD_train.conllu'\n", - "VALIDATION_DATA_PATH: str = '/Users/majajablonska/Downloads/PDBUD-master-85167180bcbe0565a09269257456961365cf6ff3/PDB-UD/PDB-UD/PDBUD_val.conllu'\n", + "TRAINING_DATA_PATH: str = '/Users/majajablonska/Documents/PDB/PDBUD_train.conllu'\n", + "VALIDATION_DATA_PATH: str = '/Users/majajablonska/Documents/PDB/PDBUD_val.conllu'\n", "# The path where the model can be saved to\n", "SERIALIZATION_DIR: str = \"/Users/majajablonska/Documents/Workspace/combotest\"" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:36:28.266635Z", - "start_time": "2023-10-11T08:36:27.158291Z" + "end_time": "2023-10-15T12:35:09.621320Z", + "start_time": "2023-10-15T12:35:09.407839Z" } }, "id": "b28c7d8bacb08d02" }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "outputs": [], "source": [ "import os\n", @@ -52,15 +52,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:36:31.986641Z", - "start_time": "2023-10-11T08:36:27.382007Z" + "end_time": "2023-10-15T12:35:17.384934Z", + "start_time": "2023-10-15T12:35:09.418819Z" } }, "id": "initial_id" }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "outputs": [ { "name": "stdout", @@ -78,7 +78,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "2ccd5ef8438f48418d1c2b29f2aa789b" + "model_id": "22fcc24a17304631b9ce8b5738210612" } }, "metadata": {}, @@ -90,7 +90,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "5cb100d63b854e0aad0f39db0679be5c" + "model_id": "a14b090b880f4d6a8314acedeadd8c1f" } }, "metadata": {}, @@ -102,7 +102,19 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "0d5396d8d87945a692d2d71fd1bf9b14" + "model_id": "4c47cc124c6f42db9fb7b3b9a07709de" + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "building vocabulary: 0it [00:00, ?it/s]", + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "10cc5ff44d5f48b98c9c35de7eec0b1e" } }, "metadata": {}, @@ -165,20 +177,30 @@ " only_include_pretrained_words=False,\n", " oov_token='_',\n", " padding_token='__PAD__'\n", - ")" + ")\n", + "\n", + "val_vocabulary = Vocabulary.from_data_loader_extended(\n", + " val_data_loader,\n", + " non_padded_namespaces=['head_labels'],\n", + " only_include_pretrained_words=False,\n", + " oov_token='_',\n", + " padding_token='__PAD__'\n", + ")\n", + "\n", + "vocabulary.extend_from_vocab(val_vocabulary)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:37:01.898078Z", - "start_time": "2023-10-11T08:36:31.982596Z" + "end_time": "2023-10-15T12:35:40.187804Z", + "start_time": "2023-10-15T12:35:17.370019Z" } }, "id": "d74957f422f0b05b" }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "outputs": [], "source": [ "seq_encoder = ComboEncoder(layer_dropout_probability=0.33,\n", @@ -193,15 +215,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:37:02.701990Z", - "start_time": "2023-10-11T08:37:01.927460Z" + "end_time": "2023-10-15T12:35:40.901949Z", + "start_time": "2023-10-15T12:35:40.192629Z" } }, "id": "fa724d362fd6bd23" }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "outputs": [ { "name": "stdout", @@ -212,9 +234,9 @@ }, { "data": { - "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fc82dac9430>" + "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fec3dd685f0>" }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -240,21 +262,21 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:37:02.769503Z", - "start_time": "2023-10-11T08:37:02.705688Z" + "end_time": "2023-10-15T12:35:40.965741Z", + "start_time": "2023-10-15T12:35:40.904199Z" } }, "id": "f8a10f9892005fca" }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.sso.sso_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']\n", + "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.sso.sso_relationship.bias', 'cls.predictions.decoder.weight', 'cls.sso.sso_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] @@ -386,15 +408,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:37:06.802516Z", - "start_time": "2023-10-11T08:37:02.748561Z" + "end_time": "2023-10-15T12:35:47.716817Z", + "start_time": "2023-10-15T12:35:40.960064Z" } }, "id": "437d12054baaffa1" }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "outputs": [], "source": [ "data_loader.index_with(vocabulary)\n", @@ -405,15 +427,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:37:44.084551Z", - "start_time": "2023-10-11T08:37:06.818779Z" + "end_time": "2023-10-15T12:36:27.968154Z", + "start_time": "2023-10-15T12:35:47.794229Z" } }, "id": "e131e0ec75dc6927" }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "outputs": [], "source": [ "val_data_loader.index_with(vocabulary)" @@ -421,15 +443,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:38:29.223247Z", - "start_time": "2023-10-11T08:37:44.066248Z" + "end_time": "2023-10-15T12:36:31.913752Z", + "start_time": "2023-10-15T12:36:27.951466Z" } }, "id": "195c71fcf8170ff" }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "outputs": [ { "name": "stderr", @@ -456,15 +478,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:38:37.162449Z", - "start_time": "2023-10-11T08:38:29.186645Z" + "end_time": "2023-10-15T12:36:32.265366Z", + "start_time": "2023-10-15T12:36:31.923212Z" } }, "id": "cefc5173154d1605" }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "outputs": [ { "name": "stderr", @@ -475,10 +497,10 @@ "-------------------------------------\n", "0 | model | ComboModel | 136 M \n", "-------------------------------------\n", - "12.1 M Trainable params\n", + "12.2 M Trainable params\n", "124 M Non-trainable params\n", "136 M Total params\n", - "546.107 Total estimated model params size (MB)\n" + "546.647 Total estimated model params size (MB)\n" ] }, { @@ -487,7 +509,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "bf1714d7fc554b17a16d6f0d3ccb96b1" + "model_id": "c0fdd5ff1efa43f7b2317abcd0c00fb9" } }, "metadata": {}, @@ -509,7 +531,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "63ad53394ef8472e873a9cf9cb093eeb" + "model_id": "5c0b5cee0be2439588dc665a4bb4ba21" } }, "metadata": {}, @@ -521,7 +543,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "8a783c14f75245d9ac3b4339bd4e1a44" + "model_id": "32bc016d2f99452db3e6dbc9575e4cc8" } }, "metadata": {}, @@ -541,15 +563,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:39:43.280379Z", - "start_time": "2023-10-11T08:38:29.495446Z" + "end_time": "2023-10-15T12:38:24.497905Z", + "start_time": "2023-10-15T12:36:32.262713Z" } }, "id": "e5af131bae4b1a33" }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "outputs": [], "source": [ "predictor = COMBO(model, dataset_reader)" @@ -557,15 +579,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:39:43.296706Z", - "start_time": "2023-10-11T08:39:42.829494Z" + "end_time": "2023-10-15T12:38:32.892083Z", + "start_time": "2023-10-15T12:38:23.228734Z" } }, "id": "3e23413c86063183" }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "outputs": [], "source": [ "a = predictor(\"Cześć, jestem psem.\")" @@ -573,24 +595,24 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:39:43.578465Z", - "start_time": "2023-10-11T08:39:42.842561Z" + "end_time": "2023-10-15T12:38:32.986864Z", + "start_time": "2023-10-15T12:38:25.706104Z" } }, "id": "d555d7f0223a624b" }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TOKEN LEMMA UPOS HEAD DEPREL \n", - "Cześć, ?????a NOUN 0 root \n", - "jestem ?????a NOUN 1 case \n", - "psem. ????a PUNCT 1 punct \n" + "Cześć, ????? NOUN 0 root \n", + "jestem ????? NOUN 1 punct \n", + "psem. ???? NOUN 2 punct \n" ] } ], @@ -602,35 +624,65 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:39:43.615606Z", - "start_time": "2023-10-11T08:39:43.250741Z" + "end_time": "2023-10-15T12:38:33.022415Z", + "start_time": "2023-10-15T12:38:28.029459Z" } }, "id": "a68cd3861e1ceb67" }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, + "outputs": [], + "source": [ + "from modules.archival import archive" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-15T12:38:33.024270Z", + "start_time": "2023-10-15T12:38:28.177734Z" + } + }, + "id": "d0f43f4493218b5" + }, + { + "cell_type": "code", + "execution_count": 20, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "Directory /Users/majajablonska/Documents/Workspace/combotest/saved_model/vocabulary is not empty\n" - ] + "data": { + "text/plain": "'/Users/majajablonska/Documents/combo'" + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "model.save(os.path.join(SERIALIZATION_DIR, 'saved_model'))" + "archive(model, '/Users/majajablonska/Documents/combo', data_loader, val_data_loader)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-11T08:39:44.927330Z", - "start_time": "2023-10-11T08:39:43.254017Z" + "end_time": "2023-10-15T13:46:22.801925Z", + "start_time": "2023-10-15T13:44:31.564518Z" } }, - "id": "d0f43f4493218b5" + "id": "ec92aa5bb5bb3605" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-10-15T12:38:32.141265Z" + } + }, + "id": "953bd53cccd5f890" } ], "metadata": { diff --git a/combo/predict.py b/combo/predict.py index e3b285ecf54764acd8e8341b701acd151959e620..9a3a42f327ecbd17502c4c00de394bd78214ed1e 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -18,6 +18,7 @@ from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.instance import JsonDict from combo.predictors import PredictorModule from combo.utils import download, graph +from modules.model import Model logger = logging.getLogger(__name__) @@ -26,7 +27,7 @@ logger = logging.getLogger(__name__) class COMBO(PredictorModule): @register_arguments def __init__(self, - model: models.Model, + model: Model, dataset_reader: DatasetReader, tokenizer: data.Tokenizer = tokenizers.WhitespaceTokenizer(), batch_size: int = 1024,