From 5e79a44e23bc53a8ed6ecec36aff4513de036f8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com> Date: Tue, 28 Mar 2023 20:14:08 +0200 Subject: [PATCH] Add metrics and metrics_test --- combo/utils/metrics.py | 381 +++++++++++++++++++++++++++++++++++- tests/utils/test_metrics.py | 211 ++++++++++++++++++++ 2 files changed, 587 insertions(+), 5 deletions(-) create mode 100644 tests/utils/test_metrics.py diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py index 9181040..47904ec 100644 --- a/combo/utils/metrics.py +++ b/combo/utils/metrics.py @@ -1,17 +1,388 @@ +from typing import Optional, List, Dict, Iterable + +import torch +from overrides import overrides + + +""" +Class Metric adapted from AllenNLP +https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/training/metrics/metric.py +""" + + class Metric: - pass + """ + A very general abstract class representing a metric which can be + accumulated. + """ + + supports_distributed = False + + def __call__( + self, predictions: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.BoolTensor] + ): + """ + # Parameters + predictions : `torch.Tensor`, required. + A tensor of predictions. + gold_labels : `torch.Tensor`, required. + A tensor corresponding to some gold label to evaluate against. + mask : `torch.BoolTensor`, optional (default = `None`). + A mask can be passed, in order to deal with metrics which are + computed over potentially padded elements, such as sequence labels. + """ + raise NotImplementedError + + def get_metric(self, reset: bool): + """ + Compute and return the metric. Optionally also call `self.reset`. + """ + raise NotImplementedError + + def reset(self) -> None: + """ + Reset any accumulators or internal state. + """ + raise NotImplementedError + + @staticmethod + def detach_tensors(*tensors: torch.Tensor) -> Iterable[torch.Tensor]: + """ + If you actually passed gradient-tracking Tensors to a Metric, there will be + a huge memory leak, because it will prevent garbage collection for the computation + graph. This method ensures the tensors are detached. + """ + # Check if it's actually a tensor in case something else was passed. + return (x.detach() if isinstance(x, torch.Tensor) else x for x in tensors) + + +""" +Adapted from COMBO +Author: Mateusz Klimaszewski +""" + class LemmaAccuracy(Metric): - pass + + def __init__(self): + self._correct_count = 0.0 + self._total_count = 0.0 + self.correct_indices = torch.ones([]) + + @overrides + def __call__(self, + predictions: torch.Tensor, + gold_labels: torch.Tensor, + mask: Optional[torch.BoolTensor] = None): + if gold_labels is None: + return + predictions, gold_labels, mask = self.detach_tensors(predictions, + gold_labels, + mask) + + # Some sanity checks. + if gold_labels.size() != predictions.size(): + raise ValueError( + f"gold_labels must have shape == predictions.size() but " + f"found tensor of shape: {gold_labels.size()}" + ) + if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]: + raise ValueError( + f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but " + f"found tensor of shape: {mask.size()}" + ) + if mask is None: + mask = predictions.new_ones(predictions.size()[:-1]).bool() + if mask.dim() < predictions.dim(): + mask = mask.unsqueeze(-1) + + padding_mask = gold_labels.gt(0) + correct = predictions.eq(gold_labels) * padding_mask + correct = (correct.int().sum(-1) == padding_mask.int().sum(-1)) * mask.squeeze(-1) + correct = correct.float() + + self.correct_indices = correct.flatten().bool() + self._correct_count += correct.sum() + self._total_count += mask.sum() + + @overrides + def get_metric(self, reset: bool) -> float: + if self._total_count > 0: + accuracy = float(self._correct_count) / float(self._total_count) + else: + accuracy = 0.0 + if reset: + self.reset() + return accuracy + + @overrides + def reset(self) -> None: + self._correct_count = 0.0 + self._total_count = 0.0 + self.correct_indices = torch.ones([]) class SequenceBoolAccuracy(Metric): - pass + """BoolAccuracy implementation to handle sequences.""" + + def __init__(self, prod_last_dim: bool = False): + self._correct_count = 0.0 + self._total_count = 0.0 + self.prod_last_dim = prod_last_dim + self.correct_indices = torch.ones([]) + + @overrides + def __call__(self, + predictions: torch.Tensor, + gold_labels: torch.Tensor, + mask: Optional[torch.BoolTensor] = None): + if gold_labels is None: + return + predictions, gold_labels, mask = self.detach_tensors(predictions, + gold_labels, + mask) + + # Some sanity checks. + if gold_labels.size() != predictions.size(): + raise ValueError( + f"gold_labels must have shape == predictions.size() but " + f"found tensor of shape: {gold_labels.size()}" + ) + if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]: + raise ValueError( + f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but " + f"found tensor of shape: {mask.size()}" + ) + if mask is None: + mask = predictions.new_ones(predictions.size()[:-1]).bool() + if mask.dim() < predictions.dim(): + mask = mask.unsqueeze(-1) + + correct = predictions.eq(gold_labels) * mask + + if self.prod_last_dim: + correct = correct.prod(-1).unsqueeze(-1) + + correct = correct.float() + + self.correct_indices = correct.flatten().bool() + self._correct_count += correct.sum() + self._total_count += mask.sum() + + @overrides + def get_metric(self, reset: bool) -> float: + if self._total_count > 0: + accuracy = float(self._correct_count) / float(self._total_count) + else: + accuracy = 0.0 + if reset: + self.reset() + return accuracy + + @overrides + def reset(self) -> None: + self._correct_count = 0.0 + self._total_count = 0.0 + self.correct_indices = torch.ones([]) class AttachmentScores(Metric): - pass + """ + Computes labeled and unlabeled attachment scores for a + dependency parse, as well as sentence level exact match + for both labeled and unlabeled trees. Note that the input + to this metric is the sampled predictions, not the distribution + itself. + + # Parameters + + ignore_classes : `List[int]`, optional (default = None) + A list of label ids to ignore when computing metrics. + """ + + def __init__(self, ignore_classes: List[int] = None) -> None: + self._labeled_correct = 0.0 + self._unlabeled_correct = 0.0 + self._exact_labeled_correct = 0.0 + self._exact_unlabeled_correct = 0.0 + self._total_words = 0.0 + self._total_sentences = 0.0 + self.correct_indices = torch.ones([]) + + self._ignore_classes: List[int] = ignore_classes or [] + + def __call__( # type: ignore + self, + predicted_indices: torch.Tensor, + predicted_labels: torch.Tensor, + gold_indices: torch.Tensor, + gold_labels: torch.Tensor, + mask: Optional[torch.BoolTensor] = None, + ): + """ + # Parameters + + predicted_indices : `torch.Tensor`, required. + A tensor of head index predictions of shape (batch_size, timesteps). + predicted_labels : `torch.Tensor`, required. + A tensor of arc label predictions of shape (batch_size, timesteps). + gold_indices : `torch.Tensor`, required. + A tensor of the same shape as `predicted_indices`. + gold_labels : `torch.Tensor`, required. + A tensor of the same shape as `predicted_labels`. + mask : `torch.BoolTensor`, optional (default = None). + A tensor of the same shape as `predicted_indices`. + """ + if gold_labels is None or gold_indices is None: + return + detached = self.detach_tensors( + predicted_indices, predicted_labels, gold_indices, gold_labels, mask + ) + predicted_indices, predicted_labels, gold_indices, gold_labels, mask = detached + + if mask is None: + mask = torch.ones_like(predicted_indices).bool() + + predicted_indices = predicted_indices.long() + predicted_labels = predicted_labels.long() + gold_indices = gold_indices.long() + gold_labels = gold_labels.long() + + # Multiply by a mask denoting locations of + # gold labels which we should ignore. + for label in self._ignore_classes: + label_mask = gold_labels.eq(label) + mask = mask & ~label_mask + + correct_indices = predicted_indices.eq(gold_indices).long() * mask + unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1) + if len(correct_indices.size()) > 2: + unlabeled_exact_match = unlabeled_exact_match.prod(dim=-1) + correct_labels = predicted_labels.eq(gold_labels).long() * mask + correct_labels_and_indices = correct_indices * correct_labels + self.correct_indices = correct_labels_and_indices.flatten() + labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1) + if len(correct_indices.size()) > 2: + labeled_exact_match = labeled_exact_match.prod(dim=-1) + + self._unlabeled_correct += correct_indices.sum() + self._exact_unlabeled_correct += unlabeled_exact_match.sum() + self._labeled_correct += correct_labels_and_indices.sum() + self._exact_labeled_correct += labeled_exact_match.sum() + self._total_sentences += correct_indices.size(0) + self._total_words += correct_indices.numel() - (~mask).sum() + + def get_metric(self, reset: bool = False): + """ + # Returns + + The accumulated metrics as a dictionary. + """ + unlabeled_attachment_score = 0.0 + labeled_attachment_score = 0.0 + unlabeled_exact_match = 0.0 + labeled_exact_match = 0.0 + if self._total_words > 0.0: + unlabeled_attachment_score = float(self._unlabeled_correct) / float(self._total_words) + labeled_attachment_score = float(self._labeled_correct) / float(self._total_words) + if self._total_sentences > 0: + unlabeled_exact_match = float(self._exact_unlabeled_correct) / float( + self._total_sentences + ) + labeled_exact_match = float(self._exact_labeled_correct) / float(self._total_sentences) + if reset: + self.reset() + return { + "UAS": unlabeled_attachment_score, + "LAS": labeled_attachment_score, + "UEM": unlabeled_exact_match, + "LEM": labeled_exact_match, + } + + @overrides + def reset(self): + self._labeled_correct = 0.0 + self._unlabeled_correct = 0.0 + self._exact_labeled_correct = 0.0 + self._exact_unlabeled_correct = 0.0 + self._total_words = 0.0 + self._total_sentences = 0.0 + self.correct_indices = torch.ones([]) class SemanticMetrics(Metric): - pass + """Groups metrics for all predictions.""" + + def __init__(self) -> None: + self.upos_score = SequenceBoolAccuracy() + self.xpos_score = SequenceBoolAccuracy() + self.semrel_score = SequenceBoolAccuracy() + self.feats_score = SequenceBoolAccuracy(prod_last_dim=True) + self.lemma_score = LemmaAccuracy() + self.attachment_scores = AttachmentScores() + # Ignore PADDING and OOV + self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1]) + self.em_score = 0.0 + + def __call__( # type: ignore + self, + predictions: Dict[str, torch.Tensor], + gold_labels: Dict[str, torch.Tensor], + mask: torch.BoolTensor): + self.upos_score(predictions["upostag"], gold_labels["upostag"], mask) + self.xpos_score(predictions["xpostag"], gold_labels["xpostag"], mask) + self.semrel_score(predictions["semrel"], gold_labels["semrel"], mask) + self.feats_score(predictions["feats"], gold_labels["feats"], mask) + self.lemma_score(predictions["lemma"], gold_labels["lemma"], mask) + self.attachment_scores(predictions["head"], + predictions["deprel"], + gold_labels["head"], + gold_labels["deprel"], + mask) + self.enhanced_attachment_scores(predictions["enhanced_head"], + predictions["enhanced_deprel"], + gold_labels["enhanced_head"], + gold_labels["enhanced_deprel"], + mask=None) + enhanced_indices = ( + self.enhanced_attachment_scores.correct_indices.reshape(mask.size(0), mask.size(1) + 1, -1)[:, 1:, 1:].sum( + -1).reshape(-1).bool() + if len(self.enhanced_attachment_scores.correct_indices.size()) > 0 + else self.enhanced_attachment_scores.correct_indices + ) + total = mask.sum() + correct_indices = (self.upos_score.correct_indices * + self.xpos_score.correct_indices * + self.semrel_score.correct_indices * + self.feats_score.correct_indices * + self.lemma_score.correct_indices * + self.attachment_scores.correct_indices * + enhanced_indices) * mask.flatten() + + total, correct_indices = self.detach_tensors(total, correct_indices.float().sum()) + self.em_score = (correct_indices / total).item() + + def get_metric(self, reset: bool) -> Dict[str, float]: + metrics_dict = { + "UPOS_ACC": self.upos_score.get_metric(reset), + "XPOS_ACC": self.xpos_score.get_metric(reset), + "SEMREL_ACC": self.semrel_score.get_metric(reset), + "LEMMA_ACC": self.lemma_score.get_metric(reset), + "FEATS_ACC": self.feats_score.get_metric(reset), + "EM": self.em_score + } + metrics_dict.update(self.attachment_scores.get_metric(reset)) + enhanced_metrics = {f"E{k}": v for k, v in self.enhanced_attachment_scores.get_metric(reset).items()} + metrics_dict.update(enhanced_metrics) + return metrics_dict + + def reset(self) -> None: + self.upos_score.reset() + self.xpos_score.reset() + self.semrel_score.reset() + self.lemma_score.reset() + self.feats_score.reset() + self.attachment_scores.reset() + self.enhanced_attachment_scores.reset() + self.em_score = 0.0 + diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py new file mode 100644 index 0000000..bf7f619 --- /dev/null +++ b/tests/utils/test_metrics.py @@ -0,0 +1,211 @@ +"""Metrics tests.""" +import unittest + +import torch + +from combo.utils import metrics + + +class SemanticMetricsTest(unittest.TestCase): + + def setUp(self) -> None: + self.mask: torch.BoolTensor = torch.tensor([ + [True, True, True, True], + [True, True, True, False], + [True, True, True, False], + ]) + pred = torch.tensor([ + [0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3], + ]) + pred_seq = pred.reshape(3, 4, 1) + gold = pred.clone() + gold_seq = pred_seq.clone() + self.upostag, self.upostag_l = (("upostag", x) for x in [pred, gold]) + self.xpostag, self.xpostag_l = (("xpostag", x) for x in [pred, gold]) + self.semrel, self.semrel_l = (("semrel", x) for x in [pred, gold]) + self.head, self.head_l = (("head", x) for x in [pred, gold]) + self.deprel, self.deprel_l = (("deprel", x) for x in [pred, gold]) + # TODO(mklimasz) Set up an example with size 3x5x5 + self.enhanced_head, self.enhanced_head_l = (("enhanced_head", x) for x in [None, None]) + self.enhanced_deprel, self.enhanced_deprel_l = (("enhanced_deprel", x) for x in [None, None]) + self.feats, self.feats_l = (("feats", x) for x in [pred_seq, gold_seq]) + self.lemma, self.lemma_l = (("lemma", x) for x in [pred_seq, gold_seq]) + self.predictions = dict( + [self.upostag, self.xpostag, self.semrel, self.feats, self.lemma, self.head, self.deprel, + self.enhanced_head, self.enhanced_deprel]) + self.gold_labels = dict([self.upostag_l, self.xpostag_l, self.semrel_l, self.feats_l, self.lemma_l, self.head_l, + self.deprel_l, self.enhanced_head_l, self.enhanced_deprel_l]) + self.eps = 1e-6 + + def test_every_prediction_correct(self): + # given + metric = metrics.SemanticMetrics() + + # when + metric(self.predictions, self.gold_labels, self.mask) + + # then + self.assertEqual(1.0, metric.em_score) + + def test_missing_predictions_for_one_target(self): + # given + metric = metrics.SemanticMetrics() + self.predictions["upostag"] = None + self.gold_labels["upostag"] = None + + # when + metric(self.predictions, self.gold_labels, self.mask) + + # then + self.assertEqual(1.0, metric.em_score) + + def test_missing_predictions_for_two_targets(self): + # given + metric = metrics.SemanticMetrics() + self.predictions["upostag"] = None + self.gold_labels["upostag"] = None + self.predictions["lemma"] = None + self.gold_labels["lemma"] = None + + # when + metric(self.predictions, self.gold_labels, self.mask) + + # then + self.assertEqual(1.0, metric.em_score) + + def test_one_classification_in_one_target_is_wrong(self): + # given + metric = metrics.SemanticMetrics() + self.predictions["upostag"][0][0] = 100 + + # when + metric(self.predictions, self.gold_labels, self.mask) + + # then + self.assertAlmostEqual(0.9, metric.em_score, delta=self.eps) + + def test_classification_errors_and_target_without_predictions(self): + # given + metric = metrics.SemanticMetrics() + self.predictions["feats"] = None + self.gold_labels["feats"] = None + self.predictions["upostag"][0][0] = 100 + self.predictions["upostag"][2][0] = 100 + # should be ignored due to masking + self.predictions["upostag"][1][3] = 100 + + # when + metric(self.predictions, self.gold_labels, self.mask) + + # then + self.assertAlmostEqual(0.8, metric.em_score, delta=self.eps) + + +class SequenceBoolAccuracyTest(unittest.TestCase): + + def setUp(self) -> None: + self.mask: torch.BoolTensor = torch.tensor([ + [True, True, True, True], + [True, True, True, False], + [True, True, True, False], + ]) + + def test_regular_classification_accuracy(self): + # given + metric = metrics.SequenceBoolAccuracy() + predictions = torch.tensor([ + [1, 1, 0, 8], + [1, 2, 3, 4], + [9, 4, 3, 9], + ]) + gold_labels = torch.tensor([ + [11, 1, 0, 8], + [14, 2, 3, 14], + [9, 4, 13, 9], + ]) + + # when + metric(predictions, gold_labels, self.mask) + + # then + self.assertEqual(metric._correct_count.item(), 7) + self.assertEqual(metric._total_count.item(), 10) + + def test_feats_classification_accuracy(self): + # given + metric = metrics.SequenceBoolAccuracy(prod_last_dim=True) + # batch_size, sequence_length, classes + predictions = torch.tensor([ + [[1, 4], [0, 2], [0, 2], [0, 3]], + [[1, 4], [0, 2], [0, 2], [0, 3]], + [[1, 4], [0, 2], [0, 2], [0, 3]], + ]) + gold_labels = torch.tensor([ + [[1, 14], [0, 2], [0, 2], [0, 3]], + [[11, 4], [0, 2], [0, 2], [10, 3]], + [[1, 4], [0, 2], [10, 12], [0, 3]], + ]) + + # when + metric(predictions, gold_labels, self.mask) + + # then + self.assertEqual(metric._correct_count.item(), 7) + self.assertEqual(metric._total_count.item(), 10) + + +class LemmaAccuracyTest(unittest.TestCase): + + def setUp(self) -> None: + self.mask: torch.BoolTensor = torch.tensor([ + [True, True, True, True], + [True, True, True, False], + ]) + + def test_prediction_has_error_in_not_padded_place(self): + # given + metric = metrics.LemmaAccuracy() + predictions = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 0], [1, 1000, 0], [1, 1, 0], [1, 1, 0], ], + ]) + gold_labels = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ], + ]) + expected_correct_count = 6 + expected_total_count = 7 + expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 0, 1, 0]) + + # when + metric(predictions, gold_labels, self.mask) + + # then + self.assertEqual(metric._correct_count.item(), expected_correct_count) + self.assertEqual(metric._total_count.item(), expected_total_count) + self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices))) + + def test_prediction_wrong_prediction_in_padding_should_be_ignored(self): + # given + metric = metrics.LemmaAccuracy() + predictions = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 1000], [1, 1, 0], [1, 1, 0], [1, 1, 0], ], + ]) + gold_labels = torch.tensor([ + [[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ], + [[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ], + ]) + expected_correct_count = 7 + expected_total_count = 7 + expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 1, 1, 0]) + + # when + metric(predictions, gold_labels, self.mask) + + # then + self.assertEqual(expected_correct_count, metric._correct_count.item()) + self.assertEqual(expected_total_count, metric._total_count.item()) + self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices))) -- GitLab