diff --git a/combo/utils/checks.py b/combo/utils/checks.py index ea76ac1a49f53219471163f8c83126305f0cf138..d6269a9cf1024fe2a9c33090dd43372a0d8cd277 100644 --- a/combo/utils/checks.py +++ b/combo/utils/checks.py @@ -1,3 +1,8 @@ +""" +Adapted from COMBO +Author: Mateusz Klimaszewski +""" + import torch @@ -8,8 +13,18 @@ class ConfigurationError(Exception): def file_exists(*paths): - pass + """Check whether paths exists.""" + for path in paths: + if path is None: + raise ConfigurationError("File cannot be None") + if not os.path.exists(path): + raise ConfigurationError(f"Could not find the file at path: '{path}'") def check_size_match(size_1: torch.Size, size_2: torch.Size, tensor_1_name: str, tensor_2_name: str): - pass + """Check if tensors' sizes are the same.""" + if size_1 != size_2: + raise ConfigurationError( + f"{tensor_1_name} must match {tensor_2_name}, but got {size_1} " + f"and {size_2} instead" + ) \ No newline at end of file diff --git a/tests/utils/test_checks.py b/tests/utils/test_checks.py new file mode 100644 index 0000000000000000000000000000000000000000..c9409cd4cc97a45cf2479862a062122461c0ddb2 --- /dev/null +++ b/tests/utils/test_checks.py @@ -0,0 +1,37 @@ +"""Checks tests.""" +import unittest + +import torch + +from combo.utils import checks, ConfigurationError + + +class SizeCheckTest(unittest.TestCase): + + def test_equal_sizes(self): + # given + size = (10, 2) + tensor1 = torch.rand(size) + tensor2 = torch.rand(size) + + # when + checks.check_size_match(size_1=tensor1.size(), + size_2=tensor2.size(), + tensor_1_name="", tensor_2_name="") + + # then + # nothing happens + self.assertTrue(True) + + def test_different_sizes(self): + # given + size1 = (10, 2) + size2 = (20, 1) + tensor1 = torch.rand(size1) + tensor2 = torch.rand(size2) + + # when/then + with self.assertRaises(ConfigurationError): + checks.check_size_match(size_1=tensor1.size(), + size_2=tensor2.size(), + tensor_1_name="", tensor_2_name="")