Skip to content
Snippets Groups Projects
worker_asr.py 2.88 KiB
Newer Older
import functools
Marcin Wątroba's avatar
Marcin Wątroba committed
import json
import os
import threading

import pika
Marcin Wątroba's avatar
Marcin Wątroba committed
from pika.adapters.blocking_connection import BlockingChannel

from new_experiment.pipeline.pipeline_process_asr import run_hf_facebook_wav2vec2_asr_task
from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline
from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline
from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline
from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline
from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline
from new_experiment.utils.param_util import get_param
Marcin Wątroba's avatar
Marcin Wątroba committed

_RABBIT_URL = get_param('RABBIT_URL',
                        'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')


def ack_message(channel, delivery_tag):
    """Note that `channel` must be the same pika channel instance via which
    the message being ACKed was retrieved (AMQP protocol constraint).
    """
    if channel.is_open:
        channel.basic_ack(delivery_tag)
    else:
        # Channel is already closed, so we can't ACK this message;
        # log and/or do something that makes sense for your app in this case.
        pass


def do_work(connection, channel, delivery_tag, body):
    thread_id = threading.get_ident()
    # fmt1 = 'Thread id: {} Delivery tag: {} Message body: {}'
    # LOGGER.info(fmt1.format(thread_id, delivery_tag, body))

    print(body)
    message_dict = json.loads(body.decode('utf-8'))
    print(message_dict)
    task = message_dict['task']
    dataset = message_dict['dataset']
    asr_name = message_dict['asr_name']

    if task == 'hf_facebook_wav2vec2_asr':
        run_hf_facebook_wav2vec2_asr_task(dataset, asr_name)

    cb = functools.partial(ack_message, channel, delivery_tag)
    connection.add_callback_threadsafe(cb)
    print('\n#########################\n')


def on_message(channel: BlockingChannel, method_frame, header_frame, body, args):
    (connection, threads) = args
    delivery_tag = method_frame.delivery_tag
    t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body))
    t.start()
    threads.append(t)


def new_main():
    parameters = pika.URLParameters(_RABBIT_URL)
    connection = pika.BlockingConnection(parameters)
    channel = connection.channel()
    channel.basic_qos(prefetch_count=1)

    threads = []
    on_message_callback = functools.partial(on_message, args=(connection, threads))
    channel.basic_consume('hf_facebook_wav2vec2_asr', on_message_callback)

    try:
        channel.start_consuming()
    except KeyboardInterrupt:
        channel.stop_consuming()

    # Wait for all to complete
    for thread in threads:
        thread.join()

    connection.close()


if __name__ == '__main__':
    new_main()