From 95002a468e758efe1c20451ce1b767fda15bf3da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com> Date: Sat, 14 Jan 2023 00:54:51 +0100 Subject: [PATCH] Change facebook wav2vec2 model --- new_experiment/queue_base.py | 58 ++++++++++++++++++++++++++++ new_experiment/worker_pipeline.py | 64 +------------------------------ 2 files changed, 60 insertions(+), 62 deletions(-) create mode 100644 new_experiment/queue_base.py diff --git a/new_experiment/queue_base.py b/new_experiment/queue_base.py new file mode 100644 index 0000000..6307f56 --- /dev/null +++ b/new_experiment/queue_base.py @@ -0,0 +1,58 @@ +import functools +import threading +from typing import Callable + +import pika +from pika.adapters.blocking_connection import BlockingChannel + +from new_experiment.utils.param_util import get_param + +_RABBIT_URL = get_param('RABBIT_URL', + 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') + + +def ack_message(channel, delivery_tag): + """Note that `channel` must be the same pika channel instance via which + the message being ACKed was retrieved (AMQP protocol constraint). + """ + if channel.is_open: + channel.basic_ack(delivery_tag) + else: + # Channel is already closed, so we can't ACK this message; + # log and/or do something that makes sense for your app in this case. + pass + + +def process_queue(queue_name: str, process_message: Callable[[bytes], None]): + def do_work(connection, channel, delivery_tag, body): + process_message(body) + cb = functools.partial(ack_message, channel, delivery_tag) + connection.add_callback_threadsafe(cb) + print('\n#########################\n') + + def on_message(channel: BlockingChannel, method_frame, header_frame, body, args): + (connection, threads) = args + delivery_tag = method_frame.delivery_tag + t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body)) + t.start() + threads.append(t) + + parameters = pika.URLParameters(_RABBIT_URL) + connection = pika.BlockingConnection(parameters) + channel = connection.channel() + channel.basic_qos(prefetch_count=1) + + threads = [] + on_message_callback = functools.partial(on_message, args=(connection, threads)) + channel.basic_consume(queue_name, on_message_callback) + + try: + channel.start_consuming() + except KeyboardInterrupt: + channel.stop_consuming() + + # Wait for all to complete + for thread in threads: + thread.join() + + connection.close() diff --git a/new_experiment/worker_pipeline.py b/new_experiment/worker_pipeline.py index 9acbe7c..27dbc61 100644 --- a/new_experiment/worker_pipeline.py +++ b/new_experiment/worker_pipeline.py @@ -1,22 +1,11 @@ import json -import os - -import functools -import logging -import pika -import threading - -from pika.adapters.blocking_connection import BlockingChannel from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline -from new_experiment.utils.param_util import get_param - -_RABBIT_URL = get_param('RABBIT_URL', - 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') +from new_experiment.queue_base import process_queue def process_message(body: bytes): @@ -41,54 +30,5 @@ def process_message(body: bytes): raise Exception(f"Bad message {message_dict}") -def ack_message(channel, delivery_tag): - """Note that `channel` must be the same pika channel instance via which - the message being ACKed was retrieved (AMQP protocol constraint). - """ - if channel.is_open: - channel.basic_ack(delivery_tag) - else: - # Channel is already closed, so we can't ACK this message; - # log and/or do something that makes sense for your app in this case. - pass - - -def do_work(connection, channel, delivery_tag, body): - process_message(body) - cb = functools.partial(ack_message, channel, delivery_tag) - connection.add_callback_threadsafe(cb) - print('\n#########################\n') - - -def on_message(channel: BlockingChannel, method_frame, header_frame, body, args): - (connection, threads) = args - delivery_tag = method_frame.delivery_tag - t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body)) - t.start() - threads.append(t) - - -def new_main(): - parameters = pika.URLParameters(_RABBIT_URL) - connection = pika.BlockingConnection(parameters) - channel = connection.channel() - channel.basic_qos(prefetch_count=1) - - threads = [] - on_message_callback = functools.partial(on_message, args=(connection, threads)) - channel.basic_consume('hf_facebook_wav2vec2_asr', on_message_callback) - - try: - channel.start_consuming() - except KeyboardInterrupt: - channel.stop_consuming() - - # Wait for all to complete - for thread in threads: - thread.join() - - connection.close() - - if __name__ == '__main__': - new_main() + process_queue('asr_benchmark_experiments', process_message) -- GitLab