Skip to content
Snippets Groups Projects
Commit 372a01bd authored by Maja Jabłońska's avatar Maja Jabłońska Committed by Martyna Wiącek
Browse files

Set up tests for Vocabulary.py

parent 18cb59e7
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
...@@ -2,3 +2,4 @@ from .samplers import TokenCountBatchSampler ...@@ -2,3 +2,4 @@ from .samplers import TokenCountBatchSampler
from .token import Token from .token import Token
from .token_indexers import * from .token_indexers import *
from .api import * from .api import *
from .vocabulary import Vocabulary
\ No newline at end of file
...@@ -160,7 +160,8 @@ class Vocabulary: ...@@ -160,7 +160,8 @@ class Vocabulary:
print(mapping[i].replace("\n", "@@NEWLINE@@"), file=token_file) print(mapping[i].replace("\n", "@@NEWLINE@@"), file=token_file)
def is_padded(self, namespace: str) -> bool: def is_padded(self, namespace: str) -> bool:
return self._vocab[namespace].get_itos()[0] == self._padding_token namespace_itos = self._vocab[namespace].get_itos()
return len(namespace_itos) > 0 and namespace_itos[0] == self._padding_token
def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE): def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE):
""" """
......
...@@ -9,4 +9,5 @@ pytorch-lightning~=1.9.0 ...@@ -9,4 +9,5 @@ pytorch-lightning~=1.9.0
requests~=2.28.2 requests~=2.28.2
tqdm~=4.64.1 tqdm~=4.64.1
urllib3~=1.26.14 urllib3~=1.26.14
filelock~=3.9.0 filelock~=3.9.0
\ No newline at end of file pytest~=7.2.2
\ No newline at end of file
import unittest
from combo.data import Vocabulary
class VocabularyTest(unittest.TestCase):
def test_empty_padded_namespace_is_padded(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
self.assertTrue(v.is_padded('padded_example'))
def test_empty_non_padded_namespace_is_not_padded(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
self.assertFalse(v.is_padded('non_padded_example'))
def test_add_token_to_padded_namespace_correct_vocab_size(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_token_to_namespace('test_token', 'padded_example')
self.assertEqual(v.get_vocab_size('padded_example'), 3)
def test_add_token_to_padded_namespace_correct_token_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_token_to_namespace('test_token', 'padded_example')
self.assertEqual(v.get_token_index('test_token', 'padded_example'), 2)
def test_add_token_to_padded_namespace_correct_token_for_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_token_to_namespace('test_token', 'padded_example')
self.assertEqual(v.get_token_from_index(2, 'padded_example'), 'test_token')
def test_add_tokens_to_padded_namespace_correct_vocab_size(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'padded_example')
self.assertEqual(v.get_vocab_size('padded_example'), 4)
def test_add_tokens_to_padded_namespace_correct_token_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'padded_example')
self.assertEqual(v.get_token_index('test_token1', 'padded_example'), 2)
self.assertEqual(v.get_token_index('test_token2', 'padded_example'), 3)
def test_add_tokens_to_padded_namespace_correct_token_for_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'padded_example')
self.assertEqual(v.get_token_from_index(2, 'padded_example'), 'test_token1')
self.assertEqual(v.get_token_from_index(3, 'padded_example'), 'test_token2')
def test_add_token_to_non_padded_namespace_correct_vocab_size(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_token_to_namespace('test_token', 'non_padded_example')
self.assertEqual(v.get_vocab_size('non_padded_example'), 1)
def test_add_token_to_non_padded_namespace_correct_token_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_token_to_namespace('test_token', 'non_padded_example')
self.assertEqual(v.get_token_index('test_token', 'non_padded_example'), 0)
def test_add_token_to_non_padded_namespace_correct_token_for_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_token_to_namespace('test_token', 'non_padded_example')
self.assertEqual(v.get_token_from_index(0, 'non_padded_example'), 'test_token')
def test_add_tokens_to_non_padded_namespace_correct_vocab_size(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'non_padded_example')
self.assertEqual(v.get_vocab_size('non_padded_example'), 2)
def test_add_tokens_to_non_padded_namespace_correct_token_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'non_padded_example')
self.assertEqual(v.get_token_index('test_token1', 'non_padded_example'), 0)
self.assertEqual(v.get_token_index('test_token2', 'non_padded_example'), 1)
def test_add_tokens_to_non_padded_namespace_correct_token_for_index(self):
padded_namespaces = ['padded_example']
non_padded_namespaces = ['non_padded_example']
v = Vocabulary(non_padded_namespaces=non_padded_namespaces)
v.add_tokens_to_namespace(['test_token1', 'test_token2'], 'non_padded_example')
self.assertEqual(v.get_token_from_index(0, 'non_padded_example'), 'test_token1')
self.assertEqual(v.get_token_from_index(1, 'non_padded_example'), 'test_token2')
if __name__ == '__main__':
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment