import json
from typing import Dict, List

import numpy as np

from sziszapangma.core.transformer.embedding_transformer import EmbeddingTransformer


class FileStoredEmbeddingTransformer(EmbeddingTransformer):
    _cache: Dict[str, np.ndarray]

    def __init__(self, file_path: str):
        with open(file_path, "r") as f:
            json_content = json.loads(f.read())
            self._cache = dict({key: np.array(json_content[key]) for key in json_content.keys()})

    def get_embeddings(self, words: List[str]) -> Dict[str, np.ndarray]:
        return dict({word: self._cache[word] for word in words})

    def get_embedding(self, word: str) -> np.ndarray:
        return self._cache[word]