From 826e57a756c8ba96b26c06406c03ac57807b6e19 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Fri, 11 Dec 2020 10:19:12 +0100 Subject: [PATCH] Hotfix off by one in enhanced graphs. --- combo/models/graph_parser.py | 6 +++--- combo/predict.py | 9 +++++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py index 6799d49..2dc02dc 100644 --- a/combo/models/graph_parser.py +++ b/combo/models/graph_parser.py @@ -164,9 +164,9 @@ class GraphDependencyRelationModel(base.Predictor): heads_true: torch.Tensor, mask: torch.BoolTensor, sample_weights: torch.Tensor) -> torch.Tensor: - - true = true[true.long() > 0] - pred = pred[heads_true.long() == 1] + correct_heads_mask = heads_true.long() == 1 + true = true[correct_heads_mask] + pred = pred[correct_heads_mask] loss = F.cross_entropy(pred, true.long()) return loss.sum() / pred.size(0) diff --git a/combo/predict.py b/combo/predict.py index 070975f..e52b42e 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -197,8 +197,13 @@ class SemanticMultitaskPredictor(predictor.Predictor): raise NotImplementedError(f"Unknown field name {field_name}!") if "enhanced_head" in predictions and predictions["enhanced_head"]: - graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]), - rel_scores=np.array(predictions["enhanced_deprel_prob"]), + # TODO off-by-one hotfix, refactor + h = np.array(predictions["enhanced_head"]) + h = np.concatenate((h[-1:], h[:-1])) + r = np.array(predictions["enhanced_deprel_prob"]) + r = np.concatenate((r[-1:], r[:-1])) + graph.sdp_to_dag_deps(arc_scores=h, + rel_scores=r, tree_tokens=tree_tokens, root_idx=self.vocab.get_token_index("root", "deprel_labels"), vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) -- GitLab