Skip to content
Snippets Groups Projects
Commit 826e57a7 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Hotfix off by one in enhanced graphs.

parent e295e563
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
This commit is part of merge request !8. Comments created here will be created in the context of that merge request.
......@@ -164,9 +164,9 @@ class GraphDependencyRelationModel(base.Predictor):
heads_true: torch.Tensor,
mask: torch.BoolTensor,
sample_weights: torch.Tensor) -> torch.Tensor:
true = true[true.long() > 0]
pred = pred[heads_true.long() == 1]
correct_heads_mask = heads_true.long() == 1
true = true[correct_heads_mask]
pred = pred[correct_heads_mask]
loss = F.cross_entropy(pred, true.long())
return loss.sum() / pred.size(0)
......
......@@ -197,8 +197,13 @@ class SemanticMultitaskPredictor(predictor.Predictor):
raise NotImplementedError(f"Unknown field name {field_name}!")
if "enhanced_head" in predictions and predictions["enhanced_head"]:
graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
rel_scores=np.array(predictions["enhanced_deprel_prob"]),
# TODO off-by-one hotfix, refactor
h = np.array(predictions["enhanced_head"])
h = np.concatenate((h[-1:], h[:-1]))
r = np.array(predictions["enhanced_deprel_prob"])
r = np.concatenate((r[-1:], r[:-1]))
graph.sdp_to_dag_deps(arc_scores=h,
rel_scores=r,
tree_tokens=tree_tokens,
root_idx=self.vocab.get_token_index("root", "deprel_labels"),
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment