Skip to content
Snippets Groups Projects
Commit 5e79a44e authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

Add metrics and metrics_test

parent 89790343
1 merge request!46Merge COMBO 3.0 into master
from typing import Optional, List, Dict, Iterable
import torch
from overrides import overrides
"""
Class Metric adapted from AllenNLP
https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/training/metrics/metric.py
"""
class Metric: class Metric:
pass """
A very general abstract class representing a metric which can be
accumulated.
"""
supports_distributed = False
def __call__(
self, predictions: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.BoolTensor]
):
"""
# Parameters
predictions : `torch.Tensor`, required.
A tensor of predictions.
gold_labels : `torch.Tensor`, required.
A tensor corresponding to some gold label to evaluate against.
mask : `torch.BoolTensor`, optional (default = `None`).
A mask can be passed, in order to deal with metrics which are
computed over potentially padded elements, such as sequence labels.
"""
raise NotImplementedError
def get_metric(self, reset: bool):
"""
Compute and return the metric. Optionally also call `self.reset`.
"""
raise NotImplementedError
def reset(self) -> None:
"""
Reset any accumulators or internal state.
"""
raise NotImplementedError
@staticmethod
def detach_tensors(*tensors: torch.Tensor) -> Iterable[torch.Tensor]:
"""
If you actually passed gradient-tracking Tensors to a Metric, there will be
a huge memory leak, because it will prevent garbage collection for the computation
graph. This method ensures the tensors are detached.
"""
# Check if it's actually a tensor in case something else was passed.
return (x.detach() if isinstance(x, torch.Tensor) else x for x in tensors)
"""
Adapted from COMBO
Author: Mateusz Klimaszewski
"""
class LemmaAccuracy(Metric): class LemmaAccuracy(Metric):
pass
def __init__(self):
self._correct_count = 0.0
self._total_count = 0.0
self.correct_indices = torch.ones([])
@overrides
def __call__(self,
predictions: torch.Tensor,
gold_labels: torch.Tensor,
mask: Optional[torch.BoolTensor] = None):
if gold_labels is None:
return
predictions, gold_labels, mask = self.detach_tensors(predictions,
gold_labels,
mask)
# Some sanity checks.
if gold_labels.size() != predictions.size():
raise ValueError(
f"gold_labels must have shape == predictions.size() but "
f"found tensor of shape: {gold_labels.size()}"
)
if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]:
raise ValueError(
f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but "
f"found tensor of shape: {mask.size()}"
)
if mask is None:
mask = predictions.new_ones(predictions.size()[:-1]).bool()
if mask.dim() < predictions.dim():
mask = mask.unsqueeze(-1)
padding_mask = gold_labels.gt(0)
correct = predictions.eq(gold_labels) * padding_mask
correct = (correct.int().sum(-1) == padding_mask.int().sum(-1)) * mask.squeeze(-1)
correct = correct.float()
self.correct_indices = correct.flatten().bool()
self._correct_count += correct.sum()
self._total_count += mask.sum()
@overrides
def get_metric(self, reset: bool) -> float:
if self._total_count > 0:
accuracy = float(self._correct_count) / float(self._total_count)
else:
accuracy = 0.0
if reset:
self.reset()
return accuracy
@overrides
def reset(self) -> None:
self._correct_count = 0.0
self._total_count = 0.0
self.correct_indices = torch.ones([])
class SequenceBoolAccuracy(Metric): class SequenceBoolAccuracy(Metric):
pass """BoolAccuracy implementation to handle sequences."""
def __init__(self, prod_last_dim: bool = False):
self._correct_count = 0.0
self._total_count = 0.0
self.prod_last_dim = prod_last_dim
self.correct_indices = torch.ones([])
@overrides
def __call__(self,
predictions: torch.Tensor,
gold_labels: torch.Tensor,
mask: Optional[torch.BoolTensor] = None):
if gold_labels is None:
return
predictions, gold_labels, mask = self.detach_tensors(predictions,
gold_labels,
mask)
# Some sanity checks.
if gold_labels.size() != predictions.size():
raise ValueError(
f"gold_labels must have shape == predictions.size() but "
f"found tensor of shape: {gold_labels.size()}"
)
if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]:
raise ValueError(
f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but "
f"found tensor of shape: {mask.size()}"
)
if mask is None:
mask = predictions.new_ones(predictions.size()[:-1]).bool()
if mask.dim() < predictions.dim():
mask = mask.unsqueeze(-1)
correct = predictions.eq(gold_labels) * mask
if self.prod_last_dim:
correct = correct.prod(-1).unsqueeze(-1)
correct = correct.float()
self.correct_indices = correct.flatten().bool()
self._correct_count += correct.sum()
self._total_count += mask.sum()
@overrides
def get_metric(self, reset: bool) -> float:
if self._total_count > 0:
accuracy = float(self._correct_count) / float(self._total_count)
else:
accuracy = 0.0
if reset:
self.reset()
return accuracy
@overrides
def reset(self) -> None:
self._correct_count = 0.0
self._total_count = 0.0
self.correct_indices = torch.ones([])
class AttachmentScores(Metric): class AttachmentScores(Metric):
pass """
Computes labeled and unlabeled attachment scores for a
dependency parse, as well as sentence level exact match
for both labeled and unlabeled trees. Note that the input
to this metric is the sampled predictions, not the distribution
itself.
# Parameters
ignore_classes : `List[int]`, optional (default = None)
A list of label ids to ignore when computing metrics.
"""
def __init__(self, ignore_classes: List[int] = None) -> None:
self._labeled_correct = 0.0
self._unlabeled_correct = 0.0
self._exact_labeled_correct = 0.0
self._exact_unlabeled_correct = 0.0
self._total_words = 0.0
self._total_sentences = 0.0
self.correct_indices = torch.ones([])
self._ignore_classes: List[int] = ignore_classes or []
def __call__( # type: ignore
self,
predicted_indices: torch.Tensor,
predicted_labels: torch.Tensor,
gold_indices: torch.Tensor,
gold_labels: torch.Tensor,
mask: Optional[torch.BoolTensor] = None,
):
"""
# Parameters
predicted_indices : `torch.Tensor`, required.
A tensor of head index predictions of shape (batch_size, timesteps).
predicted_labels : `torch.Tensor`, required.
A tensor of arc label predictions of shape (batch_size, timesteps).
gold_indices : `torch.Tensor`, required.
A tensor of the same shape as `predicted_indices`.
gold_labels : `torch.Tensor`, required.
A tensor of the same shape as `predicted_labels`.
mask : `torch.BoolTensor`, optional (default = None).
A tensor of the same shape as `predicted_indices`.
"""
if gold_labels is None or gold_indices is None:
return
detached = self.detach_tensors(
predicted_indices, predicted_labels, gold_indices, gold_labels, mask
)
predicted_indices, predicted_labels, gold_indices, gold_labels, mask = detached
if mask is None:
mask = torch.ones_like(predicted_indices).bool()
predicted_indices = predicted_indices.long()
predicted_labels = predicted_labels.long()
gold_indices = gold_indices.long()
gold_labels = gold_labels.long()
# Multiply by a mask denoting locations of
# gold labels which we should ignore.
for label in self._ignore_classes:
label_mask = gold_labels.eq(label)
mask = mask & ~label_mask
correct_indices = predicted_indices.eq(gold_indices).long() * mask
unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1)
if len(correct_indices.size()) > 2:
unlabeled_exact_match = unlabeled_exact_match.prod(dim=-1)
correct_labels = predicted_labels.eq(gold_labels).long() * mask
correct_labels_and_indices = correct_indices * correct_labels
self.correct_indices = correct_labels_and_indices.flatten()
labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1)
if len(correct_indices.size()) > 2:
labeled_exact_match = labeled_exact_match.prod(dim=-1)
self._unlabeled_correct += correct_indices.sum()
self._exact_unlabeled_correct += unlabeled_exact_match.sum()
self._labeled_correct += correct_labels_and_indices.sum()
self._exact_labeled_correct += labeled_exact_match.sum()
self._total_sentences += correct_indices.size(0)
self._total_words += correct_indices.numel() - (~mask).sum()
def get_metric(self, reset: bool = False):
"""
# Returns
The accumulated metrics as a dictionary.
"""
unlabeled_attachment_score = 0.0
labeled_attachment_score = 0.0
unlabeled_exact_match = 0.0
labeled_exact_match = 0.0
if self._total_words > 0.0:
unlabeled_attachment_score = float(self._unlabeled_correct) / float(self._total_words)
labeled_attachment_score = float(self._labeled_correct) / float(self._total_words)
if self._total_sentences > 0:
unlabeled_exact_match = float(self._exact_unlabeled_correct) / float(
self._total_sentences
)
labeled_exact_match = float(self._exact_labeled_correct) / float(self._total_sentences)
if reset:
self.reset()
return {
"UAS": unlabeled_attachment_score,
"LAS": labeled_attachment_score,
"UEM": unlabeled_exact_match,
"LEM": labeled_exact_match,
}
@overrides
def reset(self):
self._labeled_correct = 0.0
self._unlabeled_correct = 0.0
self._exact_labeled_correct = 0.0
self._exact_unlabeled_correct = 0.0
self._total_words = 0.0
self._total_sentences = 0.0
self.correct_indices = torch.ones([])
class SemanticMetrics(Metric): class SemanticMetrics(Metric):
pass """Groups metrics for all predictions."""
def __init__(self) -> None:
self.upos_score = SequenceBoolAccuracy()
self.xpos_score = SequenceBoolAccuracy()
self.semrel_score = SequenceBoolAccuracy()
self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
self.lemma_score = LemmaAccuracy()
self.attachment_scores = AttachmentScores()
# Ignore PADDING and OOV
self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1])
self.em_score = 0.0
def __call__( # type: ignore
self,
predictions: Dict[str, torch.Tensor],
gold_labels: Dict[str, torch.Tensor],
mask: torch.BoolTensor):
self.upos_score(predictions["upostag"], gold_labels["upostag"], mask)
self.xpos_score(predictions["xpostag"], gold_labels["xpostag"], mask)
self.semrel_score(predictions["semrel"], gold_labels["semrel"], mask)
self.feats_score(predictions["feats"], gold_labels["feats"], mask)
self.lemma_score(predictions["lemma"], gold_labels["lemma"], mask)
self.attachment_scores(predictions["head"],
predictions["deprel"],
gold_labels["head"],
gold_labels["deprel"],
mask)
self.enhanced_attachment_scores(predictions["enhanced_head"],
predictions["enhanced_deprel"],
gold_labels["enhanced_head"],
gold_labels["enhanced_deprel"],
mask=None)
enhanced_indices = (
self.enhanced_attachment_scores.correct_indices.reshape(mask.size(0), mask.size(1) + 1, -1)[:, 1:, 1:].sum(
-1).reshape(-1).bool()
if len(self.enhanced_attachment_scores.correct_indices.size()) > 0
else self.enhanced_attachment_scores.correct_indices
)
total = mask.sum()
correct_indices = (self.upos_score.correct_indices *
self.xpos_score.correct_indices *
self.semrel_score.correct_indices *
self.feats_score.correct_indices *
self.lemma_score.correct_indices *
self.attachment_scores.correct_indices *
enhanced_indices) * mask.flatten()
total, correct_indices = self.detach_tensors(total, correct_indices.float().sum())
self.em_score = (correct_indices / total).item()
def get_metric(self, reset: bool) -> Dict[str, float]:
metrics_dict = {
"UPOS_ACC": self.upos_score.get_metric(reset),
"XPOS_ACC": self.xpos_score.get_metric(reset),
"SEMREL_ACC": self.semrel_score.get_metric(reset),
"LEMMA_ACC": self.lemma_score.get_metric(reset),
"FEATS_ACC": self.feats_score.get_metric(reset),
"EM": self.em_score
}
metrics_dict.update(self.attachment_scores.get_metric(reset))
enhanced_metrics = {f"E{k}": v for k, v in self.enhanced_attachment_scores.get_metric(reset).items()}
metrics_dict.update(enhanced_metrics)
return metrics_dict
def reset(self) -> None:
self.upos_score.reset()
self.xpos_score.reset()
self.semrel_score.reset()
self.lemma_score.reset()
self.feats_score.reset()
self.attachment_scores.reset()
self.enhanced_attachment_scores.reset()
self.em_score = 0.0
"""Metrics tests."""
import unittest
import torch
from combo.utils import metrics
class SemanticMetricsTest(unittest.TestCase):
def setUp(self) -> None:
self.mask: torch.BoolTensor = torch.tensor([
[True, True, True, True],
[True, True, True, False],
[True, True, True, False],
])
pred = torch.tensor([
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
])
pred_seq = pred.reshape(3, 4, 1)
gold = pred.clone()
gold_seq = pred_seq.clone()
self.upostag, self.upostag_l = (("upostag", x) for x in [pred, gold])
self.xpostag, self.xpostag_l = (("xpostag", x) for x in [pred, gold])
self.semrel, self.semrel_l = (("semrel", x) for x in [pred, gold])
self.head, self.head_l = (("head", x) for x in [pred, gold])
self.deprel, self.deprel_l = (("deprel", x) for x in [pred, gold])
# TODO(mklimasz) Set up an example with size 3x5x5
self.enhanced_head, self.enhanced_head_l = (("enhanced_head", x) for x in [None, None])
self.enhanced_deprel, self.enhanced_deprel_l = (("enhanced_deprel", x) for x in [None, None])
self.feats, self.feats_l = (("feats", x) for x in [pred_seq, gold_seq])
self.lemma, self.lemma_l = (("lemma", x) for x in [pred_seq, gold_seq])
self.predictions = dict(
[self.upostag, self.xpostag, self.semrel, self.feats, self.lemma, self.head, self.deprel,
self.enhanced_head, self.enhanced_deprel])
self.gold_labels = dict([self.upostag_l, self.xpostag_l, self.semrel_l, self.feats_l, self.lemma_l, self.head_l,
self.deprel_l, self.enhanced_head_l, self.enhanced_deprel_l])
self.eps = 1e-6
def test_every_prediction_correct(self):
# given
metric = metrics.SemanticMetrics()
# when
metric(self.predictions, self.gold_labels, self.mask)
# then
self.assertEqual(1.0, metric.em_score)
def test_missing_predictions_for_one_target(self):
# given
metric = metrics.SemanticMetrics()
self.predictions["upostag"] = None
self.gold_labels["upostag"] = None
# when
metric(self.predictions, self.gold_labels, self.mask)
# then
self.assertEqual(1.0, metric.em_score)
def test_missing_predictions_for_two_targets(self):
# given
metric = metrics.SemanticMetrics()
self.predictions["upostag"] = None
self.gold_labels["upostag"] = None
self.predictions["lemma"] = None
self.gold_labels["lemma"] = None
# when
metric(self.predictions, self.gold_labels, self.mask)
# then
self.assertEqual(1.0, metric.em_score)
def test_one_classification_in_one_target_is_wrong(self):
# given
metric = metrics.SemanticMetrics()
self.predictions["upostag"][0][0] = 100
# when
metric(self.predictions, self.gold_labels, self.mask)
# then
self.assertAlmostEqual(0.9, metric.em_score, delta=self.eps)
def test_classification_errors_and_target_without_predictions(self):
# given
metric = metrics.SemanticMetrics()
self.predictions["feats"] = None
self.gold_labels["feats"] = None
self.predictions["upostag"][0][0] = 100
self.predictions["upostag"][2][0] = 100
# should be ignored due to masking
self.predictions["upostag"][1][3] = 100
# when
metric(self.predictions, self.gold_labels, self.mask)
# then
self.assertAlmostEqual(0.8, metric.em_score, delta=self.eps)
class SequenceBoolAccuracyTest(unittest.TestCase):
def setUp(self) -> None:
self.mask: torch.BoolTensor = torch.tensor([
[True, True, True, True],
[True, True, True, False],
[True, True, True, False],
])
def test_regular_classification_accuracy(self):
# given
metric = metrics.SequenceBoolAccuracy()
predictions = torch.tensor([
[1, 1, 0, 8],
[1, 2, 3, 4],
[9, 4, 3, 9],
])
gold_labels = torch.tensor([
[11, 1, 0, 8],
[14, 2, 3, 14],
[9, 4, 13, 9],
])
# when
metric(predictions, gold_labels, self.mask)
# then
self.assertEqual(metric._correct_count.item(), 7)
self.assertEqual(metric._total_count.item(), 10)
def test_feats_classification_accuracy(self):
# given
metric = metrics.SequenceBoolAccuracy(prod_last_dim=True)
# batch_size, sequence_length, classes
predictions = torch.tensor([
[[1, 4], [0, 2], [0, 2], [0, 3]],
[[1, 4], [0, 2], [0, 2], [0, 3]],
[[1, 4], [0, 2], [0, 2], [0, 3]],
])
gold_labels = torch.tensor([
[[1, 14], [0, 2], [0, 2], [0, 3]],
[[11, 4], [0, 2], [0, 2], [10, 3]],
[[1, 4], [0, 2], [10, 12], [0, 3]],
])
# when
metric(predictions, gold_labels, self.mask)
# then
self.assertEqual(metric._correct_count.item(), 7)
self.assertEqual(metric._total_count.item(), 10)
class LemmaAccuracyTest(unittest.TestCase):
def setUp(self) -> None:
self.mask: torch.BoolTensor = torch.tensor([
[True, True, True, True],
[True, True, True, False],
])
def test_prediction_has_error_in_not_padded_place(self):
# given
metric = metrics.LemmaAccuracy()
predictions = torch.tensor([
[[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
[[1, 1, 0], [1, 1000, 0], [1, 1, 0], [1, 1, 0], ],
])
gold_labels = torch.tensor([
[[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
[[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
])
expected_correct_count = 6
expected_total_count = 7
expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 0, 1, 0])
# when
metric(predictions, gold_labels, self.mask)
# then
self.assertEqual(metric._correct_count.item(), expected_correct_count)
self.assertEqual(metric._total_count.item(), expected_total_count)
self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices)))
def test_prediction_wrong_prediction_in_padding_should_be_ignored(self):
# given
metric = metrics.LemmaAccuracy()
predictions = torch.tensor([
[[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
[[1, 1, 1000], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
])
gold_labels = torch.tensor([
[[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
[[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
])
expected_correct_count = 7
expected_total_count = 7
expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 1, 1, 0])
# when
metric(predictions, gold_labels, self.mask)
# then
self.assertEqual(expected_correct_count, metric._correct_count.item())
self.assertEqual(expected_total_count, metric._total_count.item())
self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices)))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment