Skip to content
Snippets Groups Projects
add_to_queue_pipeline.py 3.34 KiB
Newer Older
import datetime
import json
from typing import List

import pika
from minio import Minio
from pika.adapters.blocking_connection import BlockingChannel

COMMANDS = ['run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', 'run_spacy_dep_tag_wer_pipeline',
            'run_spacy_ner_wer_pipeline', 'run_spacy_pos_wer_pipeline']
LANGUAGES = ['nl', 'fr', 'de', 'it', 'pl', 'es', 'en']
WHISPER_ASR_MODEL = ['tiny', 'base', 'small', 'medium', 'large-v2']
DATASETS = ['google_fleurs', 'minds14', 'voxpopuli']


def get_all_datasets() -> List[str]:
    return [it.object_name[:-1] for it in get_minio_client().list_objects('dataset-audio', '')]


def get_dataset_items_id(dataset_name: str) -> List[str]:
    return [it.object_name.split('/')[-1].split('.')[0] for it in
            get_minio_client().list_objects('dataset-audio', f'{dataset_name}/')]


def get_minio_client() -> Minio:
    return Minio('minio-asr-benchmarks.theliver.pl', 'minio_user', 'eUxzEQbyYPdzrLxuvvethSbk18kB2s7G')


Marcin Wątroba's avatar
Marcin Wątroba committed
def add_to_queue(dataset: str, asr_name: str, task: str, channel: BlockingChannel, queue_name: str):
    message_dict = {'dataset': dataset, 'asr_name': asr_name, 'task': task}
    print(datetime.datetime.now().isoformat(), f'{queue_name}   {message_dict}')
    message_bytes = json.dumps(message_dict).encode('utf-8')
    channel.queue_declare(queue=queue_name, durable=True)
    channel.basic_publish(exchange='', routing_key=queue_name, body=message_bytes, properties=pika.BasicProperties(
        delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE
    ))


def get_all_datasets_with_language() -> List[str]:
    to_return = []
    for it in LANGUAGES:
        for itt in DATASETS:
            to_return.append(f'{it}_{itt}')
    return to_return


def add_whisper(channel: BlockingChannel):
    for whisper_variant in WHISPER_ASR_MODEL:
        asr_name = f'whisper_{whisper_variant}'
        for dataset_name in get_all_datasets_with_language():
            for command in COMMANDS:
                add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments')
Marcin Wątroba's avatar
Marcin Wątroba committed
def get_hf_facebook_wav2vec2_model_by_language_code(language_code: str) -> str:
    return {
        # 'nl': 'facebook_wav2vec2_large_xlsr_53_dutch',
        'en': 'facebook_wav2vec2_large_960h_lv60_self',
        # 'fr': 'facebook_wav2vec2_large_xlsr_53_french',
        # 'de': 'facebook_wav2vec2_large_xlsr_53_german',
        # 'it': 'facebook_wav2vec2_large_xlsr_53_italian',
        # 'pl': 'facebook_wav2vec2_large_xlsr_53_polish',
        # 'es': 'facebook_wav2vec2_large_xlsr_53_spanish'
Marcin Wątroba's avatar
Marcin Wątroba committed
    }[language_code]


def add_facebook_hf_wav2vec2_asr(channel: BlockingChannel):
    for dataset_name in get_all_datasets_with_language():
        if dataset_name.startswith('en'):
            add_to_queue(
                dataset_name,
                get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2]),
                'hf_facebook_wav2vec2_asr',
                channel,
                'hf_facebook_wav2vec2_asr'
            )
def main():
    parameters = pika.URLParameters(
        'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
    connection = pika.BlockingConnection(parameters=parameters)
    channel = connection.channel()
Marcin Wątroba's avatar
Marcin Wątroba committed
    # add_whisper(channel)
    add_facebook_hf_wav2vec2_asr(channel)
    connection.close()


if __name__ == '__main__':
    main()