diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000000000000000000000000000000000000..88d1c8dbb6d083104d27efd4a7833759b6bd52e0 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,588 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-whitelist= + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS,venv + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns=/data + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins=pylint_quotes + +# Pickle collected data for later comparisons. +persistent=yes + +# Specify a configuration file. +#rcfile= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=print-statement, + parameter-unpacking, + unpacking-in-except, + old-raise-syntax, + backtick, + long-suffix, + old-ne-operator, + old-octal-literal, + import-star-module-level, + non-ascii-bytes-literal, + raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + apply-builtin, + basestring-builtin, + buffer-builtin, + cmp-builtin, + coerce-builtin, + execfile-builtin, + file-builtin, + long-builtin, + raw_input-builtin, + reduce-builtin, + standarderror-builtin, + unicode-builtin, + xrange-builtin, + coerce-method, + delslice-method, + getslice-method, + setslice-method, + no-absolute-import, + old-division, + dict-iter-method, + dict-view-method, + next-method-called, + metaclass-assignment, + indexing-exception, + raising-string, + reload-builtin, + oct-method, + hex-method, + nonzero-method, + cmp-method, + input-builtin, + round-builtin, + intern-builtin, + unichr-builtin, + map-builtin-not-iterating, + zip-builtin-not-iterating, + range-builtin-not-iterating, + filter-builtin-not-iterating, + using-cmp-argument, + eq-without-hash, + div-method, + idiv-method, + rdiv-method, + exception-message-attribute, + invalid-str-codec, + sys-max-int, + bad-python3-import, + deprecated-string-function, + deprecated-str-translate-call, + deprecated-itertools-function, + deprecated-types-field, + next-method-defined, + dict-items-not-iterating, + dict-keys-not-iterating, + dict-values-not-iterating, + deprecated-operator-function, + deprecated-urllib-function, + xreadlines-attribute, + deprecated-sys-function, + exception-escape, + comprehension-escape, + arguments-differ, + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'error', 'warning', 'refactor', and 'convention' +# which contain the number of messages in each category, as well as 'statement' +# which is the total number of statements analyzed. This score is used by the +# global evaluation report (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. +#class-attribute-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + x, + _ + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_|test.*|^.*Test + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. +#variable-rgx= + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )?<?https?://\S+>?$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=120 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# List of optional constructs for which whitespace checking is disabled. `dict- +# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. +# `trailing-comma` allows a space between comma and closing bracket: (a, ). +# `empty-line` allows space-only lines. +no-space-check=trailing-comma, + dict-separator + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + + +[LOGGING] + +# Format style used to check logging format string. `old` means using % +# formatting, `new` is for `{}` formatting,and `fstr` is for f-strings. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[STRING] + +# This flag controls whether the implicit-str-concat-in-sequence should +# generate a warning on implicit string concatenation in sequences defined over +# several lines. +check-str-concat-over-line-jumps=no + + +[STRING_QUOTES] + +string-quote=single +triple-quote=double +docstring-quote=double + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members=numpy.*, torch.* + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[DESIGN] + +# Maximum number of arguments for function / method. +max-args=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=BaseException, + Exception diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..df0c4836a738746e334ca86483674a5049cb7a5d --- /dev/null +++ b/Makefile @@ -0,0 +1,14 @@ +clean: + rm -rf COMBO.egg-info + rm -rf .eggs + rm -rf .pytest_cache + +develop: + python setup.py develop + +install: + python setup.py install + +test: + python setup.py test + pylint --rcfile=.pylintrc tests combo \ No newline at end of file diff --git a/README.md b/README.md index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0e4c660aba94e8403d5c01ac0363db82892b8476 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,83 @@ +## Installation + +Clone this repository and run: +```bash +python setup.py develop +``` + +## Training + +Command: +```bash +combo --mode train \ + --training_data_path your_training_path \ + --validation_data_path your_validation_path +``` + +Options: +```bash +combo --helpfull +``` + +Examples (for clarity without training/validation data paths): + +* train on gpu 0 + ```bash + combo --mode train --cuda_davice 0 + ``` +* use pretrained embeddings: + ```bash + combo --mode train --pretrained_tokens your_pretrained_embeddings_path --embedding_dim your_embeddings_dim + ``` +* use pretrained transformer embeddings: + ```bash + combo --mode train --pretrained_transformer_name your_choosen_pretrained_transformer + ``` +* predict only dependency tree: + ```bash + combo --mode train --targets head --targets deprel + ``` +* use part-of-speech tags for predicting only dependency tree + ```bash + combo --mode train --targets head --targets deprel --features token --features char --features upostag + ``` +Advanced configuration: [Configuration](#configuration) + +## Prediction + +### ConLLU file prediction: +Input and output are both in `*.conllu` format. +```bash +combo --mode predict --model_path your_model_tar_gz --input_file your_conllu_file --output_file your_output_file --silent +``` + +### Console +Works for models where input was text-based only. + +Interactive testing in console (load model and just type sentence in console). + +```bash +combo --mode predict --model_path your_model_tar_gz --input_file "-" +``` +### Raw text +Works for models where input was text-based only. + +Input: one sentence per line. + +Output: List of token jsons. + +```bash +combo --mode predict --model_path your_model_tar_gz --input_file your_text_file --output_file your_output_file --silent +``` +#### Advanced + +There are 2 tokenizers: whitespace and spacy-based (`en_core_web_sm` model). + +Use either `--predictor_name semantic-multitask-predictor` or `--predictor_name semantic-multitask-predictor-spacy`. + +## Configuration + +### Advanced +Config template [config.template.jsonnet](config.template.jsonnet) is formed in `allennlp` format so you can freely modify it. +There is configuration for all the training/model parameters (learning rates, epochs number etc.). +Some of them use `jsonnet` syntax to get values from configuration flags, however most of them can be modified directly there. diff --git a/combo/__init__.py b/combo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/combo/commands/__init__.py b/combo/commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32071442cc22f05bede1d4533fba4480a7db8563 --- /dev/null +++ b/combo/commands/__init__.py @@ -0,0 +1 @@ +from .train import FinetuningTrainModel diff --git a/combo/commands/train.py b/combo/commands/train.py new file mode 100644 index 0000000000000000000000000000000000000000..44c291596c3dfba29811a35ec316dbaeee5910c0 --- /dev/null +++ b/combo/commands/train.py @@ -0,0 +1,171 @@ +import os +from typing import List + +from allennlp import data, models, common, training +from allennlp.commands import train +from allennlp.common import checks +from allennlp.common import util as common_util +from allennlp.training import util as training_util + + +@train.TrainModel.register("finetuning", constructor="from_partial_objects_finetuning") +class FinetuningTrainModel(train.TrainModel): + """Class made only for finetuning, the only difference is saving vocab from concatenated + (archive and current) datasets.""" + + @classmethod + def from_partial_objects_finetuning( + cls, + serialization_dir: str, + local_rank: int, + batch_weight_key: str, + dataset_reader: data.DatasetReader, + train_data_path: str, + model: common.Lazy[models.Model], + data_loader: common.Lazy[data.DataLoader], + trainer: common.Lazy[training.Trainer], + vocabulary: common.Lazy[data.Vocabulary] = None, + datasets_for_vocab_creation: List[str] = None, + validation_dataset_reader: data.DatasetReader = None, + validation_data_path: str = None, + validation_data_loader: common.Lazy[data.DataLoader] = None, + test_data_path: str = None, + evaluate_on_test: bool = False, + ) -> "train.TrainModel": + """ + This method is intended for use with our `FromParams` logic, to construct a `TrainModel` + object from a config file passed to the `allennlp train` command. The arguments to this + method are the allowed top-level keys in a configuration file (except for the first three, + which are obtained separately). + + You *could* use this outside of our `FromParams` logic if you really want to, but there + might be easier ways to accomplish your goal than instantiating `Lazy` objects. If you are + writing your own training loop, we recommend that you look at the implementation of this + method for inspiration and possibly some utility functions you can call, but you very likely + should not use this method directly. + + The `Lazy` type annotations here are a mechanism for building dependencies to an object + sequentially - the `TrainModel` object needs data, a model, and a trainer, but the model + needs to see the data before it's constructed (to create a vocabulary) and the trainer needs + the data and the model before it's constructed. Objects that have sequential dependencies + like this are labeled as `Lazy` in their type annotations, and we pass the missing + dependencies when we call their `construct()` method, which you can see in the code below. + + # Parameters + serialization_dir: `str` + The directory where logs and model archives will be saved. + local_rank: `int` + The process index that is initialized using the GPU device id. + batch_weight_key: `str` + The name of metric used to weight the loss on a per-batch basis. + dataset_reader: `DatasetReader` + The `DatasetReader` that will be used for training and (by default) for validation. + train_data_path: `str` + The file (or directory) that will be passed to `dataset_reader.read()` to construct the + training data. + model: `Lazy[Model]` + The model that we will train. This is lazy because it depends on the `Vocabulary`; + after constructing the vocabulary we call `model.construct(vocab=vocabulary)`. + data_loader: `Lazy[DataLoader]` + The data_loader we use to batch instances from the dataset reader at training and (by + default) validation time. This is lazy because it takes a dataset in it's constructor. + trainer: `Lazy[Trainer]` + The `Trainer` that actually implements the training loop. This is a lazy object because + it depends on the model that's going to be trained. + vocabulary: `Lazy[Vocabulary]`, optional (default=None) + The `Vocabulary` that we will use to convert strings in the data to integer ids (and + possibly set sizes of embedding matrices in the `Model`). By default we construct the + vocabulary from the instances that we read. + datasets_for_vocab_creation: `List[str]`, optional (default=None) + If you pass in more than one dataset but don't want to use all of them to construct a + vocabulary, you can pass in this key to limit it. Valid entries in the list are + "train", "validation" and "test". + validation_dataset_reader: `DatasetReader`, optional (default=None) + If given, we will use this dataset reader for the validation data instead of + `dataset_reader`. + validation_data_path: `str`, optional (default=None) + If given, we will use this data for computing validation metrics and early stopping. + validation_data_loader: `Lazy[DataLoader]`, optional (default=None) + If given, the data_loader we use to batch instances from the dataset reader at + validation and test time. This is lazy because it takes a dataset in it's constructor. + test_data_path: `str`, optional (default=None) + If given, we will use this as test data. This makes it available for vocab creation by + default, but nothing else. + evaluate_on_test: `bool`, optional (default=False) + If given, we will evaluate the final model on this data at the end of training. Note + that we do not recommend using this for actual test data in every-day experimentation; + you should only very rarely evaluate your model on actual test data. + """ + + datasets = training_util.read_all_datasets( + train_data_path=train_data_path, + dataset_reader=dataset_reader, + validation_dataset_reader=validation_dataset_reader, + validation_data_path=validation_data_path, + test_data_path=test_data_path, + ) + + if datasets_for_vocab_creation: + for key in datasets_for_vocab_creation: + if key not in datasets: + raise checks.ConfigurationError(f"invalid 'dataset_for_vocab_creation' {key}") + + instance_generator = ( + instance + for key, dataset in datasets.items() + if not datasets_for_vocab_creation or key in datasets_for_vocab_creation + for instance in dataset + ) + + vocabulary_ = vocabulary.construct(instances=instance_generator) + if not vocabulary_: + vocabulary_ = data.Vocabulary.from_instances(instance_generator) + model_ = model.construct(vocab=vocabulary_) + + # Initializing the model can have side effect of expanding the vocabulary. + # Save the vocab only in the master. In the degenerate non-distributed + # case, we're trivially the master. + if common_util.is_master(): + vocabulary_path = os.path.join(serialization_dir, "vocabulary") + # Only difference compared to TrainModel! + model_.vocab.save_to_files(vocabulary_path) + + for dataset in datasets.values(): + dataset.index_with(model_.vocab) + + data_loader_ = data_loader.construct(dataset=datasets["train"]) + validation_data = datasets.get("validation") + if validation_data is not None: + # Because of the way Lazy[T] works, we can't check it's existence + # _before_ we've tried to construct it. It returns None if it is not + # present, so we try to construct it first, and then afterward back off + # to the data_loader configuration used for training if it returns None. + validation_data_loader_ = validation_data_loader.construct(dataset=validation_data) + if validation_data_loader_ is None: + validation_data_loader_ = data_loader.construct(dataset=validation_data) + else: + validation_data_loader_ = None + + test_data = datasets.get("test") + if test_data is not None: + test_data_loader = validation_data_loader.construct(dataset=test_data) + if test_data_loader is None: + test_data_loader = data_loader.construct(dataset=test_data) + else: + test_data_loader = None + + # We don't need to pass serialization_dir and local_rank here, because they will have been + # passed through the trainer by from_params already, because they were keyword arguments to + # construct this class in the first place. + trainer_ = trainer.construct( + model=model_, data_loader=data_loader_, validation_data_loader=validation_data_loader_, + ) + + return cls( + serialization_dir=serialization_dir, + model=model_, + trainer=trainer_, + evaluation_data_loader=test_data_loader, + evaluate_on_test=evaluate_on_test, + batch_weight_key=batch_weight_key, + ) diff --git a/combo/data/__init__.py b/combo/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7abffa2b3206b29c93e14f86ae76101a6a859cf3 --- /dev/null +++ b/combo/data/__init__.py @@ -0,0 +1,2 @@ +from .samplers import TokenCountBatchSampler +from .token_indexers import TokenCharactersIndexer diff --git a/combo/data/dataset.py b/combo/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..25a4f07a074bca03c452fbf3b3ce682e678bcd83 --- /dev/null +++ b/combo/data/dataset.py @@ -0,0 +1,227 @@ +import logging +from typing import Union, List, Dict, Iterable, Optional + +import conllu +from allennlp import data as allen_data +from allennlp.common import checks +from allennlp.data import fields as allen_fields, vocabulary +from conllu import parser +from overrides import overrides + +from combo.data import fields + +logger = logging.getLogger(__name__) + + +@allen_data.DatasetReader.register('conllu') +class UniversalDependenciesDatasetReader(allen_data.DatasetReader): + + def __init__( + self, + token_indexers: Dict[str, allen_data.TokenIndexer] = None, + lemma_indexers: Dict[str, allen_data.TokenIndexer] = None, + features: List[str] = None, + targets: List[str] = None, + use_sem: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if features is None: + features = ['token', 'char'] + if targets is None: + targets = ['head', 'deprel', 'upostag', 'xpostag', 'lemma', 'feats'] + + if 'token' not in features and 'char' not in features: + raise checks.ConfigurationError("There must be at least one ('char' or 'token') text-based feature!") + + intersection = set(features).intersection(set(targets)) + if len(intersection) != 0: + raise checks.ConfigurationError( + "Features and targets cannot share elements! " + "Remove {} from either features or targets.".format(intersection) + ) + self._use_sem = use_sem + + # *.conllu readers configuration + fields = list(parser.DEFAULT_FIELDS) + fields[1] = 'token' # use 'token' instead of 'form' + field_parsers = parser.DEFAULT_FIELD_PARSERS + if self._use_sem: + fields = list(fields) + fields.append('semrel') + field_parsers['semrel'] = lambda line, i: parser.parse_nullable_value(line[i]), + self.field_parsers = field_parsers + self.fields = tuple(fields) + + self._token_indexers = token_indexers + self._lemma_indexers = lemma_indexers + self._targets = targets + self._features = features + self.generate_labels = True + # Filter out not required token indexers to avoid + # Mismatched token keys ConfigurationError + for indexer_name in list(self._token_indexers.keys()): + if indexer_name not in self._features: + del self._token_indexers[indexer_name] + + @overrides + def _read(self, file_path: str) -> Iterable[allen_data.Instance]: + file_path = [file_path] if len(file_path.split(':')) == 0 else file_path.split(':') + + for conllu_file in file_path: + with open(conllu_file, 'r') as f: + for annotation in conllu.parse_incr(f, fields=self.fields, field_parsers=self.field_parsers): + # CoNLLU annotations sometimes add back in words that have been elided + # in the original sentence; we remove these, as we're just predicting + # dependencies for the original sentence. + # We filter by integers here as elided words have a non-integer word id, + # as parsed by the conllu python library. + annotation = conllu.TokenList([x for x in annotation if isinstance(x['id'], int)]) + yield self.text_to_instance(annotation) + + @overrides + def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance: + fields_: Dict[str, allen_data.Field] = {} + tokens = [allen_data.Token(t['token'], + pos_=t.get('upostag'), + tag_=t.get('xpostag'), + lemma_=t.get('lemma')) + for t in tree] + + # features + text_field = allen_fields.TextField(tokens, self._token_indexers) + fields_['sentence'] = text_field + + # targets + if self.generate_labels: + for target_name in self._targets: + if target_name != 'sent': + target_values = [t[target_name] for t in tree.tokens] + if target_name == 'lemma': + target_values = [allen_data.Token(v) for v in target_values] + fields_[target_name] = allen_fields.TextField(target_values, self._lemma_indexers) + elif target_name == 'feats': + target_values = self._feat_values(tree) + fields_[target_name] = fields.SequenceMultiLabelField(target_values, + self._feats_to_index_multi_label, + text_field, + label_namespace="feats_labels") + elif target_name == 'head': + target_values = [0 if v == '_' else int(v) for v in target_values] + fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field, + label_namespace=target_name + "_labels") + else: + fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field, + label_namespace=target_name + "_labels") + + # metadata + fields_["metadata"] = allen_fields.MetadataField({'input': tree, 'field_names': self.fields}) + return allen_data.Instance(fields_) + + @staticmethod + def _feat_values(tree: conllu.TokenList): + features = [] + for token in tree: + token_features = [] + if token['feats'] is not None: + for feat, value in token['feats'].items(): + if feat in ['_', '__ROOT__']: + pass + else: + token_features.append(feat + '=' + value) + features.append(token_features) + return features + + @staticmethod + def _feats_to_index_multi_label(vocab: allen_data.Vocabulary): + label_namespace = "feats_labels" + vocab_size = vocab.get_vocab_size(label_namespace) + slices = get_slices_if_not_provided(vocab) + + def _m_from_n_ones_encoding(multi_label: List[str]) -> List[int]: + one_hot_encoding = [0] * vocab_size + for cat, cat_indices in slices.items(): + if cat not in ['__PAD__', '_']: + label_from_cat = [label for label in multi_label if cat == label.split('=')[0]] + if label_from_cat: + label_from_cat = label_from_cat[0] + index = vocab.get_token_index(label_from_cat, label_namespace) + else: + # Get Cat=None index + index = vocab.get_token_index(cat + "=None", label_namespace) + one_hot_encoding[index] = 1 + return one_hot_encoding + + return _m_from_n_ones_encoding + + +@allen_data.Vocabulary.register('from_instances_extended', constructor='from_instances_extended') +class Vocabulary(allen_data.Vocabulary): + + @classmethod + def from_instances_extended( + cls, + instances: Iterable[allen_data.Instance], + min_count: Dict[str, int] = None, + max_vocab_size: Union[int, Dict[str, int]] = None, + non_padded_namespaces: Iterable[str] = vocabulary.DEFAULT_NON_PADDED_NAMESPACES, + pretrained_files: Optional[Dict[str, str]] = None, + only_include_pretrained_words: bool = False, + min_pretrained_embeddings: Dict[str, int] = None, + padding_token: Optional[str] = vocabulary.DEFAULT_PADDING_TOKEN, + oov_token: Optional[str] = vocabulary.DEFAULT_OOV_TOKEN, + ) -> "Vocabulary": + """ + Extension to manually fill gaps in missing 'feats_labels'. + """ + # Load manually tokens from pretrained file (using different strategy + # - only words add all embedding file, without checking if were seen + # in any dataset. + tokens_to_add = None + if pretrained_files and 'tokens' in pretrained_files: + pretrained_set = set(vocabulary._read_pretrained_tokens(pretrained_files['tokens'])) + tokens_to_add = {'tokens': list(pretrained_set)} + pretrained_files = None + + vocab = super().from_instances( + instances=instances, + min_count=min_count, + max_vocab_size=max_vocab_size, + non_padded_namespaces=non_padded_namespaces, + pretrained_files=pretrained_files, + only_include_pretrained_words=only_include_pretrained_words, + tokens_to_add=tokens_to_add, + min_pretrained_embeddings=min_pretrained_embeddings, + padding_token=padding_token, + oov_token=oov_token + ) + # Extending vocab with features that does not show up explicitly. + # To know all features we need to read full dataset first. + # Adding auxiliary '=None' feature for each category is needed + # to perform classification. + get_slices_if_not_provided(vocab) + return vocab + + +def get_slices_if_not_provided(vocab: allen_data.Vocabulary): + if hasattr(vocab, 'slices'): + return vocab.slices + + if 'feats_labels' in vocab.get_namespaces(): + idx2token = vocab.get_index_to_token_vocabulary('feats_labels') + for k, v in dict(idx2token).items(): + if v not in ['_', '__PAD__']: + empty_value = v.split("=")[0] + "=None" + vocab.add_token_to_namespace(empty_value, 'feats_labels') + + slices = {} + for idx, name in vocab.get_index_to_token_vocabulary('feats_labels').items(): + # There are 2 types features: with (Case=Acc) or without assigment (None). + # Here we group their indices by name (before assigment sign). + name = name.split('=')[0] + if name in slices: + slices[name].append(idx) + else: + slices[name] = [idx] + vocab.slices = slices + return vocab.slices diff --git a/combo/data/fields/__init__.py b/combo/data/fields/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8d3a265f506aab43e0cda78ec931cd5c44c92e --- /dev/null +++ b/combo/data/fields/__init__.py @@ -0,0 +1 @@ +from .sequence_multilabel_field import SequenceMultiLabelField diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py new file mode 100644 index 0000000000000000000000000000000000000000..4e98a148aee35e42af0b4828a031368fe0eafc12 --- /dev/null +++ b/combo/data/fields/sequence_multilabel_field.py @@ -0,0 +1,138 @@ +"""Sequence multilabel field implementation.""" +import logging +import textwrap +from typing import Set, List, Callable, Iterator, Union, Dict + +import torch +from allennlp import data +from allennlp.common import checks, util +from allennlp.data import fields +from overrides import overrides + +logger = logging.getLogger(__name__) + + +class SequenceMultiLabelField(data.Field[torch.Tensor]): + """ + A `SequenceMultiLabelField` is an extension of the :class:`MultiLabelField` that allows for multiple labels + while keeping sequence dimension. + + This field will get converted into a sequence of vectors of length equal to the vocabulary size with + M from N encoding for the labels (all zeros, and ones for the labels). + + # Parameters + + multi_labels : `List[List[str]]` + multi_label_indexer : `Callable[[data.Vocabulary], Callable[[List[str]], List[int]]]` + Nested callable which based on vocab creates mapper for multilabel field in the sequence from strings + to indexed, int values. + sequence_field : `SequenceField` + A field containing the sequence that this `SequenceMultiLabelField` is labeling. Most often, this is a + `TextField`, for tagging individual tokens in a sentence. + label_namespace : `str`, optional (default="labels") + The namespace to use for converting label strings into integers. We map label strings to + integers for you (e.g., "entailment" and "contradiction" get converted to 0, 1, ...), + and this namespace tells the `Vocabulary` object which mapping from strings to integers + to use (so "entailment" as a label doesn't get the same integer id as "entailment" as a + word). If you have multiple different label fields in your data, you should make sure you + use different namespaces for each one, always using the suffix "labels" (e.g., + "passage_labels" and "question_labels"). + """ + _already_warned_namespaces: Set[str] = set() + + def __init__( + self, + multi_labels: List[List[str]], + multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str]], List[int]]], + sequence_field: fields.SequenceField, + label_namespace: str = "labels", + ) -> None: + self.multi_labels = multi_labels + self.sequence_field = sequence_field + self.multi_label_indexer = multi_label_indexer + self._label_namespace = label_namespace + self._indexed_multi_labels = None + self._maybe_warn_for_namespace(label_namespace) + if len(multi_labels) != sequence_field.sequence_length(): + raise checks.ConfigurationError( + "Label length and sequence length " + "don't match: %d and %d" % (len(multi_labels), sequence_field.sequence_length()) + ) + + if not all([isinstance(x, str) for multi_label in multi_labels for x in multi_label]): + raise checks.ConfigurationError( + "SequenceMultiLabelField must be passed either all " + "strings or all ints. Found labels {} with " + "types: {}.".format(multi_labels, [type(x) for multi_label in multi_labels for x in multi_label]) + ) + + def _maybe_warn_for_namespace(self, label_namespace: str) -> None: + if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")): + if label_namespace not in self._already_warned_namespaces: + logger.warning( + "Your label namespace was '%s'. We recommend you use a namespace " + "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by " + "default to your vocabulary. See documentation for " + "`non_padded_namespaces` parameter in Vocabulary.", + self._label_namespace, + ) + self._already_warned_namespaces.add(label_namespace) + + # Sequence methods + def __iter__(self) -> Iterator[Union[List[str], int]]: + return iter(self.multi_labels) + + def __getitem__(self, idx: int) -> Union[List[str], int]: + return self.multi_labels[idx] + + def __len__(self) -> int: + return len(self.multi_labels) + + @overrides + def count_vocab_items(self, counter: Dict[str, Dict[str, int]]): + if self._indexed_multi_labels is None: + for multi_label in self.multi_labels: + for label in multi_label: + counter[self._label_namespace][label] += 1 # type: ignore + + @overrides + def index(self, vocab: data.Vocabulary): + indexer = self.multi_label_indexer(vocab) + + indexed = [] + for multi_label in self.multi_labels: + indexed.append(indexer(multi_label)) + self._indexed_multi_labels = indexed + + @overrides + def get_padding_lengths(self) -> Dict[str, int]: + return {"num_tokens": self.sequence_field.sequence_length()} + + @overrides + def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: + desired_num_tokens = padding_lengths["num_tokens"] + assert len(self._indexed_multi_labels) > 0 + classes_count = len(self._indexed_multi_labels[0]) + default_value = [0.0] * classes_count + padded_tags = util.pad_sequence_to_length(self._indexed_multi_labels, desired_num_tokens, lambda: default_value) + tensor = torch.LongTensor(padded_tags) + return tensor + + @overrides + def empty_field(self) -> "SequenceMultiLabelField": + # The empty_list here is needed for mypy + empty_list: List[List[str]] = [[]] + sequence_label_field = SequenceMultiLabelField(empty_list, lambda x: lambda y: y, + self.sequence_field.empty_field()) + sequence_label_field._indexed_labels = empty_list + return sequence_label_field + + def __str__(self) -> str: + length = self.sequence_field.sequence_length() + formatted_labels = "".join( + "\t\t" + labels + "\n" for labels in textwrap.wrap(repr(self.multi_labels), 100) + ) + return ( + f"SequenceMultiLabelField of length {length} with " + f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'." + ) diff --git a/combo/data/samplers/__init__.py b/combo/data/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab003fe4341da38983553324b0454681a715eda1 --- /dev/null +++ b/combo/data/samplers/__init__.py @@ -0,0 +1 @@ +from .samplers import TokenCountBatchSampler diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..a3c86f4953d59efcc4b6c5246ce16242a48f3f77 --- /dev/null +++ b/combo/data/samplers/samplers.py @@ -0,0 +1,54 @@ +from typing import List + +import numpy as np + +from allennlp import data as allen_data + + +@allen_data.BatchSampler.register("token_count") +class TokenCountBatchSampler(allen_data.BatchSampler): + + def __init__(self, dataset, word_batch_size: int = 2500, shuffle_dataset: bool = True): + self._index = 0 + self.shuffle_dataset = shuffle_dataset + self.batch_dataset = self._batchify(dataset, word_batch_size) + if shuffle_dataset: + self._shuffle() + + @staticmethod + def _batchify(dataset, word_batch_size) -> List[List[int]]: + dataset = list(dataset) + batches = [] + batch = [] + words_count = 0 + lengths = [len(instance.fields['sentence'].tokens) for instance in dataset] + argsorted_lengths = np.argsort(lengths) + for idx in argsorted_lengths: + words_count += lengths[idx] + batch.append(idx) + if words_count > word_batch_size: + batches.append(batch) + words_count = 0 + batch = [] + return batches + + def __iter__(self): + return self + + def __next__(self): + if self._index >= len(self.batch_dataset): + if self.shuffle_dataset: + self._index = 0 + self._shuffle() + raise StopIteration() + + batch = self.batch_dataset[self._index] + self._index += 1 + return batch + + def _shuffle(self): + indices = np.random.permutation(range(len(self.batch_dataset))) + self.batch_dataset = np.array(self.batch_dataset)[indices].tolist() + + def __len__(self): + return len(self.batch_dataset) diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2a7b7510477a68af50c60d40cc21eef3d863bc --- /dev/null +++ b/combo/data/token_indexers/__init__.py @@ -0,0 +1 @@ +from .token_characters_indexer import TokenCharactersIndexer diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..ea7a3eaab32a79758573fe6f778c59060d9bce6c --- /dev/null +++ b/combo/data/token_indexers/token_characters_indexer.py @@ -0,0 +1,62 @@ +"""Custom character token indexer.""" +import itertools +from typing import List, Dict + +import torch +from allennlp import data +from allennlp.common import util +from allennlp.data import tokenizers +from allennlp.data.token_indexers import token_characters_indexer +from overrides import overrides + + +@data.TokenIndexer.register("characters_const_padding") +class TokenCharactersIndexer(token_characters_indexer.TokenCharactersIndexer): + """Wrapper around allennlp token indexer with const padding.""" + + def __init__(self, + namespace: str = "token_characters", + character_tokenizer: tokenizers.CharacterTokenizer = tokenizers.CharacterTokenizer(), + start_tokens: List[str] = None, + end_tokens: List[str] = None, + min_padding_length: int = 0, + token_min_padding_length: int = 0): + super().__init__(namespace, character_tokenizer, start_tokens, end_tokens, min_padding_length, + token_min_padding_length) + + @overrides + def get_padding_lengths(self, indexed_tokens: data.IndexedTokenList) -> Dict[str, int]: + padding_lengths = {"token_characters": len(indexed_tokens["token_characters"]), + "num_token_characters": self._min_padding_length} + return padding_lengths + + @overrides + def as_padded_tensor_dict( + self, tokens: data.IndexedTokenList, padding_lengths: Dict[str, int] + ) -> Dict[str, torch.Tensor]: + # Pad the tokens. + padded_tokens = util.pad_sequence_to_length( + tokens["token_characters"], + padding_lengths["token_characters"], + default_value=lambda: [], + ) + + # Pad the characters within the tokens. + desired_token_length = padding_lengths["num_token_characters"] + longest_token: List[int] = max(tokens["token_characters"], key=len, default=[]) # type: ignore + padding_value = 0 + if desired_token_length > len(longest_token): + # Since we want to pad to greater than the longest token, we add a + # "dummy token" so we can take advantage of the fast implementation of itertools.zip_longest. + padded_tokens.append([padding_value] * desired_token_length) + # pad the list of lists to the longest sublist, appending 0's + padded_tokens = list(zip(*itertools.zip_longest(*padded_tokens, fillvalue=padding_value))) + if desired_token_length > len(longest_token): + # Removes the "dummy token". + padded_tokens.pop() + # Truncates all the tokens to the desired length, and return the result. + return { + "token_characters": torch.LongTensor( + [list(token[:desired_token_length]) for token in padded_tokens] + ) + } diff --git a/combo/main.py b/combo/main.py new file mode 100644 index 0000000000000000000000000000000000000000..2f37661ac42fb1ea34aea756da1dc966ec662909 --- /dev/null +++ b/combo/main.py @@ -0,0 +1,189 @@ +"""Main entry point.""" +import logging +import os +import tempfile +from typing import Dict + +import torch +from absl import app +from absl import flags +from allennlp import common, models, predictors +from allennlp.commands import train, predict as allen_predict +from allennlp.common import checks as allen_checks, util +from allennlp.models import archival + +from combo import predict +from combo.data import dataset +from combo.utils import checks + +logger = logging.getLogger(__name__) + +FLAGS = flags.FLAGS +flags.DEFINE_enum(name='mode', default=None, enum_values=['train', 'predict'], + help="Specify COMBO mode: train or precit") + +# Common flags +flags.DEFINE_integer(name='cuda_device', default=-1, + help="Cuda device id (default -1 cpu)") + +# Training flags +flags.DEFINE_string(name='training_data_path', default="./tests/fixtures/example.conllu", + help='Training data path') +flags.DEFINE_string(name='validation_data_path', default='', + help='Validation data path') +flags.DEFINE_string(name='pretrained_tokens', default='', + help='Pretrained tokens embeddings path') +flags.DEFINE_integer(name='embedding_dim', default=300, + help='Embeddings dim') +flags.DEFINE_string(name='pretrained_transformer_name', default='', + help='Pretrained transformer model name (see transformers from HuggingFace library for list of' + 'available models) for transformers based embeddings.') +flags.DEFINE_multi_enum(name='features', default=['token', 'char'], + enum_values=['token', 'char', 'upostag', 'xpostag', 'lemma'], + help='Features used to train model (required \'token\' and \'char\')') +flags.DEFINE_multi_enum(name='targets', default=['deprel', 'feats', 'head', 'lemma', 'upostag', 'xpostag'], + enum_values=['deprel', 'feats', 'head', 'lemma', 'upostag', 'xpostag', 'semrel', 'sent'], + help='Targets of the model (required \'deprel\' and \'head\')') +flags.DEFINE_string(name='serialization_dir', default=None, + help='Model serialization directory (default - system temp dir).') + +# Finetune after training flags +flags.DEFINE_string(name='finetuning_training_data_path', default='', + help='Training data path') +flags.DEFINE_string(name='finetuning_validation_data_path', default='', + help='Validation data path') +flags.DEFINE_string(name='config_path', default='config.template.jsonnet', + help='Config file path.') + +# Test after training flags +flags.DEFINE_string(name='result', default='result.conll', + help='Test result path file') +flags.DEFINE_string(name='test_path', default=None, + help='Test path file.') + +# Experimental +flags.DEFINE_boolean(name='use_pure_config', default=False, + help='Ignore ext flags (experimental).') + +# Prediction flags +flags.DEFINE_string(name='model_path', default=None, + help='Pretrained model path.') +flags.DEFINE_string(name='input_file', default=None, + help='File to predict path') +flags.DEFINE_string(name='output_file', default='output.log', + help='Predictions result file.') +flags.DEFINE_integer(name='batch_size', default=1, + help='Prediction batch size.') +flags.DEFINE_boolean(name='silent', default=True, + help='Silent prediction to file (without printing to console).') +flags.DEFINE_enum(name='predictor_name', default='semantic-multitask-predictor-spacy', + enum_values=['semantic-multitask-predictor', 'semantic-multitask-predictor-spacy'], + help='Use predictor with whitespace or spacy tokenizer.') + + +def run(_): + """Run model.""" + # Imports are required to make Registrable modules visible without passing parameter + util.import_module_and_submodules('combo.commands') + util.import_module_and_submodules('combo.models') + util.import_module_and_submodules('combo.training') + + if FLAGS.mode == 'train': + checks.file_exists(FLAGS.config_path) + params = common.Params.from_file(FLAGS.config_path, ext_vars=_get_ext_vars()) + model_params = params.get('model').as_ordered_dict() + serialization_dir = tempfile.mkdtemp(prefix='allennlp', dir=FLAGS.serialization_dir) + model = train.train_model(params, serialization_dir=serialization_dir, file_friendly_logging=True) + logger.info(f'Training model stored in: {serialization_dir}') + + if FLAGS.finetuning_training_data_path: + checks.file_exists(FLAGS.finetuning_training_data_path) + + # Loading will be performed from stored model.tar.gz + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + params = common.Params.from_file(FLAGS.config_path, ext_vars=_get_ext_vars(finetuning=True)) + # Replace model definition with pretrained archive + params['model'] = { + 'type': 'from_archive', + 'archive_file': serialization_dir + '/model.tar.gz', + } + serialization_dir = tempfile.mkdtemp(prefix='allennlp', suffix='-finetuning', dir=FLAGS.serialization_dir) + model = train.train_model(params.duplicate(), serialization_dir=serialization_dir, + file_friendly_logging=True) + + # Make finetuning model serialization independent from training serialization + # Storing model definition instead of archive + params['model'] = model_params + params.to_file(os.path.join(serialization_dir, archival.CONFIG_NAME)) + archival.archive_model(serialization_dir) + + logger.info(f'Finetuned model stored in: {serialization_dir}') + + if FLAGS.test_path and FLAGS.result: + checks.file_exists(FLAGS.test_path) + params = common.Params.from_file(FLAGS.config_path)['dataset_reader'] + params.pop('type') + dataset_reader = dataset.UniversalDependenciesDatasetReader.from_params(params) + predictor = predict.SemanticMultitaskPredictor( + model=model, + dataset_reader=dataset_reader + ) + test_path = FLAGS.test_path + test_trees = dataset_reader.read(test_path) + with open(FLAGS.result, 'w') as f: + for tree in test_trees: + f.writelines(predictor.predict_instance_as_tree(tree).serialize()) + else: + use_dataset_reader = ".conllu" in FLAGS.input_file.lower() + manager = allen_predict._PredictManager( + _get_predictor(), + FLAGS.input_file, + FLAGS.output_file, + FLAGS.batch_size, + FLAGS.silent, + use_dataset_reader, + ) + manager.run() + + +def _get_predictor() -> predictors.Predictor: + allen_checks.check_for_gpu(FLAGS.cuda_device) + checks.file_exists(FLAGS.model_path) + archive = models.load_archive( + FLAGS.model_path, + cuda_device=FLAGS.cuda_device, + ) + + return predictors.Predictor.from_archive( + archive, FLAGS.predictor_name + ) + + +def _get_ext_vars(finetuning: bool = False) -> Dict: + if FLAGS.use_pure_config: + return {} + else: + return { + 'training_data_path': FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path, + 'validation_data_path': ( + FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path), + 'pretrained_tokens': FLAGS.pretrained_tokens, + 'pretrained_transformer_name': FLAGS.pretrained_transformer_name, + 'features': ' '.join(FLAGS.features), + 'targets': ' '.join(FLAGS.targets), + 'type': 'finetuning' if finetuning else 'default', + 'embedding_dim': str(FLAGS.embedding_dim), + } + + +def main(): + """Parse flags.""" + flags.mark_flag_as_required('mode') + app.run(run) + + +if __name__ == '__main__': + main() diff --git a/combo/models/__init__.py b/combo/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5aa7b283c70af76a27d97f63783368bdbe4ffa3f --- /dev/null +++ b/combo/models/__init__.py @@ -0,0 +1,8 @@ +"""Models module.""" +from .base import FeedForwardPredictor +from .parser import DependencyRelationModel +from .embeddings import CharacterBasedWordEmbeddings +from .encoder import ComboEncoder +from .lemma import LemmatizerModel +from .model import SemanticMultitaskModel +from .morpho import MorphologicalFeatures diff --git a/combo/models/base.py b/combo/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..dc49f90dc02dc0553cb536b0e09dc2a510866069 --- /dev/null +++ b/combo/models/base.py @@ -0,0 +1,114 @@ +from typing import Dict, Optional, List, Union + +import torch +import torch.nn as nn +from allennlp import common, data +from allennlp import nn as allen_nn +from allennlp.common import checks +from allennlp.modules import feedforward +from allennlp.nn import Activation + +from combo.models import utils + + +class Predictor(nn.Module, common.Registrable): + + default_implementation = 'feedforward_predictor_from_vocab' + + def forward(self, + x: Union[torch.Tensor, List[torch.Tensor]], + mask: Optional[torch.BoolTensor] = None, + labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + + +class Linear(nn.Linear, common.FromParams): + + def __init__(self, + in_features: int, + out_features: int, + activation: Optional[allen_nn.Activation] = lambda x: x, + dropout_rate: Optional[float] = 0.0): + super().__init__(in_features, out_features) + self.activation = activation + self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else lambda x: x + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = super().forward(x) + x = self.activation(x) + return self.dropout(x) + + def get_output_dim(self) -> int: + return self.out_features + + +@Predictor.register('feedforward_predictor') +@Predictor.register('feedforward_predictor_from_vocab', constructor='from_vocab') +class FeedForwardPredictor(Predictor): + """Feedforward predictor. Should be used on top of Seq2Seq encoder.""" + + def __init__(self, feedforward_network: feedforward.FeedForward): + super().__init__() + self.feedforward_network = feedforward_network + + def forward(self, + x: Union[torch.Tensor, List[torch.Tensor]], + mask: Optional[torch.BoolTensor] = None, + labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: + if mask is None: + mask = x.new_ones(x.size()[:-1]) + + x = self.feedforward_network(x) + output = { + 'prediction': x.argmax(-1), + 'probability': x + } + + if labels is not None: + if sample_weights is None: + sample_weights = labels.new_ones([mask.size(0)]) + output['loss'] = self._loss(x, labels, mask, sample_weights) + + return output + + def _loss(self, + pred: torch.Tensor, + true: torch.Tensor, + mask: torch.BoolTensor, + sample_weights: torch.Tensor) -> torch.Tensor: + BATCH_SIZE, _, CLASSES = pred.size() + valid_positions = mask.sum() + pred = pred.reshape(-1, CLASSES) + true = true.reshape(-1) + mask = mask.reshape(-1) + loss = utils.masked_cross_entropy(pred, true, mask) * mask + loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) + return loss.sum() / valid_positions + + @classmethod + def from_vocab(cls, + vocab: data.Vocabulary, + vocab_namespace: str, + input_dim: int, + num_layers: int, + hidden_dims: List[int], + activations: Union[Activation, List[Activation]], + dropout: Union[float, List[float]] = 0.0, + ): + if len(hidden_dims) + 1 != num_layers: + raise checks.ConfigurationError( + "len(hidden_dims) (%d) + 1 != num_layers (%d)" % (len(hidden_dims), num_layers) + ) + + assert vocab_namespace in vocab.get_namespaces() + hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)] + + return cls(feedforward.FeedForward( + input_dim=input_dim, + num_layers=num_layers, + hidden_dims=hidden_dims, + activations=activations, + dropout=dropout) + ) diff --git a/combo/models/dilated_cnn.py b/combo/models/dilated_cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..59d25f3e3dbd76ec29d4c5f704f851469713c44d --- /dev/null +++ b/combo/models/dilated_cnn.py @@ -0,0 +1,39 @@ +from typing import List + +import torch +import torch.nn as nn + +from allennlp import common +from allennlp import nn as allen_nn + + +class DilatedCnnEncoder(nn.Module, common.FromParams): + + def __init__(self, + input_dim: int, + filters: List[int], + kernel_size: List[int], + stride: List[int], + padding: List[int], + dilation: List[int], + activations: List[allen_nn.Activation]): + super().__init__() + conv1d_layers = [] + input_dims = [input_dim] + filters[:-1] + output_dims = filters + for idx in range(len(activations)): + conv1d_layers.append(nn.Conv1d( + in_channels=input_dims[idx], + out_channels=output_dims[idx], + kernel_size=kernel_size[idx], + stride=stride[idx], + padding=padding[idx], + dilation=dilation[idx])) + self.conv1d_layers = nn.ModuleList(conv1d_layers) + self.activations = activations + assert len(self.activations) == len(self.conv1d_layers) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + for layer, activation in zip(self.conv1d_layers, self.activations): + x = activation(layer(x)) + return x diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..b5bf26262fe2bf119d4ad89336ef2291508516a0 --- /dev/null +++ b/combo/models/embeddings.py @@ -0,0 +1,150 @@ +"""Embeddings.""" +from typing import Optional + +import torch +import torch.nn as nn +from allennlp import nn as allen_nn, data +from allennlp.modules import token_embedders +from overrides import overrides +from transformers import modeling_auto + +from combo.models import base, dilated_cnn + + +@token_embedders.TokenEmbedder.register('char_embeddings') +@token_embedders.TokenEmbedder.register('char_embeddings_from_config', constructor='from_config') +class CharacterBasedWordEmbeddings(token_embedders.TokenEmbedder): + """Character-based word embeddings.""" + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + dilated_cnn_encoder: dilated_cnn.DilatedCnnEncoder): + super().__init__() + self.char_embed = nn.Embedding( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + ) + self.dilated_cnn_encoder = dilated_cnn_encoder + self.output_dim = embedding_dim + + def forward(self, + x: torch.Tensor, + char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor: + if char_mask is None: + char_mask = x.new_ones(x.size()) + + x = self.char_embed(x) + x = x * char_mask.unsqueeze(-1).float() + + BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_EMB = x.size() + + words = [] + for i in range(SENTENCE_LENGTH): + word = x[:, i, :, :].reshape(BATCH_SIZE, MAX_WORD_LENGTH, CHAR_EMB).transpose(1, 2) + word = self.dilated_cnn_encoder(word) + word, _ = torch.max(word, dim=2) + words.append(word) + return torch.stack(words, dim=1) + + @overrides + def get_output_dim(self) -> int: + return self.output_dim + + @classmethod + def from_config(cls, + embedding_dim: int, + vocab: data.Vocabulary, + dilated_cnn_encoder: dilated_cnn.DilatedCnnEncoder, + vocab_namespace: str = 'token_characters'): + assert vocab_namespace in vocab.get_namespaces() + return cls( + embedding_dim=embedding_dim, + num_embeddings=vocab.get_vocab_size(vocab_namespace), + dilated_cnn_encoder=dilated_cnn_encoder + ) + + +@token_embedders.TokenEmbedder.register('embeddings_projected') +class ProjectedWordEmbedder(token_embedders.Embedding): + """Word embeddings.""" + + def __init__(self, + embedding_dim: int, + num_embeddings: int = None, + weight: torch.FloatTensor = None, + padding_index: int = None, + trainable: bool = True, + max_norm: float = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + vocab_namespace: str = "tokens", + pretrained_file: str = None, + vocab: data.Vocabulary = None, + projection_layer: Optional[base.Linear] = None): + super().__init__( + embedding_dim=embedding_dim, + num_embeddings=num_embeddings, + weight=weight, + padding_index=padding_index, + trainable=trainable, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + vocab_namespace=vocab_namespace, + pretrained_file=pretrained_file, + vocab=vocab + ) + self._projection = projection_layer + self.output_dim = embedding_dim if projection_layer is None else projection_layer.out_features + + @overrides + def get_output_dim(self) -> int: + return self.output_dim + + +@token_embedders.TokenEmbedder.register('transformers_word_embeddings') +class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEmbedder): + """ + Transformers word embeddings as last hidden state + optional projection layers. + + Tested with Bert (but should work for other models as well). + """ + + def __init__(self, + model_name: str, + projection_dim: int, + projection_activation: Optional[allen_nn.Activation] = lambda x: x, + projection_dropout_rate: Optional[float] = 0.0): + super().__init__(model_name) + self.transformers_encoder = modeling_auto.AutoModel.from_pretrained(model_name) + self.output_dim = self.transformers_encoder.config.hidden_size + if projection_dim: + self.projection_layer = base.Linear(in_features=self.output_dim, + out_features=projection_dim, + dropout_rate=projection_dropout_rate, + activation=projection_activation) + self.output_dim = projection_dim + else: + self.projection_layer = None + + def forward( + self, + token_ids: torch.LongTensor, + mask: torch.BoolTensor, + offsets: torch.LongTensor, + wordpiece_mask: torch.BoolTensor, + type_ids: Optional[torch.LongTensor] = None, + segment_concat_mask: Optional[torch.BoolTensor] = None, + ) -> torch.Tensor: + x = super().forward(token_ids=token_ids, mask=mask, offsets=offsets, wordpiece_mask=wordpiece_mask, + type_ids=type_ids, segment_concat_mask=segment_concat_mask) + if self.projection_layer: + x = self.projection_layer(x) + return x + + @overrides + def get_output_dim(self): + return self.output_dim diff --git a/combo/models/encoder.py b/combo/models/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d13e7e5dd96bce0719e572e68bb4afc9a144ef3e --- /dev/null +++ b/combo/models/encoder.py @@ -0,0 +1,76 @@ +"""Encoder.""" +from typing import Optional, Tuple + +import torch +import torch.nn.utils.rnn as rnn +from allennlp import common, modules +from allennlp.modules import input_variational_dropout, stacked_bidirectional_lstm, seq2seq_encoders +from overrides import overrides + + +class StackedBiLSTM(stacked_bidirectional_lstm.StackedBidirectionalLstm, common.FromParams): + + def __init__(self, input_size: int, hidden_size: int, num_layers: int, recurrent_dropout_probability: float, + layer_dropout_probability: float, use_highway: bool = False): + super().__init__(input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + recurrent_dropout_probability=recurrent_dropout_probability, + layer_dropout_probability=layer_dropout_probability, + use_highway=use_highway) + + @overrides + def forward(self, + inputs: rnn.PackedSequence, + initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + ) -> Tuple[rnn.PackedSequence, Tuple[torch.Tensor, torch.Tensor]]: + """Changes when compared to stacked_bidirectional_lstm.StackedBidirectionalLstm + * dropout also on last layer + * accepts BxTxD tensor + * state from n-1 layer used as n layer initial state + + :param inputs: + :param initial_state: + :return: + """ + output_sequence = inputs + state_fwd = None + state_bwd = None + for i in range(self.num_layers): + forward_layer = getattr(self, 'forward_layer_{}'.format(i)) + backward_layer = getattr(self, 'backward_layer_{}'.format(i)) + + forward_output, state_fwd = forward_layer(output_sequence, state_fwd) + backward_output, state_bwd = backward_layer(output_sequence, state_bwd) + + forward_output, lengths = rnn.pad_packed_sequence(forward_output, batch_first=True) + backward_output, _ = rnn.pad_packed_sequence(backward_output, batch_first=True) + + output_sequence = torch.cat([forward_output, backward_output], -1) + + output_sequence = self.layer_dropout(output_sequence) + output_sequence = rnn.pack_padded_sequence(output_sequence, lengths, batch_first=True) + + return output_sequence, (state_fwd, state_bwd) + + +@modules.Seq2SeqEncoder.register('combo_encoder') +class ComboEncoder(seq2seq_encoders.PytorchSeq2SeqWrapper): + """COMBO encoder (https://www.aclweb.org/anthology/K18-2004.pdf). + + This implementation uses Variational Dropout on the input and then outputs of each BiLSTM layer + (instead of used Gaussian Dropout and Gaussian Noise). + """ + + def __init__(self, stacked_bilstm: StackedBiLSTM, layer_dropout_probability: float): + super().__init__(stacked_bilstm, stateful=False) + self.layer_dropout = input_variational_dropout.InputVariationalDropout(p=layer_dropout_probability) + + @overrides + def forward(self, + inputs: torch.Tensor, + mask: torch.BoolTensor, + hidden_state: torch.Tensor = None) -> torch.Tensor: + x = self.layer_dropout(inputs) + x = super().forward(x, mask) + return self.layer_dropout(x) diff --git a/combo/models/lemma.py b/combo/models/lemma.py new file mode 100644 index 0000000000000000000000000000000000000000..8ab4030245ceff6931a7bd13049ce92cc2a0ccb6 --- /dev/null +++ b/combo/models/lemma.py @@ -0,0 +1,116 @@ +"""Lemmatizer models.""" +from typing import Optional, Dict, List, Union + +import torch +import torch.nn as nn +from allennlp import data, nn as allen_nn +from allennlp.common import checks + +from combo.models import base, dilated_cnn, utils + + +@base.Predictor.register('combo_lemma_predictor_from_vocab', constructor='from_vocab') +class LemmatizerModel(base.Predictor): + """Lemmatizer model.""" + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + dilated_cnn_encoder: dilated_cnn.DilatedCnnEncoder, + input_projection_layer: base.Linear): + super().__init__() + self.char_embed = nn.Embedding( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + ) + self.dilated_cnn_encoder = dilated_cnn_encoder + self.input_projection_layer = input_projection_layer + + def forward(self, + x: Union[torch.Tensor, List[torch.Tensor]], + mask: Optional[torch.BoolTensor] = None, + labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: + encoder_emb, chars = x + + encoder_emb = self.input_projection_layer(encoder_emb) + char_embeddings = self.char_embed(chars) + + BATCH_SIZE, SENTENCE_LENGTH, WORD_EMB = encoder_emb.size() + _, _, MAX_WORD_LENGTH, CHAR_EMB = char_embeddings.size() + + + encoder_emb = encoder_emb.unsqueeze(2).repeat(1, 1, MAX_WORD_LENGTH, 1) + + pred = [] + for i in range(SENTENCE_LENGTH): + word_emb = (encoder_emb[:, i, :, :].reshape(BATCH_SIZE, MAX_WORD_LENGTH, -1)) + char_sent_emb = char_embeddings[:, i, :].reshape(BATCH_SIZE, MAX_WORD_LENGTH, CHAR_EMB) + x = torch.cat((char_sent_emb, word_emb), -1).transpose(2, 1) + x = self.dilated_cnn_encoder(x) + pred.append(x) + x = torch.stack(pred, dim=1).transpose(2, 3) + output = { + 'prediction': x.argmax(-1), + 'probability': x + } + + if labels is not None: + if mask is None: + mask = encoder_emb.new_ones(encoder_emb.size()[:-2]) + if sample_weights is None: + sample_weights = labels.new_ones(BATCH_SIZE) + mask = mask.unsqueeze(2).repeat(1, 1, MAX_WORD_LENGTH).bool() + output['loss'] = self._loss(x, labels, mask, sample_weights) + + return output + + @staticmethod + def _loss(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor, + sample_weights: torch.Tensor) -> torch.Tensor: + BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_CLASSES = pred.size() + pred = pred.reshape(-1, CHAR_CLASSES) + + valid_positions = mask.sum() + mask = mask.reshape(-1) + true = true.reshape(-1) + loss = utils.masked_cross_entropy(pred, true, mask) * mask + loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) + return loss.sum() / valid_positions + + @classmethod + def from_vocab(cls, + vocab: data.Vocabulary, + char_vocab_namespace: str, + lemma_vocab_namespace: str, + embedding_dim: int, + input_projection_layer: base.Linear, + filters: List[int], + kernel_size: List[int], + stride: List[int], + padding: List[int], + dilation: List[int], + activations: List[allen_nn.Activation], + ): + assert char_vocab_namespace in vocab.get_namespaces() + assert lemma_vocab_namespace in vocab.get_namespaces() + + if len(filters) + 1 != len(kernel_size): + raise checks.ConfigurationError( + "len(filters) (%d) + 1 != kernel_size (%d)" % (len(filters), len(kernel_size)) + ) + filters = filters + [vocab.get_vocab_size(lemma_vocab_namespace)] + + dilated_cnn_encoder = dilated_cnn.DilatedCnnEncoder( + input_dim=embedding_dim + input_projection_layer.get_output_dim(), + filters=filters, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + activations=activations, + ) + return cls(num_embeddings=vocab.get_vocab_size(char_vocab_namespace), + embedding_dim=embedding_dim, + dilated_cnn_encoder=dilated_cnn_encoder, + input_projection_layer=input_projection_layer) diff --git a/combo/models/model.py b/combo/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e4a6249dd81b3b5d56688ce61dc7028755929ee6 --- /dev/null +++ b/combo/models/model.py @@ -0,0 +1,195 @@ +"""Main COMBO model.""" +from typing import Optional, Dict, Any, List + +import torch +from allennlp import data, modules, models as allen_models, nn as allen_nn +from allennlp.modules import text_field_embedders +from allennlp.nn import util +from overrides import overrides + +from combo.models import base +from combo.utils import metrics + + +@allen_models.Model.register('semantic_multitask') +class SemanticMultitaskModel(allen_models.Model): + """Main COMBO model.""" + + def __init__(self, + vocab: data.Vocabulary, + loss_weights: Dict[str, float], + text_field_embedder: text_field_embedders.TextFieldEmbedder, + seq_encoder: modules.Seq2SeqEncoder, + use_sample_weight: bool = True, + lemmatizer: Optional[base.Predictor] = None, + upos_tagger: Optional[base.Predictor] = None, + xpos_tagger: Optional[base.Predictor] = None, + semantic_relation: Optional[base.Predictor] = None, + morphological_feat: Optional[base.Predictor] = None, + dependency_relation: Optional[base.Predictor] = None, + regularizer: allen_nn.RegularizerApplicator = None) -> None: + super().__init__(vocab, regularizer) + self.text_field_embedder = text_field_embedder + self.loss_weights = loss_weights + self.use_sample_weight = use_sample_weight + self.seq_encoder = seq_encoder + self.lemmatizer = lemmatizer + self.upos_tagger = upos_tagger + self.xpos_tagger = xpos_tagger + self.semantic_relation = semantic_relation + self.morphological_feat = morphological_feat + self.dependency_relation = dependency_relation + self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, self.seq_encoder.get_output_dim()])) + self.scores = metrics.SemanticMetrics() + self._partial_losses = None + + @overrides + def forward(self, + sentence: Dict[str, Dict[str, torch.Tensor]], + metadata: List[Dict[str, Any]], + upostag: torch.Tensor = None, + xpostag: torch.Tensor = None, + lemma: Dict[str, Dict[str, torch.Tensor]] = None, + feats: torch.Tensor = None, + head: torch.Tensor = None, + deprel: torch.Tensor = None, + semrel: torch.Tensor = None,) -> Dict[str, torch.Tensor]: + + # Prepare masks + char_mask: torch.BoolTensor = sentence['char']['token_characters'] > 0 + word_mask = util.get_text_field_mask(sentence) + + # If enabled weight samples loss by log(sentence_length) + sample_weights = word_mask.sum(-1).float().log() if self.use_sample_weight else None + + encoder_input = self.text_field_embedder(sentence, char_mask=char_mask) + encoder_emb = self.seq_encoder(encoder_input, word_mask) + + batch_size, _, encoding_dim = encoder_emb.size() + + # Concatenate the head sentinel (ROOT) onto the sentence representation. + head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) + encoder_emb = torch.cat([head_sentinel, encoder_emb], 1) + word_mask = torch.cat([word_mask.new_ones((batch_size, 1)), word_mask], 1) + + upos_output = self._optional(self.upos_tagger, + encoder_emb[:, 1:], + mask=word_mask[:, 1:], + labels=upostag, + sample_weights=sample_weights) + xpos_output = self._optional(self.xpos_tagger, + encoder_emb[:, 1:], + mask=word_mask[:, 1:], + labels=xpostag, + sample_weights=sample_weights) + semrel_output = self._optional(self.semantic_relation, + encoder_emb[:, 1:], + mask=word_mask[:, 1:], + labels=semrel, + sample_weights=sample_weights) + morpho_output = self._optional(self.morphological_feat, + encoder_emb[:, 1:], + mask=word_mask[:, 1:], + labels=feats, + sample_weights=sample_weights) + lemma_output = self._optional(self.lemmatizer, + (encoder_emb[:, 1:], sentence.get('char').get('token_characters') + if sentence.get('char') else None), + mask=word_mask[:, 1:], + labels=lemma.get('char').get('token_characters') if lemma else None, + sample_weights=sample_weights) + parser_output = self._optional(self.dependency_relation, + encoder_emb, + returns_tuple=True, + mask=word_mask, + labels=(deprel, head), + sample_weights=sample_weights) + relations_pred, head_pred = parser_output['prediction'] + output = { + 'upostag': upos_output['prediction'], + 'xpostag': xpos_output['prediction'], + 'semrel': semrel_output['prediction'], + 'feats': morpho_output['prediction'], + 'lemma': lemma_output['prediction'], + 'head': head_pred, + 'deprel': relations_pred, + 'sentence_embedding': torch.max(encoder_emb[:, 1:], dim=1)[0], + } + + if self._has_labels([upostag, xpostag, lemma, feats, head, deprel, semrel]): + + # Feats mapping + if self.morphological_feat: + mapped_gold_labels = [] + for cat, cat_indices in self.morphological_feat.slices.items(): + mapped_gold_labels.append(feats[:, :, cat_indices].argmax(dim=-1)) + + feats = torch.stack(mapped_gold_labels, dim=-1) + + labels = { + 'upostag': upostag, + 'xpostag': xpostag, + 'semrel': semrel, + 'feats': feats, + 'lemma': lemma.get('char').get('token_characters') if lemma else None, + 'head': head, + 'deprel': deprel, + } + self.scores(output, labels, word_mask[:, 1:]) + relations_loss, head_loss = parser_output['loss'] + losses = { + 'upostag_loss': upos_output['loss'], + 'xpostag_loss': xpos_output['loss'], + 'semrel_loss': semrel_output['loss'], + 'feats_loss': morpho_output['loss'], + 'lemma_loss': lemma_output['loss'], + 'head_loss': head_loss, + 'deprel_loss': relations_loss, + # Cycle loss is only for the metrics purposes. + 'cycle_loss': parser_output.get('cycle_loss') + } + self._partial_losses = losses.copy() + losses['loss'] = self._calculate_loss(losses) + output.update(losses) + + return self._clean(output) + + @staticmethod + def _has_labels(labels): + return any(x is not None for x in labels) + + def _calculate_loss(self, output): + losses = [] + for name, value in self.loss_weights.items(): + if output.get(f'{name}_loss'): + losses.append(output[f'{name}_loss'] * value) + return torch.stack(losses).sum() + + @staticmethod + def _optional(callable_model: Optional[torch.nn.Module], + *args, + returns_tuple: bool = False, + **kwargs): + if callable_model: + return callable_model(*args, **kwargs) + else: + if returns_tuple: + return {'prediction': (None, None), 'loss': (None, None)} + else: + return {'prediction': None, 'loss': None} + + @staticmethod + def _clean(output): + for k, v in dict(output).items(): + if v is None: + del output[k] + return output + + @overrides + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + metrics = self.scores.get_metric(reset) + if self._partial_losses: + losses = self._clean(self._partial_losses) + losses = {f"partial_loss/{k}": v.detach().item() for k, v in losses.items()} + metrics.update(losses) + return metrics diff --git a/combo/models/morpho.py b/combo/models/morpho.py new file mode 100644 index 0000000000000000000000000000000000000000..238d4284ce97d5508ccb6ef663a25c317efc77bf --- /dev/null +++ b/combo/models/morpho.py @@ -0,0 +1,102 @@ +"""Morphological features models.""" +from typing import Dict, List, Optional, Union + +import torch +from allennlp import data +from allennlp.common import checks +from allennlp.modules import feedforward +from allennlp.nn import Activation + +from combo.data import dataset +from combo.models import base, utils + + +@base.Predictor.register('combo_morpho_from_vocab', constructor='from_vocab') +class MorphologicalFeatures(base.Predictor): + """Morphological features predicting model.""" + + def __init__(self, feedforward_network: feedforward.FeedForward, slices: Dict[str, List[int]]): + super().__init__() + self.feedforward_network = feedforward_network + self.slices = slices + + def forward(self, + x: Union[torch.Tensor, List[torch.Tensor]], + mask: Optional[torch.BoolTensor] = None, + labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: + if mask is None: + mask = x.new_ones(x.size()[:-1]) + + x = self.feedforward_network(x) + + prediction = [] + for cat, cat_indices in self.slices.items(): + prediction.append(x[:, :, cat_indices].argmax(dim=-1)) + + output = { + 'prediction': torch.stack(prediction, dim=-1), + 'probability': x + } + + if labels is not None: + if sample_weights is None: + sample_weights = labels.new_ones([mask.size(0)]) + output['loss'] = self._loss(x, labels, mask, sample_weights) + + return output + + def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor, + sample_weights: torch.Tensor) -> torch.Tensor: + assert pred.size() == true.size() + BATCH_SIZE, SENTENCE_LENGTH, MORPHOLOGICAL_FEATURES = pred.size() + + valid_positions = mask.sum() + + pred = pred.reshape(-1, MORPHOLOGICAL_FEATURES) + true = true.reshape(-1, MORPHOLOGICAL_FEATURES) + mask = mask.reshape(-1) + loss = None + loss_func = utils.masked_cross_entropy + for cat, cat_indices in self.slices.items(): + if cat not in ['__PAD__', '_']: + if loss is None: + loss = loss_func(pred[:, cat_indices], + true[:, cat_indices].argmax(dim=1), + mask) * mask + else: + loss += loss_func(pred[:, cat_indices], + true[:, cat_indices].argmax(dim=1), + mask) * mask + loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) + return loss.sum() / valid_positions + + @classmethod + def from_vocab(cls, + vocab: data.Vocabulary, + vocab_namespace: str, + input_dim: int, + num_layers: int, + hidden_dims: List[int], + activations: Union[Activation, List[Activation]], + dropout: Union[float, List[float]] = 0.0, + ): + if len(hidden_dims) + 1 != num_layers: + raise checks.ConfigurationError( + "len(hidden_dims) (%d) + 1 != num_layers (%d)" % (len(hidden_dims), num_layers) + ) + + assert vocab_namespace in vocab.get_namespaces() + hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)] + + slices = dataset.get_slices_if_not_provided(vocab) + + return cls( + feedforward_network=feedforward.FeedForward( + input_dim=input_dim, + num_layers=num_layers, + hidden_dims=hidden_dims, + activations=activations, + dropout=dropout), + slices=slices + ) diff --git a/combo/models/parser.py b/combo/models/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..63f4c8af6641c9aca49ca683bbe257bd3a0b1b6f --- /dev/null +++ b/combo/models/parser.py @@ -0,0 +1,187 @@ +"""Dependency parsing models.""" +from typing import Tuple, Dict, Optional, Union, List + +import numpy as np +import torch +import torch.nn.functional as F +from allennlp import data +from allennlp.nn import chu_liu_edmonds + +from combo.models import base, utils + + +class HeadPredictionModel(base.Predictor): + + def __init__(self, + head_projection_layer: base.Linear, + dependency_projection_layer: base.Linear, + cycle_loss_n: int = 0): + super().__init__() + self.head_projection_layer = head_projection_layer + self.dependency_projection_layer = dependency_projection_layer + self.cycle_loss_n = cycle_loss_n + + def forward(self, + x: Union[torch.Tensor, List[torch.Tensor]], + mask: Optional[torch.BoolTensor] = None, + labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: + if mask is None: + mask = x.new_ones(x.size()[-1]) + + head_arc_emb = self.head_projection_layer(x) + dep_arc_emb = self.dependency_projection_layer(x) + x = dep_arc_emb.bmm(head_arc_emb.transpose(2, 1)) + + if self.training: + pred = x.argmax(-1) + else: + pred = [] + # Adding non existing in mask ROOT to lengths + lengths = mask.data.sum(dim=1).long().cpu().numpy() + 1 + for idx, length in enumerate(lengths): + probs = x[idx, :].softmax(dim=-1).cpu().numpy() + probs[:, 0] = 0 + heads, _ = chu_liu_edmonds.decode_mst(probs.T, length=length, has_labels=False) + heads[0] = 0 + pred.append(heads) + pred = torch.from_numpy(np.stack(pred)).to(x.device) + + output = { + 'prediction': pred[:, 1:], + 'probability': x + } + + if labels is not None: + if sample_weights is None: + sample_weights = labels.new_ones([mask.size(0)]) + output['loss'], output['cycle_loss'] = self._loss(x, labels, mask, sample_weights) + + return output + + def _cycle_loss(self, pred: torch.Tensor): + BATCH_SIZE, SENTENCE_LENGTH, _ = pred.size() + loss = pred.new_zeros(BATCH_SIZE) + # 1: as using non __ROOT__ tokens + yn = pred[:, 1:, 1:] + for i in range(self.cycle_loss_n): + loss += self._batch_trace(yn) / BATCH_SIZE + yn = yn.bmm(pred[:, 1:, 1:]) + + return loss + + def _batch_trace(self, x: torch.Tensor) -> torch.Tensor: + assert len(x.size()) == 3 + BATCH_SIZE, N, M = x.size() + assert N == M + identity = x.new_tensor(torch.eye(N)) + identity = identity.reshape((1, N, N)) + batch_identity = identity.repeat(BATCH_SIZE, 1, 1) + return (x * batch_identity).sum() + + def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor, + sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + BATCH_SIZE, N, M = pred.size() + assert N == M + SENTENCE_LENGTH = N + + valid_positions = mask.sum() + + result = [] + # Ignore first pred dimension as it is ROOT token prediction + for i in range(SENTENCE_LENGTH - 1): + pred_i = pred[:, i + 1, :].reshape(BATCH_SIZE, SENTENCE_LENGTH) + true_i = true[:, i].reshape(-1) + mask_i = mask[:, i] + cross_entropy_loss = utils.masked_cross_entropy(pred_i, true_i, mask_i) * mask_i + result.append(cross_entropy_loss) + cycle_loss = self._cycle_loss(pred) + loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1) + return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean() + + +@base.Predictor.register('combo_dependency_parsing_from_vocab', constructor='from_vocab') +class DependencyRelationModel(base.Predictor): + + def __init__(self, + head_predictor: HeadPredictionModel, + head_projection_layer: base.Linear, + dependency_projection_layer: base.Linear, + relation_prediction_layer: base.Linear): + super().__init__() + self.head_predictor = head_predictor + self.head_projection_layer = head_projection_layer + self.dependency_projection_layer = dependency_projection_layer + self.relation_prediction_layer = relation_prediction_layer + + def forward(self, + x: Union[torch.Tensor, List[torch.Tensor]], + mask: Optional[torch.BoolTensor] = None, + labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: + if mask is not None: + mask = mask[:, 1:] + relations_labels, head_labels = None, None + if labels is not None and labels[0] is not None: + relations_labels, head_labels = labels + if mask is None: + mask = head_labels.new_ones(head_labels.size()) + + head_output = self.head_predictor(x, mask, head_labels, sample_weights) + head_pred = head_output['probability'] + head_pred_soft = F.softmax(head_pred, dim=-1) + + head_rel_emb = self.head_projection_layer(x) + + dep_rel_emb = self.dependency_projection_layer(x) + + dep_rel_pred = head_pred_soft.bmm(head_rel_emb) + dep_rel_pred = torch.cat((dep_rel_pred, dep_rel_emb), dim=-1) + relation_prediction = self.relation_prediction_layer(dep_rel_pred) + output = head_output + + output['prediction'] = (relation_prediction.argmax(-1)[:, 1:], head_output['prediction']) + + if labels is not None and labels[0] is not None: + if sample_weights is None: + sample_weights = labels.new_ones([mask.size(0)]) + loss = self._loss(relation_prediction[:, 1:], relations_labels, mask, sample_weights) + output['loss'] = (loss, head_output['loss']) + + return output + + def _loss(self, + pred: torch.Tensor, + true: torch.Tensor, + mask: torch.BoolTensor, + sample_weights: torch.Tensor) -> torch.Tensor: + + valid_positions = mask.sum() + + BATCH_SIZE, _, DEPENDENCY_RELATIONS = pred.size() + pred = pred.reshape(-1, DEPENDENCY_RELATIONS) + true = true.reshape(-1) + mask = mask.reshape(-1) + loss = utils.masked_cross_entropy(pred, true, mask) * mask + loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) + return loss.sum() / valid_positions + + @classmethod + def from_vocab(cls, + vocab: data.Vocabulary, + vocab_namespace: str, + head_predictor: HeadPredictionModel, + head_projection_layer: base.Linear, + dependency_projection_layer: base.Linear + ): + assert vocab_namespace in vocab.get_namespaces() + relation_prediction_layer = base.Linear( + in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(), + out_features=vocab.get_vocab_size(vocab_namespace) + ) + return cls( + head_predictor=head_predictor, + head_projection_layer=head_projection_layer, + dependency_projection_layer=dependency_projection_layer, + relation_prediction_layer=relation_prediction_layer + ) diff --git a/combo/models/utils.py b/combo/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f97913303415552f6c594ac857b563cab52b6c3b --- /dev/null +++ b/combo/models/utils.py @@ -0,0 +1,8 @@ +import torch +import torch.nn.functional as F + + +def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + mask = mask.float().unsqueeze(-1) + pred = pred + (mask + 1e-45).log() + return F.cross_entropy(pred, true, reduction='none') diff --git a/combo/predict.py b/combo/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4f8326702724efcdcdefb629d567353606b976 --- /dev/null +++ b/combo/predict.py @@ -0,0 +1,141 @@ +import collections +import logging +import time +from typing import List + +import conllu +from allennlp import data as allen_data, common, models +from allennlp.common import util +from allennlp.data import tokenizers +from allennlp.predictors import predictor +from overrides import overrides + +logger = logging.getLogger(__name__) + + +@predictor.Predictor.register('semantic-multitask-predictor') +@predictor.Predictor.register('semantic-multitask-predictor-spacy', constructor='with_spacy_tokenizer') +class SemanticMultitaskPredictor(predictor.Predictor): + + def __init__(self, + model: models.Model, + dataset_reader: allen_data.DatasetReader, + tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer()) -> None: + super().__init__(model, dataset_reader) + self.vocab = model.vocab + self._dataset_reader.generate_labels = False + self._tokenizer = tokenizer + + @overrides + def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance: + tokens = self._tokenizer.tokenize(json_dict['sentence']) + tree = self._sentence_to_tree([t.text for t in tokens]) + return self._dataset_reader.text_to_instance(tree) + + @overrides + def load_line(self, line: str) -> common.JsonDict: + return {'sentence': line.replace("\n", " ").strip()} + + @overrides + def dump_line(self, outputs: common.JsonDict) -> str: + # Check whether serialized (str) tree or token's list + # Serialized tree has already separators between lines + if type(outputs['tree']) == str: + return str(outputs['tree']) + else: + return str(outputs['tree']) + "\n" + + @overrides + def predict_instance(self, instance: allen_data.Instance) -> common.JsonDict: + start_time = time.time() + tree = self.predict_instance_as_tree(instance) + tree_json = util.sanitize(tree.serialize()) + result = collections.OrderedDict([ + ('tree', tree_json), + ]) + end_time = time.time() + logger.info('Took {} ms'.format((end_time - start_time) * 1000.0)) + return result + + @overrides + def predict_json(self, inputs: common.JsonDict) -> common.JsonDict: + start_time = time.time() + instance = self._json_to_instance(inputs) + tree = self.predict_instance_as_tree(instance) + tree_json = util.sanitize(tree) + result = collections.OrderedDict([ + ('tree', tree_json), + ]) + end_time = time.time() + logger.info('Took {} ms'.format((end_time - start_time) * 1000.0)) + return result + + def predict_instance_as_tree(self, instance: allen_data.Instance) -> conllu.TokenList: + predictions = super().predict_instance(instance) + return self._predictions_as_tree(predictions, instance) + + @staticmethod + def _sentence_to_tree(sentence: List[str]): + d = collections.OrderedDict + return conllu.TokenList( + [d({'id': idx, 'token': token}) for + idx, token + in enumerate(sentence)] + ) + + def _predictions_as_tree(self, predictions, instance): + tree = instance.fields['metadata']['input'] + field_names = instance.fields['metadata']['field_names'] + for idx, token in enumerate(tree): + for field_name in field_names: + if field_name in predictions: + if field_name in ['xpostag', 'upostag', 'semrel', 'deprel']: + value = self.vocab.get_token_from_index(predictions[field_name][idx], field_name + '_labels') + token[field_name] = value + elif field_name in ['head']: + token[field_name] = int(predictions[field_name][idx]) + elif field_name in ['feats']: + slices = self._model.morphological_feat.slices + features = [] + prediction = predictions[field_name][idx] + for (cat, cat_indices), pred_idx in zip(slices.items(), prediction): + if cat not in ['__PAD__', '_']: + value = self.vocab.get_token_from_index(cat_indices[pred_idx], + field_name + '_labels') + # Exclude auxiliary values + if '=None' not in value: + features.append(value) + if len(features) == 0: + field_value = '_' + else: + field_value = '|'.join(sorted(features)) + + token[field_name] = field_value + elif field_name == 'head': + pass + elif field_name == 'lemma': + prediction = predictions[field_name][idx] + word_chars = [] + for char_idx in prediction[1:-1]: + pred_char = self.vocab.get_token_from_index(char_idx, 'lemma_characters') + + if pred_char == '__END__': + break + elif pred_char == '__PAD__': + continue + elif '_' in pred_char: + pred_char = '?' + + word_chars.append(pred_char) + token[field_name] = ''.join(word_chars) + else: + raise NotImplementedError(f'Unknown field name {field_name}!') + + if self._dataset_reader and 'sent' in self._dataset_reader._targets: + tree.metadata = {'sentence_embedding': str(predictions['sentence_embedding'])} + return tree + + @classmethod + def with_spacy_tokenizer(cls, model: models.Model, + dataset_reader: allen_data.DatasetReader): + return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) diff --git a/combo/training/__init__.py b/combo/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9980d5fda85246e3f8b38b434257defa165413be --- /dev/null +++ b/combo/training/__init__.py @@ -0,0 +1,3 @@ +"""Training tools.""" +from .scheduler import Scheduler +from .trainer import GradientDescentTrainer diff --git a/combo/training/scheduler.py b/combo/training/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..8752d739029f16289db94ee493d00bbd92158ce7 --- /dev/null +++ b/combo/training/scheduler.py @@ -0,0 +1,40 @@ +import torch.optim.lr_scheduler as lr_scheduler +from allennlp.training.learning_rate_schedulers import learning_rate_scheduler +from overrides import overrides + + +@learning_rate_scheduler.LearningRateScheduler.register("combo_scheduler") +class Scheduler(learning_rate_scheduler._PyTorchLearningRateSchedulerWrapper): + + def __init__(self, optimizer, patience: int = 6, decreases: int = 2, threshold: float = 1e-3): + super().__init__(lr_scheduler.LambdaLR(optimizer, lr_lambda=[self._lr_lambda])) + self.threshold = threshold + self.decreases = decreases + self.patience = patience + self.start_patience = patience + self.best_score = 0.0 + + @staticmethod + def _lr_lambda(idx: int) -> float: + return 1.0 / (1.0 + idx * 1e-4) + + @overrides + def step(self, metric: float = None) -> None: + self.lr_scheduler.step() + + if metric is not None: + if metric - self.best_score > self.threshold: + self.best_score = metric if metric > self.best_score else self.best_score + self.patience = self.start_patience + else: + if self.patience <= 1: + if self.decreases == 0: + # This is condition for Trainer to trigger early stopping + self.patience = 0 + else: + self.patience = self.start_patience + self.decreases -= 1 + self.threshold /= 2 + self.lr_scheduler.base_lrs = [x / 2 for x in self.lr_scheduler.base_lrs] + else: + self.patience -= 1 diff --git a/combo/training/trainer.py b/combo/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..01773c394199c097214e9887e872e77580cfad54 --- /dev/null +++ b/combo/training/trainer.py @@ -0,0 +1,213 @@ +import datetime +import logging +import os +import time +import traceback +from typing import Any, Dict, List, Optional + +import torch +import torch.distributed as dist +import torch.optim as optim +import torch.optim.lr_scheduler +import torch.utils.data as data +from allennlp import training +from allennlp.common import checks +from allennlp.common import util as common_util +from allennlp.models import model +from allennlp.training import checkpointer +from allennlp.training import learning_rate_schedulers +from allennlp.training import momentum_schedulers +from allennlp.training import moving_average +from allennlp.training import tensorboard_writer +from allennlp.training import util as training_util +from overrides import overrides + +logger = logging.getLogger(__name__) + + +@training.EpochCallback.register('transfer_patience') +class TransferPatienceEpochCallback(training.EpochCallback): + + def __call__(self, trainer: "training.GradientDescentTrainer", metrics: Dict[str, Any], epoch: int) -> None: + if trainer._learning_rate_scheduler and trainer._learning_rate_scheduler.patience is not None: + trainer._metric_tracker._patience = trainer._learning_rate_scheduler.patience + trainer._metric_tracker._epochs_with_no_improvement = 0 + else: + raise checks.ConfigurationError("Learning rate scheduler isn't properly setup!") + + +@training.Trainer.register("gradient_descent_validate_n", constructor="from_partial_objects") +class GradientDescentTrainer(training.GradientDescentTrainer): + + def __init__(self, model: model.Model, optimizer: optim.Optimizer, data_loader: data.DataLoader, + patience: Optional[int] = None, validation_metric: str = "-loss", + validation_data_loader: data.DataLoader = None, num_epochs: int = 20, + serialization_dir: Optional[str] = None, checkpointer: checkpointer.Checkpointer = None, + cuda_device: int = -1, + grad_norm: Optional[float] = None, grad_clipping: Optional[float] = None, + learning_rate_scheduler: Optional[learning_rate_schedulers.LearningRateScheduler] = None, + momentum_scheduler: Optional[momentum_schedulers.MomentumScheduler] = None, + tensorboard_writer: tensorboard_writer.TensorboardWriter = None, + moving_average: Optional[moving_average.MovingAverage] = None, + batch_callbacks: List[training.BatchCallback] = None, + epoch_callbacks: List[training.EpochCallback] = None, distributed: bool = False, local_rank: int = 0, + world_size: int = 1, num_gradient_accumulation_steps: int = 1, + opt_level: Optional[str] = None) -> None: + + super().__init__(model, optimizer, data_loader, patience, validation_metric, validation_data_loader, num_epochs, + serialization_dir, checkpointer, cuda_device, grad_norm, grad_clipping, + learning_rate_scheduler, momentum_scheduler, tensorboard_writer, moving_average, + batch_callbacks, epoch_callbacks, distributed, local_rank, world_size, + num_gradient_accumulation_steps, opt_level) + # TODO extract param to constructor (+ constructor method?) + self.validate_every_n = 5 + + @overrides + def train(self) -> Dict[str, Any]: + """ + Trains the supplied model with the supplied parameters. + """ + try: + epoch_counter = self._restore_checkpoint() + except RuntimeError: + traceback.print_exc() + raise checks.ConfigurationError( + "Could not recover training from the checkpoint. Did you mean to output to " + "a different serialization directory or delete the existing serialization " + "directory?" + ) + + training_util.enable_gradient_clipping(self.model, self._grad_clipping) + + logger.info("Beginning training.") + + val_metrics: Dict[str, float] = {} + this_epoch_val_metric: float = None + metrics: Dict[str, Any] = {} + epochs_trained = 0 + training_start_time = time.time() + + metrics["best_epoch"] = self._metric_tracker.best_epoch + for key, value in self._metric_tracker.best_epoch_metrics.items(): + metrics["best_validation_" + key] = value + + for callback in self._epoch_callbacks: + callback(self, metrics={}, epoch=-1) + + for epoch in range(epoch_counter, self._num_epochs): + epoch_start_time = time.time() + train_metrics = self._train_epoch(epoch) + + # get peak of memory usage + if "cpu_memory_MB" in train_metrics: + metrics["peak_cpu_memory_MB"] = max( + metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"] + ) + for key, value in train_metrics.items(): + if key.startswith("gpu_"): + metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) + + if self._validation_data_loader is not None: + val_metrics = {} + this_epoch_val_metric = None + if epoch % self.validate_every_n == 0: + with torch.no_grad(): + # We have a validation set, so compute all the metrics on it. + val_loss, val_reg_loss, num_batches = self._validation_loss(epoch) + + # It is safe again to wait till the validation is done. This is + # important to get the metrics right. + if self._distributed: + dist.barrier() + + val_metrics = training_util.get_metrics( + self.model, + val_loss, + val_reg_loss, + num_batches, + reset=True, + world_size=self._world_size, + cuda_device=[self.cuda_device], + ) + + # Check validation metric for early stopping + this_epoch_val_metric = val_metrics[self._validation_metric] + # self._metric_tracker.add_metric(this_epoch_val_metric) + + train_metrics['patience'] = self._metric_tracker._patience + if self._metric_tracker.should_stop_early(): + logger.info("Ran out of patience. Stopping training.") + break + + if self._master: + self._tensorboard.log_metrics( + train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1 + ) # +1 because tensorboard doesn't like 0 + + # Create overall metrics dict + training_elapsed_time = time.time() - training_start_time + metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time)) + metrics["training_start_epoch"] = epoch_counter + metrics["training_epochs"] = epochs_trained + metrics["epoch"] = epoch + + for key, value in train_metrics.items(): + metrics["training_" + key] = value + for key, value in val_metrics.items(): + metrics["validation_" + key] = value + + if self._metric_tracker.is_best_so_far(): + # Update all the best_ metrics. + # (Otherwise they just stay the same as they were.) + metrics["best_epoch"] = epoch + for key, value in val_metrics.items(): + metrics["best_validation_" + key] = value + + self._metric_tracker.best_epoch_metrics = val_metrics + + if self._serialization_dir and self._master: + common_util.dump_metrics( + os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics + ) + + # The Scheduler API is agnostic to whether your schedule requires a validation metric - + # if it doesn't, the validation metric passed here is ignored. + if self._learning_rate_scheduler: + self._learning_rate_scheduler.step(this_epoch_val_metric) + if self._momentum_scheduler: + self._momentum_scheduler.step(this_epoch_val_metric) + + if self._master: + self._checkpointer.save_checkpoint( + epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far() + ) + + # Wait for the master to finish saving the checkpoint + if self._distributed: + dist.barrier() + + for callback in self._epoch_callbacks: + callback(self, metrics=metrics, epoch=epoch) + + epoch_elapsed_time = time.time() - epoch_start_time + logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) + + if epoch < self._num_epochs - 1: + training_elapsed_time = time.time() - training_start_time + estimated_time_remaining = training_elapsed_time * ( + (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1 + ) + formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining))) + logger.info("Estimated training time remaining: %s", formatted_time) + + epochs_trained += 1 + + # make sure pending events are flushed to disk and files are closed properly + self._tensorboard.close() + + # Load the best model state before returning + best_model_state = self._checkpointer.best_model_state() + if best_model_state: + self.model.load_state_dict(best_model_state) + + return metrics diff --git a/combo/utils/__init__.py b/combo/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/combo/utils/checks.py b/combo/utils/checks.py new file mode 100644 index 0000000000000000000000000000000000000000..d322c15192c6e82647cb51fba8ca9b2d78e5ee3a --- /dev/null +++ b/combo/utils/checks.py @@ -0,0 +1,20 @@ +import os + +import torch +from allennlp.common import checks as allen_checks + + +def file_exists(*paths): + for path in paths: + if path is None: + raise allen_checks.ConfigurationError(f'File cannot be None') + if not os.path.exists(path): + raise allen_checks.ConfigurationError(f'Could not find the file at path: `{path}`') + + +def check_size_match(size_1: torch.Size, size_2: torch.Size, tensor_1_name: str, tensor_2_name: str): + if size_1 != size_2: + raise allen_checks.ConfigurationError( + f"{tensor_1_name} must match {tensor_2_name}, but got {size_1} " + f"and {size_2} instead" + ) diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..bc142ff244d081f63f7ae54fe203d6c8a785ffe0 --- /dev/null +++ b/combo/utils/metrics.py @@ -0,0 +1,246 @@ +from typing import Optional, List, Dict + +import torch +from allennlp.training import metrics +from overrides import overrides + + +class SequenceBoolAccuracy(metrics.Metric): + + def __init__(self, prod_last_dim: bool = False): + self._correct_count = 0.0 + self._total_count = 0.0 + self.prod_last_dim = prod_last_dim + self.correct_indices = torch.ones([]) + + @overrides + def __call__(self, + predictions: torch.Tensor, + gold_labels: torch.Tensor, + mask: Optional[torch.BoolTensor] = None): + if gold_labels is None: + return + predictions, gold_labels, mask = self.detach_tensors(predictions, + gold_labels, + mask) + + # Some sanity checks. + if gold_labels.size() != predictions.size(): + raise ValueError( + f"gold_labels must have shape == predictions.size() but " + f"found tensor of shape: {gold_labels.size()}" + ) + if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]: + raise ValueError( + f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but " + f"found tensor of shape: {mask.size()}" + ) + if mask is None: + mask = predictions.new_ones(predictions.size()[:-1]).bool() + if mask.dim() < predictions.dim(): + mask = mask.unsqueeze(-1) + + correct = predictions.eq(gold_labels) * mask + + if self.prod_last_dim: + correct = correct.prod(-1).unsqueeze(-1) + + correct = correct.float() + + self.correct_indices = correct.flatten().bool() + self._correct_count += correct.sum() + self._total_count += mask.sum() + + @overrides + def get_metric(self, reset: bool) -> float: + if self._total_count > 0: + accuracy = float(self._correct_count) / float(self._total_count) + else: + accuracy = 0.0 + if reset: + self.reset() + return accuracy + + @overrides + def reset(self) -> None: + self._correct_count = 0.0 + self._total_count = 0.0 + self.correct_indices = torch.ones([]) + + +class AttachmentScores(metrics.Metric): + """ + Computes labeled and unlabeled attachment scores for a + dependency parse, as well as sentence level exact match + for both labeled and unlabeled trees. Note that the input + to this metric is the sampled predictions, not the distribution + itself. + + # Parameters + + ignore_classes : `List[int]`, optional (default = None) + A list of label ids to ignore when computing metrics. + """ + + def __init__(self, ignore_classes: List[int] = None) -> None: + self._labeled_correct = 0.0 + self._unlabeled_correct = 0.0 + self._exact_labeled_correct = 0.0 + self._exact_unlabeled_correct = 0.0 + self._total_words = 0.0 + self._total_sentences = 0.0 + self.correct_indices = torch.ones([]) + + self._ignore_classes: List[int] = ignore_classes or [] + + def __call__( # type: ignore + self, + predicted_indices: torch.Tensor, + predicted_labels: torch.Tensor, + gold_indices: torch.Tensor, + gold_labels: torch.Tensor, + mask: Optional[torch.BoolTensor] = None, + ): + """ + # Parameters + + predicted_indices : `torch.Tensor`, required. + A tensor of head index predictions of shape (batch_size, timesteps). + predicted_labels : `torch.Tensor`, required. + A tensor of arc label predictions of shape (batch_size, timesteps). + gold_indices : `torch.Tensor`, required. + A tensor of the same shape as `predicted_indices`. + gold_labels : `torch.Tensor`, required. + A tensor of the same shape as `predicted_labels`. + mask : `torch.BoolTensor`, optional (default = None). + A tensor of the same shape as `predicted_indices`. + """ + detached = self.detach_tensors( + predicted_indices, predicted_labels, gold_indices, gold_labels, mask + ) + predicted_indices, predicted_labels, gold_indices, gold_labels, mask = detached + + if mask is None: + mask = torch.ones_like(predicted_indices).bool() + + predicted_indices = predicted_indices.long() + predicted_labels = predicted_labels.long() + gold_indices = gold_indices.long() + gold_labels = gold_labels.long() + + # Multiply by a mask denoting locations of + # gold labels which we should ignore. + for label in self._ignore_classes: + label_mask = gold_labels.eq(label) + mask = mask & ~label_mask + + correct_indices = predicted_indices.eq(gold_indices).long() * mask + unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1) + correct_labels = predicted_labels.eq(gold_labels).long() * mask + correct_labels_and_indices = correct_indices * correct_labels + self.correct_indices = correct_labels_and_indices.flatten() + labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1) + + self._unlabeled_correct += correct_indices.sum() + self._exact_unlabeled_correct += unlabeled_exact_match.sum() + self._labeled_correct += correct_labels_and_indices.sum() + self._exact_labeled_correct += labeled_exact_match.sum() + self._total_sentences += correct_indices.size(0) + self._total_words += correct_indices.numel() - (~mask).sum() + + def get_metric(self, reset: bool = False): + """ + # Returns + + The accumulated metrics as a dictionary. + """ + unlabeled_attachment_score = 0.0 + labeled_attachment_score = 0.0 + unlabeled_exact_match = 0.0 + labeled_exact_match = 0.0 + if self._total_words > 0.0: + unlabeled_attachment_score = float(self._unlabeled_correct) / float(self._total_words) + labeled_attachment_score = float(self._labeled_correct) / float(self._total_words) + if self._total_sentences > 0: + unlabeled_exact_match = float(self._exact_unlabeled_correct) / float( + self._total_sentences + ) + labeled_exact_match = float(self._exact_labeled_correct) / float(self._total_sentences) + if reset: + self.reset() + return { + "UAS": unlabeled_attachment_score, + "LAS": labeled_attachment_score, + "UEM": unlabeled_exact_match, + "LEM": labeled_exact_match, + } + + @overrides + def reset(self): + self._labeled_correct = 0.0 + self._unlabeled_correct = 0.0 + self._exact_labeled_correct = 0.0 + self._exact_unlabeled_correct = 0.0 + self._total_words = 0.0 + self._total_sentences = 0.0 + self.correct_indices = torch.ones([]) + + +class SemanticMetrics(metrics.Metric): + + def __init__(self) -> None: + self.upos_score = SequenceBoolAccuracy() + self.xpos_score = SequenceBoolAccuracy() + self.semrel_score = SequenceBoolAccuracy() + self.feats_score = SequenceBoolAccuracy(prod_last_dim=True) + self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True) + self.attachment_scores = AttachmentScores() + self.em_score = 0.0 + + def __call__( # type: ignore + self, + predictions: Dict[str, torch.Tensor], + gold_labels: Dict[str, torch.Tensor], + mask: torch.BoolTensor): + self.upos_score(predictions['upostag'], gold_labels['upostag'], mask) + self.xpos_score(predictions['xpostag'], gold_labels['xpostag'], mask) + self.semrel_score(predictions['semrel'], gold_labels['semrel'], mask) + self.feats_score(predictions['feats'], gold_labels['feats'], mask) + self.lemma_score(predictions['lemma'], gold_labels['lemma'], mask) + self.attachment_scores(predictions['head'], + predictions['deprel'], + gold_labels['head'], + gold_labels['deprel'], + mask) + total = mask.sum() + correct_indices = (self.upos_score.correct_indices * + self.xpos_score.correct_indices * + self.semrel_score.correct_indices * + self.feats_score.correct_indices * + self.lemma_score.correct_indices * + self.attachment_scores.correct_indices + ) + + total, correct_indices = self.detach_tensors(total, correct_indices) + self.em_score = (correct_indices.float().sum() / total).item() + + def get_metric(self, reset: bool) -> Dict[str, float]: + metrics = { + "UPOS_ACC": self.upos_score.get_metric(reset), + "XPOS_ACC": self.xpos_score.get_metric(reset), + "SEMREL_ACC": self.semrel_score.get_metric(reset), + "LEMMA_ACC": self.lemma_score.get_metric(reset), + "FEATS_ACC": self.feats_score.get_metric(reset), + "EM": self.em_score + } + metrics.update(self.attachment_scores.get_metric(reset)) + return metrics + + def reset(self) -> None: + self.upos_score.reset() + self.xpos_score.reset() + self.semrel_score.reset() + self.lemma_score.reset() + self.feats_score.reset() + self.attachment_scores.reset() + self.em_score = 0.0 diff --git a/config.template.jsonnet b/config.template.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..4ec3ec4304c8bceb02a19be702d8c2e6bbfa0087 --- /dev/null +++ b/config.template.jsonnet @@ -0,0 +1,369 @@ +######################################################################################## +# BASIC configuration # +######################################################################################## +# Training data path, str +# Must be in CONNLU format (or it's extended version with semantic relation field). +# Can accepted multiple paths when concatenated with ':', "path1:path2" +local training_data_path = std.extVar("training_data_path"); +# Validation data path, str +# Can accepted multiple paths when concatenated with ':', "path1:path2" +local validation_data_path = if std.length(std.extVar("validation_data_path")) > 0 then std.extVar("validation_data_path"); +# Path to pretrained tokens, str or null +local pretrained_tokens = if std.length(std.extVar("pretrained_tokens")) > 0 then std.extVar("pretrained_tokens"); +# Name of pretrained transformer model, str or null +local pretrained_transformer_name = if std.length(std.extVar("pretrained_transformer_name")) > 0 then std.extVar("pretrained_transformer_name"); +# Learning rate value, float +local learning_rate = 0.002; +# Number of epochs, int +local num_epochs = 1; +# Cuda device id, -1 for cpu, int +local cuda_device = -1; +# Minimum number of words in batch, int +local word_batch_size = 1; +# Features used as input, list of str +# Choice "upostag", "xpostag", "lemma" +# Required "token", "char" +local features = std.split(std.extVar("features"), " "); +# Targets of the model, list of str +# Choice "feats", "lemma", "upostag", "xpostag", "semrel" +# Required "deprel", "head" +local targets = std.split(std.extVar("targets"), " "); +# Path for tensorboard metrics, str +local metrics_dir = "./runs"; +# Word embedding dimension, int +# If pretrained_tokens is not null must much provided dimensionality +local embedding_dim = std.parseInt(std.extVar("embedding_dim")); +# Dropout rate on predictors, float +# All of the models on top of the encoder uses this dropout +local predictors_dropout = 0.25; +# Xpostag embedding dimension, int +# (discarded if xpostag not in features) +local xpostag_dim = 100; +# Upostag embedding dimension, int +# (discarded if upostag not in features) +local upostag_dim = 100; +# Lemma embedding dimension, int +# (discarded if lemma not in features) +local lemma_char_dim = 64; +# Character embedding dim, int +local char_dim = 64; +# Word embedding projection dim, int +local projected_embedding_dim = 100; +# Loss weights, dict[str, int] +local loss_weights = { + xpostag: 0.05, + upostag: 0.05, + lemma: 0.05, + feats: 0.2, + deprel: 0.8, + head: 0.2, +}; +# Encoder hidden size, int +local hidden_size = 512; +# Number of layers in the encoder, int +local num_layers = 2; +# Cycle loss iterations, int +local cycle_loss_n = 0; +# Maximum length of the word, int +# Shorter words are padded, longer - truncated +local word_length = 30; + + +# Helper functions +local in_features(name) = !(std.length(std.find(name, features)) == 0); +local in_targets(name) = !(std.length(std.find(name, targets)) == 0); +local use_transformer = pretrained_transformer_name != null; + +# Verify some configuration requirements +assert in_features("token"): "Key 'token' must be in features!"; +assert in_features("char"): "Key 'char' must be in features!"; + +assert in_targets("deprel"): "Key 'deprel' must be in targets!"; +assert in_targets("head"): "Key 'head' must be in targets!"; + +assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use pretrained tokens and pretrained transformer at the same time!"; + +######################################################################################## +# ADVANCED configuration # +######################################################################################## + +# Detailed dataset, training, vocabulary and model configuration. +{ + # Configuration type (default or finetuning), str + type: std.extVar('type'), + # Datasets used for vocab creation, list of str + # Choice "train", "valid" + datasets_for_vocab_creation: ['train'], + # Path to training data, str + train_data_path: training_data_path, + # Path to validation data, str + validation_data_path: validation_data_path, + # Dataset reader configuration (conllu format) + dataset_reader: { + type: "conllu", + features: features, + targets: targets, + # Whether data contains semantic relation field, bool + use_sem: if in_features("semrel") then true else false, + token_indexers: { + token: if use_transformer then { + type: "pretrained_transformer_mismatched", + model_name: pretrained_transformer_name, + } else { + # SingleIdTokenIndexer, token as single int + type: "single_id", + }, + upostag: { + type: "single_id", + namespace: "upostag", + feature_name: "pos_", + }, + xpostag: { + type: "single_id", + namespace: "xpostag", + feature_name: "tag_", + }, + lemma: { + type: "characters_const_padding", + character_tokenizer: { + start_tokens: ["__START__"], + end_tokens: ["__END__"], + }, + # +2 for start and end token + min_padding_length: word_length + 2, + }, + char: { + type: "characters_const_padding", + character_tokenizer: { + start_tokens: ["__START__"], + end_tokens: ["__END__"], + }, + # +2 for start and end token + min_padding_length: word_length + 2, + }, + }, + lemma_indexers: { + char: { + type: "characters_const_padding", + namespace: "lemma_characters", + character_tokenizer: { + start_tokens: ["__START__"], + end_tokens: ["__END__"], + }, + # +2 for start and end token + min_padding_length: word_length + 2, + }, + }, + }, + # Data loader configuration + data_loader: { + batch_sampler: { + type: "token_count", + word_batch_size: word_batch_size, + }, + }, + # Vocabulary configuration + vocabulary: std.prune({ + type: "from_instances_extended", + only_include_pretrained_words: true, + pretrained_files: { + tokens: pretrained_tokens, + }, + oov_token: "_", + padding_token: "__PAD__", + non_padded_namespaces: ["head_labels"], + }), + model: std.prune({ + type: "semantic_multitask", + text_field_embedder: { + type: "basic", + token_embedders: { + xpostag: if in_features("xpostag") then { + type: "embedding", + padding_index: 0, + embedding_dim: xpostag_dim, + vocab_namespace: "xpostag", + }, + upostag: if in_features("upostag") then { + type: "embedding", + padding_index: 0, + embedding_dim: upostag_dim, + vocab_namespace: "upostag", + }, + token: if use_transformer then { + type: "transformers_word_embeddings", + model_name: pretrained_transformer_name, + projection_dim: projected_embedding_dim, + } else { + type: "embeddings_projected", + embedding_dim: embedding_dim, + projection_layer: { + in_features: embedding_dim, + out_features: projected_embedding_dim, + dropout_rate: 0.25, + activation: "tanh" + }, + vocab_namespace: "tokens", + pretrained_file: pretrained_tokens, + trainable: if pretrained_tokens == null then true else false, + }, + char: { + type: "char_embeddings_from_config", + embedding_dim: char_dim, + dilated_cnn_encoder: { + input_dim: char_dim, + filters: [512, 256, char_dim], + kernel_size: [3, 3, 3], + stride: [1, 1, 1], + padding: [1, 2, 4], + dilation: [1, 2, 4], + activations: ["relu", "relu", "linear"], + }, + }, + lemma: if in_features("lemma") then { + type: "char_embeddings_from_config", + embedding_dim: lemma_char_dim, + dilated_cnn_encoder: { + input_dim: lemma_char_dim, + filters: [512, 256, lemma_char_dim], + kernel_size: [3, 3, 3], + stride: [1, 1, 1], + padding: [1, 2, 4], + dilation: [1, 2, 4], + activations: ["relu", "relu", "linear"], + }, + }, + }, + }, + loss_weights: loss_weights, + seq_encoder: { + type: "combo_encoder", + layer_dropout_probability: 0.33, + stacked_bilstm: { + input_size: + char_dim + projected_embedding_dim + + if in_features('xpostag') then xpostag_dim else 0 + + if in_features('lemma') then lemma_char_dim else 0 + + if in_features('upostag') then upostag_dim else 0, + hidden_size: hidden_size, + num_layers: num_layers, + recurrent_dropout_probability: 0.33, + layer_dropout_probability: 0.33 + }, + }, + dependency_relation: { + type: "combo_dependency_parsing_from_vocab", + vocab_namespace: 'deprel_labels', + head_predictor: { + local projection_dim = 512, + cycle_loss_n: cycle_loss_n, + head_projection_layer: { + in_features: hidden_size * 2, + out_features: projection_dim, + activation: "tanh", + }, + dependency_projection_layer: { + in_features: hidden_size * 2, + out_features: projection_dim, + activation: "tanh", + }, + }, + local projection_dim = 128, + head_projection_layer: { + in_features: hidden_size * 2, + out_features: projection_dim, + dropout_rate: predictors_dropout, + activation: "tanh" + }, + dependency_projection_layer: { + in_features: hidden_size * 2, + out_features: projection_dim, + dropout_rate: predictors_dropout, + activation: "tanh" + }, + }, + morphological_feat: if in_targets("feats") then { + type: "combo_morpho_from_vocab", + vocab_namespace: "feats_labels", + input_dim: hidden_size * 2, + hidden_dims: [128], + activations: ["tanh", "linear"], + dropout: [predictors_dropout, 0.0], + num_layers: 2, + }, + lemmatizer: if in_targets("lemma") then { + type: "combo_lemma_predictor_from_vocab", + char_vocab_namespace: "token_characters", + lemma_vocab_namespace: "lemma_characters", + embedding_dim: 256, + input_projection_layer: { + in_features: hidden_size * 2, + out_features: 32, + dropout_rate: predictors_dropout, + activation: "tanh" + }, + filters: [256, 256, 256], + kernel_size: [3, 3, 3, 1], + stride: [1, 1, 1, 1], + padding: [1, 2, 4, 0], + dilation: [1, 2, 4, 1], + activations: ["relu", "relu", "relu", "linear"], + }, + upos_tagger: if in_targets("upostag") then { + input_dim: hidden_size * 2, + hidden_dims: [64], + activations: ["tanh", "linear"], + dropout: [predictors_dropout, 0.0], + num_layers: 2, + vocab_namespace: "upostag_labels" + }, + xpos_tagger: if in_targets("xpostag") then { + input_dim: hidden_size * 2, + hidden_dims: [128], + activations: ["tanh", "linear"], + dropout: [predictors_dropout, 0.0], + num_layers: 2, + vocab_namespace: "xpostag_labels" + }, + semantic_relation: if in_targets("semrel") then { + input_dim: hidden_size * 2, + hidden_dims: [64], + activations: ["tanh", "linear"], + dropout: [predictors_dropout, 0.0], + num_layers: 2, + vocab_namespace: "semrel_labels" + }, + regularizer: { + regexes: [ + [".*conv1d.*", {type: "l2", alpha: 1e-6}], + [".*forward.*", {type: "l2", alpha: 1e-6}], + [".*backward.*", {type: "l2", alpha: 1e-6}], + [".*char_embed.*", {type: "l2", alpha: 1e-5}], + ], + }, + }), + trainer: { + type: "gradient_descent_validate_n", + cuda_device: cuda_device, + grad_clipping: 5.0, + num_epochs: num_epochs, + optimizer: { + type: "adam", + lr: learning_rate, + betas: [0.9, 0.9], + }, + patience: 1, # it will be overwriten by callback + epoch_callbacks: [ + { type: "transfer_patience" }, + ], + learning_rate_scheduler: { + type: "combo_scheduler", + }, + tensorboard_writer: { + serialization_dir: metrics_dir, + should_log_learning_rate: true, + summary_interval: 2, + }, + validation_metric: "+EM", + }, +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..68c7453553d2fd81d2a044b229a2ece7f0f2a45b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +pylint==2.4.4 +pylint-quotes==0.2.1 +pytest==5.3.5 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..b7e478982ccf9ab1963c74e1084dfccb6e42c583 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[aliases] +test=pytest diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..c7caeb3b98e89a09d0737a809e2c8af865f65a94 --- /dev/null +++ b/setup.py @@ -0,0 +1,25 @@ +"""Setup.""" +from setuptools import find_packages, setup + +REQUIREMENTS = [ + 'absl-py==0.9.0', + 'allennlp==1.0.0rc4', + 'conllu==2.3.2', + 'joblib==0.14.1', + 'jsonnet==0.15.0', + 'overrides==2.8.0', + 'tensorboard==2.1.0', + 'torch==1.5.0', + 'torchvision==0.6.0', + 'transformers==2.9.1', +] + +setup( + name='COMBO', + version='0.0.1', + install_requires=REQUIREMENTS, + packages=find_packages(exclude=['tests']), + setup_requires=['pytest-runner', 'pytest-pylint'], + tests_require=['pytest', 'pylint'], + entry_points={'console_scripts': ['combo = combo.main:main']}, +) diff --git a/tests/data/fields/test_samplers.py b/tests/data/fields/test_samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..81689fc6ac573687708e3667d4ec734cb5f9b971 --- /dev/null +++ b/tests/data/fields/test_samplers.py @@ -0,0 +1,32 @@ +"""Sampler tests.""" +import unittest + +from allennlp import data +from allennlp.data import fields + +from combo.data import TokenCountBatchSampler + + +class TokenCountBatchSamplerTest(unittest.TestCase): + + def setUp(self) -> None: + self.dataset = [] + self.sentences = ['First sentence makes full batch.', 'Short', 'This ends first batch'] + for sentence in self.sentences: + tokens = [data.Token(t) + for t in sentence.split()] + text_field = fields.TextField(tokens, {}) + self.dataset.append(data.Instance({'sentence': text_field})) + + def test_batches(self): + # given + sampler = TokenCountBatchSampler(self.dataset, word_batch_size=2, shuffle_dataset=False) + + # when + length = len(sampler) + values = list(sampler) + + # then + self.assertEqual(2, length) + # sort by lengths + word_batch_size makes 1, 2 first batch + self.assertListEqual([[1, 2], [0]], values) diff --git a/tests/data/fields/test_sequence_multilabel_field.py b/tests/data/fields/test_sequence_multilabel_field.py new file mode 100644 index 0000000000000000000000000000000000000000..85d28baf44bf5d05b9d209031280da10cade834f --- /dev/null +++ b/tests/data/fields/test_sequence_multilabel_field.py @@ -0,0 +1,96 @@ +"""Sequence multilabel field tests.""" +import unittest +from typing import List + +import torch +from allennlp import data as allen_data +from allennlp.data import fields as allen_fields + +from combo.data import fields + + +class IndexingSequenceMultiLabelFieldTest(unittest.TestCase): + + def setUp(self) -> None: + self.namespace = 'test_labels' + self.vocab = allen_data.Vocabulary() + self.vocab.add_tokens_to_namespace( + tokens=['t' + str(idx) for idx in range(3)], + namespace=self.namespace + ) + + def _indexer(vocab: allen_data.Vocabulary): + vocab_size = vocab.get_vocab_size(self.namespace) + + def _mapper(multi_label: List[str]) -> List[int]: + one_hot = [0] * vocab_size + for label in multi_label: + index = vocab.get_token_index(label, self.namespace) + one_hot[index] = 1 + return one_hot + + return _mapper + + self.indexer = _indexer + self.sequence_field = _SequenceFieldTestWrapper(self.vocab.get_vocab_size(self.namespace)) + + def test_indexing(self): + # given + field = fields.SequenceMultiLabelField( + multi_labels=[['t1', 't2'], [], ['t0']], + multi_label_indexer=self.indexer, + sequence_field=self.sequence_field, + label_namespace=self.namespace + ) + expected = [[0, 1, 1], [0, 0, 0], [1, 0, 0]] + + # when + field.index(self.vocab) + + # then + self.assertEqual(expected, field._indexed_multi_labels) + + def test_mapping_to_tensor(self): + # given + field = fields.SequenceMultiLabelField( + multi_labels=[['t1', 't2'], [], ['t0']], + multi_label_indexer=self.indexer, + sequence_field=self.sequence_field, + label_namespace=self.namespace + ) + field.index(self.vocab) + expected = torch.tensor([[0, 1, 1], [0, 0, 0], [1, 0, 0]]) + + # when + actual = field.as_tensor(field.get_padding_lengths()) + + # then + self.assertTrue(torch.all(expected.eq(actual))) + + def test_sequence_method(self): + # given + field = fields.SequenceMultiLabelField( + multi_labels=[['t1', 't2'], [], ['t0']], + multi_label_indexer=self.indexer, + sequence_field=self.sequence_field, + label_namespace=self.namespace + ) + + # when + length = len(field) + iter_length = len(list(iter(field))) + middle_value = field[1] + + # then + self.assertEqual(3, length) + self.assertEqual(3, iter_length) + self.assertEqual([], middle_value) + + +class _SequenceFieldTestWrapper(allen_fields.SequenceField): + + def __init__(self, length: int): + self.length = length + + def sequence_length(self) -> int: + return self.length diff --git a/tests/fixtures/example.conllu b/tests/fixtures/example.conllu new file mode 100644 index 0000000000000000000000000000000000000000..0cf30f04351db8904513b7363a6e8414b27cd347 --- /dev/null +++ b/tests/fixtures/example.conllu @@ -0,0 +1,5 @@ +# sent_id = test-s1 +# text = Easy sentence. +1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep 1 amod _ _ +2 Sentence verylonglemmawhichmustbetruncatedbythesystemto30 NOUN nom Number=Sing 0 root _ _ +3 . . PUNCT . _ 1 punct _ _ diff --git a/tests/fixtures/example.vec b/tests/fixtures/example.vec new file mode 100644 index 0000000000000000000000000000000000000000..99d498691f6e086f954417062d7cff8d21b14b5f --- /dev/null +++ b/tests/fixtures/example.vec @@ -0,0 +1,4 @@ +2 300 +Sentence 0.1271 0.0356 -0.0134 -0.0100 0.1353 0.0263 -0.0597 0.0231 0.0074 -0.0126 0.0209 -0.0309 0.2831 -0.0016 -0.0338 0.1002 0.0494 -0.0874 -0.0929 -0.0134 -0.1597 -0.0016 0.0273 0.0035 -0.1760 -0.0756 -0.0853 0.0095 0.0008 -0.0660 0.1774 0.0660 -0.0071 -0.0210 0.0229 -0.0173 -0.0379 -0.0682 0.0198 0.0254 -0.0108 -0.0778 -0.0094 0.0625 -0.0221 0.0631 0.0354 -0.0296 0.0617 0.0314 0.1582 0.1129 0.0281 0.0021 0.4368 -0.1094 0.0555 -0.0240 -0.1006 -0.0608 -0.0936 0.0184 -0.0265 0.0042 -0.1166 -0.0007 -0.0332 -0.0832 0.0454 -0.0446 -0.0499 -0.0852 -0.0684 -0.0598 -0.0065 -0.0321 0.0578 -0.0149 -0.0353 0.0319 0.0691 0.0843 0.8444 0.0066 -0.0465 -0.0164 0.0589 0.0603 -0.0589 -0.0497 -0.0184 0.0095 -0.1072 0.0254 -0.0094 -0.0271 0.0378 0.0048 -0.0015 -0.1009 -0.0182 -0.0513 0.0573 -0.0002 0.0518 -0.0759 0.0564 -0.0137 -0.0261 0.0656 0.0061 -0.0295 -0.0349 -0.1159 -0.0101 -0.0236 -0.0483 0.0343 0.0057 -0.1205 -0.0833 0.0393 0.0236 0.0129 -0.0899 -0.0360 -0.0380 -0.0386 0.0778 -0.0329 -0.0757 0.0701 -0.0246 -0.3195 -0.1065 -0.0229 0.2152 0.0918 -0.0843 -0.0342 -0.0281 0.0780 -0.0004 0.1064 0.1484 0.0727 -0.0202 -0.0514 -0.0393 0.0165 -0.0911 -0.0117 0.0395 -0.1155 -0.2148 -0.0061 0.0328 -0.0558 -0.1027 0.0284 -0.0193 -0.0006 -0.0815 0.0062 0.0224 0.0079 -0.1508 -0.2233 0.0603 -0.0200 0.0484 -0.0298 0.2781 0.0728 -0.0166 -0.0532 -0.1248 0.1290 -0.0349 0.0578 -0.0680 -0.0327 0.1507 -0.0220 -0.0769 -0.1918 0.0003 0.0308 -0.0105 0.0436 0.0709 -0.0622 0.0654 0.1635 -0.0525 0.0198 0.0719 0.0586 0.0509 -0.0101 -0.2630 0.0578 0.0433 -0.0168 0.1403 -0.0320 0.0609 0.0251 -0.0038 -0.0138 -0.0424 0.0864 -0.1150 -0.0703 -0.1763 0.0424 0.0629 -0.0049 -0.3170 0.1916 -0.0464 0.0751 0.0237 0.0864 -0.1301 0.0279 -0.2121 -0.0120 0.1025 0.0207 -0.0091 -0.0104 0.0493 -0.0151 -0.0463 -0.0199 0.0247 0.0271 0.0203 0.0210 0.0967 0.0239 0.0313 0.0170 -0.1281 0.0023 -0.0771 -0.0273 -0.0128 -0.0623 0.0169 0.0569 0.0365 0.0747 -0.0060 0.0538 0.0408 -0.0028 -0.0357 0.0390 0.0170 -0.1010 -0.0654 -0.0552 0.0706 -0.0406 0.0005 -0.0368 -0.0214 -0.0939 -0.0396 0.0535 0.0406 -0.0293 0.0240 0.0699 -0.0381 -0.0011 -0.0323 0.0977 -0.0687 0.1055 0.0692 -0.0002 -0.0174 -0.1106 0.0914 -0.0234 0.2364 0.0245 -0.0698 0.1212 0.0032 -0.0884 -0.1106 -0.0209 0.0748 -0.0175 -0.0111 0.0092 +. -0.1028 0.0208 -0.0235 0.0794 0.0284 0.0097 0.0328 -0.0054 -0.0338 -0.0541 -0.0293 -0.0616 0.3285 -0.1766 0.0214 0.3172 0.1139 0.0546 0.0106 -0.0492 0.1317 0.0441 -0.0442 -0.0822 -0.2058 -0.3436 0.0357 0.0860 0.0631 -0.0236 0.0253 0.1204 -0.0011 0.0833 0.0457 0.0463 -0.0049 -0.0002 0.0356 -0.0221 -0.0241 0.1429 -0.0498 0.1081 -0.0569 0.1883 -0.0259 0.0110 0.0795 0.0906 0.0159 -0.1774 0.0466 0.1067 0.4010 0.0221 0.0473 0.0630 -0.0191 0.0152 -0.0745 0.0615 -0.0565 0.0168 0.0042 -0.0890 -0.0333 0.0493 -0.0694 0.0076 -0.1073 0.0070 0.0190 -0.0456 -0.1193 0.0096 -0.0013 -0.0110 -0.0602 0.0485 -0.0192 0.1404 1.0454 -0.0198 0.1328 0.0495 -0.0621 -0.0135 -0.0507 -0.0276 0.0010 0.0499 0.1893 -0.0062 0.0283 -0.0601 0.0179 0.0620 0.0073 -0.0410 -0.0922 -0.0439 -0.0612 0.0197 0.0297 0.0109 0.0535 0.0724 -0.0540 0.0939 0.0062 -0.0886 -0.0044 -0.0572 -0.0058 -0.0100 -0.0138 0.0796 -0.0190 0.0539 0.1392 0.2092 -0.0483 0.0948 0.0293 -0.1015 -0.0377 -0.0052 -0.1322 0.0353 -0.0000 0.0968 -0.0553 -0.1574 -0.2598 -0.0186 0.2303 -0.0370 0.0001 -0.0023 0.0240 0.0100 -0.0249 0.2826 -0.1605 -0.0762 -0.0917 -0.0243 -0.0795 -0.0109 -0.0247 -0.1053 0.1111 -0.0671 -0.2930 0.0796 0.0465 -0.0236 0.0477 0.0900 -0.0841 -0.0175 -0.0492 -0.0029 0.0029 -0.0543 0.3960 -0.0695 0.0380 0.0090 0.0191 -0.0685 0.3998 -0.2886 -0.0274 -0.0080 -0.0081 0.6646 -0.0150 -0.0635 -0.0385 -0.0017 0.0057 -0.0346 0.0091 -0.5501 0.0295 -0.0726 0.0424 -0.0347 0.0247 -0.0223 -0.0820 0.5079 0.0582 0.0771 0.0283 0.0328 0.0002 -0.0144 0.2214 0.0470 0.0793 0.0079 -0.0763 0.0095 -0.0511 -0.0454 -0.0095 0.0097 -0.0196 0.1779 0.0348 0.0653 0.0193 0.0236 0.0269 -0.0042 -0.0604 -0.1879 0.0400 0.0415 0.0368 -0.0349 0.1267 0.0685 0.2566 -0.3385 0.0788 -0.0458 0.0736 0.3825 0.0496 0.0342 0.0906 -0.0084 -0.0540 0.0018 -0.0042 0.0416 0.0251 -0.0075 0.0347 -0.0899 -0.0071 0.0186 0.0034 -0.0192 -0.0017 -0.1191 0.0037 -0.0363 0.0364 0.0664 -0.2065 -0.1264 0.0571 -0.0173 -0.1454 -0.0173 0.0830 -0.0438 0.0029 0.1918 0.2079 0.0393 0.0081 0.0228 -0.0086 -0.1322 -0.0006 0.1892 0.0536 0.0155 -0.0371 -0.0056 0.0188 0.0864 -0.1048 0.2564 -0.1761 0.0433 0.0597 -0.0005 0.0409 -0.0803 0.0188 0.0211 -0.0255 -0.0026 -0.0438 -0.2800 -0.0682 -0.0544 -0.0386 -0.0479 0.1007 0.0435 0.0229 0.0632 + diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000000000000000000000000000000000000..b7628b1258a4eff70c58616e310736795e4ac3e2 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,52 @@ +import logging +import os +import pathlib +import shutil +import tempfile +import unittest +from allennlp.commands import train +from allennlp.common import Params, util + + +class TrainingEndToEndTest(unittest.TestCase): + PROJECT_ROOT = (pathlib.Path(__file__).parent / '..').resolve() + MODULE_ROOT = PROJECT_ROOT / 'combo' + TESTS_ROOT = PROJECT_ROOT / 'tests' + FIXTURES_ROOT = TESTS_ROOT / 'fixtures' + TEST_DIR = pathlib.Path(tempfile.mkdtemp(prefix='allennlp_tests')) + + def setUp(self) -> None: + logging.getLogger('allennlp.common.util').disabled = True + logging.getLogger('allennlp.training.tensorboard_writer').disabled = True + logging.getLogger('allennlp.common.params').disabled = True + logging.getLogger('allennlp.nn.initializers').disabled = True + + def test_training_produces_model(self): + # given + util.import_module_and_submodules('combo.models') + util.import_module_and_submodules('combo.training') + ext_vars = { + 'training_data_path': os.path.join(self.FIXTURES_ROOT, 'example.conllu'), + 'validation_data_path': os.path.join(self.FIXTURES_ROOT, 'example.conllu'), + 'features': 'token char', + 'targets': 'deprel head lemma feats upostag xpostag', + 'type': 'default', + 'pretrained_tokens': os.path.join(self.FIXTURES_ROOT, 'example.vec'), + 'pretrained_transformer_name': '', + 'embedding_dim': '300', + + } + params = Params.from_file(os.path.join(self.PROJECT_ROOT, 'config.template.jsonnet'), + ext_vars=ext_vars) + params['trainer']['tensorboard_writer']['serialization_dir'] = os.path.join(self.TEST_DIR, 'metrics') + params['trainer']['num_epochs'] = 1 + params['data_loader']['batch_sampler']['word_batch_size'] = 1 + + # when + model = train.train_model(params, serialization_dir=self.TEST_DIR) + + # then + self.assertIsNotNone(model) + + def tearDown(self) -> None: + shutil.rmtree(self.TEST_DIR) diff --git a/tests/utils/test_checks.py b/tests/utils/test_checks.py new file mode 100644 index 0000000000000000000000000000000000000000..df9e624246d7a0ae243cae65ea5756f2933b5829 --- /dev/null +++ b/tests/utils/test_checks.py @@ -0,0 +1,38 @@ +"""Checks tests.""" +import unittest + +import torch +from allennlp.common import checks as allen_checks + +from combo.utils import checks + + +class SizeCheckTest(unittest.TestCase): + + def test_equal_sizes(self): + # given + size = (10, 2) + tensor1 = torch.rand(size) + tensor2 = torch.rand(size) + + # when + checks.check_size_match(size_1=tensor1.size(), + size_2=tensor2.size(), + tensor_1_name='', tensor_2_name='') + + # then + # nothing happens + self.assertTrue(True) + + def test_different_sizes(self): + # given + size1 = (10, 2) + size2 = (20, 1) + tensor1 = torch.rand(size1) + tensor2 = torch.rand(size2) + + # when/then + with self.assertRaises(allen_checks.ConfigurationError): + checks.check_size_match(size_1=tensor1.size(), + size_2=tensor2.size(), + tensor_1_name='', tensor_2_name='') diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..305939c032d344a2e7d9a1d0e772a0bc8acf19f0 --- /dev/null +++ b/tests/utils/test_metrics.py @@ -0,0 +1,152 @@ +"""Metrics tests.""" +import unittest + +import torch + +from combo.utils import metrics + + +class SemanticMetricsTest(unittest.TestCase): + + def setUp(self) -> None: + self.mask: torch.BoolTensor = torch.tensor([ + [True, True, True, True], + [True, True, True, False], + [True, True, True, False], + ]) + pred = torch.tensor([ + [0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3], + ]) + pred_seq = pred.reshape(3, 4, 1) + gold = pred.clone() + gold_seq = pred_seq.clone() + self.upostag, self.upostag_l = (('upostag', x) for x in [pred, gold]) + self.xpostag, self.xpostag_l = (('xpostag', x) for x in [pred, gold]) + self.semrel, self.semrel_l = (('semrel', x) for x in [pred, gold]) + self.head, self.head_l = (('head', x) for x in [pred, gold]) + self.deprel, self.deprel_l = (('deprel', x) for x in [pred, gold]) + self.feats, self.feats_l = (('feats', x) for x in [pred_seq, gold_seq]) + self.lemma, self.lemma_l = (('lemma', x) for x in [pred_seq, gold_seq]) + self.predictions = dict( + [self.upostag, self.xpostag, self.semrel, self.feats, self.lemma, self.head, self.deprel]) + self.gold_labels = dict([self.upostag_l, self.xpostag_l, self.semrel_l, self.feats_l, self.lemma_l, self.head_l, + self.deprel_l]) + self.eps = 1e-6 + + def test_every_prediction_correct(self): + # given + metric = metrics.SemanticMetrics() + + # when + metric(self.predictions, self.gold_labels, self.mask) + + # then + self.assertEqual(1.0, metric.em_score) + + def test_missing_predictions_for_one_target(self): + # given + metric = metrics.SemanticMetrics() + self.predictions['upostag'] = None + self.gold_labels['upostag'] = None + + # when + metric(self.predictions, self.gold_labels, self.mask) + + # then + self.assertEqual(1.0, metric.em_score) + + def test_missing_predictions_for_two_targets(self): + # given + metric = metrics.SemanticMetrics() + self.predictions['upostag'] = None + self.gold_labels['upostag'] = None + self.predictions['lemma'] = None + self.gold_labels['lemma'] = None + + # when + metric(self.predictions, self.gold_labels, self.mask) + + # then + self.assertEqual(1.0, metric.em_score) + + def test_one_classification_in_one_target_is_wrong(self): + # given + metric = metrics.SemanticMetrics() + self.predictions['upostag'][0][0] = 100 + + # when + metric(self.predictions, self.gold_labels, self.mask) + + # then + self.assertAlmostEqual(0.9, metric.em_score, delta=self.eps) + + def test_classification_errors_and_target_without_predictions(self): + # given + metric = metrics.SemanticMetrics() + self.predictions['feats'] = None + self.gold_labels['feats'] = None + self.predictions['upostag'][0][0] = 100 + self.predictions['upostag'][2][0] = 100 + # should be ignored due to masking + self.predictions['upostag'][1][3] = 100 + + # when + metric(self.predictions, self.gold_labels, self.mask) + + # then + self.assertAlmostEqual(0.8, metric.em_score, delta=self.eps) + + +class SequenceBoolAccuracyTest(unittest.TestCase): + + def setUp(self) -> None: + self.mask: torch.BoolTensor = torch.tensor([ + [True, True, True, True], + [True, True, True, False], + [True, True, True, False], + ]) + + def test_regular_classification_accuracy(self): + # given + metric = metrics.SequenceBoolAccuracy() + predictions = torch.tensor([ + [1, 1, 0, 8], + [1, 2, 3, 4], + [9, 4, 3, 9], + ]) + gold_labels = torch.tensor([ + [11, 1, 0, 8], + [14, 2, 3, 14], + [9, 4, 13, 9], + ]) + + # when + metric(predictions, gold_labels, self.mask) + + # then + self.assertEqual(metric._correct_count.item(), 7) + self.assertEqual(metric._total_count.item(), 10) + + def test_feats_classification_accuracy(self): + # given + metric = metrics.SequenceBoolAccuracy(prod_last_dim=True) + # batch_size, sequence_length, classes + predictions = torch.tensor([ + [[1, 4], [0, 2], [0, 2], [0, 3]], + [[1, 4], [0, 2], [0, 2], [0, 3]], + [[1, 4], [0, 2], [0, 2], [0, 3]], + ]) + gold_labels = torch.tensor([ + [[1, 14], [0, 2], [0, 2], [0, 3]], + [[11, 4], [0, 2], [0, 2], [10, 3]], + [[1, 4], [0, 2], [10, 12], [0, 3]], + ]) + + # when + metric(predictions, gold_labels, self.mask) + + # then + self.assertEqual(metric._correct_count.item(), 7) + self.assertEqual(metric._total_count.item(), 10)