From 46acd3754c46b5cbcc8cc18eca96c777db40c641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com> Date: Sat, 14 Jan 2023 00:48:39 +0100 Subject: [PATCH] Change facebook wav2vec2 model --- new_experiment/add_to_queue_pipeline.py | 37 +++++---- new_experiment/utils/param_util.py | 5 ++ new_experiment/worker_asr.py | 58 +------------- new_experiment/worker_pipeline.py | 102 +++++++++++++++--------- 4 files changed, 96 insertions(+), 106 deletions(-) create mode 100644 new_experiment/utils/param_util.py diff --git a/new_experiment/add_to_queue_pipeline.py b/new_experiment/add_to_queue_pipeline.py index 9b9728b..d6b1fdb 100644 --- a/new_experiment/add_to_queue_pipeline.py +++ b/new_experiment/add_to_queue_pipeline.py @@ -54,26 +54,32 @@ def add_whisper(channel: BlockingChannel): def get_hf_facebook_wav2vec2_model_by_language_code(language_code: str) -> str: return { - # 'nl': 'facebook_wav2vec2_large_xlsr_53_dutch', + 'nl': 'facebook_wav2vec2_large_xlsr_53_dutch', 'en': 'facebook_wav2vec2_large_960h_lv60_self', - # 'fr': 'facebook_wav2vec2_large_xlsr_53_french', - # 'de': 'facebook_wav2vec2_large_xlsr_53_german', - # 'it': 'facebook_wav2vec2_large_xlsr_53_italian', - # 'pl': 'facebook_wav2vec2_large_xlsr_53_polish', - # 'es': 'facebook_wav2vec2_large_xlsr_53_spanish' + 'fr': 'facebook_wav2vec2_large_xlsr_53_french', + 'de': 'facebook_wav2vec2_large_xlsr_53_german', + 'it': 'facebook_wav2vec2_large_xlsr_53_italian', + 'pl': 'facebook_wav2vec2_large_xlsr_53_polish', + 'es': 'facebook_wav2vec2_large_xlsr_53_spanish' }[language_code] def add_facebook_hf_wav2vec2_asr(channel: BlockingChannel): for dataset_name in get_all_datasets_with_language(): - if dataset_name.startswith('en'): - add_to_queue( - dataset_name, - get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2]), - 'hf_facebook_wav2vec2_asr', - channel, - 'hf_facebook_wav2vec2_asr' - ) + add_to_queue( + dataset_name, + get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2]), + 'hf_facebook_wav2vec2_asr', + channel, + 'hf_facebook_wav2vec2_asr' + ) + + +def add_facebook_hf_wav2vec2_pipeline(channel: BlockingChannel): + for dataset_name in get_all_datasets_with_language(): + asr_name = get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2]) + for command in COMMANDS: + add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments') def main(): @@ -82,7 +88,8 @@ def main(): connection = pika.BlockingConnection(parameters=parameters) channel = connection.channel() # add_whisper(channel) - add_facebook_hf_wav2vec2_asr(channel) + # add_facebook_hf_wav2vec2_asr(channel) + add_facebook_hf_wav2vec2_pipeline(channel) connection.close() diff --git a/new_experiment/utils/param_util.py b/new_experiment/utils/param_util.py new file mode 100644 index 0000000..0dd919c --- /dev/null +++ b/new_experiment/utils/param_util.py @@ -0,0 +1,5 @@ +import os + + +def get_param(name: str, default: str) -> str: + return os.environ[name] if name in os.environ else default diff --git a/new_experiment/worker_asr.py b/new_experiment/worker_asr.py index bfa1d06..e15fb4f 100644 --- a/new_experiment/worker_asr.py +++ b/new_experiment/worker_asr.py @@ -1,12 +1,9 @@ +import functools import json import os - -import functools -import logging -import pika import threading -import time +import pika from pika.adapters.blocking_connection import BlockingChannel from new_experiment.pipeline.pipeline_process_asr import run_hf_facebook_wav2vec2_asr_task @@ -15,61 +12,12 @@ from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline - - -# LOG_FORMAT = ('%(levelname) -10s %(asctime)s %(name) -30s %(funcName) ' -# '-35s %(lineno) -5d: %(message)s') -# LOGGER = logging.getLogger(__name__) - -# logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT) - - -def get_param(name: str, default: str) -> str: - return os.environ[name] if name in os.environ else default - +from new_experiment.utils.param_util import get_param _RABBIT_URL = get_param('RABBIT_URL', 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') -def main(): - parameters = pika.URLParameters(_RABBIT_URL) - parameters._heartbeat = 0 - # parameters._heartbeat = 65535 - connection = pika.BlockingConnection(parameters=parameters) - channel = connection.channel() - channel.basic_qos(prefetch_count=1) - - queue_name = f'asr_benchmark_experiments' - for method_frame, properties, body in channel.consume(queue_name): - print(method_frame, properties, body) - message_dict = json.loads(body.decode('utf-8')) - print(message_dict) - - task = message_dict['task'] - dataset = message_dict['dataset'] - asr_name = message_dict['asr_name'] - if task == 'run_word_wer_classic_pipeline': - run_word_wer_classic_pipeline(dataset, asr_name) - elif task == 'run_word_wer_embedding_pipeline': - run_word_wer_embedding_pipeline(dataset, asr_name) - elif task == 'run_spacy_dep_tag_wer_pipeline': - run_spacy_dep_tag_wer_pipeline(dataset, asr_name) - elif task == 'run_spacy_ner_wer_pipeline': - run_spacy_ner_wer_pipeline(dataset, asr_name) - elif task == 'run_spacy_pos_wer_pipeline': - run_spacy_pos_wer_pipeline(dataset, asr_name) - else: - raise Exception(f"Bad message {message_dict}") - - channel.basic_ack(method_frame.delivery_tag) - print('\n########################################################\n') - - requeued_messages = channel.cancel() - print('Requeued %i messages' % requeued_messages) - connection.close() - - def ack_message(channel, delivery_tag): """Note that `channel` must be the same pika channel instance via which the message being ACKed was retrieved (AMQP protocol constraint). diff --git a/new_experiment/worker_pipeline.py b/new_experiment/worker_pipeline.py index ca20369..78c7cef 100644 --- a/new_experiment/worker_pipeline.py +++ b/new_experiment/worker_pipeline.py @@ -5,58 +5,88 @@ import functools import logging import pika import threading -import time + +from pika.adapters.blocking_connection import BlockingChannel from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline +from new_experiment.utils.param_util import get_param +_RABBIT_URL = get_param('RABBIT_URL', + 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') -def get_param(name: str, default: str) -> str: - return os.environ[name] if name in os.environ else default +def process_message(body: bytes): + print(body) + message_dict = json.loads(body.decode('utf-8')) + print(message_dict) -_RABBIT_URL = get_param('RABBIT_URL', - 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') + task = message_dict['task'] + dataset = message_dict['dataset'] + asr_name = message_dict['asr_name'] + if task == 'run_word_wer_classic_pipeline': + run_word_wer_classic_pipeline(dataset, asr_name) + elif task == 'run_word_wer_embedding_pipeline': + run_word_wer_embedding_pipeline(dataset, asr_name) + elif task == 'run_spacy_dep_tag_wer_pipeline': + run_spacy_dep_tag_wer_pipeline(dataset, asr_name) + elif task == 'run_spacy_ner_wer_pipeline': + run_spacy_ner_wer_pipeline(dataset, asr_name) + elif task == 'run_spacy_pos_wer_pipeline': + run_spacy_pos_wer_pipeline(dataset, asr_name) + else: + raise Exception(f"Bad message {message_dict}") + + +def ack_message(channel, delivery_tag): + """Note that `channel` must be the same pika channel instance via which + the message being ACKed was retrieved (AMQP protocol constraint). + """ + if channel.is_open: + channel.basic_ack(delivery_tag) + else: + # Channel is already closed, so we can't ACK this message; + # log and/or do something that makes sense for your app in this case. + pass -def main(): +def do_work(connection, channel, delivery_tag, body): + process_message(body) + cb = functools.partial(ack_message, channel, delivery_tag) + connection.add_callback_threadsafe(cb) + print('\n#########################\n') + + +def on_message(channel: BlockingChannel, method_frame, header_frame, body, args): + (connection, threads) = args + delivery_tag = method_frame.delivery_tag + t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body)) + t.start() + threads.append(t) + + +def new_main(): parameters = pika.URLParameters(_RABBIT_URL) - parameters._heartbeat = 0 - # parameters._heartbeat = 65535 - connection = pika.BlockingConnection(parameters=parameters) + connection = pika.BlockingConnection(parameters) channel = connection.channel() channel.basic_qos(prefetch_count=1) - queue_name = f'asr_benchmark_experiments' - for method_frame, properties, body in channel.consume(queue_name): - print(method_frame, properties, body) - message_dict = json.loads(body.decode('utf-8')) - print(message_dict) - - task = message_dict['task'] - dataset = message_dict['dataset'] - asr_name = message_dict['asr_name'] - if task == 'run_word_wer_classic_pipeline': - run_word_wer_classic_pipeline(dataset, asr_name) - elif task == 'run_word_wer_embedding_pipeline': - run_word_wer_embedding_pipeline(dataset, asr_name) - elif task == 'run_spacy_dep_tag_wer_pipeline': - run_spacy_dep_tag_wer_pipeline(dataset, asr_name) - elif task == 'run_spacy_ner_wer_pipeline': - run_spacy_ner_wer_pipeline(dataset, asr_name) - elif task == 'run_spacy_pos_wer_pipeline': - run_spacy_pos_wer_pipeline(dataset, asr_name) - else: - raise Exception(f"Bad message {message_dict}") - - channel.basic_ack(method_frame.delivery_tag) - print('\n########################################################\n') - - requeued_messages = channel.cancel() - print('Requeued %i messages' % requeued_messages) + threads = [] + on_message_callback = functools.partial(on_message, args=(connection, threads)) + channel.basic_consume('hf_facebook_wav2vec2_asr', on_message_callback) + + try: + channel.start_consuming() + except KeyboardInterrupt: + channel.stop_consuming() + + # Wait for all to complete + for thread in threads: + thread.join() + connection.close() -- GitLab