Commit 19c6b577 authored by Łukasz Kopociński's avatar Łukasz Kopociński

Change device extraction

parent dffe4ab6
......@@ -32,7 +32,7 @@ class RelNet(nn.Module):
)
def forward(self, x: torch.tensor):
""" x: a concatenation of word vectors """
"""x: a concatenation of word vectors."""
return self.net(x)
def extract_layer_weights(self, layer_name: str):
......
......@@ -90,6 +90,15 @@ COPY deps/model.pt model.pt
ENV PREDICTION_MODEL="/home/model.pt"
# sent2vec
WORKDIR /home/install
RUN git clone https://github.com/epfml/sent2vec.git && \
cd sent2vec/ && \
pip install . && \
cd /home/install && \
rm -rf sent2vec
# install dependencies used in worker
RUN git clone https://gitlab.clarin-pl.eu/team-semantics/semrel-extraction.git && \
cd semrel-extraction && \
......
......@@ -5,13 +5,14 @@ boto==2.49.0
boto3==1.10.27
botocore==1.13.27
click==7.0.0
cython==0.29.15
corpus_ccl==0.9
gensim==3.8.1
matplotlib==3.1.2
nlp_ws==0.6.0
numpy==1.17.4
pandas==0.25.3
pika==0.10.0
pika==0.12.0
PyYAML==5.1.2
scikit-learn==0.21.3
scipy==1.3.3
......
......@@ -13,12 +13,13 @@ class Predictor:
self,
net_model: RelNet,
elmo: ElmoVectorizer,
fasttext: FastTextVectorizer
fasttext: FastTextVectorizer,
device: torch.device
):
self._net = net_model
self._elmo = elmo
self._fasttext = fasttext
self._device = self._net.get_device()
self._device = device
def _make_vectors(self, indices_context: List[Tuple]):
orths = []
......
......@@ -41,15 +41,16 @@ class SemrelWorker(nlp_ws.NLPWorker):
device=self._device.index
)
_log.critical("Loading FASTTEXT model ...")
self._fasttext = FastTextVectorizer(
model_path=constant.FASTTEXT_MODEL
)
self._fasttext = None
# _log.critical("Loading FASTTEXT model ...")
# self._fasttext = FastTextVectorizer(
# model_path=constant.FASTTEXT_MODEL
# )
_log.critical("Loading models completed.")
def process(self, input_path: str, task_options: Dict, output_path: str):
# load model
_log.critical("Load MODEL")
net = load_model(constant.PREDICTION_MODEL)
net = net.to(self._device)
......@@ -58,14 +59,15 @@ class SemrelWorker(nlp_ws.NLPWorker):
else:
parser = Parser(find_nouns)
predictor = Predictor(net, self._elmo, self._fasttext)
predictor = Predictor(net, self._elmo, self._fasttext, self._device)
document = Document(cclutils.read_ccl(input_path))
for indices_context in parser(document):
predictions = predictor.predict(indices_context)
print(indices_context)
# predictions = predictor.predict(indices_context)
# save predictions
save_lines(Path(output_path), predictions)
# save_lines(Path(output_path), predictions)
# def _predict(self, predictor: Predictor, pairs: Iterator):
# pairs = list(pairs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment