Skip to content
Snippets Groups Projects
Commit 7ffc1072 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Fix enhanced dependency parsing metrics.

parent 51be6680
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
This commit is part of merge request !9. Comments created here will be created in the context of that merge request.
......@@ -140,10 +140,14 @@ class AttachmentScores(metrics.Metric):
correct_indices = predicted_indices.eq(gold_indices).long() * mask
unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1)
if len(correct_indices.size()) > 2:
unlabeled_exact_match = unlabeled_exact_match.prod(dim=-1)
correct_labels = predicted_labels.eq(gold_labels).long() * mask
correct_labels_and_indices = correct_indices * correct_labels
self.correct_indices = correct_labels_and_indices.flatten()
labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1)
if len(correct_indices.size()) > 2:
labeled_exact_match = labeled_exact_match.prod(dim=-1)
self._unlabeled_correct += correct_indices.sum()
self._exact_unlabeled_correct += unlabeled_exact_match.sum()
......@@ -200,7 +204,8 @@ class SemanticMetrics(metrics.Metric):
self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True)
self.attachment_scores = AttachmentScores()
self.enhanced_attachment_scores = AttachmentScores()
# Ignore PADDING and OOV
self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1])
self.em_score = 0.0
def __call__( # type: ignore
......
......@@ -73,8 +73,6 @@ local cycle_loss_n = 0;
local word_length = 30;
# Whether to use tensorboard, bool
local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
# Path for tensorboard metrics, str
local metrics_dir = "./runs";
# Helper functions
local in_features(name) = !(std.length(std.find(name, features)) == 0);
......@@ -413,7 +411,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
type: "combo_scheduler",
},
tensorboard_writer: if use_tensorboard then {
serialization_dir: metrics_dir,
should_log_learning_rate: false,
should_log_parameter_statistics: false,
summary_interval: 100,
......
......@@ -71,8 +71,6 @@ local cycle_loss_n = 0;
local word_length = 30;
# Whether to use tensorboard, bool
local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
# Path for tensorboard metrics, str
local metrics_dir = "./runs";
# Helper functions
local in_features(name) = !(std.length(std.find(name, features)) == 0);
......@@ -382,7 +380,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
type: "combo_scheduler",
},
tensorboard_writer: if use_tensorboard then {
serialization_dir: metrics_dir,
should_log_learning_rate: false,
should_log_parameter_statistics: false,
summary_interval: 100,
......
......@@ -27,7 +27,7 @@ class SemanticMetricsTest(unittest.TestCase):
self.semrel, self.semrel_l = (("semrel", x) for x in [pred, gold])
self.head, self.head_l = (("head", x) for x in [pred, gold])
self.deprel, self.deprel_l = (("deprel", x) for x in [pred, gold])
# TODO(mklimasz) Add examples with correct dimension (with ROOT token)
# TODO(mklimasz) Set up an example with size 3x5x5
self.enhanced_head, self.enhanced_head_l = (("enhanced_head", x) for x in [None, None])
self.enhanced_deprel, self.enhanced_deprel_l = (("enhanced_deprel", x) for x in [None, None])
self.feats, self.feats_l = (("feats", x) for x in [pred_seq, gold_seq])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment