diff --git a/combo/nn/utils.py b/combo/nn/utils.py
index fbb17df555dc00415a97684522321c612b90b049..4333be0a1fc5736c87a27efb18426eaee1c5d2e7 100644
--- a/combo/nn/utils.py
+++ b/combo/nn/utils.py
@@ -14,12 +14,7 @@ StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"]
 
 def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
     pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log()
-    try:
-        return F.cross_entropy(pred, true, reduction="none") * mask
-    except Exception as e:
-        print("pred shape", pred.shape, "true shape", true.shape, "mask shape", mask.shape)
-        print(F.cross_entropy(pred, true, reduction="none").shape)
-        raise e
+    return F.cross_entropy(pred, true, reduction="none") * mask
 
 def tiny_value_of_dtype(dtype: torch.dtype):
     """