diff --git a/new_experiment/add_to_queue_pipeline.py b/new_experiment/add_to_queue_pipeline.py index d6b1fdb5536a5d4f971d7e119694647d0bcd3890..ef3bf2f4c970732960815ac379735f4c1332bc14 100644 --- a/new_experiment/add_to_queue_pipeline.py +++ b/new_experiment/add_to_queue_pipeline.py @@ -71,7 +71,7 @@ def add_facebook_hf_wav2vec2_asr(channel: BlockingChannel): get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2]), 'hf_facebook_wav2vec2_asr', channel, - 'hf_facebook_wav2vec2_asr' + 'asr_benchmark_asr_run' ) @@ -88,8 +88,8 @@ def main(): connection = pika.BlockingConnection(parameters=parameters) channel = connection.channel() # add_whisper(channel) - # add_facebook_hf_wav2vec2_asr(channel) - add_facebook_hf_wav2vec2_pipeline(channel) + add_facebook_hf_wav2vec2_asr(channel) + # add_facebook_hf_wav2vec2_pipeline(channel) connection.close() diff --git a/new_experiment/hf_asr/wav2vec2_hf.py b/new_experiment/hf_asr/wav2vec2_hf.py index ff17959129c1be506382d8ea0f6bf05d8a82cfee..255ab272b9aafd93e098b91bc308d2a52046d176 100644 --- a/new_experiment/hf_asr/wav2vec2_hf.py +++ b/new_experiment/hf_asr/wav2vec2_hf.py @@ -30,7 +30,7 @@ class Wav2Vec2AsrProcessor(AsrProcessor): pred_ids = torch.argmax(logits, dim=-1) result = self._wav2vec2_processor.batch_decode(pred_ids)[0] return { - "transcription": [create_new_word(it) for it in result.split()], + "transcription": result.split(), "full_text": result, "words_time_alignment": None } diff --git a/new_experiment/pipeline/pipeline_process_asr.py b/new_experiment/pipeline/pipeline_process_asr.py index 1860df6435f6e22e427f879b32b491979e7be5f8..e1657a527caf5f622365908b5d4e1daeeb703227 100644 --- a/new_experiment/pipeline/pipeline_process_asr.py +++ b/new_experiment/pipeline/pipeline_process_asr.py @@ -34,7 +34,7 @@ def run_hf_facebook_wav2vec2_asr_task(dataset_name: str, asr_name: str): AsrTask( asr_property_name=PropertyHelper.asr_result(asr_name), task_name=f'AsrTask___{dataset_name}___{asr_name}', - require_update=False, + require_update=True, asr_processor=get_asr_processor(asr_name), record_path_provider=record_provider ) diff --git a/new_experiment/worker_asr.py b/new_experiment/worker_asr.py index e15fb4f7c6526c97f975bb852833ac45932256f7..593690a198730ba96ea51c0571da32d3231b88c3 100644 --- a/new_experiment/worker_asr.py +++ b/new_experiment/worker_asr.py @@ -1,40 +1,10 @@ -import functools import json -import os -import threading - -import pika -from pika.adapters.blocking_connection import BlockingChannel from new_experiment.pipeline.pipeline_process_asr import run_hf_facebook_wav2vec2_asr_task -from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline -from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline -from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline -from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline -from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline -from new_experiment.utils.param_util import get_param - -_RABBIT_URL = get_param('RABBIT_URL', - 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') - - -def ack_message(channel, delivery_tag): - """Note that `channel` must be the same pika channel instance via which - the message being ACKed was retrieved (AMQP protocol constraint). - """ - if channel.is_open: - channel.basic_ack(delivery_tag) - else: - # Channel is already closed, so we can't ACK this message; - # log and/or do something that makes sense for your app in this case. - pass +from new_experiment.queue_base import process_queue -def do_work(connection, channel, delivery_tag, body): - thread_id = threading.get_ident() - # fmt1 = 'Thread id: {} Delivery tag: {} Message body: {}' - # LOGGER.info(fmt1.format(thread_id, delivery_tag, body)) - +def process_message(body: bytes): print(body) message_dict = json.loads(body.decode('utf-8')) print(message_dict) @@ -45,40 +15,6 @@ def do_work(connection, channel, delivery_tag, body): if task == 'hf_facebook_wav2vec2_asr': run_hf_facebook_wav2vec2_asr_task(dataset, asr_name) - cb = functools.partial(ack_message, channel, delivery_tag) - connection.add_callback_threadsafe(cb) - print('\n#########################\n') - - -def on_message(channel: BlockingChannel, method_frame, header_frame, body, args): - (connection, threads) = args - delivery_tag = method_frame.delivery_tag - t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body)) - t.start() - threads.append(t) - - -def new_main(): - parameters = pika.URLParameters(_RABBIT_URL) - connection = pika.BlockingConnection(parameters) - channel = connection.channel() - channel.basic_qos(prefetch_count=1) - - threads = [] - on_message_callback = functools.partial(on_message, args=(connection, threads)) - channel.basic_consume('hf_facebook_wav2vec2_asr', on_message_callback) - - try: - channel.start_consuming() - except KeyboardInterrupt: - channel.stop_consuming() - - # Wait for all to complete - for thread in threads: - thread.join() - - connection.close() - if __name__ == '__main__': - new_main() + process_queue('asr_benchmark_asr_run', process_message)