diff --git a/combo/config.multitask.template.jsonnet b/combo/config.multitask.template.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..5edff14362062be0adccc65474067c1f43a7d791 --- /dev/null +++ b/combo/config.multitask.template.jsonnet @@ -0,0 +1,404 @@ +# Configuration file for jointly training a model using CoNNL-U and IOB. +local shared_config = import "config.shared.libsonnet"; +######################################################################################## +# BASIC configuration # +######################################################################################## +# Training data path, str +# Must be in CONNLU format (or it's extended version with semantic relation field). +# Can accepted multiple paths when concatenated with ',', "path1,path2" +local training_data_path = std.extVar("training_data_path"); +# Validation data path, str +# Can accepted multiple paths when concatenated with ',', "path1,path2" +local validation_data_path = if std.length(std.extVar("validation_data_path")) > 0 then std.extVar("validation_data_path"); +# Path to pretrained tokens, str or null +local pretrained_tokens = if std.length(std.extVar("pretrained_tokens")) > 0 then std.extVar("pretrained_tokens"); +# Name of pretrained transformer model, str or null +local pretrained_transformer_name = if std.length(std.extVar("pretrained_transformer_name")) > 0 then std.extVar("pretrained_transformer_name"); +# Learning rate value, float +local learning_rate = 0.002; +# Number of epochs, int +local num_epochs = std.parseInt(std.extVar("num_epochs")); +# Cuda device id, -1 for cpu, int +local cuda_device = std.parseInt(std.extVar("cuda_device")); +# Minimum number of words in batch, int +local word_batch_size = std.parseInt(std.extVar("word_batch_size")); +# Features used as input, list of str +# Choice "upostag", "xpostag", "lemma" +# Required "token", "char" +local features = std.split(std.extVar("features"), " "); +# Targets of the model, list of str +# Choice "feats", "lemma", "upostag", "xpostag", "semrel". "sent" +# Required "deprel", "head" +local targets = std.split(std.extVar("targets"), " "); +# Word embedding dimension, int +# If pretrained_tokens is not null must much provided dimensionality +local embedding_dim = std.parseInt(std.extVar("embedding_dim")); +# Dropout rate on predictors, float +# All of the models on top of the encoder uses this dropout +local predictors_dropout = 0.25; +# Xpostag embedding dimension, int +# (discarded if xpostag not in features) +local xpostag_dim = 32; +# Upostag embedding dimension, int +# (discarded if upostag not in features) +local upostag_dim = 32; +# Feats embedding dimension, int +# (discarded if feats not in featres) +local feats_dim = 32; +# Lemma embedding dimension, int +# (discarded if lemma not in features) +local lemma_char_dim = 64; +# Character embedding dim, int +local char_dim = 64; +# Word embedding projection dim, int +local projected_embedding_dim = 100; +# Loss weights, dict[str, int] +local loss_weights = { + xpostag: 0.05, + upostag: 0.05, + lemma: 0.05, + feats: 0.2, + deprel: 0.8, + head: 0.2, + semrel: 0.05, +}; +# Encoder hidden size, int +local hidden_size = 512; +# Number of layers in the encoder, int +local num_layers = 2; +# Cycle loss iterations, int +local cycle_loss_n = 0; +# Maximum length of the word, int +# Shorter words are padded, longer - truncated +local word_length = 30; +# Whether to use tensorboard, bool +local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false; + +# Helper functions +local in_features(name) = !(std.length(std.find(name, features)) == 0); +local in_targets(name) = !(std.length(std.find(name, targets)) == 0); +local use_transformer = pretrained_transformer_name != null; + +# Verify some configuration requirements +assert in_features("token"): "Key 'token' must be in features!"; +assert in_features("char"): "Key 'char' must be in features!"; + +assert in_targets("deprel"): "Key 'deprel' must be in targets!"; +assert in_targets("head"): "Key 'head' must be in targets!"; + +assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use pretrained tokens and pretrained transformer at the same time!"; + +######################################################################################## +# ADVANCED configuration # +######################################################################################## + +# Detailed dataset, training, vocabulary and model configuration. +{ + # Configuration type (default or finetuning), str + type: std.extVar('type'), + # Datasets used for vocab creation, list of str + # Choice "train", "valid" + datasets_for_vocab_creation: ['train'], + # Path to training data, str + train_data_path: { + conllu: training_data_path, + # TODO Add configuration + iob: "./data/nkjp-nested-simplified-v2.fixed.iob", + }, + # Path to validation data, str + validation_data_path: validation_data_path, + # Dataset reader configuration (conllu format) + dataset_reader: { + type: "multitask", + readers: { + "conllu": { + type: "conllu", + features: features, + targets: targets, + # Whether data contains semantic relation field, bool + use_sem: if in_targets("semrel") then true else false, + token_indexers: { + token: if use_transformer then { + type: "pretrained_transformer_mismatched", + model_name: pretrained_transformer_name, + } else { + # SingleIdTokenIndexer, token as single int + type: "single_id", + }, + upostag: { + type: "single_id", + namespace: "upostag", + feature_name: "pos_", + }, + xpostag: { + type: "single_id", + namespace: "xpostag", + feature_name: "tag_", + }, + lemma: { + type: "characters_const_padding", + character_tokenizer: { + start_tokens: ["__START__"], + end_tokens: ["__END__"], + }, + # +2 for start and end token + min_padding_length: word_length + 2, + }, + char: { + type: "characters_const_padding", + character_tokenizer: { + start_tokens: ["__START__"], + end_tokens: ["__END__"], + }, + # +2 for start and end token + min_padding_length: word_length + 2, + }, + feats: { + type: "feats_indexer", + }, + }, + lemma_indexers: { + char: { + type: "characters_const_padding", + namespace: "lemma_characters", + character_tokenizer: { + start_tokens: ["__START__"], + end_tokens: ["__END__"], + }, + # +2 for start and end token + min_padding_length: word_length + 2, + }, + }, + }, + iob: { + type: "iob", + tag_label: "ner", + label_namespace: "ner_labels", + token_indexers: { + token: if use_transformer then { + type: "pretrained_transformer_mismatched", + model_name: pretrained_transformer_name, + } else { + # SingleIdTokenIndexer, token as single int + type: "single_id", + }, + char: { + type: "characters_const_padding", + character_tokenizer: { + start_tokens: ["__START__"], + end_tokens: ["__END__"], + }, + # +2 for start and end token + min_padding_length: word_length + 2, + }, + }, + }, + }, + }, + # Data loader configuration + data_loader: { + type: "multitask", + scheduler: { + batch_size: 10 + }, + shuffle: true, +// batch_sampler: { +// type: "token_count", +// word_batch_size: word_batch_size, +// }, + }, + # Vocabulary configuration + vocabulary: std.prune({ + type: "from_instances_extended", + only_include_pretrained_words: true, + pretrained_files: { + tokens: pretrained_tokens, + }, + oov_token: "_", + padding_token: "__PAD__", + non_padded_namespaces: ["head_labels"], + }), + model: std.prune({ + type: "multitask_extended", + backbone: { + type: "combo_backbone", + text_field_embedder: { + type: "basic", + token_embedders: { + xpostag: if in_features("xpostag") then { + type: "embedding", + padding_index: 0, + embedding_dim: xpostag_dim, + vocab_namespace: "xpostag", + }, + upostag: if in_features("upostag") then { + type: "embedding", + padding_index: 0, + embedding_dim: upostag_dim, + vocab_namespace: "upostag", + }, + token: if use_transformer then { + type: "transformers_word_embeddings", + model_name: pretrained_transformer_name, + projection_dim: projected_embedding_dim, + } else { + type: "embeddings_projected", + embedding_dim: embedding_dim, + projection_layer: { + in_features: embedding_dim, + out_features: projected_embedding_dim, + dropout_rate: 0.25, + activation: "tanh" + }, + vocab_namespace: "tokens", + pretrained_file: pretrained_tokens, + trainable: if pretrained_tokens == null then true else false, + }, + char: { + type: "char_embeddings_from_config", + embedding_dim: char_dim, + dilated_cnn_encoder: { + input_dim: char_dim, + filters: [512, 256, char_dim], + kernel_size: [3, 3, 3], + stride: [1, 1, 1], + padding: [1, 2, 4], + dilation: [1, 2, 4], + activations: ["relu", "relu", "linear"], + }, + }, + lemma: if in_features("lemma") then { + type: "char_embeddings_from_config", + embedding_dim: lemma_char_dim, + dilated_cnn_encoder: { + input_dim: lemma_char_dim, + filters: [512, 256, lemma_char_dim], + kernel_size: [3, 3, 3], + stride: [1, 1, 1], + padding: [1, 2, 4], + dilation: [1, 2, 4], + activations: ["relu", "relu", "linear"], + }, + }, + feats: if in_features("feats") then { + type: "feats_embedding", + padding_index: 0, + embedding_dim: feats_dim, + vocab_namespace: "feats", + }, + }, + }, + seq_encoder: { + type: "combo_encoder", + layer_dropout_probability: 0.33, + stacked_bilstm: { + input_size: + (char_dim + projected_embedding_dim + + (if in_features('xpostag') then xpostag_dim else 0) + + (if in_features('lemma') then lemma_char_dim else 0) + + (if in_features('upostag') then upostag_dim else 0) + + (if in_features('feats') then feats_dim else 0)), + hidden_size: hidden_size, + num_layers: num_layers, + recurrent_dropout_probability: 0.33, + layer_dropout_probability: 0.33 + }, + } + }, + heads: { + conllu: { + type: "semantic_multitask_head", + loss_weights: loss_weights, + dependency_relation: { + type: "combo_dependency_parsing_from_vocab", + vocab_namespace: 'deprel_labels', + head_predictor: { + local projection_dim = 512, + cycle_loss_n: cycle_loss_n, + head_projection_layer: { + in_features: hidden_size * 2, + out_features: projection_dim, + activation: "tanh", + }, + dependency_projection_layer: { + in_features: hidden_size * 2, + out_features: projection_dim, + activation: "tanh", + }, + }, + local projection_dim = 128, + head_projection_layer: { + in_features: hidden_size * 2, + out_features: projection_dim, + dropout_rate: predictors_dropout, + activation: "tanh" + }, + dependency_projection_layer: { + in_features: hidden_size * 2, + out_features: projection_dim, + dropout_rate: predictors_dropout, + activation: "tanh" + }, + }, + morphological_feat: if in_targets("feats") then { + type: "combo_morpho_from_vocab", + vocab_namespace: "feats_labels", + input_dim: hidden_size * 2, + hidden_dims: [128], + activations: ["tanh", "linear"], + dropout: [predictors_dropout, 0.0], + num_layers: 2, + }, + lemmatizer: if in_targets("lemma") then shared_config.Model.lemma(hidden_size, predictors_dropout), + upos_tagger: if in_targets("upostag") then { + input_dim: hidden_size * 2, + hidden_dims: [64], + activations: ["tanh", "linear"], + dropout: [predictors_dropout, 0.0], + num_layers: 2, + vocab_namespace: "upostag_labels" + }, + xpos_tagger: if in_targets("xpostag") then { + input_dim: hidden_size * 2, + hidden_dims: [128], + activations: ["tanh", "linear"], + dropout: [predictors_dropout, 0.0], + num_layers: 2, + vocab_namespace: "xpostag_labels" + }, + semantic_relation: if in_targets("semrel") then { + input_dim: hidden_size * 2, + hidden_dims: [64], + activations: ["tanh", "linear"], + dropout: [predictors_dropout, 0.0], + num_layers: 2, + vocab_namespace: "semrel_labels" + }, + regularizer: { + regexes: [ + [".*conv1d.*", {type: "l2", alpha: 1e-6}], + [".*forward.*", {type: "l2", alpha: 1e-6}], + [".*backward.*", {type: "l2", alpha: 1e-6}], + [".*char_embed.*", {type: "l2", alpha: 1e-5}], + ], + }, + }, + iob: { + type: "ner_head", + feedforward_predictor: { + type: "feedforward_predictor_from_vocab", + input_dim: hidden_size * 2, + hidden_dims: [128], + activations: ["tanh", "linear"], + dropout: [predictors_dropout, 0.0], + num_layers: 2, + vocab_namespace: "ner_labels" + } + }, + }, + }), + trainer: shared_config.Trainer(cuda_device, num_epochs, learning_rate, use_tensorboard), + random_seed: 8787, + pytorch_seed: 8787, + numpy_seed: 8787, +} diff --git a/combo/config.shared.libsonnet b/combo/config.shared.libsonnet index 23d3f9a8717dc35bf3a33bd112d13de6daab13dc..bb262f6a64cd55e00df567ef5484930e679dd9c9 100644 --- a/combo/config.shared.libsonnet +++ b/combo/config.shared.libsonnet @@ -25,5 +25,27 @@ }, validation_metric: "+EM", }), - Trainer: trainer + + local lemma(hidden_size, dropout) = { + type: "combo_lemma_predictor_from_vocab", + char_vocab_namespace: "token_characters", + lemma_vocab_namespace: "lemma_characters", + embedding_dim: 256, + input_projection_layer: { + in_features: hidden_size * 2, + out_features: 32, + dropout_rate: dropout, + activation: "tanh" + }, + filters: [256, 256, 256], + kernel_size: [3, 3, 3, 1], + stride: [1, 1, 1, 1], + padding: [1, 2, 4, 0], + dilation: [1, 2, 4, 1], + activations: ["relu", "relu", "relu", "linear"], + }, + Trainer: trainer, + Model: { + lemma: lemma, + } } \ No newline at end of file diff --git a/combo/data/dataset.py b/combo/data/dataset.py index 870659fe27857d157e5061f77690f61f166330a7..6dc30b66315956de03495ad942c68ce0a3a33cf9 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -1,4 +1,5 @@ import copy +import itertools import logging import pathlib from typing import Union, List, Dict, Iterable, Optional, Any, Tuple @@ -6,8 +7,9 @@ from typing import Union, List, Dict, Iterable, Optional, Any, Tuple import conllu import torch from allennlp import data as allen_data -from allennlp.common import checks, util +from allennlp.common import checks, util, file_utils from allennlp.data import fields as allen_fields, vocabulary +from allennlp.data.dataset_readers import dataset_reader, conll2003 from conllu import parser from dataclasses import dataclass from overrides import overrides @@ -17,6 +19,36 @@ from combo.data import fields logger = logging.getLogger(__name__) +@allen_data.DatasetReader.register("iob") +class IOBDatasetReader(conll2003.Conll2003DatasetReader): + """Extension of the AllenNLP Conll2003DatasetReader with tab as a separator.""" + + def _read(self, file_path: dataset_reader.PathOrStr) -> Iterable[allen_data.Instance]: + # if `file_path` is a URL, redirect to the cache + file_path = file_utils.cached_path(file_path) + + with open(file_path, "r") as data_file: + logger.info("Reading instances from lines in file at: %s", file_path) + + # Group lines into sentence chunks based on the divider. + line_chunks = ( + lines + for is_divider, lines in itertools.groupby(data_file, conll2003._is_divider) + # Ignore the divider chunks, so that `lines` corresponds to the words + # of a single sentence. + if not is_divider + ) + for lines in self.shard_iterable(line_chunks): + fields = [line.strip().split("\t") for line in lines] + # unzipping trick returns tuples, but our Fields need lists + fields = [list(field) for field in zip(*fields)] + tokens_, pos_tags, chunk_tags, ner_tags = fields + # TextField requires `Token` objects + tokens = [allen_data.Token(token) for token in tokens_] + + yield self.text_to_instance(tokens, pos_tags, chunk_tags, ner_tags) + + @allen_data.DatasetReader.register("conllu") class UniversalDependenciesDatasetReader(allen_data.DatasetReader): @@ -99,7 +131,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): # features text_field = allen_fields.TextField(tokens, self._token_indexers) - fields_["sentence"] = text_field + fields_["tokens"] = text_field # targets if self.generate_labels: diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py index a1754985f115d95995f48a2f99d5a4485d5faa40..2dbc4221cf05902274767b6e3337ee0a90856099 100644 --- a/combo/data/samplers/samplers.py +++ b/combo/data/samplers/samplers.py @@ -18,7 +18,7 @@ class TokenCountBatchSampler(allen_data.BatchSampler): batches = [] batch = [] words_count = 0 - lengths = [len(instance.fields["sentence"].tokens) for instance in dataset] + lengths = [len(instance.fields["tokens"].tokens) for instance in dataset] argsorted_lengths = np.argsort(lengths) for idx in argsorted_lengths: words_count += lengths[idx] diff --git a/combo/models/__init__.py b/combo/models/__init__.py index ec7a1380e1cfc80b0302806e46cca4e5fc2d3568..04079101024eb7f7031fac63c37cceba2f322388 100644 --- a/combo/models/__init__.py +++ b/combo/models/__init__.py @@ -7,3 +7,4 @@ from .encoder import ComboEncoder from .lemma import LemmatizerModel from .model import ComboModel from .morpho import MorphologicalFeatures +from .multitask import MultiTaskModel diff --git a/combo/models/base.py b/combo/models/base.py index a5cb5fe61f85a98f78d143a54695d01948aa8dda..786ef245e1f7bb4860d94288de3dcb49f2b1afb1 100644 --- a/combo/models/base.py +++ b/combo/models/base.py @@ -106,7 +106,8 @@ class FeedForwardPredictor(Predictor): ) assert vocab_namespace in vocab.get_namespaces(), \ - f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!" + f"There is not {vocab_namespace} in created vocabs ({','.join(vocab.get_namespaces())}), " \ + f"check if this field has any values to predict!" hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)] return cls(feedforward.FeedForward( diff --git a/combo/models/model.py b/combo/models/model.py index 9866bcb4fba41ed2506b2d33290e6cd0fe237d29..4c9d1bee43c6460c2ed636c3fa015ddeeafecce1 100644 --- a/combo/models/model.py +++ b/combo/models/model.py @@ -1,25 +1,79 @@ """Main COMBO model.""" -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, Union import torch -from allennlp import data, modules, models as allen_models, nn as allen_nn +from allennlp import data, modules, nn as allen_nn +from allennlp.models import heads from allennlp.modules import text_field_embedders from allennlp.nn import util +from allennlp.training import metrics as allen_metrics from overrides import overrides from combo.models import base from combo.utils import metrics -@allen_models.Model.register("semantic_multitask") -class ComboModel(allen_models.Model): +@modules.Backbone.register("combo_backbone") +class ComboBackbone(modules.Backbone): + + def __init__(self, text_field_embedder: text_field_embedders.TextFieldEmbedder, + seq_encoder: modules.Seq2SeqEncoder): + super().__init__() + self.text_field_embedder = text_field_embedder + self.seq_encoder = seq_encoder + + def forward(self, tokens: Dict[str, Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + char_mask = tokens["char"]["token_characters"].gt(0) + word_mask = util.get_text_field_mask(tokens) + encoder_input = self.text_field_embedder(tokens, char_mask=char_mask) + return dict(encoder_emb=self.seq_encoder(encoder_input, word_mask), + word_mask=word_mask, + char_mask=char_mask) + + +@heads.Head.register("ner_head") +class NERModel(heads.Head): + + def __init__(self, feedforward_predictor: base.Predictor, vocab: data.Vocabulary): + super().__init__(vocab) + self.feedforward_predictor = feedforward_predictor + self._accuracy_metric = allen_metrics.CategoricalAccuracy() + # self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab, tag_namespace="ner_labels", label_encoding="IOB1", + # ignore_classes=["_"]) + self._loss = 0.0 + + def forward(self, + encoder_emb: Union[torch.Tensor, List[torch.Tensor]], + word_mask: Optional[torch.BoolTensor] = None, + tags: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: + output = self.feedforward_predictor( + x=encoder_emb, + mask=word_mask, + labels=tags, + ) + + if tags is not None: + self._loss = output["loss"] + self._accuracy_metric(output["probability"], tags, word_mask) + # self._f1_metric(output["probability"], tags, word_mask) + + return output + + @overrides + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + return { + **{"accuracy": self._accuracy_metric.get_metric(reset), "loss": self._loss}, + # **self._f1_metric.get_metric(reset) + } + + +@heads.Head.register("semantic_multitask_head") +class ComboModel(heads.Head): """Main COMBO model.""" def __init__(self, vocab: data.Vocabulary, loss_weights: Dict[str, float], - text_field_embedder: text_field_embedders.TextFieldEmbedder, - seq_encoder: modules.Seq2SeqEncoder, use_sample_weight: bool = True, lemmatizer: Optional[base.Predictor] = None, upos_tagger: Optional[base.Predictor] = None, @@ -30,10 +84,8 @@ class ComboModel(allen_models.Model): enhanced_dependency_relation: Optional[base.Predictor] = None, regularizer: allen_nn.RegularizerApplicator = None) -> None: super().__init__(vocab, regularizer) - self.text_field_embedder = text_field_embedder self.loss_weights = loss_weights self.use_sample_weight = use_sample_weight - self.seq_encoder = seq_encoder self.lemmatizer = lemmatizer self.upos_tagger = upos_tagger self.xpos_tagger = xpos_tagger @@ -41,13 +93,16 @@ class ComboModel(allen_models.Model): self.morphological_feat = morphological_feat self.dependency_relation = dependency_relation self.enhanced_dependency_relation = enhanced_dependency_relation - self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, self.seq_encoder.get_output_dim()])) + self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, 1024])) self.scores = metrics.SemanticMetrics() self._partial_losses = None @overrides def forward(self, - sentence: Dict[str, Dict[str, torch.Tensor]], + tokens: Dict[str, Dict[str, torch.Tensor]], + encoder_emb: torch.Tensor, + word_mask: torch.Tensor, + char_mask: torch.Tensor, metadata: List[Dict[str, Any]], upostag: torch.Tensor = None, xpostag: torch.Tensor = None, @@ -58,19 +113,11 @@ class ComboModel(allen_models.Model): semrel: torch.Tensor = None, enhanced_heads: torch.Tensor = None, enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]: - - # Prepare masks - char_mask = sentence["char"]["token_characters"].gt(0) - word_mask = util.get_text_field_mask(sentence) - - device = word_mask.device + device = encoder_emb.device # If enabled weight samples loss by log(sentence_length) sample_weights = word_mask.sum(-1).float().log() if self.use_sample_weight else None - encoder_input = self.text_field_embedder(sentence, char_mask=char_mask) - encoder_emb = self.seq_encoder(encoder_input, word_mask) - batch_size, _, encoding_dim = encoder_emb.size() # Concatenate the head sentinel (ROOT) onto the sentence representation. @@ -99,8 +146,8 @@ class ComboModel(allen_models.Model): labels=feats, sample_weights=sample_weights) lemma_output = self._optional(self.lemmatizer, - (encoder_emb, sentence.get("char").get("token_characters") - if sentence.get("char") else None), + (encoder_emb, tokens.get("char").get("token_characters") + if tokens.get("char") else None), mask=word_mask, labels=lemma.get("char").get("token_characters") if lemma else None, sample_weights=sample_weights) diff --git a/combo/models/multitask.py b/combo/models/multitask.py new file mode 100644 index 0000000000000000000000000000000000000000..d557591903baec157b6ebf3c6e184c01ba30daed --- /dev/null +++ b/combo/models/multitask.py @@ -0,0 +1,66 @@ +import collections +from typing import Mapping, List, Dict, Union + +import torch +from allennlp import models + + +@models.Model.register("multitask_extended") +class MultiTaskModel(models.MultiTaskModel): + """Extension of the AllenNLP MultiTaskModel to handle dictionary inputs.""" + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore + if "task" not in kwargs: + raise ValueError( + "Instances for multitask training need to contain a MetadataField with " + "the name 'task' to indicate which task they belong to. Usually the " + "MultitaskDataLoader provides this field and you don't have to do anything." + ) + + task_indices_just_for_mypy: Mapping[str, List[int]] = collections.defaultdict(lambda: []) + for i, task in enumerate(kwargs["task"]): + task_indices_just_for_mypy[task].append(i) + task_indices: Dict[str, torch.LongTensor] = { + task: torch.LongTensor(indices) for task, indices in task_indices_just_for_mypy.items() + } + + def make_inputs_for_task(task: str, whole_batch_input: Union[torch.Tensor, List, Dict]): + if isinstance(whole_batch_input, torch.Tensor): + task_indices[task] = task_indices[task].to(whole_batch_input.device) + return torch.index_select(whole_batch_input, 0, task_indices[task]) + if isinstance(whole_batch_input, dict): + return {k: make_inputs_for_task(task, v) for k, v in whole_batch_input.items()} + else: + return [whole_batch_input[i] for i in task_indices[task]] + + backbone_arguments = self._get_arguments(kwargs, "backbone") + backbone_outputs = self._backbone(**backbone_arguments) + combined_arguments = {**backbone_outputs, **kwargs} + + outputs = {**backbone_outputs} + loss = None + for head_name in self._heads: + if head_name not in task_indices: + continue + + head_arguments = self._get_arguments(combined_arguments, head_name) + head_arguments = { + key: make_inputs_for_task(head_name, value) for key, value in head_arguments.items() + } + + head_outputs = self._heads[head_name](**head_arguments) + for key in head_outputs: + outputs[f"{head_name}_{key}"] = head_outputs[key] + + if "loss" in head_outputs: + self._heads_called.add(head_name) + head_loss = self._loss_weights[head_name] * head_outputs["loss"] + if loss is None: + loss = head_loss + else: + loss += head_loss + + if loss is not None: + outputs["loss"] = loss + + return outputs diff --git a/scripts/fix_iob.py b/scripts/fix_iob.py new file mode 100644 index 0000000000000000000000000000000000000000..976e70472feec9fdb409402122f063b4f42039f2 --- /dev/null +++ b/scripts/fix_iob.py @@ -0,0 +1,35 @@ +"""Script which fixes -DOCSTART- misspellings.""" +import pathlib + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS +flags.DEFINE_string(name="input_path", default=None, + help="Path to IOB file.") +flags.DEFINE_string(name="output_path", default=None, + help="Path to store fixed IOB file.") +flags.mark_flag_as_required("input_path") +flags.mark_flag_as_required("output_path") + + +def run(_): + input_path = pathlib.Path(FLAGS.input_path) + assert input_path.exists() and input_path.is_file(), "Input doesn't exists or is not a file." + with input_path.open("r") as input_fh: + with pathlib.Path(FLAGS.output_path).open("w") as output_fh: + for line in input_fh: + + # Replace -DOCSTART with -DOCSTART- + if "-DOCSTART" in line and "-DOCSTART-" not in line: + line = line.replace("-DOCSTART", "-DOCSTART-") + + output_fh.write(line) + + +def main(): + app.run(run) + + +if __name__ == "__main__": + main() diff --git a/tests/data/fields/test_samplers.py b/tests/data/fields/test_samplers.py index f13a26449c17da7810b811352afb930615020bcd..6a228b2a0892b67ebc6b5369899bdd291d069882 100644 --- a/tests/data/fields/test_samplers.py +++ b/tests/data/fields/test_samplers.py @@ -16,7 +16,7 @@ class TokenCountBatchSamplerTest(unittest.TestCase): tokens = [data.Token(t) for t in sentence.split()] text_field = fields.TextField(tokens, {}) - self.dataset.append(data.Instance({"sentence": text_field})) + self.dataset.append(data.Instance({"tokens": text_field})) def test_batches(self): # given