diff --git a/combo/config/from_parameters.py b/combo/config/from_parameters.py index 84f66ef8f50e32e3f0b475a6c0b3ca91dbe2653a..4fb1b6b0ddfea399304f0507d4b0fb558c8f7788 100644 --- a/combo/config/from_parameters.py +++ b/combo/config/from_parameters.py @@ -167,6 +167,7 @@ class FromParameters: json.dump(self._to_params(pass_down_parameters), f) -def resolve(parameters: Dict[str, Any]) -> Any: +def resolve(parameters: Dict[str, Any], pass_down_parameters: Dict[str, Any] = None) -> Any: + pass_down_parameters = pass_down_parameters or {} clz, clz_init = Registry.resolve(parameters['type']) - return clz.from_parameters(parameters['parameters'], clz_init) + return clz.from_parameters(parameters['parameters'], clz_init, pass_down_parameters) diff --git a/combo/models/archival.py b/combo/models/archival.py index ee43a3a5e184b057e4aef97ab4fc20ba413204ee..6e337763ca3a84b25992a9074a72ed5e28c8add1 100644 --- a/combo/models/archival.py +++ b/combo/models/archival.py @@ -20,6 +20,7 @@ from combo.data.dataset_readers import DatasetReader from combo.modules.model import Model from combo.utils import ConfigurationError from combo.utils.file_utils import cached_path +from config import resolve logger = logging.getLogger(__name__) @@ -220,12 +221,8 @@ def _load_dataset_readers(config, serialization_dir): "validation_dataset_reader", dataset_reader_params.duplicate() ) - dataset_reader = Registry.resolve(DatasetReader, dataset_reader_params.get('type', 'base')).from_parameters( - dict(**dataset_reader_params, serialization_dir=serialization_dir) - ) - validation_dataset_reader = Registry.resolve(DatasetReader, validation_dataset_reader_params.get('type', 'base')).from_parameters( - dict(**validation_dataset_reader_params, serialization_dir=serialization_dir) - ) + dataset_reader = resolve(dataset_reader_params) + validation_dataset_reader = resolve(validation_dataset_reader_params) return dataset_reader, validation_dataset_reader diff --git a/combo/modules/model.py b/combo/modules/model.py index 1c898e37fa6000237119458b2055b9cbe0675fa0..e62bead90ced69abdc505baa8140bb67fd175afe 100644 --- a/combo/modules/model.py +++ b/combo/modules/model.py @@ -16,7 +16,7 @@ from overrides import overrides from combo.common import util from combo.common.params import remove_keys_from_params, Params from combo.config import FromParameters, Registry -from combo.config.from_parameters import register_arguments +from combo.config.from_parameters import register_arguments, resolve from combo.data import Vocabulary, Instance from combo.data.batch import Batch from combo.data.dataset_loaders.dataset_loader import TensorDict @@ -318,12 +318,8 @@ class Model(Module, FromParameters): # Load vocabulary from file vocab_dir = os.path.join(serialization_dir, "vocabulary") # If the config specifies a vocabulary subclass, we need to use it. - vocab_params = config.get("vocabulary", Params({})) - vocab_choice = vocab_params.pop_choice("type", list(Registry.classes()[Vocabulary].keys()), True) - vocab_class = Registry.resolve(Vocabulary, vocab_choice) - vocab = vocab_class.from_files( - vocab_dir, vocab_params.get("padding_token"), vocab_params.get("oov_token") - ) + vocab_params = config.get("vocabulary") + vocab = resolve(vocab_params) model_params = config.get("model") @@ -332,10 +328,7 @@ class Model(Module, FromParameters): # stored in our model. We don't need any pretrained weight file or initializers anymore, # and we don't want the code to look for it, so we remove it from the parameters here. remove_keys_from_params(model_params) - model_type = Registry.resolve(Model, model_params.get('type', 'semantic_multitask')) - model = model_type.from_parameters( - dict(**dict(model_params), vocabulary=vocab, serialization_dir=serialization_dir), {'vocabulary': vocab} - ) + model = resolve(model_params) # Force model to cpu or gpu, as appropriate, to make sure that the embeddings are # in sync with the weights @@ -428,7 +421,7 @@ class Model(Module, FromParameters): # Load using an overridable _load method. # This allows subclasses of Model to override _load. - model_class: Type[Model] = Registry.resolve(Model, model_type) # type: ignore + model_class: Type[Model] = Registry.resolve(model_type) # type: ignore if not isinstance(model_class, type): # If you're using from_archive to specify your model (e.g., for fine tuning), then you # can't currently override the behavior of _load; we just use the default Model._load. diff --git a/combo/modules/text_field_embedders/basic_text_field_embedder.py b/combo/modules/text_field_embedders/basic_text_field_embedder.py index 7587a101b8f7d42aa02feffe6ec6a64d07667a80..95e98e7d938bc937972ea0585bff7a8f17d6c57e 100644 --- a/combo/modules/text_field_embedders/basic_text_field_embedder.py +++ b/combo/modules/text_field_embedders/basic_text_field_embedder.py @@ -9,7 +9,7 @@ import torch from combo.common.params import Params from combo.config import Registry -from combo.config.from_parameters import register_arguments +from combo.config.from_parameters import register_arguments, resolve from combo.data.fields.text_field import TextFieldTensors from combo.modules.text_field_embedders.text_field_embedder import TextFieldEmbedder from combo.modules.time_distributed import TimeDistributed @@ -50,9 +50,7 @@ class BasicTextFieldEmbedder(TextFieldEmbedder): name = "token_embedder_%s" % key if isinstance(embedder, Params): embedder_params = dict(embedder) - embedder = Registry.resolve( - TokenEmbedder, embedder_params.get("type", "basic") - ).from_parameters(embedder_params) + embedder = resolve(embedder_params) self.add_module(name, embedder) self._ordered_embedder_keys = sorted(self._token_embedders.keys()) diff --git a/combo/nn/regularizers/regularizer.py b/combo/nn/regularizers/regularizer.py index b0ee895ab9aa9835e9147ea46f3585fc2d732487..5e6243730e1c7dec0c9b70febfbc578b3a5dda67 100644 --- a/combo/nn/regularizers/regularizer.py +++ b/combo/nn/regularizers/regularizer.py @@ -4,7 +4,7 @@ from typing import List, Tuple, Dict, Any import torch from combo.config import FromParameters, Registry -from combo.config.from_parameters import register_arguments +from combo.config.from_parameters import register_arguments, resolve from combo.nn.regularizers import Regularizer from overrides import overrides @@ -57,13 +57,8 @@ class Regularizer(FromParameters): if isinstance(regularizer_dict, dict): if 'type' not in regularizer_dict: raise ConfigurationError('Regularizer dict does not have the type field') - resolved_regularizer, resolved_regularizer_constr = Registry.resolve(regularizer_dict['type']) - regexes_to_pass.append((regex, - resolved_regularizer.from_parameters( - regularizer_dict.get('parameters', {}), - resolved_regularizer_constr, - pass_down_parameters - ))) + resolved_regularizer = resolve(regularizer_dict, pass_down_parameters) + regexes_to_pass.append((regex, resolved_regularizer)) else: regexes_to_pass.append((regex, regularizer_dict)) return cls(regexes_to_pass) diff --git a/combo/predict.py b/combo/predict.py index 361a096365c2deed5b0669ea7778048e4b718aba..e3b285ecf54764acd8e8341b701acd151959e620 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -10,7 +10,7 @@ from overrides import overrides from combo import data, models, common from combo.common import util from combo.config import Registry -from combo.config.from_parameters import register_arguments +from combo.config.from_parameters import register_arguments, resolve from combo.data import tokenizers, Instance, conllu2sentence, tokens2conllu, sentence2conllu from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.data.dataset_readers.dataset_reader import DatasetReader @@ -263,10 +263,7 @@ class COMBO(PredictorModule): archive = models.load_archive(model_path, cuda_device=cuda_device) model = archive.model - dataset_reader_class = archive.config["dataset_reader"].get( - "type", "conllu") - dataset_reader = Registry.resolve( - DatasetReader, - dataset_reader_class - ).from_parameters(archive.config["dataset_reader"]) + dataset_reader = resolve( + archive.config["dataset_reader"] + ) return cls(model, dataset_reader, tokenizer, batch_size) diff --git a/tests/config/test_configuration.py b/tests/config/test_configuration.py index 30922c2935f77fef0c77b5d62ed71b2654473798..be5d9a090031684260b5cea1b21805abc6ae706b 100644 --- a/tests/config/test_configuration.py +++ b/tests/config/test_configuration.py @@ -50,16 +50,6 @@ class ConfigurationTest(unittest.TestCase): self.assertEqual(vocab.constructed_from, 'from_files') self.assertSetEqual(vocab.get_namespaces(), {'animal_tags', 'animals'}) - def test_save_from_files_vocabulary(self): - parameters = {'type': 'from_files_vocabulary', - 'parameters': { - 'directory': VOCABULARY_DIR - }} - vocab_type, constructor = Registry.resolve(parameters['type']) - vocab = vocab_type.from_parameters(parameters['parameters'], constructor) - vocab_params = vocab.serialize() - print(vocab_params) - def test_serialize(self): vocab = Vocabulary({'counter': {'test': 0}}, max_vocab_size=10) self.assertDictEqual(vocab.serialize(),