Skip to content
Snippets Groups Projects
Commit 3c4991eb authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add span F1 metrics for NER.

parent fb9eb280
No related branches found
No related tags found
No related merge requests found
Pipeline #2905 passed
...@@ -216,7 +216,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ...@@ -216,7 +216,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
}, },
oov_token: "_", oov_token: "_",
padding_token: "__PAD__", padding_token: "__PAD__",
non_padded_namespaces: ["head_labels"], non_padded_namespaces: ["head_labels", "ner_labels"],
}), }),
model: std.prune({ model: std.prune({
type: "multitask_extended", type: "multitask_extended",
......
...@@ -38,8 +38,8 @@ class NERModel(heads.Head): ...@@ -38,8 +38,8 @@ class NERModel(heads.Head):
super().__init__(vocab) super().__init__(vocab)
self.feedforward_predictor = feedforward_predictor self.feedforward_predictor = feedforward_predictor
self._accuracy_metric = allen_metrics.CategoricalAccuracy() self._accuracy_metric = allen_metrics.CategoricalAccuracy()
# self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab, tag_namespace="ner_labels", label_encoding="IOB1", self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab, tag_namespace="ner_labels", label_encoding="IOB1",
# ignore_classes=["_"]) ignore_classes=["_"])
self._loss = 0.0 self._loss = 0.0
def forward(self, def forward(self,
...@@ -55,16 +55,19 @@ class NERModel(heads.Head): ...@@ -55,16 +55,19 @@ class NERModel(heads.Head):
if tags is not None: if tags is not None:
self._loss = output["loss"] self._loss = output["loss"]
self._accuracy_metric(output["probability"], tags, word_mask) self._accuracy_metric(output["probability"], tags, word_mask)
# self._f1_metric(output["probability"], tags, word_mask) self._f1_metric(output["probability"], tags, word_mask)
return output return output
@overrides @overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]: def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return { metrics_ = {"accuracy": self._accuracy_metric.get_metric(reset), "loss": self._loss}
**{"accuracy": self._accuracy_metric.get_metric(reset), "loss": self._loss}, for name, value in self._f1_metric.get_metric(reset).items():
# **self._f1_metric.get_metric(reset) if "overall" in name:
} metrics_[name] = value
else:
metrics_[f"_{name}"] = value
return metrics_
@heads.Head.register("semantic_multitask_head") @heads.Head.register("semantic_multitask_head")
......
...@@ -3,11 +3,12 @@ from typing import Mapping, List, Dict, Union ...@@ -3,11 +3,12 @@ from typing import Mapping, List, Dict, Union
import torch import torch
from allennlp import models from allennlp import models
from overrides import overrides
@models.Model.register("multitask_extended") @models.Model.register("multitask_extended")
class MultiTaskModel(models.MultiTaskModel): class MultiTaskModel(models.MultiTaskModel):
"""Extension of the AllenNLP MultiTaskModel to handle dictionary inputs.""" """Extension of the AllenNLP MultiTaskModel to handle dictionary inputs and silence _ prefixed metrics."""
def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore
if "task" not in kwargs: if "task" not in kwargs:
...@@ -64,3 +65,18 @@ class MultiTaskModel(models.MultiTaskModel): ...@@ -64,3 +65,18 @@ class MultiTaskModel(models.MultiTaskModel):
outputs["loss"] = loss outputs["loss"] = loss
return outputs return outputs
@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
metrics = {}
for head_name in self._heads_called:
for key, value in self._heads[head_name].get_metrics(reset).items():
# Metrics starting with "_" should be silenced.
if key.startswith("_"):
metrics[f"_{head_name}{key}"] = value
else:
metrics[f"{head_name}_{key}"] = value
if reset:
self._heads_called.clear()
return metrics
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment