Commit 41fff8eb authored by Łukasz Kopociński's avatar Łukasz Kopociński

Fix unpacking

parent 96fb3e6b
from itertools import permutations
from typing import Tuple, List
from typing import Tuple, List, NamedTuple
import numpy
import torch
......@@ -8,6 +8,11 @@ from import Vectorizer
from semrel.model.scripts import RelNet
class Results(NamedTuple):
orths: List[Tuple[str, str]]
predictions: numpy.array
class Predictor:
def __init__(
......@@ -60,7 +65,7 @@ class Predictor:
def predict(
self, indices_context: List[Tuple[List[int], List[str]]]
) -> [List[Tuple[str, str]], numpy.array]:
) -> Results:
orths, vectors = self._make_vectors(indices_context)
predictions = self._predict(vectors)
return orths, predictions
return Results(orths, predictions)
from typing import List, Tuple
from typing import List
import numpy
import torch
from semrel.model.scripts import RelNet
from worker.scripts.prediction import Results
def load_model(
......@@ -17,9 +17,10 @@ def load_model(
def format_output(
orths: List[Tuple[str, str]], predictions: numpy.array
results: List[Results]
) -> List[str]:
return [
f'{orth_from} : {orth_to} - {prediction}'
for orths, predictions in results
for (orth_from, orth_to), prediction in zip(orths, predictions)
......@@ -51,12 +51,12 @@ class SemrelWorker(nlp_ws.NLPWorker):
predictor = Predictor(net, self._elmo, self._device)
document = Document(cclutils.read_ccl(input_path))
orths, predictions = zip(*[
results = [
for indices_context in parser(document)
lines = format_output(orths, predictions)
lines = format_output(results)
save_lines(Path(output_path), lines, mode='a+')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment