From e83a7797b26fa54f3c11d477e8f6b4e94607e70b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com> Date: Tue, 7 Mar 2023 20:58:45 +0100 Subject: [PATCH] General structure --- combo/data/fields/base_field.py | 5 +- combo/data/samplers/base_sampler.py | 5 +- combo/data/token_indexers/base_indexer.py | 5 +- ...etrained_transformer_mismatched_indexer.py | 2 +- .../token_characters_indexer.py | 2 +- .../token_indexers/token_features_indexer.py | 2 +- combo/data/vocabulary.py | 82 ++++++++++++++----- 7 files changed, 69 insertions(+), 34 deletions(-) diff --git a/combo/data/fields/base_field.py b/combo/data/fields/base_field.py index 83ea563..bf8ccb2 100644 --- a/combo/data/fields/base_field.py +++ b/combo/data/fields/base_field.py @@ -1,5 +1,2 @@ -from abc import ABCMeta - - -class Field(metaclass=ABCMeta): +class Field: pass diff --git a/combo/data/samplers/base_sampler.py b/combo/data/samplers/base_sampler.py index 6e5cd40..e570a36 100644 --- a/combo/data/samplers/base_sampler.py +++ b/combo/data/samplers/base_sampler.py @@ -1,5 +1,2 @@ -from abc import ABCMeta - - -class Sampler(metaclass=ABCMeta): +class Sampler: pass diff --git a/combo/data/token_indexers/base_indexer.py b/combo/data/token_indexers/base_indexer.py index 2fb48c0..fa70d63 100644 --- a/combo/data/token_indexers/base_indexer.py +++ b/combo/data/token_indexers/base_indexer.py @@ -1,7 +1,4 @@ -from abc import ABCMeta - - -class TokenIndexer(metaclass=ABCMeta): +class TokenIndexer: pass diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py index 9aa3616..ea6a663 100644 --- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py @@ -1,4 +1,4 @@ -from combo.data import TokenIndexer +from .base_indexer import TokenIndexer class PretrainedTransformerMismatchedIndexer(TokenIndexer): diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py index 6f5dbf0..f99923e 100644 --- a/combo/data/token_indexers/token_characters_indexer.py +++ b/combo/data/token_indexers/token_characters_indexer.py @@ -1,4 +1,4 @@ -from combo.data import TokenIndexer +from .base_indexer import TokenIndexer class TokenCharactersIndexer(TokenIndexer): diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py index b6267a4..901ffdb 100644 --- a/combo/data/token_indexers/token_features_indexer.py +++ b/combo/data/token_indexers/token_features_indexer.py @@ -1,6 +1,6 @@ """Features indexer.""" -from combo.data import TokenIndexer +from .base_indexer import TokenIndexer class TokenFeatsIndexer(TokenIndexer): diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index d482b6c..feb184e 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -1,5 +1,5 @@ from collections import defaultdict, OrderedDict -from typing import Dict, Union, Optional, Iterable, Callable, Any, Set +from typing import Dict, Union, Optional, Iterable, Callable, Any, Set, List from torchtext.vocab import Vocab as TorchtextVocab from torchtext.vocab import vocab as torchtext_vocab @@ -17,7 +17,7 @@ def match_namespace(pattern: str, namespace: str): (type(pattern), type(namespace))) if pattern == namespace: return True - if len(pattern)>2 and pattern[0] == '*' and namespace.endswith(pattern[1:]): + if len(pattern) > 2 and pattern[0] == '*' and namespace.endswith(pattern[1:]): return True return False @@ -35,13 +35,14 @@ class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]): def __missing__(self, namespace: str): # Non-padded namespace if any([match_namespace(namespace, npn) for npn in self._non_padded_namespaces]): + value = torchtext_vocab(OrderedDict([])) + else: value = torchtext_vocab( OrderedDict([ - (self._padding_token, 0), - (self._oov_token, 1)]) + (self._padding_token, 1), + (self._oov_token, 1) + ]) ) - else: - value = torchtext_vocab(OrderedDict([])) dict.__setitem__(self, namespace, value) return value @@ -78,21 +79,64 @@ class Vocabulary: for token in tokens: self._vocab[namespace].append_token(token) - # def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE): - # """ - # Add the token if not present and return the index even if token was already in the namespace. - # - # :param token: token to be added - # :param namespace: namespace to add the token to - # :return: index of the token in the namespace - # """ - # - # if not isinstance(token, str): - # raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token))) - # + def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE): + """ + Add the token if not present and return the index even if token was already in the namespace. + + :param token: token to be added + :param namespace: namespace to add the token to + :return: index of the token in the namespace + """ + + if not isinstance(token, str): + raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token))) + + self._vocab[namespace].append_token(token) + + def add_tokens_to_namespace(self, tokens: List[str], namespace: str = DEFAULT_NAMESPACE): + """ + Add the token if not present and return the index even if token was already in the namespace. + + :param tokens: tokens to be added + :param namespace: namespace to add the token to + :return: index of the token in the namespace + """ + + if not isinstance(tokens, List): + raise ValueError("Vocabulary tokens must be passed as a list of strings. Got %s with type %s" % ( + repr(tokens), type(tokens))) + for token in tokens: + self._vocab[namespace].append_token(token) + + def get_index_to_token_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[int, str]: + if not isinstance(namespace, str): + raise ValueError( + "Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace))) + + itos: List[str] = self._vocab[namespace].get_itos() + + return {i: s for i, s in enumerate(itos)} + + def get_token_to_index_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[str, int]: + if not isinstance(namespace, str): + raise ValueError( + "Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace))) + + return self._vocab[namespace].get_stoi() + + def get_token_index(self, token: str, namespace: str = DEFAULT_NAMESPACE) -> int: + return self.get_token_to_index_vocabulary(namespace).get(token) + + def get_token_from_index(self, index: int, namespace: str = DEFAULT_NAMESPACE) -> Optional[str]: + return self.get_index_to_token_vocabulary(namespace).get(index) + + def get_vocab_size(self, namespace: str = DEFAULT_NAMESPACE) -> int: + return len(self._vocab[namespace].get_itos()) + + def get_namespaces(self) -> Set[str]: + return set(self._vocab.keys()) @classmethod def empty(cls): return cls() - -- GitLab