diff --git a/.idea/combolightning.iml b/.idea/combolightning.iml index 332f5233511b9e9c4e7140693f371f2c7a6382c6..d4c73d1453b1ff1f26ddcd3f4f3d30c25e5096cc 100644 --- a/.idea/combolightning.iml +++ b/.idea/combolightning.iml @@ -5,4 +5,7 @@ <orderEntry type="jdk" jdkName="combo" jdkType="Python SDK" /> <orderEntry type="sourceFolder" forTests="false" /> </component> + <component name="TestRunnerService"> + <option name="PROJECT_TEST_RUNNER" value="Unittests" /> + </component> </module> \ No newline at end of file diff --git a/combo/commands/train.py b/combo/commands/train.py index f4bb80729a4875e48574ac94d43b36660559fe39..79b7f72b642f89f86a94384895bcf7b28a912723 100644 --- a/combo/commands/train.py +++ b/combo/commands/train.py @@ -4,7 +4,7 @@ from pytorch_lightning import Trainer class FinetuningTrainModel(Trainer): """ Class made only for finetuning, - the only difference is saving vocab from concatenated + the only difference is saving vocabulary from concatenated (archive and current) datasets """ pass \ No newline at end of file diff --git a/combo/common/cached_transformers.py b/combo/common/cached_transformers.py index 627879c9acb4391557930de212505649ab74adff..24db221c901517f81be8658a2f565b729c1095e9 100644 --- a/combo/common/cached_transformers.py +++ b/combo/common/cached_transformers.py @@ -3,15 +3,198 @@ Adapted from AllenNLP https://github.com/allenai/allennlp/blob/main/allennlp/common/cached_transformers.py """ import logging +import re +import warnings + import transformers -from typing import Dict, Tuple +from typing import Dict, Optional, Union, Tuple, NamedTuple -from combo.common.util import hash_object +from transformers import AutoModel, AutoConfig +from combo.common.util import hash_object +from combo.utils import ConfigurationError, cast logger = logging.getLogger(__name__) + +class TransformerSpec(NamedTuple): + model_name: str + override_weights_file: Optional[str] = None + override_weights_strip_prefix: Optional[str] = None + reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None + + _tokenizer_cache: Dict[Tuple[str, str], transformers.PreTrainedTokenizer] = {} +_model_cache: Dict[TransformerSpec, transformers.PreTrainedModel] = {} + +def get( + model_name: str, + make_copy: bool, + override_weights_file: Optional[str] = None, + override_weights_strip_prefix: Optional[str] = None, + reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None, + load_weights: bool = True, + **kwargs, +) -> transformers.PreTrainedModel: + """ + Returns a transformer model from the cache. + + # Parameters + + model_name : `str` + The name of the transformer, for example `"bert-base-cased"` + make_copy : `bool` + If this is `True`, return a copy of the model instead of the cached model itself. If you want to modify the + parameters of the model, set this to `True`. If you want only part of the model, set this to `False`, but + make sure to `copy.deepcopy()` the bits you are keeping. + override_weights_file : `str`, optional (default = `None`) + If set, this specifies a file from which to load alternate weights that override the + weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created + with `torch.save()`. + override_weights_strip_prefix : `str`, optional (default = `None`) + If set, strip the given prefix from the state dict when loading it. + reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`) + If this is an integer, the last `reinit_modules` layers of the transformer will be + re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will + be re-initialized. Note, because the module structure of the transformer `model_name` can + differ, we cannot guarantee that providing an integer or tuple of integers will work. If + this fails, you can instead provide a tuple of strings, which will be treated as regexes and + any module with a name matching the regex will be re-initialized. Re-initializing the last + few layers of a pretrained transformer can reduce the instability of fine-tuning on small + datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect + if `load_weights` is `False` or `override_weights_file` is not `None`. + load_weights : `bool`, optional (default = `True`) + If set to `False`, no weights will be loaded. This is helpful when you only + want to initialize the architecture, like when you've already fine-tuned a model + and are going to load the weights from a state dict elsewhere. + """ + global _model_cache + spec = TransformerSpec( + model_name, + override_weights_file, + override_weights_strip_prefix, + reinit_modules, + ) + transformer = _model_cache.get(spec, None) + if transformer is None: + if not load_weights: + if override_weights_file is not None: + warnings.warn( + "You specified an 'override_weights_file' in allennlp.common.cached_transformers.get(), " + "but 'load_weights' is set to False, so 'override_weights_file' will be ignored.", + UserWarning, + ) + if reinit_modules is not None: + warnings.warn( + "You specified 'reinit_modules' in allennlp.common.cached_transformers.get(), " + "but 'load_weights' is set to False, so 'reinit_modules' will be ignored.", + UserWarning, + ) + transformer = AutoModel.from_config( + AutoConfig.from_pretrained( + model_name, + **kwargs, + ) + ) + elif override_weights_file is not None: + if reinit_modules is not None: + warnings.warn( + "You specified 'reinit_modules' in allennlp.common.cached_transformers.get(), " + "but 'override_weights_file' is not None, so 'reinit_modules' will be ignored.", + UserWarning, + ) + import torch + from combo.utils.file_utils import cached_path + + override_weights_file = cached_path(override_weights_file) + override_weights = torch.load(override_weights_file) + if override_weights_strip_prefix is not None: + + def strip_prefix(s): + if s.startswith(override_weights_strip_prefix): + return s[len(override_weights_strip_prefix) :] + else: + return s + + valid_keys = { + k + for k in override_weights.keys() + if k.startswith(override_weights_strip_prefix) + } + if len(valid_keys) > 0: + logger.info( + "Loading %d tensors from %s", len(valid_keys), override_weights_file + ) + else: + raise ValueError( + f"Specified prefix of '{override_weights_strip_prefix}' means no tensors " + f"will be loaded from {override_weights_file}." + ) + override_weights = {strip_prefix(k): override_weights[k] for k in valid_keys} + + transformer = AutoModel.from_config( + AutoConfig.from_pretrained( + model_name, + **kwargs, + ) + ) + # When DistributedDataParallel or DataParallel is used, the state dict of the + # DistributedDataParallel/DataParallel wrapper prepends "module." to all parameters + # of the actual model, since the actual model is stored within the module field. + # This accounts for if a pretained model was saved without removing the + # DistributedDataParallel/DataParallel wrapper. + if hasattr(transformer, "module"): + transformer.module.load_state_dict(override_weights) + else: + transformer.load_state_dict(override_weights) + elif reinit_modules is not None: + transformer = AutoModel.from_pretrained( + model_name, + **kwargs, + ) + num_layers = transformer.config.num_hidden_layers + if isinstance(reinit_modules, int): + reinit_modules = tuple(range(num_layers - reinit_modules, num_layers)) + if all(isinstance(x, int) for x in reinit_modules): + # This type cast is neccessary to avoid a mypy error. + reinit_modules = cast(Tuple[int], reinit_modules) + if any(layer_idx < 0 or layer_idx > num_layers for layer_idx in reinit_modules): + raise ValueError( + f"A layer index in reinit_modules ({reinit_modules}) is invalid." + f" Must be between 0 and the maximum layer index ({num_layers - 1}.)" + ) + # Some transformer models organize their modules differently, so if this fails, + # raise an error with a helpful message. + try: + for layer_idx in reinit_modules: + transformer.encoder.layer[layer_idx].apply(transformer._init_weights) + except AttributeError: + raise ConfigurationError( + f"Unable to re-initialize the layers of transformer model" + f" {model_name} using layer indices. Please provide a tuple of" + " strings corresponding to the names of the layers to re-initialize." + ) + elif all(isinstance(x, str) for x in reinit_modules): + for regex in reinit_modules: + for name, module in transformer.named_modules(): + if re.search(str(regex), name): + module.apply(transformer._init_weights) + else: + raise ValueError( + "reinit_modules must be either an integer, a tuple of strings, or a tuple of integers." + ) + else: + transformer = AutoModel.from_pretrained( + model_name, + **kwargs, + ) + _model_cache[spec] = transformer + if make_copy: + import copy + + return copy.deepcopy(transformer) + else: + return transformer def get_tokenizer(model_name: str, **kwargs) -> transformers.PreTrainedTokenizer: diff --git a/combo/config/__init__.py b/combo/config/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..778e8732286e11c4c7711ed6cf0a25267ba8acd5 100644 --- a/combo/config/__init__.py +++ b/combo/config/__init__.py @@ -0,0 +1,2 @@ +from .from_parameters import FromParameters +from .registry import Registry diff --git a/combo/config/exceptions.py b/combo/config/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..4e0222fbed21850ec74a19aff7d13adcd0618a61 --- /dev/null +++ b/combo/config/exceptions.py @@ -0,0 +1,3 @@ +class RegistryException(Exception): + def __init__(self, message): + super().__init__(message) diff --git a/combo/config/from_parameters.py b/combo/config/from_parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..7d7d5074f7b98c3d239953dca9ce76f74ed9714a --- /dev/null +++ b/combo/config/from_parameters.py @@ -0,0 +1,139 @@ +import inspect +import warnings +from typing import Any, Dict, List, Optional, Type, Union +import typing + +import torch + +from combo.common.params import Params +from combo.config.exceptions import RegistryException +from combo.config.registry import (Registry) +from combo.nn import Activation +from combo.utils import ConfigurationError + + +def resolve_optional(type_annotation): + arg1, arg2 = typing.get_args(type_annotation) + # Optional + if (arg1 is None and arg2 is not None) or (arg1 is not None and arg2 is None): + type_annotation = arg1 or arg2 + return type_annotation + + +def resolve_union(type_annotation, v): + arg1, arg2 = typing.get_args(type_annotation) + if arg1 is type(v) or typing.get_origin(arg1) is type(v): + type_annotation = arg1 + else: + type_annotation = arg2 + return type_annotation + + +def _resolve(type: Type[object], + values: typing.Union[Dict[str, Any], str], + pass_to_subclasses: Dict[str, Any], + default_name: str = 'base'): + if isinstance(values, Params): + values = Params.as_dict() + if isinstance(values, dict): + try: + if 'type' in values.keys(): + return Registry.resolve( + type, + values['type'] + ).from_parameters({**values, **pass_to_subclasses}, pass_to_subclasses) + else: + return Registry.get_default(type).from_parameters({**values, **pass_to_subclasses}, pass_to_subclasses) + except RegistryException: + if type is torch.nn.Module: + return torch.nn.Module(**values) + if type.__base__ is not object: + return _resolve(type.__base__, values, pass_to_subclasses, default_name) + warnings.warn(f'No classes of type {type} (values: {values}) in Registry!') + return values + elif type is Activation and isinstance(values, str): + try: + return Registry.resolve(type, values)() + except RegistryException: + warnings.warn(f'No Activation of class name {values} in Registry!') + return values + else: + return values + + +def _init_from_list(init_func_type_annotation: Type, + value: List[Any], + pass_to_subclasses: Dict[str, Any] = None, + default_name: str = 'base') -> List[Any]: + # Check the typing of the arguments + list_argument_anotation = typing.get_args(init_func_type_annotation)[0] + if isinstance(value, Params): + value = Params.as_dict() + if typing.get_origin(list_argument_anotation) is list: + return [_init_from_list(list_argument_anotation, v, pass_to_subclasses, default_name) for v in value] + elif typing.get_origin(list_argument_anotation) is dict: + return [_init_from_dict(list_argument_anotation, v, pass_to_subclasses, default_name) for v in value] + else: + # A single object: + return [_resolve(list_argument_anotation, v, pass_to_subclasses, default_name) for v in value] + + +def _init_from_dict(init_func_type_annotation: Type, + value: Dict[Any, Any], + pass_to_subclasses: Dict[str, Any] = None, + default_name: str = 'base') -> Dict[Any, Any]: + key_annotation, value_annotation = typing.get_args(init_func_type_annotation) + if isinstance(value, Params): + value = Params.as_dict() + if typing.get_origin(value_annotation) is list: + return {k: _init_from_list(value_annotation, v, pass_to_subclasses, default_name) for k, v in value.items()} + elif typing.get_origin(value_annotation) is dict: + return {k: _init_from_dict(value_annotation, v, pass_to_subclasses, default_name) for k, v in value.items()} + else: + return {k: _resolve(value_annotation, v, pass_to_subclasses, default_name) for k, v in value.items()} + + +class FromParameters: + @classmethod + def from_parameters(cls, + parameters: Dict[str, Any], + pass_to_subclasses: Dict[str, Any] = None, + default_name: str = 'base'): + parameters_to_pass = {} + pass_to_subclasses = pass_to_subclasses or {} + arguments = inspect.signature(cls.__init__).parameters + argument_names = list(arguments.keys()) + for k, v in parameters.items(): + if isinstance(v, Params): + v = v.as_dict() + if k not in argument_names: + # Only initialize the configuration items that have names matching with the + # init arguments functions + continue + # Check the type + type_annotation = arguments[k].annotation + + if typing.get_origin(type_annotation) is Union: + type_annotation = resolve_union(resolve_optional(type_annotation), v) + + # Check if list + if typing.get_origin(type_annotation) is list: + parameters_to_pass[k] = _init_from_list(type_annotation, + v, pass_to_subclasses, default_name) + elif typing.get_origin(type_annotation) is dict: + parameters_to_pass[k] = _init_from_dict(type_annotation, + v, pass_to_subclasses, default_name) + elif typing.get_origin(type_annotation) is None and type_annotation.__base__ is not object: + try: + parameters_to_pass[k] = _resolve(type_annotation, + v, pass_to_subclasses, default_name) + except ConfigurationError: + warnings.warn(f'An object of type {type_annotation} is not in the registry!') + parameters_to_pass[k] = v + else: + parameters_to_pass[k] = v + + return cls(**parameters_to_pass) + + def _to_params(self) -> Dict[str, str]: + return {} diff --git a/combo/config/registry.py b/combo/config/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..c7bfce81de308b7c9cf51172c7f94211cd6f3846 --- /dev/null +++ b/combo/config/registry.py @@ -0,0 +1,61 @@ +from collections import defaultdict +from typing import Any, Optional, Type, Union, Dict, List + +from combo.config.exceptions import RegistryException + + +class Registry: + __classes = defaultdict(dict) + __defaults: Dict[Type[object], Type[object]] = {} + + @classmethod + def classes(cls) -> Dict[Type[object], Dict[str, Type]]: + return cls.__classes + + @classmethod + def defaults(cls) -> Dict[Type[object], Type[object]]: + return cls.__defaults + + @classmethod + def register_base_class(cls, registry_name: str, default: bool = False): + def decorator(clz): + nonlocal registry_name, default + # if category not in REGISTRY_CLASSES: + # warnings.warn(f'Category {str(category)} is not in REGISTRY_CLASSES') + cls.__classes[clz][registry_name.lower()] = clz + if default: + cls.__defaults[clz] = clz + return clz + + return decorator + + @classmethod + def register(cls, category: Type[object], registry_name: str, default: bool = False): + def decorator(clz): + nonlocal category, registry_name, default + # if category not in REGISTRY_CLASSES: + # warnings.warn(f'Category {str(category)} is not in REGISTRY_CLASSES') + cls.__classes[category][registry_name.lower()] = clz + if default: + cls.__defaults[category] = clz + return clz + + return decorator + + @classmethod + def resolve(cls, category: Union[Type[object], str], class_name: str) -> Optional[Type]: + try: + return cls.__classes[category][class_name.lower()] + except KeyError: + try: + return cls.__defaults[category] + except KeyError: + raise RegistryException(f'No {str(category)} with registered name {class_name.lower()}, no default value either') + + @classmethod + def get_default(cls, category: Type[object]) -> Optional[Type]: + try: + return cls.__defaults[category] + except KeyError: + raise RegistryException( + f'No default value for {str(category)}') diff --git a/combo/data/__init__.py b/combo/data/__init__.py index f75b7d83b0467f0c8082f8ad66649509189db5bd..3331bb62bce02c70dd6529408cc2523302578695 100644 --- a/combo/data/__init__.py +++ b/combo/data/__init__.py @@ -1,4 +1,3 @@ -from .api import (Token, Sentence, sentence2conllu, tokens2conllu, conllu2sentence) from .vocabulary import Vocabulary from .samplers import TokenCountBatchSampler from .instance import Instance @@ -7,3 +6,4 @@ from .tokenizers import (Tokenizer, Token, CharacterTokenizer, PretrainedTransfo SpacyTokenizer, WhitespaceTokenizer, LamboTokenizer) from .dataset_readers import (ConllDatasetReader, DatasetReader, TextClassificationJSONReader, UniversalDependenciesDatasetReader) +from .api import (Sentence, tokens2conllu, conllu2sentence, sentence2conllu) diff --git a/combo/data/dataset_readers/__init__.py b/combo/data/dataset_readers/__init__.py index 2ec1049784efa5e94a186aaa8ec1a505ae5d7a7b..b8088abc3af8261c18e8ce6ecd93ed2adaad58f9 100644 --- a/combo/data/dataset_readers/__init__.py +++ b/combo/data/dataset_readers/__init__.py @@ -1,4 +1,4 @@ from .dataset_reader import DatasetReader from .text_classification_json_reader import TextClassificationJSONReader from .universal_dependencies_dataset_reader import UniversalDependenciesDatasetReader -from .conll import ConllDatasetReader +from .conllu import ConllDatasetReader diff --git a/combo/data/dataset_readers/conll.py b/combo/data/dataset_readers/conllu.py similarity index 96% rename from combo/data/dataset_readers/conll.py rename to combo/data/dataset_readers/conllu.py index d8c2c5df947d918f9a925fcf3e4e60ff3e697de2..18b3eaafeb7bbf93c0d278793c911daaa21240c3 100644 --- a/combo/data/dataset_readers/conll.py +++ b/combo/data/dataset_readers/conllu.py @@ -7,6 +7,8 @@ import itertools import logging from typing import Dict, List, Optional, Sequence, Iterable +from overrides import overrides + from combo.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer from combo.data.token_indexers.token_indexer import TokenIndexer, Token from combo.utils import ConfigurationError @@ -14,6 +16,8 @@ from .dataset_reader import DatasetReader from .dataset_utils.span_utils import to_bioul from .. import Instance from ..fields import MetadataField, TextField, Field, SequenceLabelField +from ...config import Registry +from ...config.from_parameters import FromParameters from ...utils.file_utils import cached_path logger = logging.getLogger(__name__) @@ -32,7 +36,8 @@ def _is_divider(line: str) -> bool: # TODO: maybe one should note whether the format is IOB1 or IOB2 in the processed dataset? -class ConllDatasetReader(DatasetReader): +@Registry.register(DatasetReader, 'conll2003') +class ConllDatasetReader(DatasetReader, FromParameters): """ Reads instances from a pretokenised file where each line is in the following format: ``` @@ -93,7 +98,7 @@ class ConllDatasetReader(DatasetReader): super().__init__(**kwargs) - self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} + self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} if tag_label is not None and tag_label not in self._VALID_LABELS: raise ConfigurationError("unknown tag label type: {}".format(tag_label)) for label in feature_labels: @@ -219,4 +224,4 @@ class ConllDatasetReader(DatasetReader): return self def apply_token_indexers(self, instance: Instance) -> None: - instance.fields["tokens"]._token_indexers = self._token_indexers # type: ignore + instance.fields["tokens"]._token_indexers = self.token_indexers # type: ignore diff --git a/combo/data/dataset_readers/dataset_reader.py b/combo/data/dataset_readers/dataset_reader.py index 092c7a07a0419fa72fb7421aff695e175a074af8..cad9f4b9142669472fbf6b8e1f8e18738d22553e 100644 --- a/combo/data/dataset_readers/dataset_reader.py +++ b/combo/data/dataset_readers/dataset_reader.py @@ -9,6 +9,7 @@ from typing import Iterable, Iterator, Optional, Union, TypeVar, Dict, List from overrides import overrides from torch.utils.data import IterableDataset +from combo.config import FromParameters from combo.data import SpacyTokenizer, SingleIdTokenIndexer from combo.data.instance import Instance from combo.data.tokenizers import Tokenizer @@ -21,7 +22,7 @@ PathOrStr = Union[PathLike, str] DatasetReaderInput = Union[PathOrStr, List[PathOrStr], Dict[str, PathOrStr]] -class DatasetReader(IterableDataset): +class DatasetReader(IterableDataset, FromParameters): """ A `DatasetReader` knows how to turn a file containing a dataset into a collection of `Instance`s. @@ -50,6 +51,10 @@ class DatasetReader(IterableDataset): def token_indexers(self) -> Optional[Dict[str, TokenIndexer]]: return self.__token_indexers + @token_indexers.setter + def token_indexers(self, token_indexers: Dict[str, TokenIndexer]): + self.__token_indexers = token_indexers + @overrides def __getitem__(self, item, **kwargs) -> Instance: raise NotImplementedError @@ -69,3 +74,22 @@ class DatasetReader(IterableDataset): def apply_token_indexers(self, instance: Instance) -> None: pass + + def text_to_instance(self, *inputs) -> Instance: + """ + Does whatever tokenization or processing is necessary to go from textual input to an + `Instance`. The primary intended use for this is with a + :class:`~allennlp.predictors.predictor.Predictor`, which gets text input as a JSON + object and needs to process it to be input to a model. + + The intent here is to share code between :func:`_read` and what happens at + model serving time, or any other time you want to make a prediction from new data. We need + to process the data in the same way it was done at training time. Allowing the + `DatasetReader` to process new text lets us accomplish this, as we can just call + `DatasetReader.text_to_instance` when serving predictions. + + The input type here is rather vaguely specified, unfortunately. The `Predictor` will + have to make some assumptions about the kind of `DatasetReader` that it's using, in order + to pass it the right information. + """ + raise NotImplementedError diff --git a/combo/data/dataset_readers/text_classification_json_reader.py b/combo/data/dataset_readers/text_classification_json_reader.py index 95455deb44176b1a7aa54cd710058228192fe071..781d310258a890cfefaf1e8d0566dce355dc3ecc 100644 --- a/combo/data/dataset_readers/text_classification_json_reader.py +++ b/combo/data/dataset_readers/text_classification_json_reader.py @@ -14,6 +14,7 @@ from .. import Instance, Tokenizer, TokenIndexer from ..fields import Field, ListField from ..fields.label_field import LabelField from ..fields.text_field import TextField +from ...config import Registry from ...utils import ConfigurationError @@ -22,6 +23,7 @@ def _is_sentence_segmenter(sentence_segmenter: Optional[Tokenizer]) -> bool: return callable(split_sentences_method) +@Registry.register(DatasetReader, 'text_classification_json') class TextClassificationJSONReader(DatasetReader): def __init__(self, tokenizer: Optional[Tokenizer] = None, @@ -110,6 +112,7 @@ class TextClassificationJSONReader(DatasetReader): label = str(label) yield self.text_to_instance(text, label) + def text_to_instance(self, text: str, label: Optional[Union[str, int]] = None) -> Instance: diff --git a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py index 33bad6db815fcc5a568a753192b223afe9384a56..d06852cd56c700972249a132d32653ad9a6f8b83 100644 --- a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py +++ b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py @@ -12,6 +12,7 @@ import torch from overrides import overrides from combo import data +from combo.config import Registry from combo.data import Vocabulary, fields, Instance, Token from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.fields import Field @@ -24,6 +25,7 @@ from conllu import parser from combo.utils import checks, pad_sequence_to_length +@Registry.register(DatasetReader, 'conllu') class UniversalDependenciesDatasetReader(DatasetReader, ABC): def __init__( self, @@ -120,7 +122,7 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC): def text_to_instance(self, tree: conllu.TokenList) -> Instance: fields_: Dict[str, Field] = {} - tree_tokens = [t for t in tree if isinstance(t["id"], int)] + tree_tokens = [t for t in tree if isinstance(t["idx"], int)] tokens = [Token(text=t["token"], upostag=t.get("upostag"), diff --git a/combo/data/fields/sequence_label_field.py b/combo/data/fields/sequence_label_field.py index 60d30c6f3959eb276c949aacf0deb596853d4dbe..b20bb6f3ba54aaf329951989010adff5e0280f63 100644 --- a/combo/data/fields/sequence_label_field.py +++ b/combo/data/fields/sequence_label_field.py @@ -30,7 +30,7 @@ class SequenceLabelField(Field[torch.Tensor]): labels : `Union[List[str], List[int]]` A sequence of categorical labels, encoded as strings or integers. These could be POS tags like [NN, JJ, ...], BIO tags like [B-PERS, I-PERS, O, O, ...], or any other categorical tag sequence. If the - labels are encoded as integers, they will not be indexed using a vocab. + labels are encoded as integers, they will not be indexed using a vocabulary. sequence_field : `SequenceField` A field containing the sequence that this `SequenceLabelField` is labeling. Most often, this is a `TextField`, for tagging individual tokens in a sentence. diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py index 4938d438a0bf14ea369dae000a34df8e9a855a97..55b579afb8d37980dd7c57eb668275a819dcf3e2 100644 --- a/combo/data/fields/sequence_multilabel_field.py +++ b/combo/data/fields/sequence_multilabel_field.py @@ -36,7 +36,7 @@ class SequenceMultiLabelField(Field[torch.Tensor]): multi_labels : `List[List[str]]` multi_label_indexer : `Callable[[data.Vocabulary], Callable[[List[str]], List[int]]]` - Nested callable which based on vocab and sequence length maps values of the fields in the sequence + Nested callable which based on vocabulary and sequence length maps values of the fields in the sequence from strings to indexed, int values. as_tensor: `Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]]` Nested callable which based on the field itself, maps indexed data to a tensor. diff --git a/combo/data/fields/text_field.py b/combo/data/fields/text_field.py index c4de2733c1338c684e1592d86a5345d87028b111..da42252b60a04d2ab784d9762f4695e8f1810e3f 100644 --- a/combo/data/fields/text_field.py +++ b/combo/data/fields/text_field.py @@ -19,7 +19,6 @@ import torch # produced by a given TokenIndexer, which will be input to a particular TokenEmbedder's forward() # method. We label these as tensors, because that's what they typically are, though they could in # reality have arbitrary type. -from combo.data import Vocabulary from combo.data.fields.sequence_field import SequenceField from combo.data.token_indexers import TokenIndexer, IndexedTokenList from combo.data.tokenizers import Token @@ -103,7 +102,7 @@ class TextField(SequenceField[TextFieldTensors]): for token in self.tokens: indexer.count_vocab_items(token, counter) - def index(self, vocab: Vocabulary): + def index(self, vocab): self._indexed_tokens = {} for indexer_name, indexer in self.token_indexers.items(): self._indexed_tokens[indexer_name] = indexer.tokens_to_indices(self.tokens, vocab) diff --git a/combo/data/instance.py b/combo/data/instance.py index 34771aec12bb13a7a53a047dddba4d09b3da939b..1c0ab268e04bbe3f5aacc84bb25695cdac5a2d03 100644 --- a/combo/data/instance.py +++ b/combo/data/instance.py @@ -50,7 +50,7 @@ class Instance(Mapping[str, Field]): """ Add the field to the existing fields mapping. If we have already indexed the Instance, then we also index `field`, so - it is necessary to supply the vocab. + it is necessary to supply the vocabulary. """ self.fields[field_name] = field if self.indexed and vocab is not None: diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py index 75df1b396d7e25b8b71c47f4704fcfc31f3cf947..d3edf27d0849e84e346eae5a65f11c355555c609 100644 --- a/combo/data/token_indexers/__init__.py +++ b/combo/data/token_indexers/__init__.py @@ -4,3 +4,4 @@ from .single_id_token_indexer import SingleIdTokenIndexer from .pretrained_transformer_indexer import PretrainedTransformerIndexer from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer from .pretrained_transformer_fixed_mismatched_indexer import PretrainedTransformerFixedMismatchedIndexer +from .token_const_padding_characters_indexer import TokenConstPaddingCharactersIndexer diff --git a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py index af7bb0a9afcb01a94d77853fa303c929e3571f5b..71e366761935e5527092e8b6c5c9c4ce42757c37 100644 --- a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py @@ -3,16 +3,19 @@ Adapted from COMBO Authors: Mateusz Klimaszewski, Lukasz Pszenny """ -from typing import Optional, Dict, Any, List, Tuple +from typing import Optional, Dict, Any from overrides import overrides +from combo.config import Registry from combo.data import Vocabulary +from combo.data.token_indexers.token_indexer import TokenIndexer from combo.data.token_indexers import IndexedTokenList from combo.data.token_indexers.pretrained_transformer_indexer import PretrainedTransformerIndexer from combo.data.token_indexers.pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer +@Registry.register(TokenIndexer, 'pretrained_transformer_mismatched_fixed') class PretrainedTransformerFixedMismatchedIndexer(PretrainedTransformerMismatchedIndexer): def __init__(self, model_name: str, namespace: str = "tags", max_length: int = None, diff --git a/combo/data/token_indexers/pretrained_transformer_indexer.py b/combo/data/token_indexers/pretrained_transformer_indexer.py index 6580f4b5bd48a11b79d9b510c05becabaac0c16a..d2a8a62fc4e580211ad67dfd1d4ca6d68804b600 100644 --- a/combo/data/token_indexers/pretrained_transformer_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_indexer.py @@ -6,7 +6,9 @@ https://github.com/allenai/allennlp/blob/main/allennlp/data/token_indexers/pretr from typing import Dict, List, Optional, Tuple, Any import logging import torch +from overrides import overrides +from combo.config import Registry from combo.data import Vocabulary from combo.data.tokenizers import Token from combo.data.token_indexers import TokenIndexer, IndexedTokenList @@ -16,6 +18,7 @@ from combo.utils import pad_sequence_to_length logger = logging.getLogger(__name__) +@Registry.register(TokenIndexer, 'pretrained_transformer') class PretrainedTransformerIndexer(TokenIndexer): """ This `TokenIndexer` assumes that Tokens already have their indexes in them (see `text_id` field). @@ -79,7 +82,7 @@ class PretrainedTransformerIndexer(TokenIndexer): def _add_encoding_to_vocabulary_if_needed(self, vocab: Vocabulary) -> None: """ - Copies tokens from ```transformers``` model's vocab to the specified _namespace. + Copies tokens from ```transformers``` model's vocabulary to the specified _namespace. """ if self._added_to_vocabulary: return @@ -88,10 +91,12 @@ class PretrainedTransformerIndexer(TokenIndexer): self._added_to_vocabulary = True + @overrides def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]): # If we only use pretrained models, we don't need to do anything here. pass + @overrides def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList: self._add_encoding_to_vocabulary_if_needed(vocabulary) @@ -105,6 +110,7 @@ class PretrainedTransformerIndexer(TokenIndexer): return self._postprocess_output(output) + @overrides def indices_to_tokens( self, indexed_tokens: IndexedTokenList, vocabulary: Vocabulary ) -> List[Token]: @@ -203,12 +209,14 @@ class PretrainedTransformerIndexer(TokenIndexer): return output + @overrides def get_empty_token_list(self) -> IndexedTokenList: output: IndexedTokenList = {"token_ids": [], "mask": [], "type_ids": []} if self._max_length is not None: output["segment_concat_mask"] = [] return output + @overrides def as_padded_tensor_dict( self, tokens: IndexedTokenList, padding_lengths: Dict[str, int] ) -> Dict[str, torch.Tensor]: diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py index 6d660e7cb40c78e98bae173641c91a28ac6a2ce4..e41407b82b3c4ac725038331cbe0702caf5db2af 100644 --- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py @@ -7,7 +7,9 @@ from typing import Dict, List, Any, Optional import logging import torch +from overrides import overrides +from combo.config import Registry, FromParameters from combo.data import Vocabulary from combo.data.tokenizers import Token from combo.data.token_indexers import TokenIndexer, IndexedTokenList @@ -17,6 +19,7 @@ from combo.utils import pad_sequence_to_length logger = logging.getLogger(__name__) +@Registry.register(TokenIndexer, 'pretrained_transformer_mismatched') class PretrainedTransformerMismatchedIndexer(TokenIndexer): """ Use this indexer when (for whatever reason) you are not using a corresponding @@ -67,9 +70,11 @@ class PretrainedTransformerMismatchedIndexer(TokenIndexer): self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens + @overrides def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]): return self._matched_indexer.count_vocab_items(token, counter) + @overrides def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList: self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary) @@ -91,12 +96,14 @@ class PretrainedTransformerMismatchedIndexer(TokenIndexer): return self._matched_indexer._postprocess_output(output) + @overrides def get_empty_token_list(self) -> IndexedTokenList: output = self._matched_indexer.get_empty_token_list() output["offsets"] = [] output["wordpiece_mask"] = [] return output + @overrides def as_padded_tensor_dict( self, tokens: IndexedTokenList, padding_lengths: Dict[str, int] ) -> Dict[str, torch.Tensor]: diff --git a/combo/data/token_indexers/single_id_token_indexer.py b/combo/data/token_indexers/single_id_token_indexer.py index 143c7861fbf824eb216f1f9526a489a50f13c25d..4f49bd1a52c1d3da0d8ab45a206e41c06de8e2e5 100644 --- a/combo/data/token_indexers/single_id_token_indexer.py +++ b/combo/data/token_indexers/single_id_token_indexer.py @@ -6,6 +6,9 @@ https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c8 from typing import Dict, List, Optional, Any import itertools +from overrides import overrides + +from combo.config import FromParameters, Registry from combo.data import Vocabulary from combo.data.tokenizers import Token from combo.data.token_indexers import TokenIndexer, IndexedTokenList @@ -13,6 +16,7 @@ from combo.data.token_indexers import TokenIndexer, IndexedTokenList _DEFAULT_VALUE = "THIS IS A REALLY UNLIKELY VALUE THAT HAS TO BE A STRING" +@Registry.register(TokenIndexer, 'single_id') class SingleIdTokenIndexer(TokenIndexer): """ This :class:`TokenIndexer` represents tokens as single integers. @@ -65,6 +69,7 @@ class SingleIdTokenIndexer(TokenIndexer): self._feature_name = feature_name self._default_value = default_value + @overrides def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]): if self._namespace is not None: text = self._get_feature_value(token) @@ -72,9 +77,10 @@ class SingleIdTokenIndexer(TokenIndexer): text = text.lower() counter[self._namespace][text] += 1 + @overrides def tokens_to_indices( self, tokens: List[Token], vocabulary: Vocabulary - ) -> Dict[str, List[int]]: + ) -> IndexedTokenList: indices: List[int] = [] for token in itertools.chain(self._start_tokens, tokens, self._end_tokens): @@ -89,6 +95,7 @@ class SingleIdTokenIndexer(TokenIndexer): return {"tokens": indices} + @overrides def get_empty_token_list(self) -> IndexedTokenList: return {"tokens": []} diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py index 1227f2e5a8b433a9a17b11c80a1d379b1a883f93..56f84f68e60073ab1a39def3797e30ba6ab3eba3 100644 --- a/combo/data/token_indexers/token_characters_indexer.py +++ b/combo/data/token_indexers/token_characters_indexer.py @@ -10,12 +10,14 @@ import warnings from overrides import overrides import torch +from combo.config import Registry from combo.data import Vocabulary from combo.data.token_indexers import TokenIndexer, IndexedTokenList from combo.data.tokenizers import Token, CharacterTokenizer from combo.utils import ConfigurationError, pad_sequence_to_length +@Registry.register(TokenIndexer, 'token_characters') class TokenCharactersIndexer(TokenIndexer): """ This :class:`TokenIndexer` represents tokens as lists of character indices. @@ -27,7 +29,7 @@ class TokenCharactersIndexer(TokenIndexer): _namespace : `str`, optional (default=`token_characters`) We will use this _namespace in the :class:`Vocabulary` to map the characters in each token to indices. - character_tokenizer : `CharacterTokenizer`, optional (default=`CharacterTokenizer()`) + tokenizer : `CharacterTokenizer`, optional (default=`CharacterTokenizer()`) We use a :class:`CharacterTokenizer` to handle splitting tokens into characters, as it has options for byte encoding and other things. The default here is to instantiate a `CharacterTokenizer` with its default parameters, which uses unicode characters and @@ -46,7 +48,7 @@ class TokenCharactersIndexer(TokenIndexer): def __init__( self, namespace: str = "token_characters", - character_tokenizer: CharacterTokenizer = CharacterTokenizer(), + tokenizer: CharacterTokenizer = CharacterTokenizer(), start_tokens: List[str] = None, end_tokens: List[str] = None, min_padding_length: int = 0, @@ -64,7 +66,7 @@ class TokenCharactersIndexer(TokenIndexer): ) self._min_padding_length = min_padding_length self._namespace = namespace - self._character_tokenizer = character_tokenizer + self._character_tokenizer = tokenizer self._start_tokens = [Token(st) for st in (start_tokens or [])] self._end_tokens = [Token(et) for et in (end_tokens or [])] @@ -75,14 +77,14 @@ class TokenCharactersIndexer(TokenIndexer): raise ConfigurationError("TokenCharactersIndexer needs a tokenizer that retains text") for character in self._character_tokenizer.tokenize(token.text): # If `text_id` is set on the character token (e.g., if we're using byte encoding), we - # will not be using the vocab for this character. + # will not be using the vocabulary for this character. if getattr(character, "text_id", None) is None: counter[self._namespace][character.text] += 1 @overrides def tokens_to_indices( self, tokens: List[Token], vocabulary: Vocabulary - ) -> Dict[str, List[List[int]]]: + ) -> IndexedTokenList: indices: List[List[int]] = [] for token in itertools.chain(self._start_tokens, tokens, self._end_tokens): token_indices: List[int] = [] @@ -92,7 +94,7 @@ class TokenCharactersIndexer(TokenIndexer): ) for character in self._character_tokenizer.tokenize(token.text): if getattr(character, "text_id", None) is not None: - # `text_id` being set on the token means that we aren't using the vocab, we just + # `text_id` being set on the token means that we aren't using the vocabulary, we just # use this idx instead. index = character.text_id else: @@ -147,3 +149,14 @@ class TokenCharactersIndexer(TokenIndexer): @overrides def get_empty_token_list(self) -> IndexedTokenList: return {"token_characters": []} + + @overrides + def _to_params(self) -> Dict[str, str]: + return { + 'token_min_padding_length': str(self._token_min_padding_length), + 'min_padding_length': str(self._min_padding_length), + 'namespace': str(self._namespace), + 'character_tokenizer': str(self._character_tokenizer), + 'start_tokens': str(self._start_tokens), + 'end_tokens': str(self._end_tokens) + } diff --git a/combo/data/token_indexers/token_const_padding_characters_indexer.py b/combo/data/token_indexers/token_const_padding_characters_indexer.py index 6b7440d79db3cacc84ae4cc5ca54915778880756..d193258d3173c2bc4da25f2e43763676a59f1986 100644 --- a/combo/data/token_indexers/token_const_padding_characters_indexer.py +++ b/combo/data/token_indexers/token_const_padding_characters_indexer.py @@ -7,7 +7,10 @@ import itertools from typing import List, Dict import torch + +from combo.config import Registry from combo.data.token_indexers import IndexedTokenList +from combo.data.token_indexers.token_indexer import TokenIndexer from overrides import overrides from combo.data.token_indexers.token_characters_indexer import TokenCharactersIndexer @@ -15,17 +18,18 @@ from combo.data.tokenizers import CharacterTokenizer from combo.utils import pad_sequence_to_length +@Registry.register(TokenIndexer, 'characters_const_padding') class TokenConstPaddingCharactersIndexer(TokenCharactersIndexer): """Wrapper around allennlp token indexer with const padding.""" def __init__(self, namespace: str = "token_characters", - character_tokenizer: CharacterTokenizer = CharacterTokenizer(), + tokenizer: CharacterTokenizer = CharacterTokenizer(), start_tokens: List[str] = None, end_tokens: List[str] = None, min_padding_length: int = 0, token_min_padding_length: int = 0): - super().__init__(namespace, character_tokenizer, start_tokens, end_tokens, min_padding_length, + super().__init__(namespace, tokenizer, start_tokens, end_tokens, min_padding_length, token_min_padding_length) @overrides diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py index c0c1f21915f7bef134e0f5ad5458cfe876fe1e63..cd88be608398f90efcdb3603ddddbea0248b20b1 100644 --- a/combo/data/token_indexers/token_features_indexer.py +++ b/combo/data/token_indexers/token_features_indexer.py @@ -3,19 +3,19 @@ Adapted from COMBO. Author: Mateusz Klimaszewski """ import collections -from abc import ABC from typing import List, Dict import torch from overrides import overrides -from combo.data import Vocabulary +from combo.config import Registry from combo.data.tokenizers.tokenizer import Token from combo.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList from combo.utils import pad_sequence_to_length -class TokenFeatsIndexer(TokenIndexer, ABC): +@Registry.register(TokenIndexer, 'feats_indexer') +class TokenFeatsIndexer(TokenIndexer): def __init__( self, @@ -34,7 +34,7 @@ class TokenFeatsIndexer(TokenIndexer, ABC): counter[self.namespace][feat] += 1 @overrides - def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList: + def tokens_to_indices(self, tokens: List[Token], vocabulary) -> IndexedTokenList: indices: List[List[int]] = [] vocab_size = vocabulary.get_vocab_size(self.namespace) for token in tokens: @@ -79,3 +79,11 @@ class TokenFeatsIndexer(TokenIndexer, ABC): ) tensor_dict[key] = tensor return tensor_dict + + @overrides + def _to_params(self) -> Dict[str, str]: + return { + 'token_min_padding_length': str(self._token_min_padding_length), + 'namespace': str(self.namespace), + 'feature_name': str(self._feature_name) + } diff --git a/combo/data/token_indexers/token_indexer.py b/combo/data/token_indexers/token_indexer.py index 31835daf3a8aa97242369466b2625a69db3e7efd..2c701ff97b85860f12179771bbe791a9b17761be 100644 --- a/combo/data/token_indexers/token_indexer.py +++ b/combo/data/token_indexers/token_indexer.py @@ -6,7 +6,9 @@ https://github.com/allenai/allennlp/blob/main/allennlp/data/token_indexers/token from typing import Any, Dict, List import torch +from overrides import overrides +from combo.config import FromParameters from combo.data.tokenizers.tokenizer import Token from combo.data.vocabulary import Vocabulary from combo.utils import pad_sequence_to_length @@ -19,7 +21,7 @@ from combo.utils import pad_sequence_to_length IndexedTokenList = Dict[str, List[Any]] -class TokenIndexer: +class TokenIndexer(FromParameters): """ A `TokenIndexer` determines how string tokens get represented as arrays of indices in a model. This class both converts strings into numerical values, with the help of a @@ -123,3 +125,9 @@ class TokenIndexer: if isinstance(self, other.__class__): return self.__dict__ == other.__dict__ return NotImplemented + + @overrides + def _to_params(self) -> Dict[str, str]: + return { + 'token_min_padding_length': str(self._token_min_padding_length) + } diff --git a/combo/data/tokenizers/character_tokenizer.py b/combo/data/tokenizers/character_tokenizer.py index 1c873da9198887e02bfe6d655b82be6afb48dd66..51dc5577a9590e8be3454bc47e458390aba81bfc 100644 --- a/combo/data/tokenizers/character_tokenizer.py +++ b/combo/data/tokenizers/character_tokenizer.py @@ -5,10 +5,12 @@ https://github.com/allenai/allennlp/blob/main/allennlp/data/tokenizers/character from typing import List, Union, Dict, Any +from combo.config import Registry from combo.data.tokenizers.token import Token from combo.data.tokenizers.tokenizer import Tokenizer +@Registry.register(Tokenizer, 'character') class CharacterTokenizer(Tokenizer): """ A `CharacterTokenizer` splits strings into character tokens. diff --git a/combo/data/tokenizers/lambo_tokenizer.py b/combo/data/tokenizers/lambo_tokenizer.py index 43a5ac60baf1cfdb68eb982487235f2eeccda4c1..b3cec7c96c6c2d8cf893db2c63ccdabfd40096d6 100644 --- a/combo/data/tokenizers/lambo_tokenizer.py +++ b/combo/data/tokenizers/lambo_tokenizer.py @@ -1,10 +1,12 @@ from typing import List, Dict, Any +from combo.config import Registry from combo.data.tokenizers.token import Token from combo.data.tokenizers.tokenizer import Tokenizer from lambo.segmenter.lambo import Lambo +@Registry.register(Tokenizer, 'lambo') class LamboTokenizer(Tokenizer): def __init__( diff --git a/combo/data/tokenizers/pretrained_transformer_tokenizer.py b/combo/data/tokenizers/pretrained_transformer_tokenizer.py index 667246037178dbdcbc8fa6602e47b0139aed3624..2db074abaf3e904588d232ecb477d021f306857a 100644 --- a/combo/data/tokenizers/pretrained_transformer_tokenizer.py +++ b/combo/data/tokenizers/pretrained_transformer_tokenizer.py @@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple, Iterable from transformers import PreTrainedTokenizer, AutoTokenizer +from combo.config import Registry from combo.data.tokenizers.token import Token from combo.data.tokenizers.tokenizer import Tokenizer from combo.utils import sanitize_wordpiece @@ -17,6 +18,7 @@ from combo.utils import sanitize_wordpiece logger = logging.getLogger(__name__) +@Registry.register(Tokenizer, 'pretrained_transformer') class PretrainedTransformerTokenizer(Tokenizer): """ A `PretrainedTransformerTokenizer` uses a model from HuggingFace's diff --git a/combo/data/tokenizers/spacy_tokenizer.py b/combo/data/tokenizers/spacy_tokenizer.py index 927da8abb67400b4b4ac3d72da09b517026830f9..6844aa5a4bb8e28cedfbd58660767921ec4fb73b 100644 --- a/combo/data/tokenizers/spacy_tokenizer.py +++ b/combo/data/tokenizers/spacy_tokenizer.py @@ -8,11 +8,13 @@ from typing import List, Optional import spacy from spacy.tokens import Doc +from combo.config import Registry from combo.data.tokenizers.token import Token from combo.data.tokenizers.tokenizer import Tokenizer from combo.utils.spacy import get_spacy_model +@Registry.register(Tokenizer, 'spacy') class SpacyTokenizer(Tokenizer): """ A `Tokenizer` that uses spaCy's tokenizer. It's fast and reasonable - this is the @@ -155,7 +157,7 @@ class _WhitespaceSpacyTokenizer: as follows: nlp = spacy.load("en_core_web_md") # hack to replace tokenizer with a whitespace tokenizer - nlp.tokenizer = _WhitespaceSpacyTokenizer(nlp.vocab) + nlp.tokenizer = _WhitespaceSpacyTokenizer(nlp.vocabulary) ... use nlp("here is some text") as normal. """ diff --git a/combo/data/tokenizers/tokenizer.py b/combo/data/tokenizers/tokenizer.py index b1645015e2db08306ed2bc9a2e975266b205a359..5c4d66bb203ef80423834b57d16fce5588081e32 100644 --- a/combo/data/tokenizers/tokenizer.py +++ b/combo/data/tokenizers/tokenizer.py @@ -7,10 +7,12 @@ import logging from .token import Token from typing import List, Optional +from ...config import FromParameters + logger = logging.getLogger(__name__) -class Tokenizer: +class Tokenizer(FromParameters): """ A `Tokenizer` splits strings of text into tokens. Typically, this either splits text into word tokens or character tokens, and those are the two tokenizer subclasses we have implemented diff --git a/combo/data/tokenizers/whitespace_tokenizer.py b/combo/data/tokenizers/whitespace_tokenizer.py index a4802ded17e7f8da2ca7b3deecd7be00b55e7c1f..fe3413f812f7c8d9e12bb007fd7f4b5003fddbf1 100644 --- a/combo/data/tokenizers/whitespace_tokenizer.py +++ b/combo/data/tokenizers/whitespace_tokenizer.py @@ -1,9 +1,11 @@ from typing import List, Dict, Any +from combo.config import Registry from combo.data.tokenizers.token import Token from combo.data.tokenizers.tokenizer import Tokenizer +@Registry.register(Tokenizer, 'whitespace') class WhitespaceTokenizer(Tokenizer): """ A `Tokenizer` that assumes you've already done your own tokenization somehow and have diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index 8f455a92993113a5e9392719c1ecd82b7906405e..aea54bfb7d9afe21f1b7124d161ea7dc28a2f80a 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -3,13 +3,15 @@ import os import re import glob from collections import defaultdict -from typing import Dict, Optional, Iterable, Set, List +from typing import Dict, Optional, Iterable, Set, List, Union import logging from filelock import FileLock from transformers import PreTrainedTokenizer +from combo.common import Tqdm +from combo.config import FromParameters, Registry from combo.utils import ConfigurationError from combo.utils.file_utils import cached_path @@ -34,6 +36,24 @@ def match_namespace(pattern: str, namespace: str): return False +def _read_pretrained_tokens(embeddings_file_uri: str) -> List[str]: + # Moving this import to the top breaks everything (cycling import, I guess) + from combo.data.token_embedders.embedding import EmbeddingsTextFile + + logger.info("Reading pretrained tokens from: %s", embeddings_file_uri) + tokens: List[str] = [] + with EmbeddingsTextFile(embeddings_file_uri) as embeddings_file: + for line_number, line in enumerate(Tqdm.tqdm(embeddings_file), start=1): + token_end = line.find(" ") + if token_end >= 0: + token = line[:token_end] + tokens.append(token) + else: + line_begin = line[:20] + "..." if len(line) > 20 else line + logger.warning("Skipping line number %d: %s", line_number, line_begin) + return tokens + + class NamespaceVocabulary: def __init__(self, padding_token: Optional[str] = None, @@ -93,9 +113,16 @@ class _NamespaceDependentDefaultDict(defaultdict[str, NamespaceVocabulary]): return value -class Vocabulary: +class Vocabulary(FromParameters): def __init__(self, + counter: Dict[str, Dict[str, int]] = None, + min_count: Dict[str, int] = None, + max_vocab_size: Union[int, Dict[str, int]] = None, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES, + pretrained_files: Optional[Dict[str, str]] = None, + only_include_pretrained_words: bool = False, + tokens_to_add: Dict[str, List[str]] = None, + min_pretrained_embeddings: Dict[str, int] = None, padding_token: Optional[str] = DEFAULT_PADDING_TOKEN, oov_token: Optional[str] = DEFAULT_OOV_TOKEN): @@ -112,12 +139,309 @@ class Vocabulary: self._vocab = _NamespaceDependentDefaultDict(self._non_padded_namespaces, self._padding_token, self._oov_token) + self._retained_counter: Optional[Dict[str, Dict[str, int]]] = None def _extend(self, - tokens_to_add: Dict[str, Dict[str, int]]): + counter: Dict[str, Dict[str, int]] = None, + min_count: Dict[str, int] = None, + max_vocab_size: Union[int, Dict[str, int]] = None, + non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES, + pretrained_files: Optional[Dict[str, str]] = None, + only_include_pretrained_words: bool = False, + tokens_to_add: Dict[str, Dict[str, int]] = None, + min_pretrained_embeddings: Dict[str, int] = None): + if min_count is not None: + for key in min_count: + if counter is not None and key not in counter or counter is None: + raise ConfigurationError( + f"The key '{key}' is present in min_count but not in counter" + ) + if not isinstance(max_vocab_size, dict): + int_max_vocab_size = max_vocab_size + max_vocab_size = defaultdict(lambda: int_max_vocab_size) # type: ignore + + min_count = min_count or {} + pretrained_files = pretrained_files or {} + min_pretrained_embeddings = min_pretrained_embeddings or {} + non_padded_namespaces = set(non_padded_namespaces) + counter = counter or {} + tokens_to_add = tokens_to_add or {} + self._retained_counter = counter + # Make sure vocabulary extension is safe. + current_namespaces = {*self._vocab} + extension_namespaces = {*counter, *tokens_to_add} + + for namespace in current_namespaces & extension_namespaces: + # if new namespace was already present + # Either both should be padded or none should be. + original_padded = not any( + match_namespace(pattern, namespace) for pattern in self._non_padded_namespaces + ) + extension_padded = not any( + match_namespace(pattern, namespace) for pattern in non_padded_namespaces + ) + if original_padded != extension_padded: + raise ConfigurationError( + "Common namespace {} has conflicting ".format(namespace) + + "setting of padded = True/False. " + + "Hence extension cannot be done." + ) + + self._vocab.add_non_padded_namespaces(non_padded_namespaces) + self._non_padded_namespaces.update(non_padded_namespaces) + + for namespace in counter: + pretrained_set: Optional[Set] = None + if namespace in pretrained_files: + pretrained_list = _read_pretrained_tokens(pretrained_files[namespace]) + min_embeddings = min_pretrained_embeddings.get(namespace, 0) + if min_embeddings > 0 or min_embeddings == -1: + tokens_old = tokens_to_add.get(namespace, []) + tokens_new = ( + pretrained_list + if min_embeddings == -1 + else pretrained_list[:min_embeddings] + ) + tokens_to_add[namespace] = tokens_old + tokens_new + pretrained_set = set(pretrained_list) + token_counts = list(counter[namespace].items()) + token_counts.sort(key=lambda x: x[1], reverse=True) + max_vocab: Optional[int] + try: + max_vocab = max_vocab_size[namespace] + except KeyError: + max_vocab = None + if max_vocab: + token_counts = token_counts[:max_vocab] + for token, count in token_counts: + if pretrained_set is not None: + if only_include_pretrained_words: + if token in pretrained_set and count >= min_count.get(namespace, 1): + self.add_token_to_namespace(token, namespace) + elif token in pretrained_set or count >= min_count.get(namespace, 1): + self.add_token_to_namespace(token, namespace) + elif count >= min_count.get(namespace, 1): + self.add_token_to_namespace(token, namespace) + for namespace, tokens in tokens_to_add.items(): self._vocab[namespace].append_tokens(tokens) + @classmethod + def from_files(cls, + directory: Union[str, os.PathLike], + padding_token: Optional[str] = DEFAULT_PADDING_TOKEN, + oov_token: Optional[str] = DEFAULT_OOV_TOKEN, + ) -> "Vocabulary": + """ + Loads a `Vocabulary` that was serialized either using `save_to_files` or inside + a model archive file. + + # Parameters + + directory : `str` + The directory or archive file containing the serialized vocabulary. + """ + logger.info("Loading token dictionary from %s.", directory) + padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN + oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN + + if not os.path.isdir(directory): + base_directory = cached_path(directory, extract_archive=True) + # For convenience we'll check for a 'vocabulary' subdirectory of the archive. + # That way you can use model archives directly. + vocab_subdir = os.path.join(base_directory, "vocabulary") + if os.path.isdir(vocab_subdir): + directory = vocab_subdir + elif os.path.isdir(base_directory): + directory = base_directory + else: + raise ConfigurationError(f"{directory} is neither a directory nor an archive") + + files = [file for file in glob.glob(os.path.join(directory, '*.txt'))] + + if len(files) == 0: + logger.warning(f'Directory %s is empty' % directory) + + with FileLock(os.path.join(directory, ".lock")): + with codecs.open( + os.path.join(directory, NAMESPACE_PADDING_FILE), "r", "utf-8" + ) as namespace_file: + non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file] + + vocab = cls( + non_padded_namespaces=non_padded_namespaces, + padding_token=padding_token, + oov_token=oov_token, + ) + + for namespace_filename in os.listdir(directory): + if namespace_filename == NAMESPACE_PADDING_FILE: + continue + if namespace_filename.startswith("."): + continue + namespace = namespace_filename.replace(".txt", "") + if any(match_namespace(pattern, namespace) for pattern in non_padded_namespaces): + is_padded = False + else: + is_padded = True + filename = os.path.join(directory, namespace_filename) + vocab.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token) + + return vocab + + @classmethod + def from_instances( + cls, + instances: Iterable["Instance"], + min_count: Dict[str, int] = None, + max_vocab_size: Union[int, Dict[str, int]] = None, + non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES, + pretrained_files: Optional[Dict[str, str]] = None, + only_include_pretrained_words: bool = False, + tokens_to_add: Dict[str, List[str]] = None, + min_pretrained_embeddings: Dict[str, int] = None, + padding_token: Optional[str] = DEFAULT_PADDING_TOKEN, + oov_token: Optional[str] = DEFAULT_OOV_TOKEN, + ) -> "Vocabulary": + """ + Constructs a vocabulary given a collection of `Instances` and some parameters. + We count all of the vocabulary items in the instances, then pass those counts + and the other parameters, to :func:`__init__`. See that method for a description + of what the other parameters do. + + The `instances` parameter does not get an entry in a typical AllenNLP configuration file, + but the other parameters do (if you want non-default parameters). + """ + logger.info("Fitting token dictionary from dataset.") + padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN + oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN + namespace_token_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) + for instance in Tqdm.tqdm(instances, desc="building vocabulary"): + instance.count_vocab_items(namespace_token_counts) + + return cls( + counter=namespace_token_counts, + min_count=min_count, + max_vocab_size=max_vocab_size, + non_padded_namespaces=non_padded_namespaces, + pretrained_files=pretrained_files, + only_include_pretrained_words=only_include_pretrained_words, + tokens_to_add=tokens_to_add, + min_pretrained_embeddings=min_pretrained_embeddings, + padding_token=padding_token, + oov_token=oov_token, + ) + + @classmethod + def from_files_and_instances( + cls, + instances: Iterable["Instance"], + directory: str, + padding_token: Optional[str] = DEFAULT_PADDING_TOKEN, + oov_token: Optional[str] = DEFAULT_OOV_TOKEN, + min_count: Dict[str, int] = None, + max_vocab_size: Union[int, Dict[str, int]] = None, + non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES, + pretrained_files: Optional[Dict[str, str]] = None, + only_include_pretrained_words: bool = False, + tokens_to_add: Dict[str, List[str]] = None, + min_pretrained_embeddings: Dict[str, int] = None, + ) -> "Vocabulary": + """ + Extends an already generated vocabulary using a collection of instances. + + The `instances` parameter does not get an entry in a typical AllenNLP configuration file, + but the other parameters do (if you want non-default parameters). See `__init__` for a + description of what the other parameters mean. + """ + vocab = cls.from_files(directory, padding_token, oov_token) + logger.info("Fitting token dictionary from dataset.") + namespace_token_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) + for instance in Tqdm.tqdm(instances): + instance.count_vocab_items(namespace_token_counts) + vocab._extend( + counter=namespace_token_counts, + min_count=min_count, + max_vocab_size=max_vocab_size, + non_padded_namespaces=non_padded_namespaces, + pretrained_files=pretrained_files, + only_include_pretrained_words=only_include_pretrained_words, + tokens_to_add=tokens_to_add, + min_pretrained_embeddings=min_pretrained_embeddings, + ) + return vocab + + @classmethod + def from_pretrained_transformer_and_instances( + cls, + instances: Iterable["Instance"], + transformers: Dict[str, str], + min_count: Dict[str, int] = None, + max_vocab_size: Union[int, Dict[str, int]] = None, + non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES, + pretrained_files: Optional[Dict[str, str]] = None, + only_include_pretrained_words: bool = False, + tokens_to_add: Dict[str, List[str]] = None, + min_pretrained_embeddings: Dict[str, int] = None, + padding_token: Optional[str] = DEFAULT_PADDING_TOKEN, + oov_token: Optional[str] = DEFAULT_OOV_TOKEN, + ) -> "Vocabulary": + """ + Construct a vocabulary given a collection of `Instance`'s and some parameters. Then extends + it with generated vocabularies from pretrained transformers. + + Vocabulary from instances is constructed by passing parameters to :func:`from_instances`, + and then updated by including merging in vocabularies from + :func:`from_pretrained_transformer`. See other methods for full descriptions for what the + other parameters do. + + The `instances` parameters does not get an entry in a typical AllenNLP configuration file, + other parameters do (if you want non-default parameters). + + # Parameters + + transformers : `Dict[str, str]` + Dictionary mapping the vocabulary namespaces (keys) to a transformer model name (value). + Namespaces not included will be ignored. + + # Examples + + You can use this constructor by modifying the following example within your training + configuration. + + ```jsonnet + { + vocabulary: { + type: 'from_pretrained_transformer_and_instances', + transformers: { + 'namespace1': 'bert-base-cased', + 'namespace2': 'roberta-base', + }, + } + } + ``` + """ + vocab = cls.from_instances( + instances=instances, + min_count=min_count, + max_vocab_size=max_vocab_size, + non_padded_namespaces=non_padded_namespaces, + pretrained_files=pretrained_files, + only_include_pretrained_words=only_include_pretrained_words, + tokens_to_add=tokens_to_add, + min_pretrained_embeddings=min_pretrained_embeddings, + padding_token=padding_token, + oov_token=oov_token, + ) + + for namespace, model_name in transformers.items(): + transformer_vocab = cls.from_pretrained_transformer( + model_name=model_name, namespace=namespace + ) + vocab.extend_from_vocab(transformer_vocab) + + return vocab + def save_to_files(self, directory: str) -> None: """ Persist this Vocabulary to files, so it can be reloaded later. @@ -133,7 +457,7 @@ class Vocabulary: logger.warning("Directory %s is not empty", directory) # We use a lock file to avoid race conditions where multiple processes - # might be reading/writing from/to the same vocab files at once. + # might be reading/writing from/to the same vocabulary files at once. with FileLock(os.path.join(directory, ".lock")): with codecs.open( os.path.join(directory, NAMESPACE_PADDING_FILE), "w", "utf-8" @@ -190,7 +514,7 @@ class Vocabulary: self, tokenizer: PreTrainedTokenizer, namespace: str = "tokens" ) -> None: """ - Copies tokens from a transformer tokenizer's vocab into the given namespace. + Copies tokens from a transformer tokenizer's vocabulary into the given namespace. """ try: vocab_items = tokenizer.get_vocab().items() @@ -269,87 +593,75 @@ class Vocabulary: self._vocab[namespace].insert_token(token, index) if is_padded: - assert self._oov_token in self._vocab[namespace].get_itos(), "OOV token not found!" - - -class PretrainedTransformerVocabulary(Vocabulary): - def __init__(self, - model_name: str, - namespace: str = "tokens", - oov_token: Optional[str] = None): - """ - Initialize a vocabulary from the vocabulary of a pretrained transformer model. - If `oov_token` is not given, we will try to infer it from the transformer tokenizer. - """ - from combo.common import cached_transformers - - tokenizer = cached_transformers.get_tokenizer(model_name) - if oov_token is None: - if hasattr(tokenizer, "_unk_token"): - oov_token = tokenizer._unk_token - elif hasattr(tokenizer, "special_tokens_map"): - oov_token = tokenizer.special_tokens_map.get("unk_token") - - super().__init__(non_padded_namespaces=[namespace], - oov_token=oov_token) - self.add_transformer_vocab(tokenizer, namespace) - - -class FromFilesVocabulary(Vocabulary): - def __init__(self, - directory: str, - padding_token: Optional[str] = DEFAULT_PADDING_TOKEN, - oov_token: Optional[str] = DEFAULT_OOV_TOKEN) -> None: + assert self._oov_token in self._vocab[namespace].get_itos().values(), "OOV token not found!" + + +def get_slices_if_not_provided(vocab: Vocabulary): + if hasattr(vocab, "slices"): + return vocab.slices + + if "feats_labels" in vocab.get_namespaces(): + idx2token = vocab.get_index_to_token_vocabulary("feats_labels") + for _, v in dict(idx2token).items(): + if v not in ["_", "__PAD__"]: + empty_value = v.split("=")[0] + "=None" + vocab.add_token_to_namespace(empty_value, "feats_labels") + + slices = {} + for idx, name in vocab.get_index_to_token_vocabulary("feats_labels").items(): + # There are 2 types features: with (Case=Acc) or without assigment (None). + # Here we group their indices by name (before assigment sign). + name = name.split("=")[0] + if name in slices: + slices[name].append(idx) + else: + slices[name] = [idx] + vocab.slices = slices + return vocab.slices + + +@Registry.register(Vocabulary, "from_instances_extended") +class FromInstancesVocabulary(Vocabulary, FromParameters): + @classmethod + def from_instances_extended( + cls, + instances: Iterable["Instance"], + min_count: Dict[str, int] = None, + max_vocab_size: Union[int, Dict[str, int]] = None, + non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES, + pretrained_files: Optional[Dict[str, str]] = None, + only_include_pretrained_words: bool = False, + min_pretrained_embeddings: Dict[str, int] = None, + padding_token: Optional[str] = DEFAULT_PADDING_TOKEN, + oov_token: Optional[str] = DEFAULT_OOV_TOKEN, + ) -> "Vocabulary": """ - Adapted from https://github.com/allenai/allennlp/blob/main/allennlp/data/vocabulary.py - - :param directory: - :param padding_token: - :param oov_token: - :return: + Extension to manually fill gaps in missing 'feats_labels'. """ - logger.info("Loading token dictionary from %s.", directory) - padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN - oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN - - if not os.path.isdir(directory): - base_directory = cached_path(directory, extract_archive=True) - # For convenience we'll check for a 'vocabulary' subdirectory of the archive. - # That way you can use model archives directly. - vocab_subdir = os.path.join(base_directory, "vocabulary") - if os.path.isdir(vocab_subdir): - directory = vocab_subdir - elif os.path.isdir(base_directory): - directory = base_directory - else: - raise ConfigurationError(f"{directory} is neither a directory nor an archive") - - files = [file for file in glob.glob(os.path.join(directory, '*.txt'))] - - if len(files) == 0: - logger.warning(f'Directory %s is empty' % directory) - - with FileLock(os.path.join(directory, ".lock")): - with codecs.open( - os.path.join(directory, NAMESPACE_PADDING_FILE), "r", "utf-8" - ) as namespace_file: - non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file] - - super().__init__( - non_padded_namespaces=non_padded_namespaces, - padding_token=padding_token, - oov_token=oov_token, - ) - - for namespace_filename in os.listdir(directory): - if namespace_filename == NAMESPACE_PADDING_FILE: - continue - if namespace_filename.startswith("."): - continue - namespace = namespace_filename.replace(".txt", "") - if any(match_namespace(pattern, namespace) for pattern in non_padded_namespaces): - is_padded = False - else: - is_padded = True - filename = os.path.join(directory, namespace_filename) - self.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token) + # Load manually tokens from pretrained file (using different strategy + # - only words add all embedding file, without checking if were seen + # in any dataset. + tokens_to_add = None + if pretrained_files and "tokens" in pretrained_files: + pretrained_set = set(_read_pretrained_tokens(pretrained_files["tokens"])) + tokens_to_add = {"tokens": list(pretrained_set)} + pretrained_files = None + + vocab = super().from_instances( + instances=instances, + min_count=min_count, + max_vocab_size=max_vocab_size, + non_padded_namespaces=non_padded_namespaces, + pretrained_files=pretrained_files, + only_include_pretrained_words=only_include_pretrained_words, + tokens_to_add=tokens_to_add, + min_pretrained_embeddings=min_pretrained_embeddings, + padding_token=padding_token, + oov_token=oov_token + ) + # Extending vocabulary with features that does not show up explicitly. + # To know all features we need to read full dataset first. + # Adding auxiliary '=None' feature for each category is needed + # to perform classification. + get_slices_if_not_provided(vocab) + return vocab diff --git a/combo/example.ipynb b/combo/example.ipynb index 9597084963bd621d8ab5bab95c584bed04913131..02b203df488e957c0a3c82550ca7c19fa5a1810d 100644 --- a/combo/example.ipynb +++ b/combo/example.ipynb @@ -7,8 +7,8 @@ "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2023-09-04T14:36:56.659578Z", - "start_time": "2023-09-04T14:36:45.221897Z" + "end_time": "2023-09-22T07:28:27.168062Z", + "start_time": "2023-09-22T07:28:23.442292Z" } }, "outputs": [], @@ -16,15 +16,88 @@ "from combo.predict import COMBO" ] }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "error loading _jsonnet (this is expected on Windows), treating /var/folders/n6/5g0cjy9s6j9cgp2xbw_6c2v40000gn/T/tmpzyupif28/config.json as plain json\n" + ] + }, + { + "ename": "TypeError", + "evalue": "__init__() missing 2 required positional arguments: 'model' and 'dataset_reader'", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mKeyError\u001B[0m Traceback (most recent call last)", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/registry.py:48\u001B[0m, in \u001B[0;36mRegistry.resolve\u001B[0;34m(cls, category, class_name)\u001B[0m\n\u001B[1;32m 47\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m---> 48\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mcls\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m__classes\u001B[49m\u001B[43m[\u001B[49m\u001B[43mcategory\u001B[49m\u001B[43m]\u001B[49m\u001B[43m[\u001B[49m\u001B[43mclass_name\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mlower\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m]\u001B[49m\n\u001B[1;32m 49\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mKeyError\u001B[39;00m:\n", + "\u001B[0;31mKeyError\u001B[0m: 'combo_dependency_parsing_from_vocab'", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001B[0;31mKeyError\u001B[0m Traceback (most recent call last)", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/registry.py:51\u001B[0m, in \u001B[0;36mRegistry.resolve\u001B[0;34m(cls, category, class_name)\u001B[0m\n\u001B[1;32m 50\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m---> 51\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mcls\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m__defaults\u001B[49m\u001B[43m[\u001B[49m\u001B[43mcategory\u001B[49m\u001B[43m]\u001B[49m\n\u001B[1;32m 52\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mKeyError\u001B[39;00m:\n", + "\u001B[0;31mKeyError\u001B[0m: <class 'combo.modules.parser.DependencyRelationModel'>", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001B[0;31mRegistryException\u001B[0m Traceback (most recent call last)", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:41\u001B[0m, in \u001B[0;36m_resolve\u001B[0;34m(type, values, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m 40\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtype\u001B[39m\u001B[38;5;124m'\u001B[39m \u001B[38;5;129;01min\u001B[39;00m values\u001B[38;5;241m.\u001B[39mkeys():\n\u001B[0;32m---> 41\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mRegistry\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mresolve\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 42\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mtype\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[1;32m 43\u001B[0m \u001B[43m \u001B[49m\u001B[43mvalues\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mtype\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m]\u001B[49m\n\u001B[1;32m 44\u001B[0m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241m.\u001B[39mfrom_parameters({\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mvalues, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mpass_to_subclasses}, pass_to_subclasses)\n\u001B[1;32m 45\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/registry.py:53\u001B[0m, in \u001B[0;36mRegistry.resolve\u001B[0;34m(cls, category, class_name)\u001B[0m\n\u001B[1;32m 52\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mKeyError\u001B[39;00m:\n\u001B[0;32m---> 53\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m RegistryException(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mNo \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mstr\u001B[39m(category)\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m with registered name \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mclass_name\u001B[38;5;241m.\u001B[39mlower()\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m, no default value either\u001B[39m\u001B[38;5;124m'\u001B[39m)\n", + "\u001B[0;31mRegistryException\u001B[0m: No <class 'combo.modules.parser.DependencyRelationModel'> with registered name combo_dependency_parsing_from_vocab, no default value either", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001B[0;31mTypeError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[2], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m nlp \u001B[38;5;241m=\u001B[39m \u001B[43mCOMBO\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfrom_pretrained\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mpolish-pdb-ud29\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/predict.py:249\u001B[0m, in \u001B[0;36mCOMBO.from_pretrained\u001B[0;34m(cls, path, tokenizer, batch_size, cuda_device)\u001B[0m\n\u001B[1;32m 246\u001B[0m logger\u001B[38;5;241m.\u001B[39merror(e)\n\u001B[1;32m 247\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m e\n\u001B[0;32m--> 249\u001B[0m archive \u001B[38;5;241m=\u001B[39m \u001B[43mmodels\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload_archive\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmodel_path\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcuda_device\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcuda_device\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 250\u001B[0m model \u001B[38;5;241m=\u001B[39m archive\u001B[38;5;241m.\u001B[39mmodel\n\u001B[1;32m 251\u001B[0m dataset_reader_class \u001B[38;5;241m=\u001B[39m archive\u001B[38;5;241m.\u001B[39mconfig[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mdataset_reader\u001B[39m\u001B[38;5;124m\"\u001B[39m]\u001B[38;5;241m.\u001B[39mget(\n\u001B[1;32m 252\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtype\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mconllu\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/models/archival.py:199\u001B[0m, in \u001B[0;36mload_archive\u001B[0;34m(archive_file, cuda_device, overrides, weights_file)\u001B[0m\n\u001B[1;32m 195\u001B[0m \u001B[38;5;66;03m# Instantiate model and dataset readers. Use a duplicate of the config, as it will get consumed.\u001B[39;00m\n\u001B[1;32m 196\u001B[0m dataset_reader, validation_dataset_reader \u001B[38;5;241m=\u001B[39m _load_dataset_readers(\n\u001B[1;32m 197\u001B[0m config\u001B[38;5;241m.\u001B[39mduplicate(), serialization_dir\n\u001B[1;32m 198\u001B[0m )\n\u001B[0;32m--> 199\u001B[0m model \u001B[38;5;241m=\u001B[39m \u001B[43m_load_model\u001B[49m\u001B[43m(\u001B[49m\u001B[43mconfig\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mduplicate\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mweights_path\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mserialization_dir\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcuda_device\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 201\u001B[0m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[1;32m 202\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m tempdir \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/models/archival.py:234\u001B[0m, in \u001B[0;36m_load_model\u001B[0;34m(config, weights_path, serialization_dir, cuda_device)\u001B[0m\n\u001B[1;32m 233\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m_load_model\u001B[39m(config, weights_path, serialization_dir, cuda_device):\n\u001B[0;32m--> 234\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mModel\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 235\u001B[0m \u001B[43m \u001B[49m\u001B[43mconfig\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 236\u001B[0m \u001B[43m \u001B[49m\u001B[43mweights_file\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mweights_path\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 237\u001B[0m \u001B[43m \u001B[49m\u001B[43mserialization_dir\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mserialization_dir\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 238\u001B[0m \u001B[43m \u001B[49m\u001B[43mcuda_device\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcuda_device\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 239\u001B[0m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/modules/model.py:407\u001B[0m, in \u001B[0;36mModel.load\u001B[0;34m(cls, config, serialization_dir, weights_file, cuda_device)\u001B[0m\n\u001B[1;32m 401\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(model_class, \u001B[38;5;28mtype\u001B[39m):\n\u001B[1;32m 402\u001B[0m \u001B[38;5;66;03m# If you're using from_archive to specify your model (e.g., for fine tuning), then you\u001B[39;00m\n\u001B[1;32m 403\u001B[0m \u001B[38;5;66;03m# can't currently override the behavior of _load; we just use the default Model._load.\u001B[39;00m\n\u001B[1;32m 404\u001B[0m \u001B[38;5;66;03m# If we really need to change this, we would need to implement a recursive\u001B[39;00m\n\u001B[1;32m 405\u001B[0m \u001B[38;5;66;03m# get_model_class method, that recurses whenever it finds a from_archive model type.\u001B[39;00m\n\u001B[1;32m 406\u001B[0m model_class \u001B[38;5;241m=\u001B[39m Model\n\u001B[0;32m--> 407\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mmodel_class\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_load\u001B[49m\u001B[43m(\u001B[49m\u001B[43mconfig\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mserialization_dir\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mweights_file\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcuda_device\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/modules/model.py:305\u001B[0m, in \u001B[0;36mModel._load\u001B[0;34m(cls, config, serialization_dir, weights_file, cuda_device)\u001B[0m\n\u001B[1;32m 303\u001B[0m remove_keys_from_params(model_params)\n\u001B[1;32m 304\u001B[0m model_type \u001B[38;5;241m=\u001B[39m Registry\u001B[38;5;241m.\u001B[39mresolve(Model, model_params\u001B[38;5;241m.\u001B[39mget(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtype\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124msemantic_multitask\u001B[39m\u001B[38;5;124m'\u001B[39m))\n\u001B[0;32m--> 305\u001B[0m model \u001B[38;5;241m=\u001B[39m \u001B[43mmodel_type\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfrom_parameters\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 306\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mdict\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;28;43mdict\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mmodel_params\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mvocabulary\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mvocab\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mserialization_dir\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mserialization_dir\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m{\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mvocabulary\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m:\u001B[49m\u001B[43m \u001B[49m\u001B[43mvocab\u001B[49m\u001B[43m}\u001B[49m\n\u001B[1;32m 307\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 309\u001B[0m \u001B[38;5;66;03m# Force model to cpu or gpu, as appropriate, to make sure that the embeddings are\u001B[39;00m\n\u001B[1;32m 310\u001B[0m \u001B[38;5;66;03m# in sync with the weights\u001B[39;00m\n\u001B[1;32m 311\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m cuda_device \u001B[38;5;241m>\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m:\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:128\u001B[0m, in \u001B[0;36mFromParameters.from_parameters\u001B[0;34m(cls, parameters, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m 126\u001B[0m \u001B[38;5;28;01melif\u001B[39;00m typing\u001B[38;5;241m.\u001B[39mget_origin(type_annotation) \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m type_annotation\u001B[38;5;241m.\u001B[39m__base__ \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mobject\u001B[39m:\n\u001B[1;32m 127\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 128\u001B[0m parameters_to_pass[k] \u001B[38;5;241m=\u001B[39m \u001B[43m_resolve\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtype_annotation\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 129\u001B[0m \u001B[43m \u001B[49m\u001B[43mv\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdefault_name\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 130\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m ConfigurationError:\n\u001B[1;32m 131\u001B[0m warnings\u001B[38;5;241m.\u001B[39mwarn(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mAn object of type \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtype_annotation\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m is not in the registry!\u001B[39m\u001B[38;5;124m'\u001B[39m)\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:51\u001B[0m, in \u001B[0;36m_resolve\u001B[0;34m(type, values, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m 49\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m torch\u001B[38;5;241m.\u001B[39mnn\u001B[38;5;241m.\u001B[39mModule(\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mvalues)\n\u001B[1;32m 50\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mtype\u001B[39m\u001B[38;5;241m.\u001B[39m__base__ \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mobject\u001B[39m:\n\u001B[0;32m---> 51\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43m_resolve\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mtype\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m__base__\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mvalues\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdefault_name\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 52\u001B[0m warnings\u001B[38;5;241m.\u001B[39mwarn(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mNo classes of type \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mtype\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m (values: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mvalues\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m) in Registry!\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[1;32m 53\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m values\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:41\u001B[0m, in \u001B[0;36m_resolve\u001B[0;34m(type, values, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m 39\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m 40\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtype\u001B[39m\u001B[38;5;124m'\u001B[39m \u001B[38;5;129;01min\u001B[39;00m values\u001B[38;5;241m.\u001B[39mkeys():\n\u001B[0;32m---> 41\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mRegistry\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mresolve\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 42\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mtype\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[1;32m 43\u001B[0m \u001B[43m \u001B[49m\u001B[43mvalues\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mtype\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m]\u001B[49m\n\u001B[1;32m 44\u001B[0m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfrom_parameters\u001B[49m\u001B[43m(\u001B[49m\u001B[43m{\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mvalues\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m}\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 45\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 46\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m Registry\u001B[38;5;241m.\u001B[39mget_default(\u001B[38;5;28mtype\u001B[39m)\u001B[38;5;241m.\u001B[39mfrom_parameters({\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mvalues, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mpass_to_subclasses}, pass_to_subclasses)\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:128\u001B[0m, in \u001B[0;36mFromParameters.from_parameters\u001B[0;34m(cls, parameters, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m 126\u001B[0m \u001B[38;5;28;01melif\u001B[39;00m typing\u001B[38;5;241m.\u001B[39mget_origin(type_annotation) \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m type_annotation\u001B[38;5;241m.\u001B[39m__base__ \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mobject\u001B[39m:\n\u001B[1;32m 127\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 128\u001B[0m parameters_to_pass[k] \u001B[38;5;241m=\u001B[39m \u001B[43m_resolve\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtype_annotation\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 129\u001B[0m \u001B[43m \u001B[49m\u001B[43mv\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdefault_name\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 130\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m ConfigurationError:\n\u001B[1;32m 131\u001B[0m warnings\u001B[38;5;241m.\u001B[39mwarn(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mAn object of type \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtype_annotation\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m is not in the registry!\u001B[39m\u001B[38;5;124m'\u001B[39m)\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:46\u001B[0m, in \u001B[0;36m_resolve\u001B[0;34m(type, values, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m 41\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m Registry\u001B[38;5;241m.\u001B[39mresolve(\n\u001B[1;32m 42\u001B[0m \u001B[38;5;28mtype\u001B[39m,\n\u001B[1;32m 43\u001B[0m values[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtype\u001B[39m\u001B[38;5;124m'\u001B[39m]\n\u001B[1;32m 44\u001B[0m )\u001B[38;5;241m.\u001B[39mfrom_parameters({\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mvalues, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mpass_to_subclasses}, pass_to_subclasses)\n\u001B[1;32m 45\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m---> 46\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mRegistry\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mget_default\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mtype\u001B[39;49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfrom_parameters\u001B[49m\u001B[43m(\u001B[49m\u001B[43m{\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mvalues\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m}\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 47\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m RegistryException:\n\u001B[1;32m 48\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mtype\u001B[39m \u001B[38;5;129;01mis\u001B[39;00m torch\u001B[38;5;241m.\u001B[39mnn\u001B[38;5;241m.\u001B[39mModule:\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:136\u001B[0m, in \u001B[0;36mFromParameters.from_parameters\u001B[0;34m(cls, parameters, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m 133\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 134\u001B[0m parameters_to_pass[k] \u001B[38;5;241m=\u001B[39m v\n\u001B[0;32m--> 136\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mcls\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mparameters_to_pass\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/modules/parser.py:26\u001B[0m, in \u001B[0;36mHeadPredictionModel.__init__\u001B[0;34m(self, head_projection_layer, dependency_projection_layer, cycle_loss_n)\u001B[0m\n\u001B[1;32m 22\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m__init__\u001B[39m(\u001B[38;5;28mself\u001B[39m,\n\u001B[1;32m 23\u001B[0m head_projection_layer: base\u001B[38;5;241m.\u001B[39mLinear,\n\u001B[1;32m 24\u001B[0m dependency_projection_layer: base\u001B[38;5;241m.\u001B[39mLinear,\n\u001B[1;32m 25\u001B[0m cycle_loss_n: \u001B[38;5;28mint\u001B[39m \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m):\n\u001B[0;32m---> 26\u001B[0m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[38;5;21;43m__init__\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 27\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mhead_projection_layer \u001B[38;5;241m=\u001B[39m head_projection_layer\n\u001B[1;32m 28\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdependency_projection_layer \u001B[38;5;241m=\u001B[39m dependency_projection_layer\n", + "\u001B[0;31mTypeError\u001B[0m: __init__() missing 2 required positional arguments: 'model' and 'dataset_reader'" + ] + } + ], + "source": [ + "nlp = COMBO.from_pretrained(\"polish-pdb-ud29\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-09-22T07:28:45.877954Z", + "start_time": "2023-09-22T07:28:27.167654Z" + } + }, + "id": "f1d81f8aba12630b" + }, { "cell_type": "code", "execution_count": null, "outputs": [], + "source": [ + "nlp.predict('Cześć')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-09-22T07:26:16.004782Z" + } + }, + "id": "7302b7d49ac2fc38" + }, + { + "cell_type": "markdown", "source": [], "metadata": { "collapsed": false }, - "id": "f1d81f8aba12630b" + "id": "fa5abf02a486db75" } ], "metadata": { diff --git a/combo/main.py b/combo/main.py index 10f26482eed722ab4603b2c2d4347184c228c2a0..207b9fd6d9ec0dd0c021b59f62273ad9b09fabb5 100644 --- a/combo/main.py +++ b/combo/main.py @@ -1,15 +1,11 @@ import logging -import os import pathlib -import tempfile from typing import Dict -import torch from absl import app from absl import flags -from combo import models -from combo.models.base import Predictor +from combo.nn.base import Predictor from combo.utils import checks logger = logging.getLogger(__name__) diff --git a/combo/models/__init__.py b/combo/models/__init__.py index 8cc3df5c3039bee577217dc8f78eb5c9f21bcfe7..83092c413bf5e7fa815a93707403985e2cc8a027 100644 --- a/combo/models/__init__.py +++ b/combo/models/__init__.py @@ -1,11 +1,3 @@ -from .base import FeedForwardPredictor -from .graph_parser import GraphDependencyRelationModel -from .parser import DependencyRelationModel -from .embeddings import (GloVe6BEmbedder, GloVe840BEmbedder, GloVeTwitter27BEmbedder, - GloVe42BEmbedder, FastTextEmbedder, CharNGramEmbedder) -from .encoder import ComboEncoder -from .lemma import LemmatizerModel +from .encoder import ComboStackedBidirectionalLSTM, ComboEncoder from .combo_model import ComboModel -from .morpho import MorphologicalFeatures -from .model import Model -from .archival import * \ No newline at end of file +from .archival import * diff --git a/combo/models/archival.py b/combo/models/archival.py index 11e73146c9c7427b5dc6ad160f0b6bd59e1cffde..ee43a3a5e184b057e4aef97ab4fc20ba413204ee 100644 --- a/combo/models/archival.py +++ b/combo/models/archival.py @@ -15,8 +15,9 @@ import glob from torch.nn import Module from combo.common.params import Params +from combo.config import Registry from combo.data.dataset_readers import DatasetReader -from combo.models.model import Model +from combo.modules.model import Model from combo.utils import ConfigurationError from combo.utils.file_utils import cached_path @@ -219,11 +220,11 @@ def _load_dataset_readers(config, serialization_dir): "validation_dataset_reader", dataset_reader_params.duplicate() ) - dataset_reader = DatasetReader.from_params( - dataset_reader_params, serialization_dir=serialization_dir + dataset_reader = Registry.resolve(DatasetReader, dataset_reader_params.get('type', 'base')).from_parameters( + dict(**dataset_reader_params, serialization_dir=serialization_dir) ) - validation_dataset_reader = DatasetReader.from_params( - validation_dataset_reader_params, serialization_dir=serialization_dir + validation_dataset_reader = Registry.resolve(DatasetReader, validation_dataset_reader_params.get('type', 'base')).from_parameters( + dict(**validation_dataset_reader_params, serialization_dir=serialization_dir) ) return dataset_reader, validation_dataset_reader diff --git a/combo/models/combo_model.py b/combo/models/combo_model.py index dc94077a12e2e4c26e3fcb636d43fed61a3ac4f8..cdda10c237310cbd3ada1d5cc1fa0c36b2127e32 100644 --- a/combo/models/combo_model.py +++ b/combo/models/combo_model.py @@ -5,33 +5,44 @@ import torch from overrides import overrides from combo import data -from combo.models import base -from combo.models.embeddings import TokenEmbedder -from combo.models.model import Model -from combo.modules.seq2seq_encoder import Seq2SeqEncoder -from combo.nn import RegularizerApplicator +from combo.config import FromParameters, Registry +from combo.modules.morpho import MorphologicalFeatures +from combo.modules.lemma import LemmatizerModel +from combo.modules.parser import DependencyRelationModel +from combo.modules import TextFieldEmbedder +from combo.modules.model import Model +from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder +from combo.nn import RegularizerApplicator, base from combo.nn.util import get_text_field_mask from combo.utils import metrics - - -class ComboModel(Model): +# +# lemmatizer: Optional[LemmatizerModel] = None, +# upos_tagger: Optional[MorphologicalFeatures] = None, +# xpos_tagger: Optional[MorphologicalFeatures] = None, +# semantic_relation: Optional[base.Predictor] = None, +# morphological_feat: Optional[MorphologicalFeatures] = None, +# dependency_relation: Optional[DependencyRelationModel] = None, +# enhanced_dependency_relation: Optional[DependencyRelationModel] = None, + +@Registry.register(Model, "semantic_multitask") +class ComboModel(Model, FromParameters): """Main COMBO model.""" def __init__(self, - vocab: data.Vocabulary, + vocabulary: data.Vocabulary, loss_weights: Dict[str, float], - text_field_embedder: TokenEmbedder, + text_field_embedder: TextFieldEmbedder, seq_encoder: Seq2SeqEncoder, use_sample_weight: bool = True, - lemmatizer: Optional[base.Predictor] = None, - upos_tagger: Optional[base.Predictor] = None, - xpos_tagger: Optional[base.Predictor] = None, - semantic_relation: Optional[base.Predictor] = None, - morphological_feat: Optional[base.Predictor] = None, - dependency_relation: Optional[base.Predictor] = None, - enhanced_dependency_relation: Optional[base.Predictor] = None, + lemmatizer: LemmatizerModel = None, + upos_tagger: MorphologicalFeatures = None, + xpos_tagger: MorphologicalFeatures = None, + semantic_relation: base.Predictor = None, + morphological_feat: MorphologicalFeatures = None, + dependency_relation: DependencyRelationModel = None, + enhanced_dependency_relation: DependencyRelationModel = None, regularizer: RegularizerApplicator = None) -> None: - super().__init__(vocab, regularizer) + super().__init__(vocabulary, regularizer) self.text_field_embedder = text_field_embedder self.loss_weights = loss_weights self.use_sample_weight = use_sample_weight diff --git a/combo/models/combo_nn.py b/combo/models/combo_nn.py deleted file mode 100644 index 822c1cd665e7aba7b8b94ce05ead907412ced553..0000000000000000000000000000000000000000 --- a/combo/models/combo_nn.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch -import torch.nn as nn -from overrides import overrides - - -class Activation(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - raise NotImplementedError - - -class LinearActivation(Activation): - @overrides - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x diff --git a/combo/models/encoder.py b/combo/models/encoder.py index 1904cb82a02850b95764894e1bba25066af18322..12932adf579f2392b7c2ff09b2ae7482255e7ba1 100644 --- a/combo/models/encoder.py +++ b/combo/models/encoder.py @@ -9,15 +9,19 @@ import torch.nn.utils.rnn as rnn from overrides import overrides from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence +from combo.config import FromParameters, Registry from combo.modules import input_variational_dropout from combo.modules.augmented_lstm import AugmentedLstm from combo.modules.input_variational_dropout import InputVariationalDropout +from combo.modules.module import Module +from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder from combo.utils import ConfigurationError TensorPair = Tuple[torch.Tensor, torch.Tensor] -class StackedBidirectionalLstm(torch.nn.Module): +@Registry.register_base_class('stacked_bilstm', default=True) +class StackedBidirectionalLstm(torch.nn.Module, FromParameters): """ A standard stacked Bidirectional LSTM where the LSTM layers are concatenated between each layer. The only difference between @@ -53,26 +57,25 @@ class StackedBidirectionalLstm(torch.nn.Module): """ def __init__( - self, - input_size: int, - hidden_size: int, - num_layers: int, - recurrent_dropout_probability: float = 0.0, - layer_dropout_probability: float = 0.0, - use_highway: bool = True, + self, + input_size: int, + hidden_size: int, + num_layers: int, + recurrent_dropout_probability: float = 0.0, + layer_dropout_probability: float = 0.0, + use_highway: bool = True, ) -> None: super().__init__() # Required to be wrapped with a `PytorchSeq2SeqWrapper`. - self.input_size = input_size - self.hidden_size = hidden_size - self.num_layers = num_layers - self.bidirectional = True + self.__input_size = input_size + self.__hidden_size = hidden_size + self.__num_layers = num_layers + self.__bidirectional = True layers = [] lstm_input_size = input_size for layer_index in range(num_layers): - forward_layer = AugmentedLstm( lstm_input_size, hidden_size, @@ -97,8 +100,24 @@ class StackedBidirectionalLstm(torch.nn.Module): self.lstm_layers = layers self.layer_dropout = InputVariationalDropout(layer_dropout_probability) + @property + def input_size(self) -> int: + return self.__input_size + + @property + def hidden_size(self) -> int: + return self.__hidden_size + + @property + def num_layers(self) -> int: + return self.__num_layers + + @property + def bidirectional(self) -> int: + return self.__bidirectional + def forward( - self, inputs: PackedSequence, initial_state: Optional[TensorPair] = None + self, inputs: PackedSequence, initial_state: Optional[TensorPair] = None ) -> Tuple[PackedSequence, TensorPair]: """ # Parameters @@ -156,8 +175,11 @@ class StackedBidirectionalLstm(torch.nn.Module): return output_sequence, final_state_tuple + + # TODO: merge into one -class ComboStackedBidirectionalLSTM(StackedBidirectionalLstm): +@Registry.register_base_class('stacked_bilstm', default=True) +class ComboStackedBidirectionalLSTM(StackedBidirectionalLstm, FromParameters): def __init__(self, input_size: int, hidden_size: int, num_layers: int, recurrent_dropout_probability: float, layer_dropout_probability: float, use_highway: bool = False): @@ -168,7 +190,7 @@ class ComboStackedBidirectionalLSTM(StackedBidirectionalLstm): layer_dropout_probability=layer_dropout_probability, use_highway=use_highway) - @overrides + # @overrides def forward(self, inputs: rnn.PackedSequence, initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None @@ -203,7 +225,8 @@ class ComboStackedBidirectionalLSTM(StackedBidirectionalLstm): return output_sequence, (state_fwd, state_bwd) -class ComboEncoder: +@Registry.register(Seq2SeqEncoder, 'combo_encoder', default=True) +class ComboEncoder(Seq2SeqEncoder, FromParameters): """COMBO encoder (https://www.aclweb.org/anthology/K18-2004.pdf). This implementation uses Variational Dropout on the input and then outputs of each BiLSTM layer diff --git a/combo/modules/__init__.py b/combo/modules/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..77b73d53dcff8839504482201348385343e20c2c 100644 --- a/combo/modules/__init__.py +++ b/combo/modules/__init__.py @@ -0,0 +1,2 @@ +from .text_field_embedders import TextFieldEmbedder +from .token_embedders import * diff --git a/combo/modules/augmented_lstm.py b/combo/modules/augmented_lstm.py index 9e33d869d5713f63209f8a14ca1f05fddaed18d2..2131c95fbd842cf1ffed382a3f02ba4224bab8f2 100644 --- a/combo/modules/augmented_lstm.py +++ b/combo/modules/augmented_lstm.py @@ -7,13 +7,16 @@ from typing import Optional, Tuple import torch from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence +from combo.config import FromParameters, Registry +from combo.modules.module import Module from combo.nn.util import get_dropout_mask from combo.nn.initializers import block_orthogonal from combo.utils import ConfigurationError -class AugmentedLSTMCell(torch.nn.Module): +@Registry.register(torch.nn.Module, 'augumented_lstm_cell') +class AugmentedLSTMCell(Module, FromParameters): """ `AugmentedLSTMCell` implements a AugmentedLSTM cell. @@ -144,7 +147,8 @@ class AugmentedLSTMCell(torch.nn.Module): return timestep_output, memory -class AugmentedLstm(torch.nn.Module): +@Registry.register(torch.nn.Module, 'augumented_lstm') +class AugmentedLstm(Module, FromParameters): """ `AugmentedLstm` implements a one-layer single directional AugmentedLSTM layer. AugmentedLSTM is an LSTM which optionally diff --git a/combo/models/dilated_cnn.py b/combo/modules/dilated_cnn.py similarity index 77% rename from combo/models/dilated_cnn.py rename to combo/modules/dilated_cnn.py index 79ca6d9a952e1da104150b896696eed75aec4901..2e89942efd24a8ae8b06d09c58f3731efd6531d0 100644 --- a/combo/models/dilated_cnn.py +++ b/combo/modules/dilated_cnn.py @@ -6,12 +6,13 @@ Author: Mateusz Klimaszewski from typing import List import torch -import torch.nn as nn -from combo.models.combo_nn import Activation +from combo.config import FromParameters, Registry +from combo.nn.activations import Activation -class DilatedCnnEncoder(nn.Module): +@Registry.register_base_class('dilated_cnn', default=True) +class DilatedCnnEncoder(torch.nn.Module, FromParameters): def __init__(self, input_dim: int, @@ -26,14 +27,14 @@ class DilatedCnnEncoder(nn.Module): input_dims = [input_dim] + filters[:-1] output_dims = filters for idx in range(len(activations)): - conv1d_layers.append(nn.Conv1d( + conv1d_layers.append(torch.nn.Conv1d( in_channels=input_dims[idx], out_channels=output_dims[idx], kernel_size=(kernel_size[idx],), stride=(stride[idx],), padding=padding[idx], dilation=(dilation[idx],))) - self.conv1d_layers = nn.ModuleList(conv1d_layers) + self.conv1d_layers = torch.nn.ModuleList(conv1d_layers) self.activations = activations assert len(self.activations) == len(self.conv1d_layers) diff --git a/combo/modules/encoder.py b/combo/modules/encoder.py index bf258565bad0d442a1775a17706fb329caa27920..4d053a8e95bd7f8a7898183a9ec3b214f2d2fb07 100644 --- a/combo/modules/encoder.py +++ b/combo/modules/encoder.py @@ -7,6 +7,7 @@ from typing import Tuple, Union, Optional, Callable, Any import torch from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence +from combo.modules.module import Module from combo.nn.util import get_lengths_from_binary_sequence_mask, sort_batch_by_length # We have two types here for the state, because storing the state in something diff --git a/combo/models/graph_parser.py b/combo/modules/graph_parser.py similarity index 89% rename from combo/models/graph_parser.py rename to combo/modules/graph_parser.py index 6dffef52a943e4671aea7ac911b5c25981c3051f..49aefeef9f0ebc7fa714f1507f891fe1e6e91415 100644 --- a/combo/models/graph_parser.py +++ b/combo/modules/graph_parser.py @@ -6,8 +6,9 @@ Author: Mateusz Klimaszewski from typing import List, Optional, Union, Tuple, Dict from combo import data -from combo.models import base -from combo.models.base import Predictor +from combo.config import Registry +from combo.nn import base +from combo.nn.base import Predictor import torch import torch.nn.functional as F @@ -102,15 +103,24 @@ class GraphHeadPredictionModel(Predictor): return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean() +@Registry.register(Predictor, 'combo_graph_dependency_parsing_from_vocab') class GraphDependencyRelationModel(Predictor): """Dependency relation parsing model.""" def __init__(self, + vocab: data.Vocabulary, + vocab_namespace: str, head_predictor: GraphHeadPredictionModel, head_projection_layer: base.Linear, - dependency_projection_layer: base.Linear, - relation_prediction_layer: base.Linear): + dependency_projection_layer: base.Linear): + """Creates parser combining model configuration and vocabulary data.""" super().__init__() + assert vocab_namespace in vocab.get_namespaces() + relation_prediction_layer = base.Linear( + in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(), + out_features=vocab.get_vocab_size(vocab_namespace) + ) + self.head_predictor = head_predictor self.head_projection_layer = head_projection_layer self.dependency_projection_layer = dependency_projection_layer @@ -167,24 +177,3 @@ class GraphDependencyRelationModel(Predictor): pred = pred[correct_heads_mask] loss = F.cross_entropy(pred, true.long()) return loss.sum() / pred.size(0) - - @classmethod - def from_vocab(cls, - vocab: data.Vocabulary, - vocab_namespace: str, - head_predictor: GraphHeadPredictionModel, - head_projection_layer: base.Linear, - dependency_projection_layer: base.Linear - ): - """Creates parser combining model configuration and vocabulary data.""" - assert vocab_namespace in vocab.get_namespaces() - relation_prediction_layer = base.Linear( - in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(), - out_features=vocab.get_vocab_size(vocab_namespace) - ) - return cls( - head_predictor=head_predictor, - head_projection_layer=head_projection_layer, - dependency_projection_layer=dependency_projection_layer, - relation_prediction_layer=relation_prediction_layer - ) diff --git a/combo/modules/input_variational_dropout.py b/combo/modules/input_variational_dropout.py index 5744642be710cc136cf82eaff443b96cb4e66352..4433c1e37f0db07e5b738d5921c3c1a1100a60a9 100644 --- a/combo/modules/input_variational_dropout.py +++ b/combo/modules/input_variational_dropout.py @@ -5,8 +5,11 @@ https://github.com/allenai/allennlp/blob/main/allennlp/modules/input_variational import torch +from combo.config import FromParameters, Registry -class InputVariationalDropout(torch.nn.Dropout): + +@Registry.register(torch.nn.Module, 'input_variational_dropout') +class InputVariationalDropout(torch.nn.Dropout, FromParameters): """ Apply the dropout technique in Gal and Ghahramani, [Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning](https://arxiv.org/abs/1506.02142) to a diff --git a/combo/models/lemma.py b/combo/modules/lemma.py similarity index 72% rename from combo/models/lemma.py rename to combo/modules/lemma.py index d724a1ecb9c22610fc6ac56493929178d7a6cd5a..05e0cbf090c686a885df42796b6898e4740251f3 100644 --- a/combo/models/lemma.py +++ b/combo/modules/lemma.py @@ -4,23 +4,54 @@ import torch import torch.nn as nn from combo import data -from combo.models import dilated_cnn, base, utils -from combo.models.base import Predictor, TimeDistributed -from combo.models.combo_nn import Activation +from combo.config import Registry +from combo.models import utils +from combo.modules import dilated_cnn +from combo.nn import base +from combo.nn.base import Predictor +from combo.modules.time_distributed import TimeDistributed +from combo.nn.activations import Activation from combo.utils import ConfigurationError +@Registry.register(Predictor, 'combo_lemma_predictor_from_vocab') class LemmatizerModel(Predictor): """Lemmatizer model.""" def __init__(self, - num_embeddings: int, + vocab: data.Vocabulary, + char_vocab_namespace: str, + lemma_vocab_namespace: str, embedding_dim: int, - dilated_cnn_encoder: dilated_cnn.DilatedCnnEncoder, - input_projection_layer: base.Linear): + input_projection_layer: base.Linear, + filters: List[int], + kernel_size: List[int], + stride: List[int], + padding: List[int], + dilation: List[int], + activations: List[Activation], + ): + assert char_vocab_namespace in vocab.get_namespaces() + assert lemma_vocab_namespace in vocab.get_namespaces() super().__init__() + + if len(filters) + 1 != len(kernel_size): + raise ConfigurationError( + f"len(filters) ({len(filters):d}) + 1 != kernel_size ({len(kernel_size):d})" + ) + filters = filters + [vocab.get_vocab_size(lemma_vocab_namespace)] + + dilated_cnn_encoder = dilated_cnn.DilatedCnnEncoder( + input_dim=embedding_dim + input_projection_layer.get_output_dim(), + filters=filters, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + activations=activations, + ) self.char_embed = nn.Embedding( - num_embeddings=num_embeddings, + num_embeddings=vocab.get_vocab_size(char_vocab_namespace), embedding_dim=embedding_dim, ) self.dilated_cnn_encoder = TimeDistributed(dilated_cnn_encoder) @@ -68,40 +99,3 @@ class LemmatizerModel(Predictor): loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) valid_positions = mask.sum() return loss.sum() / valid_positions - - @classmethod - def from_vocab(cls, - vocab: data.Vocabulary, - char_vocab_namespace: str, - lemma_vocab_namespace: str, - embedding_dim: int, - input_projection_layer: base.Linear, - filters: List[int], - kernel_size: List[int], - stride: List[int], - padding: List[int], - dilation: List[int], - activations: List[Activation], - ): - assert char_vocab_namespace in vocab.get_namespaces() - assert lemma_vocab_namespace in vocab.get_namespaces() - - if len(filters) + 1 != len(kernel_size): - raise ConfigurationError( - f"len(filters) ({len(filters):d}) + 1 != kernel_size ({len(kernel_size):d})" - ) - filters = filters + [vocab.get_vocab_size(lemma_vocab_namespace)] - - dilated_cnn_encoder = dilated_cnn.DilatedCnnEncoder( - input_dim=embedding_dim + input_projection_layer.get_output_dim(), - filters=filters, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - activations=activations, - ) - return cls(num_embeddings=vocab.get_vocab_size(char_vocab_namespace), - embedding_dim=embedding_dim, - dilated_cnn_encoder=dilated_cnn_encoder, - input_projection_layer=input_projection_layer) diff --git a/combo/models/model.py b/combo/modules/model.py similarity index 93% rename from combo/models/model.py rename to combo/modules/model.py index 2dbc00a4cf80f8f4e5d9109646b9f11c8f3a0a7d..1bdf52cf0aad67cd85a00b726679a6fa75fa993e 100644 --- a/combo/models/model.py +++ b/combo/modules/model.py @@ -13,8 +13,10 @@ import numpy import torch from combo.common.params import remove_keys_from_params, Params +from combo.config import FromParameters, Registry from combo.data import Vocabulary, Instance from combo.data.batch import Batch +from combo.modules.module import Module from combo.nn import util, RegularizerApplicator from combo.utils import ConfigurationError @@ -25,7 +27,7 @@ logger = logging.getLogger(__name__) _DEFAULT_WEIGHTS = "best.th" -class Model(torch.nn.Module): +class Model(Module, FromParameters): """ This abstract class represents a model to be trained. Rather than relying completely on the Pytorch Module, we modify the output spec of `forward` to be a dictionary. @@ -53,7 +55,7 @@ class Model(torch.nn.Module): # Parameters - vocab: `Vocabulary` + vocabulary: `Vocabulary` There are two typical use-cases for the `Vocabulary` in a `Model`: getting vocabulary sizes when constructing embedding matrices or output classifiers (as the vocabulary holds the number of classes in your output, also), and translating model output into human-readable @@ -73,12 +75,12 @@ class Model(torch.nn.Module): def __init__( self, - vocab: Vocabulary, + vocabulary: Vocabulary, regularizer: RegularizerApplicator = None, serialization_dir: Optional[str] = None, ) -> None: - super().__init__() - self.vocab = vocab + super(Model, self).__init__() + self.vocab = vocabulary self._regularizer = regularizer self.serialization_dir = serialization_dir @@ -286,8 +288,8 @@ class Model(torch.nn.Module): vocab_dir = os.path.join(serialization_dir, "vocabulary") # If the config specifies a vocabulary subclass, we need to use it. vocab_params = config.get("vocabulary", Params({})) - vocab_choice = vocab_params.pop_choice("type", Vocabulary.list_available(), True) - vocab_class, _ = Vocabulary.resolve_class_name(vocab_choice) + vocab_choice = vocab_params.pop_choice("type", list(Registry.classes()[Vocabulary].keys()), True) + vocab_class = Registry.resolve(Vocabulary, vocab_choice) vocab = vocab_class.from_files( vocab_dir, vocab_params.get("padding_token"), vocab_params.get("oov_token") ) @@ -299,8 +301,9 @@ class Model(torch.nn.Module): # stored in our model. We don't need any pretrained weight file or initializers anymore, # and we don't want the code to look for it, so we remove it from the parameters here. remove_keys_from_params(model_params) - model = Model.from_params( - vocab=vocab, params=model_params, serialization_dir=serialization_dir + model_type = Registry.resolve(Model, model_params.get('type', 'semantic_multitask')) + model = model_type.from_parameters( + dict(**dict(model_params), vocabulary=vocab, serialization_dir=serialization_dir), {'vocabulary': vocab} ) # Force model to cpu or gpu, as appropriate, to make sure that the embeddings are @@ -310,12 +313,12 @@ class Model(torch.nn.Module): else: model.cpu() - # If vocab+embedding extension was done, the model initialized from from_params + # If vocabulary+embedding extension was done, the model initialized from from_params # and one defined by state dict in weights_file might not have same embedding shapes. - # Eg. when model embedder module was transferred along with vocab extension, the + # Eg. when model embedder module was transferred along with vocabulary extension, the # initialized embedding weight shape would be smaller than one in the state_dict. # So calling model embedding extension is required before load_state_dict. - # If vocab and model embeddings are in sync, following would be just a no-op. + # If vocabulary and model embeddings are in sync, following would be just a no-op. model.extend_embedder_vocab() # Load state dict. We pass `strict=False` so PyTorch doesn't raise a RuntimeError @@ -342,12 +345,12 @@ class Model(torch.nn.Module): filter_out_authorized_missing_keys(model) - if unexpected_keys or missing_keys: - raise RuntimeError( - f"Error loading state dict for {model.__class__.__name__}\n\t" - f"Missing keys: {missing_keys}\n\t" - f"Unexpected keys: {unexpected_keys}" - ) + # if unexpected_keys or missing_keys: + # raise RuntimeError( + # f"Error loading state dict for {model.__class__.__name__}\n\t" + # f"Missing keys: {missing_keys}\n\t" + # f"Unexpected keys: {unexpected_keys}" + # ) return model @@ -394,7 +397,7 @@ class Model(torch.nn.Module): # Load using an overridable _load method. # This allows subclasses of Model to override _load. - model_class: Type[Model] = cls.by_name(model_type) # type: ignore + model_class: Type[Model] = Registry.resolve(Model, model_type) # type: ignore if not isinstance(model_class, type): # If you're using from_archive to specify your model (e.g., for fine tuning), then you # can't currently override the behavior of _load; we just use the default Model._load. @@ -406,7 +409,7 @@ class Model(torch.nn.Module): def extend_embedder_vocab(self, embedding_sources_mapping: Dict[str, str] = None) -> None: """ Iterates through all embedding modules in the model and assures it can embed - with the extended vocab. This is required in fine-tuning or transfer learning + with the extended vocabulary. This is required in fine-tuning or transfer learning scenarios where model was trained with original vocabulary but during fine-tuning/transfer-learning, it will have it work with extended vocabulary (original + new-data vocabulary). @@ -440,7 +443,7 @@ class Model(torch.nn.Module): convenience, and so that we can register it for easy use for fine tuning an existing model from a config file. - If `vocab` is given, we will extend the loaded model's vocabulary using the passed vocab + If `vocabulary` is given, we will extend the loaded model's vocabulary using the passed vocabulary object (including calling `extend_embedder_vocab`, which extends embedding layers). """ from combo.models.archival import load_archive # here to avoid circular imports diff --git a/combo/modules/module.py b/combo/modules/module.py new file mode 100644 index 0000000000000000000000000000000000000000..6681257908e9606ce33a28ac87fc2bcdc90bf2c4 --- /dev/null +++ b/combo/modules/module.py @@ -0,0 +1,49 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/nn/module.py#L14 +""" + +from typing import List, Optional, Tuple + + +import torch + +from combo.nn.util import ( + _check_incompatible_keys, + _IncompatibleKeys, + StateDictType, +) + + +class Module(torch.nn.Module): + """ + This is just `torch.nn.Module` with some extra functionality. + """ + + def _post_load_state_dict( + self, missing_keys: List[str], unexpected_keys: List[str] + ) -> Tuple[List[str], List[str]]: + """ + Subclasses can override this and potentially modify `missing_keys` or `unexpected_keys`. + """ + return missing_keys, unexpected_keys + + def load_state_dict(self, state_dict: StateDictType, strict: bool = True) -> _IncompatibleKeys: + """ + Same as [`torch.nn.Module.load_state_dict()`] + (https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict) + except we also run the [`_post_load_state_dict`](#_post_load_state_dict) method before returning, + which can be implemented by subclasses to customize the behavior. + """ + missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False) # type: ignore[arg-type] + missing_keys, unexpected_keys = self._post_load_state_dict(missing_keys, unexpected_keys) + _check_incompatible_keys(self, missing_keys, unexpected_keys, strict) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + # def load_state_dict_distributed( + # self, state_dict: Optional[StateDictType], strict: bool = True + # ) -> _IncompatibleKeys: + # missing_keys, unexpected_keys = load_state_dict_distributed(self, state_dict, strict=strict) + # missing_keys, unexpected_keys = self._post_load_state_dict(missing_keys, unexpected_keys) + # _check_incompatible_keys(self, missing_keys, unexpected_keys, strict) + # return _IncompatibleKeys(missing_keys, unexpected_keys) \ No newline at end of file diff --git a/combo/models/morpho.py b/combo/modules/morpho.py similarity index 74% rename from combo/models/morpho.py rename to combo/modules/morpho.py index 5fb9545eeec5bc0049a51abb69ba479817410484..0ab78974c146cb3774968810a6e102617c3cf2f2 100644 --- a/combo/models/morpho.py +++ b/combo/modules/morpho.py @@ -6,18 +6,44 @@ from typing import Dict, List, Optional, Union import torch from combo import data +from combo.config import Registry from combo.data import dataset -from combo.models import base, utils -from combo.models.combo_nn import Activation +from combo.models import utils +from combo.nn import base +from combo.nn.activations import Activation +from combo.predictors.predictor import Predictor from combo.utils import ConfigurationError -class MorphologicalFeatures(base.Predictor): +@Registry.register(Predictor, 'combo_morpho_from_vocab') +class MorphologicalFeatures(Predictor): """Morphological features predicting model.""" - def __init__(self, feedforward_network: base.FeedForward, slices: Dict[str, List[int]]): + def __init__(self, + vocab: data.Vocabulary, + vocab_namespace: str, + input_dim: int, + num_layers: int, + hidden_dims: List[int], + activations: List[Activation], # change to Union[Activation, List[Activation]] + dropout: Union[float, List[float]] = 0.0, + ): super().__init__() - self.feedforward_network = feedforward_network + if len(hidden_dims) + 1 != num_layers: + raise ConfigurationError( + f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})" + ) + + assert vocab_namespace in vocab.get_namespaces() + hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)] + + slices = dataset.get_slices_if_not_provided(vocab) + self.feedforward_network = base.FeedForward( + input_dim=input_dim, + num_layers=num_layers, + hidden_dims=hidden_dims, + activations=activations, + dropout=dropout) self.slices = slices def forward(self, @@ -71,33 +97,3 @@ class MorphologicalFeatures(base.Predictor): mask) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) return loss.sum() / valid_positions - - @classmethod - def from_vocab(cls, - vocab: data.Vocabulary, - vocab_namespace: str, - input_dim: int, - num_layers: int, - hidden_dims: List[int], - activations: Union[Activation, List[Activation]], - dropout: Union[float, List[float]] = 0.0, - ): - if len(hidden_dims) + 1 != num_layers: - raise ConfigurationError( - f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})" - ) - - assert vocab_namespace in vocab.get_namespaces() - hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)] - - slices = dataset.get_slices_if_not_provided(vocab) - - return cls( - feedforward_network=base.FeedForward( - input_dim=input_dim, - num_layers=num_layers, - hidden_dims=hidden_dims, - activations=activations, - dropout=dropout), - slices=slices - ) diff --git a/combo/models/parser.py b/combo/modules/parser.py similarity index 89% rename from combo/models/parser.py rename to combo/modules/parser.py index 42d2efb3944a3e81188d6cf8f673dcb2cfb75002..2d8b470663f94b3a58b33dc6445273e7e6df0600 100644 --- a/combo/models/parser.py +++ b/combo/modules/parser.py @@ -9,11 +9,14 @@ import torch import torch.nn.functional as F from combo import data -from combo.models import base, utils -from combo.nn import chu_liu_edmonds +from combo.config import Registry +from combo.models import utils +from combo.nn import chu_liu_edmonds, base +from combo.predictors.predictor import Predictor -class HeadPredictionModel(base.Predictor): +@Registry.register_base_class( 'head_prediction', default=True) +class HeadPredictionModel(Predictor): """Head prediction model.""" def __init__(self, @@ -113,18 +116,26 @@ class HeadPredictionModel(base.Predictor): return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean() - -class DependencyRelationModel(base.Predictor): +@Registry.register(Predictor, 'combo_dependency_parsing_from_vocab') +class DependencyRelationModel(Predictor): """Dependency relation parsing model.""" def __init__(self, - root_idx: int, + vocab: data.Vocabulary, + vocab_namespace: str, head_predictor: HeadPredictionModel, head_projection_layer: base.Linear, - dependency_projection_layer: base.Linear, - relation_prediction_layer: base.Linear): + dependency_projection_layer: base.Linear + ): + """Creates parser combining model configuration and vocabulary data.""" super().__init__() - self.root_idx = root_idx + assert vocab_namespace in vocab.get_namespaces() + relation_prediction_layer = base.Linear( + in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(), + out_features=vocab.get_vocab_size(vocab_namespace) + ) + + self.root_idx = vocab.get_token_index("root", vocab_namespace) self.head_predictor = head_predictor self.head_projection_layer = head_projection_layer self.dependency_projection_layer = dependency_projection_layer @@ -200,24 +211,6 @@ class DependencyRelationModel(base.Predictor): loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) return loss.sum() / valid_positions - @classmethod - def from_vocab(cls, - vocab: data.Vocabulary, - vocab_namespace: str, - head_predictor: HeadPredictionModel, - head_projection_layer: base.Linear, - dependency_projection_layer: base.Linear - ): - """Creates parser combining model configuration and vocabulary data.""" - assert vocab_namespace in vocab.get_namespaces() - relation_prediction_layer = base.Linear( - in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(), - out_features=vocab.get_vocab_size(vocab_namespace) - ) - return cls( - head_predictor=head_predictor, - head_projection_layer=head_projection_layer, - dependency_projection_layer=dependency_projection_layer, - relation_prediction_layer=relation_prediction_layer, - root_idx=vocab.get_token_index("root", vocab_namespace) - ) + + +# combo_lemma_predictor_from_vocab \ No newline at end of file diff --git a/combo/modules/scalar_mix.py b/combo/modules/scalar_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e894e2b688deed4b68c47461f1dd9237913c9a --- /dev/null +++ b/combo/modules/scalar_mix.py @@ -0,0 +1,101 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/modules/scalar_mix.py +""" + +from typing import List + +import torch +from torch.nn import ParameterList, Parameter + +from combo.nn import util +from combo.utils import ConfigurationError + + +class ScalarMix(torch.nn.Module): + """ + Computes a parameterised scalar mixture of N tensors, `mixture = gamma * sum(s_k * tensor_k)` + where `s = softmax(w)`, with `w` and `gamma` scalar parameters. + + In addition, if `do_layer_norm=True` then apply layer normalization to each tensor + before weighting. + """ + + def __init__( + self, + mixture_size: int, + do_layer_norm: bool = False, + initial_scalar_parameters: List[float] = None, + trainable: bool = True, + ) -> None: + super().__init__() + self.mixture_size = mixture_size + self.do_layer_norm = do_layer_norm + + if initial_scalar_parameters is None: + initial_scalar_parameters = [0.0] * mixture_size + elif len(initial_scalar_parameters) != mixture_size: + raise ConfigurationError( + "Length of initial_scalar_parameters {} differs " + "from mixture_size {}".format(initial_scalar_parameters, mixture_size) + ) + + self.scalar_parameters = ParameterList( + [ + Parameter( + torch.FloatTensor([initial_scalar_parameters[i]]), requires_grad=trainable + ) + for i in range(mixture_size) + ] + ) + self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=trainable) + + def forward(self, tensors: List[torch.Tensor], mask: torch.BoolTensor = None) -> torch.Tensor: + """ + Compute a weighted average of the `tensors`. The input tensors an be any shape + with at least two dimensions, but must all be the same shape. + + When `do_layer_norm=True`, the `mask` is required input. If the `tensors` are + dimensioned `(dim_0, ..., dim_{n-1}, dim_n)`, then the `mask` is dimensioned + `(dim_0, ..., dim_{n-1})`, as in the typical case with `tensors` of shape + `(batch_size, timesteps, dim)` and `mask` of shape `(batch_size, timesteps)`. + + When `do_layer_norm=False` the `mask` is ignored. + """ + if len(tensors) != self.mixture_size: + raise ConfigurationError( + "{} tensors were passed, but the module was initialized to " + "mix {} tensors.".format(len(tensors), self.mixture_size) + ) + + def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked): + tensor_masked = tensor * broadcast_mask + mean = torch.sum(tensor_masked) / num_elements_not_masked + variance = ( + torch.sum(((tensor_masked - mean) * broadcast_mask) ** 2) / num_elements_not_masked + ) + return (tensor - mean) / torch.sqrt(variance + util.tiny_value_of_dtype(variance.dtype)) + + normed_weights = torch.nn.functional.softmax( + torch.cat([parameter for parameter in self.scalar_parameters]), dim=0 + ) + normed_weights = torch.split(normed_weights, split_size_or_sections=1) + + if not self.do_layer_norm: + pieces = [] + for weight, tensor in zip(normed_weights, tensors): + pieces.append(weight * tensor) + return self.gamma * sum(pieces) + + else: + assert mask is not None + broadcast_mask = mask.unsqueeze(-1) + input_dim = tensors[0].size(-1) + num_elements_not_masked = torch.sum(mask) * input_dim + + pieces = [] + for weight, tensor in zip(normed_weights, tensors): + pieces.append( + weight * _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked) + ) + return self.gamma * sum(pieces) diff --git a/combo/modules/seq2seq_encoder.py b/combo/modules/seq2seq_encoder.py deleted file mode 100644 index 71413f3c5f0539caf337e10e7c05c09e50dd5c19..0000000000000000000000000000000000000000 --- a/combo/modules/seq2seq_encoder.py +++ /dev/null @@ -1,33 +0,0 @@ -class Seq2SeqEncoder: - """ - A `Seq2SeqEncoder` is a `Module` that takes as input a sequence of vectors and returns a - modified sequence of vectors. Input shape : `(batch_size, sequence_length, input_dim)`; output - shape : `(batch_size, sequence_length, output_dim)`. - - We add two methods to the basic `Module` API: `get_input_dim()` and `get_output_dim()`. - You might need this if you want to construct a `Linear` layer using the output of this encoder, - or to raise sensible errors for mis-matching input dimensions. - """ - - def get_input_dim(self) -> int: - """ - Returns the dimension of the vector input for each element in the sequence input - to a `Seq2SeqEncoder`. This is `not` the shape of the input tensor, but the - last element of that shape. - """ - raise NotImplementedError - - def get_output_dim(self) -> int: - """ - Returns the dimension of each vector in the sequence output by this `Seq2SeqEncoder`. - This is `not` the shape of the returned tensor, but the last element of that shape. - """ - raise NotImplementedError - - def is_bidirectional(self) -> bool: - """ - Returns `True` if this encoder is bidirectional. If so, we assume the forward direction - of the encoder is the first half of the final dimension, and the backward direction is the - second half. - """ - raise NotImplementedError diff --git a/combo/modules/seq2seq_encoders/__init__.py b/combo/modules/seq2seq_encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/combo/modules/seq2seq_encoders/seq2seq_encoder.py b/combo/modules/seq2seq_encoders/seq2seq_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..78cc6482639cd9df51615e67fbfcd1f14b82c2f1 --- /dev/null +++ b/combo/modules/seq2seq_encoders/seq2seq_encoder.py @@ -0,0 +1,103 @@ +import torch +from torch.nn.utils.rnn import pad_packed_sequence + +from combo.modules.encoder import _EncoderBase +from combo.config.from_parameters import FromParameters +from combo.utils import ConfigurationError + + +class Seq2SeqEncoder(_EncoderBase, FromParameters): + """ + A `Seq2SeqEncoder` is a `Module` that takes as input a sequence of vectors and returns a + modified sequence of vectors. Input shape : `(batch_size, sequence_length, input_dim)`; output + shape : `(batch_size, sequence_length, output_dim)`. + + We add two methods to the basic `Module` API: `get_input_dim()` and `get_output_dim()`. + You might need this if you want to construct a `Linear` layer using the output of this encoder, + or to raise sensible errors for mis-matching input dimensions. + """ + + def __init__(self, module: torch.nn.Module, stateful: bool = False) -> None: + super().__init__(stateful) + self._module = module + try: + if not self._module.batch_first: + raise ConfigurationError("Our encoder semantics assumes batch is always first!") + except AttributeError: + pass + + try: + self._is_bidirectional = self._module.bidirectional + except AttributeError: + self._is_bidirectional = False + if self._is_bidirectional: + self._num_directions = 2 + else: + self._num_directions = 1 + + def get_input_dim(self) -> int: + return self._module.input_size + + def get_output_dim(self) -> int: + return self._module.hidden_size * self._num_directions + + def is_bidirectional(self) -> bool: + return self._is_bidirectional + + def forward( + self, inputs: torch.Tensor, mask: torch.BoolTensor, hidden_state: torch.Tensor = None + ) -> torch.Tensor: + + if self.stateful and mask is None: + raise ValueError("Always pass a mask with stateful RNNs.") + if self.stateful and hidden_state is not None: + raise ValueError("Stateful RNNs provide their own initial hidden_state.") + + if mask is None: + return self._module(inputs, hidden_state)[0] + + batch_size, total_sequence_length = mask.size() + + packed_sequence_output, final_states, restoration_indices = self.sort_and_run_forward( + self._module, inputs, mask, hidden_state + ) + + unpacked_sequence_tensor, _ = pad_packed_sequence(packed_sequence_output, batch_first=True) + + num_valid = unpacked_sequence_tensor.size(0) + # Some RNNs (GRUs) only return one state as a Tensor. Others (LSTMs) return two. + # If one state, use a single element list to handle in a consistent manner below. + if not isinstance(final_states, (list, tuple)) and self.stateful: + final_states = [final_states] + + # Add back invalid rows. + if num_valid < batch_size: + _, length, output_dim = unpacked_sequence_tensor.size() + zeros = unpacked_sequence_tensor.new_zeros(batch_size - num_valid, length, output_dim) + unpacked_sequence_tensor = torch.cat([unpacked_sequence_tensor, zeros], 0) + + # The states also need to have invalid rows added back. + if self.stateful: + new_states = [] + for state in final_states: + num_layers, _, state_dim = state.size() + zeros = state.new_zeros(num_layers, batch_size - num_valid, state_dim) + new_states.append(torch.cat([state, zeros], 1)) + final_states = new_states + + # It's possible to need to pass sequences which are padded to longer than the + # max length of the sequence to a Seq2SeqEncoder. However, packing and unpacking + # the sequences mean that the returned tensor won't include these dimensions, because + # the RNN did not need to process them. We add them back on in the form of zeros here. + sequence_length_difference = total_sequence_length - unpacked_sequence_tensor.size(1) + if sequence_length_difference > 0: + zeros = unpacked_sequence_tensor.new_zeros( + batch_size, sequence_length_difference, unpacked_sequence_tensor.size(-1) + ) + unpacked_sequence_tensor = torch.cat([unpacked_sequence_tensor, zeros], 1) + + if self.stateful: + self._update_states(final_states, restoration_indices) + + # Restore the original indices and return the sequence. + return unpacked_sequence_tensor.index_select(0, restoration_indices) diff --git a/combo/modules/text_field_embedders/__init__.py b/combo/modules/text_field_embedders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..705155749be42b0d015c348b24846f78933d2f3c --- /dev/null +++ b/combo/modules/text_field_embedders/__init__.py @@ -0,0 +1,2 @@ +from .text_field_embedder import TextFieldEmbedder +from .basic_text_field_embedder import BasicTextFieldEmbedder diff --git a/combo/modules/text_field_embedders/basic_text_field_embedder.py b/combo/modules/text_field_embedders/basic_text_field_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..56a5cbe93b10d8d4efbe7275e39e83139c1fbf87 --- /dev/null +++ b/combo/modules/text_field_embedders/basic_text_field_embedder.py @@ -0,0 +1,117 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/modules/text_field_embedders/basic_text_field_embedder.py +""" +from typing import Dict +import inspect + +import torch + +from combo.common.params import Params +from combo.config import Registry +from combo.data.fields.text_field import TextFieldTensors +from combo.modules.text_field_embedders.text_field_embedder import TextFieldEmbedder +from combo.modules.time_distributed import TimeDistributed +from combo.modules.token_embedders import EmptyEmbedder +from combo.modules.token_embedders.token_embedder import TokenEmbedder +from combo.utils import ConfigurationError + + +@Registry.register(TextFieldEmbedder, "basic") +class BasicTextFieldEmbedder(TextFieldEmbedder): + """ + This is a `TextFieldEmbedder` that wraps a collection of + [`TokenEmbedder`](../token_embedders/token_embedder.md) objects. Each + `TokenEmbedder` embeds or encodes the representation output from one + [`allennlp.data.TokenIndexer`](../../data/token_indexers/token_indexer.md). As the data produced by a + [`allennlp.data.fields.TextField`](../../data/fields/text_field.md) is a dictionary mapping names to these + representations, we take `TokenEmbedders` with corresponding names. Each `TokenEmbedders` + embeds its input, and the result is concatenated in an arbitrary (but consistent) order. + + Registered as a `TextFieldEmbedder` with name "basic", which is also the default. + + # Parameters + + token_embedders : `Dict[str, TokenEmbedder]`, required. + A dictionary mapping token embedder names to implementations. + These names should match the corresponding indexer used to generate + the tensor passed to the TokenEmbedder. + """ + + def __init__(self, token_embedders: Dict[str, TokenEmbedder]) -> None: + super().__init__() + # NOTE(mattg): I'd prefer to just use ModuleDict(token_embedders) here, but that changes + # weight locations in torch state dictionaries and invalidates all prior models, just for a + # cosmetic change in the code. + self._token_embedders = token_embedders + for key, embedder in token_embedders.items(): + name = "token_embedder_%s" % key + if isinstance(embedder, Params): + embedder_params = dict(embedder) + embedder = Registry.resolve( + TokenEmbedder, embedder_params.get("type", "basic") + ).from_parameters(embedder_params) + self.add_module(name, embedder) + self._ordered_embedder_keys = sorted(self._token_embedders.keys()) + + def get_output_dim(self) -> int: + output_dim = 0 + for embedder in self._token_embedders.values(): + output_dim += embedder.get_output_dim() + return output_dim + + def forward( + self, text_field_input: TextFieldTensors, num_wrapping_dims: int = 0, **kwargs + ) -> torch.Tensor: + if sorted(self._token_embedders.keys()) != sorted(text_field_input.keys()): + message = "Mismatched token keys: %s and %s" % ( + str(self._token_embedders.keys()), + str(text_field_input.keys()), + ) + embedder_keys = set(self._token_embedders.keys()) + input_keys = set(text_field_input.keys()) + if embedder_keys > input_keys and all( + isinstance(embedder, EmptyEmbedder) + for name, embedder in self._token_embedders.items() + if name in embedder_keys - input_keys + ): + # Allow extra embedders that are only in the token embedders (but not input) and are empty to pass + # config check + pass + else: + raise ConfigurationError(message) + + embedded_representations = [] + for key in self._ordered_embedder_keys: + # Note: need to use getattr here so that the pytorch voodoo + # with submodules works with multiple GPUs. + embedder = getattr(self, "token_embedder_{}".format(key)) + if isinstance(embedder, EmptyEmbedder): + # Skip empty embedders + continue + forward_params = inspect.signature(embedder.forward).parameters + forward_params_values = {} + missing_tensor_args = set() + for param in forward_params.keys(): + if param in kwargs: + forward_params_values[param] = kwargs[param] + else: + missing_tensor_args.add(param) + + for _ in range(num_wrapping_dims): + embedder = TimeDistributed(embedder) + + tensors: Dict[str, torch.Tensor] = text_field_input[key] + if len(tensors) == 1 and len(missing_tensor_args) == 1: + # If there's only one tensor argument to the embedder, and we just have one tensor to + # embed, we can just pass in that tensor, without requiring a name match. + token_vectors = embedder(list(tensors.values())[0], **forward_params_values) + else: + # If there are multiple tensor arguments, we have to require matching names from the + # TokenIndexer. I don't think there's an easy way around that. + token_vectors = embedder(**tensors, **forward_params_values) + if token_vectors is not None: + # To handle some very rare use cases, we allow the return value of the embedder to + # be None; we just skip it in that case. + embedded_representations.append(token_vectors) + return torch.cat(embedded_representations, dim=-1) diff --git a/combo/modules/text_field_embedders/text_field_embedder.py b/combo/modules/text_field_embedders/text_field_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..f44c45ffe8abd1aace807057978d51229f1418a0 --- /dev/null +++ b/combo/modules/text_field_embedders/text_field_embedder.py @@ -0,0 +1,55 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/modules/text_field_embedders/text_field_embedder.py +""" + +import torch + +from combo.config import FromParameters +from combo.data.fields.text_field import TextFieldTensors + + +class TextFieldEmbedder(torch.nn.Module, FromParameters): + """ + A `TextFieldEmbedder` is a `Module` that takes as input the + [`DataArray`](../../data/fields/text_field.md) produced by a [`TextField`](../../data/fields/text_field.md) and + returns as output an embedded representation of the tokens in that field. + + The `DataArrays` produced by `TextFields` are _dictionaries_ with named representations, like + "words" and "characters". When you create a `TextField`, you pass in a dictionary of + [`TokenIndexer`](../../data/token_indexers/token_indexer.md) objects, telling the field how exactly the + tokens in the field should be represented. This class changes the type signature of `Module.forward`, + restricting `TextFieldEmbedders` to take inputs corresponding to a single `TextField`, which is + a dictionary of tensors with the same names as were passed to the `TextField`. + + We also add a method to the basic `Module` API: `get_output_dim()`. You might need this + if you want to construct a `Linear` layer using the output of this embedder, for instance. + """ + + default_implementation = "basic" + + def forward( + self, text_field_input: TextFieldTensors, num_wrapping_dims: int = 0, **kwargs + ) -> torch.Tensor: + """ + # Parameters + + text_field_input : `TextFieldTensors` + A dictionary that was the output of a call to `TextField.as_tensor`. Each tensor in + here is assumed to have a shape roughly similar to `(batch_size, sequence_length)` + (perhaps with an extra trailing dimension for the characters in each token). + num_wrapping_dims : `int`, optional (default=`0`) + If you have a `ListField[TextField]` that created the `text_field_input`, you'll + end up with tensors of shape `(batch_size, wrapping_dim1, wrapping_dim2, ..., + sequence_length)`. This parameter tells us how many wrapping dimensions there are, so + that we can correctly `TimeDistribute` the embedding of each named representation. + """ + raise NotImplementedError + + def get_output_dim(self) -> int: + """ + Returns the dimension of the vector representing each token in the output of this + `TextFieldEmbedder`. This is _not_ the shape of the returned tensor, but the last element + of that shape. + """ + raise NotImplementedError diff --git a/combo/modules/time_distributed.py b/combo/modules/time_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..e7975d737d3ccfbb082498c7b81cc17afdac379e --- /dev/null +++ b/combo/modules/time_distributed.py @@ -0,0 +1,73 @@ +""" +Adapted from AllenNLP +""" +from typing import List + +import torch +from overrides import overrides + +from combo.config import Registry, FromParameters +from combo.modules.module import Module + + +@Registry.register(Module, 'time_distributed') +class TimeDistributed(Module, FromParameters): + """ + Given an input shaped like `(batch_size, time_steps, [rest])` and a `Module` that takes + inputs like `(batch_size, [rest])`, `TimeDistributed` reshapes the input to be + `(batch_size * time_steps, [rest])`, applies the contained `Module`, then reshapes it back. + + Note that while the above gives shapes with `batch_size` first, this `Module` also works if + `batch_size` is second - we always just combine the first two dimensions, then split them. + + It also reshapes keyword arguments unless they are not tensors or their name is specified in + the optional `pass_through` iterable. + """ + + def __init__(self, module: torch.nn.Module): + super().__init__() + self._module = module + + @overrides + def forward(self, *inputs, pass_through: List[str] = None, **kwargs): + + pass_through = pass_through or [] + + reshaped_inputs = [self._reshape_tensor(input_tensor) for input_tensor in inputs] + + # Need some input to then get the batch_size and time_steps. + some_input = None + if inputs: + some_input = inputs[-1] + + reshaped_kwargs = {} + for key, value in kwargs.items(): + if isinstance(value, torch.Tensor) and key not in pass_through: + if some_input is None: + some_input = value + + value = self._reshape_tensor(value) + + reshaped_kwargs[key] = value + + reshaped_outputs = self._module(*reshaped_inputs, **reshaped_kwargs) + + if some_input is None: + raise RuntimeError("No input tensor to time-distribute") + + # Now get the output back into the right shape. + # (batch_size, time_steps, **output_size) + new_size = some_input.size()[:2] + reshaped_outputs.size()[1:] + outputs = reshaped_outputs.contiguous().view(new_size) + + return outputs + + @staticmethod + def _reshape_tensor(input_tensor): + input_size = input_tensor.size() + if len(input_size) <= 2: + raise RuntimeError(f"No dimension to distribute: {input_size}") + # Squash batch_size and time_steps into a single axis; result has shape + # (batch_size * time_steps, **input_size). + squashed_shape = [-1] + list(input_size[2:]) + return input_tensor.contiguous().view(*squashed_shape) diff --git a/combo/modules/token_embedders/__init__.py b/combo/modules/token_embedders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..75d4b18645671603b131aea2d1d6118668e4c0e2 --- /dev/null +++ b/combo/modules/token_embedders/__init__.py @@ -0,0 +1,7 @@ +from .token_embedder import * +from .empty_embedder import EmptyEmbedder +from .character_token_embedder import CharacterBasedWordEmbedder +from .pretrained_transformer_embedder import PretrainedTransformerEmbedder +from .pretrained_transformer_mismatched_embedder import PretrainedTransformerMismatchedEmbedder +from .transformers_words_embeddings import TransformersWordEmbedder +from .projected_words_embedder import ProjectedWordEmbedder diff --git a/combo/modules/token_embedders/character_token_embedder.py b/combo/modules/token_embedders/character_token_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..b066cea88a24562284105f538b2365d9d3055196 --- /dev/null +++ b/combo/modules/token_embedders/character_token_embedder.py @@ -0,0 +1,53 @@ +""" +Adapted from COMBO +Author: Mateusz Klimaszewski +https://gitlab.clarin-pl.eu/syntactic-tools/combo/-/blob/master/combo/models/embeddings.py +""" +from overrides import overrides + +from combo.config import Registry +from combo.data import Vocabulary +from combo.modules.dilated_cnn import DilatedCnnEncoder +from combo.modules.token_embedders import TokenEmbedder + +"""Embeddings.""" +from typing import Optional + +import torch +import torch.nn as nn +from combo import modules + + +# @token_embedders.TokenEmbedder.register("char_embeddings") +@Registry.register(TokenEmbedder, "char_embeddings_from_config") +class CharacterBasedWordEmbedder(TokenEmbedder): + """Character-based word embeddings.""" + + def __init__(self, + embedding_dim: int, + vocabulary: Vocabulary, + dilated_cnn_encoder: DilatedCnnEncoder, + vocab_namespace: str = "token_characters"): + assert vocab_namespace in vocabulary.get_namespaces() + super().__init__() + self.char_embed = nn.Embedding( + num_embeddings=vocabulary.get_vocab_size(vocab_namespace), + embedding_dim=embedding_dim, + ) + self.dilated_cnn_encoder = modules.TimeDistributed(dilated_cnn_encoder) + self.output_dim = embedding_dim + + def forward(self, + x: torch.Tensor, + char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor: + if char_mask is None: + char_mask = x.new_ones(x.size()) + + x = self.char_embed(x) + x = x * char_mask.unsqueeze(-1).float() + x = self.dilated_cnn_encoder(x.transpose(2, 3)) + return torch.max(x, dim=-1)[0] + + @overrides + def get_output_dim(self) -> int: + return self.output_dim diff --git a/combo/modules/token_embedders/embedding.py b/combo/modules/token_embedders/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7a628e1be9d3230f06e2d5b29c70c49e6cc3894f --- /dev/null +++ b/combo/modules/token_embedders/embedding.py @@ -0,0 +1,675 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/modules/token_embedders/token_embedder.py +""" +import io +import itertools +import logging +import re +import tarfile +import warnings +import zipfile +from typing import Any, cast, Iterator, NamedTuple, Optional, Sequence, Tuple, BinaryIO + +import numpy +import torch + +from torch.nn.functional import embedding + +from combo.common import Tqdm +from combo.config import Registry +from combo.utils import ConfigurationError +from combo.utils.file_utils import cached_path, get_file_extension +from cached_path import is_url_or_existing_file +from combo.data.vocabulary import Vocabulary +from combo.modules.time_distributed import TimeDistributed +from combo.modules.token_embedders.token_embedder import TokenEmbedder +from combo.nn import util + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + import h5py + +logger = logging.getLogger(__name__) + + +@Registry.register(TokenEmbedder, "embedding") +class Embedding(TokenEmbedder): + """ + A more featureful embedding module than the default in Pytorch. Adds the ability to: + + 1. embed higher-order inputs + 2. pre-specify the weight matrix + 3. use a non-trainable embedding + 4. project the resultant embeddings to some other dimension (which only makes sense with + non-trainable embeddings). + + Note that if you are using our data API and are trying to embed a + [`TextField`](../../data/fields/text_field.md), you should use a + [`TextFieldEmbedder`](../text_field_embedders/text_field_embedder.md) instead of using this directly. + + Registered as a `TokenEmbedder` with name "embedding". + + # Parameters + + num_embeddings : `int` + Size of the dictionary of embeddings (vocabulary size). + embedding_dim : `int` + The size of each embedding vector. + projection_dim : `int`, optional (default=`None`) + If given, we add a projection layer after the embedding layer. This really only makes + sense if `trainable` is `False`. + weight : `torch.FloatTensor`, optional (default=`None`) + A pre-initialised weight matrix for the embedding lookup, allowing the use of + pretrained vectors. + padding_index : `int`, optional (default=`None`) + If given, pads the output with zeros whenever it encounters the index. + trainable : `bool`, optional (default=`True`) + Whether or not to optimize the embedding parameters. + max_norm : `float`, optional (default=`None`) + If given, will renormalize the embeddings to always have a norm lesser than this + norm_type : `float`, optional (default=`2`) + The p of the p-norm to compute for the max_norm option + scale_grad_by_freq : `bool`, optional (default=`False`) + If given, this will scale gradients by the frequency of the words in the mini-batch. + sparse : `bool`, optional (default=`False`) + Whether or not the Pytorch backend should use a sparse representation of the embedding weight. + vocab_namespace : `str`, optional (default=`None`) + In case of fine-tuning/transfer learning, the model's embedding matrix needs to be + extended according to the size of extended-vocabulary. To be able to know how much to + extend the embedding-matrix, it's necessary to know which vocab_namspace was used to + construct it in the original training. We store vocab_namespace used during the original + training as an attribute, so that it can be retrieved during fine-tuning. + pretrained_file : `str`, optional (default=`None`) + Path to a file of word vectors to initialize the embedding matrix. It can be the + path to a local file or a URL of a (cached) remote file. Two formats are supported: + * hdf5 file - containing an embedding matrix in the form of a torch.Tensor; + * text file - an utf-8 encoded text file with space separated fields. + vocabulary : `Vocabulary`, optional (default = `None`) + Used to construct an embedding from a pretrained file. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "embedding", it gets specified as a top-level parameter, then is passed in to this module + separately. + + # Returns + + An Embedding module. + """ + + def __init__( + self, + embedding_dim: int, + num_embeddings: int = None, + projection_dim: int = None, + weight: torch.FloatTensor = None, + padding_index: int = None, + trainable: bool = True, + max_norm: float = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + vocab_namespace: str = "tokens", + pretrained_file: str = None, + vocabulary: Vocabulary = None, + ) -> None: + super().__init__() + + if num_embeddings is None and vocabulary is None: + raise ConfigurationError( + "Embedding must be constructed with either num_embeddings or a vocabulary." + ) + + _vocab_namespace: Optional[str] = vocab_namespace + if num_embeddings is None: + num_embeddings = vocabulary.get_vocab_size(_vocab_namespace) # type: ignore + else: + # If num_embeddings is present, set default namespace to None so that extend_vocab + # call doesn't misinterpret that some namespace was originally used. + _vocab_namespace = None # type: ignore + + self.num_embeddings = num_embeddings + self.padding_index = padding_index + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + self._vocab_namespace = _vocab_namespace + self._pretrained_file = pretrained_file + + self.output_dim = projection_dim or embedding_dim + + if weight is not None and pretrained_file: + raise ConfigurationError( + "Embedding was constructed with both a weight and a pretrained file." + ) + + elif pretrained_file is not None: + + if vocabulary is None: + raise ConfigurationError( + "To construct an Embedding from a pretrained file, you must also pass a vocabulary." + ) + + # If we're loading a saved model, we don't want to actually read a pre-trained + # embedding file - the embeddings will just be in our saved weights, and we might not + # have the original embedding file anymore, anyway. + + # TODO: having to pass tokens here is SUPER gross, but otherwise this breaks the + # extend_vocab method, which relies on the value of vocab_namespace being None + # to infer at what stage the embedding has been constructed. Phew. + weight = _read_pretrained_embeddings_file( + pretrained_file, embedding_dim, vocabulary, vocab_namespace + ) + self.weight = torch.nn.Parameter(weight, requires_grad=trainable) + + elif weight is not None: + self.weight = torch.nn.Parameter(weight, requires_grad=trainable) + + else: + weight = torch.FloatTensor(num_embeddings, embedding_dim) + self.weight = torch.nn.Parameter(weight, requires_grad=trainable) + torch.nn.init.xavier_uniform_(self.weight) + + # Whatever way we have constructed the embedding, it should be consistent with + # num_embeddings and embedding_dim. + if self.weight.size() != (num_embeddings, embedding_dim): + raise ConfigurationError( + "A weight matrix was passed with contradictory embedding shapes." + ) + + if self.padding_index is not None: + self.weight.data[self.padding_index].fill_(0) + + if projection_dim: + self._projection = torch.nn.Linear(embedding_dim, projection_dim) + else: + self._projection = None + + def get_output_dim(self) -> int: + return self.output_dim + + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + # tokens may have extra dimensions (batch_size, d1, ..., dn, sequence_length), + # but embedding expects (batch_size, sequence_length), so pass tokens to + # util.combine_initial_dims (which is a no-op if there are no extra dimensions). + # Remember the original size. + original_size = tokens.size() + tokens = util.combine_initial_dims(tokens) + + embedded = embedding( + tokens, + self.weight, + padding_idx=self.padding_index, + max_norm=self.max_norm, + norm_type=self.norm_type, + scale_grad_by_freq=self.scale_grad_by_freq, + sparse=self.sparse, + ) + + # Now (if necessary) add back in the extra dimensions. + embedded = util.uncombine_initial_dims(embedded, original_size) + + if self._projection: + projection = self._projection + for _ in range(embedded.dim() - 2): + projection = TimeDistributed(projection) + embedded = projection(embedded) + return embedded + + def extend_vocab( + self, + extended_vocab: Vocabulary, + vocab_namespace: str = None, + extension_pretrained_file: str = None, + model_path: str = None, + ): + """ + Extends the embedding matrix according to the extended vocabulary. + If extension_pretrained_file is available, it will be used for initializing the new words + embeddings in the extended vocabulary; otherwise we will check if _pretrained_file attribute + is already available. If none is available, they will be initialized with xavier uniform. + + # Parameters + + extended_vocab : `Vocabulary` + Vocabulary extended from original vocabulary used to construct + this `Embedding`. + vocab_namespace : `str`, (optional, default=`None`) + In case you know what vocab_namespace should be used for extension, you + can pass it. If not passed, it will check if vocab_namespace used at the + time of `Embedding` construction is available. If so, this namespace + will be used or else extend_vocab will be a no-op. + extension_pretrained_file : `str`, (optional, default=`None`) + A file containing pretrained embeddings can be specified here. It can be + the path to a local file or an URL of a (cached) remote file. Check format + details in `from_params` of `Embedding` class. + model_path : `str`, (optional, default=`None`) + Path traversing the model attributes upto this embedding module. + Eg. "_text_field_embedder.token_embedder_tokens". This is only useful + to give a helpful error message when extend_vocab is implicitly called + by train or any other command. + """ + # Caveat: For allennlp v0.8.1 and below, we weren't storing vocab_namespace as an attribute, + # knowing which is necessary at time of embedding vocabulary extension. So old archive models are + # currently unextendable. + + vocab_namespace = vocab_namespace or self._vocab_namespace + if not vocab_namespace: + # It's not safe to default to "tokens" or any other namespace. + logger.info( + "Loading a model trained before embedding extension was implemented; " + "pass an explicit vocabulary namespace if you want to extend the vocabulary." + ) + return + + extended_num_embeddings = extended_vocab.get_vocab_size(vocab_namespace) + if extended_num_embeddings == self.num_embeddings: + # It's already been extended. No need to initialize / read pretrained file in first place (no-op) + return + + if extended_num_embeddings < self.num_embeddings: + raise ConfigurationError( + f"Size of namespace, {vocab_namespace} for extended_vocab is smaller than " + f"embedding. You likely passed incorrect vocabulary or namespace for extension." + ) + + # Case 1: user passed extension_pretrained_file and it's available. + if extension_pretrained_file and is_url_or_existing_file(extension_pretrained_file): + # Don't have to do anything here, this is the happy case. + pass + # Case 2: user passed extension_pretrained_file and it's not available + elif extension_pretrained_file: + raise ConfigurationError( + f"You passed pretrained embedding file {extension_pretrained_file} " + f"for model_path {model_path} but it's not available." + ) + # Case 3: user didn't pass extension_pretrained_file, but pretrained_file attribute was + # saved during training and is available. + elif is_url_or_existing_file(self._pretrained_file): + extension_pretrained_file = self._pretrained_file + # Case 4: no file is available, hope that pretrained embeddings weren't used in the first place and warn + elif self._pretrained_file is not None: + # Warn here instead of an exception to allow a fine-tuning even without the original pretrained_file + logger.warning( + f"Embedding at model_path, {model_path} cannot locate the pretrained_file. " + f"Originally pretrained_file was at '{self._pretrained_file}'." + ) + else: + # When loading a model from archive there is no way to distinguish between whether a pretrained-file + # was or wasn't used during the original training. So we leave an info here. + logger.info( + "If you are fine-tuning and want to use a pretrained_file for " + "embedding extension, please pass the mapping by --embedding-sources argument." + ) + + embedding_dim = self.weight.data.shape[-1] + if not extension_pretrained_file: + extra_num_embeddings = extended_num_embeddings - self.num_embeddings + extra_weight = torch.FloatTensor(extra_num_embeddings, embedding_dim) + torch.nn.init.xavier_uniform_(extra_weight) + else: + # It's easiest to just reload the embeddings for the entire vocabulary, + # then only keep the ones we need. + whole_weight = _read_pretrained_embeddings_file( + extension_pretrained_file, embedding_dim, extended_vocab, vocab_namespace + ) + extra_weight = whole_weight[self.num_embeddings :, :] + + device = self.weight.data.device + extended_weight = torch.cat([self.weight.data, extra_weight.to(device)], dim=0) + self.weight = torch.nn.Parameter(extended_weight, requires_grad=self.weight.requires_grad) + self.num_embeddings = extended_num_embeddings + + +def _read_pretrained_embeddings_file( + file_uri: str, embedding_dim: int, vocab: Vocabulary, namespace: str = "tokens" +) -> torch.FloatTensor: + """ + Returns and embedding matrix for the given vocabulary using the pretrained embeddings + contained in the given file. Embeddings for tokens not found in the pretrained embedding file + are randomly initialized using a normal distribution with mean and standard deviation equal to + those of the pretrained embeddings. + + We support two file formats: + + * text format - utf-8 encoded text file with space separated fields: [word] [dim 1] [dim 2] ... + The text file can eventually be compressed, and even resides in an archive with multiple files. + If the file resides in an archive with other files, then `embeddings_filename` must + be a URI "(archive_uri)#file_path_inside_the_archive" + + * hdf5 format - hdf5 file containing an embedding matrix in the form of a torch.Tensor. + + If the filename ends with '.hdf5' or '.h5' then we load from hdf5, otherwise we assume + text format. + + # Parameters + + file_uri : `str`, required. + It can be: + + * a file system path or a URL of an eventually compressed text file or a zip/tar archive + containing a single file. + + * URI of the type `(archive_path_or_url)#file_path_inside_archive` if the text file + is contained in a multi-file archive. + + vocabulary : `Vocabulary`, required. + A Vocabulary object. + namespace : `str`, (optional, default=`"tokens"`) + The namespace of the vocabulary to find pretrained embeddings for. + trainable : `bool`, (optional, default=`True`) + Whether or not the embedding parameters should be optimized. + + # Returns + + A weight matrix with embeddings initialized from the read file. The matrix has shape + `(vocabulary.get_vocab_size(namespace), embedding_dim)`, where the indices of words appearing in + the pretrained embedding file are initialized to the pretrained embedding value. + """ + file_ext = get_file_extension(file_uri) + if file_ext in [".h5", ".hdf5"]: + return _read_embeddings_from_hdf5(file_uri, embedding_dim, vocab, namespace) + + return _read_embeddings_from_text_file(file_uri, embedding_dim, vocab, namespace) + + +def _read_embeddings_from_text_file( + file_uri: str, embedding_dim: int, vocab: Vocabulary, namespace: str = "tokens" +) -> torch.FloatTensor: + """ + Read pre-trained word vectors from an eventually compressed text file, possibly contained + inside an archive with multiple files. The text file is assumed to be utf-8 encoded with + space-separated fields: [word] [dim 1] [dim 2] ... + + Lines that contain more numerical tokens than `embedding_dim` raise a warning and are skipped. + + The remainder of the docstring is identical to `_read_pretrained_embeddings_file`. + """ + tokens_to_keep = set(vocab.get_index_to_token_vocabulary(namespace).values()) + vocab_size = vocab.get_vocab_size(namespace) + embeddings = {} + + # First we read the embeddings from the file, only keeping vectors for the words we need. + logger.info("Reading pretrained embeddings from file") + + with EmbeddingsTextFile(file_uri) as embeddings_file: + for line in Tqdm.tqdm(embeddings_file): + token = line.split(" ", 1)[0] + if token in tokens_to_keep: + fields = line.rstrip().split(" ") + if len(fields) - 1 != embedding_dim: + # Sometimes there are funny unicode parsing problems that lead to different + # fields lengths (e.g., a word with a unicode space character that splits + # into more than one column). We skip those lines. Note that if you have + # some kind of long header, this could result in all of your lines getting + # skipped. It's hard to check for that here; you just have to look in the + # embedding_misses_file and at the model summary to make sure things look + # like they are supposed to. + logger.warning( + "Found line with wrong number of dimensions (expected: %d; actual: %d): %s", + embedding_dim, + len(fields) - 1, + line, + ) + continue + + vector = numpy.asarray(fields[1:], dtype="float32") + embeddings[token] = vector + + if not embeddings: + raise ConfigurationError( + "No embeddings of correct dimension found; you probably " + "misspecified your embedding_dim parameter, or didn't " + "pre-populate your Vocabulary" + ) + + all_embeddings = numpy.asarray(list(embeddings.values())) + embeddings_mean = float(numpy.mean(all_embeddings)) + embeddings_std = float(numpy.std(all_embeddings)) + # Now we initialize the weight matrix for an embedding layer, starting with random vectors, + # then filling in the word vectors we just read. + logger.info("Initializing pre-trained embedding layer") + embedding_matrix = torch.FloatTensor(vocab_size, embedding_dim).normal_( + embeddings_mean, embeddings_std + ) + num_tokens_found = 0 + index_to_token = vocab.get_index_to_token_vocabulary(namespace) + for i in range(vocab_size): + token = index_to_token[i] + + # If we don't have a pre-trained vector for this word, we'll just leave this row alone, + # so the word has a random initialization. + if token in embeddings: + embedding_matrix[i] = torch.FloatTensor(embeddings[token]) + num_tokens_found += 1 + else: + logger.debug( + "Token %s was not found in the embedding file. Initialising randomly.", token + ) + + logger.info( + "Pretrained embeddings were found for %d out of %d tokens", num_tokens_found, vocab_size + ) + + return embedding_matrix + + +def _read_embeddings_from_hdf5( + embeddings_filename: str, embedding_dim: int, vocabulary: Vocabulary, namespace: str = "tokens" +) -> torch.FloatTensor: + """ + Reads from a hdf5 formatted file. The embedding matrix is assumed to + be keyed by 'embedding' and of size `(num_tokens, embedding_dim)`. + """ + with h5py.File(embeddings_filename, "r") as fin: + embeddings = fin["embedding"][...] + + if list(embeddings.shape) != [vocabulary.get_vocab_size(namespace), embedding_dim]: + raise ConfigurationError( + "Read shape {0} embeddings from the file, but expected {1}".format( + list(embeddings.shape), [vocabulary.get_vocab_size(namespace), embedding_dim] + ) + ) + + return torch.FloatTensor(embeddings) + + +def format_embeddings_file_uri( + main_file_path_or_url: str, path_inside_archive: Optional[str] = None +) -> str: + if path_inside_archive: + return "({})#{}".format(main_file_path_or_url, path_inside_archive) + return main_file_path_or_url + + +class EmbeddingsFileURI(NamedTuple): + main_file_uri: str + path_inside_archive: Optional[str] = None + + +def parse_embeddings_file_uri(uri: str) -> "EmbeddingsFileURI": + match = re.fullmatch(r"\((.*)\)#(.*)", uri) + if match: + fields = cast(Tuple[str, str], match.groups()) + return EmbeddingsFileURI(*fields) + else: + return EmbeddingsFileURI(uri, None) + + +class EmbeddingsTextFile(Iterator[str]): + """ + Utility class for opening embeddings text files. Handles various compression formats, + as well as context management. + + # Parameters + + file_uri : `str` + It can be: + + * a file system path or a URL of an eventually compressed text file or a zip/tar archive + containing a single file. + * URI of the type `(archive_path_or_url)#file_path_inside_archive` if the text file + is contained in a multi-file archive. + + encoding : `str` + cache_dir : `str` + """ + + DEFAULT_ENCODING = "utf-8" + + def __init__( + self, file_uri: str, encoding: str = DEFAULT_ENCODING, cache_dir: str = None + ) -> None: + + self.uri = file_uri + self._encoding = encoding + self._cache_dir = cache_dir + self._archive_handle: Any = None # only if the file is inside an archive + + main_file_uri, path_inside_archive = parse_embeddings_file_uri(file_uri) + main_file_local_path = cached_path(main_file_uri, cache_dir=cache_dir) + + if zipfile.is_zipfile(main_file_local_path): # ZIP archive + self._open_inside_zip(main_file_uri, path_inside_archive) + + elif tarfile.is_tarfile(main_file_local_path): # TAR archive + self._open_inside_tar(main_file_uri, path_inside_archive) + + else: # all the other supported formats, including uncompressed files + if path_inside_archive: + raise ValueError("Unsupported archive format: %s" + main_file_uri) + + # All the python packages for compressed files share the same interface of io.open + extension = get_file_extension(main_file_uri) + + # Some systems don't have support for all of these libraries, so we import them only + # when necessary. + package = None + if extension in [".txt", ".vec"]: + package = io + elif extension == ".gz": + import gzip + + package = gzip + elif extension == ".bz2": + import bz2 + + package = bz2 + elif extension == ".xz": + import lzma + + package = lzma + + if package is None: + logger.warning( + 'The embeddings file has an unknown file extension "%s". ' + "We will assume the file is an (uncompressed) text file", + extension, + ) + package = io + + self._handle = package.open( # type: ignore + main_file_local_path, "rt", encoding=encoding + ) + + # To use this with tqdm we'd like to know the number of tokens. It's possible that the + # first line of the embeddings file contains this: if it does, we want to start iteration + # from the 2nd line, otherwise we want to start from the 1st. + # Unfortunately, once we read the first line, we cannot move back the file iterator + # because the underlying file may be "not seekable"; we use itertools.chain instead. + first_line = next(self._handle) # this moves the iterator forward + self.num_tokens = EmbeddingsTextFile._get_num_tokens_from_first_line(first_line) + if self.num_tokens: + # the first line is a header line: start iterating from the 2nd line + self._iterator = self._handle + else: + # the first line is not a header line: start iterating from the 1st line + self._iterator = itertools.chain([first_line], self._handle) + + def _open_inside_zip(self, archive_path: str, member_path: Optional[str] = None) -> None: + cached_archive_path = cached_path(archive_path, cache_dir=self._cache_dir) + archive = zipfile.ZipFile(cached_archive_path, "r") + if member_path is None: + members_list = archive.namelist() + member_path = self._get_the_only_file_in_the_archive(members_list, archive_path) + member_path = cast(str, member_path) + member_file = cast(BinaryIO, archive.open(member_path, "r")) + self._handle = io.TextIOWrapper(member_file, encoding=self._encoding) + self._archive_handle = archive + + def _open_inside_tar(self, archive_path: str, member_path: Optional[str] = None) -> None: + cached_archive_path = cached_path(archive_path, cache_dir=self._cache_dir) + archive = tarfile.open(cached_archive_path, "r") + if member_path is None: + members_list = archive.getnames() + member_path = self._get_the_only_file_in_the_archive(members_list, archive_path) + member_path = cast(str, member_path) + member = archive.getmember(member_path) # raises exception if not present + member_file = cast(BinaryIO, archive.extractfile(member)) + self._handle = io.TextIOWrapper(member_file, encoding=self._encoding) + self._archive_handle = archive + + def read(self) -> str: + return "".join(self._iterator) + + def readline(self) -> str: + return next(self._iterator) + + def close(self) -> None: + self._handle.close() + if self._archive_handle: + self._archive_handle.close() + + def __enter__(self) -> "EmbeddingsTextFile": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + def __iter__(self) -> "EmbeddingsTextFile": + return self + + def __next__(self) -> str: + return next(self._iterator) + + def __len__(self) -> Optional[int]: + if self.num_tokens: + return self.num_tokens + raise AttributeError( + "an object of type EmbeddingsTextFile implements `__len__` only if the underlying " + "text file declares the number of tokens (i.e. the number of lines following)" + "in the first line. That is not the case of this particular instance." + ) + + @staticmethod + def _get_the_only_file_in_the_archive(members_list: Sequence[str], archive_path: str) -> str: + if len(members_list) > 1: + raise ValueError( + "The archive %s contains multiple files, so you must select " + "one of the files inside providing a uri of the type: %s." + % ( + archive_path, + format_embeddings_file_uri("path_or_url_to_archive", "path_inside_archive"), + ) + ) + return members_list[0] + + @staticmethod + def _get_num_tokens_from_first_line(line: str) -> Optional[int]: + """This function takes in input a string and if it contains 1 or 2 integers, it assumes the + largest one it the number of tokens. Returns None if the line doesn't match that pattern.""" + fields = line.split(" ") + if 1 <= len(fields) <= 2: + try: + int_fields = [int(x) for x in fields] + except ValueError: + return None + else: + num_tokens = max(int_fields) + logger.info( + "Recognized a header line in the embedding file with number of tokens: %d", + num_tokens, + ) + return num_tokens + return None diff --git a/combo/modules/token_embedders/empty_embedder.py b/combo/modules/token_embedders/empty_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..1def0f854fc03c055a7070e5b3e6130f36367874 --- /dev/null +++ b/combo/modules/token_embedders/empty_embedder.py @@ -0,0 +1,28 @@ +import torch + +from combo.config import Registry +from combo.modules.token_embedders import TokenEmbedder + + +@Registry.register(TokenEmbedder, "empty") +class EmptyEmbedder(TokenEmbedder): + """ + Assumes you want to completely ignore the output of a `TokenIndexer` for some reason, and does + not return anything when asked to embed it. + + You should almost never need to use this; normally you would just not use a particular + `TokenIndexer`. It's only in very rare cases, like simplicity in data processing for language + modeling (where we use just one `TextField` to handle input embedding and computing target ids), + where you might want to use this. + + Registered as a `TokenEmbedder` with name "empty". + """ + + def __init__(self) -> None: + super().__init__() + + def get_output_dim(self): + return 0 + + def forward(self, *inputs, **kwargs) -> torch.Tensor: + return None diff --git a/combo/modules/token_embedders/pretrained_transformer_embedder.py b/combo/modules/token_embedders/pretrained_transformer_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..d32549096c019d7b8cf85690792b01bd42deb6ba --- /dev/null +++ b/combo/modules/token_embedders/pretrained_transformer_embedder.py @@ -0,0 +1,415 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +""" + +import logging +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from overrides import overrides + +from combo.config import Registry +from combo.data.tokenizers import PretrainedTransformerTokenizer +from combo.modules.scalar_mix import ScalarMix +from combo.modules.token_embedders.token_embedder import TokenEmbedder +from combo.nn.util import batched_index_select +from transformers import XLNetConfig + +logger = logging.getLogger(__name__) + + +@Registry.register(TokenEmbedder, "pretrained_transformer") +class PretrainedTransformerEmbedder(TokenEmbedder): + """ + Uses a pretrained model from `transformers` as a `TokenEmbedder`. + + Registered as a `TokenEmbedder` with name "pretrained_transformer". + + # Parameters + + model_name : `str` + The name of the `transformers` model to use. Should be the same as the corresponding + `PretrainedTransformerIndexer`. + max_length : `int`, optional (default = `None`) + If positive, folds input token IDs into multiple segments of this length, pass them + through the transformer model independently, and concatenate the final representations. + Should be set to the same value as the `max_length` option on the + `PretrainedTransformerIndexer`. + sub_module: `str`, optional (default = `None`) + The name of a submodule of the transformer to be used as the embedder. Some transformers naturally act + as embedders such as BERT. However, other models consist of encoder and decoder, in which case we just + want to use the encoder. + train_parameters: `bool`, optional (default = `True`) + If this is `True`, the transformer weights get updated during training. If this is `False`, the + transformer weights are not updated during training. + eval_mode: `bool`, optional (default = `False`) + If this is `True`, the model is always set to evaluation mode (e.g., the dropout is disabled and the + batch normalization layer statistics are not updated). If this is `False`, such dropout and batch + normalization layers are only set to evaluation mode when when the model is evaluating on development + or test data. + last_layer_only: `bool`, optional (default = `True`) + When `True` (the default), only the final layer of the pretrained transformer is taken + for the embeddings. But if set to `False`, a scalar mix of all of the layers + is used. + override_weights_file: `Optional[str]`, optional (default = `None`) + If set, this specifies a file from which to load alternate weights that override the + weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created + with `torch.save()`. + override_weights_strip_prefix: `Optional[str]`, optional (default = `None`) + If set, strip the given prefix from the state dict when loading it. + reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`) + If this is an integer, the last `reinit_modules` layers of the transformer will be + re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will + be re-initialized. Note, because the module structure of the transformer `model_name` can + differ, we cannot guarantee that providing an integer or tuple of integers will work. If + this fails, you can instead provide a tuple of strings, which will be treated as regexes and + any module with a name matching the regex will be re-initialized. Re-initializing the last + few layers of a pretrained transformer can reduce the instability of fine-tuning on small + datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect + if `load_weights` is `False` or `override_weights_file` is not `None`. + load_weights: `bool`, optional (default = `True`) + Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive + it usually makes sense to set this to `False` (via the `overrides` parameter) + to avoid unnecessarily caching and loading the original pretrained weights, + since the archive will already contain all of the weights needed. + gradient_checkpointing: `bool`, optional (default = `None`) + Enable or disable gradient checkpointing. + tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`) + Dictionary with + [additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691) + for `AutoTokenizer.from_pretrained`. + transformer_kwargs: `Dict[str, Any]`, optional (default = `None`) + Dictionary with + [additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/modeling_utils.py#L253) + for `AutoModel.from_pretrained`. + """ # noqa: E501 + + authorized_missing_keys = [r"position_ids$"] + + def __init__( + self, + model_name: str, + *, + max_length: int = None, + sub_module: str = None, + train_parameters: bool = True, + eval_mode: bool = False, + last_layer_only: bool = True, + override_weights_file: Optional[str] = None, + override_weights_strip_prefix: Optional[str] = None, + reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None, + load_weights: bool = True, + gradient_checkpointing: Optional[bool] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + transformer_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + from combo.common import cached_transformers + + self.transformer_model = cached_transformers.get( + model_name, + True, + override_weights_file=override_weights_file, + override_weights_strip_prefix=override_weights_strip_prefix, + reinit_modules=reinit_modules, + load_weights=load_weights, + **(transformer_kwargs or {}), + ) + + if gradient_checkpointing is not None: + self.transformer_model.config.update({"gradient_checkpointing": gradient_checkpointing}) + + self.config = self.transformer_model.config + if sub_module: + assert hasattr(self.transformer_model, sub_module) + self.transformer_model = getattr(self.transformer_model, sub_module) + self._max_length = max_length + + # I'm not sure if this works for all models; open an issue on github if you find a case + # where it doesn't work. + self.output_dim = self.config.hidden_size + + self._scalar_mix: Optional[ScalarMix] = None + if not last_layer_only: + self._scalar_mix = ScalarMix(self.config.num_hidden_layers) + self.config.output_hidden_states = True + + tokenizer = PretrainedTransformerTokenizer( + model_name, + tokenizer_kwargs=tokenizer_kwargs, + ) + + try: + if self.transformer_model.get_input_embeddings().num_embeddings != len( + tokenizer.tokenizer + ): + self.transformer_model.resize_token_embeddings(len(tokenizer.tokenizer)) + except NotImplementedError: + # Can't resize for transformers models that don't implement base_model.get_input_embeddings() + logger.warning( + "Could not resize the token embedding matrix of the transformer model. " + "This model does not support resizing." + ) + + self._num_added_start_tokens = len(tokenizer.single_sequence_start_tokens) + self._num_added_end_tokens = len(tokenizer.single_sequence_end_tokens) + self._num_added_tokens = self._num_added_start_tokens + self._num_added_end_tokens + + self.train_parameters = train_parameters + if not train_parameters: + for param in self.transformer_model.parameters(): + param.requires_grad = False + + self.eval_mode = eval_mode + if eval_mode: + self.transformer_model.eval() + + def train(self, mode: bool = True): + self.training = mode + for name, module in self.named_children(): + if self.eval_mode and name == "transformer_model": + module.eval() + else: + module.train(mode) + return self + + @overrides + def get_output_dim(self) -> int: + return self.output_dim + + def _number_of_token_type_embeddings(self): + if isinstance(self.config, XLNetConfig): + return 3 # XLNet has 3 type ids + elif hasattr(self.config, "type_vocab_size"): + return self.config.type_vocab_size + else: + return 0 + + def forward( + self, + token_ids: torch.LongTensor, + mask: torch.BoolTensor, + type_ids: Optional[torch.LongTensor] = None, + segment_concat_mask: Optional[torch.BoolTensor] = None, + ) -> torch.Tensor: # type: ignore + """ + # Parameters + + token_ids: `torch.LongTensor` + Shape: `[batch_size, num_wordpieces if max_length is None else num_segment_concat_wordpieces]`. + num_segment_concat_wordpieces is num_wordpieces plus special tokens inserted in the + middle, e.g. the length of: "[CLS] A B C [SEP] [CLS] D E F [SEP]" (see indexer logic). + mask: `torch.BoolTensor` + Shape: [batch_size, num_wordpieces]. + type_ids: `Optional[torch.LongTensor]` + Shape: `[batch_size, num_wordpieces if max_length is None else num_segment_concat_wordpieces]`. + segment_concat_mask: `Optional[torch.BoolTensor]` + Shape: `[batch_size, num_segment_concat_wordpieces]`. + + # Returns + + `torch.Tensor` + Shape: `[batch_size, num_wordpieces, embedding_size]`. + + """ + # Some of the huggingface transformers don't support type ids at all and crash when you supply + # them. For others, you can supply a tensor of zeros, and if you don't, they act as if you did. + # There is no practical difference to the caller, so here we pretend that one case is the same + # as another case. + if type_ids is not None: + max_type_id = type_ids.max() + if max_type_id == 0: + type_ids = None + else: + if max_type_id >= self._number_of_token_type_embeddings(): + raise ValueError("Found type ids too large for the chosen transformer model.") + assert token_ids.shape == type_ids.shape + + fold_long_sequences = self._max_length is not None and token_ids.size(1) > self._max_length + if fold_long_sequences: + batch_size, num_segment_concat_wordpieces = token_ids.size() + token_ids, segment_concat_mask, type_ids = self._fold_long_sequences( + token_ids, segment_concat_mask, type_ids + ) + + transformer_mask = segment_concat_mask if self._max_length is not None else mask + assert transformer_mask is not None + # Shape: [batch_size, num_wordpieces, embedding_size], + # or if self._max_length is not None: + # [batch_size * num_segments, self._max_length, embedding_size] + + # We call this with kwargs because some of the huggingface models don't have the + # token_type_ids parameter and fail even when it's given as None. + # Also, as of transformers v2.5.1, they are taking FloatTensor masks. + parameters = {"input_ids": token_ids, "attention_mask": transformer_mask.float()} + if type_ids is not None: + parameters["token_type_ids"] = type_ids + + transformer_output = self.transformer_model(**parameters) + if self._scalar_mix is not None: + # The hidden states will also include the embedding layer, which we don't + # include in the scalar mix. Hence the `[1:]` slicing. + hidden_states = transformer_output.hidden_states[1:] + embeddings = self._scalar_mix(hidden_states) + else: + embeddings = transformer_output.last_hidden_state + + if fold_long_sequences: + embeddings = self._unfold_long_sequences( + embeddings, segment_concat_mask, batch_size, num_segment_concat_wordpieces + ) + + return embeddings + + def _fold_long_sequences( + self, + token_ids: torch.LongTensor, + mask: torch.BoolTensor, + type_ids: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.LongTensor, torch.LongTensor, Optional[torch.LongTensor]]: + """ + We fold 1D sequences (for each element in batch), returned by `PretrainedTransformerIndexer` + that are in reality multiple segments concatenated together, to 2D tensors, e.g. + + [ [CLS] A B C [SEP] [CLS] D E [SEP] ] + -> [ [ [CLS] A B C [SEP] ], [ [CLS] D E [SEP] [PAD] ] ] + The [PAD] positions can be found in the returned `mask`. + + # Parameters + + token_ids: `torch.LongTensor` + Shape: `[batch_size, num_segment_concat_wordpieces]`. + num_segment_concat_wordpieces is num_wordpieces plus special tokens inserted in the + middle, i.e. the length of: "[CLS] A B C [SEP] [CLS] D E F [SEP]" (see indexer logic). + mask: `torch.BoolTensor` + Shape: `[batch_size, num_segment_concat_wordpieces]`. + The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask` + in `forward()`. + type_ids: `Optional[torch.LongTensor]` + Shape: [batch_size, num_segment_concat_wordpieces]. + + # Returns: + + token_ids: `torch.LongTensor` + Shape: [batch_size * num_segments, self._max_length]. + mask: `torch.BoolTensor` + Shape: [batch_size * num_segments, self._max_length]. + """ + num_segment_concat_wordpieces = token_ids.size(1) + num_segments = math.ceil(num_segment_concat_wordpieces / self._max_length) # type: ignore + padded_length = num_segments * self._max_length # type: ignore + length_to_pad = padded_length - num_segment_concat_wordpieces + + def fold(tensor): # Shape: [batch_size, num_segment_concat_wordpieces] + # Shape: [batch_size, num_segments * self._max_length] + tensor = F.pad(tensor, [0, length_to_pad], value=0) + # Shape: [batch_size * num_segments, self._max_length] + return tensor.reshape(-1, self._max_length) + + return fold(token_ids), fold(mask), fold(type_ids) if type_ids is not None else None + + def _unfold_long_sequences( + self, + embeddings: torch.FloatTensor, + mask: torch.BoolTensor, + batch_size: int, + num_segment_concat_wordpieces: int, + ) -> torch.FloatTensor: + """ + We take 2D segments of a long sequence and flatten them out to get the whole sequence + representation while remove unnecessary special tokens. + + [ [ [CLS]_emb A_emb B_emb C_emb [SEP]_emb ], [ [CLS]_emb D_emb E_emb [SEP]_emb [PAD]_emb ] ] + -> [ [CLS]_emb A_emb B_emb C_emb D_emb E_emb [SEP]_emb ] + + We truncate the start and end tokens for all segments, recombine the segments, + and manually add back the start and end tokens. + + # Parameters + + embeddings: `torch.FloatTensor` + Shape: [batch_size * num_segments, self._max_length, embedding_size]. + mask: `torch.BoolTensor` + Shape: [batch_size * num_segments, self._max_length]. + The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask` + in `forward()`. + batch_size: `int` + num_segment_concat_wordpieces: `int` + The length of the original "[ [CLS] A B C [SEP] [CLS] D E F [SEP] ]", i.e. + the original `token_ids.size(1)`. + + # Returns: + + embeddings: `torch.FloatTensor` + Shape: [batch_size, self._num_wordpieces, embedding_size]. + """ + + def lengths_to_mask(lengths, max_len, device): + return torch.arange(max_len, device=device).expand( + lengths.size(0), max_len + ) < lengths.unsqueeze(1) + + device = embeddings.device + num_segments = int(embeddings.size(0) / batch_size) + embedding_size = embeddings.size(2) + + # We want to remove all segment-level special tokens but maintain sequence-level ones + num_wordpieces = num_segment_concat_wordpieces - (num_segments - 1) * self._num_added_tokens + + embeddings = embeddings.reshape( + batch_size, num_segments * self._max_length, embedding_size # type: ignore + ) + mask = mask.reshape(batch_size, num_segments * self._max_length) # type: ignore + # We assume that all 1s in the mask precede all 0s, and add an assert for that. + # Open an issue on GitHub if this breaks for you. + # Shape: (batch_size,) + seq_lengths = mask.sum(-1) + if not (lengths_to_mask(seq_lengths, mask.size(1), device) == mask).all(): + raise ValueError( + "Long sequence splitting only supports masks with all 1s preceding all 0s." + ) + # Shape: (batch_size, self._num_added_end_tokens); this is a broadcast op + end_token_indices = ( + seq_lengths.unsqueeze(-1) - torch.arange(self._num_added_end_tokens, device=device) - 1 + ) + + # Shape: (batch_size, self._num_added_start_tokens, embedding_size) + start_token_embeddings = embeddings[:, : self._num_added_start_tokens, :] + # Shape: (batch_size, self._num_added_end_tokens, embedding_size) + end_token_embeddings = batched_index_select(embeddings, end_token_indices) + + embeddings = embeddings.reshape(batch_size, num_segments, self._max_length, embedding_size) + embeddings = embeddings[ + :, :, self._num_added_start_tokens : embeddings.size(2) - self._num_added_end_tokens, : + ] # truncate segment-level start/end tokens + embeddings = embeddings.reshape(batch_size, -1, embedding_size) # flatten + + # Now try to put end token embeddings back which is a little tricky. + + # The number of segment each sequence spans, excluding padding. Mimicking ceiling operation. + # Shape: (batch_size,) + num_effective_segments = (seq_lengths + self._max_length - 1) // self._max_length + # The number of indices that end tokens should shift back. + num_removed_non_end_tokens = ( + num_effective_segments * self._num_added_tokens - self._num_added_end_tokens + ) + # Shape: (batch_size, self._num_added_end_tokens) + end_token_indices -= num_removed_non_end_tokens.unsqueeze(-1) + assert (end_token_indices >= self._num_added_start_tokens).all() + # Add space for end embeddings + embeddings = torch.cat([embeddings, torch.zeros_like(end_token_embeddings)], 1) + # Add end token embeddings back + embeddings.scatter_( + 1, end_token_indices.unsqueeze(-1).expand_as(end_token_embeddings), end_token_embeddings + ) + + # Now put back start tokens. We can do this before putting back end tokens, but then + # we need to change `num_removed_non_end_tokens` a little. + embeddings = torch.cat([start_token_embeddings, embeddings], 1) + + # Truncate to original length + embeddings = embeddings[:, :num_wordpieces, :] + return embeddings \ No newline at end of file diff --git a/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py b/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..d6f7a6f969b03de009c0abbd6fa36bc934114a13 --- /dev/null +++ b/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py @@ -0,0 +1,183 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/modules/token_embedders/pretrained_transformer_mismatched_embedder.py +""" + +from typing import Optional, Dict, Any + + +import torch + +from combo.modules.token_embedders import PretrainedTransformerEmbedder, TokenEmbedder +from combo.nn import util + +from combo.config import Registry +from combo.utils import ConfigurationError + + +@Registry.register(TokenEmbedder, "pretrained_transformer_mismatched") +class PretrainedTransformerMismatchedEmbedder(TokenEmbedder): + """ + Use this embedder to embed wordpieces given by `PretrainedTransformerMismatchedIndexer` + and to get word-level representations. + + Registered as a `TokenEmbedder` with name "pretrained_transformer_mismatched". + + # Parameters + + model_name : `str` + The name of the `transformers` model to use. Should be the same as the corresponding + `PretrainedTransformerMismatchedIndexer`. + max_length : `int`, optional (default = `None`) + If positive, folds input token IDs into multiple segments of this length, pass them + through the transformer model independently, and concatenate the final representations. + Should be set to the same value as the `max_length` option on the + `PretrainedTransformerMismatchedIndexer`. + sub_module: `str`, optional (default = `None`) + The name of a submodule of the transformer to be used as the embedder. Some transformers naturally act + as embedders such as BERT. However, other models consist of encoder and decoder, in which case we just + want to use the encoder. + train_parameters: `bool`, optional (default = `True`) + If this is `True`, the transformer weights get updated during training. + last_layer_only: `bool`, optional (default = `True`) + When `True` (the default), only the final layer of the pretrained transformer is taken + for the embeddings. But if set to `False`, a scalar mix of all of the layers + is used. + override_weights_file: `Optional[str]`, optional (default = `None`) + If set, this specifies a file from which to load alternate weights that override the + weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created + with `torch.save()`. + override_weights_strip_prefix: `Optional[str]`, optional (default = `None`) + If set, strip the given prefix from the state dict when loading it. + load_weights: `bool`, optional (default = `True`) + Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive + it usually makes sense to set this to `False` (via the `overrides` parameter) + to avoid unnecessarily caching and loading the original pretrained weights, + since the archive will already contain all of the weights needed. + gradient_checkpointing: `bool`, optional (default = `None`) + Enable or disable gradient checkpointing. + tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`) + Dictionary with + [additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691) + for `AutoTokenizer.from_pretrained`. + transformer_kwargs: `Dict[str, Any]`, optional (default = `None`) + Dictionary with + [additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/modeling_utils.py#L253) + for `AutoModel.from_pretrained`. + sub_token_mode: `Optional[str]`, optional (default= `avg`) + If `sub_token_mode` is set to `first`, return first sub-token representation as word-level representation + If `sub_token_mode` is set to `avg`, return average of all the sub-tokens representation as word-level representation + If `sub_token_mode` is not specified it defaults to `avg` + If invalid `sub_token_mode` is provided, throw `ConfigurationError` + + """ # noqa: E501 + + def __init__( + self, + model_name: str, + max_length: int = None, + sub_module: str = None, + train_parameters: bool = True, + last_layer_only: bool = True, + override_weights_file: Optional[str] = None, + override_weights_strip_prefix: Optional[str] = None, + load_weights: bool = True, + gradient_checkpointing: Optional[bool] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + transformer_kwargs: Optional[Dict[str, Any]] = None, + sub_token_mode: Optional[str] = "avg", + ) -> None: + super().__init__() + # The matched version v.s. mismatched + self._matched_embedder = PretrainedTransformerEmbedder( + model_name, + max_length=max_length, + sub_module=sub_module, + train_parameters=train_parameters, + last_layer_only=last_layer_only, + override_weights_file=override_weights_file, + override_weights_strip_prefix=override_weights_strip_prefix, + load_weights=load_weights, + gradient_checkpointing=gradient_checkpointing, + tokenizer_kwargs=tokenizer_kwargs, + transformer_kwargs=transformer_kwargs, + ) + self.sub_token_mode = sub_token_mode + + def get_output_dim(self): + return self._matched_embedder.get_output_dim() + + def forward( + self, + token_ids: torch.LongTensor, + mask: torch.BoolTensor, + offsets: torch.LongTensor, + wordpiece_mask: torch.BoolTensor, + type_ids: Optional[torch.LongTensor] = None, + segment_concat_mask: Optional[torch.BoolTensor] = None, + ) -> torch.Tensor: # type: ignore + """ + # Parameters + + token_ids: `torch.LongTensor` + Shape: [batch_size, num_wordpieces] (for exception see `PretrainedTransformerEmbedder`). + mask: `torch.BoolTensor` + Shape: [batch_size, num_orig_tokens]. + offsets: `torch.LongTensor` + Shape: [batch_size, num_orig_tokens, 2]. + Maps indices for the original tokens, i.e. those given as input to the indexer, + to a span in token_ids. `token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]` + corresponds to the original j-th token from the i-th batch. + wordpiece_mask: `torch.BoolTensor` + Shape: [batch_size, num_wordpieces]. + type_ids: `Optional[torch.LongTensor]` + Shape: [batch_size, num_wordpieces]. + segment_concat_mask: `Optional[torch.BoolTensor]` + See `PretrainedTransformerEmbedder`. + + # Returns + + `torch.Tensor` + Shape: [batch_size, num_orig_tokens, embedding_size]. + """ + # Shape: [batch_size, num_wordpieces, embedding_size]. + embeddings = self._matched_embedder( + token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask + ) + + # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) + # span_mask: (batch_size, num_orig_tokens, max_span_length) + span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets) + + span_mask = span_mask.unsqueeze(-1) + + # Shape: (batch_size, num_orig_tokens, max_span_length, embedding_size) + span_embeddings *= span_mask # zero out paddings + + # If "sub_token_mode" is set to "first", return the first sub-token embedding + if self.sub_token_mode == "first": + # Select first sub-token embeddings from span embeddings + # Shape: (batch_size, num_orig_tokens, embedding_size) + orig_embeddings = span_embeddings[:, :, 0, :] + + # If "sub_token_mode" is set to "avg", return the average of embeddings of all sub-tokens of a word + elif self.sub_token_mode == "avg": + # Sum over embeddings of all sub-tokens of a word + # Shape: (batch_size, num_orig_tokens, embedding_size) + span_embeddings_sum = span_embeddings.sum(2) + + # Shape (batch_size, num_orig_tokens) + span_embeddings_len = span_mask.sum(2) + + # Find the average of sub-tokens embeddings by dividing `span_embedding_sum` by `span_embedding_len` + # Shape: (batch_size, num_orig_tokens, embedding_size) + orig_embeddings = span_embeddings_sum / torch.clamp_min(span_embeddings_len, 1) + + # All the places where the span length is zero, write in zeros. + orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0 + + # If invalid "sub_token_mode" is provided, throw error + else: + raise ConfigurationError(f"Do not recognise 'sub_token_mode' {self.sub_token_mode}") + + return orig_embeddings diff --git a/combo/modules/token_embedders/projected_words_embedder.py b/combo/modules/token_embedders/projected_words_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..cd5b5dc71f6a1c12fc0d0b779ab4d2b07d4c8621 --- /dev/null +++ b/combo/modules/token_embedders/projected_words_embedder.py @@ -0,0 +1,50 @@ +from typing import Optional + +import torch +from overrides import overrides + +from combo.config import Registry +from combo.data.vocabulary import Vocabulary +from combo.nn.base import Linear +from combo.modules.token_embedders.token_embedder import TokenEmbedder +from combo.modules.token_embedders.embedding import Embedding + + +@Registry.register(TokenEmbedder, "embeddings_projected") +class ProjectedWordEmbedder(Embedding): + """Word embeddings.""" + + def __init__(self, + embedding_dim: int, + num_embeddings: int = None, + weight: torch.FloatTensor = None, + padding_index: int = None, + trainable: bool = True, + max_norm: float = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + vocab_namespace: str = "tokens", + pretrained_file: str = None, + vocabulary: Vocabulary = None, + projection_layer: Linear = None): # Change to Optional[Linear] + super().__init__( + embedding_dim=embedding_dim, + num_embeddings=num_embeddings, + weight=weight, + padding_index=padding_index, + trainable=trainable, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + vocab_namespace=vocab_namespace, + pretrained_file=pretrained_file, + vocabulary=vocabulary + ) + self._projection = projection_layer + self.output_dim = embedding_dim if projection_layer is None else projection_layer.out_features + + @overrides + def get_output_dim(self) -> int: + return self.output_dim diff --git a/combo/models/embeddings.py b/combo/modules/token_embedders/token_embedder.py similarity index 62% rename from combo/models/embeddings.py rename to combo/modules/token_embedders/token_embedder.py index 35b732ae5c4497893a28e67ecff36cbb383b3d79..ed8b61981c3bc5e32714df96d9e3d55c08069e56 100644 --- a/combo/models/embeddings.py +++ b/combo/modules/token_embedders/token_embedder.py @@ -5,19 +5,19 @@ from overrides import overrides from torch import nn from torchtext.vocab import Vectors, GloVe, FastText, CharNGram +from combo.config import FromParameters, Registry from combo.data import Vocabulary -from combo.models.base import TimeDistributed -from combo.models.dilated_cnn import DilatedCnnEncoder from combo.models.utils import tiny_value_of_dtype +from combo.modules.module import Module +from combo.modules.time_distributed import TimeDistributed from combo.utils import ConfigurationError -class TokenEmbedder(nn.Module): +class TokenEmbedder(Module, FromParameters): def __init__(self): super(TokenEmbedder, self).__init__() - @property - def output_dim(self) -> int: + def get_output_dim(self) -> int: raise NotImplementedError() def forward(self, @@ -26,10 +26,11 @@ class TokenEmbedder(nn.Module): raise NotImplementedError() -class _TorchEmbedder(TokenEmbedder): +@Registry.register(TokenEmbedder, 'base') +class TorchEmbedder(TokenEmbedder): def __init__(self, - num_embeddings: int, - embedding_dim: int, + num_TokenEmbedders: int, + TokenEmbedder_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2., @@ -40,28 +41,28 @@ class _TorchEmbedder(TokenEmbedder): weight: Optional[torch.Tensor] = None, trainable: bool = True, projection_dim: Optional[int] = None): - super(_TorchEmbedder, self).__init__() - self._embedding_dim = embedding_dim - self._embedding = nn.Embedding(num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) + super(TorchEmbedder, self).__init__() + self._TokenEmbedder_dim = TokenEmbedder_dim + # self._TokenEmbedder = nn.TokenEmbedder(num_TokenEmbedders=num_TokenEmbedders, + # TokenEmbedder_dim=TokenEmbedder_dim, + # padding_idx=padding_idx, + # max_norm=max_norm, + # norm_type=norm_type, + # scale_grad_by_freq=scale_grad_by_freq, + # sparse=sparse) self.__vocab_namespace = vocab_namespace self.__vocab = vocab if weight is not None: - if weight.shape() != (num_embeddings, embedding_dim): + if weight.shape() != (num_TokenEmbedders, TokenEmbedder_dim): raise ConfigurationError( - "Weight matrix must be of shape (num_embeddings, embedding_dim)." + + "Weight matrix must be of shape (num_TokenEmbedders, TokenEmbedder_dim)." + f"Got: ({weight.shape()})" ) self.__weight = torch.nn.Parameter(weight, requires_grad=trainable) else: - self.__weight = torch.nn.Parameter(torch.FloatTensor(num_embeddings, embedding_dim), + self.__weight = torch.nn.Parameter(torch.FloatTensor(num_TokenEmbedders, TokenEmbedder_dim), requires_grad=trainable) torch.nn.init.xavier_uniform_(self.__weight) @@ -69,21 +70,17 @@ class _TorchEmbedder(TokenEmbedder): self.__weight.data[padding_idx].fill_(0) if projection_dim: - self._projection = torch.nn.Linear(embedding_dim, projection_dim) - self._output_dim = projection_dim + self._projection = torch.nn.Linear(TokenEmbedder_dim, projection_dim) + self.output_dim = projection_dim else: self._projection = None - self._output_dim = embedding_dim - - @overrides - def output_dim(self) -> int: - return self._output_dim + self.output_dim = TokenEmbedder_dim @overrides def forward(self, x: torch.Tensor, char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor: - embedded = self._embedding(x) + embedded = self._TokenEmbedder(x) if self._projection: projection = self._projection for p in range(embedded.dim()-2): @@ -92,6 +89,7 @@ class _TorchEmbedder(TokenEmbedder): return embedded +@Registry.register(TokenEmbedder, 'torchtext_vectors') class _TorchtextVectorsEmbedder(TokenEmbedder): """ Torchtext Vectors object wrapper @@ -110,7 +108,7 @@ class _TorchtextVectorsEmbedder(TokenEmbedder): self.__lower_case_backup = lower_case_backup @overrides - def output_dim(self) -> int: + def get_output_dim(self) -> int: return len(self.__torchtext_embedder) @overrides @@ -120,75 +118,75 @@ class _TorchtextVectorsEmbedder(TokenEmbedder): return self.__torchtext_embedder.get_vecs_by_tokens(x, self.__lower_case_backup) +@Registry.register(TokenEmbedder, 'glove42b') class GloVe42BEmbedder(_TorchtextVectorsEmbedder): def __init__(self, dim: int = 300): super(GloVe42BEmbedder, self).__init__(GloVe("42B", dim)) +@Registry.register(TokenEmbedder, 'glove840b') class GloVe840BEmbedder(_TorchtextVectorsEmbedder): def __init__(self, dim: int = 300): super(GloVe840BEmbedder, self).__init__(GloVe("840B", dim)) +@Registry.register(TokenEmbedder, 'glove_twitter27b') class GloVeTwitter27BEmbedder(_TorchtextVectorsEmbedder): def __init__(self, dim: int = 300): super(GloVeTwitter27BEmbedder, self).__init__(GloVe("twitter.27B", dim)) +@Registry.register(TokenEmbedder, 'glove6b') class GloVe6BEmbedder(_TorchtextVectorsEmbedder): def __init__(self, dim: int = 300): super(GloVe6BEmbedder, self).__init__(GloVe("6B", dim)) +@Registry.register(TokenEmbedder, 'fast_text') class FastTextEmbedder(_TorchtextVectorsEmbedder): def __init__(self, language: str = "en"): super(FastTextEmbedder, self).__init__(FastText(language)) +@Registry.register(TokenEmbedder, 'char_ngram') class CharNGramEmbedder(_TorchtextVectorsEmbedder): def __init__(self): super(CharNGramEmbedder, self).__init__(CharNGram()) -class CharacterBasedWordEmbedder(TokenEmbedder): - def __init__(self, - num_embeddings: int, - embedding_dim: int, - dilated_cnn_encoder: DilatedCnnEncoder): - super(CharacterBasedWordEmbedder, self).__init__() - self.__embedding_dim = embedding_dim - self.__dilated_cnn_encoder = dilated_cnn_encoder - self.char_embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) - - @overrides - def output_dim(self) -> int: - return self.__embedding_dim - - @overrides - def forward(self, - x: torch.Tensor, - char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor: - if char_mask is None: - char_mask = x.new_ones(x.size()) - - x = self.char_embed(x) - x = x * char_mask.unsqueeze(-1).float() - x = self.__dilated_cnn_encoder(x.transpose(2, 3)) - return torch.max(x, dim=-1)[0] - - -class PretrainedTransformerMismatchedEmbedder(TokenEmbedder): - pass - - -class TransformersWordEmbedder(PretrainedTransformerMismatchedEmbedder): - pass - - -class FeatsTokenEmbedder(_TorchEmbedder): +# @Registry.register(RegistryCategory.TokenEmbedder, 'character_based_word') +# class CharacterBasedWordEmbedder(TokenEmbedder): +# def __init__(self, +# num_TokenEmbedders: int, +# TokenEmbedder_dim: int, +# dilated_cnn_encoder: DilatedCnnEncoder): +# super(CharacterBasedWordEmbedder, self).__init__() +# self.__TokenEmbedder_dim = TokenEmbedder_dim +# self.__dilated_cnn_encoder = dilated_cnn_encoder +# self.char_embed = nn.TokenEmbedder(num_TokenEmbedders=num_TokenEmbedders, TokenEmbedder_dim=TokenEmbedder_dim) +# +# @overrides +# def output_dim(self) -> int: +# return self.__TokenEmbedder_dim +# +# @overrides +# def forward(self, +# x: torch.Tensor, +# char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor: +# if char_mask is None: +# char_mask = x.new_ones(x.size()) +# +# x = self.char_embed(x) +# x = x * char_mask.unsqueeze(-1).float() +# x = self.__dilated_cnn_encoder(x.transpose(2, 3)) +# return torch.max(x, dim=-1)[0] + + +@Registry.register(TokenEmbedder, 'feats_token') +class FeatsTokenEmbedder(TorchEmbedder): def __init__(self, - num_embeddings: int, - embedding_dim: int, + num_TokenEmbedders: int, + TokenEmbedder_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2., @@ -198,8 +196,8 @@ class FeatsTokenEmbedder(_TorchEmbedder): vocab: Vocabulary = None, weight: Optional[torch.Tensor] = None, trainable: bool = True): - super(FeatsTokenEmbedder, self).__init__(num_embeddings, - embedding_dim, + super(FeatsTokenEmbedder, self).__init__(num_TokenEmbedders, + TokenEmbedder_dim, padding_idx, max_norm, norm_type, diff --git a/combo/modules/token_embedders/transformers_words_embeddings.py b/combo/modules/token_embedders/transformers_words_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..c25de4d2d94196aa2960af5db38c4ed2c09f0207 --- /dev/null +++ b/combo/modules/token_embedders/transformers_words_embeddings.py @@ -0,0 +1,67 @@ +""" +Adapted from COMBO +Author: Mateusz Klimaszewski +""" +from overrides import overrides + +from combo.config import Registry +from combo.nn.activations import Activation +from combo.modules.token_embedders.token_embedder import TokenEmbedder +from combo.modules.token_embedders.pretrained_transformer_mismatched_embedder import PretrainedTransformerMismatchedEmbedder +import torch +from combo.nn import base +from typing import Any, Dict, Optional + + +@Registry.register(TokenEmbedder, "transformers_word_embeddings") +class TransformersWordEmbedder(PretrainedTransformerMismatchedEmbedder): + """ + Transformers word embeddings as last hidden state + optional projection layers. + + Tested with Bert (but should work for other models as well). + """ + + authorized_missing_keys = [r"position_ids$"] + + def __init__(self, + model_name: str, + projection_dim: int = 0, + projection_activation: Optional[Activation] = lambda x: x, + projection_dropout_rate: Optional[float] = 0.0, + freeze_transformer: bool = True, + last_layer_only: bool = True, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + transformer_kwargs: Optional[Dict[str, Any]] = None): + super().__init__(model_name, + train_parameters=not freeze_transformer, + last_layer_only=last_layer_only, + tokenizer_kwargs=tokenizer_kwargs, + transformer_kwargs=transformer_kwargs) + if projection_dim: + self.projection_layer = base.Linear(in_features=super().get_output_dim(), + out_features=projection_dim, + dropout_rate=projection_dropout_rate, + activation=projection_activation) + self.output_dim = projection_dim + else: + self.projection_layer = None + self.output_dim = super().get_output_dim() + + #@overrides + def forward( + self, + token_ids: torch.LongTensor, + mask: torch.BoolTensor, + offsets: torch.LongTensor, + wordpiece_mask: torch.BoolTensor, + type_ids: Optional[torch.LongTensor] = None, + segment_concat_mask: Optional[torch.BoolTensor] = None, + ) -> torch.Tensor: # type: ignore + x = super().forward(token_ids, mask, offsets, wordpiece_mask, type_ids, segment_concat_mask) + if self.projection_layer: + x = self.projection_layer(x) + return x + + @overrides + def get_output_dim(self) -> int: + return self.output_dim diff --git a/combo/nn/__init__.py b/combo/nn/__init__.py index a01e009a13678c207b06a1e2957587bf063192f0..4216e25b112314c81009a2dd6296d0c694520bc1 100644 --- a/combo/nn/__init__.py +++ b/combo/nn/__init__.py @@ -1 +1,2 @@ -from .regularizers import * \ No newline at end of file +from .regularizers import * +from .activations import * diff --git a/combo/nn/activations.py b/combo/nn/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..c04cdbda43fd28cf8ced527e5bacb126bf7c3d39 --- /dev/null +++ b/combo/nn/activations.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn +from overrides import overrides + +from combo.config.registry import Registry + + +class Activation(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +@Registry.register(Activation, 'linear') +class LinearActivation(Activation): + @overrides + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +@Registry.register(Activation, 'relu') +class ReLUActivation(Activation): + def __init__(self): + super().__init__() + self.__torch_activation = torch.nn.ReLU() + @overrides + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.__torch_activation.forward(x) + + +@Registry.register(Activation, 'leaky_relu') +class ReLUActivation(Activation): + def __init__(self): + super().__init__() + self.__torch_activation = torch.nn.LeakyReLU() + @overrides + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.__torch_activation.forward(x) + + +@Registry.register(Activation, 'gelu') +class ReLUActivation(Activation): + def __init__(self): + super().__init__() + self.__torch_activation = torch.nn.GELU() + @overrides + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.__torch_activation.forward(x) + + +@Registry.register(Activation, 'sigmoid') +class SigmoidActivation(Activation): + def __init__(self): + super().__init__() + self.__torch_activation = torch.nn.Sigmoid() + @overrides + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.__torch_activation.forward(x) + + +@Registry.register(Activation, 'tanh') +class TanhActivation(Activation): + def __init__(self): + super().__init__() + self.__torch_activation = torch.nn.Tanh() + @overrides + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.__torch_activation.forward(x) + + +@Registry.register(Activation, 'mish') +class MishActivation(Activation): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.tanh(torch.nn.functional.softplus(x)) + + +@Registry.register(Activation, 'swish') +class SwishActivation(Activation): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(x) + + +@Registry.register(Activation, 'gelu_new') +class GeluNew(Activation): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also + see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return ( + 0.5 + * x + * (1.0 + torch.tanh(torch.math.sqrt(2.0 / torch.math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + ) + + +@Registry.register(Activation, 'gelu_fast') +class GeluFast(Activation): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) diff --git a/combo/models/base.py b/combo/nn/base.py similarity index 73% rename from combo/models/base.py rename to combo/nn/base.py index 45eae041affc66f1bd34e3224b893e22285d0028..c752066f06fef3e274549a687799ece43337adff 100644 --- a/combo/models/base.py +++ b/combo/nn/base.py @@ -2,20 +2,23 @@ from typing import Dict, Optional, List, Union, Tuple import torch import torch.nn as nn -from overrides import overrides -from combo.models.combo_nn import Activation +from combo.config.registry import Registry +from combo.config.from_parameters import FromParameters +from combo.nn.activations import Activation import combo.utils.checks as checks from combo.data.vocabulary import Vocabulary from combo.models.utils import masked_cross_entropy +from combo.modules.module import Module from combo.predictors.predictor import Predictor -class Linear(nn.Linear): +@Registry.register_base_class("linear", default=True) +class Linear(nn.Linear, FromParameters): def __init__(self, in_features: int, out_features: int, - activation: Optional[Activation] = None, + activation: Activation = None, dropout_rate: Optional[float] = 0.0): super().__init__(in_features, out_features) self.activation = activation if activation else self.identity @@ -34,7 +37,7 @@ class Linear(nn.Linear): return x -class FeedForward(torch.nn.Module): +class FeedForward(Module): """ Modified copy of allennlp.modules.feedforward.FeedForward @@ -85,7 +88,7 @@ class FeedForward(torch.nn.Module): input_dim: int, num_layers: int, hidden_dims: Union[int, List[int]], - activations: Union[Activation, List[Activation]], + activations: List[Activation], # Change to Union[Activation, List[Activation]] dropout: Union[float, List[float]] = 0.0, ) -> None: @@ -187,7 +190,7 @@ class FeedForwardPredictor(Predictor): input_dim: int, num_layers: int, hidden_dims: List[int], - activations: Union[Activation, List[Activation]], + activations: List[Activation], # TODO: change to Union[Activation, List[Activation]] dropout: Union[float, List[float]] = 0.0, ): if len(hidden_dims) + 1 != num_layers: @@ -205,70 +208,3 @@ class FeedForwardPredictor(Predictor): hidden_dims=hidden_dims, activations=activations, dropout=dropout)) - - -""" -Adapted from AllenNLP -""" - - -class TimeDistributed(torch.nn.Module): - """ - Given an input shaped like `(batch_size, time_steps, [rest])` and a `Module` that takes - inputs like `(batch_size, [rest])`, `TimeDistributed` reshapes the input to be - `(batch_size * time_steps, [rest])`, applies the contained `Module`, then reshapes it back. - - Note that while the above gives shapes with `batch_size` first, this `Module` also works if - `batch_size` is second - we always just combine the first two dimensions, then split them. - - It also reshapes keyword arguments unless they are not tensors or their name is specified in - the optional `pass_through` iterable. - """ - - def __init__(self, module): - super().__init__() - self._module = module - - @overrides - def forward(self, *inputs, pass_through: List[str] = None, **kwargs): - - pass_through = pass_through or [] - - reshaped_inputs = [self._reshape_tensor(input_tensor) for input_tensor in inputs] - - # Need some input to then get the batch_size and time_steps. - some_input = None - if inputs: - some_input = inputs[-1] - - reshaped_kwargs = {} - for key, value in kwargs.items(): - if isinstance(value, torch.Tensor) and key not in pass_through: - if some_input is None: - some_input = value - - value = self._reshape_tensor(value) - - reshaped_kwargs[key] = value - - reshaped_outputs = self._module(*reshaped_inputs, **reshaped_kwargs) - - if some_input is None: - raise RuntimeError("No input tensor to time-distribute") - - # Now get the output back into the right shape. - # (batch_size, time_steps, **output_size) - new_size = some_input.size()[:2] + reshaped_outputs.size()[1:] - outputs = reshaped_outputs.contiguous().view(new_size) - - return outputs - - @staticmethod - def _reshape_tensor(input_tensor): - input_size = input_tensor.size() - if len(input_size) <= 2: - raise RuntimeError(f"No dimension to distribute: {input_size}") - # Squash batch_size and time_steps into a single axis; result has shape - # (batch_size * time_steps, **input_size). - squashed_shape = [-1] + list(input_size[2:]) - return input_tensor.contiguous().view(*squashed_shape) diff --git a/combo/nn/util.py b/combo/nn/util.py index 56c564da2da4f673ffda2ec09acbfa28619bd552..ae0cca77915972d134b9fdf2bedfbc1220d86684 100644 --- a/combo/nn/util.py +++ b/combo/nn/util.py @@ -2,7 +2,7 @@ Adapted from AllenNLP https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py """ -from typing import Union, Dict, Optional, List, Any +from typing import Union, Dict, Optional, List, Any, NamedTuple import torch @@ -10,6 +10,113 @@ from combo.common.util import int_to_device from combo.utils import ConfigurationError +StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"] + + +def tiny_value_of_dtype(dtype: torch.dtype): + """ + Returns a moderately tiny value for a given PyTorch data type that is used to avoid numerical + issues such as division by zero. + This is different from `info_value_of_dtype(dtype).tiny` because it causes some NaN bugs. + Only supports floating point dtypes. + """ + if not dtype.is_floating_point: + raise TypeError("Only supports floating point dtypes.") + if dtype == torch.float or dtype == torch.double: + return 1e-13 + elif dtype == torch.half: + return 1e-4 + else: + raise TypeError("Does not support dtype " + str(dtype)) + + +def get_device_of(tensor: torch.Tensor) -> int: + """ + Returns the device of the tensor. + """ + if not tensor.is_cuda: + return -1 + else: + return tensor.get_device() + +def combine_initial_dims(tensor: torch.Tensor) -> torch.Tensor: + """ + Given a (possibly higher order) tensor of ids with shape + (d1, ..., dn, sequence_length) + Return a view that's (d1 * ... * dn, sequence_length). + If original tensor is 1-d or 2-d, return it as is. + """ + if tensor.dim() <= 2: + return tensor + else: + return tensor.view(-1, tensor.size(-1)) + + +def uncombine_initial_dims(tensor: torch.Tensor, original_size: torch.Size) -> torch.Tensor: + """ + Given a tensor of embeddings with shape + (d1 * ... * dn, sequence_length, embedding_dim) + and the original shape + (d1, ..., dn, sequence_length), + return the reshaped tensor of embeddings with shape + (d1, ..., dn, sequence_length, embedding_dim). + If original size is 1-d or 2-d, return it as is. + """ + if len(original_size) <= 2: + return tensor + else: + view_args = list(original_size) + [tensor.size(-1)] + return tensor.view(*view_args) + + +def get_range_vector(size: int, device: int) -> torch.Tensor: + """ + Returns a range vector with the desired size, starting at 0. The CUDA implementation + is meant to avoid copy data from CPU to GPU. + """ + if device > -1: + return torch.cuda.LongTensor(size, device=device).fill_(1).cumsum(0) - 1 + else: + return torch.arange(0, size, dtype=torch.long) + + +def flatten_and_batch_shift_indices(indices: torch.Tensor, sequence_length: int) -> torch.Tensor: + if torch.max(indices) >= sequence_length or torch.min(indices) < 0: + raise ConfigurationError( + f"All elements in indices should be in range (0, {sequence_length - 1})" + ) + offsets = get_range_vector(indices.size(0), get_device_of(indices)) * sequence_length + for _ in range(len(indices.size()) - 1): + offsets = offsets.unsqueeze(1) + + # Shape: (batch_size, d_1, ..., d_n) + offset_indices = indices + offsets + + # Shape: (batch_size * d_1 * ... * d_n) + offset_indices = offset_indices.view(-1) + return offset_indices + + +def batched_index_select( + target: torch.Tensor, + indices: torch.LongTensor, + flattened_indices: Optional[torch.LongTensor] = None, +) -> torch.Tensor: + if flattened_indices is None: + # Shape: (batch_size * d_1 * ... * d_n) + flattened_indices = flatten_and_batch_shift_indices(indices, target.size(1)) + + # Shape: (batch_size * sequence_length, embedding_size) + flattened_target = target.view(-1, target.size(-1)) + + # Shape: (batch_size * d_1 * ... * d_n, embedding_size) + flattened_selected = flattened_target.index_select(0, flattened_indices) + selected_shape = list(indices.size()) + [target.size(-1)] + # Shape: (batch_size, d_1, ..., d_n, embedding_size) + selected_targets = flattened_selected.view(*selected_shape) + return selected_targets + + def move_to_device(obj, device: Union[torch.device, int]): """ Given a structure (possibly) containing Tensors, @@ -257,3 +364,33 @@ def get_token_offsets_from_text_field_inputs( return embedder_arg_value return None +def _check_incompatible_keys( + module, missing_keys: List[str], unexpected_keys: List[str], strict: bool +): + error_msgs: List[str] = [] + if missing_keys: + error_msgs.append( + "Missing key(s) in state_dict: {}".format(", ".join(f'"{k}"' for k in missing_keys)) + ) + if unexpected_keys: + error_msgs.append( + "Unexpected key(s) in state_dict: {}".format( + ", ".join(f'"{k}"' for k in unexpected_keys) + ) + ) + if error_msgs and strict: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + module.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + +class _IncompatibleKeys(NamedTuple): + missing_keys: List[str] + unexpected_keys: List[str] + + def __repr__(self): + if not self.missing_keys and not self.unexpected_keys: + return "<All keys matched successfully>" + return f"(missing_keys = {self.missing_keys}, unexpected_keys = {self.unexpected_keys})" + diff --git a/combo/predict.py b/combo/predict.py index aa547e428004ff97ff893a7c91c81d8a7e3f6e6e..a9c06078e7596580d92cd9b55e44e5c1831837b8 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -8,8 +8,10 @@ from overrides import overrides from combo import data, models from combo.common import util -from combo.data import sentence2conllu, tokens2conllu, conllu2sentence, tokenizers, Instance +from combo.config import Registry +from combo.data import tokenizers, Instance, conllu2sentence, tokens2conllu, sentence2conllu from combo.data.dataset_readers.dataset_reader import DatasetReader + from combo.data.instance import JsonDict from combo.predictors.predictor import Predictor from combo.utils import download, graph @@ -17,6 +19,7 @@ from combo.utils import download, graph logger = logging.getLogger(__name__) +@Registry.register(Predictor, 'combo') class COMBO(Predictor): def __init__(self, @@ -28,7 +31,6 @@ class COMBO(Predictor): super().__init__(model, dataset_reader) self.batch_size = batch_size self.vocab = model.vocab - self.dataset_reader = self._dataset_reader self.dataset_reader.generate_labels = False self.dataset_reader.lazy = True self._tokenizer = tokenizer @@ -246,6 +248,10 @@ class COMBO(Predictor): archive = models.load_archive(model_path, cuda_device=cuda_device) model = archive.model - dataset_reader = DatasetReader.from_params( - archive.config["dataset_reader"]) + dataset_reader_class = archive.config["dataset_reader"].get( + "type", "conllu") + dataset_reader = Registry.resolve( + DatasetReader, + dataset_reader_class + ).from_parameters(archive.config["dataset_reader"]) return cls(model, dataset_reader, tokenizer, batch_size) diff --git a/combo/predictors/predictor.py b/combo/predictors/predictor.py index e4917a7c8466d4fda3214043f4109a1703d3e639..94876e40de05c6b9dbbcf80042f148eee3275308 100644 --- a/combo/predictors/predictor.py +++ b/combo/predictors/predictor.py @@ -3,12 +3,11 @@ Adapted from AllenNLP https://github.com/allenai/allennlp/blob/main/allennlp/predictors/predictor.py """ -from typing import List, Iterator, Dict, Tuple, Any, Type, Union +from typing import List, Iterator, Dict, Tuple, Any import logging import json import re from contextlib import contextmanager -from pathlib import Path import numpy import torch @@ -17,16 +16,17 @@ from torch import Tensor from torch import backends from combo.common.util import sanitize +from combo.config import FromParameters from combo.data.batch import Batch from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.instance import JsonDict, Instance -from combo.models.model import Model +from combo.modules.model import Model from combo.nn import util logger = logging.getLogger(__name__) -class Predictor: +class Predictor(FromParameters): """ a `Predictor` is a thin wrapper around an AllenNLP model that handles JSON -> JSON predictions that can be used for serving models through the web API or making predictions in bulk. @@ -36,7 +36,7 @@ class Predictor: if frozen: model.eval() self._model = model - self._dataset_reader = dataset_reader + self.dataset_reader = dataset_reader self.cuda_device = next(self._model.named_parameters())[1].get_device() self._token_offsets: List[Tensor] = [] diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/config/test_configuration.py b/tests/config/test_configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..ec69bb0147923a68ff906c937bed7a7d4dd8c430 --- /dev/null +++ b/tests/config/test_configuration.py @@ -0,0 +1,31 @@ +import unittest + +from combo.config import Registry +from combo.data import WhitespaceTokenizer, DatasetReader, UniversalDependenciesDatasetReader +from combo.data.token_indexers.token_characters_indexer import TokenCharactersIndexer + + +class ConfigurationTest(unittest.TestCase): + + def test_dataset_reader_from_registry(self): + dataset_reader = Registry.resolve(DatasetReader, 'conllu')() + self.assertEqual(type(dataset_reader), UniversalDependenciesDatasetReader) + + def test_dataset_reader_from_registry_with_parameters(self): + parameters = {'token_indexers': {'char': {'type': 'token_characters'}}, + 'use_sem': True} + dataset_reader = Registry.resolve(DatasetReader, 'conllu').from_parameters(parameters) + self.assertEqual(type(dataset_reader), UniversalDependenciesDatasetReader) + self.assertEqual(type(dataset_reader.token_indexers['char']), TokenCharactersIndexer) + self.assertEqual(dataset_reader.use_sem, True) + + def test_dataset_reader_from_registry_with_token_indexer_parameters(self): + parameters = {'token_indexers': {'char': {'type': 'token_characters', + 'namespace': 'custom_namespace', + 'tokenizer': { + 'type': 'whitespace' + }}}} + dataset_reader = Registry.resolve(DatasetReader, 'conllu').from_parameters(parameters) + self.assertEqual(type(dataset_reader), UniversalDependenciesDatasetReader) + self.assertEqual(dataset_reader.token_indexers['char']._namespace, 'custom_namespace') + self.assertEqual(type(dataset_reader.token_indexers['char']._character_tokenizer), WhitespaceTokenizer)