From 7ffc1072953f43d6b35575fd4e145dc211901f21 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Mon, 21 Dec 2020 14:10:43 +0100
Subject: [PATCH] Fix enhanced dependency parsing metrics.

---
 combo/utils/metrics.py        | 7 ++++++-
 config.graph.template.jsonnet | 3 ---
 config.template.jsonnet       | 3 ---
 tests/utils/test_metrics.py   | 2 +-
 4 files changed, 7 insertions(+), 8 deletions(-)

diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py
index ae73db8..682e885 100644
--- a/combo/utils/metrics.py
+++ b/combo/utils/metrics.py
@@ -140,10 +140,14 @@ class AttachmentScores(metrics.Metric):
 
         correct_indices = predicted_indices.eq(gold_indices).long() * mask
         unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1)
+        if len(correct_indices.size()) > 2:
+            unlabeled_exact_match = unlabeled_exact_match.prod(dim=-1)
         correct_labels = predicted_labels.eq(gold_labels).long() * mask
         correct_labels_and_indices = correct_indices * correct_labels
         self.correct_indices = correct_labels_and_indices.flatten()
         labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1)
+        if len(correct_indices.size()) > 2:
+            labeled_exact_match = labeled_exact_match.prod(dim=-1)
 
         self._unlabeled_correct += correct_indices.sum()
         self._exact_unlabeled_correct += unlabeled_exact_match.sum()
@@ -200,7 +204,8 @@ class SemanticMetrics(metrics.Metric):
         self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
         self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True)
         self.attachment_scores = AttachmentScores()
-        self.enhanced_attachment_scores = AttachmentScores()
+        # Ignore PADDING and OOV
+        self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1])
         self.em_score = 0.0
 
     def __call__(  # type: ignore
diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet
index d55cb89..6975aba 100644
--- a/config.graph.template.jsonnet
+++ b/config.graph.template.jsonnet
@@ -73,8 +73,6 @@ local cycle_loss_n = 0;
 local word_length = 30;
 # Whether to use tensorboard, bool
 local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
-# Path for tensorboard metrics, str
-local metrics_dir = "./runs";
 
 # Helper functions
 local in_features(name) = !(std.length(std.find(name, features)) == 0);
@@ -413,7 +411,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
             type: "combo_scheduler",
         },
         tensorboard_writer: if use_tensorboard then {
-            serialization_dir: metrics_dir,
             should_log_learning_rate: false,
             should_log_parameter_statistics: false,
             summary_interval: 100,
diff --git a/config.template.jsonnet b/config.template.jsonnet
index 8e5ddc9..f41ba62 100644
--- a/config.template.jsonnet
+++ b/config.template.jsonnet
@@ -71,8 +71,6 @@ local cycle_loss_n = 0;
 local word_length = 30;
 # Whether to use tensorboard, bool
 local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
-# Path for tensorboard metrics, str
-local metrics_dir = "./runs";
 
 # Helper functions
 local in_features(name) = !(std.length(std.find(name, features)) == 0);
@@ -382,7 +380,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
             type: "combo_scheduler",
         },
         tensorboard_writer: if use_tensorboard then {
-            serialization_dir: metrics_dir,
             should_log_learning_rate: false,
             should_log_parameter_statistics: false,
             summary_interval: 100,
diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py
index 1d1ad3b..242eaa3 100644
--- a/tests/utils/test_metrics.py
+++ b/tests/utils/test_metrics.py
@@ -27,7 +27,7 @@ class SemanticMetricsTest(unittest.TestCase):
         self.semrel, self.semrel_l = (("semrel", x) for x in [pred, gold])
         self.head, self.head_l = (("head", x) for x in [pred, gold])
         self.deprel, self.deprel_l = (("deprel", x) for x in [pred, gold])
-        # TODO(mklimasz) Add examples with correct dimension (with ROOT token)
+        # TODO(mklimasz) Set up an example with size 3x5x5
         self.enhanced_head, self.enhanced_head_l = (("enhanced_head", x) for x in [None, None])
         self.enhanced_deprel, self.enhanced_deprel_l = (("enhanced_deprel", x) for x in [None, None])
         self.feats, self.feats_l = (("feats", x) for x in [pred_seq, gold_seq])
-- 
GitLab