Skip to content
Snippets Groups Projects
Commit b8d64d27 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Fixes to serialization and registry

parent 3038c5c3
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
......@@ -167,6 +167,7 @@ class FromParameters:
json.dump(self._to_params(pass_down_parameters), f)
def resolve(parameters: Dict[str, Any]) -> Any:
def resolve(parameters: Dict[str, Any], pass_down_parameters: Dict[str, Any] = None) -> Any:
pass_down_parameters = pass_down_parameters or {}
clz, clz_init = Registry.resolve(parameters['type'])
return clz.from_parameters(parameters['parameters'], clz_init)
return clz.from_parameters(parameters['parameters'], clz_init, pass_down_parameters)
......@@ -20,6 +20,7 @@ from combo.data.dataset_readers import DatasetReader
from combo.modules.model import Model
from combo.utils import ConfigurationError
from combo.utils.file_utils import cached_path
from config import resolve
logger = logging.getLogger(__name__)
......@@ -220,12 +221,8 @@ def _load_dataset_readers(config, serialization_dir):
"validation_dataset_reader", dataset_reader_params.duplicate()
)
dataset_reader = Registry.resolve(DatasetReader, dataset_reader_params.get('type', 'base')).from_parameters(
dict(**dataset_reader_params, serialization_dir=serialization_dir)
)
validation_dataset_reader = Registry.resolve(DatasetReader, validation_dataset_reader_params.get('type', 'base')).from_parameters(
dict(**validation_dataset_reader_params, serialization_dir=serialization_dir)
)
dataset_reader = resolve(dataset_reader_params)
validation_dataset_reader = resolve(validation_dataset_reader_params)
return dataset_reader, validation_dataset_reader
......
......@@ -16,7 +16,7 @@ from overrides import overrides
from combo.common import util
from combo.common.params import remove_keys_from_params, Params
from combo.config import FromParameters, Registry
from combo.config.from_parameters import register_arguments
from combo.config.from_parameters import register_arguments, resolve
from combo.data import Vocabulary, Instance
from combo.data.batch import Batch
from combo.data.dataset_loaders.dataset_loader import TensorDict
......@@ -318,12 +318,8 @@ class Model(Module, FromParameters):
# Load vocabulary from file
vocab_dir = os.path.join(serialization_dir, "vocabulary")
# If the config specifies a vocabulary subclass, we need to use it.
vocab_params = config.get("vocabulary", Params({}))
vocab_choice = vocab_params.pop_choice("type", list(Registry.classes()[Vocabulary].keys()), True)
vocab_class = Registry.resolve(Vocabulary, vocab_choice)
vocab = vocab_class.from_files(
vocab_dir, vocab_params.get("padding_token"), vocab_params.get("oov_token")
)
vocab_params = config.get("vocabulary")
vocab = resolve(vocab_params)
model_params = config.get("model")
......@@ -332,10 +328,7 @@ class Model(Module, FromParameters):
# stored in our model. We don't need any pretrained weight file or initializers anymore,
# and we don't want the code to look for it, so we remove it from the parameters here.
remove_keys_from_params(model_params)
model_type = Registry.resolve(Model, model_params.get('type', 'semantic_multitask'))
model = model_type.from_parameters(
dict(**dict(model_params), vocabulary=vocab, serialization_dir=serialization_dir), {'vocabulary': vocab}
)
model = resolve(model_params)
# Force model to cpu or gpu, as appropriate, to make sure that the embeddings are
# in sync with the weights
......@@ -428,7 +421,7 @@ class Model(Module, FromParameters):
# Load using an overridable _load method.
# This allows subclasses of Model to override _load.
model_class: Type[Model] = Registry.resolve(Model, model_type) # type: ignore
model_class: Type[Model] = Registry.resolve(model_type) # type: ignore
if not isinstance(model_class, type):
# If you're using from_archive to specify your model (e.g., for fine tuning), then you
# can't currently override the behavior of _load; we just use the default Model._load.
......
......@@ -9,7 +9,7 @@ import torch
from combo.common.params import Params
from combo.config import Registry
from combo.config.from_parameters import register_arguments
from combo.config.from_parameters import register_arguments, resolve
from combo.data.fields.text_field import TextFieldTensors
from combo.modules.text_field_embedders.text_field_embedder import TextFieldEmbedder
from combo.modules.time_distributed import TimeDistributed
......@@ -50,9 +50,7 @@ class BasicTextFieldEmbedder(TextFieldEmbedder):
name = "token_embedder_%s" % key
if isinstance(embedder, Params):
embedder_params = dict(embedder)
embedder = Registry.resolve(
TokenEmbedder, embedder_params.get("type", "basic")
).from_parameters(embedder_params)
embedder = resolve(embedder_params)
self.add_module(name, embedder)
self._ordered_embedder_keys = sorted(self._token_embedders.keys())
......
......@@ -4,7 +4,7 @@ from typing import List, Tuple, Dict, Any
import torch
from combo.config import FromParameters, Registry
from combo.config.from_parameters import register_arguments
from combo.config.from_parameters import register_arguments, resolve
from combo.nn.regularizers import Regularizer
from overrides import overrides
......@@ -57,13 +57,8 @@ class Regularizer(FromParameters):
if isinstance(regularizer_dict, dict):
if 'type' not in regularizer_dict:
raise ConfigurationError('Regularizer dict does not have the type field')
resolved_regularizer, resolved_regularizer_constr = Registry.resolve(regularizer_dict['type'])
regexes_to_pass.append((regex,
resolved_regularizer.from_parameters(
regularizer_dict.get('parameters', {}),
resolved_regularizer_constr,
pass_down_parameters
)))
resolved_regularizer = resolve(regularizer_dict, pass_down_parameters)
regexes_to_pass.append((regex, resolved_regularizer))
else:
regexes_to_pass.append((regex, regularizer_dict))
return cls(regexes_to_pass)
......@@ -10,7 +10,7 @@ from overrides import overrides
from combo import data, models, common
from combo.common import util
from combo.config import Registry
from combo.config.from_parameters import register_arguments
from combo.config.from_parameters import register_arguments, resolve
from combo.data import tokenizers, Instance, conllu2sentence, tokens2conllu, sentence2conllu
from combo.data.dataset_loaders.dataset_loader import TensorDict
from combo.data.dataset_readers.dataset_reader import DatasetReader
......@@ -263,10 +263,7 @@ class COMBO(PredictorModule):
archive = models.load_archive(model_path, cuda_device=cuda_device)
model = archive.model
dataset_reader_class = archive.config["dataset_reader"].get(
"type", "conllu")
dataset_reader = Registry.resolve(
DatasetReader,
dataset_reader_class
).from_parameters(archive.config["dataset_reader"])
dataset_reader = resolve(
archive.config["dataset_reader"]
)
return cls(model, dataset_reader, tokenizer, batch_size)
......@@ -50,16 +50,6 @@ class ConfigurationTest(unittest.TestCase):
self.assertEqual(vocab.constructed_from, 'from_files')
self.assertSetEqual(vocab.get_namespaces(), {'animal_tags', 'animals'})
def test_save_from_files_vocabulary(self):
parameters = {'type': 'from_files_vocabulary',
'parameters': {
'directory': VOCABULARY_DIR
}}
vocab_type, constructor = Registry.resolve(parameters['type'])
vocab = vocab_type.from_parameters(parameters['parameters'], constructor)
vocab_params = vocab.serialize()
print(vocab_params)
def test_serialize(self):
vocab = Vocabulary({'counter': {'test': 0}}, max_vocab_size=10)
self.assertDictEqual(vocab.serialize(),
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment