Skip to content
Snippets Groups Projects
Commit 7ffc1072 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Fix enhanced dependency parsing metrics.

parent 51be6680
No related branches found
No related tags found
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
...@@ -140,10 +140,14 @@ class AttachmentScores(metrics.Metric): ...@@ -140,10 +140,14 @@ class AttachmentScores(metrics.Metric):
correct_indices = predicted_indices.eq(gold_indices).long() * mask correct_indices = predicted_indices.eq(gold_indices).long() * mask
unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1) unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1)
if len(correct_indices.size()) > 2:
unlabeled_exact_match = unlabeled_exact_match.prod(dim=-1)
correct_labels = predicted_labels.eq(gold_labels).long() * mask correct_labels = predicted_labels.eq(gold_labels).long() * mask
correct_labels_and_indices = correct_indices * correct_labels correct_labels_and_indices = correct_indices * correct_labels
self.correct_indices = correct_labels_and_indices.flatten() self.correct_indices = correct_labels_and_indices.flatten()
labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1) labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1)
if len(correct_indices.size()) > 2:
labeled_exact_match = labeled_exact_match.prod(dim=-1)
self._unlabeled_correct += correct_indices.sum() self._unlabeled_correct += correct_indices.sum()
self._exact_unlabeled_correct += unlabeled_exact_match.sum() self._exact_unlabeled_correct += unlabeled_exact_match.sum()
...@@ -200,7 +204,8 @@ class SemanticMetrics(metrics.Metric): ...@@ -200,7 +204,8 @@ class SemanticMetrics(metrics.Metric):
self.feats_score = SequenceBoolAccuracy(prod_last_dim=True) self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True) self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True)
self.attachment_scores = AttachmentScores() self.attachment_scores = AttachmentScores()
self.enhanced_attachment_scores = AttachmentScores() # Ignore PADDING and OOV
self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1])
self.em_score = 0.0 self.em_score = 0.0
def __call__( # type: ignore def __call__( # type: ignore
......
...@@ -73,8 +73,6 @@ local cycle_loss_n = 0; ...@@ -73,8 +73,6 @@ local cycle_loss_n = 0;
local word_length = 30; local word_length = 30;
# Whether to use tensorboard, bool # Whether to use tensorboard, bool
local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false; local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
# Path for tensorboard metrics, str
local metrics_dir = "./runs";
# Helper functions # Helper functions
local in_features(name) = !(std.length(std.find(name, features)) == 0); local in_features(name) = !(std.length(std.find(name, features)) == 0);
...@@ -413,7 +411,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ...@@ -413,7 +411,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
type: "combo_scheduler", type: "combo_scheduler",
}, },
tensorboard_writer: if use_tensorboard then { tensorboard_writer: if use_tensorboard then {
serialization_dir: metrics_dir,
should_log_learning_rate: false, should_log_learning_rate: false,
should_log_parameter_statistics: false, should_log_parameter_statistics: false,
summary_interval: 100, summary_interval: 100,
......
...@@ -71,8 +71,6 @@ local cycle_loss_n = 0; ...@@ -71,8 +71,6 @@ local cycle_loss_n = 0;
local word_length = 30; local word_length = 30;
# Whether to use tensorboard, bool # Whether to use tensorboard, bool
local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false; local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
# Path for tensorboard metrics, str
local metrics_dir = "./runs";
# Helper functions # Helper functions
local in_features(name) = !(std.length(std.find(name, features)) == 0); local in_features(name) = !(std.length(std.find(name, features)) == 0);
...@@ -382,7 +380,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ...@@ -382,7 +380,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
type: "combo_scheduler", type: "combo_scheduler",
}, },
tensorboard_writer: if use_tensorboard then { tensorboard_writer: if use_tensorboard then {
serialization_dir: metrics_dir,
should_log_learning_rate: false, should_log_learning_rate: false,
should_log_parameter_statistics: false, should_log_parameter_statistics: false,
summary_interval: 100, summary_interval: 100,
......
...@@ -27,7 +27,7 @@ class SemanticMetricsTest(unittest.TestCase): ...@@ -27,7 +27,7 @@ class SemanticMetricsTest(unittest.TestCase):
self.semrel, self.semrel_l = (("semrel", x) for x in [pred, gold]) self.semrel, self.semrel_l = (("semrel", x) for x in [pred, gold])
self.head, self.head_l = (("head", x) for x in [pred, gold]) self.head, self.head_l = (("head", x) for x in [pred, gold])
self.deprel, self.deprel_l = (("deprel", x) for x in [pred, gold]) self.deprel, self.deprel_l = (("deprel", x) for x in [pred, gold])
# TODO(mklimasz) Add examples with correct dimension (with ROOT token) # TODO(mklimasz) Set up an example with size 3x5x5
self.enhanced_head, self.enhanced_head_l = (("enhanced_head", x) for x in [None, None]) self.enhanced_head, self.enhanced_head_l = (("enhanced_head", x) for x in [None, None])
self.enhanced_deprel, self.enhanced_deprel_l = (("enhanced_deprel", x) for x in [None, None]) self.enhanced_deprel, self.enhanced_deprel_l = (("enhanced_deprel", x) for x in [None, None])
self.feats, self.feats_l = (("feats", x) for x in [pred_seq, gold_seq]) self.feats, self.feats_l = (("feats", x) for x in [pred_seq, gold_seq])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment