Commit f4f6a906 authored by Łukasz Kopociński's avatar Łukasz Kopociński

Show predictions for orths

parent 35b5433b
......@@ -26,7 +26,7 @@ class Predictor:
vectors = []
for indices, context in indices_context:
_orths = [
orths_pairs = [
orth
for index, orth in enumerate(context)
if index in indices
......@@ -37,7 +37,7 @@ class Predictor:
_vectors_elmo = _vectors_elmo[indices]
_vectors_fasttext = _vectors_fasttext[indices]
orths.extend(_orths)
orths.extend(orths_pairs)
vectors.append((_vectors_elmo, _vectors_fasttext))
vectors_elmo, vectors_fasttext = zip(*vectors)
......@@ -59,14 +59,16 @@ class Predictor:
vector = torch.cat([elmo_vectors, fasttext_vectors], 1)
return vector.to(self._device)
orths_pairs = [(orths[idx_f], orths[idx_t])
for idx_f, idx_t in zip(idx_from, idx_to)]
return orths_pairs, vector.to(self._device)
def _predict(self, vectors: torch.Tensor):
with torch.no_grad():
predictions = self._net(vectors)
predictions = torch.argmax(predictions)
return predictions
return torch.argmax(predictions)
def predict(self, indices_context: List[Tuple]):
vectors = self._make_vectors(indices_context)
return self._predict(vectors)
orths, vectors = self._make_vectors(indices_context)
predictions = self._predict(vectors)
return zip(orths, predictions)
......@@ -41,11 +41,10 @@ class SemrelWorker(nlp_ws.NLPWorker):
device=self._device.index
)
self._fasttext = self._elmo
# _log.critical("Loading FASTTEXT model ...")
# self._fasttext = FastTextVectorizer(
# model_path=constant.FASTTEXT_MODEL
# )
_log.critical("Loading FASTTEXT model ...")
self._fasttext = FastTextVectorizer(
model_path=constant.FASTTEXT_MODEL
)
_log.critical("Loading models completed.")
......@@ -66,24 +65,9 @@ class SemrelWorker(nlp_ws.NLPWorker):
predictions = predictor.predict(indices_context)
_log.critical(str(predictions))
# save predictions
# save_lines(Path(output_path), predictions)
# def _predict(self, predictor: Predictor, pairs: Iterator):
# pairs = list(pairs)
# members_from, members_to = zip(*pairs)
# orths_from = [context[index] for index, context in members_from]
# orths_to = [context[index] for index, context in members_to]
#
# predictions = [predictor.predict(pair) for pair in pairs]
#
# return [
# f'{orth_from}\t{orth_to}: {decision}\n'
# for orth_from, orth_to, decision
# in zip(orths_from, orths_to, predictions)
# ]
if __name__ == '__main__':
_log.critical("Start semrel prediction.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment