From e83a7797b26fa54f3c11d477e8f6b4e94607e70b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com>
Date: Tue, 7 Mar 2023 20:58:45 +0100
Subject: [PATCH] General structure

---
 combo/data/fields/base_field.py               |  5 +-
 combo/data/samplers/base_sampler.py           |  5 +-
 combo/data/token_indexers/base_indexer.py     |  5 +-
 ...etrained_transformer_mismatched_indexer.py |  2 +-
 .../token_characters_indexer.py               |  2 +-
 .../token_indexers/token_features_indexer.py  |  2 +-
 combo/data/vocabulary.py                      | 82 ++++++++++++++-----
 7 files changed, 69 insertions(+), 34 deletions(-)

diff --git a/combo/data/fields/base_field.py b/combo/data/fields/base_field.py
index 83ea563..bf8ccb2 100644
--- a/combo/data/fields/base_field.py
+++ b/combo/data/fields/base_field.py
@@ -1,5 +1,2 @@
-from abc import ABCMeta
-
-
-class Field(metaclass=ABCMeta):
+class Field:
     pass
diff --git a/combo/data/samplers/base_sampler.py b/combo/data/samplers/base_sampler.py
index 6e5cd40..e570a36 100644
--- a/combo/data/samplers/base_sampler.py
+++ b/combo/data/samplers/base_sampler.py
@@ -1,5 +1,2 @@
-from abc import ABCMeta
-
-
-class Sampler(metaclass=ABCMeta):
+class Sampler:
     pass
diff --git a/combo/data/token_indexers/base_indexer.py b/combo/data/token_indexers/base_indexer.py
index 2fb48c0..fa70d63 100644
--- a/combo/data/token_indexers/base_indexer.py
+++ b/combo/data/token_indexers/base_indexer.py
@@ -1,7 +1,4 @@
-from abc import ABCMeta
-
-
-class TokenIndexer(metaclass=ABCMeta):
+class TokenIndexer:
     pass
 
 
diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
index 9aa3616..ea6a663 100644
--- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
@@ -1,4 +1,4 @@
-from combo.data import TokenIndexer
+from .base_indexer import TokenIndexer
 
 
 class PretrainedTransformerMismatchedIndexer(TokenIndexer):
diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py
index 6f5dbf0..f99923e 100644
--- a/combo/data/token_indexers/token_characters_indexer.py
+++ b/combo/data/token_indexers/token_characters_indexer.py
@@ -1,4 +1,4 @@
-from combo.data import TokenIndexer
+from .base_indexer import TokenIndexer
 
 
 class TokenCharactersIndexer(TokenIndexer):
diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py
index b6267a4..901ffdb 100644
--- a/combo/data/token_indexers/token_features_indexer.py
+++ b/combo/data/token_indexers/token_features_indexer.py
@@ -1,6 +1,6 @@
 """Features indexer."""
 
-from combo.data import TokenIndexer
+from .base_indexer import TokenIndexer
 
 
 class TokenFeatsIndexer(TokenIndexer):
diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py
index d482b6c..feb184e 100644
--- a/combo/data/vocabulary.py
+++ b/combo/data/vocabulary.py
@@ -1,5 +1,5 @@
 from collections import defaultdict, OrderedDict
-from typing import Dict, Union, Optional, Iterable, Callable, Any, Set
+from typing import Dict, Union, Optional, Iterable, Callable, Any, Set, List
 
 from torchtext.vocab import Vocab as TorchtextVocab
 from torchtext.vocab import vocab as torchtext_vocab
@@ -17,7 +17,7 @@ def match_namespace(pattern: str, namespace: str):
                          (type(pattern), type(namespace)))
     if pattern == namespace:
         return True
-    if len(pattern)>2 and pattern[0] == '*' and namespace.endswith(pattern[1:]):
+    if len(pattern) > 2 and pattern[0] == '*' and namespace.endswith(pattern[1:]):
         return True
     return False
 
@@ -35,13 +35,14 @@ class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]):
     def __missing__(self, namespace: str):
         # Non-padded namespace
         if any([match_namespace(namespace, npn) for npn in self._non_padded_namespaces]):
+            value = torchtext_vocab(OrderedDict([]))
+        else:
             value = torchtext_vocab(
                 OrderedDict([
-                    (self._padding_token, 0),
-                    (self._oov_token, 1)])
+                    (self._padding_token, 1),
+                    (self._oov_token, 1)
+                ])
             )
-        else:
-            value = torchtext_vocab(OrderedDict([]))
         dict.__setitem__(self, namespace, value)
         return value
 
@@ -78,21 +79,64 @@ class Vocabulary:
             for token in tokens:
                 self._vocab[namespace].append_token(token)
 
-    # def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE):
-    #     """
-    #     Add the token if not present and return the index even if token was already in the namespace.
-    #
-    #     :param token: token to be added
-    #     :param namespace: namespace to add the token to
-    #     :return: index of the token in the namespace
-    #     """
-    #
-    #     if not isinstance(token, str):
-    #         raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token)))
-    #
+    def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE):
+        """
+        Add the token if not present and return the index even if token was already in the namespace.
+
+        :param token: token to be added
+        :param namespace: namespace to add the token to
+        :return: index of the token in the namespace
+        """
+
+        if not isinstance(token, str):
+            raise ValueError("Vocabulary tokens must be strings. Got %s with type %s" % (repr(token), type(token)))
+
+        self._vocab[namespace].append_token(token)
+
+    def add_tokens_to_namespace(self, tokens: List[str], namespace: str = DEFAULT_NAMESPACE):
+        """
+        Add the token if not present and return the index even if token was already in the namespace.
+
+        :param tokens: tokens to be added
+        :param namespace: namespace to add the token to
+        :return: index of the token in the namespace
+        """
+
+        if not isinstance(tokens, List):
+            raise ValueError("Vocabulary tokens must be passed as a list of strings. Got %s with type %s" % (
+            repr(tokens), type(tokens)))
 
+        for token in tokens:
+            self._vocab[namespace].append_token(token)
+
+    def get_index_to_token_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[int, str]:
+        if not isinstance(namespace, str):
+            raise ValueError(
+                "Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace)))
+
+        itos: List[str] = self._vocab[namespace].get_itos()
+
+        return {i: s for i, s in enumerate(itos)}
+
+    def get_token_to_index_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[str, int]:
+        if not isinstance(namespace, str):
+            raise ValueError(
+                "Namespace must be passed as string. Received %s with type %s" % (repr(namespace), type(namespace)))
+
+        return self._vocab[namespace].get_stoi()
+
+    def get_token_index(self, token: str, namespace: str = DEFAULT_NAMESPACE) -> int:
+        return self.get_token_to_index_vocabulary(namespace).get(token)
+
+    def get_token_from_index(self, index: int, namespace: str = DEFAULT_NAMESPACE) -> Optional[str]:
+        return self.get_index_to_token_vocabulary(namespace).get(index)
+
+    def get_vocab_size(self, namespace: str = DEFAULT_NAMESPACE) -> int:
+        return len(self._vocab[namespace].get_itos())
+
+    def get_namespaces(self) -> Set[str]:
+        return set(self._vocab.keys())
 
     @classmethod
     def empty(cls):
         return cls()
-
-- 
GitLab