diff --git a/combo/config.multitask.template.jsonnet b/combo/config.multitask.template.jsonnet index a977819eadb339d94f2f59f2421c103c6d82667e..d7699796e2b9f154b8b02e050845cad5ba97e83a 100644 --- a/combo/config.multitask.template.jsonnet +++ b/combo/config.multitask.template.jsonnet @@ -199,7 +199,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't data_loader: { type: "multitask", scheduler: { - batch_size: 20 + batch_size: 100 }, shuffle: true, // batch_sampler: { @@ -384,16 +384,13 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't }, }, iob: { - type: "ner_head", - feedforward_predictor: { - type: "feedforward_predictor_from_vocab", - input_dim: hidden_size * 2, - hidden_dims: [128], - activations: ["tanh", "linear"], - dropout: [predictors_dropout, 0.0], - num_layers: 2, - vocab_namespace: "ner_labels" - } + type: "ner_head_from_vocab", + input_dim: hidden_size * 2, + hidden_dims: [128], + activations: ["tanh", "linear"], + dropout: [predictors_dropout, 0.0], + num_layers: 2, + vocab_namespace: "ner_labels" }, }, }), diff --git a/combo/models/model.py b/combo/models/model.py index dd3629ad803fdb79662a7346340248df85230c1e..a3db919d6f7cc2fd93942ccc4a8df6eb3503bc85 100644 --- a/combo/models/model.py +++ b/combo/models/model.py @@ -1,10 +1,11 @@ """Main COMBO model.""" -from typing import Optional, Dict, Any, List, Union +from typing import Optional, Dict, Any, List, Union, cast import torch from allennlp import data, modules, nn as allen_nn +from allennlp.common import checks from allennlp.models import heads -from allennlp.modules import text_field_embedders +from allennlp.modules import text_field_embedders, conditional_random_field as crf from allennlp.nn import util from allennlp.training import metrics as allen_metrics from overrides import overrides @@ -31,33 +32,86 @@ class ComboBackbone(modules.Backbone): char_mask=char_mask) -@heads.Head.register("ner_head") +@heads.Head.register("ner_head_from_vocab", constructor="from_vocab") class NERModel(heads.Head): + """Based on AllenNLP-models CrfTagger.""" - def __init__(self, feedforward_predictor: base.Predictor, vocab: data.Vocabulary): + def __init__(self, feedforward_network: modules.FeedForward, vocab: data.Vocabulary, + label_namespace: str = "ner_labels", + include_start_end_transitions: bool = True, + label_encoding: str = "IOB1"): super().__init__(vocab) - self.feedforward_predictor = feedforward_predictor + self._feedforward_network = feedforward_network + + labels = self.vocab.get_index_to_token_vocabulary(label_namespace) + constraints = crf.allowed_transitions(label_encoding, labels) + self.include_start_end_transitions = include_start_end_transitions + self.num_tags = self.vocab.get_vocab_size(label_namespace) + self._crf = modules.ConditionalRandomField( + self.num_tags, constraints, include_start_end_transitions=include_start_end_transitions + ) + self._accuracy_metric = allen_metrics.CategoricalAccuracy() - self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab, tag_namespace="ner_labels", label_encoding="IOB1") + self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab, + tag_namespace=label_namespace, + label_encoding=label_encoding) self._loss = 0.0 def forward(self, encoder_emb: Union[torch.Tensor, List[torch.Tensor]], word_mask: Optional[torch.BoolTensor] = None, tags: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: - output = self.feedforward_predictor( - x=encoder_emb, - mask=word_mask, - labels=tags, - ) + batch_size = encoder_emb.size(0) + logits = self._feedforward_network(encoder_emb) + best_paths = self._crf.viterbi_tags(logits, word_mask, top_k=1) + predicted_tags = cast(List[List[int]], [x[0][0] for x in best_paths]) + output = {"tags": predicted_tags} if tags is not None: + log_likelihood: torch.Tensor = self._crf(logits, tags, word_mask) / batch_size # Mean instead of sum. + output["loss"] = -log_likelihood + class_probabilities = logits * 0.0 + for i, instance_tags in enumerate(predicted_tags): + for j, tag_id in enumerate(instance_tags): + class_probabilities[i, j, tag_id] = 1 + self._loss = output["loss"].cpu().item() - self._accuracy_metric(output["probability"], tags, word_mask) - self._f1_metric(output["probability"], tags, word_mask) + self._accuracy_metric(class_probabilities, tags, word_mask) + self._f1_metric(class_probabilities, tags, word_mask) return output + @classmethod + def from_vocab(cls, + vocab: data.Vocabulary, + vocab_namespace: str, + input_dim: int, + num_layers: int, + hidden_dims: List[int], + activations: Union[allen_nn.Activation, List[allen_nn.Activation]], + dropout: Union[float, List[float]] = 0.0, + ): + if len(hidden_dims) + 1 != num_layers: + raise checks.ConfigurationError( + f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})" + ) + + assert vocab_namespace in vocab.get_namespaces(), \ + f"There is not {vocab_namespace} in created vocabs ({','.join(vocab.get_namespaces())}), " \ + f"check if this field has any values to predict!" + hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)] + return cls( + feedforward_network=modules.FeedForward( + input_dim=input_dim, + num_layers=num_layers, + hidden_dims=hidden_dims, + activations=activations, + dropout=dropout, + ), + label_namespace=vocab_namespace, + vocab=vocab + ) + @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics_ = {"accuracy": self._accuracy_metric.get_metric(reset), "loss": self._loss}