Skip to content
Snippets Groups Projects
Commit 826e57a7 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Hotfix off by one in enhanced graphs.

parent e295e563
No related branches found
No related tags found
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
......@@ -164,9 +164,9 @@ class GraphDependencyRelationModel(base.Predictor):
heads_true: torch.Tensor,
mask: torch.BoolTensor,
sample_weights: torch.Tensor) -> torch.Tensor:
true = true[true.long() > 0]
pred = pred[heads_true.long() == 1]
correct_heads_mask = heads_true.long() == 1
true = true[correct_heads_mask]
pred = pred[correct_heads_mask]
loss = F.cross_entropy(pred, true.long())
return loss.sum() / pred.size(0)
......
......@@ -197,8 +197,13 @@ class SemanticMultitaskPredictor(predictor.Predictor):
raise NotImplementedError(f"Unknown field name {field_name}!")
if "enhanced_head" in predictions and predictions["enhanced_head"]:
graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
rel_scores=np.array(predictions["enhanced_deprel_prob"]),
# TODO off-by-one hotfix, refactor
h = np.array(predictions["enhanced_head"])
h = np.concatenate((h[-1:], h[:-1]))
r = np.array(predictions["enhanced_deprel_prob"])
r = np.concatenate((r[-1:], r[:-1]))
graph.sdp_to_dag_deps(arc_scores=h,
rel_scores=r,
tree_tokens=tree_tokens,
root_idx=self.vocab.get_token_index("root", "deprel_labels"),
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment