Skip to content
Snippets Groups Projects
Commit e3dc6fdf authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

Set up tests for Vocabulary.py

parent 56f29c3a
No related branches found
No related tags found
1 merge request!46Merge COMBO 3.0 into master
......@@ -2,3 +2,4 @@ from .samplers import TokenCountBatchSampler
from .token import Token
from .token_indexers import *
from .api import *
from .vocabulary import Vocabulary
\ No newline at end of file
......@@ -160,7 +160,8 @@ class Vocabulary:
print(mapping[i].replace("\n", "@@NEWLINE@@"), file=token_file)
def is_padded(self, namespace: str) -> bool:
return self._vocab[namespace].get_itos()[0] == self._padding_token
namespace_itos = self._vocab[namespace].get_itos()
return len(namespace_itos) > 0 and namespace_itos[0] == self._padding_token
def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE):
"""
......
......@@ -10,3 +10,4 @@ requests~=2.28.2
tqdm~=4.64.1
urllib3~=1.26.14
filelock~=3.9.0
pytest~=7.2.2
\ No newline at end of file
import unittest
from combo.data import Vocabulary
class VocabularyTest(unittest.TestCase):
def test_empty_padded_namespace_is_padded(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
self.assertTrue(v.is_padded('padded_example'))
def test_empty_non_padded_namespace_is_not_padded(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
self.assertFalse(v.is_padded('non_padded_example'))
def test_add_token_to_padded_namespace_correct_vocab_size(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_token_to_namespace('test_token', 'padded_example')
self.assertEqual(v.get_vocab_size('padded_example'), 3)
def test_add_token_to_padded_namespace_correct_token_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_token_to_namespace('test_token', 'padded_example')
self.assertEqual(v.get_token_index('test_token', 'padded_example'), 2)
def test_add_token_to_padded_namespace_correct_token_for_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_token_to_namespace('test_token', 'padded_example')
self.assertEqual(v.get_token_from_index(2, 'padded_example'), 'test_token')
def test_add_tokens_to_padded_namespace_correct_vocab_size(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'padded_example')
self.assertEqual(v.get_vocab_size('padded_example'), 4)
def test_add_tokens_to_padded_namespace_correct_token_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'padded_example')
self.assertEqual(v.get_token_index('test_token1', 'padded_example'), 2)
self.assertEqual(v.get_token_index('test_token2', 'padded_example'), 3)
def test_add_tokens_to_padded_namespace_correct_token_for_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'padded_example')
self.assertEqual(v.get_token_from_index(2, 'padded_example'), 'test_token1')
self.assertEqual(v.get_token_from_index(3, 'padded_example'), 'test_token2')
def test_add_token_to_non_padded_namespace_correct_vocab_size(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_token_to_namespace('test_token', 'non_padded_example')
self.assertEqual(v.get_vocab_size('non_padded_example'), 1)
def test_add_token_to_non_padded_namespace_correct_token_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_token_to_namespace('test_token', 'non_padded_example')
self.assertEqual(v.get_token_index('test_token', 'non_padded_example'), 0)
def test_add_token_to_non_padded_namespace_correct_token_for_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_token_to_namespace('test_token', 'non_padded_example')
self.assertEqual(v.get_token_from_index(0, 'non_padded_example'), 'test_token')
def test_add_tokens_to_non_padded_namespace_correct_vocab_size(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'non_padded_example')
self.assertEqual(v.get_vocab_size('non_padded_example'), 2)
def test_add_tokens_to_non_padded_namespace_correct_token_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'non_padded_example')
self.assertEqual(v.get_token_index('test_token1', 'non_padded_example'), 0)
self.assertEqual(v.get_token_index('test_token2', 'non_padded_example'), 1)
def test_add_tokens_to_non_padded_namespace_correct_token_for_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'non_padded_example')
self.assertEqual(v.get_token_from_index(0, 'non_padded_example'), 'test_token1')
self.assertEqual(v.get_token_from_index(1, 'non_padded_example'), 'test_token2')
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment