diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py index 6799d4994222802af3cd1c6b869bd8a760523ffe..2dc02dc98e20ff7637d2333135857ec81cec009a 100644 --- a/combo/models/graph_parser.py +++ b/combo/models/graph_parser.py @@ -164,9 +164,9 @@ class GraphDependencyRelationModel(base.Predictor): heads_true: torch.Tensor, mask: torch.BoolTensor, sample_weights: torch.Tensor) -> torch.Tensor: - - true = true[true.long() > 0] - pred = pred[heads_true.long() == 1] + correct_heads_mask = heads_true.long() == 1 + true = true[correct_heads_mask] + pred = pred[correct_heads_mask] loss = F.cross_entropy(pred, true.long()) return loss.sum() / pred.size(0) diff --git a/combo/predict.py b/combo/predict.py index 070975f3ed3c7e3e7a75be91d311683bf91ff5f4..e52b42ef94f92239410489082530f564729ff18c 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -197,8 +197,13 @@ class SemanticMultitaskPredictor(predictor.Predictor): raise NotImplementedError(f"Unknown field name {field_name}!") if "enhanced_head" in predictions and predictions["enhanced_head"]: - graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]), - rel_scores=np.array(predictions["enhanced_deprel_prob"]), + # TODO off-by-one hotfix, refactor + h = np.array(predictions["enhanced_head"]) + h = np.concatenate((h[-1:], h[:-1])) + r = np.array(predictions["enhanced_deprel_prob"]) + r = np.concatenate((r[-1:], r[:-1])) + graph.sdp_to_dag_deps(arc_scores=h, + rel_scores=r, tree_tokens=tree_tokens, root_idx=self.vocab.get_token_index("root", "deprel_labels"), vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))