Skip to content
Snippets Groups Projects
Commit 1e1002f5 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Add an option to build a vocabulary from either training, validation sets, or both

parent a8613d6c
1 merge request!46Merge COMBO 3.0 into master
......@@ -27,7 +27,6 @@ from combo.modules.model import Model
from combo.utils import ConfigurationError
from combo.utils.matrices import extract_combo_matrices
logging.setLoggerClass(ComboLogger)
logger = logging.getLogger(__name__)
_FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"]
......@@ -81,6 +80,8 @@ flags.DEFINE_string(name="config_path", default="",
help="Config file path.")
flags.DEFINE_boolean(name="save_matrices", default=True,
help="Save relation distribution matrices.")
flags.DEFINE_list(name="datasets_for_vocabulary", default=["train"],
help="")
# Finetune after training flags
flags.DEFINE_string(name="finetuning_training_data_path", default="",
......@@ -115,10 +116,19 @@ def build_vocabulary_from_instances(training_data_loader: DataLoader,
validation_data_loader: Optional[DataLoader],
logging_prefix: str) -> Vocabulary:
logger.info('Building a vocabulary from instances.', prefix=logging_prefix)
instances = chain(training_data_loader.iter_instances(),
validation_data_loader.iter_instances()) \
if validation_data_loader \
else training_data_loader.iter_instances()
if "train" in FLAGS.datasets_for_vocabulary and "valid" in FLAGS.datasets_for_vocabulary:
instances = chain(training_data_loader.iter_instances(),
validation_data_loader.iter_instances()) \
if validation_data_loader \
else training_data_loader.iter_instances()
elif "train" in FLAGS.datasets_for_vocabulary:
instances = training_data_loader.iter_instances()
elif "valid" in FLAGS.datasets_for_vocabulary:
instances = validation_data_loader.iter_instances()
else:
logger.error("train and valid are the only allowed values for --datasets_for_vocabulary!",
prefix=logging_prefix)
raise ValueError("train and valid are the only allowed values for --datasets_for_vocabulary!")
vocabulary = Vocabulary.from_instances_extended(
instances,
non_padded_namespaces=['head_labels'],
......@@ -165,6 +175,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader],
return dataset_reader, training_data_loader, validation_data_loader, vocabulary
def _read_property_from_config(property_key: str,
params: Dict[str, Any],
logging_prefix: str,
......@@ -212,13 +223,19 @@ def read_vocabulary_from_config(params: Dict[str, Any],
return vocabulary
def read_model_from_config(logging_prefix: str) -> Optional[Tuple[Model, DatasetReader, DataLoader, DataLoader, Vocabulary]]:
def read_model_from_config(logging_prefix: str) -> Optional[
Tuple[Model, DatasetReader, DataLoader, DataLoader, Vocabulary]]:
try:
checks.file_exists(FLAGS.config_path)
except ConfigurationError as e:
handle_error(e, logging_prefix)
return
if FLAGS.serialization_dir is None:
logger.error(f'--serialization_dir was not passed as an argument!')
print(f'--serialization_dir was not passed as an argument!')
return
with open(FLAGS.config_path, 'r') as f:
params = json.load(f)
......@@ -235,12 +252,14 @@ def read_model_from_config(logging_prefix: str) -> Optional[Tuple[Model, Dataset
dataset_reader = read_dataset_reader_from_config(params, logging_prefix, pass_down_parameters)
training_data_loader = read_data_loader_from_config(params, logging_prefix,
validation=False, pass_down_parameters=pass_down_parameters)
if (not FLAGS.validation_data_path or not FLAGS.finetuning_validation_data_path) and 'validation_data_loader' in params:
if (
not FLAGS.validation_data_path or not FLAGS.finetuning_validation_data_path) and 'validation_data_loader' in params:
logger.warning('Validation data loader is in parameters, but no validation data path was provided!')
validation_data_loader = None
else:
validation_data_loader = read_data_loader_from_config(params, logging_prefix,
validation=True, pass_down_parameters=pass_down_parameters)
validation=True,
pass_down_parameters=pass_down_parameters)
vocabulary = read_vocabulary_from_config(params, logging_prefix, pass_down_parameters)
dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults(
......@@ -273,7 +292,8 @@ def run(_):
if FLAGS.config_path:
logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix)
model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = read_model_from_config(prefix)
model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = read_model_from_config(
prefix)
else:
dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults(
dataset_reader,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment