From 5e79a44e23bc53a8ed6ecec36aff4513de036f8d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com>
Date: Tue, 28 Mar 2023 20:14:08 +0200
Subject: [PATCH] Add metrics and metrics_test

---
 combo/utils/metrics.py      | 381 +++++++++++++++++++++++++++++++++++-
 tests/utils/test_metrics.py | 211 ++++++++++++++++++++
 2 files changed, 587 insertions(+), 5 deletions(-)
 create mode 100644 tests/utils/test_metrics.py

diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py
index 9181040..47904ec 100644
--- a/combo/utils/metrics.py
+++ b/combo/utils/metrics.py
@@ -1,17 +1,388 @@
+from typing import Optional, List, Dict, Iterable
+
+import torch
+from overrides import overrides
+
+
+"""
+Class Metric adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/training/metrics/metric.py
+"""
+
+
 class Metric:
-    pass
+    """
+    A very general abstract class representing a metric which can be
+    accumulated.
+    """
+
+    supports_distributed = False
+
+    def __call__(
+        self, predictions: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.BoolTensor]
+    ):
+        """
+        # Parameters
+        predictions : `torch.Tensor`, required.
+            A tensor of predictions.
+        gold_labels : `torch.Tensor`, required.
+            A tensor corresponding to some gold label to evaluate against.
+        mask : `torch.BoolTensor`, optional (default = `None`).
+            A mask can be passed, in order to deal with metrics which are
+            computed over potentially padded elements, such as sequence labels.
+        """
+        raise NotImplementedError
+
+    def get_metric(self, reset: bool):
+        """
+        Compute and return the metric. Optionally also call `self.reset`.
+        """
+        raise NotImplementedError
+
+    def reset(self) -> None:
+        """
+        Reset any accumulators or internal state.
+        """
+        raise NotImplementedError
+
+    @staticmethod
+    def detach_tensors(*tensors: torch.Tensor) -> Iterable[torch.Tensor]:
+        """
+        If you actually passed gradient-tracking Tensors to a Metric, there will be
+        a huge memory leak, because it will prevent garbage collection for the computation
+        graph. This method ensures the tensors are detached.
+        """
+        # Check if it's actually a tensor in case something else was passed.
+        return (x.detach() if isinstance(x, torch.Tensor) else x for x in tensors)
+
+
+"""
+Adapted from COMBO
+Author: Mateusz Klimaszewski
+"""
+
 
 class LemmaAccuracy(Metric):
-    pass
+
+    def __init__(self):
+        self._correct_count = 0.0
+        self._total_count = 0.0
+        self.correct_indices = torch.ones([])
+
+    @overrides
+    def __call__(self,
+                 predictions: torch.Tensor,
+                 gold_labels: torch.Tensor,
+                 mask: Optional[torch.BoolTensor] = None):
+        if gold_labels is None:
+            return
+        predictions, gold_labels, mask = self.detach_tensors(predictions,
+                                                             gold_labels,
+                                                             mask)
+
+        # Some sanity checks.
+        if gold_labels.size() != predictions.size():
+            raise ValueError(
+                f"gold_labels must have shape == predictions.size() but "
+                f"found tensor of shape: {gold_labels.size()}"
+            )
+        if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]:
+            raise ValueError(
+                f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but "
+                f"found tensor of shape: {mask.size()}"
+            )
+        if mask is None:
+            mask = predictions.new_ones(predictions.size()[:-1]).bool()
+        if mask.dim() < predictions.dim():
+            mask = mask.unsqueeze(-1)
+
+        padding_mask = gold_labels.gt(0)
+        correct = predictions.eq(gold_labels) * padding_mask
+        correct = (correct.int().sum(-1) == padding_mask.int().sum(-1)) * mask.squeeze(-1)
+        correct = correct.float()
+
+        self.correct_indices = correct.flatten().bool()
+        self._correct_count += correct.sum()
+        self._total_count += mask.sum()
+
+    @overrides
+    def get_metric(self, reset: bool) -> float:
+        if self._total_count > 0:
+            accuracy = float(self._correct_count) / float(self._total_count)
+        else:
+            accuracy = 0.0
+        if reset:
+            self.reset()
+        return accuracy
+
+    @overrides
+    def reset(self) -> None:
+        self._correct_count = 0.0
+        self._total_count = 0.0
+        self.correct_indices = torch.ones([])
 
 
 class SequenceBoolAccuracy(Metric):
-    pass
+    """BoolAccuracy implementation to handle sequences."""
+
+    def __init__(self, prod_last_dim: bool = False):
+        self._correct_count = 0.0
+        self._total_count = 0.0
+        self.prod_last_dim = prod_last_dim
+        self.correct_indices = torch.ones([])
+
+    @overrides
+    def __call__(self,
+                 predictions: torch.Tensor,
+                 gold_labels: torch.Tensor,
+                 mask: Optional[torch.BoolTensor] = None):
+        if gold_labels is None:
+            return
+        predictions, gold_labels, mask = self.detach_tensors(predictions,
+                                                             gold_labels,
+                                                             mask)
+
+        # Some sanity checks.
+        if gold_labels.size() != predictions.size():
+            raise ValueError(
+                f"gold_labels must have shape == predictions.size() but "
+                f"found tensor of shape: {gold_labels.size()}"
+            )
+        if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]:
+            raise ValueError(
+                f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but "
+                f"found tensor of shape: {mask.size()}"
+            )
+        if mask is None:
+            mask = predictions.new_ones(predictions.size()[:-1]).bool()
+        if mask.dim() < predictions.dim():
+            mask = mask.unsqueeze(-1)
+
+        correct = predictions.eq(gold_labels) * mask
+
+        if self.prod_last_dim:
+            correct = correct.prod(-1).unsqueeze(-1)
+
+        correct = correct.float()
+
+        self.correct_indices = correct.flatten().bool()
+        self._correct_count += correct.sum()
+        self._total_count += mask.sum()
+
+    @overrides
+    def get_metric(self, reset: bool) -> float:
+        if self._total_count > 0:
+            accuracy = float(self._correct_count) / float(self._total_count)
+        else:
+            accuracy = 0.0
+        if reset:
+            self.reset()
+        return accuracy
+
+    @overrides
+    def reset(self) -> None:
+        self._correct_count = 0.0
+        self._total_count = 0.0
+        self.correct_indices = torch.ones([])
 
 
 class AttachmentScores(Metric):
-    pass
+    """
+    Computes labeled and unlabeled attachment scores for a
+    dependency parse, as well as sentence level exact match
+    for both labeled and unlabeled trees. Note that the input
+    to this metric is the sampled predictions, not the distribution
+    itself.
+
+    # Parameters
+
+    ignore_classes : `List[int]`, optional (default = None)
+        A list of label ids to ignore when computing metrics.
+    """
+
+    def __init__(self, ignore_classes: List[int] = None) -> None:
+        self._labeled_correct = 0.0
+        self._unlabeled_correct = 0.0
+        self._exact_labeled_correct = 0.0
+        self._exact_unlabeled_correct = 0.0
+        self._total_words = 0.0
+        self._total_sentences = 0.0
+        self.correct_indices = torch.ones([])
+
+        self._ignore_classes: List[int] = ignore_classes or []
+
+    def __call__(  # type: ignore
+            self,
+            predicted_indices: torch.Tensor,
+            predicted_labels: torch.Tensor,
+            gold_indices: torch.Tensor,
+            gold_labels: torch.Tensor,
+            mask: Optional[torch.BoolTensor] = None,
+    ):
+        """
+        # Parameters
+
+        predicted_indices : `torch.Tensor`, required.
+            A tensor of head index predictions of shape (batch_size, timesteps).
+        predicted_labels : `torch.Tensor`, required.
+            A tensor of arc label predictions of shape (batch_size, timesteps).
+        gold_indices : `torch.Tensor`, required.
+            A tensor of the same shape as `predicted_indices`.
+        gold_labels : `torch.Tensor`, required.
+            A tensor of the same shape as `predicted_labels`.
+        mask : `torch.BoolTensor`, optional (default = None).
+            A tensor of the same shape as `predicted_indices`.
+        """
+        if gold_labels is None or gold_indices is None:
+            return
+        detached = self.detach_tensors(
+            predicted_indices, predicted_labels, gold_indices, gold_labels, mask
+        )
+        predicted_indices, predicted_labels, gold_indices, gold_labels, mask = detached
+
+        if mask is None:
+            mask = torch.ones_like(predicted_indices).bool()
+
+        predicted_indices = predicted_indices.long()
+        predicted_labels = predicted_labels.long()
+        gold_indices = gold_indices.long()
+        gold_labels = gold_labels.long()
+
+        # Multiply by a mask denoting locations of
+        # gold labels which we should ignore.
+        for label in self._ignore_classes:
+            label_mask = gold_labels.eq(label)
+            mask = mask & ~label_mask
+
+        correct_indices = predicted_indices.eq(gold_indices).long() * mask
+        unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1)
+        if len(correct_indices.size()) > 2:
+            unlabeled_exact_match = unlabeled_exact_match.prod(dim=-1)
+        correct_labels = predicted_labels.eq(gold_labels).long() * mask
+        correct_labels_and_indices = correct_indices * correct_labels
+        self.correct_indices = correct_labels_and_indices.flatten()
+        labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1)
+        if len(correct_indices.size()) > 2:
+            labeled_exact_match = labeled_exact_match.prod(dim=-1)
+
+        self._unlabeled_correct += correct_indices.sum()
+        self._exact_unlabeled_correct += unlabeled_exact_match.sum()
+        self._labeled_correct += correct_labels_and_indices.sum()
+        self._exact_labeled_correct += labeled_exact_match.sum()
+        self._total_sentences += correct_indices.size(0)
+        self._total_words += correct_indices.numel() - (~mask).sum()
+
+    def get_metric(self, reset: bool = False):
+        """
+        # Returns
+
+        The accumulated metrics as a dictionary.
+        """
+        unlabeled_attachment_score = 0.0
+        labeled_attachment_score = 0.0
+        unlabeled_exact_match = 0.0
+        labeled_exact_match = 0.0
+        if self._total_words > 0.0:
+            unlabeled_attachment_score = float(self._unlabeled_correct) / float(self._total_words)
+            labeled_attachment_score = float(self._labeled_correct) / float(self._total_words)
+        if self._total_sentences > 0:
+            unlabeled_exact_match = float(self._exact_unlabeled_correct) / float(
+                self._total_sentences
+            )
+            labeled_exact_match = float(self._exact_labeled_correct) / float(self._total_sentences)
+        if reset:
+            self.reset()
+        return {
+            "UAS": unlabeled_attachment_score,
+            "LAS": labeled_attachment_score,
+            "UEM": unlabeled_exact_match,
+            "LEM": labeled_exact_match,
+        }
+
+    @overrides
+    def reset(self):
+        self._labeled_correct = 0.0
+        self._unlabeled_correct = 0.0
+        self._exact_labeled_correct = 0.0
+        self._exact_unlabeled_correct = 0.0
+        self._total_words = 0.0
+        self._total_sentences = 0.0
+        self.correct_indices = torch.ones([])
 
 
 class SemanticMetrics(Metric):
-    pass
+    """Groups metrics for all predictions."""
+
+    def __init__(self) -> None:
+        self.upos_score = SequenceBoolAccuracy()
+        self.xpos_score = SequenceBoolAccuracy()
+        self.semrel_score = SequenceBoolAccuracy()
+        self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
+        self.lemma_score = LemmaAccuracy()
+        self.attachment_scores = AttachmentScores()
+        # Ignore PADDING and OOV
+        self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1])
+        self.em_score = 0.0
+
+    def __call__(  # type: ignore
+            self,
+            predictions: Dict[str, torch.Tensor],
+            gold_labels: Dict[str, torch.Tensor],
+            mask: torch.BoolTensor):
+        self.upos_score(predictions["upostag"], gold_labels["upostag"], mask)
+        self.xpos_score(predictions["xpostag"], gold_labels["xpostag"], mask)
+        self.semrel_score(predictions["semrel"], gold_labels["semrel"], mask)
+        self.feats_score(predictions["feats"], gold_labels["feats"], mask)
+        self.lemma_score(predictions["lemma"], gold_labels["lemma"], mask)
+        self.attachment_scores(predictions["head"],
+                               predictions["deprel"],
+                               gold_labels["head"],
+                               gold_labels["deprel"],
+                               mask)
+        self.enhanced_attachment_scores(predictions["enhanced_head"],
+                                        predictions["enhanced_deprel"],
+                                        gold_labels["enhanced_head"],
+                                        gold_labels["enhanced_deprel"],
+                                        mask=None)
+        enhanced_indices = (
+            self.enhanced_attachment_scores.correct_indices.reshape(mask.size(0), mask.size(1) + 1, -1)[:, 1:, 1:].sum(
+                -1).reshape(-1).bool()
+            if len(self.enhanced_attachment_scores.correct_indices.size()) > 0
+            else self.enhanced_attachment_scores.correct_indices
+        )
+        total = mask.sum()
+        correct_indices = (self.upos_score.correct_indices *
+                           self.xpos_score.correct_indices *
+                           self.semrel_score.correct_indices *
+                           self.feats_score.correct_indices *
+                           self.lemma_score.correct_indices *
+                           self.attachment_scores.correct_indices *
+                           enhanced_indices) * mask.flatten()
+
+        total, correct_indices = self.detach_tensors(total, correct_indices.float().sum())
+        self.em_score = (correct_indices / total).item()
+
+    def get_metric(self, reset: bool) -> Dict[str, float]:
+        metrics_dict = {
+            "UPOS_ACC": self.upos_score.get_metric(reset),
+            "XPOS_ACC": self.xpos_score.get_metric(reset),
+            "SEMREL_ACC": self.semrel_score.get_metric(reset),
+            "LEMMA_ACC": self.lemma_score.get_metric(reset),
+            "FEATS_ACC": self.feats_score.get_metric(reset),
+            "EM": self.em_score
+        }
+        metrics_dict.update(self.attachment_scores.get_metric(reset))
+        enhanced_metrics = {f"E{k}": v for k, v in self.enhanced_attachment_scores.get_metric(reset).items()}
+        metrics_dict.update(enhanced_metrics)
+        return metrics_dict
+
+    def reset(self) -> None:
+        self.upos_score.reset()
+        self.xpos_score.reset()
+        self.semrel_score.reset()
+        self.lemma_score.reset()
+        self.feats_score.reset()
+        self.attachment_scores.reset()
+        self.enhanced_attachment_scores.reset()
+        self.em_score = 0.0
+
diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py
new file mode 100644
index 0000000..bf7f619
--- /dev/null
+++ b/tests/utils/test_metrics.py
@@ -0,0 +1,211 @@
+"""Metrics tests."""
+import unittest
+
+import torch
+
+from combo.utils import metrics
+
+
+class SemanticMetricsTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.mask: torch.BoolTensor = torch.tensor([
+            [True, True, True, True],
+            [True, True, True, False],
+            [True, True, True, False],
+        ])
+        pred = torch.tensor([
+            [0, 1, 2, 3],
+            [0, 1, 2, 3],
+            [0, 1, 2, 3],
+        ])
+        pred_seq = pred.reshape(3, 4, 1)
+        gold = pred.clone()
+        gold_seq = pred_seq.clone()
+        self.upostag, self.upostag_l = (("upostag", x) for x in [pred, gold])
+        self.xpostag, self.xpostag_l = (("xpostag", x) for x in [pred, gold])
+        self.semrel, self.semrel_l = (("semrel", x) for x in [pred, gold])
+        self.head, self.head_l = (("head", x) for x in [pred, gold])
+        self.deprel, self.deprel_l = (("deprel", x) for x in [pred, gold])
+        # TODO(mklimasz) Set up an example with size 3x5x5
+        self.enhanced_head, self.enhanced_head_l = (("enhanced_head", x) for x in [None, None])
+        self.enhanced_deprel, self.enhanced_deprel_l = (("enhanced_deprel", x) for x in [None, None])
+        self.feats, self.feats_l = (("feats", x) for x in [pred_seq, gold_seq])
+        self.lemma, self.lemma_l = (("lemma", x) for x in [pred_seq, gold_seq])
+        self.predictions = dict(
+            [self.upostag, self.xpostag, self.semrel, self.feats, self.lemma, self.head, self.deprel,
+             self.enhanced_head, self.enhanced_deprel])
+        self.gold_labels = dict([self.upostag_l, self.xpostag_l, self.semrel_l, self.feats_l, self.lemma_l, self.head_l,
+                                 self.deprel_l, self.enhanced_head_l, self.enhanced_deprel_l])
+        self.eps = 1e-6
+
+    def test_every_prediction_correct(self):
+        # given
+        metric = metrics.SemanticMetrics()
+
+        # when
+        metric(self.predictions, self.gold_labels, self.mask)
+
+        # then
+        self.assertEqual(1.0, metric.em_score)
+
+    def test_missing_predictions_for_one_target(self):
+        # given
+        metric = metrics.SemanticMetrics()
+        self.predictions["upostag"] = None
+        self.gold_labels["upostag"] = None
+
+        # when
+        metric(self.predictions, self.gold_labels, self.mask)
+
+        # then
+        self.assertEqual(1.0, metric.em_score)
+
+    def test_missing_predictions_for_two_targets(self):
+        # given
+        metric = metrics.SemanticMetrics()
+        self.predictions["upostag"] = None
+        self.gold_labels["upostag"] = None
+        self.predictions["lemma"] = None
+        self.gold_labels["lemma"] = None
+
+        # when
+        metric(self.predictions, self.gold_labels, self.mask)
+
+        # then
+        self.assertEqual(1.0, metric.em_score)
+
+    def test_one_classification_in_one_target_is_wrong(self):
+        # given
+        metric = metrics.SemanticMetrics()
+        self.predictions["upostag"][0][0] = 100
+
+        # when
+        metric(self.predictions, self.gold_labels, self.mask)
+
+        # then
+        self.assertAlmostEqual(0.9, metric.em_score, delta=self.eps)
+
+    def test_classification_errors_and_target_without_predictions(self):
+        # given
+        metric = metrics.SemanticMetrics()
+        self.predictions["feats"] = None
+        self.gold_labels["feats"] = None
+        self.predictions["upostag"][0][0] = 100
+        self.predictions["upostag"][2][0] = 100
+        # should be ignored due to masking
+        self.predictions["upostag"][1][3] = 100
+
+        # when
+        metric(self.predictions, self.gold_labels, self.mask)
+
+        # then
+        self.assertAlmostEqual(0.8, metric.em_score, delta=self.eps)
+
+
+class SequenceBoolAccuracyTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.mask: torch.BoolTensor = torch.tensor([
+            [True, True, True, True],
+            [True, True, True, False],
+            [True, True, True, False],
+        ])
+
+    def test_regular_classification_accuracy(self):
+        # given
+        metric = metrics.SequenceBoolAccuracy()
+        predictions = torch.tensor([
+            [1, 1, 0, 8],
+            [1, 2, 3, 4],
+            [9, 4, 3, 9],
+        ])
+        gold_labels = torch.tensor([
+            [11, 1, 0, 8],
+            [14, 2, 3, 14],
+            [9, 4, 13, 9],
+        ])
+
+        # when
+        metric(predictions, gold_labels, self.mask)
+
+        # then
+        self.assertEqual(metric._correct_count.item(), 7)
+        self.assertEqual(metric._total_count.item(), 10)
+
+    def test_feats_classification_accuracy(self):
+        # given
+        metric = metrics.SequenceBoolAccuracy(prod_last_dim=True)
+        # batch_size, sequence_length, classes
+        predictions = torch.tensor([
+            [[1, 4], [0, 2], [0, 2], [0, 3]],
+            [[1, 4], [0, 2], [0, 2], [0, 3]],
+            [[1, 4], [0, 2], [0, 2], [0, 3]],
+        ])
+        gold_labels = torch.tensor([
+            [[1, 14], [0, 2], [0, 2], [0, 3]],
+            [[11, 4], [0, 2], [0, 2], [10, 3]],
+            [[1, 4], [0, 2], [10, 12], [0, 3]],
+        ])
+
+        # when
+        metric(predictions, gold_labels, self.mask)
+
+        # then
+        self.assertEqual(metric._correct_count.item(), 7)
+        self.assertEqual(metric._total_count.item(), 10)
+
+
+class LemmaAccuracyTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.mask: torch.BoolTensor = torch.tensor([
+            [True, True, True, True],
+            [True, True, True, False],
+        ])
+
+    def test_prediction_has_error_in_not_padded_place(self):
+        # given
+        metric = metrics.LemmaAccuracy()
+        predictions = torch.tensor([
+            [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
+            [[1, 1, 0], [1, 1000, 0], [1, 1, 0], [1, 1, 0], ],
+        ])
+        gold_labels = torch.tensor([
+            [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
+            [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
+        ])
+        expected_correct_count = 6
+        expected_total_count = 7
+        expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 0, 1, 0])
+
+        # when
+        metric(predictions, gold_labels, self.mask)
+
+        # then
+        self.assertEqual(metric._correct_count.item(), expected_correct_count)
+        self.assertEqual(metric._total_count.item(), expected_total_count)
+        self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices)))
+
+    def test_prediction_wrong_prediction_in_padding_should_be_ignored(self):
+        # given
+        metric = metrics.LemmaAccuracy()
+        predictions = torch.tensor([
+            [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
+            [[1, 1, 1000], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
+        ])
+        gold_labels = torch.tensor([
+            [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
+            [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
+        ])
+        expected_correct_count = 7
+        expected_total_count = 7
+        expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 1, 1, 0])
+
+        # when
+        metric(predictions, gold_labels, self.mask)
+
+        # then
+        self.assertEqual(expected_correct_count, metric._correct_count.item())
+        self.assertEqual(expected_total_count, metric._total_count.item())
+        self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices)))
-- 
GitLab