Commit 96fb3e6b authored by Łukasz Kopociński's avatar Łukasz Kopociński

Refactor worker

parent 6c20a98a
......@@ -5,8 +5,6 @@ from typing import Deque, Tuple, Callable, List
from semrel.data.scripts.corpus import Document, DocSentence
WINDOW_SIZE = 3
class Slicer(ABC):
......@@ -38,8 +36,8 @@ class SentenceWindow(Slicer):
class Parser:
def __init__(self, extractor: Callable):
self._slicer = SentenceWindow(window_size=WINDOW_SIZE)
def __init__(self, slicer: Slicer, extractor: Callable):
self._slicer = slicer
self._extractor = extractor
def __call__(self, document: Document) -> List[Tuple]:
......
from itertools import permutations
from typing import Tuple, List
import numpy
import torch
from semrel.data.scripts.vectorizers import Vectorizer
......@@ -23,35 +24,43 @@ class Predictor:
self,
indices_context: List[Tuple]
) -> Tuple[List[Tuple[str, str]], torch.Tensor]:
orths = [orth
for indices, context in indices_context
for index, orth in enumerate(context)
if index in indices]
orths = [
orth
for indices, context in indices_context
for index, orth in enumerate(context)
if index in indices
]
vectors = [self._vectorizer.embed(context)[indices]
for indices, context in indices_context]
vectors = [
self._vectorizer.embed(context)[indices]
for indices, context in indices_context
]
vectors = torch.cat(vectors)
orths_size = len(orths)
orths_indices = range(orths_size)
indices_pairs = [*permutations(orths_indices, 2)]
idx_from, idx_to = zip(*indices_pairs)
indices_pairs = [*permutations(orths_indices, r=2)]
indices_from, indices_to = zip(*indices_pairs)
vec_from = vectors[[*idx_from]]
vec_to = vectors[[*idx_to]]
vector = torch.cat([vec_from, vec_to], 1)
vectors_from = vectors[[*indices_from]]
vectors_to = vectors[[*indices_to]]
vector = torch.cat([vectors_from, vectors_to], dim=1)
orths_pairs = [(orths[idx_f], orths[idx_t])
for idx_f, idx_t in indices_pairs]
orths_pairs = [
(orths[idx_f], orths[idx_t])
for idx_f, idx_t in indices_pairs
]
return orths_pairs, vector.to(self._device)
def _predict(self, vectors: torch.Tensor):
with torch.no_grad():
predictions = self._net(vectors)
return torch.argmax(predictions, 1)
return torch.argmax(predictions, dim=1).cpu().numpy()
def predict(self, indices_context: List[Tuple]):
def predict(
self, indices_context: List[Tuple[List[int], List[str]]]
) -> [List[Tuple[str, str]], numpy.array]:
orths, vectors = self._make_vectors(indices_context)
predictions = self._predict(vectors)
return orths, predictions.cpu().numpy()
return orths, predictions
from typing import List, Tuple
import numpy
import torch
from semrel.model.scripts import RelNet
def load_model(
model_path: str, vector_size: int, device: torch.device
) -> RelNet:
net = RelNet(in_dim=vector_size)
net.load(model_path)
net = net.to(device)
net.eval()
return net
def format_output(
orths: List[Tuple[str, str]], predictions: numpy.array
) -> List[str]:
return [
f'{orth_from} : {orth_to} - {prediction}'
for (orth_from, orth_to), prediction in zip(orths, predictions)
]
......@@ -8,22 +8,16 @@ from corpus_ccl import cclutils
from semrel.data.scripts.corpus import Document
from semrel.data.scripts.utils.io import save_lines
from semrel.data.scripts.vectorizers import ElmoVectorizer
from semrel.model.scripts import RelNet
from semrel.model.scripts.utils.utils import get_device
from worker.scripts import constant
from worker.scripts.extractor import Parser, find_named_entities, find_nouns
from worker.scripts.extractor import \
Parser, find_named_entities, find_nouns, SentenceWindow
from worker.scripts.prediction import Predictor
from worker.scripts.utils import load_model, format_output
_log = logging.getLogger(__name__)
def load_model(model_path: str, vector_size: int = 2048) -> RelNet:
net = RelNet(in_dim=vector_size)
net.load(model_path)
net.eval()
return net
class SemrelWorker(nlp_ws.NLPWorker):
@classmethod
......@@ -31,44 +25,39 @@ class SemrelWorker(nlp_ws.NLPWorker):
pass
def init(self):
_log.critical("Loading models.")
self._vector_size = 2048
self._window_size = 3
self._device = get_device()
_log.critical("Loading ELMO model ...")
self._elmo = ElmoVectorizer(
options_path=constant.ELMO_MODEL_OPTIONS,
weights_path=constant.ELMO_MODEL_WEIGHTS,
device=self._device.index
)
# _log.critical("Loading FASTTEXT model ...")
# self._fasttext = FastTextVectorizer(
# model_path=constant.FASTTEXT_MODEL
# )
_log.critical("Loading models completed.")
def process(self, input_path: str, task_options: Dict, output_path: str):
_log.critical("Load MODEL")
net = load_model(constant.PREDICTION_MODEL)
net = net.to(self._device)
def process(
self, input_path: str, task_options: Dict, output_path: str
) -> None:
net = load_model(
model_path=constant.PREDICTION_MODEL,
vector_size=self._vector_size,
device=self._device
)
if task_options.get(constant.NER_KEY, False):
parser = Parser(find_named_entities)
else:
parser = Parser(find_nouns)
is_ner_task = task_options.get(constant.NER_KEY, False)
extractor = find_named_entities if is_ner_task else find_nouns
slicer = SentenceWindow(window_size=self._window_size)
parser = Parser(slicer, extractor)
predictor = Predictor(net, self._elmo, self._device)
document = Document(cclutils.read_ccl(input_path))
for indices_context in parser(document):
orths, predictions = predictor.predict(indices_context)
lines = [
f'{orth_from} : {orth_to} - {prediction}'
for (orth_from, orth_to), prediction in zip(orths, predictions)
]
# save predictions
save_lines(Path(output_path), lines, mode='a+')
orths, predictions = zip(*[
predictor.predict(indices_context)
for indices_context in parser(document)
])
lines = format_output(orths, predictions)
save_lines(Path(output_path), lines, mode='a+')
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment