From 9e93749560e17b03e28033dedad1b20d35050f31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com> Date: Fri, 13 Jan 2023 23:37:12 +0100 Subject: [PATCH] Add new worker --- new_experiment/add_to_queue_pipeline.py | 29 +++- new_experiment/hf_asr/wav2vec2_hf.py | 4 +- new_experiment/new_worker.py | 136 ++++++++++++++++++ .../pipeline/pipeline_process_asr.py | 27 ++-- new_experiment/worker.py | 66 --------- 5 files changed, 182 insertions(+), 80 deletions(-) create mode 100644 new_experiment/new_worker.py diff --git a/new_experiment/add_to_queue_pipeline.py b/new_experiment/add_to_queue_pipeline.py index 2053526..b929cf0 100644 --- a/new_experiment/add_to_queue_pipeline.py +++ b/new_experiment/add_to_queue_pipeline.py @@ -26,9 +26,8 @@ def get_minio_client() -> Minio: return Minio('minio-asr-benchmarks.theliver.pl', 'minio_user', 'eUxzEQbyYPdzrLxuvvethSbk18kB2s7G') -def add_to_queue(dataset: str, asr_name: str, task: str, channel: BlockingChannel): +def add_to_queue(dataset: str, asr_name: str, task: str, channel: BlockingChannel, queue_name: str): message_dict = {'dataset': dataset, 'asr_name': asr_name, 'task': task} - queue_name = 'asr_benchmark_experiments' print(datetime.datetime.now().isoformat(), f'{queue_name} {message_dict}') message_bytes = json.dumps(message_dict).encode('utf-8') channel.queue_declare(queue=queue_name, durable=True) @@ -53,12 +52,36 @@ def add_whisper(channel: BlockingChannel): add_to_queue(dataset_name, asr_name, command, channel) +def get_hf_facebook_wav2vec2_model_by_language_code(language_code: str) -> str: + return { + 'nl': 'facebook_wav2vec2_large_xlsr_53_dutch', + 'en': 'facebook_wav2vec2_xls_r_300m', + 'fr': 'facebook_wav2vec2_large_xlsr_53_french', + 'de': 'facebook_wav2vec2_large_xlsr_53_german', + 'it': 'facebook_wav2vec2_large_xlsr_53_italian', + 'pl': 'facebook_wav2vec2_large_xlsr_53_polish', + 'es': 'facebook_wav2vec2_large_xlsr_53_spanish' + }[language_code] + + +def add_facebook_hf_wav2vec2_asr(channel: BlockingChannel): + for dataset_name in get_all_datasets_with_language(): + add_to_queue( + dataset_name, + get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2]), + 'hf_facebook_wav2vec2_asr', + channel, + 'hf_facebook_wav2vec2_asr' + ) + + def main(): parameters = pika.URLParameters( 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') connection = pika.BlockingConnection(parameters=parameters) channel = connection.channel() - add_whisper(channel) + # add_whisper(channel) + add_facebook_hf_wav2vec2_asr(channel) connection.close() diff --git a/new_experiment/hf_asr/wav2vec2_hf.py b/new_experiment/hf_asr/wav2vec2_hf.py index c9e06a0..ff17959 100644 --- a/new_experiment/hf_asr/wav2vec2_hf.py +++ b/new_experiment/hf_asr/wav2vec2_hf.py @@ -21,9 +21,7 @@ class Wav2Vec2AsrProcessor(AsrProcessor): self._wav2vec2_processor = Wav2Vec2Processor.from_pretrained(model_name) def call_recognise(self, file_path: str) -> Dict[str, Any]: - # samplerate, data = wavfile.read(file_path) - # data, samplerate = soundfile.read(file_path) - data, samplerate = librosa.load(file_path) + data, samplerate = librosa.load(file_path, sr=16000) features = self._wav2vec2_processor(data, sampling_rate=samplerate, padding=True, return_tensors="pt") input_values = features.input_values.to(self._device) attention_mask = features.attention_mask.to(self._device) diff --git a/new_experiment/new_worker.py b/new_experiment/new_worker.py new file mode 100644 index 0000000..bfa1d06 --- /dev/null +++ b/new_experiment/new_worker.py @@ -0,0 +1,136 @@ +import json +import os + +import functools +import logging +import pika +import threading +import time + +from pika.adapters.blocking_connection import BlockingChannel + +from new_experiment.pipeline.pipeline_process_asr import run_hf_facebook_wav2vec2_asr_task +from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline +from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline +from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline +from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline +from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline + + +# LOG_FORMAT = ('%(levelname) -10s %(asctime)s %(name) -30s %(funcName) ' +# '-35s %(lineno) -5d: %(message)s') +# LOGGER = logging.getLogger(__name__) + +# logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT) + + +def get_param(name: str, default: str) -> str: + return os.environ[name] if name in os.environ else default + + +_RABBIT_URL = get_param('RABBIT_URL', + 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') + + +def main(): + parameters = pika.URLParameters(_RABBIT_URL) + parameters._heartbeat = 0 + # parameters._heartbeat = 65535 + connection = pika.BlockingConnection(parameters=parameters) + channel = connection.channel() + channel.basic_qos(prefetch_count=1) + + queue_name = f'asr_benchmark_experiments' + for method_frame, properties, body in channel.consume(queue_name): + print(method_frame, properties, body) + message_dict = json.loads(body.decode('utf-8')) + print(message_dict) + + task = message_dict['task'] + dataset = message_dict['dataset'] + asr_name = message_dict['asr_name'] + if task == 'run_word_wer_classic_pipeline': + run_word_wer_classic_pipeline(dataset, asr_name) + elif task == 'run_word_wer_embedding_pipeline': + run_word_wer_embedding_pipeline(dataset, asr_name) + elif task == 'run_spacy_dep_tag_wer_pipeline': + run_spacy_dep_tag_wer_pipeline(dataset, asr_name) + elif task == 'run_spacy_ner_wer_pipeline': + run_spacy_ner_wer_pipeline(dataset, asr_name) + elif task == 'run_spacy_pos_wer_pipeline': + run_spacy_pos_wer_pipeline(dataset, asr_name) + else: + raise Exception(f"Bad message {message_dict}") + + channel.basic_ack(method_frame.delivery_tag) + print('\n########################################################\n') + + requeued_messages = channel.cancel() + print('Requeued %i messages' % requeued_messages) + connection.close() + + +def ack_message(channel, delivery_tag): + """Note that `channel` must be the same pika channel instance via which + the message being ACKed was retrieved (AMQP protocol constraint). + """ + if channel.is_open: + channel.basic_ack(delivery_tag) + else: + # Channel is already closed, so we can't ACK this message; + # log and/or do something that makes sense for your app in this case. + pass + + +def do_work(connection, channel, delivery_tag, body): + thread_id = threading.get_ident() + # fmt1 = 'Thread id: {} Delivery tag: {} Message body: {}' + # LOGGER.info(fmt1.format(thread_id, delivery_tag, body)) + + print(body) + message_dict = json.loads(body.decode('utf-8')) + print(message_dict) + task = message_dict['task'] + dataset = message_dict['dataset'] + asr_name = message_dict['asr_name'] + + if task == 'hf_facebook_wav2vec2_asr': + run_hf_facebook_wav2vec2_asr_task(dataset, asr_name) + + cb = functools.partial(ack_message, channel, delivery_tag) + connection.add_callback_threadsafe(cb) + print('\n#########################\n') + + +def on_message(channel: BlockingChannel, method_frame, header_frame, body, args): + (connection, threads) = args + delivery_tag = method_frame.delivery_tag + t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body)) + t.start() + threads.append(t) + + +def new_main(): + parameters = pika.URLParameters(_RABBIT_URL) + connection = pika.BlockingConnection(parameters) + channel = connection.channel() + channel.basic_qos(prefetch_count=1) + + threads = [] + on_message_callback = functools.partial(on_message, args=(connection, threads)) + channel.basic_consume('hf_facebook_wav2vec2_asr', on_message_callback) + + try: + channel.start_consuming() + except KeyboardInterrupt: + channel.stop_consuming() + + # Wait for all to complete + for thread in threads: + thread.join() + + connection.close() + + +if __name__ == '__main__': + new_main() diff --git a/new_experiment/pipeline/pipeline_process_asr.py b/new_experiment/pipeline/pipeline_process_asr.py index 79fe1cb..c883a3d 100644 --- a/new_experiment/pipeline/pipeline_process_asr.py +++ b/new_experiment/pipeline/pipeline_process_asr.py @@ -10,10 +10,22 @@ from sziszapangma.integration.task.asr_task import AsrTask def get_asr_processor(asr_name: str) -> AsrProcessor: if asr_name == 'facebook_wav2vec2_large_xlsr_53_dutch': return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-dutch') + if asr_name == 'facebook_wav2vec2_xls_r_300m': + return Wav2Vec2AsrProcessor('facebook/wav2vec2-xls-r-300m') + if asr_name == 'facebook_wav2vec2_large_xlsr_53_french': + return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-french') + if asr_name == 'facebook_wav2vec2_large_xlsr_53_german': + return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-german') + if asr_name == 'facebook_wav2vec2_large_xlsr_53_italian': + return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-italian') + if asr_name == 'facebook_wav2vec2_large_xlsr_53_polish': + return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-polish') + if asr_name == 'facebook_wav2vec2_large_xlsr_53_spanish': + return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-spanish') raise Exception(f'AsrProcessor not found for name: {asr_name}') -def run_spacy_dep_tag_wer_pipeline(dataset_name: str, asr_name: str): +def run_hf_facebook_wav2vec2_asr_task(dataset_name: str, asr_name: str): repository = get_experiment_repository(dataset_name) record_provider = LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name) experiment_processor = ExperimentManager( @@ -21,8 +33,8 @@ def run_spacy_dep_tag_wer_pipeline(dataset_name: str, asr_name: str): processing_tasks=[ AsrTask( asr_property_name=PropertyHelper.asr_result(asr_name), - task_name=f'SpacyDepTagSentenceWerProcessor___{dataset_name}___{asr_name}', - require_update=False, + task_name=f'AsrTask___{dataset_name}___{asr_name}', + require_update=True, asr_processor=get_asr_processor(asr_name), record_path_provider=record_provider ) @@ -31,8 +43,7 @@ def run_spacy_dep_tag_wer_pipeline(dataset_name: str, asr_name: str): ) experiment_processor.process() - -if __name__ == '__main__': - run_spacy_dep_tag_wer_pipeline('nl_minds14', 'facebook_wav2vec2_large_xlsr_53_dutch') - run_spacy_dep_tag_wer_pipeline('nl_google_fleurs', 'facebook_wav2vec2_large_xlsr_53_dutch') - run_spacy_dep_tag_wer_pipeline('nl_voxpopuli', 'facebook_wav2vec2_large_xlsr_53_dutch') +# if __name__ == '__main__': +# run_spacy_dep_tag_wer_pipeline('nl_minds14', 'facebook_wav2vec2_large_xlsr_53_dutch') +# run_spacy_dep_tag_wer_pipeline('nl_google_fleurs', 'facebook_wav2vec2_large_xlsr_53_dutch') +# run_spacy_dep_tag_wer_pipeline('nl_voxpopuli', 'facebook_wav2vec2_large_xlsr_53_dutch') diff --git a/new_experiment/worker.py b/new_experiment/worker.py index 70dd0b2..ca20369 100644 --- a/new_experiment/worker.py +++ b/new_experiment/worker.py @@ -60,71 +60,5 @@ def main(): connection.close() -def new_main(): - LOG_FORMAT = ('%(levelname) -10s %(asctime)s %(name) -30s %(funcName) ' - '-35s %(lineno) -5d: %(message)s') - LOGGER = logging.getLogger(__name__) - - logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT) - - def ack_message(channel, delivery_tag): - """Note that `channel` must be the same pika channel instance via which - the message being ACKed was retrieved (AMQP protocol constraint). - """ - if channel.is_open: - channel.basic_ack(delivery_tag) - else: - # Channel is already closed, so we can't ACK this message; - # log and/or do something that makes sense for your app in this case. - pass - - def do_work(connection, channel, delivery_tag, body): - thread_id = threading.get_ident() - fmt1 = 'Thread id: {} Delivery tag: {} Message body: {}' - LOGGER.info(fmt1.format(thread_id, delivery_tag, body)) - # Sleeping to simulate 10 seconds of work - time.sleep(10) - cb = functools.partial(ack_message, channel, delivery_tag) - connection.add_callback_threadsafe(cb) - - def on_message(channel, method_frame, header_frame, body, args): - (connection, threads) = args - delivery_tag = method_frame.delivery_tag - t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body)) - t.start() - threads.append(t) - - credentials = pika.PlainCredentials('guest', 'guest') - # Note: sending a short heartbeat to prove that heartbeats are still - # sent even though the worker simulates long-running work - parameters = pika.ConnectionParameters('localhost', credentials=credentials, heartbeat=5) - connection = pika.BlockingConnection(parameters) - - channel = connection.channel() - channel.exchange_declare(exchange="test_exchange", exchange_type="direct", passive=False, durable=True, - auto_delete=False) - channel.queue_declare(queue="standard", auto_delete=True) - channel.queue_bind(queue="standard", exchange="test_exchange", routing_key="standard_key") - # Note: prefetch is set to 1 here as an example only and to keep the number of threads created - # to a reasonable amount. In production you will want to test with different prefetch values - # to find which one provides the best performance and usability for your solution - channel.basic_qos(prefetch_count=1) - - threads = [] - on_message_callback = functools.partial(on_message, args=(connection, threads)) - channel.basic_consume('standard', on_message_callback) - - try: - channel.start_consuming() - except KeyboardInterrupt: - channel.stop_consuming() - - # Wait for all to complete - for thread in threads: - thread.join() - - connection.close() - - if __name__ == '__main__': main() -- GitLab