diff --git a/combo/data/fields/base_field.py b/combo/data/fields/base_field.py index 83ea563e6e7cb424cadbe23428045059f2cb04a5..bf8ccb254ef055baac86464f0bf74bfcd7c91a6f 100644 --- a/combo/data/fields/base_field.py +++ b/combo/data/fields/base_field.py @@ -1,5 +1,2 @@ -from abc import ABCMeta - - -class Field(metaclass=ABCMeta): +class Field: pass diff --git a/combo/data/samplers/base_sampler.py b/combo/data/samplers/base_sampler.py index 6e5cd4017069e289f4661211e2d8b0aa38a653f3..e570a36d5a38abc945e93a133a5b91de376b0489 100644 --- a/combo/data/samplers/base_sampler.py +++ b/combo/data/samplers/base_sampler.py @@ -1,5 +1,2 @@ -from abc import ABCMeta - - -class Sampler(metaclass=ABCMeta): +class Sampler: pass diff --git a/combo/data/token_indexers/base_indexer.py b/combo/data/token_indexers/base_indexer.py index 2fb48c0c047365b26cf3c7343da1523fc5077d1f..fa70d6355c0109b55b5df5336bda3b3bfe6c7285 100644 --- a/combo/data/token_indexers/base_indexer.py +++ b/combo/data/token_indexers/base_indexer.py @@ -1,7 +1,4 @@ -from abc import ABCMeta - - -class TokenIndexer(metaclass=ABCMeta): +class TokenIndexer: pass diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py index 9aa3616cfd44bf7dc5baabfb6e3c98417e0070fb..ea6a663610912b241fa1470a2f3a6eb7ccf8af40 100644 --- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py @@ -1,4 +1,4 @@ -from combo.data import TokenIndexer +from .base_indexer import TokenIndexer class PretrainedTransformerMismatchedIndexer(TokenIndexer): diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py index 6f5dbf035ce518783aa799e7330e51c0c92a64ba..f99923eae83458d71c15d8c746bb175446efa52e 100644 --- a/combo/data/token_indexers/token_characters_indexer.py +++ b/combo/data/token_indexers/token_characters_indexer.py @@ -1,4 +1,4 @@ -from combo.data import TokenIndexer +from .base_indexer import TokenIndexer class TokenCharactersIndexer(TokenIndexer): diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py index b6267a4377bcbe665a2716c8bf7697454b2f738f..901ffdb1e8d1ce4f872f95833cb6ceefb012ed68 100644 --- a/combo/data/token_indexers/token_features_indexer.py +++ b/combo/data/token_indexers/token_features_indexer.py @@ -1,6 +1,6 @@ """Features indexer.""" -from combo.data import TokenIndexer +from .base_indexer import TokenIndexer class TokenFeatsIndexer(TokenIndexer): diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index d482b6c51be0f7a4f01aa574f944d788a68bfaa3..feb184e2eb2d84ee19f29d9a025d0406cbd97e9a 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -1,5 +1,5 @@ from collections import defaultdict, OrderedDict -from typing import Dict, Union, Optional, Iterable, Callable, Any, Set +from typing import Dict, Union, Optional, Iterable, Callable, Any, Set, List from torchtext.vocab import Vocab as TorchtextVocab from torchtext.vocab import vocab as torchtext_vocab @@ -17,7 +17,7 @@ def match_namespace(pattern: str, namespace: str): (type(pattern), type(namespace))) if pattern == namespace: return True - if len(pattern)>2 and pattern[0] == '*' and namespace.endswith(pattern[1:]): + if len(pattern) > 2 and pattern[0] == '*' and namespace.endswith(pattern[1:]): return True return False @@ -35,13 +35,14 @@ class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]): def __missing__(self, namespace: str): # Non-padded namespace if any([match_namespace(namespace, npn) for npn in self._non_padded_namespaces]): + value = torchtext_vocab(OrderedDict([])) + else: value = torchtext_vocab( OrderedDict([ - (self._padding_token, 0), - (self._oov_token, 1)]) + (self._padding_token, 1), + (self._oov_token, 1) + ]) ) - else: - value = torchtext_vocab(OrderedDict([])) dict.__setitem__(self, namespace, value) return value @@ -78,21 +79,64 @@ class Vocabulary: for token in tokens: self._vocab[namespace].append_token(token) - # def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE): - # """ - # Add the token if not present and return the index even if token was already in the namespace. - # - # :param token: token to be added - # :param namespace: namespace to add the token to - # :return: index of the token in the namespace - # """ - # - # if not isinstance(token, str): - # raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token))) - # + def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE): + """ + Add the token if not present and return the index even if token was already in the namespace. + + :param token: token to be added + :param namespace: namespace to add the token to + :return: index of the token in the namespace + """ + + if not isinstance(token, str): + raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token))) + + self._vocab[namespace].append_token(token) + + def add_tokens_to_namespace(self, tokens: List[str], namespace: str = DEFAULT_NAMESPACE): + """ + Add the token if not present and return the index even if token was already in the namespace. + + :param tokens: tokens to be added + :param namespace: namespace to add the token to + :return: index of the token in the namespace + """ + + if not isinstance(tokens, List): + raise ValueError("Vocabulary tokens must be passed as a list of strings. Got %s with type %s" % ( + repr(tokens), type(tokens))) + for token in tokens: + self._vocab[namespace].append_token(token) + + def get_index_to_token_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[int, str]: + if not isinstance(namespace, str): + raise ValueError( + "Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace))) + + itos: List[str] = self._vocab[namespace].get_itos() + + return {i: s for i, s in enumerate(itos)} + + def get_token_to_index_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[str, int]: + if not isinstance(namespace, str): + raise ValueError( + "Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace))) + + return self._vocab[namespace].get_stoi() + + def get_token_index(self, token: str, namespace: str = DEFAULT_NAMESPACE) -> int: + return self.get_token_to_index_vocabulary(namespace).get(token) + + def get_token_from_index(self, index: int, namespace: str = DEFAULT_NAMESPACE) -> Optional[str]: + return self.get_index_to_token_vocabulary(namespace).get(index) + + def get_vocab_size(self, namespace: str = DEFAULT_NAMESPACE) -> int: + return len(self._vocab[namespace].get_itos()) + + def get_namespaces(self) -> Set[str]: + return set(self._vocab.keys()) @classmethod def empty(cls): return cls() -