Skip to content
Snippets Groups Projects
Commit 984e06ef authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add CRF to NER.

parent f3dcbc39
Branches
No related merge requests found
Pipeline #2914 passed with stage
in 3 minutes and 56 seconds
...@@ -199,7 +199,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ...@@ -199,7 +199,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
data_loader: { data_loader: {
type: "multitask", type: "multitask",
scheduler: { scheduler: {
batch_size: 20 batch_size: 100
}, },
shuffle: true, shuffle: true,
// batch_sampler: { // batch_sampler: {
...@@ -384,16 +384,13 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ...@@ -384,16 +384,13 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
}, },
}, },
iob: { iob: {
type: "ner_head", type: "ner_head_from_vocab",
feedforward_predictor: { input_dim: hidden_size * 2,
type: "feedforward_predictor_from_vocab", hidden_dims: [128],
input_dim: hidden_size * 2, activations: ["tanh", "linear"],
hidden_dims: [128], dropout: [predictors_dropout, 0.0],
activations: ["tanh", "linear"], num_layers: 2,
dropout: [predictors_dropout, 0.0], vocab_namespace: "ner_labels"
num_layers: 2,
vocab_namespace: "ner_labels"
}
}, },
}, },
}), }),
......
"""Main COMBO model.""" """Main COMBO model."""
from typing import Optional, Dict, Any, List, Union from typing import Optional, Dict, Any, List, Union, cast
import torch import torch
from allennlp import data, modules, nn as allen_nn from allennlp import data, modules, nn as allen_nn
from allennlp.common import checks
from allennlp.models import heads from allennlp.models import heads
from allennlp.modules import text_field_embedders from allennlp.modules import text_field_embedders, conditional_random_field as crf
from allennlp.nn import util from allennlp.nn import util
from allennlp.training import metrics as allen_metrics from allennlp.training import metrics as allen_metrics
from overrides import overrides from overrides import overrides
...@@ -31,33 +32,86 @@ class ComboBackbone(modules.Backbone): ...@@ -31,33 +32,86 @@ class ComboBackbone(modules.Backbone):
char_mask=char_mask) char_mask=char_mask)
@heads.Head.register("ner_head") @heads.Head.register("ner_head_from_vocab", constructor="from_vocab")
class NERModel(heads.Head): class NERModel(heads.Head):
"""Based on AllenNLP-models CrfTagger."""
def __init__(self, feedforward_predictor: base.Predictor, vocab: data.Vocabulary): def __init__(self, feedforward_network: modules.FeedForward, vocab: data.Vocabulary,
label_namespace: str = "ner_labels",
include_start_end_transitions: bool = True,
label_encoding: str = "IOB1"):
super().__init__(vocab) super().__init__(vocab)
self.feedforward_predictor = feedforward_predictor self._feedforward_network = feedforward_network
labels = self.vocab.get_index_to_token_vocabulary(label_namespace)
constraints = crf.allowed_transitions(label_encoding, labels)
self.include_start_end_transitions = include_start_end_transitions
self.num_tags = self.vocab.get_vocab_size(label_namespace)
self._crf = modules.ConditionalRandomField(
self.num_tags, constraints, include_start_end_transitions=include_start_end_transitions
)
self._accuracy_metric = allen_metrics.CategoricalAccuracy() self._accuracy_metric = allen_metrics.CategoricalAccuracy()
self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab, tag_namespace="ner_labels", label_encoding="IOB1") self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab,
tag_namespace=label_namespace,
label_encoding=label_encoding)
self._loss = 0.0 self._loss = 0.0
def forward(self, def forward(self,
encoder_emb: Union[torch.Tensor, List[torch.Tensor]], encoder_emb: Union[torch.Tensor, List[torch.Tensor]],
word_mask: Optional[torch.BoolTensor] = None, word_mask: Optional[torch.BoolTensor] = None,
tags: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: tags: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
output = self.feedforward_predictor( batch_size = encoder_emb.size(0)
x=encoder_emb, logits = self._feedforward_network(encoder_emb)
mask=word_mask, best_paths = self._crf.viterbi_tags(logits, word_mask, top_k=1)
labels=tags, predicted_tags = cast(List[List[int]], [x[0][0] for x in best_paths])
)
output = {"tags": predicted_tags}
if tags is not None: if tags is not None:
log_likelihood: torch.Tensor = self._crf(logits, tags, word_mask) / batch_size # Mean instead of sum.
output["loss"] = -log_likelihood
class_probabilities = logits * 0.0
for i, instance_tags in enumerate(predicted_tags):
for j, tag_id in enumerate(instance_tags):
class_probabilities[i, j, tag_id] = 1
self._loss = output["loss"].cpu().item() self._loss = output["loss"].cpu().item()
self._accuracy_metric(output["probability"], tags, word_mask) self._accuracy_metric(class_probabilities, tags, word_mask)
self._f1_metric(output["probability"], tags, word_mask) self._f1_metric(class_probabilities, tags, word_mask)
return output return output
@classmethod
def from_vocab(cls,
vocab: data.Vocabulary,
vocab_namespace: str,
input_dim: int,
num_layers: int,
hidden_dims: List[int],
activations: Union[allen_nn.Activation, List[allen_nn.Activation]],
dropout: Union[float, List[float]] = 0.0,
):
if len(hidden_dims) + 1 != num_layers:
raise checks.ConfigurationError(
f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})"
)
assert vocab_namespace in vocab.get_namespaces(), \
f"There is not {vocab_namespace} in created vocabs ({','.join(vocab.get_namespaces())}), " \
f"check if this field has any values to predict!"
hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
return cls(
feedforward_network=modules.FeedForward(
input_dim=input_dim,
num_layers=num_layers,
hidden_dims=hidden_dims,
activations=activations,
dropout=dropout,
),
label_namespace=vocab_namespace,
vocab=vocab
)
@overrides @overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]: def get_metrics(self, reset: bool = False) -> Dict[str, float]:
metrics_ = {"accuracy": self._accuracy_metric.get_metric(reset), "loss": self._loss} metrics_ = {"accuracy": self._accuracy_metric.get_metric(reset), "loss": self._loss}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment