diff --git a/combo/config/from_parameters.py b/combo/config/from_parameters.py index 5f5580b1f3da8bea2ccb137097df729af2966c79..743c80ac18587ed741651388e5b0dff6815d7be6 100644 --- a/combo/config/from_parameters.py +++ b/combo/config/from_parameters.py @@ -216,8 +216,8 @@ def override_parameters(parameters: Dict[str, Any], override_values: Dict[str, A overriden_parameters = flatten_dictionary(parameters) override_values = flatten_dictionary(override_values) for ko, vo in override_values.items(): - if ko in overriden_parameters: - overriden_parameters[ko] = vo + #if ko in overriden_parameters: + overriden_parameters[ko] = vo return unflatten_dictionary(overriden_parameters) diff --git a/combo/main.py b/combo/main.py index 4f0a84669c5322535b33232fa456172f6eddea57..84571612e87109e007274c8d264976210e57daae 100755 --- a/combo/main.py +++ b/combo/main.py @@ -154,7 +154,11 @@ def get_defaults(dataset_reader: Optional[DatasetReader], or not training_data_loader or (FLAGS.validation_data_path and not validation_data_loader)): # Dataset reader is required to read training data and/or for training (and validation) data loader - dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name) + dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name, + tokenizer=LamboTokenizer(FLAGS.tokenizer_language, + default_turns=FLAGS.turns, + default_split_subwords=FLAGS.split_subwords) + ) if not training_data_loader: training_data_loader = default_data_loader(dataset_reader, training_data_path) @@ -399,9 +403,9 @@ def run(_): logger.info("No dataset reader in the configuration or archive file - using a default UD dataset reader", prefix=prefix) dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name, - LamboTokenizer(tokenizer_language, - default_turns=FLAGS.turns, - default_split_subwords=FLAGS.split_subwords) + tokenizer=LamboTokenizer(tokenizer_language, + default_turns=FLAGS.turns, + default_split_subwords=FLAGS.split_subwords) ) predictor = COMBO(model, dataset_reader) @@ -491,7 +495,12 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: "reader": { "parameters": { "features": FLAGS.features, - "targets": FLAGS.targets + "targets": FLAGS.targets, + "tokenizer": { + "parameters": { + "language": FLAGS.tokenizer_language + } + } } } } @@ -504,7 +513,12 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: "reader": { "parameters": { "features": FLAGS.features, - "targets": FLAGS.targets + "targets": FLAGS.targets, + "tokenizer": { + "parameters": { + "language": FLAGS.tokenizer_language + } + } } } } @@ -512,7 +526,12 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: "dataset_reader": { "parameters": { "features": FLAGS.features, - "targets": FLAGS.targets + "targets": FLAGS.targets, + "tokenizer": { + "parameters": { + "language": FLAGS.tokenizer_language + } + } } } } @@ -526,7 +545,6 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: "oov_token": "_" } } - return to_override