Commit 35b5433b authored by Łukasz Kopociński's avatar Łukasz Kopociński

Fix prediction

parent 19c6b577
......@@ -25,7 +25,7 @@ class Predictor:
orths = []
vectors = []
for indices, context in zip(*indices_context):
for indices, context in indices_context:
_orths = [
orth
for index, orth in enumerate(context)
......@@ -46,18 +46,18 @@ class Predictor:
size = len(orths)
idx_from, idx_to = zip(*list(permutations(range(size))))
idx_from, idx_to = zip(*list(permutations(range(size), 2)))
elmo_from = vectors_elmo[idx_from]
elmo_to = vectors_elmo[idx_to]
elmo_from = vectors_elmo[[*idx_from]]
elmo_to = vectors_elmo[[*idx_to]]
fasttext_from = vectors_fasttext[idx_from]
fasttext_to = vectors_fasttext[idx_to]
fasttext_from = vectors_fasttext[[*idx_from]]
fasttext_to = vectors_fasttext[[*idx_to]]
elmo_vectors = torch.cat([elmo_from, elmo_to])
fasttext_vectors = torch.cat([fasttext_from, fasttext_to])
elmo_vectors = torch.cat([elmo_from, elmo_to], 1)
fasttext_vectors = torch.cat([fasttext_from, fasttext_to], 1)
vector = torch.cat([elmo_vectors, fasttext_vectors])
vector = torch.cat([elmo_vectors, fasttext_vectors], 1)
return vector.to(self._device)
......
......@@ -41,7 +41,7 @@ class SemrelWorker(nlp_ws.NLPWorker):
device=self._device.index
)
self._fasttext = None
self._fasttext = self._elmo
# _log.critical("Loading FASTTEXT model ...")
# self._fasttext = FastTextVectorizer(
# model_path=constant.FASTTEXT_MODEL
......@@ -63,8 +63,9 @@ class SemrelWorker(nlp_ws.NLPWorker):
document = Document(cclutils.read_ccl(input_path))
for indices_context in parser(document):
print(indices_context)
# predictions = predictor.predict(indices_context)
predictions = predictor.predict(indices_context)
_log.critical(str(predictions))
# save predictions
# save_lines(Path(output_path), predictions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment