Skip to content
Snippets Groups Projects
Commit 213310db authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Fix for unittest

parent 05aab344
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
import codecs
import os
import re
import glob
from collections import defaultdict, OrderedDict
from typing import Dict, Union, Optional, Iterable, Callable, Any, Set, List
from collections import defaultdict
from typing import Dict, Optional, Iterable, Set, List
from torchtext.vocab import Vocab as TorchtextVocab
from torchtext.vocab import vocab as torchtext_vocab
import logging
from filelock import FileLock
from transformers import PreTrainedTokenizer
from combo.common import Tqdm
from combo.data.token_embedders.embedding import EmbeddingsTextFile
from combo.utils import ConfigurationError
from combo.utils.file_utils import cached_path
logger = logging.Logger(__name__)
......@@ -21,6 +20,7 @@ DEFAULT_PADDING_TOKEN = "@@PADDING@@"
DEFAULT_OOV_TOKEN = "@@UNKNOWN@@"
NAMESPACE_PADDING_FILE = "non_padded_namespaces.txt"
DEFAULT_NAMESPACE = "tokens"
_NEW_LINE_REGEX = re.compile(r"\n|\r\n")
def match_namespace(pattern: str, namespace: str):
......@@ -34,24 +34,43 @@ def match_namespace(pattern: str, namespace: str):
return False
def _read_pretrained_tokens(embeddings_file_uri: str) -> List[str]:
# Moving this import to the top breaks everything (cycling import, I guess)
class NamespaceVocabulary:
def __init__(self,
padding_token: Optional[str] = None,
oov_token: Optional[str] = None):
if padding_token and oov_token:
self._itos = {0: padding_token, 1: oov_token}
self._stoi = {padding_token: 0, oov_token: 1}
elif not padding_token and not oov_token:
self._itos = {}
self._stoi = {}
else:
raise ValueError('Padding token and OOV token must be either both None or both provided')
logger.info("Reading pretrained tokens from: %s", embeddings_file_uri)
tokens: List[str] = []
with EmbeddingsTextFile(embeddings_file_uri) as embeddings_file:
for line_number, line in enumerate(Tqdm.tqdm(embeddings_file), start=1):
token_end = line.find(" ")
if token_end >= 0:
token = line[:token_end]
tokens.append(token)
else:
line_begin = line[:20] + "..." if len(line) > 20 else line
logger.warning("Skipping line number %d: %s", line_number, line_begin)
return tokens
def append_token(self, token: str):
# TODO: Should I check if tokens are duplicated here?
vocab_size = len(self._itos)
self._itos[vocab_size] = token
self._stoi[token] = vocab_size
def append_tokens(self, tokens: Iterable[str]):
next_index = len(self._itos)
for ind, token in enumerate(tokens):
self._itos[next_index + ind] = token
self._stoi[token] = next_index + ind
def insert_token(self, token: str, index: int):
self._itos[index] = token
self._stoi[token] = index
def get_itos(self) -> Dict[int, str]:
return self._itos
class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]):
def get_stoi(self) -> Dict[str, int]:
return self._stoi
class _NamespaceDependentDefaultDict(defaultdict[str, NamespaceVocabulary]):
def __init__(self,
non_padded_namespaces: Iterable[str],
padding_token: str,
......@@ -61,29 +80,21 @@ class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]):
self._oov_token = oov_token
super().__init__()
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
self._non_padded_namespaces.update(non_padded_namespaces)
def __missing__(self, namespace: str):
# Non-padded _namespace
if any([match_namespace(npn, namespace) for npn in self._non_padded_namespaces]):
value = torchtext_vocab(OrderedDict([]))
value = NamespaceVocabulary()
else:
value = torchtext_vocab(
OrderedDict([
(self._padding_token, 1),
(self._oov_token, 1)
])
)
value = NamespaceVocabulary(self._padding_token, self._oov_token)
dict.__setitem__(self, namespace, value)
return value
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
self._non_padded_namespaces.update(non_padded_namespaces)
class Vocabulary:
def __init__(self,
counter: Dict[str, Dict[str, int]] = None,
min_count: Dict[str, int] = None,
max_vocab_size: Union[int, Dict[str, int]] = None,
non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
oov_token: Optional[str] = DEFAULT_OOV_TOKEN):
......@@ -105,46 +116,7 @@ class Vocabulary:
def _extend(self,
tokens_to_add: Dict[str, Dict[str, int]]):
for namespace, tokens in tokens_to_add.items():
for token in tokens:
self._vocab[namespace].append_token(token)
@classmethod
def from_files(cls,
directory: str,
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
oov_token: Optional[str] = DEFAULT_OOV_TOKEN) -> None:
"""
Adapted from https://github.com/allenai/allennlp/blob/main/allennlp/data/vocabulary.py
:param directory:
:param padding_token:
:param oov_token:
:return:
"""
files = [file for file in glob.glob(os.path.join(directory, '*.txt'))]
if len(files) == 0:
logger.warning(f'Directory %s is empty' % directory)
non_padded_namespaces = []
try:
with codecs.open(
os.path.join(directory, NAMESPACE_PADDING_FILE), "w", "utf-8"
) as namespace_file:
non_padded_namespaces = [namespace.strip() for namespace in namespace_file]
except FileNotFoundError:
logger.warning("No file %s - all namespaces will be treated as padded namespaces." % NAMESPACE_PADDING_FILE)
for file in files:
if file.split('/')[-1] == NAMESPACE_PADDING_FILE:
# Namespaces file - already read
continue
namespace_name = file.split('/')[-1].replace('.txt', '')
with codecs.open(
file, "w", "utf-8"
) as namespace_tokens_file:
tokens = [token.strip() for token in namespace_tokens_file]
self._vocab[namespace].append_tokens(tokens)
def save_to_files(self, directory: str) -> None:
"""
......@@ -215,7 +187,7 @@ class Vocabulary:
self._vocab[namespace].append_token(token)
def add_transformer_vocab(
self, tokenizer: PreTrainedTokenizer, namespace: str = "tokens"
self, tokenizer: PreTrainedTokenizer, namespace: str = "tokens"
) -> None:
"""
Copies tokens from a transformer tokenizer's vocab into the given namespace.
......@@ -232,14 +204,24 @@ class Vocabulary:
self._non_padded_namespaces.add(namespace)
def extend_from_vocab(self, vocab: "Vocabulary") -> None:
"""
Adds all vocabulary items from all namespaces in the given vocabulary to this vocabulary.
Useful if you want to load a model and extends its vocabulary from new instances.
We also add all non-padded namespaces from the given vocabulary to this vocabulary.
"""
self._non_padded_namespaces.update(vocab._non_padded_namespaces)
for namespace in vocab.get_namespaces():
for token in vocab.get_token_to_index_vocabulary(namespace):
self.add_token_to_namespace(token, namespace)
def get_index_to_token_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[int, str]:
if not isinstance(namespace, str):
raise ValueError(
"Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace)))
itos: List[str] = self._vocab[namespace].get_itos()
return {i: s for i, s in enumerate(itos)}
return self._vocab[namespace].get_itos()
def get_token_to_index_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[str, int]:
if not isinstance(namespace, str):
......@@ -266,3 +248,108 @@ class Vocabulary:
def get_namespaces(self) -> Set[str]:
return set(self._vocab.keys())
def set_from_file(self,
filename: str,
is_padded: bool = True,
oov_token: str = DEFAULT_OOV_TOKEN,
namespace: str = "tokens"):
if is_padded:
self._vocab[namespace].insert_token(self._padding_token, 0)
with codecs.open(filename, "r", "utf-8") as input_file:
lines = _NEW_LINE_REGEX.split(input_file.read())
# Be flexible about having final newline or not
if lines and lines[-1] == "":
lines = lines[:-1]
for i, line in enumerate(lines):
index = i + 1 if is_padded else i
token = line.replace("@@NEWLINE@@", "\n")
if token == oov_token:
token = self._oov_token
self._vocab[namespace].insert_token(token, index)
if is_padded:
assert self._oov_token in self._vocab[namespace].get_itos(), "OOV token not found!"
class PretrainedTransformerVocabulary(Vocabulary):
def __init__(self,
model_name: str,
namespace: str = "tokens",
oov_token: Optional[str] = None):
"""
Initialize a vocabulary from the vocabulary of a pretrained transformer model.
If `oov_token` is not given, we will try to infer it from the transformer tokenizer.
"""
from combo.common import cached_transformers
tokenizer = cached_transformers.get_tokenizer(model_name)
if oov_token is None:
if hasattr(tokenizer, "_unk_token"):
oov_token = tokenizer._unk_token
elif hasattr(tokenizer, "special_tokens_map"):
oov_token = tokenizer.special_tokens_map.get("unk_token")
super().__init__(non_padded_namespaces=[namespace],
oov_token=oov_token)
self.add_transformer_vocab(tokenizer, namespace)
class FromFilesVocabulary(Vocabulary):
def __init__(self,
directory: str,
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
oov_token: Optional[str] = DEFAULT_OOV_TOKEN) -> None:
"""
Adapted from https://github.com/allenai/allennlp/blob/main/allennlp/data/vocabulary.py
:param directory:
:param padding_token:
:param oov_token:
:return:
"""
logger.info("Loading token dictionary from %s.", directory)
padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
if not os.path.isdir(directory):
base_directory = cached_path(directory, extract_archive=True)
# For convenience we'll check for a 'vocabulary' subdirectory of the archive.
# That way you can use model archives directly.
vocab_subdir = os.path.join(base_directory, "vocabulary")
if os.path.isdir(vocab_subdir):
directory = vocab_subdir
elif os.path.isdir(base_directory):
directory = base_directory
else:
raise ConfigurationError(f"{directory} is neither a directory nor an archive")
files = [file for file in glob.glob(os.path.join(directory, '*.txt'))]
if len(files) == 0:
logger.warning(f'Directory %s is empty' % directory)
with FileLock(os.path.join(directory, ".lock")):
with codecs.open(
os.path.join(directory, NAMESPACE_PADDING_FILE), "r", "utf-8"
) as namespace_file:
non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file]
super().__init__(
non_padded_namespaces=non_padded_namespaces,
padding_token=padding_token,
oov_token=oov_token,
)
for namespace_filename in os.listdir(directory):
if namespace_filename == NAMESPACE_PADDING_FILE:
continue
if namespace_filename.startswith("."):
continue
namespace = namespace_filename.replace(".txt", "")
if any(match_namespace(pattern, namespace) for pattern in non_padded_namespaces):
is_padded = False
else:
is_padded = True
filename = os.path.join(directory, namespace_filename)
self.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token)
import unittest
import os
from combo.data import ConllDatasetReader
......@@ -6,19 +7,19 @@ from combo.data import ConllDatasetReader
class ConllDatasetReaderTest(unittest.TestCase):
def test_read_all_tokens(self):
reader = ConllDatasetReader(coding_scheme='IOB2')
tokens = [token for token in reader('conll_test_file.txt')]
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))]
self.assertEqual(len(tokens), 6)
def test_tokenize_correct_tokens(self):
reader = ConllDatasetReader(coding_scheme='IOB2')
token = next(iter(reader('conll_test_file.txt')))
token = next(iter(reader(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))))
self.assertListEqual([str(t) for t in token['tokens'].tokens],
['SOCCER', '-', 'JAPAN', 'GET', 'LUCKY', 'WIN', ',',
'CHINA', 'IN', 'SURPRISE', 'DEFEAT', '.'])
def test_tokenize_correct_tags(self):
reader = ConllDatasetReader(coding_scheme='IOB2')
token = next(iter(reader('conll_test_file.txt')))
token = next(iter(reader(os.path.join(os.path.dirname(__file__), 'conll_test_file.txt'))))
self.assertListEqual(token['tags'].labels,
['O', 'O', 'B-LOC', 'O', 'O', 'O', 'O',
'B-PER', 'O', 'O', 'O', 'O'])
import unittest
import os
from combo.data.dataset_readers import TextClassificationJSONReader
from combo.data.fields import LabelField, TextField, ListField
......@@ -8,12 +9,12 @@ from combo.data.tokenizers import SpacySentenceSplitter
class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_tokens(self):
reader = TextClassificationJSONReader()
tokens = [token for token in reader('text_classification_json_reader.json')]
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens), 2)
def test_read_two_examples_fields_without_sentence_splitting(self):
reader = TextClassificationJSONReader()
tokens = [token for token in reader('text_classification_json_reader.json')]
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('label'), LabelField)
self.assertEqual(tokens[0].fields.get('label').label, 'label1')
......@@ -23,7 +24,7 @@ class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_examples_tokens_without_sentence_splitting(self):
reader = TextClassificationJSONReader()
tokens = [token for token in reader('text_classification_json_reader.json')]
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('tokens'), TextField)
self.assertEqual(len(tokens[0].fields.get('tokens').tokens), 13)
......@@ -33,7 +34,7 @@ class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_examples_tokens_with_sentence_splitting(self):
reader = TextClassificationJSONReader(sentence_segmenter=SpacySentenceSplitter())
tokens = [token for token in reader('text_classification_json_reader.json')]
tokens = [token for token in reader(os.path.join(os.path.dirname(__file__), 'text_classification_json_reader.json'))]
self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('tokens'), ListField)
self.assertEqual(len(tokens[0].fields.get('tokens').field_list), 2)
......
import unittest
import os
from combo.data import UniversalDependenciesDatasetReader
......@@ -6,24 +7,24 @@ from combo.data import UniversalDependenciesDatasetReader
class UniversalDependenciesDatasetReaderTest(unittest.TestCase):
def test_read_all_tokens(self):
t = UniversalDependenciesDatasetReader()
tokens = [token for token in t('tl_trg-ud-test.conllu')]
tokens = [token for token in t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))]
self.assertEqual(len(tokens), 128)
def test_read_text(self):
t = UniversalDependenciesDatasetReader()
token = next(iter(t('tl_trg-ud-test.conllu')))
token = next(iter(t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual([str(t) for t in token['sentence'].tokens],
['Gumising', 'ang', 'bata', '.'])
def test_read_deprel(self):
t = UniversalDependenciesDatasetReader()
token = next(iter(t('tl_trg-ud-test.conllu')))
token = next(iter(t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual(token['deprel'].labels,
['root', 'case', 'nsubj', 'punct'])
def test_read_upostag(self):
t = UniversalDependenciesDatasetReader()
token = next(iter(t('tl_trg-ud-test.conllu')))
token = next(iter(t(os.path.join(os.path.dirname(__file__), 'tl_trg-ud-test.conllu'))))
self.assertListEqual(token['upostag'].labels,
['VERB', 'ADP', 'NOUN', 'PUNCT'])
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment