Commit baf27917 authored by Łukasz Kopociński's avatar Łukasz Kopociński

Make one vectorizer worker

parent f4f6a906
......@@ -3,7 +3,7 @@ from typing import Tuple, List
import torch
from semrel.data.scripts.vectorizers import ElmoVectorizer, FastTextVectorizer
from semrel.data.scripts.vectorizers import Vectorizer
from semrel.model.scripts import RelNet
......@@ -12,55 +12,38 @@ class Predictor:
def __init__(
self,
net_model: RelNet,
elmo: ElmoVectorizer,
fasttext: FastTextVectorizer,
vectorizer: Vectorizer,
device: torch.device
):
self._net = net_model
self._elmo = elmo
self._fasttext = fasttext
self._vectorizer = vectorizer
self._device = device
def _make_vectors(self, indices_context: List[Tuple]):
orths = []
vectors = []
for indices, context in indices_context:
orths_pairs = [
orth
for index, orth in enumerate(context)
if index in indices
]
_vectors_elmo = self._elmo.embed(context)
_vectors_fasttext = self._fasttext.embed(context)
_vectors_elmo = _vectors_elmo[indices]
_vectors_fasttext = _vectors_fasttext[indices]
orths.extend(orths_pairs)
vectors.append((_vectors_elmo, _vectors_fasttext))
vectors_elmo, vectors_fasttext = zip(*vectors)
vectors_elmo = torch.cat(vectors_elmo)
vectors_fasttext = torch.cat(vectors_fasttext)
size = len(orths)
idx_from, idx_to = zip(*list(permutations(range(size), 2)))
elmo_from = vectors_elmo[[*idx_from]]
elmo_to = vectors_elmo[[*idx_to]]
fasttext_from = vectors_fasttext[[*idx_from]]
fasttext_to = vectors_fasttext[[*idx_to]]
elmo_vectors = torch.cat([elmo_from, elmo_to], 1)
fasttext_vectors = torch.cat([fasttext_from, fasttext_to], 1)
vector = torch.cat([elmo_vectors, fasttext_vectors], 1)
def _make_vectors(
self,
indices_context: List[Tuple]
) -> Tuple[List[Tuple[str, str]], torch.Tensor]:
orths = [orth
for indices, context in indices_context
for index, orth in enumerate(context)
if index in indices]
vectors = [self._vectorizer.embed(context)[indices]
for indices, context in indices_context]
vectors = torch.cat(vectors)
orths_size = len(orths)
orths_indices = range(orths_size)
indices_pairs = [*permutations(orths_indices, 2)]
idx_from, idx_to = zip(*indices_pairs)
vec_from = vectors[[*idx_from]]
vec_to = vectors[[*idx_to]]
vector = torch.cat([vec_from, vec_to], 1)
orths_pairs = [(orths[idx_f], orths[idx_t])
for idx_f, idx_t in zip(idx_from, idx_to)]
for idx_f, idx_t in indices_pairs]
return orths_pairs, vector.to(self._device)
def _predict(self, vectors: torch.Tensor):
......
import logging
from pathlib import Path
from typing import Dict, Iterator
from typing import Dict
import nlp_ws
from corpus_ccl import cclutils
from semrel.data.scripts.corpus import Document
from semrel.data.scripts.utils.io import save_lines
from semrel.data.scripts.vectorizers import ElmoVectorizer, FastTextVectorizer
from semrel.data.scripts.vectorizers import ElmoVectorizer
from semrel.model.scripts import RelNet
from semrel.model.scripts.utils.utils import get_device
from worker.scripts import constant
......@@ -41,10 +39,10 @@ class SemrelWorker(nlp_ws.NLPWorker):
device=self._device.index
)
_log.critical("Loading FASTTEXT model ...")
self._fasttext = FastTextVectorizer(
model_path=constant.FASTTEXT_MODEL
)
# _log.critical("Loading FASTTEXT model ...")
# self._fasttext = FastTextVectorizer(
# model_path=constant.FASTTEXT_MODEL
# )
_log.critical("Loading models completed.")
......@@ -58,7 +56,7 @@ class SemrelWorker(nlp_ws.NLPWorker):
else:
parser = Parser(find_nouns)
predictor = Predictor(net, self._elmo, self._fasttext, self._device)
predictor = Predictor(net, self._elmo, self._device)
document = Document(cclutils.read_ccl(input_path))
for indices_context in parser(document):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment