Skip to content
Snippets Groups Projects
Commit 213310db authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Fix for unittest

parent 05aab344
No related branches found
No related tags found
1 merge request!46Merge COMBO 3.0 into master
import codecs import codecs
import os import os
import re
import glob import glob
from collections import defaultdict, OrderedDict from collections import defaultdict
from typing import Dict, Union, Optional, Iterable, Callable, Any, Set, List from typing import Dict, Optional, Iterable, Set, List
from torchtext.vocab import Vocab as TorchtextVocab
from torchtext.vocab import vocab as torchtext_vocab
import logging import logging
from filelock import FileLock from filelock import FileLock
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from combo.common import Tqdm from combo.utils import ConfigurationError
from combo.data.token_embedders.embedding import EmbeddingsTextFile from combo.utils.file_utils import cached_path
logger = logging.Logger(__name__) logger = logging.Logger(__name__)
...@@ -21,6 +20,7 @@ DEFAULT_PADDING_TOKEN = "@@PADDING@@" ...@@ -21,6 +20,7 @@ DEFAULT_PADDING_TOKEN = "@@PADDING@@"
DEFAULT_OOV_TOKEN = "@@UNKNOWN@@" DEFAULT_OOV_TOKEN = "@@UNKNOWN@@"
NAMESPACE_PADDING_FILE = "non_padded_namespaces.txt" NAMESPACE_PADDING_FILE = "non_padded_namespaces.txt"
DEFAULT_NAMESPACE = "tokens" DEFAULT_NAMESPACE = "tokens"
_NEW_LINE_REGEX = re.compile(r"\n|\r\n")
def match_namespace(pattern: str, namespace: str): def match_namespace(pattern: str, namespace: str):
...@@ -34,24 +34,43 @@ def match_namespace(pattern: str, namespace: str): ...@@ -34,24 +34,43 @@ def match_namespace(pattern: str, namespace: str):
return False return False
def _read_pretrained_tokens(embeddings_file_uri: str) -> List[str]: class NamespaceVocabulary:
# Moving this import to the top breaks everything (cycling import, I guess) def __init__(self,
padding_token: Optional[str] = None,
logger.info("Reading pretrained tokens from: %s", embeddings_file_uri) oov_token: Optional[str] = None):
tokens: List[str] = [] if padding_token and oov_token:
with EmbeddingsTextFile(embeddings_file_uri) as embeddings_file: self._itos = {0: padding_token, 1: oov_token}
for line_number, line in enumerate(Tqdm.tqdm(embeddings_file), start=1): self._stoi = {padding_token: 0, oov_token: 1}
token_end = line.find(" ") elif not padding_token and not oov_token:
if token_end >= 0: self._itos = {}
token = line[:token_end] self._stoi = {}
tokens.append(token)
else: else:
line_begin = line[:20] + "..." if len(line) > 20 else line raise ValueError('Padding token and OOV token must be either both None or both provided')
logger.warning("Skipping line number %d: %s", line_number, line_begin)
return tokens def append_token(self, token: str):
# TODO: Should I check if tokens are duplicated here?
vocab_size = len(self._itos)
self._itos[vocab_size] = token
self._stoi[token] = vocab_size
def append_tokens(self, tokens: Iterable[str]):
next_index = len(self._itos)
for ind, token in enumerate(tokens):
self._itos[next_index + ind] = token
self._stoi[token] = next_index + ind
def insert_token(self, token: str, index: int):
self._itos[index] = token
self._stoi[token] = index
def get_itos(self) -> Dict[int, str]:
return self._itos
def get_stoi(self) -> Dict[str, int]:
return self._stoi
class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]):
class _NamespaceDependentDefaultDict(defaultdict[str, NamespaceVocabulary]):
def __init__(self, def __init__(self,
non_padded_namespaces: Iterable[str], non_padded_namespaces: Iterable[str],
padding_token: str, padding_token: str,
...@@ -61,29 +80,21 @@ class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]): ...@@ -61,29 +80,21 @@ class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]):
self._oov_token = oov_token self._oov_token = oov_token
super().__init__() super().__init__()
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
self._non_padded_namespaces.update(non_padded_namespaces)
def __missing__(self, namespace: str): def __missing__(self, namespace: str):
# Non-padded _namespace # Non-padded _namespace
if any([match_namespace(npn, namespace) for npn in self._non_padded_namespaces]): if any([match_namespace(npn, namespace) for npn in self._non_padded_namespaces]):
value = torchtext_vocab(OrderedDict([])) value = NamespaceVocabulary()
else: else:
value = torchtext_vocab( value = NamespaceVocabulary(self._padding_token, self._oov_token)
OrderedDict([
(self._padding_token, 1),
(self._oov_token, 1)
])
)
dict.__setitem__(self, namespace, value) dict.__setitem__(self, namespace, value)
return value return value
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
self._non_padded_namespaces.update(non_padded_namespaces)
class Vocabulary: class Vocabulary:
def __init__(self, def __init__(self,
counter: Dict[str, Dict[str, int]] = None,
min_count: Dict[str, int] = None,
max_vocab_size: Union[int, Dict[str, int]] = None,
non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN, padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
oov_token: Optional[str] = DEFAULT_OOV_TOKEN): oov_token: Optional[str] = DEFAULT_OOV_TOKEN):
...@@ -105,46 +116,7 @@ class Vocabulary: ...@@ -105,46 +116,7 @@ class Vocabulary:
def _extend(self, def _extend(self,
tokens_to_add: Dict[str, Dict[str, int]]): tokens_to_add: Dict[str, Dict[str, int]]):
for namespace, tokens in tokens_to_add.items(): for namespace, tokens in tokens_to_add.items():
for token in tokens: self._vocab[namespace].append_tokens(tokens)
self._vocab[namespace].append_token(token)
@classmethod
def from_files(cls,
directory: str,
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
oov_token: Optional[str] = DEFAULT_OOV_TOKEN) -> None:
"""
Adapted from https://github.com/allenai/allennlp/blob/main/allennlp/data/vocabulary.py
:param directory:
:param padding_token:
:param oov_token:
:return:
"""
files = [file for file in glob.glob(os.path.join(directory, '*.txt'))]
if len(files) == 0:
logger.warning(f'Directory %s is empty' % directory)
non_padded_namespaces = []
try:
with codecs.open(
os.path.join(directory, NAMESPACE_PADDING_FILE), "w", "utf-8"
) as namespace_file:
non_padded_namespaces = [namespace.strip() for namespace in namespace_file]
except FileNotFoundError:
logger.warning("No file %s - all namespaces will be treated as padded namespaces." % NAMESPACE_PADDING_FILE)
for file in files:
if file.split('/')[-1] == NAMESPACE_PADDING_FILE:
# Namespaces file - already read
continue
namespace_name = file.split('/')[-1].replace('.txt', '')
with codecs.open(
file, "w", "utf-8"
) as namespace_tokens_file:
tokens = [token.strip() for token in namespace_tokens_file]
def save_to_files(self, directory: str) -> None: def save_to_files(self, directory: str) -> None:
""" """
...@@ -232,14 +204,24 @@ class Vocabulary: ...@@ -232,14 +204,24 @@ class Vocabulary:
self._non_padded_namespaces.add(namespace) self._non_padded_namespaces.add(namespace)
def extend_from_vocab(self, vocab: "Vocabulary") -> None:
"""
Adds all vocabulary items from all namespaces in the given vocabulary to this vocabulary.
Useful if you want to load a model and extends its vocabulary from new instances.
We also add all non-padded namespaces from the given vocabulary to this vocabulary.
"""
self._non_padded_namespaces.update(vocab._non_padded_namespaces)
for namespace in vocab.get_namespaces():
for token in vocab.get_token_to_index_vocabulary(namespace):
self.add_token_to_namespace(token, namespace)
def get_index_to_token_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[int, str]: def get_index_to_token_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[int, str]:
if not isinstance(namespace, str): if not isinstance(namespace, str):
raise ValueError( raise ValueError(
"Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace))) "Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace)))
itos: List[str] = self._vocab[namespace].get_itos() return self._vocab[namespace].get_itos()
return {i: s for i, s in enumerate(itos)}
def get_token_to_index_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[str, int]: def get_token_to_index_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[str, int]:
if not isinstance(namespace, str): if not isinstance(namespace, str):
...@@ -266,3 +248,108 @@ class Vocabulary: ...@@ -266,3 +248,108 @@ class Vocabulary:
def get_namespaces(self) -> Set[str]: def get_namespaces(self) -> Set[str]:
return set(self._vocab.keys()) return set(self._vocab.keys())
def set_from_file(self,
filename: str,
is_padded: bool = True,
oov_token: str = DEFAULT_OOV_TOKEN,
namespace: str = "tokens"):
if is_padded:
self._vocab[namespace].insert_token(self._padding_token, 0)
with codecs.open(filename, "r", "utf-8") as input_file:
lines = _NEW_LINE_REGEX.split(input_file.read())
# Be flexible about having final newline or not
if lines and lines[-1] == "":
lines = lines[:-1]
for i, line in enumerate(lines):
index = i + 1 if is_padded else i
token = line.replace("@@NEWLINE@@", "\n")
if token == oov_token:
token = self._oov_token
self._vocab[namespace].insert_token(token, index)
if is_padded:
assert self._oov_token in self._vocab[namespace].get_itos(), "OOV token not found!"
class PretrainedTransformerVocabulary(Vocabulary):
def __init__(self,
model_name: str,
namespace: str = "tokens",
oov_token: Optional[str] = None):
"""
Initialize a vocabulary from the vocabulary of a pretrained transformer model.
If `oov_token` is not given, we will try to infer it from the transformer tokenizer.
"""
from combo.common import cached_transformers
tokenizer = cached_transformers.get_tokenizer(model_name)
if oov_token is None:
if hasattr(tokenizer, "_unk_token"):
oov_token = tokenizer._unk_token
elif hasattr(tokenizer, "special_tokens_map"):
oov_token = tokenizer.special_tokens_map.get("unk_token")
super().__init__(non_padded_namespaces=[namespace],
oov_token=oov_token)
self.add_transformer_vocab(tokenizer, namespace)
class FromFilesVocabulary(Vocabulary):
def __init__(self,
directory: str,
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
oov_token: Optional[str] = DEFAULT_OOV_TOKEN) -> None:
"""
Adapted from https://github.com/allenai/allennlp/blob/main/allennlp/data/vocabulary.py
:param directory:
:param padding_token:
:param oov_token:
:return:
"""
logger.info("Loading token dictionary from %s.", directory)
padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
if not os.path.isdir(directory):
base_directory = cached_path(directory, extract_archive=True)
# For convenience we'll check for a 'vocabulary' subdirectory of the archive.
# That way you can use model archives directly.
vocab_subdir = os.path.join(base_directory, "vocabulary")
if os.path.isdir(vocab_subdir):
directory = vocab_subdir
elif os.path.isdir(base_directory):
directory = base_directory
else:
raise ConfigurationError(f"{directory} is neither a directory nor an archive")
files = [file for file in glob.glob(os.path.join(directory, '*.txt'))]
if len(files) == 0:
logger.warning(f'Directory %s is empty' % directory)
with FileLock(os.path.join(directory, ".lock")):
with codecs.open(
os.path.join(directory, NAMESPACE_PADDING_FILE), "r", "utf-8"
) as namespace_file:
non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file]
super().__init__(
non_padded_namespaces=non_padded_namespaces,
padding_token=padding_token,
oov_token=oov_token,
)
for namespace_filename in os.listdir(directory):
if namespace_filename == NAMESPACE_PADDING_FILE:
continue
if namespace_filename.startswith("."):
continue
namespace = namespace_filename.replace(".txt", "")
if any(match_namespace(pattern, namespace) for pattern in non_padded_namespaces):
is_padded = False
else:
is_padded = True
filename = os.path.join(directory, namespace_filename)
self.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token)
import unittest import unittest
import os
from combo.data import ConllDatasetReader from combo.data import ConllDatasetReader
...@@ -6,19 +7,19 @@ from combo.data import ConllDatasetReader ...@@ -6,19 +7,19 @@ from combo.data import ConllDatasetReader
class ConllDatasetReaderTest(unittest.TestCase): class ConllDatasetReaderTest(unittest.TestCase):
def test_read_all_tokens(self): def test_read_all_tokens(self):
reader = ConllDatasetReader(coding_scheme='IOB2') reader = ConllDatasetReader(coding_scheme='IOB2')
tokens = [token for token in reader('conll_test_file.txt')] tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))]
self.assertEqual(len(tokens), 6) self.assertEqual(len(tokens), 6)
def test_tokenize_correct_tokens(self): def test_tokenize_correct_tokens(self):
reader = ConllDatasetReader(coding_scheme='IOB2') reader = ConllDatasetReader(coding_scheme='IOB2')
token = next(iter(reader('conll_test_file.txt'))) token = next(iter(reader(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))))
self.assertListEqual([str(t) for t in token['tokens'].tokens], self.assertListEqual([str(t) for t in token['tokens'].tokens],
['SOCCER', '-', 'JAPAN', 'GET', 'LUCKY', 'WIN', ',', ['SOCCER', '-', 'JAPAN', 'GET', 'LUCKY', 'WIN', ',',
'CHINA', 'IN', 'SURPRISE', 'DEFEAT', '.']) 'CHINA', 'IN', 'SURPRISE', 'DEFEAT', '.'])
def test_tokenize_correct_tags(self): def test_tokenize_correct_tags(self):
reader = ConllDatasetReader(coding_scheme='IOB2') reader = ConllDatasetReader(coding_scheme='IOB2')
token = next(iter(reader('conll_test_file.txt'))) token = next(iter(reader(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))))
self.assertListEqual(token['tags'].labels, self.assertListEqual(token['tags'].labels,
['O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', ['O', 'O', 'B-LOC', 'O', 'O', 'O', 'O',
'B-PER', 'O', 'O', 'O', 'O']) 'B-PER', 'O', 'O', 'O', 'O'])
import unittest import unittest
import os
from combo.data.dataset_readers import TextClassificationJSONReader from combo.data.dataset_readers import TextClassificationJSONReader
from combo.data.fields import LabelField, TextField, ListField from combo.data.fields import LabelField, TextField, ListField
...@@ -8,12 +9,12 @@ from combo.data.tokenizers import SpacySentenceSplitter ...@@ -8,12 +9,12 @@ from combo.data.tokenizers import SpacySentenceSplitter
class TextClassificationJSONReaderTest(unittest.TestCase): class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_tokens(self): def test_read_two_tokens(self):
reader = TextClassificationJSONReader() reader = TextClassificationJSONReader()
tokens = [token for token in reader('text_classification_json_reader.json')] tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens), 2) self.assertEqual(len(tokens), 2)
def test_read_two_examples_fields_without_sentence_splitting(self): def test_read_two_examples_fields_without_sentence_splitting(self):
reader = TextClassificationJSONReader() reader = TextClassificationJSONReader()
tokens = [token for token in reader('text_classification_json_reader.json')] tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens[0].fields.items()), 2) self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('label'), LabelField) self.assertIsInstance(tokens[0].fields.get('label'), LabelField)
self.assertEqual(tokens[0].fields.get('label').label, 'label1') self.assertEqual(tokens[0].fields.get('label').label, 'label1')
...@@ -23,7 +24,7 @@ class TextClassificationJSONReaderTest(unittest.TestCase): ...@@ -23,7 +24,7 @@ class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_examples_tokens_without_sentence_splitting(self): def test_read_two_examples_tokens_without_sentence_splitting(self):
reader = TextClassificationJSONReader() reader = TextClassificationJSONReader()
tokens = [token for token in reader('text_classification_json_reader.json')] tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens[0].fields.items()), 2) self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('tokens'), TextField) self.assertIsInstance(tokens[0].fields.get('tokens'), TextField)
self.assertEqual(len(tokens[0].fields.get('tokens').tokens), 13) self.assertEqual(len(tokens[0].fields.get('tokens').tokens), 13)
...@@ -33,7 +34,7 @@ class TextClassificationJSONReaderTest(unittest.TestCase): ...@@ -33,7 +34,7 @@ class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_examples_tokens_with_sentence_splitting(self): def test_read_two_examples_tokens_with_sentence_splitting(self):
reader = TextClassificationJSONReader(sentence_segmenter=SpacySentenceSplitter()) reader = TextClassificationJSONReader(sentence_segmenter=SpacySentenceSplitter())
tokens = [token for token in reader('text_classification_json_reader.json')] tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens[0].fields.items()), 2) self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('tokens'), ListField) self.assertIsInstance(tokens[0].fields.get('tokens'), ListField)
self.assertEqual(len(tokens[0].fields.get('tokens').field_list), 2) self.assertEqual(len(tokens[0].fields.get('tokens').field_list), 2)
......
import unittest import unittest
import os
from combo.data import UniversalDependenciesDatasetReader from combo.data import UniversalDependenciesDatasetReader
...@@ -6,24 +7,24 @@ from combo.data import UniversalDependenciesDatasetReader ...@@ -6,24 +7,24 @@ from combo.data import UniversalDependenciesDatasetReader
class UniversalDependenciesDatasetReaderTest(unittest.TestCase): class UniversalDependenciesDatasetReaderTest(unittest.TestCase):
def test_read_all_tokens(self): def test_read_all_tokens(self):
t = UniversalDependenciesDatasetReader() t = UniversalDependenciesDatasetReader()
tokens = [token for token in t('tl_trg-ud-test.conllu')] tokens = [token for token in t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))]
self.assertEqual(len(tokens), 128) self.assertEqual(len(tokens), 128)
def test_read_text(self): def test_read_text(self):
t = UniversalDependenciesDatasetReader() t = UniversalDependenciesDatasetReader()
token = next(iter(t('tl_trg-ud-test.conllu'))) token = next(iter(t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual([str(t) for t in token['sentence'].tokens], self.assertListEqual([str(t) for t in token['sentence'].tokens],
['Gumising', 'ang', 'bata', '.']) ['Gumising', 'ang', 'bata', '.'])
def test_read_deprel(self): def test_read_deprel(self):
t = UniversalDependenciesDatasetReader() t = UniversalDependenciesDatasetReader()
token = next(iter(t('tl_trg-ud-test.conllu'))) token = next(iter(t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual(token['deprel'].labels, self.assertListEqual(token['deprel'].labels,
['root', 'case', 'nsubj', 'punct']) ['root', 'case', 'nsubj', 'punct'])
def test_read_upostag(self): def test_read_upostag(self):
t = UniversalDependenciesDatasetReader() t = UniversalDependenciesDatasetReader()
token = next(iter(t('tl_trg-ud-test.conllu'))) token = next(iter(t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual(token['upostag'].labels, self.assertListEqual(token['upostag'].labels,
['VERB', 'ADP', 'NOUN', 'PUNCT']) ['VERB', 'ADP', 'NOUN', 'PUNCT'])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment