From 7ffc1072953f43d6b35575fd4e145dc211901f21 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Mon, 21 Dec 2020 14:10:43 +0100 Subject: [PATCH] Fix enhanced dependency parsing metrics. --- combo/utils/metrics.py | 7 ++++++- config.graph.template.jsonnet | 3 --- config.template.jsonnet | 3 --- tests/utils/test_metrics.py | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py index ae73db8..682e885 100644 --- a/combo/utils/metrics.py +++ b/combo/utils/metrics.py @@ -140,10 +140,14 @@ class AttachmentScores(metrics.Metric): correct_indices = predicted_indices.eq(gold_indices).long() * mask unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1) + if len(correct_indices.size()) > 2: + unlabeled_exact_match = unlabeled_exact_match.prod(dim=-1) correct_labels = predicted_labels.eq(gold_labels).long() * mask correct_labels_and_indices = correct_indices * correct_labels self.correct_indices = correct_labels_and_indices.flatten() labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1) + if len(correct_indices.size()) > 2: + labeled_exact_match = labeled_exact_match.prod(dim=-1) self._unlabeled_correct += correct_indices.sum() self._exact_unlabeled_correct += unlabeled_exact_match.sum() @@ -200,7 +204,8 @@ class SemanticMetrics(metrics.Metric): self.feats_score = SequenceBoolAccuracy(prod_last_dim=True) self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True) self.attachment_scores = AttachmentScores() - self.enhanced_attachment_scores = AttachmentScores() + # Ignore PADDING and OOV + self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1]) self.em_score = 0.0 def __call__( # type: ignore diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet index d55cb89..6975aba 100644 --- a/config.graph.template.jsonnet +++ b/config.graph.template.jsonnet @@ -73,8 +73,6 @@ local cycle_loss_n = 0; local word_length = 30; # Whether to use tensorboard, bool local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false; -# Path for tensorboard metrics, str -local metrics_dir = "./runs"; # Helper functions local in_features(name) = !(std.length(std.find(name, features)) == 0); @@ -413,7 +411,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't type: "combo_scheduler", }, tensorboard_writer: if use_tensorboard then { - serialization_dir: metrics_dir, should_log_learning_rate: false, should_log_parameter_statistics: false, summary_interval: 100, diff --git a/config.template.jsonnet b/config.template.jsonnet index 8e5ddc9..f41ba62 100644 --- a/config.template.jsonnet +++ b/config.template.jsonnet @@ -71,8 +71,6 @@ local cycle_loss_n = 0; local word_length = 30; # Whether to use tensorboard, bool local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false; -# Path for tensorboard metrics, str -local metrics_dir = "./runs"; # Helper functions local in_features(name) = !(std.length(std.find(name, features)) == 0); @@ -382,7 +380,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't type: "combo_scheduler", }, tensorboard_writer: if use_tensorboard then { - serialization_dir: metrics_dir, should_log_learning_rate: false, should_log_parameter_statistics: false, summary_interval: 100, diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py index 1d1ad3b..242eaa3 100644 --- a/tests/utils/test_metrics.py +++ b/tests/utils/test_metrics.py @@ -27,7 +27,7 @@ class SemanticMetricsTest(unittest.TestCase): self.semrel, self.semrel_l = (("semrel", x) for x in [pred, gold]) self.head, self.head_l = (("head", x) for x in [pred, gold]) self.deprel, self.deprel_l = (("deprel", x) for x in [pred, gold]) - # TODO(mklimasz) Add examples with correct dimension (with ROOT token) + # TODO(mklimasz) Set up an example with size 3x5x5 self.enhanced_head, self.enhanced_head_l = (("enhanced_head", x) for x in [None, None]) self.enhanced_deprel, self.enhanced_deprel_l = (("enhanced_deprel", x) for x in [None, None]) self.feats, self.feats_l = (("feats", x) for x in [pred_seq, gold_seq]) -- GitLab