Skip to content
Snippets Groups Projects
Commit baa4b609 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Serialization

parent 0fcda040
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
import inspect
from typing import Any, Callable, Dict, List, Optional
import typing
import functools
import json
......@@ -24,9 +25,6 @@ def _resolve(values: typing.Union[Dict[str, Any], str], pass_down_parameters: Di
if isinstance(values, Params):
values = Params.as_dict()
tt = values.get("type", "?") if isinstance(values, dict) else values
print(f'Resolving {tt} with pass_down_parameters {pass_down_parameters}')
if isinstance(values, list):
return [_resolve(v, pass_down_parameters) for v in values]
......@@ -68,6 +66,7 @@ def serialize_single_value(value: Any, pass_down_parameter_names: List[str] = No
def register_arguments(func: callable):
@functools.wraps(func)
def wrapper(*args, **kwargs):
self_arg = args[0]
if self_arg.constructed_args is None:
......
......@@ -131,7 +131,7 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC):
yield self.text_to_instance(annotation)
def text_to_instance(self, tree: conllu.TokenList) -> Instance:
def text_to_instance(self, tree: conllu.models.TokenList) -> Instance:
fields_: Dict[str, Field] = {}
tokens = [Token.from_conllu_token(t) for t in tree if isinstance(t["id"], int)]
......
......@@ -8,7 +8,6 @@ from typing import Dict, Optional, Iterable, Set, List, Union, Any
import logging
from filelock import FileLock
from overrides import overrides
from transformers import PreTrainedTokenizer
from combo.common import Tqdm
......@@ -118,9 +117,9 @@ class _NamespaceDependentDefaultDict(defaultdict[str, NamespaceVocabulary]):
@Registry.register("base_vocabulary")
@Registry.register("from_files_vocabulary", "from_files")
@Registry.register("from_pretrained_transformer_vocabulary", "from_pretrained_transformer")
@Registry.register("from_instances_vocabulary", "from_instances")
@Registry.register("from_data_loader_vocabulary", "from_data_loader")
@Registry.register("from_pretrained_transformer_and_instances_vocabulary", "from_pretrained_transformer_and_instances")
@Registry.register("from_instances_extended_vocabulary", "from_instances_extended")
@Registry.register("from_data_loader_extended_vocabulary", "from_data_loader_extended")
class Vocabulary(FromParameters):
@register_arguments
def __init__(self,
......@@ -306,6 +305,37 @@ class Vocabulary(FromParameters):
@classmethod
@register_arguments
def from_data_loader(
cls,
data_loader: "DataLoader",
min_count: Dict[str, int] = None,
max_vocab_size: Union[int, Dict[str, int]] = None,
non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
pretrained_files: Optional[Dict[str, str]] = None,
only_include_pretrained_words: bool = False,
tokens_to_add: Dict[str, List[str]] = None,
min_pretrained_embeddings: Dict[str, int] = None,
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
serialization_dir: Optional[str] = None
) -> "Vocabulary":
vocab = cls.from_instances(
instances=data_loader.iter_instances(),
min_count=min_count,
max_vocab_size=max_vocab_size,
non_padded_namespaces=non_padded_namespaces,
pretrained_files=pretrained_files,
only_include_pretrained_words=only_include_pretrained_words,
tokens_to_add=tokens_to_add,
min_pretrained_embeddings=min_pretrained_embeddings,
padding_token=padding_token,
oov_token=oov_token,
serialization_dir=serialization_dir
)
vocab.constructed_from = 'from_dataset_loader'
return vocab
@classmethod
def from_instances(
cls,
instances: Iterable["Instance"],
......@@ -349,7 +379,6 @@ class Vocabulary(FromParameters):
oov_token=oov_token,
serialization_dir=serialization_dir
)
vocab.constructed_from = 'from_instances'
return vocab
@classmethod
......@@ -389,8 +418,7 @@ class Vocabulary(FromParameters):
pretrained_files=pretrained_files,
only_include_pretrained_words=only_include_pretrained_words,
tokens_to_add=tokens_to_add,
min_pretrained_embeddings=min_pretrained_embeddings,
serialization_dir=serialization_dir
min_pretrained_embeddings=min_pretrained_embeddings
)
vocab.constructed_from = 'from_files_and_instances'
return vocab
......@@ -649,6 +677,36 @@ class Vocabulary(FromParameters):
@classmethod
@register_arguments
def from_data_loader_extended(
cls,
data_loader: "DataLoader",
min_count: Dict[str, int] = None,
max_vocab_size: Union[int, Dict[str, int]] = None,
non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
pretrained_files: Optional[Dict[str, str]] = None,
only_include_pretrained_words: bool = False,
min_pretrained_embeddings: Dict[str, int] = None,
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
serialization_dir: Optional[str] = None
) -> "Vocabulary":
vocab = cls.from_instances_extended(
instances=data_loader.iter_instances(),
min_count=min_count,
max_vocab_size=max_vocab_size,
non_padded_namespaces=non_padded_namespaces,
pretrained_files=pretrained_files,
only_include_pretrained_words=only_include_pretrained_words,
min_pretrained_embeddings=min_pretrained_embeddings,
padding_token=padding_token,
oov_token=oov_token,
serialization_dir=serialization_dir
)
vocab.constructed_from = 'from_data_loader_extended'
return vocab
@classmethod
def from_instances_extended(
cls,
instances: Iterable["Instance"],
......@@ -695,17 +753,6 @@ class Vocabulary(FromParameters):
vocab.constructed_from = 'from_instances_extended'
return vocab
@overrides
def serialize(self, pass_down_parameter_names: List[str] = None) -> Dict[str, Any]:
if self._serialization_dir is None:
raise ConfigurationError("To serialize a vocabulary, serialization_dir needs to be provided")
self.save_to_files(self._serialization_dir)
return {'type': 'from_files_vocabulary',
'parameters': {
'directory': self._serialization_dir,
'padding_token': self._padding_token,
'oov_token': self._oov_token
}}
def get_slices_if_not_provided(vocab: Vocabulary):
if hasattr(vocab, "slices"):
......
......@@ -17,7 +17,7 @@ from combo.modules.token_embedders import CharacterBasedWordEmbedder, Transforme
from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivation
from combo.modules import FeedForwardPredictor
from combo.nn.base import Linear
from combo.nn.regularizers import RegularizerApplicator
from combo.nn.regularizers import Regularizer
from combo.nn.regularizers.regularizers import L2Regularizer
......@@ -153,7 +153,7 @@ def default_model(vocabulary: Vocabulary) -> ComboModel:
num_layers=2,
vocab_namespace="feats_labels"
),
regularizer=RegularizerApplicator([
regularizer=Regularizer([
(".*conv1d.*", L2Regularizer(1e-6)),
(".*forward.*", L2Regularizer(1e-6)),
(".*backward.*", L2Regularizer(1e-6)),
......
......@@ -14,7 +14,7 @@ from combo.modules.parser import DependencyRelationModel
from combo.modules import TextFieldEmbedder
from combo.modules.model import Model
from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
from combo.nn import RegularizerApplicator, base
from combo.nn import Regularizer, base
from combo.nn.utils import get_text_field_mask
from combo.utils import metrics
......@@ -40,7 +40,7 @@ class ComboModel(Model, FromParameters):
morphological_feat: MorphologicalFeatures = None,
dependency_relation: DependencyRelationModel = None,
enhanced_dependency_relation: DependencyRelationModel = None,
regularizer: RegularizerApplicator = None,
regularizer: Regularizer = None,
serialization_dir: Optional[str] = None) -> None:
super().__init__(vocabulary, regularizer, serialization_dir)
......
......@@ -21,7 +21,7 @@ from combo.data import Vocabulary, Instance
from combo.data.batch import Batch
from combo.data.dataset_loaders.dataset_loader import TensorDict
from combo.modules.module import Module
from combo.nn import utils, RegularizerApplicator
from combo.nn import utils, Regularizer
from combo.nn.utils import device_mapping
from combo.utils import ConfigurationError
......@@ -69,7 +69,7 @@ class Model(Module, FromParameters):
In a typical AllenNLP configuration file, this parameter does not get an entry under the
"model", it gets specified as a top-level parameter, then is passed in to the model
separately.
regularizer: `RegularizerApplicator`, optional
regularizer: `Regularizer`, optional
If given, the `Trainer` will use this to regularize model parameters.
serialization_dir: `str`, optional
The directory in which the training output is saved to, or the directory the model is loaded from.
......@@ -82,7 +82,7 @@ class Model(Module, FromParameters):
def __init__(
self,
vocabulary: Vocabulary,
regularizer: RegularizerApplicator = None,
regularizer: Regularizer = None,
serialization_dir: Optional[str] = None,
) -> None:
super(Model, self).__init__()
......
from .regularizers import *
from .regularizer_applicator import *
\ No newline at end of file
from .regularizer import *
\ No newline at end of file
import re
from typing import List, Tuple
from typing import List, Tuple, Dict, Any
import torch
......@@ -7,9 +7,13 @@ from combo.config import FromParameters, Registry
from combo.config.from_parameters import register_arguments
from combo.nn.regularizers import Regularizer
from overrides import overrides
from utils import ConfigurationError
@Registry.register('base_regularizer')
class RegularizerApplicator(FromParameters):
class Regularizer(FromParameters):
"""
Applies regularizers to the parameters of a Module based on regex matches.
"""
......@@ -41,3 +45,22 @@ class RegularizerApplicator(FromParameters):
accumulator = accumulator + penalty
break
return accumulator
@classmethod
def from_parameters(cls,
parameters: Dict[str, Any] = None,
constructor_method_name: str = None,
pass_down_parameters: Dict[str, Any] = None):
regexes = parameters.get('regexes', [])
regexes_to_pass = []
for regex, regularizer_dict in regexes:
if 'type' not in regularizer_dict:
raise ConfigurationError('Regularizer dict does not have the type field')
resolved_regularizer, resolved_regularizer_constr = Registry.resolve(regularizer_dict['type'])
regexes_to_pass.append((regex,
resolved_regularizer.from_parameters(
regularizer_dict.get('parameters', {}),
resolved_regularizer_constr,
pass_down_parameters
)))
return cls(regexes_to_pass)
This diff is collapsed.
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment