diff --git a/combo/nn/utils.py b/combo/nn/utils.py index fbb17df555dc00415a97684522321c612b90b049..4333be0a1fc5736c87a27efb18426eaee1c5d2e7 100644 --- a/combo/nn/utils.py +++ b/combo/nn/utils.py @@ -14,12 +14,7 @@ StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"] def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log() - try: - return F.cross_entropy(pred, true, reduction="none") * mask - except Exception as e: - print("pred shape", pred.shape, "true shape", true.shape, "mask shape", mask.shape) - print(F.cross_entropy(pred, true, reduction="none").shape) - raise e + return F.cross_entropy(pred, true, reduction="none") * mask def tiny_value_of_dtype(dtype: torch.dtype): """