From c7327132fbaf1e34c5330771422e1ca3882e7ee2 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Wed, 15 Nov 2023 17:22:25 +1100 Subject: [PATCH] Add CLI parameters --- combo/combo_model.py | 16 ++- combo/config/__init__.py | 2 +- combo/config/from_parameters.py | 59 +++++++++- combo/main.py | 169 +++++++++++++++++++++++------ combo/polish_model_training.ipynb | 94 ++++++++-------- docs/training.md | 26 +++++ tests/config/test_configuration.py | 59 ++++++++++ 7 files changed, 334 insertions(+), 91 deletions(-) create mode 100644 docs/training.md diff --git a/combo/combo_model.py b/combo/combo_model.py index 4959b72..79c9f36 100644 --- a/combo/combo_model.py +++ b/combo/combo_model.py @@ -23,6 +23,7 @@ from combo.nn import utils from combo.nn.utils import get_text_field_mask from combo.predictors import Predictor from combo.utils import metrics +from utils import ConfigurationError @Registry.register("semantic_multitask") @@ -165,7 +166,10 @@ class ComboModel(Model, FromParameters): if self.morphological_feat: mapped_gold_labels = [] for _, cat_indices in self.morphological_feat.slices.items(): - mapped_gold_labels.append(feats[:, :, cat_indices].argmax(dim=-1)) + try: + mapped_gold_labels.append(feats[:, :, cat_indices].argmax(dim=-1)) + except TypeError: + raise ConfigurationError('Feats is None - if no feats are provided, the morphological_feat property should be set to None.') feats = torch.stack(mapped_gold_labels, dim=-1) @@ -184,11 +188,11 @@ class ComboModel(Model, FromParameters): relations_loss, head_loss = parser_output["loss"] enhanced_relations_loss, enhanced_head_loss = enhanced_parser_output["loss"] losses = { - "upostag_loss": upos_output["loss"], - "xpostag_loss": xpos_output["loss"], - "semrel_loss": semrel_output["loss"], - "feats_loss": morpho_output["loss"], - "lemma_loss": lemma_output["loss"], + "upostag_loss": upos_output.get("loss"), + "xpostag_loss": xpos_output.get("loss"), + "semrel_loss": semrel_output.get("loss"), + "feats_loss": morpho_output.get("loss"), + "lemma_loss": lemma_output.get("loss"), "head_loss": head_loss, "deprel_loss": relations_loss, "enhanced_head_loss": enhanced_head_loss, diff --git a/combo/config/__init__.py b/combo/config/__init__.py index 948f357..e956d66 100644 --- a/combo/config/__init__.py +++ b/combo/config/__init__.py @@ -1,2 +1,2 @@ -from .from_parameters import FromParameters, resolve +from .from_parameters import FromParameters, override_parameters, resolve from .registry import Registry diff --git a/combo/config/from_parameters.py b/combo/config/from_parameters.py index e853042..da9dafb 100644 --- a/combo/config/from_parameters.py +++ b/combo/config/from_parameters.py @@ -19,7 +19,6 @@ def get_matching_arguments(args: Dict[str, Any], func: Callable) -> Dict[str, An def _resolve(values: typing.Union[Dict[str, Any], str], pass_down_parameters: Dict[str, Any] = None) -> Any: - if isinstance(values, Params): values = Params.as_dict() @@ -148,7 +147,7 @@ class FromParameters: if pn in pass_down_parameter_names: continue parameters_dict[pn] = serialize_single_value(param_value, - pass_down_parameter_names+self.pass_down_parameter_names()) + pass_down_parameter_names + self.pass_down_parameter_names()) return parameters_dict def serialize(self, pass_down_parameter_names: List[str] = None) -> Dict[str, Any]: @@ -166,3 +165,59 @@ def resolve(parameters: Dict[str, Any], pass_down_parameters: Dict[str, Any] = N pass_down_parameters = pass_down_parameters or {} clz, clz_init = Registry.resolve(parameters['type']) return clz.from_parameters(parameters['parameters'], clz_init, pass_down_parameters) + + +def flatten_dictionary(d, parent_key='', sep='/'): + """ + Flatten a nested dictionary. + + Parameters: + d (dict): The input dictionary. + parent_key (str): The parent key to use for recursion (default is an empty string). + sep (str): The separator to use when concatenating keys (default is '_'). + + Returns: + dict: A flattened dictionary. + """ + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dictionary(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def unflatten_dictionary(flat_dict, sep='/'): + """ + Unflatten a flattened dictionary. + + Parameters: + flat_dict (dict): The flattened dictionary. + sep (str): The separator used in the flattened keys (default is '_'). + + Returns: + dict: The unflattened dictionary. + """ + unflattened_dict = {} + for key, value in flat_dict.items(): + keys = key.split(sep) + current_level = unflattened_dict + + for k in keys[:-1]: + current_level = current_level.setdefault(k, {}) + + current_level[keys[-1]] = value + + return unflattened_dict + + +def override_parameters(parameters: Dict[str, Any], override_values: Dict[str, Any]) -> Dict[str, Any]: + overriden_parameters = flatten_dictionary(parameters) + override_values = flatten_dictionary(override_values) + for ko, vo in override_values.items(): + if ko in overriden_parameters: + overriden_parameters[ko] = vo + + return unflatten_dictionary(overriden_parameters) diff --git a/combo/main.py b/combo/main.py index ecc455c..44a6a30 100755 --- a/combo/main.py +++ b/combo/main.py @@ -18,20 +18,28 @@ from combo.default_model import default_ud_dataset_reader, default_data_loader from combo.modules.archival import load_archive, archive from combo.predict import COMBO from combo.data import api -from combo.data import DatasetReader +from config import override_parameters +from utils import ConfigurationError logging.setLoggerClass(ComboLogger) logger = logging.getLogger(__name__) _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"] _TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent", "deps"] + +def handle_error(error: Exception): + msg = getattr(error, 'message', str(error)) + logger.error(msg) + print(f'Error: {msg}') + + FLAGS = flags.FLAGS flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict"], help="Specify COMBO mode: train or predict") # Common flags -flags.DEFINE_integer(name="cuda_device", default=-1, - help="Cuda device idx (default -1 cpu)") +flags.DEFINE_integer(name="n_cuda_devices", default=-1, + help="Number of devices to train on (default -1 auto mode - train on as many as possible)") flags.DEFINE_string(name="output_file", default="output.log", help="Predictions result file.") @@ -42,8 +50,8 @@ flags.DEFINE_string(name="validation_data_path", default="", help="Validation da flags.DEFINE_alias(name="validation_data", original_name="validation_data_path") flags.DEFINE_string(name="pretrained_tokens", default="", help="Pretrained tokens embeddings path") -flags.DEFINE_integer(name="embedding_dim", default=300, - help="Embeddings dim") +flags.DEFINE_integer(name="lemmatizer_embedding_dim", default=300, + help="Lemmatizer embeddings dim") flags.DEFINE_integer(name="num_epochs", default=400, help="Epochs num") flags.DEFINE_integer(name="word_batch_size", default=2500, @@ -72,10 +80,8 @@ flags.DEFINE_string(name="finetuning_validation_data_path", default="", flags.DEFINE_string(name="test_data_path", default=None, help="Test path file.") flags.DEFINE_alias(name="test_data", original_name="test_data_path") - -# Experimental flags.DEFINE_boolean(name="use_pure_config", default=False, - help="Ignore ext flags (experimental).") + help="Ignore ext flags.") # Prediction flags flags.DEFINE_string(name="model_path", default=None, @@ -99,37 +105,58 @@ def run(_): if not FLAGS.finetuning: prefix = 'Training' logger.info('Setting up the model for training', prefix=prefix) - checks.file_exists(FLAGS.config_path) + try: + checks.file_exists(FLAGS.config_path) + except ConfigurationError as e: + handle_error(e) + return logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix) with open(FLAGS.config_path, 'r') as f: params = json.load(f) - params = {**params, **_get_ext_vars()} + params = override_parameters(params, _get_ext_vars(True)) + + if 'feats' not in FLAGS.features: + del params['model']['parameters']['morphological_feat'] serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir) params['vocabulary']['parameters']['directory'] = os.path.join('/'.join(FLAGS.config_path.split('/')[:-1]), - params['vocabulary']['parameters']['directory']) + params['vocabulary']['parameters'][ + 'directory']) try: vocabulary = resolve(params['vocabulary']) - except KeyError: - logger.error('No vocabulary in config.json!') + except Exception as e: + handle_error(e) return + try: + model = resolve(override_parameters(params['model'], _get_ext_vars(False)), + pass_down_parameters={'vocabulary': vocabulary}) + except Exception as e: + handle_error(e) + return - model = resolve(params['model'], pass_down_parameters={'vocabulary': vocabulary}) dataset_reader = None if 'data_loader' in params: logger.info(f'Resolving the training data loader from parameters', prefix=prefix) - train_data_loader = resolve(params['data_loader']) + try: + train_data_loader = resolve(params['data_loader']) + except Exception as e: + handle_error(e) + return else: checks.file_exists(FLAGS.training_data_path) logger.info(f'Using a default UD data loader with training data path {FLAGS.training_data_path}', prefix=prefix) - train_data_loader = default_data_loader(default_ud_dataset_reader(), - FLAGS.training_data_path) + try: + train_data_loader = default_data_loader(default_ud_dataset_reader(), + FLAGS.training_data_path) + except Exception as e: + handle_error(e) + return logger.info('Indexing training data loader') train_data_loader.index_with(model.vocab) @@ -180,10 +207,18 @@ def run(_): nlp = TrainableCombo(model, torch.optim.Adam, optimizer_kwargs={'betas': [0.9, 0.9], 'lr': 0.002}, validation_metrics=['EM']) + + n_cuda_devices = "auto" if FLAGS.n_cuda_devices == -1 else FLAGS.n_cuda_devices + trainer = pl.Trainer(max_epochs=FLAGS.num_epochs, default_root_dir=serialization_dir, - gradient_clip_val=5) - trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader) + gradient_clip_val=5, + devices=n_cuda_devices) + try: + trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader) + except Exception as e: + handle_error(e) + return logger.info(f'Archiving the model in {serialization_dir}', prefix=prefix) archive(model, serialization_dir, train_data_loader, validation_data_loader, dataset_reader) @@ -192,8 +227,9 @@ def run(_): if FLAGS.test_data_path and FLAGS.output_file: checks.file_exists(FLAGS.test_data_path) if not dataset_reader: - logger.info("No dataset reader in the configuration or archive file - using a default UD dataset reader", - prefix=prefix) + logger.info( + "No dataset reader in the configuration or archive file - using a default UD dataset reader", + prefix=prefix) dataset_reader = default_ud_dataset_reader() logger.info("Predicting test examples", prefix=prefix) test_trees = dataset_reader.read(FLAGS.test_data_path) @@ -212,7 +248,7 @@ def run(_): logger.info("No dataset reader in the configuration or archive file - using a default UD dataset reader", prefix=prefix) dataset_reader = default_ud_dataset_reader() - + predictor = COMBO(model, dataset_reader) if FLAGS.input_file == '-': @@ -242,23 +278,86 @@ def run(_): def _get_ext_vars(finetuning: bool = False) -> Dict: if FLAGS.use_pure_config: return {} - return { - "training_data_path": ( - ",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)), - "validation_data_path": ( - ",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)), + + to_override = { + "model": { + "parameters": { + "lemmatizer": { + "parameters": { + "embedding_dim": FLAGS.lemmatizer_embedding_dim + } + }, + "text_field_embedder": { + "parameters": { + "token_embedders": { + "parameters": { + "token": { + "parameters": { + "model_name": FLAGS.pretrained_transformer_name + } + } + } + } + } + }, + "serialization_dir": FLAGS.serialization_dir + } + }, + "data_loader": { + "data_path": (",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)), + "parameters": { + "reader": { + "parameters": { + "features": FLAGS.features, + "targets": FLAGS.targets, + "token_indexers": { + "token": { + "parameters": { + "model_name": FLAGS.pretrained_transformer_name + } + } + } + } + } + } + }, + "validation_data_loader": { + "data_path": (",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)), + "parameters": { + "reader": { + "parameters": { + "features": FLAGS.features, + "targets": FLAGS.targets, + "token_indexers": { + "token": { + "parameters": { + "model_name": FLAGS.pretrained_transformer_name + } + } + } + } + } + } + }, + "dataset_reader": { + "parameters": { + "features": FLAGS.features, + "targets": FLAGS.targets, + "token_indexers": { + "token": { + "parameters": { + "model_name": FLAGS.pretrained_transformer_name + } + } + } + } + }, "pretrained_tokens": FLAGS.pretrained_tokens, - "pretrained_transformer_name": FLAGS.pretrained_transformer_name, - "features": " ".join(FLAGS.features), - "targets": " ".join(FLAGS.targets), - "type": "finetuning" if finetuning else "default", - "embedding_dim": int(FLAGS.embedding_dim), - "cuda_device": int(FLAGS.cuda_device), - "num_epochs": int(FLAGS.num_epochs), "word_batch_size": int(FLAGS.word_batch_size), - "use_tensorboard": int(FLAGS.tensorboard), } + return to_override + def main(): """Parse flags.""" diff --git a/combo/polish_model_training.ipynb b/combo/polish_model_training.ipynb index bd2796f..005fe4b 100644 --- a/combo/polish_model_training.ipynb +++ b/combo/polish_model_training.ipynb @@ -6,16 +6,16 @@ "outputs": [], "source": [ "# The path where the training and validation datasets are stored\n", - "TRAINING_DATA_PATH: str = '/Users/majajablonska/Documents/PDB/PDBUD_train.conllu'\n", - "VALIDATION_DATA_PATH: str = '/Users/majajablonska/Documents/PDB/PDBUD_val.conllu'\n", + "TRAINING_DATA_PATH: str = '/Users/majajablonska/Documents/PDBUD/train.conllu'\n", + "VALIDATION_DATA_PATH: str = '/Users/majajablonska/Documents/PDBUD/val.conllu'\n", "# The path where the model can be saved to\n", "SERIALIZATION_DIR: str = \"/Users/majajablonska/Documents/Workspace/combotest\"" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:18:36.376046Z", - "start_time": "2023-11-13T08:18:36.189836Z" + "end_time": "2023-11-13T12:15:21.197003Z", + "start_time": "2023-11-13T12:15:19.886422Z" } }, "id": "b28c7d8bacb08d02" @@ -51,8 +51,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:18:41.300316Z", - "start_time": "2023-11-13T08:18:36.197537Z" + "end_time": "2023-11-13T12:15:28.665585Z", + "start_time": "2023-11-13T12:15:19.907198Z" } }, "id": "initial_id" @@ -77,7 +77,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "e45b57a7047043a48ccfeacfb49312b5" + "model_id": "2179b1be2f484a33948a76d087002182" } }, "metadata": {}, @@ -89,7 +89,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "a886ae4451474459b088659ebac076ae" + "model_id": "86762d681ee0467e8501de2b34061aad" } }, "metadata": {}, @@ -101,7 +101,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "20ff98564c7c43a9971c25f82ceda997" + "model_id": "b9e631cb77594ea5aae60e6d15809885" } }, "metadata": {}, @@ -169,8 +169,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:19:01.785477Z", - "start_time": "2023-11-13T08:18:41.119674Z" + "end_time": "2023-11-13T12:15:51.717065Z", + "start_time": "2023-11-13T12:15:28.442131Z" } }, "id": "d74957f422f0b05b" @@ -192,8 +192,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:19:03.816723Z", - "start_time": "2023-11-13T08:19:01.774666Z" + "end_time": "2023-11-13T12:15:52.574303Z", + "start_time": "2023-11-13T12:15:51.724469Z" } }, "id": "fa724d362fd6bd23" @@ -211,7 +211,7 @@ }, { "data": { - "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fb2d3cdfc80>" + "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fb512dc4f20>" }, "execution_count": 5, "metadata": {}, @@ -239,8 +239,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:19:03.877868Z", - "start_time": "2023-11-13T08:19:03.826289Z" + "end_time": "2023-11-13T12:15:52.641199Z", + "start_time": "2023-11-13T12:15:52.583194Z" } }, "id": "f8a10f9892005fca" @@ -263,8 +263,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:19:03.887795Z", - "start_time": "2023-11-13T08:19:03.870640Z" + "end_time": "2023-11-13T12:15:52.659289Z", + "start_time": "2023-11-13T12:15:52.625700Z" } }, "id": "14413692656b68ac" @@ -277,7 +277,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight']\n", + "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] @@ -411,8 +411,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:19:10.006549Z", - "start_time": "2023-11-13T08:19:03.885912Z" + "end_time": "2023-11-13T12:15:56.509687Z", + "start_time": "2023-11-13T12:15:52.658879Z" } }, "id": "437d12054baaffa1" @@ -430,8 +430,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:19:47.953809Z", - "start_time": "2023-11-13T08:19:09.989582Z" + "end_time": "2023-11-13T12:16:30.663344Z", + "start_time": "2023-11-13T12:15:56.529656Z" } }, "id": "e131e0ec75dc6927" @@ -446,8 +446,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:19:53.878201Z", - "start_time": "2023-11-13T08:19:47.940147Z" + "end_time": "2023-11-13T12:16:45.453326Z", + "start_time": "2023-11-13T12:16:30.488388Z" } }, "id": "195c71fcf8170ff" @@ -481,8 +481,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:19:54.150333Z", - "start_time": "2023-11-13T08:19:53.874397Z" + "end_time": "2023-11-13T12:16:45.785538Z", + "start_time": "2023-11-13T12:16:45.365250Z" } }, "id": "cefc5173154d1605" @@ -503,7 +503,7 @@ "12.1 M Trainable params\n", "124 M Non-trainable params\n", "136 M Total params\n", - "546.115 Total estimated model params size (MB)\n" + "546.106 Total estimated model params size (MB)\n" ] }, { @@ -512,7 +512,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "9f6e199b0fd546f5833fdda238964165" + "model_id": "f2dd3228246843428b8fcb8ae932c1f1" } }, "metadata": {}, @@ -534,7 +534,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "594c065e5d2441f48ba2b87c7a3f528f" + "model_id": "0bcdd388df664784ba19667c6a0593a1" } }, "metadata": {}, @@ -546,7 +546,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "bd8bf330a90e4eddb52a3c87af6d2869" + "model_id": "a8203342bf454c22b292548d64f085a9" } }, "metadata": {}, @@ -566,8 +566,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:20:39.048426Z", - "start_time": "2023-11-13T08:19:54.147748Z" + "end_time": "2023-11-13T12:17:47.659618Z", + "start_time": "2023-11-13T12:16:45.706948Z" } }, "id": "e5af131bae4b1a33" @@ -582,8 +582,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:20:39.152284Z", - "start_time": "2023-11-13T08:20:39.042845Z" + "end_time": "2023-11-13T12:17:47.975345Z", + "start_time": "2023-11-13T12:17:47.644327Z" } }, "id": "3e23413c86063183" @@ -598,8 +598,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:20:39.228735Z", - "start_time": "2023-11-13T08:20:39.052747Z" + "end_time": "2023-11-13T12:17:47.989681Z", + "start_time": "2023-11-13T12:17:47.665490Z" } }, "id": "d555d7f0223a624b" @@ -613,8 +613,8 @@ "output_type": "stream", "text": [ "TOKEN LEMMA UPOS HEAD DEPREL \n", - "Cześć, ????? NOUN 0 root \n", - "jestem ????? NOUN 1 punct \n", + "Cześć, ?????? NOUN 0 root \n", + "jestem ?????a NOUN 1 punct \n", "psem. ????? NOUN 1 punct \n" ] } @@ -627,8 +627,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:20:39.237630Z", - "start_time": "2023-11-13T08:20:39.227051Z" + "end_time": "2023-11-13T12:17:48.005229Z", + "start_time": "2023-11-13T12:17:47.923055Z" } }, "id": "a68cd3861e1ceb67" @@ -643,8 +643,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:20:39.248539Z", - "start_time": "2023-11-13T08:20:39.233003Z" + "end_time": "2023-11-13T12:17:48.008545Z", + "start_time": "2023-11-13T12:17:47.928808Z" } }, "id": "d0f43f4493218b5" @@ -668,8 +668,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:24:03.513738Z", - "start_time": "2023-11-13T08:20:39.250115Z" + "end_time": "2023-11-13T12:19:17.944519Z", + "start_time": "2023-11-13T12:17:47.965095Z" } }, "id": "ec92aa5bb5bb3605" @@ -684,8 +684,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-13T08:24:03.668958Z", - "start_time": "2023-11-13T08:24:02.256799Z" + "end_time": "2023-11-13T12:19:17.954324Z", + "start_time": "2023-11-13T12:19:17.920401Z" } }, "id": "5ad8a827586f65e3" diff --git a/docs/training.md b/docs/training.md new file mode 100644 index 0000000..7ac726c --- /dev/null +++ b/docs/training.md @@ -0,0 +1,26 @@ +# Training + +Basic command: + +```bash +combo --mode train \ + --training_data_path your_training_path \ + --validation_data_path your_validation_path +``` + +Options: + +```bash +combo --helpfull +``` + +## Examples + +For clarity, the training and validation data paths are omitted. + +Train on multiple accelerators (default: train on all available ones) +```bash +combo --mode train + --n_cuda_devices 8 +``` + diff --git a/tests/config/test_configuration.py b/tests/config/test_configuration.py index 48083cf..123bfbc 100644 --- a/tests/config/test_configuration.py +++ b/tests/config/test_configuration.py @@ -2,6 +2,7 @@ import unittest import os from combo.config import Registry +from combo.config.from_parameters import override_parameters from combo.data import WhitespaceTokenizer, UniversalDependenciesDatasetReader, Vocabulary from combo.data.token_indexers.token_characters_indexer import TokenCharactersIndexer @@ -79,3 +80,61 @@ class ConfigurationTest(unittest.TestCase): self.assertEqual(type(reconstructed_vocab), Vocabulary) self.assertEqual(reconstructed_vocab.constructed_from, 'from_files') self.assertSetEqual(reconstructed_vocab.get_namespaces(), {'animals'}) + + + def test_override_parameters(self): + parameters = { + 'type': 'base_vocabulary', + 'parameters': { + 'counter': {'counter': {'test': 0}}, + 'max_vocab_size': 10 + } + } + + to_override = {'parameters': {'max_vocab_size': 15}} + + self.assertDictEqual({ + 'type': 'base_vocabulary', + 'parameters': { + 'counter': {'counter': {'test': 0}}, + 'max_vocab_size': 15 + } + }, override_parameters(parameters, to_override)) + + def test_override_nested_parameters(self): + parameters = { + 'type': 'base_vocabulary', + 'parameters': { + 'counter': {'counter': {'test': 0}, 'another_property': 0}, + 'another_counter': {'counter': {'test': 0}, 'another_property': 0} + } + } + + to_override = {'parameters': {'another_counter': {'counter': {'test': 1}}}} + + self.assertDictEqual({ + 'type': 'base_vocabulary', + 'parameters': { + 'counter': {'counter': {'test': 0}, 'another_property': 0}, + 'another_counter': {'counter': {'test': 1}, 'another_property': 0} + } + }, override_parameters(parameters, to_override)) + + def test_override_parameters_no_change(self): + parameters = { + 'type': 'base_vocabulary', + 'parameters': { + 'counter': {'counter': {'test': 0}}, + 'max_vocab_size': 10 + } + } + + to_override = {} + + self.assertDictEqual({ + 'type': 'base_vocabulary', + 'parameters': { + 'counter': {'counter': {'test': 0}}, + 'max_vocab_size': 10 + } + }, override_parameters(parameters, to_override)) -- GitLab