From 557b09e227740031d8be78d2423dc3097d8e4551 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Tue, 16 Jan 2024 13:19:38 +0100
Subject: [PATCH] Remove debug prints

---
 combo/nn/utils.py | 7 +------
 1 file changed, 1 insertion(+), 6 deletions(-)

diff --git a/combo/nn/utils.py b/combo/nn/utils.py
index fbb17df..4333be0 100644
--- a/combo/nn/utils.py
+++ b/combo/nn/utils.py
@@ -14,12 +14,7 @@ StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"]
 
 def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
     pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log()
-    try:
-        return F.cross_entropy(pred, true, reduction="none") * mask
-    except Exception as e:
-        print("pred shape", pred.shape, "true shape", true.shape, "mask shape", mask.shape)
-        print(F.cross_entropy(pred, true, reduction="none").shape)
-        raise e
+    return F.cross_entropy(pred, true, reduction="none") * mask
 
 def tiny_value_of_dtype(dtype: torch.dtype):
     """
-- 
GitLab