Skip to content
Snippets Groups Projects
Commit 761b4a25 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix masked cross entropy.

parent 08bf1a1e
No related merge requests found
......@@ -80,7 +80,6 @@ class LemmatizerModel(base.Predictor):
@classmethod
def from_vocab(cls,
vocab: data.Vocabulary,
char_vocab_namespace: str,
lemma_vocab_namespace: str,
......
......@@ -3,6 +3,5 @@ import torch.nn.functional as F
def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
mask = mask.float().unsqueeze(-1)
pred = pred + (mask + 1e-45).log()
pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log()
return F.cross_entropy(pred, true, reduction='none') * mask
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment