import datetime import json from typing import List import pika from minio import Minio from pika.adapters.blocking_connection import BlockingChannel COMMANDS = ['run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', 'run_spacy_dep_tag_wer_pipeline', 'run_spacy_ner_wer_pipeline', 'run_spacy_pos_wer_pipeline'] LANGUAGES = ['nl', 'fr', 'de', 'it', 'pl', 'es', 'en'] WHISPER_ASR_MODEL = ['tiny', 'base', 'small', 'medium', 'large-v2'] DATASETS = ['google_fleurs', 'minds14', 'voxpopuli'] def get_all_datasets() -> List[str]: return [it.object_name[:-1] for it in get_minio_client().list_objects('dataset-audio', '')] def get_dataset_items_id(dataset_name: str) -> List[str]: return [it.object_name.split('/')[-1].split('.')[0] for it in get_minio_client().list_objects('dataset-audio', f'{dataset_name}/')] def get_minio_client() -> Minio: return Minio('minio-asr-benchmarks.theliver.pl', 'minio_user', 'eUxzEQbyYPdzrLxuvvethSbk18kB2s7G') def add_to_queue(dataset: str, asr_name: str, task: str, channel: BlockingChannel, queue_name: str): message_dict = {'dataset': dataset, 'asr_name': asr_name, 'task': task} print(datetime.datetime.now().isoformat(), f'{queue_name} {message_dict}') message_bytes = json.dumps(message_dict).encode('utf-8') channel.queue_declare(queue=queue_name, durable=True) channel.basic_publish(exchange='', routing_key=queue_name, body=message_bytes, properties=pika.BasicProperties( delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE )) def get_all_datasets_with_language() -> List[str]: to_return = [] for it in LANGUAGES: for itt in DATASETS: to_return.append(f'{it}_{itt}') return to_return def add_whisper(channel: BlockingChannel): for whisper_variant in WHISPER_ASR_MODEL: asr_name = f'whisper_{whisper_variant}' for dataset_name in get_all_datasets_with_language(): for command in COMMANDS: add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments') def get_hf_facebook_wav2vec2_model_by_language_code(language_code: str) -> str: return { # 'nl': 'facebook_wav2vec2_large_xlsr_53_dutch', 'en': 'facebook_wav2vec2_large_960h_lv60_self', # 'fr': 'facebook_wav2vec2_large_xlsr_53_french', # 'de': 'facebook_wav2vec2_large_xlsr_53_german', # 'it': 'facebook_wav2vec2_large_xlsr_53_italian', # 'pl': 'facebook_wav2vec2_large_xlsr_53_polish', # 'es': 'facebook_wav2vec2_large_xlsr_53_spanish' }[language_code] def add_facebook_hf_wav2vec2_asr(channel: BlockingChannel): for dataset_name in get_all_datasets_with_language(): if dataset_name.startswith('en'): add_to_queue( dataset_name, get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2]), 'hf_facebook_wav2vec2_asr', channel, 'hf_facebook_wav2vec2_asr' ) def main(): parameters = pika.URLParameters( 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') connection = pika.BlockingConnection(parameters=parameters) channel = connection.channel() # add_whisper(channel) add_facebook_hf_wav2vec2_asr(channel) connection.close() if __name__ == '__main__': main()