Commit f02ca270 authored by Łukasz Kopociński's avatar Łukasz Kopociński

Save results to file

parent 4c1decc8
......@@ -49,9 +49,9 @@ class Predictor:
def _predict(self, vectors: torch.Tensor):
with torch.no_grad():
predictions = self._net(vectors)
return torch.argmax(predictions)
return torch.argmax(predictions, 1)
def predict(self, indices_context: List[Tuple]):
orths, vectors = self._make_vectors(indices_context)
predictions = self._predict(vectors)
return orths, predictions
return orths, predictions.numpy()
import logging
from pathlib import Path
from typing import Dict
import nlp_ws
from corpus_ccl import cclutils
from semrel.data.scripts.corpus import Document
from semrel.data.scripts.utils.io import save_lines
from semrel.data.scripts.vectorizers import ElmoVectorizer
from semrel.model.scripts import RelNet
from semrel.model.scripts.utils.utils import get_device
......@@ -60,11 +62,13 @@ class SemrelWorker(nlp_ws.NLPWorker):
document = Document(cclutils.read_ccl(input_path))
for indices_context in parser(document):
predictions = predictor.predict(indices_context)
_log.critical(str(predictions))
# save predictions
# save_lines(Path(output_path), predictions)
orths, predictions = predictor.predict(indices_context)
lines = [
f'{orth_from} : {orth_to} - {prediction}'
for (orth_from, orth_to), prediction in zip(orths, predictions)
]
# save predictions
save_lines(Path(output_path), lines, mode='a+')
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment