Skip to content
Snippets Groups Projects
Commit 2aa41eba authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Remove lemma padding from lemma loss and metric.

parent c0835180
Branches
Tags
2 merge requests!15Merge develop to master,!14Install and docs
......@@ -62,11 +62,11 @@ class LemmatizerModel(base.Predictor):
BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_CLASSES = pred.size()
pred = pred.reshape(-1, CHAR_CLASSES)
valid_positions = mask.sum()
mask = mask.reshape(-1)
true = true.reshape(-1)
mask = true.gt(0)
loss = utils.masked_cross_entropy(pred, true, mask)
loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
valid_positions = mask.sum()
return loss.sum() / valid_positions
@classmethod
......
......@@ -6,6 +6,66 @@ from allennlp.training import metrics
from overrides import overrides
class LemmaAccuracy(metrics.Metric):
def __init__(self):
self._correct_count = 0.0
self._total_count = 0.0
self.correct_indices = torch.ones([])
@overrides
def __call__(self,
predictions: torch.Tensor,
gold_labels: torch.Tensor,
mask: Optional[torch.BoolTensor] = None):
if gold_labels is None:
return
predictions, gold_labels, mask = self.detach_tensors(predictions,
gold_labels,
mask)
# Some sanity checks.
if gold_labels.size() != predictions.size():
raise ValueError(
f"gold_labels must have shape == predictions.size() but "
f"found tensor of shape: {gold_labels.size()}"
)
if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]:
raise ValueError(
f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but "
f"found tensor of shape: {mask.size()}"
)
if mask is None:
mask = predictions.new_ones(predictions.size()[:-1]).bool()
if mask.dim() < predictions.dim():
mask = mask.unsqueeze(-1)
padding_mask = gold_labels.gt(0)
correct = predictions.eq(gold_labels) * padding_mask
correct = (correct.int().sum(-1) == padding_mask.int().sum(-1)) * mask.squeeze(-1)
correct = correct.float()
self.correct_indices = correct.flatten().bool()
self._correct_count += correct.sum()
self._total_count += mask.sum()
@overrides
def get_metric(self, reset: bool) -> float:
if self._total_count > 0:
accuracy = float(self._correct_count) / float(self._total_count)
else:
accuracy = 0.0
if reset:
self.reset()
return accuracy
@overrides
def reset(self) -> None:
self._correct_count = 0.0
self._total_count = 0.0
self.correct_indices = torch.ones([])
class SequenceBoolAccuracy(metrics.Metric):
"""BoolAccuracy implementation to handle sequences."""
......@@ -202,7 +262,7 @@ class SemanticMetrics(metrics.Metric):
self.xpos_score = SequenceBoolAccuracy()
self.semrel_score = SequenceBoolAccuracy()
self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True)
self.lemma_score = LemmaAccuracy()
self.attachment_scores = AttachmentScores()
# Ignore PADDING and OOV
self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1])
......
......@@ -154,3 +154,58 @@ class SequenceBoolAccuracyTest(unittest.TestCase):
# then
self.assertEqual(metric._correct_count.item(), 7)
self.assertEqual(metric._total_count.item(), 10)
class LemmaAccuracyTest(unittest.TestCase):
def setUp(self) -> None:
self.mask: torch.BoolTensor = torch.tensor([
[True, True, True, True],
[True, True, True, False],
])
def test_prediction_has_error_in_not_padded_place(self):
# given
metric = metrics.LemmaAccuracy()
predictions = torch.tensor([
[[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
[[1, 1, 0], [1, 1000, 0], [1, 1, 0], [1, 1, 0], ],
])
gold_labels = torch.tensor([
[[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
[[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
])
expected_correct_count = 6
expected_total_count = 7
expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 0, 1, 0])
# when
metric(predictions, gold_labels, self.mask)
# then
self.assertEqual(metric._correct_count.item(), expected_correct_count)
self.assertEqual(metric._total_count.item(), expected_total_count)
self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices)))
def test_prediction_wrong_prediction_in_padding_should_be_ignored(self):
# given
metric = metrics.LemmaAccuracy()
predictions = torch.tensor([
[[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
[[1, 1, 1000], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
])
gold_labels = torch.tensor([
[[1, 1, 1], [1, 1, 1], [2, 2, 0], [1, 1, 4], ],
[[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], ],
])
expected_correct_count = 7
expected_total_count = 7
expected_correct_indices = torch.tensor([1, 1, 1, 1, 1, 1, 1, 0])
# when
metric(predictions, gold_labels, self.mask)
# then
self.assertEqual(expected_correct_count, metric._correct_count.item())
self.assertEqual(expected_total_count, metric._total_count.item())
self.assertTrue(torch.all(expected_correct_indices.eq(metric.correct_indices)))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment