Skip to content
Snippets Groups Projects
Commit 9821e7ae authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Minor fixes to CLI prediction

parent 109e0255
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
...@@ -142,6 +142,7 @@ class FromParameters: ...@@ -142,6 +142,7 @@ class FromParameters:
def _to_params(self, pass_down_parameter_names: List[str] = None) -> Dict[str, str]: def _to_params(self, pass_down_parameter_names: List[str] = None) -> Dict[str, str]:
parameters_to_serialize = self.constructed_args or {} parameters_to_serialize = self.constructed_args or {}
pass_down_parameter_names = pass_down_parameter_names or [] pass_down_parameter_names = pass_down_parameter_names or []
parameters_dict = {} parameters_dict = {}
for pn, param_value in parameters_to_serialize.items(): for pn, param_value in parameters_to_serialize.items():
if pn in pass_down_parameter_names: if pn in pass_down_parameter_names:
...@@ -151,8 +152,6 @@ class FromParameters: ...@@ -151,8 +152,6 @@ class FromParameters:
return parameters_dict return parameters_dict
def serialize(self, pass_down_parameter_names: List[str] = None) -> Dict[str, Any]: def serialize(self, pass_down_parameter_names: List[str] = None) -> Dict[str, Any]:
pass_down_parameter_names = pass_down_parameter_names or []
constructor_method = self.constructed_from if self.constructed_from else '__init__' constructor_method = self.constructed_from if self.constructed_from else '__init__'
if not getattr(self, constructor_method): if not getattr(self, constructor_method):
raise ConfigurationError('Class ' + str(type(self)) + ' has no constructor method ' + constructor_method) raise ConfigurationError('Class ' + str(type(self)) + ' has no constructor method ' + constructor_method)
......
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import pathlib import pathlib
import tempfile import tempfile
from itertools import chain from itertools import chain
from typing import Dict, Optional, Any from typing import Dict, Optional, Any, Tuple
import torch import torch
from absl import app, flags from absl import app, flags
...@@ -24,6 +24,7 @@ from config import override_parameters ...@@ -24,6 +24,7 @@ from config import override_parameters
from config.from_parameters import override_or_add_parameters from config.from_parameters import override_or_add_parameters
from data import LamboTokenizer, Sentence, Vocabulary, DatasetReader from data import LamboTokenizer, Sentence, Vocabulary, DatasetReader
from data.dataset_loaders import DataLoader from data.dataset_loaders import DataLoader
from modules.model import Model
from utils import ConfigurationError from utils import ConfigurationError
logging.setLoggerClass(ComboLogger) logging.setLoggerClass(ComboLogger)
...@@ -127,6 +128,42 @@ def build_vocabulary_from_instances(training_data_loader: DataLoader, ...@@ -127,6 +128,42 @@ def build_vocabulary_from_instances(training_data_loader: DataLoader,
return vocabulary return vocabulary
def get_defaults(dataset_reader: Optional[DatasetReader],
training_data_loader: Optional[DataLoader],
validation_data_loader: Optional[DataLoader],
vocabulary: Optional[Vocabulary],
training_data_path: str,
validation_data_path: str,
prefix: str) -> Tuple[DatasetReader, DataLoader, DataLoader, Vocabulary]:
if not dataset_reader and (FLAGS.test_data_path
or not training_data_loader
or (FLAGS.validation_data_path and not validation_data_loader)):
# Dataset reader is required to read training data and/or for training (and validation) data loader
dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name)
if not training_data_loader:
training_data_loader = default_data_loader(dataset_reader, training_data_path)
else:
if training_data_path:
training_data_loader.data_path = training_data_path
else:
logger.warning(f'No training data path provided - using the path from configuration: ' +
str(training_data_loader.data_path), prefix=prefix)
if FLAGS.validation_data_path and not validation_data_loader:
validation_data_loader = default_data_loader(dataset_reader, validation_data_path)
else:
if validation_data_path:
validation_data_loader.data_path = validation_data_path
else:
logger.warning(f'No validation data path provided - using the path from configuration: ' +
str(validation_data_loader.data_path), prefix=prefix)
if not vocabulary:
vocabulary = build_vocabulary_from_instances(training_data_loader, validation_data_loader, prefix)
return dataset_reader, training_data_loader, validation_data_loader, vocabulary
def _read_property_from_config(property_key: str, def _read_property_from_config(property_key: str,
params: Dict[str, Any], params: Dict[str, Any],
logging_prefix: str) -> Optional[Any]: logging_prefix: str) -> Optional[Any]:
...@@ -167,7 +204,7 @@ def read_vocabulary_from_config(params: Dict[str, Any], ...@@ -167,7 +204,7 @@ def read_vocabulary_from_config(params: Dict[str, Any],
return vocabulary return vocabulary
def read_model_from_config(logging_prefix: str): def read_model_from_config(logging_prefix: str) -> Optional[Tuple[Model, DatasetReader, DataLoader, DataLoader, Vocabulary]]:
try: try:
checks.file_exists(FLAGS.config_path) checks.file_exists(FLAGS.config_path)
except ConfigurationError as e: except ConfigurationError as e:
...@@ -187,20 +224,29 @@ def read_model_from_config(logging_prefix: str): ...@@ -187,20 +224,29 @@ def read_model_from_config(logging_prefix: str):
validation_data_loader = read_data_loader_from_config(params, logging_prefix, validation=True) validation_data_loader = read_data_loader_from_config(params, logging_prefix, validation=True)
vocabulary = read_vocabulary_from_config(params, logging_prefix) vocabulary = read_vocabulary_from_config(params, logging_prefix)
dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults(
dataset_reader,
training_data_loader,
validation_data_loader,
vocabulary,
FLAGS.training_data_path if not FLAGS.finetuning else FLAGS.finetuning_training_data_path,
FLAGS.validation_data_path if not FLAGS.finetuning else FLAGS.finetuning_validation_data_path,
logging_prefix
)
pass_down_parameters = {'vocabulary': vocabulary} pass_down_parameters = {'vocabulary': vocabulary}
if not FLAGS.use_pure_config: if not FLAGS.use_pure_config:
pass_down_parameters['model_name'] = FLAGS.pretrained_transformer_name pass_down_parameters['model_name'] = FLAGS.pretrained_transformer_name
logger.info('Resolving the model from parameters.', prefix=logging_prefix) logger.info('Resolving the model from parameters.', prefix=logging_prefix)
model = resolve(params['model'], model = resolve(params['model'], pass_down_parameters=pass_down_parameters)
pass_down_parameters=pass_down_parameters)
return model, vocabulary, training_data_loader, validation_data_loader, dataset_reader return model, dataset_reader, training_data_loader, validation_data_loader, vocabulary
def run(_): def run(_):
if FLAGS.mode == 'train': if FLAGS.mode == 'train':
model, vocabulary, training_data_loader, validation_data_loader, dataset_reader = None, None, None, None, None model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = None, None, None, None, None
if not FLAGS.finetuning: if not FLAGS.finetuning:
prefix = 'Training' prefix = 'Training'
...@@ -208,7 +254,7 @@ def run(_): ...@@ -208,7 +254,7 @@ def run(_):
if FLAGS.config_path: if FLAGS.config_path:
logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix) logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix)
model, vocabulary, training_data_loader, validation_data_loader, dataset_reader = read_model_from_config(prefix) model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = read_model_from_config(prefix)
if FLAGS.use_pure_config and model is None: if FLAGS.use_pure_config and model is None:
logger.error('Error in configuration - model could not be read from parameters. ' + logger.error('Error in configuration - model could not be read from parameters. ' +
...@@ -218,9 +264,6 @@ def run(_): ...@@ -218,9 +264,6 @@ def run(_):
serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir) serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir)
training_data_path = FLAGS.training_data_path
validation_data_path = FLAGS.validation_data_path
else: else:
prefix = 'Finetuning' prefix = 'Finetuning'
...@@ -242,38 +285,15 @@ def run(_): ...@@ -242,38 +285,15 @@ def run(_):
serialization_dir = tempfile.mkdtemp(prefix='combo', suffix='-finetuning', dir=FLAGS.serialization_dir) serialization_dir = tempfile.mkdtemp(prefix='combo', suffix='-finetuning', dir=FLAGS.serialization_dir)
training_data_path = FLAGS.finetuning_training_data_path dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults(
validation_data_path = FLAGS.finetuning_validation_data_path dataset_reader,
training_data_loader,
if not dataset_reader and (FLAGS.test_data_path validation_data_loader,
or not training_data_loader vocabulary,
or (FLAGS.validation_data_path and not validation_data_loader)): FLAGS.finetuning_training_data_path,
# Dataset reader is required to read training data and/or for training (and validation) data loader FLAGS.finetuning_validation_data_path,
dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name) prefix
)
if not training_data_loader:
training_data_loader = default_data_loader(dataset_reader, training_data_path)
else:
if training_data_path:
training_data_loader.data_path = training_data_path
else:
logger.warning(f'No training data path provided - using the path from configuration: ' +
str(training_data_loader.data_path), prefix=prefix)
if FLAGS.validation_data_path and not validation_data_loader:
validation_data_loader = default_data_loader(dataset_reader, validation_data_path)
else:
if validation_data_path:
validation_data_loader.data_path = validation_data_path
else:
logger.warning(f'No validation data path provided - using the path from configuration: ' +
str(validation_data_loader.data_path), prefix=prefix)
if not vocabulary:
vocabulary = build_vocabulary_from_instances(training_data_loader, validation_data_loader, prefix)
if not model:
model = default_model(FLAGS.pretrained_transformer_name, vocabulary)
logger.info('Indexing training data loader', prefix=prefix) logger.info('Indexing training data loader', prefix=prefix)
training_data_loader.index_with(model.vocab) training_data_loader.index_with(model.vocab)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment