diff --git a/combo/main.py b/combo/main.py
index 9eb0a03469f7b9d344cd2797cca8fc8874b259e9..1e281069011b7b94992359d44d81d5d55c94fbc9 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -27,7 +27,6 @@ from combo.modules.model import Model
 from combo.utils import ConfigurationError
 from combo.utils.matrices import extract_combo_matrices
 
-
 logging.setLoggerClass(ComboLogger)
 logger = logging.getLogger(__name__)
 _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"]
@@ -81,6 +80,8 @@ flags.DEFINE_string(name="config_path", default="",
                     help="Config file path.")
 flags.DEFINE_boolean(name="save_matrices", default=True,
                      help="Save relation distribution matrices.")
+flags.DEFINE_list(name="datasets_for_vocabulary", default=["train"],
+                  help="")
 
 # Finetune after training flags
 flags.DEFINE_string(name="finetuning_training_data_path", default="",
@@ -115,10 +116,19 @@ def build_vocabulary_from_instances(training_data_loader: DataLoader,
                                     validation_data_loader: Optional[DataLoader],
                                     logging_prefix: str) -> Vocabulary:
     logger.info('Building a vocabulary from instances.', prefix=logging_prefix)
-    instances = chain(training_data_loader.iter_instances(),
-                      validation_data_loader.iter_instances()) \
-        if validation_data_loader \
-        else training_data_loader.iter_instances()
+    if "train" in FLAGS.datasets_for_vocabulary and "valid" in FLAGS.datasets_for_vocabulary:
+        instances = chain(training_data_loader.iter_instances(),
+                          validation_data_loader.iter_instances()) \
+            if validation_data_loader \
+            else training_data_loader.iter_instances()
+    elif "train" in FLAGS.datasets_for_vocabulary:
+        instances = training_data_loader.iter_instances()
+    elif "valid" in FLAGS.datasets_for_vocabulary:
+        instances = validation_data_loader.iter_instances()
+    else:
+        logger.error("train and valid are the only allowed values for --datasets_for_vocabulary!",
+                     prefix=logging_prefix)
+        raise ValueError("train and valid are the only allowed values for --datasets_for_vocabulary!")
     vocabulary = Vocabulary.from_instances_extended(
         instances,
         non_padded_namespaces=['head_labels'],
@@ -165,6 +175,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader],
 
     return dataset_reader, training_data_loader, validation_data_loader, vocabulary
 
+
 def _read_property_from_config(property_key: str,
                                params: Dict[str, Any],
                                logging_prefix: str,
@@ -212,13 +223,19 @@ def read_vocabulary_from_config(params: Dict[str, Any],
     return vocabulary
 
 
-def read_model_from_config(logging_prefix: str) -> Optional[Tuple[Model, DatasetReader, DataLoader, DataLoader, Vocabulary]]:
+def read_model_from_config(logging_prefix: str) -> Optional[
+    Tuple[Model, DatasetReader, DataLoader, DataLoader, Vocabulary]]:
     try:
         checks.file_exists(FLAGS.config_path)
     except ConfigurationError as e:
         handle_error(e, logging_prefix)
         return
 
+    if FLAGS.serialization_dir is None:
+        logger.error(f'--serialization_dir was not passed as an argument!')
+        print(f'--serialization_dir was not passed as an argument!')
+        return
+
     with open(FLAGS.config_path, 'r') as f:
         params = json.load(f)
 
@@ -235,12 +252,14 @@ def read_model_from_config(logging_prefix: str) -> Optional[Tuple[Model, Dataset
     dataset_reader = read_dataset_reader_from_config(params, logging_prefix, pass_down_parameters)
     training_data_loader = read_data_loader_from_config(params, logging_prefix,
                                                         validation=False, pass_down_parameters=pass_down_parameters)
-    if (not FLAGS.validation_data_path or not FLAGS.finetuning_validation_data_path) and 'validation_data_loader' in params:
+    if (
+            not FLAGS.validation_data_path or not FLAGS.finetuning_validation_data_path) and 'validation_data_loader' in params:
         logger.warning('Validation data loader is in parameters, but no validation data path was provided!')
         validation_data_loader = None
     else:
         validation_data_loader = read_data_loader_from_config(params, logging_prefix,
-                                                              validation=True, pass_down_parameters=pass_down_parameters)
+                                                              validation=True,
+                                                              pass_down_parameters=pass_down_parameters)
     vocabulary = read_vocabulary_from_config(params, logging_prefix, pass_down_parameters)
 
     dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults(
@@ -273,7 +292,8 @@ def run(_):
 
             if FLAGS.config_path:
                 logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix)
-                model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = read_model_from_config(prefix)
+                model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = read_model_from_config(
+                    prefix)
             else:
                 dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults(
                     dataset_reader,