Skip to content
Snippets Groups Projects
file_stored_embedding_transformer.py 698 B
Newer Older
import json
Marcin Wątroba's avatar
Marcin Wątroba committed
from typing import Dict, List

import numpy as np

Marcin Wątroba's avatar
Marcin Wątroba committed
from sziszapangma.core.transformer.embedding_transformer import EmbeddingTransformer


class FileStoredEmbeddingTransformer(EmbeddingTransformer):
Marcin Wątroba's avatar
Marcin Wątroba committed
    _cache: Dict[str, np.ndarray]

    def __init__(self, file_path: str):
Marcin Wątroba's avatar
Marcin Wątroba committed
        with open(file_path, "r") as f:
            json_content = json.loads(f.read())
Marcin Wątroba's avatar
Marcin Wątroba committed
            self._cache = dict({key: np.array(json_content[key]) for key in json_content.keys()})

    def get_embeddings(self, words: List[str]) -> Dict[str, np.ndarray]:
Marcin Wątroba's avatar
Marcin Wątroba committed
        return dict({word: self._cache[word] for word in words})

    def get_embedding(self, word: str) -> np.ndarray:
        return self._cache[word]