Skip to content
Snippets Groups Projects

Enhanced dependency parsing develop to master

Merged Mateusz Klimaszewski requested to merge develop into master
Viewing commit 7ffc1072
Show latest version
4 files
+ 7
8
Compare changes
  • Side-by-side
  • Inline
Files
4
+ 6
1
@@ -140,10 +140,14 @@ class AttachmentScores(metrics.Metric):
correct_indices = predicted_indices.eq(gold_indices).long() * mask
unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1)
if len(correct_indices.size()) > 2:
unlabeled_exact_match = unlabeled_exact_match.prod(dim=-1)
correct_labels = predicted_labels.eq(gold_labels).long() * mask
correct_labels_and_indices = correct_indices * correct_labels
self.correct_indices = correct_labels_and_indices.flatten()
labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1)
if len(correct_indices.size()) > 2:
labeled_exact_match = labeled_exact_match.prod(dim=-1)
self._unlabeled_correct += correct_indices.sum()
self._exact_unlabeled_correct += unlabeled_exact_match.sum()
@@ -200,7 +204,8 @@ class SemanticMetrics(metrics.Metric):
self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True)
self.attachment_scores = AttachmentScores()
self.enhanced_attachment_scores = AttachmentScores()
# Ignore PADDING and OOV
self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1])
self.em_score = 0.0
def __call__( # type: ignore