Skip to content
Snippets Groups Projects
Commit 9821e7ae authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Minor fixes to CLI prediction

parent 109e0255
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
......@@ -142,6 +142,7 @@ class FromParameters:
def _to_params(self, pass_down_parameter_names: List[str] = None) -> Dict[str, str]:
parameters_to_serialize = self.constructed_args or {}
pass_down_parameter_names = pass_down_parameter_names or []
parameters_dict = {}
for pn, param_value in parameters_to_serialize.items():
if pn in pass_down_parameter_names:
......@@ -151,8 +152,6 @@ class FromParameters:
return parameters_dict
def serialize(self, pass_down_parameter_names: List[str] = None) -> Dict[str, Any]:
pass_down_parameter_names = pass_down_parameter_names or []
constructor_method = self.constructed_from if self.constructed_from else '__init__'
if not getattr(self, constructor_method):
raise ConfigurationError('Class ' + str(type(self)) + ' has no constructor method ' + constructor_method)
......
......@@ -4,7 +4,7 @@ import os
import pathlib
import tempfile
from itertools import chain
from typing import Dict, Optional, Any
from typing import Dict, Optional, Any, Tuple
import torch
from absl import app, flags
......@@ -24,6 +24,7 @@ from config import override_parameters
from config.from_parameters import override_or_add_parameters
from data import LamboTokenizer, Sentence, Vocabulary, DatasetReader
from data.dataset_loaders import DataLoader
from modules.model import Model
from utils import ConfigurationError
logging.setLoggerClass(ComboLogger)
......@@ -127,6 +128,42 @@ def build_vocabulary_from_instances(training_data_loader: DataLoader,
return vocabulary
def get_defaults(dataset_reader: Optional[DatasetReader],
training_data_loader: Optional[DataLoader],
validation_data_loader: Optional[DataLoader],
vocabulary: Optional[Vocabulary],
training_data_path: str,
validation_data_path: str,
prefix: str) -> Tuple[DatasetReader, DataLoader, DataLoader, Vocabulary]:
if not dataset_reader and (FLAGS.test_data_path
or not training_data_loader
or (FLAGS.validation_data_path and not validation_data_loader)):
# Dataset reader is required to read training data and/or for training (and validation) data loader
dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name)
if not training_data_loader:
training_data_loader = default_data_loader(dataset_reader, training_data_path)
else:
if training_data_path:
training_data_loader.data_path = training_data_path
else:
logger.warning(f'No training data path provided - using the path from configuration: ' +
str(training_data_loader.data_path), prefix=prefix)
if FLAGS.validation_data_path and not validation_data_loader:
validation_data_loader = default_data_loader(dataset_reader, validation_data_path)
else:
if validation_data_path:
validation_data_loader.data_path = validation_data_path
else:
logger.warning(f'No validation data path provided - using the path from configuration: ' +
str(validation_data_loader.data_path), prefix=prefix)
if not vocabulary:
vocabulary = build_vocabulary_from_instances(training_data_loader, validation_data_loader, prefix)
return dataset_reader, training_data_loader, validation_data_loader, vocabulary
def _read_property_from_config(property_key: str,
params: Dict[str, Any],
logging_prefix: str) -> Optional[Any]:
......@@ -167,7 +204,7 @@ def read_vocabulary_from_config(params: Dict[str, Any],
return vocabulary
def read_model_from_config(logging_prefix: str):
def read_model_from_config(logging_prefix: str) -> Optional[Tuple[Model, DatasetReader, DataLoader, DataLoader, Vocabulary]]:
try:
checks.file_exists(FLAGS.config_path)
except ConfigurationError as e:
......@@ -187,20 +224,29 @@ def read_model_from_config(logging_prefix: str):
validation_data_loader = read_data_loader_from_config(params, logging_prefix, validation=True)
vocabulary = read_vocabulary_from_config(params, logging_prefix)
dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults(
dataset_reader,
training_data_loader,
validation_data_loader,
vocabulary,
FLAGS.training_data_path if not FLAGS.finetuning else FLAGS.finetuning_training_data_path,
FLAGS.validation_data_path if not FLAGS.finetuning else FLAGS.finetuning_validation_data_path,
logging_prefix
)
pass_down_parameters = {'vocabulary': vocabulary}
if not FLAGS.use_pure_config:
pass_down_parameters['model_name'] = FLAGS.pretrained_transformer_name
logger.info('Resolving the model from parameters.', prefix=logging_prefix)
model = resolve(params['model'],
pass_down_parameters=pass_down_parameters)
model = resolve(params['model'], pass_down_parameters=pass_down_parameters)
return model, vocabulary, training_data_loader, validation_data_loader, dataset_reader
return model, dataset_reader, training_data_loader, validation_data_loader, vocabulary
def run(_):
if FLAGS.mode == 'train':
model, vocabulary, training_data_loader, validation_data_loader, dataset_reader = None, None, None, None, None
model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = None, None, None, None, None
if not FLAGS.finetuning:
prefix = 'Training'
......@@ -208,7 +254,7 @@ def run(_):
if FLAGS.config_path:
logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix)
model, vocabulary, training_data_loader, validation_data_loader, dataset_reader = read_model_from_config(prefix)
model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = read_model_from_config(prefix)
if FLAGS.use_pure_config and model is None:
logger.error('Error in configuration - model could not be read from parameters. ' +
......@@ -218,9 +264,6 @@ def run(_):
serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir)
training_data_path = FLAGS.training_data_path
validation_data_path = FLAGS.validation_data_path
else:
prefix = 'Finetuning'
......@@ -242,38 +285,15 @@ def run(_):
serialization_dir = tempfile.mkdtemp(prefix='combo', suffix='-finetuning', dir=FLAGS.serialization_dir)
training_data_path = FLAGS.finetuning_training_data_path
validation_data_path = FLAGS.finetuning_validation_data_path
if not dataset_reader and (FLAGS.test_data_path
or not training_data_loader
or (FLAGS.validation_data_path and not validation_data_loader)):
# Dataset reader is required to read training data and/or for training (and validation) data loader
dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name)
if not training_data_loader:
training_data_loader = default_data_loader(dataset_reader, training_data_path)
else:
if training_data_path:
training_data_loader.data_path = training_data_path
else:
logger.warning(f'No training data path provided - using the path from configuration: ' +
str(training_data_loader.data_path), prefix=prefix)
if FLAGS.validation_data_path and not validation_data_loader:
validation_data_loader = default_data_loader(dataset_reader, validation_data_path)
else:
if validation_data_path:
validation_data_loader.data_path = validation_data_path
else:
logger.warning(f'No validation data path provided - using the path from configuration: ' +
str(validation_data_loader.data_path), prefix=prefix)
if not vocabulary:
vocabulary = build_vocabulary_from_instances(training_data_loader, validation_data_loader, prefix)
if not model:
model = default_model(FLAGS.pretrained_transformer_name, vocabulary)
dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults(
dataset_reader,
training_data_loader,
validation_data_loader,
vocabulary,
FLAGS.finetuning_training_data_path,
FLAGS.finetuning_validation_data_path,
prefix
)
logger.info('Indexing training data loader', prefix=prefix)
training_data_loader.index_with(model.vocab)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment