diff --git a/combo/models/lemma.py b/combo/models/lemma.py index 4ff9c92d78d1f495c1e0d0df0992b368e735a47d..828ba3e9c0377b52c1bfe90dfb311d4f19689d3d 100644 --- a/combo/models/lemma.py +++ b/combo/models/lemma.py @@ -62,11 +62,11 @@ class LemmatizerModel(base.Predictor): BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_CLASSES = pred.size() pred = pred.reshape(-1, CHAR_CLASSES) - valid_positions = mask.sum() - mask = mask.reshape(-1) true = true.reshape(-1) + mask = true.gt(0) loss = utils.masked_cross_entropy(pred, true, mask) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) + valid_positions = mask.sum() return loss.sum() / valid_positions @classmethod diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py index 1a175402ec00b011221f2f7c1c76d7496b68281c..c2a1202148e602d433d2d4a3ac2a280f80bb76fb 100644 --- a/combo/utils/metrics.py +++ b/combo/utils/metrics.py @@ -6,6 +6,66 @@ from allennlp.training import metrics from overrides import overrides +class LemmaAccuracy(metrics.Metric): + + def __init__(self): + self._correct_count = 0.0 + self._total_count = 0.0 + self.correct_indices = torch.ones([]) + + @overrides + def __call__(self, + predictions: torch.Tensor, + gold_labels: torch.Tensor, + mask: Optional[torch.BoolTensor] = None): + if gold_labels is None: + return + predictions, gold_labels, mask = self.detach_tensors(predictions, + gold_labels, + mask) + + # Some sanity checks. + if gold_labels.size() != predictions.size(): + raise ValueError( + f"gold_labels must have shape == predictions.size() but " + f"found tensor of shape: {gold_labels.size()}" + ) + if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]: + raise ValueError( + f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but " + f"found tensor of shape: {mask.size()}" + ) + if mask is None: + mask = predictions.new_ones(predictions.size()[:-1]).bool() + if mask.dim() < predictions.dim(): + mask = mask.unsqueeze(-1) + + padding_mask = gold_labels.gt(0) + correct = predictions.eq(gold_labels) * padding_mask + correct = (correct.int().sum(-1) == padding_mask.int().sum(-1)) * mask.squeeze(-1) + correct = correct.float() + + self.correct_indices = correct.flatten().bool() + self._correct_count += correct.sum() + self._total_count += mask.sum() + + @overrides + def get_metric(self, reset: bool) -> float: + if self._total_count > 0: + accuracy = float(self._correct_count) / float(self._total_count) + else: + accuracy = 0.0 + if reset: + self.reset() + return accuracy + + @overrides + def reset(self) -> None: + self._correct_count = 0.0 + self._total_count = 0.0 + self.correct_indices = torch.ones([]) + + class SequenceBoolAccuracy(metrics.Metric): """BoolAccuracy implementation to handle sequences.""" @@ -202,7 +262,7 @@ class SemanticMetrics(metrics.Metric): self.xpos_score = SequenceBoolAccuracy() self.semrel_score = SequenceBoolAccuracy() self.feats_score = SequenceBoolAccuracy(prod_last_dim=True) - self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True) + self.lemma_score = LemmaAccuracy() self.attachment_scores = AttachmentScores() # Ignore PADDING and OOV self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1]) diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py index 242eaa3dcffaf452c19a52e2f56625de00cd0433..bf7f619d19853700803598a72201815f264d1c70 100644 --- a/tests/utils/test_metrics.py +++ b/tests/utils/test_metrics.py @@ -154,3 +154,58 @@ class SequenceBoolAccuracyTest(unittest.TestCase): # then self.assertEqual(metric._correct_count.item(), 7) self.assertEqual(metric._total_count.item(), 10) + + +class LemmaAccuracyTest(unittest.TestCase): + + def setUp(self) -> None: + self.mask: torch.BoolTensor = torch.tensor([ + [True, True, True, True], + [True, True, True, False], + ]) + + def test_prediction_has_error_in_not_padded_place(self): + # given + metric = metrics.LemmaAccuracy() + predictions = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 0], [1, 1000, 0], [1, 1, 0], [1, 1, 0], ], + ]) + gold_labels = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ], + ]) + expected_correct_count = 6 + expected_total_count = 7 + expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 0, 1, 0]) + + # when + metric(predictions, gold_labels, self.mask) + + # then + self.assertEqual(metric._correct_count.item(), expected_correct_count) + self.assertEqual(metric._total_count.item(), expected_total_count) + self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices))) + + def test_prediction_wrong_prediction_in_padding_should_be_ignored(self): + # given + metric = metrics.LemmaAccuracy() + predictions = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 1000], [1, 1, 0], [1, 1, 0], [1, 1, 0], ], + ]) + gold_labels = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ], + ]) + expected_correct_count = 7 + expected_total_count = 7 + expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 1, 1, 0]) + + # when + metric(predictions, gold_labels, self.mask) + + # then + self.assertEqual(expected_correct_count, metric._correct_count.item()) + self.assertEqual(expected_total_count, metric._total_count.item()) + self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices)))