Skip to content
Snippets Groups Projects
Commit 7000a02f authored by Maja Jabłońska's avatar Maja Jabłońska Committed by Martyna Wiącek
Browse files

General structure

parent 520801e2
No related branches found
No related tags found
1 merge request!46Merge COMBO 3.0 into master
from abc import ABCMeta
class Field(metaclass=ABCMeta):
class Field:
pass
from abc import ABCMeta
class Sampler(metaclass=ABCMeta):
class Sampler:
pass
from abc import ABCMeta
class TokenIndexer(metaclass=ABCMeta):
class TokenIndexer:
pass
......
from combo.data import TokenIndexer
from .base_indexer import TokenIndexer
class PretrainedTransformerMismatchedIndexer(TokenIndexer):
......
from combo.data import TokenIndexer
from .base_indexer import TokenIndexer
class TokenCharactersIndexer(TokenIndexer):
......
"""Features indexer."""
from combo.data import TokenIndexer
from .base_indexer import TokenIndexer
class TokenFeatsIndexer(TokenIndexer):
......
from collections import defaultdict, OrderedDict
from typing import Dict, Union, Optional, Iterable, Callable, Any, Set
from typing import Dict, Union, Optional, Iterable, Callable, Any, Set, List
from torchtext.vocab import Vocab as TorchtextVocab
from torchtext.vocab import vocab as torchtext_vocab
......@@ -35,13 +35,14 @@ class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]):
def __missing__(self, namespace: str):
# Non-padded namespace
if any([match_namespace(namespace, npn) for npn in self._non_padded_namespaces]):
value = torchtext_vocab(OrderedDict([]))
else:
value = torchtext_vocab(
OrderedDict([
(self._padding_token, 0),
(self._oov_token, 1)])
(self._padding_token, 1),
(self._oov_token, 1)
])
)
else:
value = torchtext_vocab(OrderedDict([]))
dict.__setitem__(self, namespace, value)
return value
......@@ -78,21 +79,64 @@ class Vocabulary:
for token in tokens:
self._vocab[namespace].append_token(token)
# def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE):
# """
# Add the token if not present and return the index even if token was already in the namespace.
#
# :param token: token to be added
# :param namespace: namespace to add the token to
# :return: index of the token in the namespace
# """
#
# if not isinstance(token, str):
# raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token)))
#
def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE):
"""
Add the token if not present and return the index even if token was already in the namespace.
:param token: token to be added
:param namespace: namespace to add the token to
:return: index of the token in the namespace
"""
if not isinstance(token, str):
raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token)))
self._vocab[namespace].append_token(token)
def add_tokens_to_namespace(self, tokens: List[str], namespace: str = DEFAULT_NAMESPACE):
"""
Add the token if not present and return the index even if token was already in the namespace.
:param tokens: tokens to be added
:param namespace: namespace to add the token to
:return: index of the token in the namespace
"""
if not isinstance(tokens, List):
raise ValueError("Vocabulary tokens must be passed as a list of strings. Got %s with type %s" % (
repr(tokens), type(tokens)))
for token in tokens:
self._vocab[namespace].append_token(token)
def get_index_to_token_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[int, str]:
if not isinstance(namespace, str):
raise ValueError(
"Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace)))
itos: List[str] = self._vocab[namespace].get_itos()
return {i: s for i, s in enumerate(itos)}
def get_token_to_index_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[str, int]:
if not isinstance(namespace, str):
raise ValueError(
"Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace)))
return self._vocab[namespace].get_stoi()
def get_token_index(self, token: str, namespace: str = DEFAULT_NAMESPACE) -> int:
return self.get_token_to_index_vocabulary(namespace).get(token)
def get_token_from_index(self, index: int, namespace: str = DEFAULT_NAMESPACE) -> Optional[str]:
return self.get_index_to_token_vocabulary(namespace).get(index)
def get_vocab_size(self, namespace: str = DEFAULT_NAMESPACE) -> int:
return len(self._vocab[namespace].get_itos())
def get_namespaces(self) -> Set[str]:
return set(self._vocab.keys())
@classmethod
def empty(cls):
return cls()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment