Skip to content
Snippets Groups Projects
Commit f0609100 authored by Marcin Wątroba's avatar Marcin Wątroba
Browse files

Update worker

parent 95002a46
Branches
No related merge requests found
......@@ -71,7 +71,7 @@ def add_facebook_hf_wav2vec2_asr(channel: BlockingChannel):
get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2]),
'hf_facebook_wav2vec2_asr',
channel,
'hf_facebook_wav2vec2_asr'
'asr_benchmark_asr_run'
)
......@@ -88,8 +88,8 @@ def main():
connection = pika.BlockingConnection(parameters=parameters)
channel = connection.channel()
# add_whisper(channel)
# add_facebook_hf_wav2vec2_asr(channel)
add_facebook_hf_wav2vec2_pipeline(channel)
add_facebook_hf_wav2vec2_asr(channel)
# add_facebook_hf_wav2vec2_pipeline(channel)
connection.close()
......
......@@ -30,7 +30,7 @@ class Wav2Vec2AsrProcessor(AsrProcessor):
pred_ids = torch.argmax(logits, dim=-1)
result = self._wav2vec2_processor.batch_decode(pred_ids)[0]
return {
"transcription": [create_new_word(it) for it in result.split()],
"transcription": result.split(),
"full_text": result,
"words_time_alignment": None
}
......@@ -34,7 +34,7 @@ def run_hf_facebook_wav2vec2_asr_task(dataset_name: str, asr_name: str):
AsrTask(
asr_property_name=PropertyHelper.asr_result(asr_name),
task_name=f'AsrTask___{dataset_name}___{asr_name}',
require_update=False,
require_update=True,
asr_processor=get_asr_processor(asr_name),
record_path_provider=record_provider
)
......
import functools
import json
import os
import threading
import pika
from pika.adapters.blocking_connection import BlockingChannel
from new_experiment.pipeline.pipeline_process_asr import run_hf_facebook_wav2vec2_asr_task
from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline
from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline
from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline
from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline
from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline
from new_experiment.utils.param_util import get_param
_RABBIT_URL = get_param('RABBIT_URL',
'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
def ack_message(channel, delivery_tag):
"""Note that `channel` must be the same pika channel instance via which
the message being ACKed was retrieved (AMQP protocol constraint).
"""
if channel.is_open:
channel.basic_ack(delivery_tag)
else:
# Channel is already closed, so we can't ACK this message;
# log and/or do something that makes sense for your app in this case.
pass
from new_experiment.queue_base import process_queue
def do_work(connection, channel, delivery_tag, body):
thread_id = threading.get_ident()
# fmt1 = 'Thread id: {} Delivery tag: {} Message body: {}'
# LOGGER.info(fmt1.format(thread_id, delivery_tag, body))
def process_message(body: bytes):
print(body)
message_dict = json.loads(body.decode('utf-8'))
print(message_dict)
......@@ -45,40 +15,6 @@ def do_work(connection, channel, delivery_tag, body):
if task == 'hf_facebook_wav2vec2_asr':
run_hf_facebook_wav2vec2_asr_task(dataset, asr_name)
cb = functools.partial(ack_message, channel, delivery_tag)
connection.add_callback_threadsafe(cb)
print('\n#########################\n')
def on_message(channel: BlockingChannel, method_frame, header_frame, body, args):
(connection, threads) = args
delivery_tag = method_frame.delivery_tag
t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body))
t.start()
threads.append(t)
def new_main():
parameters = pika.URLParameters(_RABBIT_URL)
connection = pika.BlockingConnection(parameters)
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
threads = []
on_message_callback = functools.partial(on_message, args=(connection, threads))
channel.basic_consume('hf_facebook_wav2vec2_asr', on_message_callback)
try:
channel.start_consuming()
except KeyboardInterrupt:
channel.stop_consuming()
# Wait for all to complete
for thread in threads:
thread.join()
connection.close()
if __name__ == '__main__':
new_main()
process_queue('asr_benchmark_asr_run', process_message)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment