diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index feb184e2eb2d84ee19f29d9a025d0406cbd97e9a..7ddb7e37fbd1570f895670119f317a27ffc2a2f0 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -1,8 +1,16 @@ +import codecs +import os +import glob from collections import defaultdict, OrderedDict from typing import Dict, Union, Optional, Iterable, Callable, Any, Set, List from torchtext.vocab import Vocab as TorchtextVocab from torchtext.vocab import vocab as torchtext_vocab +import logging + +from filelock import FileLock + +logger = logging.Logger(__name__) DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels") DEFAULT_PADDING_TOKEN = "@@PADDING@@" @@ -79,6 +87,81 @@ class Vocabulary: for token in tokens: self._vocab[namespace].append_token(token) + @classmethod + def from_files(cls, + directory: str, + padding_token: Optional[str] = DEFAULT_PADDING_TOKEN, + oov_token: Optional[str] = DEFAULT_OOV_TOKEN) -> None: + """ + Adapted from https://github.com/allenai/allennlp/blob/main/allennlp/data/vocabulary.py + + :param directory: + :param padding_token: + :param oov_token: + :return: + """ + files = [file for file in glob.glob(os.path.join(directory, '*.txt'))] + + if len(files) == 0: + logger.warning(f'Directory %s is empty' % directory) + + non_padded_namespaces = [] + + try: + with codecs.open( + os.path.join(directory, NAMESPACE_PADDING_FILE), "w", "utf-8" + ) as namespace_file: + non_padded_namespaces = [namespace.strip() for namespace in namespace_file] + except FileNotFoundError: + logger.warning("No file %s - all namespaces will be treated as padded namespaces." % NAMESPACE_PADDING_FILE) + + for file in files: + if file.split('/')[-1] == NAMESPACE_PADDING_FILE: + # Namespaces file - already read + continue + namespace_name = file.split('/')[-1].replace('.txt', '') + with codecs.open( + file, "w", "utf-8" + ) as namespace_tokens_file: + tokens = [token.strip() for token in namespace_tokens_file] + + def save_to_files(self, directory: str) -> None: + """ + Persist this Vocabulary to files, so it can be reloaded later. + Each namespace corresponds to one file. + + Adapred from https://github.com/allenai/allennlp/blob/main/allennlp/data/vocabulary.py + # Parameters + directory : `str` + The directory where we save the serialized vocabulary. + """ + os.makedirs(directory, exist_ok=True) + if os.listdir(directory): + logger.warning("Directory %s is not empty", directory) + + # We use a lock file to avoid race conditions where multiple processes + # might be reading/writing from/to the same vocab files at once. + with FileLock(os.path.join(directory, ".lock")): + with codecs.open( + os.path.join(directory, NAMESPACE_PADDING_FILE), "w", "utf-8" + ) as namespace_file: + for namespace_str in self._non_padded_namespaces: + print(namespace_str, file=namespace_file) + + for namespace, vocab in self._vocab.items(): + # Each namespace gets written to its own file, in index order. + with codecs.open( + os.path.join(directory, namespace + ".txt"), "w", "utf-8" + ) as token_file: + mapping = vocab.get_itos() + num_tokens = len(mapping) + start_index = 1 if mapping[0] == self._padding_token else 0 + for i in range(start_index, num_tokens): + print(mapping[i].replace("\n", "@@NEWLINE@@"), file=token_file) + + def is_padded(self, namespace: str) -> bool: + return self._vocab[namespace].get_itos()[0] == self._padding_token + def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE): """ Add the token if not present and return the index even if token was already in the namespace. @@ -103,8 +186,8 @@ class Vocabulary: """ if not isinstance(tokens, List): - raise ValueError("Vocabulary tokens must be passed as a list of strings. Got %s with type %s" % ( - repr(tokens), type(tokens))) + raise ValueError("Vocabulary tokens must be passed as a list of strings. Got tokens with type %s" % ( + type(tokens))) for token in tokens: self._vocab[namespace].append_token(token) @@ -126,17 +209,20 @@ class Vocabulary: return self._vocab[namespace].get_stoi() def get_token_index(self, token: str, namespace: str = DEFAULT_NAMESPACE) -> int: - return self.get_token_to_index_vocabulary(namespace).get(token) - - def get_token_from_index(self, index: int, namespace: str = DEFAULT_NAMESPACE) -> Optional[str]: - return self.get_index_to_token_vocabulary(namespace).get(index) + try: + return self._vocab[namespace].get_stoi()[token] + except KeyError: + try: + return self._vocab[namespace].get_stoi()[token][self._oov_token] + except KeyError: + raise KeyError("Namespace %s doesn't contain token %s or default OOV token %s" % + (namespace, repr(token), repr(self._oov_token))) + + def get_token_from_index(self, index: int, namespace: str = DEFAULT_NAMESPACE) -> str: + return self._vocab[namespace].get_itos()[index] def get_vocab_size(self, namespace: str = DEFAULT_NAMESPACE) -> int: return len(self._vocab[namespace].get_itos()) def get_namespaces(self) -> Set[str]: return set(self._vocab.keys()) - - @classmethod - def empty(cls): - return cls()