Skip to content
Snippets Groups Projects
Commit 18cb59e7 authored by Maja Jabłońska's avatar Maja Jabłońska Committed by Martyna Wiącek
Browse files

Add from_file and to_file for Vocabulary

parent 6dc7c92c
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
import codecs
import os
import glob
from collections import defaultdict, OrderedDict
from typing import Dict, Union, Optional, Iterable, Callable, Any, Set, List
from torchtext.vocab import Vocab as TorchtextVocab
from torchtext.vocab import vocab as torchtext_vocab
import logging
from filelock import FileLock
logger = logging.Logger(__name__)
DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
DEFAULT_PADDING_TOKEN = "@@PADDING@@"
......@@ -79,6 +87,81 @@ class Vocabulary:
for token in tokens:
self._vocab[namespace].append_token(token)
@classmethod
def from_files(cls,
directory: str,
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
oov_token: Optional[str] = DEFAULT_OOV_TOKEN) -> None:
"""
Adapted from https://github.com/allenai/allennlp/blob/main/allennlp/data/vocabulary.py
:param directory:
:param padding_token:
:param oov_token:
:return:
"""
files = [file for file in glob.glob(os.path.join(directory, '*.txt'))]
if len(files) == 0:
logger.warning(f'Directory %s is empty' % directory)
non_padded_namespaces = []
try:
with codecs.open(
os.path.join(directory, NAMESPACE_PADDING_FILE), "w", "utf-8"
) as namespace_file:
non_padded_namespaces = [namespace.strip() for namespace in namespace_file]
except FileNotFoundError:
logger.warning("No file %s - all namespaces will be treated as padded namespaces." % NAMESPACE_PADDING_FILE)
for file in files:
if file.split('/')[-1] == NAMESPACE_PADDING_FILE:
# Namespaces file - already read
continue
namespace_name = file.split('/')[-1].replace('.txt', '')
with codecs.open(
file, "w", "utf-8"
) as namespace_tokens_file:
tokens = [token.strip() for token in namespace_tokens_file]
def save_to_files(self, directory: str) -> None:
"""
Persist this Vocabulary to files, so it can be reloaded later.
Each namespace corresponds to one file.
Adapred from https://github.com/allenai/allennlp/blob/main/allennlp/data/vocabulary.py
# Parameters
directory : `str`
The directory where we save the serialized vocabulary.
"""
os.makedirs(directory, exist_ok=True)
if os.listdir(directory):
logger.warning("Directory %s is not empty", directory)
# We use a lock file to avoid race conditions where multiple processes
# might be reading/writing from/to the same vocab files at once.
with FileLock(os.path.join(directory, ".lock")):
with codecs.open(
os.path.join(directory, NAMESPACE_PADDING_FILE), "w", "utf-8"
) as namespace_file:
for namespace_str in self._non_padded_namespaces:
print(namespace_str, file=namespace_file)
for namespace, vocab in self._vocab.items():
# Each namespace gets written to its own file, in index order.
with codecs.open(
os.path.join(directory, namespace + ".txt"), "w", "utf-8"
) as token_file:
mapping = vocab.get_itos()
num_tokens = len(mapping)
start_index = 1 if mapping[0] == self._padding_token else 0
for i in range(start_index, num_tokens):
print(mapping[i].replace("\n", "@@NEWLINE@@"), file=token_file)
def is_padded(self, namespace: str) -> bool:
return self._vocab[namespace].get_itos()[0] == self._padding_token
def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE):
"""
Add the token if not present and return the index even if token was already in the namespace.
......@@ -103,8 +186,8 @@ class Vocabulary:
"""
if not isinstance(tokens, List):
raise ValueError("Vocabulary tokens must be passed as a list of strings. Got %s with type %s" % (
repr(tokens), type(tokens)))
raise ValueError("Vocabulary tokens must be passed as a list of strings. Got tokens with type %s" % (
type(tokens)))
for token in tokens:
self._vocab[namespace].append_token(token)
......@@ -126,17 +209,20 @@ class Vocabulary:
return self._vocab[namespace].get_stoi()
def get_token_index(self, token: str, namespace: str = DEFAULT_NAMESPACE) -> int:
return self.get_token_to_index_vocabulary(namespace).get(token)
def get_token_from_index(self, index: int, namespace: str = DEFAULT_NAMESPACE) -> Optional[str]:
return self.get_index_to_token_vocabulary(namespace).get(index)
try:
return self._vocab[namespace].get_stoi()[token]
except KeyError:
try:
return self._vocab[namespace].get_stoi()[token][self._oov_token]
except KeyError:
raise KeyError("Namespace %s doesn't contain token %s or default OOV token %s" %
(namespace, repr(token), repr(self._oov_token)))
def get_token_from_index(self, index: int, namespace: str = DEFAULT_NAMESPACE) -> str:
return self._vocab[namespace].get_itos()[index]
def get_vocab_size(self, namespace: str = DEFAULT_NAMESPACE) -> int:
return len(self._vocab[namespace].get_itos())
def get_namespaces(self) -> Set[str]:
return set(self._vocab.keys())
@classmethod
def empty(cls):
return cls()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment