Skip to content
Snippets Groups Projects
Commit c7b78af8 authored by Wiktor Walentynowicz's avatar Wiktor Walentynowicz :construction_worker_tone1:
Browse files

Update library version to 0.4.0. Fix prediction.

parent 679a70c2
Branches
1 merge request!5Develop
Pipeline #7053 passed with stage
in 29 seconds
torch==1.13.0
--index-url https://pypi.clarin-pl.eu/simple/
nlp_ws
winer[torch-gpu]==0.3.1
winer==0.4.0
awscli==1.22.57
\ No newline at end of file
--index-url https://pypi.clarin-pl.eu/simple/
nlp_ws
winer==0.3.1
winer==0.4.0
awscli==1.22.57
\ No newline at end of file
"""Implementation of punctuator service"""
from winer.datafiles import read_clarin_json, write_clarin_json
from winer.document import create_document_from_clarin_json, \
create_entities_from_hf_outputs
from winer.winer import Winer
from winer import Winer
from winer import allign_into_entities, allign_into_words
from winer import create_document_from_clarin_json, create_entities_from_hf_outputs
from winer import read_clarin_json, write_clarin_json
import logging
......@@ -16,7 +16,7 @@ class WinerWorker:
):
logging.info("Loading models...")
self.active_model = Winer(models_location)
self._model = Winer(models_location)
def process(
self,
......@@ -24,12 +24,25 @@ class WinerWorker:
task_options: dict,
output_path: str
) -> None:
# Read inputs
documents = [create_document_from_clarin_json(read_clarin_json(input_path))]
outputs = self.active_model.predict(
[str(document) for document in documents]
)
tok_inputs = [document.get_pretokenized_text() for document in documents]
plain_inputs = [document.get_text() for document in documents]
# Predicion
encoded_inputs = self._model.tokenize(tok_inputs)
predictions_per_token = self._model.predict(encoded_inputs)
# Aggregation
aggregations = [allign_into_words(tokens, labels)
for tokens, labels in zip(
[encoded_inputs.word_ids(index)
for index in range(len(predictions_per_token))],
predictions_per_token)]
entities = allign_into_entities(tok_inputs, aggregations,
inputs_sentences=plain_inputs)
for idx in range(len(documents)):
documents[idx].add_entites(create_entities_from_hf_outputs(outputs[idx]))
documents[idx].add_entites(create_entities_from_hf_outputs(entities[idx]))
write_clarin_json(documents[0].as_clarin_json(), output_path)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment