Skip to content
Snippets Groups Projects
Commit e83a7797 authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

General structure

parent 1bf88a88
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
from abc import ABCMeta
class Field(metaclass=ABCMeta):
class Field:
pass
from abc import ABCMeta
class Sampler(metaclass=ABCMeta):
class Sampler:
pass
from abc import ABCMeta
class TokenIndexer(metaclass=ABCMeta):
class TokenIndexer:
pass
......
from combo.data import TokenIndexer
from .base_indexer import TokenIndexer
class PretrainedTransformerMismatchedIndexer(TokenIndexer):
......
from combo.data import TokenIndexer
from .base_indexer import TokenIndexer
class TokenCharactersIndexer(TokenIndexer):
......
"""Features indexer."""
from combo.data import TokenIndexer
from .base_indexer import TokenIndexer
class TokenFeatsIndexer(TokenIndexer):
......
from collections import defaultdict, OrderedDict
from typing import Dict, Union, Optional, Iterable, Callable, Any, Set
from typing import Dict, Union, Optional, Iterable, Callable, Any, Set, List
from torchtext.vocab import Vocab as TorchtextVocab
from torchtext.vocab import vocab as torchtext_vocab
......@@ -17,7 +17,7 @@ def match_namespace(pattern: str, namespace: str):
(type(pattern), type(namespace)))
if pattern == namespace:
return True
if len(pattern)>2 and pattern[0] == '*' and namespace.endswith(pattern[1:]):
if len(pattern) > 2 and pattern[0] == '*' and namespace.endswith(pattern[1:]):
return True
return False
......@@ -35,13 +35,14 @@ class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]):
def __missing__(self, namespace: str):
# Non-padded namespace
if any([match_namespace(namespace, npn) for npn in self._non_padded_namespaces]):
value = torchtext_vocab(OrderedDict([]))
else:
value = torchtext_vocab(
OrderedDict([
(self._padding_token, 0),
(self._oov_token, 1)])
(self._padding_token, 1),
(self._oov_token, 1)
])
)
else:
value = torchtext_vocab(OrderedDict([]))
dict.__setitem__(self, namespace, value)
return value
......@@ -78,21 +79,64 @@ class Vocabulary:
for token in tokens:
self._vocab[namespace].append_token(token)
# def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE):
# """
# Add the token if not present and return the index even if token was already in the namespace.
#
# :param token: token to be added
# :param namespace: namespace to add the token to
# :return: index of the token in the namespace
# """
#
# if not isinstance(token, str):
# raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token)))
#
def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE):
"""
Add the token if not present and return the index even if token was already in the namespace.
:param token: token to be added
:param namespace: namespace to add the token to
:return: index of the token in the namespace
"""
if not isinstance(token, str):
raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token)))
self._vocab[namespace].append_token(token)
def add_tokens_to_namespace(self, tokens: List[str], namespace: str = DEFAULT_NAMESPACE):
"""
Add the token if not present and return the index even if token was already in the namespace.
:param tokens: tokens to be added
:param namespace: namespace to add the token to
:return: index of the token in the namespace
"""
if not isinstance(tokens, List):
raise ValueError("Vocabulary tokens must be passed as a list of strings. Got %s with type %s" % (
repr(tokens), type(tokens)))
for token in tokens:
self._vocab[namespace].append_token(token)
def get_index_to_token_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[int, str]:
if not isinstance(namespace, str):
raise ValueError(
"Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace)))
itos: List[str] = self._vocab[namespace].get_itos()
return {i: s for i, s in enumerate(itos)}
def get_token_to_index_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[str, int]:
if not isinstance(namespace, str):
raise ValueError(
"Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace)))
return self._vocab[namespace].get_stoi()
def get_token_index(self, token: str, namespace: str = DEFAULT_NAMESPACE) -> int:
return self.get_token_to_index_vocabulary(namespace).get(token)
def get_token_from_index(self, index: int, namespace: str = DEFAULT_NAMESPACE) -> Optional[str]:
return self.get_index_to_token_vocabulary(namespace).get(index)
def get_vocab_size(self, namespace: str = DEFAULT_NAMESPACE) -> int:
return len(self._vocab[namespace].get_itos())
def get_namespaces(self) -> Set[str]:
return set(self._vocab.keys())
@classmethod
def empty(cls):
return cls()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment