Newer
Older
import json
from typing import List, Dict
import numpy as np
from sziszapangma.core.transformer.embedding_transformer import \
EmbeddingTransformer
class FileStoredEmbeddingTransformer(EmbeddingTransformer):
_cache: Dict[str, np.array]
def __init__(self, file_path: str):
with open(file_path, 'r') as f:
json_content = json.loads(f.read())
self._cache = dict({
key: np.array(json_content[key])
for key in json_content.keys()
})
def get_embeddings(self, words: List[str]) -> Dict[str, np.ndarray]:
return dict({
word: self._cache[word]
for word in words
})
def get_embedding(self, word: str) -> np.ndarray:
return self._cache[word]