From c7327132fbaf1e34c5330771422e1ca3882e7ee2 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Wed, 15 Nov 2023 17:22:25 +1100
Subject: [PATCH] Add CLI parameters

---
 combo/combo_model.py               |  16 ++-
 combo/config/__init__.py           |   2 +-
 combo/config/from_parameters.py    |  59 +++++++++-
 combo/main.py                      | 169 +++++++++++++++++++++++------
 combo/polish_model_training.ipynb  |  94 ++++++++--------
 docs/training.md                   |  26 +++++
 tests/config/test_configuration.py |  59 ++++++++++
 7 files changed, 334 insertions(+), 91 deletions(-)
 create mode 100644 docs/training.md

diff --git a/combo/combo_model.py b/combo/combo_model.py
index 4959b72..79c9f36 100644
--- a/combo/combo_model.py
+++ b/combo/combo_model.py
@@ -23,6 +23,7 @@ from combo.nn import utils
 from combo.nn.utils import get_text_field_mask
 from combo.predictors import Predictor
 from combo.utils import metrics
+from utils import ConfigurationError
 
 
 @Registry.register("semantic_multitask")
@@ -165,7 +166,10 @@ class ComboModel(Model, FromParameters):
             if self.morphological_feat:
                 mapped_gold_labels = []
                 for _, cat_indices in self.morphological_feat.slices.items():
-                    mapped_gold_labels.append(feats[:, :, cat_indices].argmax(dim=-1))
+                    try:
+                        mapped_gold_labels.append(feats[:, :, cat_indices].argmax(dim=-1))
+                    except TypeError:
+                        raise ConfigurationError('Feats is None - if no feats are provided, the morphological_feat property should be set to None.')
 
                 feats = torch.stack(mapped_gold_labels, dim=-1)
 
@@ -184,11 +188,11 @@ class ComboModel(Model, FromParameters):
             relations_loss, head_loss = parser_output["loss"]
             enhanced_relations_loss, enhanced_head_loss = enhanced_parser_output["loss"]
             losses = {
-                "upostag_loss": upos_output["loss"],
-                "xpostag_loss": xpos_output["loss"],
-                "semrel_loss": semrel_output["loss"],
-                "feats_loss": morpho_output["loss"],
-                "lemma_loss": lemma_output["loss"],
+                "upostag_loss": upos_output.get("loss"),
+                "xpostag_loss": xpos_output.get("loss"),
+                "semrel_loss": semrel_output.get("loss"),
+                "feats_loss": morpho_output.get("loss"),
+                "lemma_loss": lemma_output.get("loss"),
                 "head_loss": head_loss,
                 "deprel_loss": relations_loss,
                 "enhanced_head_loss": enhanced_head_loss,
diff --git a/combo/config/__init__.py b/combo/config/__init__.py
index 948f357..e956d66 100644
--- a/combo/config/__init__.py
+++ b/combo/config/__init__.py
@@ -1,2 +1,2 @@
-from .from_parameters import FromParameters, resolve
+from .from_parameters import FromParameters, override_parameters, resolve
 from .registry import Registry
diff --git a/combo/config/from_parameters.py b/combo/config/from_parameters.py
index e853042..da9dafb 100644
--- a/combo/config/from_parameters.py
+++ b/combo/config/from_parameters.py
@@ -19,7 +19,6 @@ def get_matching_arguments(args: Dict[str, Any], func: Callable) -> Dict[str, An
 
 
 def _resolve(values: typing.Union[Dict[str, Any], str], pass_down_parameters: Dict[str, Any] = None) -> Any:
-
     if isinstance(values, Params):
         values = Params.as_dict()
 
@@ -148,7 +147,7 @@ class FromParameters:
             if pn in pass_down_parameter_names:
                 continue
             parameters_dict[pn] = serialize_single_value(param_value,
-                                                         pass_down_parameter_names+self.pass_down_parameter_names())
+                                                         pass_down_parameter_names + self.pass_down_parameter_names())
         return parameters_dict
 
     def serialize(self, pass_down_parameter_names: List[str] = None) -> Dict[str, Any]:
@@ -166,3 +165,59 @@ def resolve(parameters: Dict[str, Any], pass_down_parameters: Dict[str, Any] = N
     pass_down_parameters = pass_down_parameters or {}
     clz, clz_init = Registry.resolve(parameters['type'])
     return clz.from_parameters(parameters['parameters'], clz_init, pass_down_parameters)
+
+
+def flatten_dictionary(d, parent_key='', sep='/'):
+    """
+    Flatten a nested dictionary.
+
+    Parameters:
+        d (dict): The input dictionary.
+        parent_key (str): The parent key to use for recursion (default is an empty string).
+        sep (str): The separator to use when concatenating keys (default is '_').
+
+    Returns:
+        dict: A flattened dictionary.
+    """
+    items = []
+    for k, v in d.items():
+        new_key = f"{parent_key}{sep}{k}" if parent_key else k
+        if isinstance(v, dict):
+            items.extend(flatten_dictionary(v, new_key, sep=sep).items())
+        else:
+            items.append((new_key, v))
+    return dict(items)
+
+
+def unflatten_dictionary(flat_dict, sep='/'):
+    """
+    Unflatten a flattened dictionary.
+
+    Parameters:
+        flat_dict (dict): The flattened dictionary.
+        sep (str): The separator used in the flattened keys (default is '_').
+
+    Returns:
+        dict: The unflattened dictionary.
+    """
+    unflattened_dict = {}
+    for key, value in flat_dict.items():
+        keys = key.split(sep)
+        current_level = unflattened_dict
+
+        for k in keys[:-1]:
+            current_level = current_level.setdefault(k, {})
+
+        current_level[keys[-1]] = value
+
+    return unflattened_dict
+
+
+def override_parameters(parameters: Dict[str, Any], override_values: Dict[str, Any]) -> Dict[str, Any]:
+    overriden_parameters = flatten_dictionary(parameters)
+    override_values = flatten_dictionary(override_values)
+    for ko, vo in override_values.items():
+        if ko in overriden_parameters:
+            overriden_parameters[ko] = vo
+
+    return unflatten_dictionary(overriden_parameters)
diff --git a/combo/main.py b/combo/main.py
index ecc455c..44a6a30 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -18,20 +18,28 @@ from combo.default_model import default_ud_dataset_reader, default_data_loader
 from combo.modules.archival import load_archive, archive
 from combo.predict import COMBO
 from combo.data import api
-from combo.data import DatasetReader
+from config import override_parameters
+from utils import ConfigurationError
 
 logging.setLoggerClass(ComboLogger)
 logger = logging.getLogger(__name__)
 _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"]
 _TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent", "deps"]
 
+
+def handle_error(error: Exception):
+    msg = getattr(error, 'message', str(error))
+    logger.error(msg)
+    print(f'Error: {msg}')
+
+
 FLAGS = flags.FLAGS
 flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict"],
                   help="Specify COMBO mode: train or predict")
 
 # Common flags
-flags.DEFINE_integer(name="cuda_device", default=-1,
-                     help="Cuda device idx (default -1 cpu)")
+flags.DEFINE_integer(name="n_cuda_devices", default=-1,
+                     help="Number of devices to train on (default -1 auto mode - train on as many as possible)")
 flags.DEFINE_string(name="output_file", default="output.log",
                     help="Predictions result file.")
 
@@ -42,8 +50,8 @@ flags.DEFINE_string(name="validation_data_path", default="", help="Validation da
 flags.DEFINE_alias(name="validation_data", original_name="validation_data_path")
 flags.DEFINE_string(name="pretrained_tokens", default="",
                     help="Pretrained tokens embeddings path")
-flags.DEFINE_integer(name="embedding_dim", default=300,
-                     help="Embeddings dim")
+flags.DEFINE_integer(name="lemmatizer_embedding_dim", default=300,
+                     help="Lemmatizer embeddings dim")
 flags.DEFINE_integer(name="num_epochs", default=400,
                      help="Epochs num")
 flags.DEFINE_integer(name="word_batch_size", default=2500,
@@ -72,10 +80,8 @@ flags.DEFINE_string(name="finetuning_validation_data_path", default="",
 flags.DEFINE_string(name="test_data_path", default=None,
                     help="Test path file.")
 flags.DEFINE_alias(name="test_data", original_name="test_data_path")
-
-# Experimental
 flags.DEFINE_boolean(name="use_pure_config", default=False,
-                     help="Ignore ext flags (experimental).")
+                     help="Ignore ext flags.")
 
 # Prediction flags
 flags.DEFINE_string(name="model_path", default=None,
@@ -99,37 +105,58 @@ def run(_):
         if not FLAGS.finetuning:
             prefix = 'Training'
             logger.info('Setting up the model for training', prefix=prefix)
-            checks.file_exists(FLAGS.config_path)
+            try:
+                checks.file_exists(FLAGS.config_path)
+            except ConfigurationError as e:
+                handle_error(e)
+                return
 
             logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix)
             with open(FLAGS.config_path, 'r') as f:
                 params = json.load(f)
-            params = {**params, **_get_ext_vars()}
+            params = override_parameters(params, _get_ext_vars(True))
+
+            if 'feats' not in FLAGS.features:
+                del params['model']['parameters']['morphological_feat']
 
             serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir)
 
             params['vocabulary']['parameters']['directory'] = os.path.join('/'.join(FLAGS.config_path.split('/')[:-1]),
-                                                                           params['vocabulary']['parameters']['directory'])
+                                                                           params['vocabulary']['parameters'][
+                                                                               'directory'])
 
             try:
                 vocabulary = resolve(params['vocabulary'])
-            except KeyError:
-                logger.error('No vocabulary in config.json!')
+            except Exception as e:
+                handle_error(e)
                 return
 
+            try:
+                model = resolve(override_parameters(params['model'], _get_ext_vars(False)),
+                                pass_down_parameters={'vocabulary': vocabulary})
+            except Exception as e:
+                handle_error(e)
+                return
 
-            model = resolve(params['model'], pass_down_parameters={'vocabulary': vocabulary})
             dataset_reader = None
 
             if 'data_loader' in params:
                 logger.info(f'Resolving the training data loader from parameters', prefix=prefix)
-                train_data_loader = resolve(params['data_loader'])
+                try:
+                    train_data_loader = resolve(params['data_loader'])
+                except Exception as e:
+                    handle_error(e)
+                    return
             else:
                 checks.file_exists(FLAGS.training_data_path)
                 logger.info(f'Using a default UD data loader with training data path {FLAGS.training_data_path}',
                             prefix=prefix)
-                train_data_loader = default_data_loader(default_ud_dataset_reader(),
-                                                        FLAGS.training_data_path)
+                try:
+                    train_data_loader = default_data_loader(default_ud_dataset_reader(),
+                                                            FLAGS.training_data_path)
+                except Exception as e:
+                    handle_error(e)
+                    return
 
             logger.info('Indexing training data loader')
             train_data_loader.index_with(model.vocab)
@@ -180,10 +207,18 @@ def run(_):
         nlp = TrainableCombo(model, torch.optim.Adam,
                              optimizer_kwargs={'betas': [0.9, 0.9], 'lr': 0.002},
                              validation_metrics=['EM'])
+
+        n_cuda_devices = "auto" if FLAGS.n_cuda_devices == -1 else FLAGS.n_cuda_devices
+
         trainer = pl.Trainer(max_epochs=FLAGS.num_epochs,
                              default_root_dir=serialization_dir,
-                             gradient_clip_val=5)
-        trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader)
+                             gradient_clip_val=5,
+                             devices=n_cuda_devices)
+        try:
+            trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader)
+        except Exception as e:
+            handle_error(e)
+            return
 
         logger.info(f'Archiving the model in {serialization_dir}', prefix=prefix)
         archive(model, serialization_dir, train_data_loader, validation_data_loader, dataset_reader)
@@ -192,8 +227,9 @@ def run(_):
         if FLAGS.test_data_path and FLAGS.output_file:
             checks.file_exists(FLAGS.test_data_path)
             if not dataset_reader:
-                logger.info("No dataset reader in the configuration or archive file - using a default UD dataset reader",
-                            prefix=prefix)
+                logger.info(
+                    "No dataset reader in the configuration or archive file - using a default UD dataset reader",
+                    prefix=prefix)
                 dataset_reader = default_ud_dataset_reader()
             logger.info("Predicting test examples", prefix=prefix)
             test_trees = dataset_reader.read(FLAGS.test_data_path)
@@ -212,7 +248,7 @@ def run(_):
             logger.info("No dataset reader in the configuration or archive file - using a default UD dataset reader",
                         prefix=prefix)
             dataset_reader = default_ud_dataset_reader()
-        
+
         predictor = COMBO(model, dataset_reader)
 
         if FLAGS.input_file == '-':
@@ -242,23 +278,86 @@ def run(_):
 def _get_ext_vars(finetuning: bool = False) -> Dict:
     if FLAGS.use_pure_config:
         return {}
-    return {
-        "training_data_path": (
-            ",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)),
-        "validation_data_path": (
-            ",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)),
+
+    to_override = {
+        "model": {
+            "parameters": {
+                "lemmatizer": {
+                    "parameters": {
+                        "embedding_dim": FLAGS.lemmatizer_embedding_dim
+                    }
+                },
+                "text_field_embedder": {
+                    "parameters": {
+                        "token_embedders": {
+                            "parameters": {
+                                "token": {
+                                    "parameters": {
+                                        "model_name": FLAGS.pretrained_transformer_name
+                                    }
+                                }
+                            }
+                        }
+                    }
+                },
+                "serialization_dir": FLAGS.serialization_dir
+            }
+        },
+        "data_loader": {
+            "data_path": (",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)),
+            "parameters": {
+                "reader": {
+                    "parameters": {
+                        "features": FLAGS.features,
+                        "targets": FLAGS.targets,
+                        "token_indexers": {
+                            "token": {
+                                "parameters": {
+                                    "model_name": FLAGS.pretrained_transformer_name
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        },
+        "validation_data_loader": {
+            "data_path": (",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)),
+            "parameters": {
+                "reader": {
+                    "parameters": {
+                        "features": FLAGS.features,
+                        "targets": FLAGS.targets,
+                        "token_indexers": {
+                            "token": {
+                                "parameters": {
+                                    "model_name": FLAGS.pretrained_transformer_name
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        },
+        "dataset_reader": {
+            "parameters": {
+                "features": FLAGS.features,
+                "targets": FLAGS.targets,
+                "token_indexers": {
+                    "token": {
+                        "parameters": {
+                            "model_name": FLAGS.pretrained_transformer_name
+                        }
+                    }
+                }
+            }
+        },
         "pretrained_tokens": FLAGS.pretrained_tokens,
-        "pretrained_transformer_name": FLAGS.pretrained_transformer_name,
-        "features": " ".join(FLAGS.features),
-        "targets": " ".join(FLAGS.targets),
-        "type": "finetuning" if finetuning else "default",
-        "embedding_dim": int(FLAGS.embedding_dim),
-        "cuda_device": int(FLAGS.cuda_device),
-        "num_epochs": int(FLAGS.num_epochs),
         "word_batch_size": int(FLAGS.word_batch_size),
-        "use_tensorboard": int(FLAGS.tensorboard),
     }
 
+    return to_override
+
 
 def main():
     """Parse flags."""
diff --git a/combo/polish_model_training.ipynb b/combo/polish_model_training.ipynb
index bd2796f..005fe4b 100644
--- a/combo/polish_model_training.ipynb
+++ b/combo/polish_model_training.ipynb
@@ -6,16 +6,16 @@
    "outputs": [],
    "source": [
     "# The path where the training and validation datasets are stored\n",
-    "TRAINING_DATA_PATH: str = '/Users/majajablonska/Documents/PDB/PDBUD_train.conllu'\n",
-    "VALIDATION_DATA_PATH: str = '/Users/majajablonska/Documents/PDB/PDBUD_val.conllu'\n",
+    "TRAINING_DATA_PATH: str = '/Users/majajablonska/Documents/PDBUD/train.conllu'\n",
+    "VALIDATION_DATA_PATH: str = '/Users/majajablonska/Documents/PDBUD/val.conllu'\n",
     "# The path where the model can be saved to\n",
     "SERIALIZATION_DIR: str = \"/Users/majajablonska/Documents/Workspace/combotest\""
    ],
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:18:36.376046Z",
-     "start_time": "2023-11-13T08:18:36.189836Z"
+     "end_time": "2023-11-13T12:15:21.197003Z",
+     "start_time": "2023-11-13T12:15:19.886422Z"
     }
    },
    "id": "b28c7d8bacb08d02"
@@ -51,8 +51,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:18:41.300316Z",
-     "start_time": "2023-11-13T08:18:36.197537Z"
+     "end_time": "2023-11-13T12:15:28.665585Z",
+     "start_time": "2023-11-13T12:15:19.907198Z"
     }
    },
    "id": "initial_id"
@@ -77,7 +77,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "e45b57a7047043a48ccfeacfb49312b5"
+       "model_id": "2179b1be2f484a33948a76d087002182"
       }
      },
      "metadata": {},
@@ -89,7 +89,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "a886ae4451474459b088659ebac076ae"
+       "model_id": "86762d681ee0467e8501de2b34061aad"
       }
      },
      "metadata": {},
@@ -101,7 +101,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "20ff98564c7c43a9971c25f82ceda997"
+       "model_id": "b9e631cb77594ea5aae60e6d15809885"
       }
      },
      "metadata": {},
@@ -169,8 +169,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:19:01.785477Z",
-     "start_time": "2023-11-13T08:18:41.119674Z"
+     "end_time": "2023-11-13T12:15:51.717065Z",
+     "start_time": "2023-11-13T12:15:28.442131Z"
     }
    },
    "id": "d74957f422f0b05b"
@@ -192,8 +192,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:19:03.816723Z",
-     "start_time": "2023-11-13T08:19:01.774666Z"
+     "end_time": "2023-11-13T12:15:52.574303Z",
+     "start_time": "2023-11-13T12:15:51.724469Z"
     }
    },
    "id": "fa724d362fd6bd23"
@@ -211,7 +211,7 @@
     },
     {
      "data": {
-      "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fb2d3cdfc80>"
+      "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fb512dc4f20>"
      },
      "execution_count": 5,
      "metadata": {},
@@ -239,8 +239,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:19:03.877868Z",
-     "start_time": "2023-11-13T08:19:03.826289Z"
+     "end_time": "2023-11-13T12:15:52.641199Z",
+     "start_time": "2023-11-13T12:15:52.583194Z"
     }
    },
    "id": "f8a10f9892005fca"
@@ -263,8 +263,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:19:03.887795Z",
-     "start_time": "2023-11-13T08:19:03.870640Z"
+     "end_time": "2023-11-13T12:15:52.659289Z",
+     "start_time": "2023-11-13T12:15:52.625700Z"
     }
    },
    "id": "14413692656b68ac"
@@ -277,7 +277,7 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight']\n",
+      "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
       "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
       "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
      ]
@@ -411,8 +411,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:19:10.006549Z",
-     "start_time": "2023-11-13T08:19:03.885912Z"
+     "end_time": "2023-11-13T12:15:56.509687Z",
+     "start_time": "2023-11-13T12:15:52.658879Z"
     }
    },
    "id": "437d12054baaffa1"
@@ -430,8 +430,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:19:47.953809Z",
-     "start_time": "2023-11-13T08:19:09.989582Z"
+     "end_time": "2023-11-13T12:16:30.663344Z",
+     "start_time": "2023-11-13T12:15:56.529656Z"
     }
    },
    "id": "e131e0ec75dc6927"
@@ -446,8 +446,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:19:53.878201Z",
-     "start_time": "2023-11-13T08:19:47.940147Z"
+     "end_time": "2023-11-13T12:16:45.453326Z",
+     "start_time": "2023-11-13T12:16:30.488388Z"
     }
    },
    "id": "195c71fcf8170ff"
@@ -481,8 +481,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:19:54.150333Z",
-     "start_time": "2023-11-13T08:19:53.874397Z"
+     "end_time": "2023-11-13T12:16:45.785538Z",
+     "start_time": "2023-11-13T12:16:45.365250Z"
     }
    },
    "id": "cefc5173154d1605"
@@ -503,7 +503,7 @@
       "12.1 M    Trainable params\n",
       "124 M     Non-trainable params\n",
       "136 M     Total params\n",
-      "546.115   Total estimated model params size (MB)\n"
+      "546.106   Total estimated model params size (MB)\n"
      ]
     },
     {
@@ -512,7 +512,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "9f6e199b0fd546f5833fdda238964165"
+       "model_id": "f2dd3228246843428b8fcb8ae932c1f1"
       }
      },
      "metadata": {},
@@ -534,7 +534,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "594c065e5d2441f48ba2b87c7a3f528f"
+       "model_id": "0bcdd388df664784ba19667c6a0593a1"
       }
      },
      "metadata": {},
@@ -546,7 +546,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "bd8bf330a90e4eddb52a3c87af6d2869"
+       "model_id": "a8203342bf454c22b292548d64f085a9"
       }
      },
      "metadata": {},
@@ -566,8 +566,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:20:39.048426Z",
-     "start_time": "2023-11-13T08:19:54.147748Z"
+     "end_time": "2023-11-13T12:17:47.659618Z",
+     "start_time": "2023-11-13T12:16:45.706948Z"
     }
    },
    "id": "e5af131bae4b1a33"
@@ -582,8 +582,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:20:39.152284Z",
-     "start_time": "2023-11-13T08:20:39.042845Z"
+     "end_time": "2023-11-13T12:17:47.975345Z",
+     "start_time": "2023-11-13T12:17:47.644327Z"
     }
    },
    "id": "3e23413c86063183"
@@ -598,8 +598,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:20:39.228735Z",
-     "start_time": "2023-11-13T08:20:39.052747Z"
+     "end_time": "2023-11-13T12:17:47.989681Z",
+     "start_time": "2023-11-13T12:17:47.665490Z"
     }
    },
    "id": "d555d7f0223a624b"
@@ -613,8 +613,8 @@
      "output_type": "stream",
      "text": [
       "TOKEN           LEMMA           UPOS       HEAD       DEPREL    \n",
-      "Cześć,          ?????           NOUN                0 root      \n",
-      "jestem          ?????           NOUN                1 punct     \n",
+      "Cześć,          ??????          NOUN                0 root      \n",
+      "jestem          ?????a          NOUN                1 punct     \n",
       "psem.           ?????           NOUN                1 punct     \n"
      ]
     }
@@ -627,8 +627,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:20:39.237630Z",
-     "start_time": "2023-11-13T08:20:39.227051Z"
+     "end_time": "2023-11-13T12:17:48.005229Z",
+     "start_time": "2023-11-13T12:17:47.923055Z"
     }
    },
    "id": "a68cd3861e1ceb67"
@@ -643,8 +643,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:20:39.248539Z",
-     "start_time": "2023-11-13T08:20:39.233003Z"
+     "end_time": "2023-11-13T12:17:48.008545Z",
+     "start_time": "2023-11-13T12:17:47.928808Z"
     }
    },
    "id": "d0f43f4493218b5"
@@ -668,8 +668,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:24:03.513738Z",
-     "start_time": "2023-11-13T08:20:39.250115Z"
+     "end_time": "2023-11-13T12:19:17.944519Z",
+     "start_time": "2023-11-13T12:17:47.965095Z"
     }
    },
    "id": "ec92aa5bb5bb3605"
@@ -684,8 +684,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T08:24:03.668958Z",
-     "start_time": "2023-11-13T08:24:02.256799Z"
+     "end_time": "2023-11-13T12:19:17.954324Z",
+     "start_time": "2023-11-13T12:19:17.920401Z"
     }
    },
    "id": "5ad8a827586f65e3"
diff --git a/docs/training.md b/docs/training.md
new file mode 100644
index 0000000..7ac726c
--- /dev/null
+++ b/docs/training.md
@@ -0,0 +1,26 @@
+# Training
+
+Basic command:
+
+```bash
+combo --mode train \
+      --training_data_path your_training_path \
+      --validation_data_path your_validation_path
+```
+
+Options:
+
+```bash
+combo --helpfull
+```
+
+## Examples
+
+For clarity, the training and validation data paths are omitted.
+
+Train on multiple accelerators (default: train on all available ones)
+```bash
+combo --mode train
+      --n_cuda_devices 8
+```
+
diff --git a/tests/config/test_configuration.py b/tests/config/test_configuration.py
index 48083cf..123bfbc 100644
--- a/tests/config/test_configuration.py
+++ b/tests/config/test_configuration.py
@@ -2,6 +2,7 @@ import unittest
 import os
 
 from combo.config import Registry
+from combo.config.from_parameters import override_parameters
 from combo.data import WhitespaceTokenizer, UniversalDependenciesDatasetReader, Vocabulary
 from combo.data.token_indexers.token_characters_indexer import TokenCharactersIndexer
 
@@ -79,3 +80,61 @@ class ConfigurationTest(unittest.TestCase):
         self.assertEqual(type(reconstructed_vocab), Vocabulary)
         self.assertEqual(reconstructed_vocab.constructed_from, 'from_files')
         self.assertSetEqual(reconstructed_vocab.get_namespaces(), {'animals'})
+
+
+    def test_override_parameters(self):
+        parameters = {
+            'type': 'base_vocabulary',
+            'parameters': {
+                'counter': {'counter': {'test': 0}},
+                'max_vocab_size': 10
+            }
+        }
+
+        to_override = {'parameters': {'max_vocab_size': 15}}
+
+        self.assertDictEqual({
+            'type': 'base_vocabulary',
+            'parameters': {
+                'counter': {'counter': {'test': 0}},
+                'max_vocab_size': 15
+            }
+        }, override_parameters(parameters, to_override))
+
+    def test_override_nested_parameters(self):
+        parameters = {
+            'type': 'base_vocabulary',
+            'parameters': {
+                'counter': {'counter': {'test': 0}, 'another_property': 0},
+                'another_counter': {'counter': {'test': 0}, 'another_property': 0}
+            }
+        }
+
+        to_override = {'parameters': {'another_counter': {'counter': {'test': 1}}}}
+
+        self.assertDictEqual({
+            'type': 'base_vocabulary',
+            'parameters': {
+                'counter': {'counter': {'test': 0}, 'another_property': 0},
+                'another_counter': {'counter': {'test': 1}, 'another_property': 0}
+            }
+        }, override_parameters(parameters, to_override))
+
+    def test_override_parameters_no_change(self):
+        parameters = {
+            'type': 'base_vocabulary',
+            'parameters': {
+                'counter': {'counter': {'test': 0}},
+                'max_vocab_size': 10
+            }
+        }
+
+        to_override = {}
+
+        self.assertDictEqual({
+            'type': 'base_vocabulary',
+            'parameters': {
+                'counter': {'counter': {'test': 0}},
+                'max_vocab_size': 10
+            }
+        }, override_parameters(parameters, to_override))
-- 
GitLab