diff --git a/combo/data/__init__.py b/combo/data/__init__.py index 2d69cca8bb04168159e0c2b3557dabc4b7e18ab2..3c35c824c4f5dace97ae45f9bc38dc9b5b21d83c 100644 --- a/combo/data/__init__.py +++ b/combo/data/__init__.py @@ -2,3 +2,4 @@ from .samplers import TokenCountBatchSampler from .token import Token from .token_indexers import * from .api import * +from .vocabulary import Vocabulary \ No newline at end of file diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index 7ddb7e37fbd1570f895670119f317a27ffc2a2f0..eb48859718b0f88d6a1892454d52d9ea9a673b27 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -160,7 +160,8 @@ class Vocabulary: print(mapping[i].replace("\n", "@@NEWLINE@@"), file=token_file) def is_padded(self, namespace: str) -> bool: - return self._vocab[namespace].get_itos()[0] == self._padding_token + namespace_itos = self._vocab[namespace].get_itos() + return len(namespace_itos) > 0 and namespace_itos[0] == self._padding_token def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE): """ diff --git a/requirements.txt b/requirements.txt index 75e6626a61db39cda4966e5ba80861da8af6fad5..75adf56aaaff9694b1661fbc2d405a63885a98ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ pytorch-lightning~=1.9.0 requests~=2.28.2 tqdm~=4.64.1 urllib3~=1.26.14 -filelock~=3.9.0 \ No newline at end of file +filelock~=3.9.0 +pytest~=7.2.2 \ No newline at end of file diff --git a/tests/data/test_vocabulary.py b/tests/data/test_vocabulary.py new file mode 100644 index 0000000000000000000000000000000000000000..5049e64531bf5019947ecf0ccad50499a9637b21 --- /dev/null +++ b/tests/data/test_vocabulary.py @@ -0,0 +1,108 @@ +import unittest + +from combo.data import Vocabulary + +class VocabularyTest(unittest.TestCase): + def test_empty_padded_namespace_is_padded(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + self.assertTrue(v.is_padded('padded_example')) + + def test_empty_non_padded_namespace_is_not_padded(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + self.assertFalse(v.is_padded('non_padded_example')) + + def test_add_token_to_padded_namespace_correct_vocab_size(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + v.add_token_to_namespace('test_token', 'padded_example') + self.assertEqual(v.get_vocab_size('padded_example'), 3) + + def test_add_token_to_padded_namespace_correct_token_index(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + v.add_token_to_namespace('test_token', 'padded_example') + self.assertEqual(v.get_token_index('test_token', 'padded_example'), 2) + + def test_add_token_to_padded_namespace_correct_token_for_index(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + v.add_token_to_namespace('test_token', 'padded_example') + self.assertEqual(v.get_token_from_index(2, 'padded_example'), 'test_token') + + def test_add_tokens_to_padded_namespace_correct_vocab_size(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'padded_example') + self.assertEqual(v.get_vocab_size('padded_example'), 4) + + def test_add_tokens_to_padded_namespace_correct_token_index(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'padded_example') + self.assertEqual(v.get_token_index('test_token1', 'padded_example'), 2) + self.assertEqual(v.get_token_index('test_token2', 'padded_example'), 3) + + def test_add_tokens_to_padded_namespace_correct_token_for_index(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'padded_example') + self.assertEqual(v.get_token_from_index(2, 'padded_example'), 'test_token1') + self.assertEqual(v.get_token_from_index(3, 'padded_example'), 'test_token2') + + def test_add_token_to_non_padded_namespace_correct_vocab_size(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + v.add_token_to_namespace('test_token', 'non_padded_example') + self.assertEqual(v.get_vocab_size('non_padded_example'), 1) + + def test_add_token_to_non_padded_namespace_correct_token_index(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + v.add_token_to_namespace('test_token', 'non_padded_example') + self.assertEqual(v.get_token_index('test_token', 'non_padded_example'), 0) + + def test_add_token_to_non_padded_namespace_correct_token_for_index(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + v.add_token_to_namespace('test_token', 'non_padded_example') + self.assertEqual(v.get_token_from_index(0, 'non_padded_example'), 'test_token') + + def test_add_tokens_to_non_padded_namespace_correct_vocab_size(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'non_padded_example') + self.assertEqual(v.get_vocab_size('non_padded_example'), 2) + + def test_add_tokens_to_non_padded_namespace_correct_token_index(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'non_padded_example') + self.assertEqual(v.get_token_index('test_token1', 'non_padded_example'), 0) + self.assertEqual(v.get_token_index('test_token2', 'non_padded_example'), 1) + + def test_add_tokens_to_non_padded_namespace_correct_token_for_index(self): + padded_namespaces = ['padded_example'] + non_padded_namespaces = ['non_padded_example'] + v = Vocabulary(non_padded_namespaces=non_padded_namespaces) + v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'non_padded_example') + self.assertEqual(v.get_token_from_index(0, 'non_padded_example'), 'test_token1') + self.assertEqual(v.get_token_from_index(1, 'non_padded_example'), 'test_token2') + + +if __name__ == '__main__': + unittest.main()