import functools import json import os import threading import pika from pika.adapters.blocking_connection import BlockingChannel from new_experiment.pipeline.pipeline_process_asr import run_hf_facebook_wav2vec2_asr_task from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline from new_experiment.utils.param_util import get_param _RABBIT_URL = get_param('RABBIT_URL', 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') def ack_message(channel, delivery_tag): """Note that `channel` must be the same pika channel instance via which the message being ACKed was retrieved (AMQP protocol constraint). """ if channel.is_open: channel.basic_ack(delivery_tag) else: # Channel is already closed, so we can't ACK this message; # log and/or do something that makes sense for your app in this case. pass def do_work(connection, channel, delivery_tag, body): thread_id = threading.get_ident() # fmt1 = 'Thread id: {} Delivery tag: {} Message body: {}' # LOGGER.info(fmt1.format(thread_id, delivery_tag, body)) print(body) message_dict = json.loads(body.decode('utf-8')) print(message_dict) task = message_dict['task'] dataset = message_dict['dataset'] asr_name = message_dict['asr_name'] if task == 'hf_facebook_wav2vec2_asr': run_hf_facebook_wav2vec2_asr_task(dataset, asr_name) cb = functools.partial(ack_message, channel, delivery_tag) connection.add_callback_threadsafe(cb) print('\n#########################\n') def on_message(channel: BlockingChannel, method_frame, header_frame, body, args): (connection, threads) = args delivery_tag = method_frame.delivery_tag t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body)) t.start() threads.append(t) def new_main(): parameters = pika.URLParameters(_RABBIT_URL) connection = pika.BlockingConnection(parameters) channel = connection.channel() channel.basic_qos(prefetch_count=1) threads = [] on_message_callback = functools.partial(on_message, args=(connection, threads)) channel.basic_consume('hf_facebook_wav2vec2_asr', on_message_callback) try: channel.start_consuming() except KeyboardInterrupt: channel.stop_consuming() # Wait for all to complete for thread in threads: thread.join() connection.close() if __name__ == '__main__': new_main()