From 3c4991ebbada0f94702b1b806cc1f114d0349f11 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Thu, 22 Apr 2021 08:39:09 +0200
Subject: [PATCH] Add span F1 metrics for NER.

---
 combo/config.multitask.template.jsonnet |  2 +-
 combo/models/model.py                   | 17 ++++++++++-------
 combo/models/multitask.py               | 18 +++++++++++++++++-
 3 files changed, 28 insertions(+), 9 deletions(-)

diff --git a/combo/config.multitask.template.jsonnet b/combo/config.multitask.template.jsonnet
index 5edff14..cf91ef1 100644
--- a/combo/config.multitask.template.jsonnet
+++ b/combo/config.multitask.template.jsonnet
@@ -216,7 +216,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
         },
         oov_token: "_",
         padding_token: "__PAD__",
-        non_padded_namespaces: ["head_labels"],
+        non_padded_namespaces: ["head_labels", "ner_labels"],
     }),
     model: std.prune({
         type: "multitask_extended",
diff --git a/combo/models/model.py b/combo/models/model.py
index 4c9d1be..cb386fb 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -38,8 +38,8 @@ class NERModel(heads.Head):
         super().__init__(vocab)
         self.feedforward_predictor = feedforward_predictor
         self._accuracy_metric = allen_metrics.CategoricalAccuracy()
-        # self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab, tag_namespace="ner_labels", label_encoding="IOB1",
-        #                                                    ignore_classes=["_"])
+        self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab, tag_namespace="ner_labels", label_encoding="IOB1",
+                                                           ignore_classes=["_"])
         self._loss = 0.0
 
     def forward(self,
@@ -55,16 +55,19 @@ class NERModel(heads.Head):
         if tags is not None:
             self._loss = output["loss"]
             self._accuracy_metric(output["probability"], tags, word_mask)
-            # self._f1_metric(output["probability"], tags, word_mask)
+            self._f1_metric(output["probability"], tags, word_mask)
 
         return output
 
     @overrides
     def get_metrics(self, reset: bool = False) -> Dict[str, float]:
-        return {
-            **{"accuracy": self._accuracy_metric.get_metric(reset), "loss": self._loss},
-            # **self._f1_metric.get_metric(reset)
-        }
+        metrics_ = {"accuracy": self._accuracy_metric.get_metric(reset), "loss": self._loss}
+        for name, value in self._f1_metric.get_metric(reset).items():
+            if "overall" in name:
+                metrics_[name] = value
+            else:
+                metrics_[f"_{name}"] = value
+        return metrics_
 
 
 @heads.Head.register("semantic_multitask_head")
diff --git a/combo/models/multitask.py b/combo/models/multitask.py
index d557591..794d0df 100644
--- a/combo/models/multitask.py
+++ b/combo/models/multitask.py
@@ -3,11 +3,12 @@ from typing import Mapping, List, Dict, Union
 
 import torch
 from allennlp import models
+from overrides import overrides
 
 
 @models.Model.register("multitask_extended")
 class MultiTaskModel(models.MultiTaskModel):
-    """Extension of the AllenNLP MultiTaskModel to handle dictionary inputs."""
+    """Extension of the AllenNLP MultiTaskModel to handle dictionary inputs and silence _ prefixed metrics."""
 
     def forward(self, **kwargs) -> Dict[str, torch.Tensor]:  # type: ignore
         if "task" not in kwargs:
@@ -64,3 +65,18 @@ class MultiTaskModel(models.MultiTaskModel):
             outputs["loss"] = loss
 
         return outputs
+
+    @overrides
+    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
+        metrics = {}
+        for head_name in self._heads_called:
+            for key, value in self._heads[head_name].get_metrics(reset).items():
+                # Metrics starting with "_" should be silenced.
+                if key.startswith("_"):
+                    metrics[f"_{head_name}{key}"] = value
+                else:
+                    metrics[f"{head_name}_{key}"] = value
+
+        if reset:
+            self._heads_called.clear()
+        return metrics
-- 
GitLab