From 8fcf797de7058fe48277de249948ba920352a548 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com> Date: Tue, 28 Mar 2023 20:02:07 +0200 Subject: [PATCH] Add checks and checks_test --- combo/utils/checks.py | 19 +++++++++++++++++-- tests/utils/test_checks.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 tests/utils/test_checks.py diff --git a/combo/utils/checks.py b/combo/utils/checks.py index ea76ac1..d6269a9 100644 --- a/combo/utils/checks.py +++ b/combo/utils/checks.py @@ -1,3 +1,8 @@ +""" +Adapted from COMBO +Author: Mateusz Klimaszewski +""" + import torch @@ -8,8 +13,18 @@ class ConfigurationError(Exception): def file_exists(*paths): - pass + """Check whether paths exists.""" + for path in paths: + if path is None: + raise ConfigurationError("File cannot be None") + if not os.path.exists(path): + raise ConfigurationError(f"Could not find the file at path: '{path}'") def check_size_match(size_1: torch.Size, size_2: torch.Size, tensor_1_name: str, tensor_2_name: str): - pass + """Check if tensors' sizes are the same.""" + if size_1 != size_2: + raise ConfigurationError( + f"{tensor_1_name} must match {tensor_2_name}, but got {size_1} " + f"and {size_2} instead" + ) \ No newline at end of file diff --git a/tests/utils/test_checks.py b/tests/utils/test_checks.py new file mode 100644 index 0000000..c9409cd --- /dev/null +++ b/tests/utils/test_checks.py @@ -0,0 +1,37 @@ +"""Checks tests.""" +import unittest + +import torch + +from combo.utils import checks, ConfigurationError + + +class SizeCheckTest(unittest.TestCase): + + def test_equal_sizes(self): + # given + size = (10, 2) + tensor1 = torch.rand(size) + tensor2 = torch.rand(size) + + # when + checks.check_size_match(size_1=tensor1.size(), + size_2=tensor2.size(), + tensor_1_name="", tensor_2_name="") + + # then + # nothing happens + self.assertTrue(True) + + def test_different_sizes(self): + # given + size1 = (10, 2) + size2 = (20, 1) + tensor1 = torch.rand(size1) + tensor2 = torch.rand(size2) + + # when/then + with self.assertRaises(ConfigurationError): + checks.check_size_match(size_1=tensor1.size(), + size_2=tensor2.size(), + tensor_1_name="", tensor_2_name="") -- GitLab