Skip to content
Snippets Groups Projects
Commit 557b09e2 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Remove debug prints

parent 3c877323
1 merge request!46Merge COMBO 3.0 into master
...@@ -14,12 +14,7 @@ StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"] ...@@ -14,12 +14,7 @@ StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"]
def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log() pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log()
try: return F.cross_entropy(pred, true, reduction="none") * mask
return F.cross_entropy(pred, true, reduction="none") * mask
except Exception as e:
print("pred shape", pred.shape, "true shape", true.shape, "mask shape", mask.shape)
print(F.cross_entropy(pred, true, reduction="none").shape)
raise e
def tiny_value_of_dtype(dtype: torch.dtype): def tiny_value_of_dtype(dtype: torch.dtype):
""" """
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment