From 8fcf797de7058fe48277de249948ba920352a548 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com>
Date: Tue, 28 Mar 2023 20:02:07 +0200
Subject: [PATCH] Add checks and checks_test

---
 combo/utils/checks.py      | 19 +++++++++++++++++--
 tests/utils/test_checks.py | 37 +++++++++++++++++++++++++++++++++++++
 2 files changed, 54 insertions(+), 2 deletions(-)
 create mode 100644 tests/utils/test_checks.py

diff --git a/combo/utils/checks.py b/combo/utils/checks.py
index ea76ac1..d6269a9 100644
--- a/combo/utils/checks.py
+++ b/combo/utils/checks.py
@@ -1,3 +1,8 @@
+"""
+Adapted from COMBO
+Author: Mateusz Klimaszewski
+"""
+
 import torch
 
 
@@ -8,8 +13,18 @@ class ConfigurationError(Exception):
 
 
 def file_exists(*paths):
-    pass
+    """Check whether paths exists."""
+    for path in paths:
+        if path is None:
+            raise ConfigurationError("File cannot be None")
+        if not os.path.exists(path):
+            raise ConfigurationError(f"Could not find the file at path: '{path}'")
 
 
 def check_size_match(size_1: torch.Size, size_2: torch.Size, tensor_1_name: str, tensor_2_name: str):
-    pass
+    """Check if tensors' sizes are the same."""
+    if size_1 != size_2:
+        raise ConfigurationError(
+            f"{tensor_1_name} must match {tensor_2_name}, but got {size_1} "
+            f"and {size_2} instead"
+        )
\ No newline at end of file
diff --git a/tests/utils/test_checks.py b/tests/utils/test_checks.py
new file mode 100644
index 0000000..c9409cd
--- /dev/null
+++ b/tests/utils/test_checks.py
@@ -0,0 +1,37 @@
+"""Checks tests."""
+import unittest
+
+import torch
+
+from combo.utils import checks, ConfigurationError
+
+
+class SizeCheckTest(unittest.TestCase):
+
+    def test_equal_sizes(self):
+        # given
+        size = (10, 2)
+        tensor1 = torch.rand(size)
+        tensor2 = torch.rand(size)
+
+        # when
+        checks.check_size_match(size_1=tensor1.size(),
+                                size_2=tensor2.size(),
+                                tensor_1_name="", tensor_2_name="")
+
+        # then
+        # nothing happens
+        self.assertTrue(True)
+
+    def test_different_sizes(self):
+        # given
+        size1 = (10, 2)
+        size2 = (20, 1)
+        tensor1 = torch.rand(size1)
+        tensor2 = torch.rand(size2)
+
+        # when/then
+        with self.assertRaises(ConfigurationError):
+            checks.check_size_match(size_1=tensor1.size(),
+                                    size_2=tensor2.size(),
+                                    tensor_1_name="", tensor_2_name="")
-- 
GitLab