Skip to content
Snippets Groups Projects
Commit 89790343 authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

Add checks and checks_test

parent 15f4ca6b
No related branches found
No related tags found
1 merge request!46Merge COMBO 3.0 into master
"""
Adapted from COMBO
Author: Mateusz Klimaszewski
"""
import torch
......@@ -8,8 +13,18 @@ class ConfigurationError(Exception):
def file_exists(*paths):
pass
"""Check whether paths exists."""
for path in paths:
if path is None:
raise ConfigurationError("File cannot be None")
if not os.path.exists(path):
raise ConfigurationError(f"Could not find the file at path: '{path}'")
def check_size_match(size_1: torch.Size, size_2: torch.Size, tensor_1_name: str, tensor_2_name: str):
pass
"""Check if tensors' sizes are the same."""
if size_1 != size_2:
raise ConfigurationError(
f"{tensor_1_name} must match {tensor_2_name}, but got {size_1} "
f"and {size_2} instead"
)
\ No newline at end of file
"""Checks tests."""
import unittest
import torch
from combo.utils import checks, ConfigurationError
class SizeCheckTest(unittest.TestCase):
def test_equal_sizes(self):
# given
size = (10, 2)
tensor1 = torch.rand(size)
tensor2 = torch.rand(size)
# when
checks.check_size_match(size_1=tensor1.size(),
size_2=tensor2.size(),
tensor_1_name="", tensor_2_name="")
# then
# nothing happens
self.assertTrue(True)
def test_different_sizes(self):
# given
size1 = (10, 2)
size2 = (20, 1)
tensor1 = torch.rand(size1)
tensor2 = torch.rand(size2)
# when/then
with self.assertRaises(ConfigurationError):
checks.check_size_match(size_1=tensor1.size(),
size_2=tensor2.size(),
tensor_1_name="", tensor_2_name="")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment