From 557b09e227740031d8be78d2423dc3097d8e4551 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Tue, 16 Jan 2024 13:19:38 +0100 Subject: [PATCH] Remove debug prints --- combo/nn/utils.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/combo/nn/utils.py b/combo/nn/utils.py index fbb17df..4333be0 100644 --- a/combo/nn/utils.py +++ b/combo/nn/utils.py @@ -14,12 +14,7 @@ StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"] def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log() - try: - return F.cross_entropy(pred, true, reduction="none") * mask - except Exception as e: - print("pred shape", pred.shape, "true shape", true.shape, "mask shape", mask.shape) - print(F.cross_entropy(pred, true, reduction="none").shape) - raise e + return F.cross_entropy(pred, true, reduction="none") * mask def tiny_value_of_dtype(dtype: torch.dtype): """ -- GitLab