diff --git a/combo/models/parser.py b/combo/models/parser.py index bb0fc91834778323d594a879c2918ab651cc50b4..9b38be80ac03e7d6ef0435e7e6e4eb1d01e23f40 100644 --- a/combo/models/parser.py +++ b/combo/models/parser.py @@ -155,8 +155,10 @@ class DependencyRelationModel(base.Predictor): output = head_output output["embedding"] = dep_rel_pred #import pdb;pdb.set_trace() - output["deprel_label_distribution"] = F.softmax(relation_prediction[:, 1:, 1:], dim=-1) - output["deprel_tree_distribution"] = head_pred_soft + # output["deprel_label_distribution"] = F.softmax(relation_prediction[:, 1:, 1:], dim=-1) + output["deprel_label_distribution"] = relation_prediction[:, 1:, 1:] + # output["deprel_tree_distribution"] = head_pred_soft + output["deprel_tree_distribution"] = head_pred if self.training: output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])