From c96708a501e902703e681ab3cf49105548023f3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com> Date: Thu, 12 Jan 2023 03:26:06 +0100 Subject: [PATCH] Add worker for pipeline --- new_experiment/add_to_queue_pipeline.py | 66 +++++++++++++++++++ .../pipeline_process_spacy_dep_tag_wer.py | 2 +- .../pipeline_process_spacy_ner_wer.py | 2 +- new_experiment/worker.py | 7 -- 4 files changed, 68 insertions(+), 9 deletions(-) create mode 100644 new_experiment/add_to_queue_pipeline.py diff --git a/new_experiment/add_to_queue_pipeline.py b/new_experiment/add_to_queue_pipeline.py new file mode 100644 index 0000000..2053526 --- /dev/null +++ b/new_experiment/add_to_queue_pipeline.py @@ -0,0 +1,66 @@ +import datetime +import json +from typing import List + +import pika +from minio import Minio +from pika.adapters.blocking_connection import BlockingChannel + +COMMANDS = ['run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', 'run_spacy_dep_tag_wer_pipeline', + 'run_spacy_ner_wer_pipeline', 'run_spacy_pos_wer_pipeline'] +LANGUAGES = ['nl', 'fr', 'de', 'it', 'pl', 'es', 'en'] +WHISPER_ASR_MODEL = ['tiny', 'base', 'small', 'medium', 'large-v2'] +DATASETS = ['google_fleurs', 'minds14', 'voxpopuli'] + + +def get_all_datasets() -> List[str]: + return [it.object_name[:-1] for it in get_minio_client().list_objects('dataset-audio', '')] + + +def get_dataset_items_id(dataset_name: str) -> List[str]: + return [it.object_name.split('/')[-1].split('.')[0] for it in + get_minio_client().list_objects('dataset-audio', f'{dataset_name}/')] + + +def get_minio_client() -> Minio: + return Minio('minio-asr-benchmarks.theliver.pl', 'minio_user', 'eUxzEQbyYPdzrLxuvvethSbk18kB2s7G') + + +def add_to_queue(dataset: str, asr_name: str, task: str, channel: BlockingChannel): + message_dict = {'dataset': dataset, 'asr_name': asr_name, 'task': task} + queue_name = 'asr_benchmark_experiments' + print(datetime.datetime.now().isoformat(), f'{queue_name} {message_dict}') + message_bytes = json.dumps(message_dict).encode('utf-8') + channel.queue_declare(queue=queue_name, durable=True) + channel.basic_publish(exchange='', routing_key=queue_name, body=message_bytes, properties=pika.BasicProperties( + delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE + )) + + +def get_all_datasets_with_language() -> List[str]: + to_return = [] + for it in LANGUAGES: + for itt in DATASETS: + to_return.append(f'{it}_{itt}') + return to_return + + +def add_whisper(channel: BlockingChannel): + for whisper_variant in WHISPER_ASR_MODEL: + asr_name = f'whisper_{whisper_variant}' + for dataset_name in get_all_datasets_with_language(): + for command in COMMANDS: + add_to_queue(dataset_name, asr_name, command, channel) + + +def main(): + parameters = pika.URLParameters( + 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') + connection = pika.BlockingConnection(parameters=parameters) + channel = connection.channel() + add_whisper(channel) + connection.close() + + +if __name__ == '__main__': + main() diff --git a/new_experiment/pipeline/pipeline_process_spacy_dep_tag_wer.py b/new_experiment/pipeline/pipeline_process_spacy_dep_tag_wer.py index 03420db..8e2d568 100644 --- a/new_experiment/pipeline/pipeline_process_spacy_dep_tag_wer.py +++ b/new_experiment/pipeline/pipeline_process_spacy_dep_tag_wer.py @@ -25,7 +25,7 @@ def run_spacy_dep_tag_wer_pipeline(dataset_name: str, asr_name: str): alignment_property_name=PropertyHelper.dep_tag_alignment(asr_name, model_name), wer_property_name=PropertyHelper.dep_tag_metrics(asr_name, model_name), task_name=f'SpacyDepTagSentenceWerProcessor___{dataset_name}___{asr_name}', - require_update=False + require_update=True ) ], experiment_repository=repository, diff --git a/new_experiment/pipeline/pipeline_process_spacy_ner_wer.py b/new_experiment/pipeline/pipeline_process_spacy_ner_wer.py index 96cc9ef..c74e419 100644 --- a/new_experiment/pipeline/pipeline_process_spacy_ner_wer.py +++ b/new_experiment/pipeline/pipeline_process_spacy_ner_wer.py @@ -25,7 +25,7 @@ def run_spacy_ner_wer_pipeline(dataset_name: str, asr_name: str): alignment_property_name=PropertyHelper.ner_alignment(asr_name, model_name), wer_property_name=PropertyHelper.ner_metrics(asr_name, model_name), task_name=f'SpacyNerSentenceWerProcessor___{dataset_name}___{asr_name}', - require_update=False + require_update=True ) ], experiment_repository=repository, diff --git a/new_experiment/worker.py b/new_experiment/worker.py index 93a861b..7422ea9 100644 --- a/new_experiment/worker.py +++ b/new_experiment/worker.py @@ -1,19 +1,13 @@ import json import os -import uuid import pika -from minio import Minio -from pymongo import MongoClient -from urllib3 import HTTPResponse -from new_datasets.whisper_processor import WhisperAsrProcessor from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline -from sziszapangma.integration.repository.mongo_experiment_repository import MongoExperimentRepository def get_param(name: str, default: str) -> str: @@ -39,7 +33,6 @@ def main(): task = message_dict['task'] dataset = message_dict['dataset'] asr_name = message_dict['asr_name'] - if task == 'run_word_wer_classic_pipeline': run_word_wer_classic_pipeline(dataset, asr_name) elif task == 'run_word_wer_embedding_pipeline': -- GitLab