Skip to content
Snippets Groups Projects
Commit 826e57a7 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Hotfix off by one in enhanced graphs.

parent e295e563
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
This commit is part of merge request !8. Comments created here will be created in the context of that merge request.
...@@ -164,9 +164,9 @@ class GraphDependencyRelationModel(base.Predictor): ...@@ -164,9 +164,9 @@ class GraphDependencyRelationModel(base.Predictor):
heads_true: torch.Tensor, heads_true: torch.Tensor,
mask: torch.BoolTensor, mask: torch.BoolTensor,
sample_weights: torch.Tensor) -> torch.Tensor: sample_weights: torch.Tensor) -> torch.Tensor:
correct_heads_mask = heads_true.long() == 1
true = true[true.long() > 0] true = true[correct_heads_mask]
pred = pred[heads_true.long() == 1] pred = pred[correct_heads_mask]
loss = F.cross_entropy(pred, true.long()) loss = F.cross_entropy(pred, true.long())
return loss.sum() / pred.size(0) return loss.sum() / pred.size(0)
......
...@@ -197,8 +197,13 @@ class SemanticMultitaskPredictor(predictor.Predictor): ...@@ -197,8 +197,13 @@ class SemanticMultitaskPredictor(predictor.Predictor):
raise NotImplementedError(f"Unknown field name {field_name}!") raise NotImplementedError(f"Unknown field name {field_name}!")
if "enhanced_head" in predictions and predictions["enhanced_head"]: if "enhanced_head" in predictions and predictions["enhanced_head"]:
graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]), # TODO off-by-one hotfix, refactor
rel_scores=np.array(predictions["enhanced_deprel_prob"]), h = np.array(predictions["enhanced_head"])
h = np.concatenate((h[-1:], h[:-1]))
r = np.array(predictions["enhanced_deprel_prob"])
r = np.concatenate((r[-1:], r[:-1]))
graph.sdp_to_dag_deps(arc_scores=h,
rel_scores=r,
tree_tokens=tree_tokens, tree_tokens=tree_tokens,
root_idx=self.vocab.get_token_index("root", "deprel_labels"), root_idx=self.vocab.get_token_index("root", "deprel_labels"),
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment