Skip to content
Snippets Groups Projects
file_stored_embedding_transformer.py 782 B
Newer Older
import json
from typing import List, Dict

import numpy as np

from sziszapangma.core.transformer.embedding_transformer import \
    EmbeddingTransformer


class FileStoredEmbeddingTransformer(EmbeddingTransformer):
    _cache: Dict[str, np.array]

    def __init__(self, file_path: str):
        with open(file_path, 'r') as f:
            json_content = json.loads(f.read())
            self._cache = dict({
                key: np.array(json_content[key])
                for key in json_content.keys()
            })

    def get_embeddings(self, words: List[str]) -> Dict[str, np.ndarray]:
        return dict({
            word: self._cache[word]
            for word in words
        })

    def get_embedding(self, word: str) -> np.ndarray:
        return self._cache[word]