From 2aa41eba98c2af0857d2122be7c47549711f970e Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Wed, 6 Jan 2021 23:35:38 +0100
Subject: [PATCH] Remove lemma padding from lemma loss and metric.

---
 combo/models/lemma.py       |  4 +--
 combo/utils/metrics.py      | 62 ++++++++++++++++++++++++++++++++++++-
 tests/utils/test_metrics.py | 55 ++++++++++++++++++++++++++++++++
 3 files changed, 118 insertions(+), 3 deletions(-)

diff --git a/combo/models/lemma.py b/combo/models/lemma.py
index 4ff9c92..828ba3e 100644
--- a/combo/models/lemma.py
+++ b/combo/models/lemma.py
@@ -62,11 +62,11 @@ class LemmatizerModel(base.Predictor):
         BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_CLASSES = pred.size()
         pred = pred.reshape(-1, CHAR_CLASSES)
 
-        valid_positions = mask.sum()
-        mask = mask.reshape(-1)
         true = true.reshape(-1)
+        mask = true.gt(0)
         loss = utils.masked_cross_entropy(pred, true, mask)
         loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
+        valid_positions = mask.sum()
         return loss.sum() / valid_positions
 
     @classmethod
diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py
index 1a17540..c2a1202 100644
--- a/combo/utils/metrics.py
+++ b/combo/utils/metrics.py
@@ -6,6 +6,66 @@ from allennlp.training import metrics
 from overrides import overrides
 
 
+class LemmaAccuracy(metrics.Metric):
+
+    def __init__(self):
+        self._correct_count = 0.0
+        self._total_count = 0.0
+        self.correct_indices = torch.ones([])
+
+    @overrides
+    def __call__(self,
+                 predictions: torch.Tensor,
+                 gold_labels: torch.Tensor,
+                 mask: Optional[torch.BoolTensor] = None):
+        if gold_labels is None:
+            return
+        predictions, gold_labels, mask = self.detach_tensors(predictions,
+                                                             gold_labels,
+                                                             mask)
+
+        # Some sanity checks.
+        if gold_labels.size() != predictions.size():
+            raise ValueError(
+                f"gold_labels must have shape == predictions.size() but "
+                f"found tensor of shape: {gold_labels.size()}"
+            )
+        if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]:
+            raise ValueError(
+                f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but "
+                f"found tensor of shape: {mask.size()}"
+            )
+        if mask is None:
+            mask = predictions.new_ones(predictions.size()[:-1]).bool()
+        if mask.dim() < predictions.dim():
+            mask = mask.unsqueeze(-1)
+
+        padding_mask = gold_labels.gt(0)
+        correct = predictions.eq(gold_labels) * padding_mask
+        correct = (correct.int().sum(-1) == padding_mask.int().sum(-1)) * mask.squeeze(-1)
+        correct = correct.float()
+
+        self.correct_indices = correct.flatten().bool()
+        self._correct_count += correct.sum()
+        self._total_count += mask.sum()
+
+    @overrides
+    def get_metric(self, reset: bool) -> float:
+        if self._total_count > 0:
+            accuracy = float(self._correct_count) / float(self._total_count)
+        else:
+            accuracy = 0.0
+        if reset:
+            self.reset()
+        return accuracy
+
+    @overrides
+    def reset(self) -> None:
+        self._correct_count = 0.0
+        self._total_count = 0.0
+        self.correct_indices = torch.ones([])
+
+
 class SequenceBoolAccuracy(metrics.Metric):
     """BoolAccuracy implementation to handle sequences."""
 
@@ -202,7 +262,7 @@ class SemanticMetrics(metrics.Metric):
         self.xpos_score = SequenceBoolAccuracy()
         self.semrel_score = SequenceBoolAccuracy()
         self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
-        self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True)
+        self.lemma_score = LemmaAccuracy()
         self.attachment_scores = AttachmentScores()
         # Ignore PADDING and OOV
         self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1])
diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py
index 242eaa3..bf7f619 100644
--- a/tests/utils/test_metrics.py
+++ b/tests/utils/test_metrics.py
@@ -154,3 +154,58 @@ class SequenceBoolAccuracyTest(unittest.TestCase):
         # then
         self.assertEqual(metric._correct_count.item(), 7)
         self.assertEqual(metric._total_count.item(), 10)
+
+
+class LemmaAccuracyTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.mask: torch.BoolTensor = torch.tensor([
+            [True, True, True, True],
+            [True, True, True, False],
+        ])
+
+    def test_prediction_has_error_in_not_padded_place(self):
+        # given
+        metric = metrics.LemmaAccuracy()
+        predictions = torch.tensor([
+            [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
+            [[1, 1, 0], [1, 1000, 0], [1, 1, 0], [1, 1, 0], ],
+        ])
+        gold_labels = torch.tensor([
+            [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
+            [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
+        ])
+        expected_correct_count = 6
+        expected_total_count = 7
+        expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 0, 1, 0])
+
+        # when
+        metric(predictions, gold_labels, self.mask)
+
+        # then
+        self.assertEqual(metric._correct_count.item(), expected_correct_count)
+        self.assertEqual(metric._total_count.item(), expected_total_count)
+        self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices)))
+
+    def test_prediction_wrong_prediction_in_padding_should_be_ignored(self):
+        # given
+        metric = metrics.LemmaAccuracy()
+        predictions = torch.tensor([
+            [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
+            [[1, 1, 1000], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
+        ])
+        gold_labels = torch.tensor([
+            [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
+            [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
+        ])
+        expected_correct_count = 7
+        expected_total_count = 7
+        expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 1, 1, 0])
+
+        # when
+        metric(predictions, gold_labels, self.mask)
+
+        # then
+        self.assertEqual(expected_correct_count, metric._correct_count.item())
+        self.assertEqual(expected_total_count, metric._total_count.item())
+        self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices)))
-- 
GitLab