Skip to content
Snippets Groups Projects
Commit e83a7797 authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

General structure

parent 1bf88a88
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
from abc import ABCMeta class Field:
class Field(metaclass=ABCMeta):
pass pass
from abc import ABCMeta class Sampler:
class Sampler(metaclass=ABCMeta):
pass pass
from abc import ABCMeta class TokenIndexer:
class TokenIndexer(metaclass=ABCMeta):
pass pass
......
from combo.data import TokenIndexer from .base_indexer import TokenIndexer
class PretrainedTransformerMismatchedIndexer(TokenIndexer): class PretrainedTransformerMismatchedIndexer(TokenIndexer):
......
from combo.data import TokenIndexer from .base_indexer import TokenIndexer
class TokenCharactersIndexer(TokenIndexer): class TokenCharactersIndexer(TokenIndexer):
......
"""Features indexer.""" """Features indexer."""
from combo.data import TokenIndexer from .base_indexer import TokenIndexer
class TokenFeatsIndexer(TokenIndexer): class TokenFeatsIndexer(TokenIndexer):
......
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from typing import Dict, Union, Optional, Iterable, Callable, Any, Set from typing import Dict, Union, Optional, Iterable, Callable, Any, Set, List
from torchtext.vocab import Vocab as TorchtextVocab from torchtext.vocab import Vocab as TorchtextVocab
from torchtext.vocab import vocab as torchtext_vocab from torchtext.vocab import vocab as torchtext_vocab
...@@ -17,7 +17,7 @@ def match_namespace(pattern: str, namespace: str): ...@@ -17,7 +17,7 @@ def match_namespace(pattern: str, namespace: str):
(type(pattern), type(namespace))) (type(pattern), type(namespace)))
if pattern == namespace: if pattern == namespace:
return True return True
if len(pattern)>2 and pattern[0] == '*' and namespace.endswith(pattern[1:]): if len(pattern) > 2 and pattern[0] == '*' and namespace.endswith(pattern[1:]):
return True return True
return False return False
...@@ -35,13 +35,14 @@ class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]): ...@@ -35,13 +35,14 @@ class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]):
def __missing__(self, namespace: str): def __missing__(self, namespace: str):
# Non-padded namespace # Non-padded namespace
if any([match_namespace(namespace, npn) for npn in self._non_padded_namespaces]): if any([match_namespace(namespace, npn) for npn in self._non_padded_namespaces]):
value = torchtext_vocab(OrderedDict([]))
else:
value = torchtext_vocab( value = torchtext_vocab(
OrderedDict([ OrderedDict([
(self._padding_token, 0), (self._padding_token, 1),
(self._oov_token, 1)]) (self._oov_token, 1)
])
) )
else:
value = torchtext_vocab(OrderedDict([]))
dict.__setitem__(self, namespace, value) dict.__setitem__(self, namespace, value)
return value return value
...@@ -78,21 +79,64 @@ class Vocabulary: ...@@ -78,21 +79,64 @@ class Vocabulary:
for token in tokens: for token in tokens:
self._vocab[namespace].append_token(token) self._vocab[namespace].append_token(token)
# def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE): def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE):
# """ """
# Add the token if not present and return the index even if token was already in the namespace. Add the token if not present and return the index even if token was already in the namespace.
#
# :param token: token to be added :param token: token to be added
# :param namespace: namespace to add the token to :param namespace: namespace to add the token to
# :return: index of the token in the namespace :return: index of the token in the namespace
# """ """
#
# if not isinstance(token, str): if not isinstance(token, str):
# raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token))) raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token)))
#
self._vocab[namespace].append_token(token)
def add_tokens_to_namespace(self, tokens: List[str], namespace: str = DEFAULT_NAMESPACE):
"""
Add the token if not present and return the index even if token was already in the namespace.
:param tokens: tokens to be added
:param namespace: namespace to add the token to
:return: index of the token in the namespace
"""
if not isinstance(tokens, List):
raise ValueError("Vocabulary tokens must be passed as a list of strings. Got %s with type %s" % (
repr(tokens), type(tokens)))
for token in tokens:
self._vocab[namespace].append_token(token)
def get_index_to_token_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[int, str]:
if not isinstance(namespace, str):
raise ValueError(
"Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace)))
itos: List[str] = self._vocab[namespace].get_itos()
return {i: s for i, s in enumerate(itos)}
def get_token_to_index_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[str, int]:
if not isinstance(namespace, str):
raise ValueError(
"Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace)))
return self._vocab[namespace].get_stoi()
def get_token_index(self, token: str, namespace: str = DEFAULT_NAMESPACE) -> int:
return self.get_token_to_index_vocabulary(namespace).get(token)
def get_token_from_index(self, index: int, namespace: str = DEFAULT_NAMESPACE) -> Optional[str]:
return self.get_index_to_token_vocabulary(namespace).get(index)
def get_vocab_size(self, namespace: str = DEFAULT_NAMESPACE) -> int:
return len(self._vocab[namespace].get_itos())
def get_namespaces(self) -> Set[str]:
return set(self._vocab.keys())
@classmethod @classmethod
def empty(cls): def empty(cls):
return cls() return cls()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment