Newer
Older
import datetime
import json
from typing import List
import pika
from minio import Minio
from pika.adapters.blocking_connection import BlockingChannel
COMMANDS = ['run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', 'run_spacy_dep_tag_wer_pipeline',
'run_spacy_ner_wer_pipeline', 'run_spacy_pos_wer_pipeline']
LANGUAGES = ['nl', 'fr', 'de', 'it', 'pl', 'es', 'en']
WHISPER_ASR_MODEL = ['tiny', 'base', 'small', 'medium', 'large-v2']
DATASETS = ['google_fleurs', 'minds14', 'voxpopuli']
def get_all_datasets() -> List[str]:
return [it.object_name[:-1] for it in get_minio_client().list_objects('dataset-audio', '')]
def get_dataset_items_id(dataset_name: str) -> List[str]:
return [it.object_name.split('/')[-1].split('.')[0] for it in
get_minio_client().list_objects('dataset-audio', f'{dataset_name}/')]
def get_minio_client() -> Minio:
return Minio('minio-asr-benchmarks.theliver.pl', 'minio_user', 'eUxzEQbyYPdzrLxuvvethSbk18kB2s7G')
def add_to_queue(dataset: str, asr_name: str, task: str, channel: BlockingChannel, queue_name: str):
message_dict = {'dataset': dataset, 'asr_name': asr_name, 'task': task}
print(datetime.datetime.now().isoformat(), f'{queue_name} {message_dict}')
message_bytes = json.dumps(message_dict).encode('utf-8')
channel.queue_declare(queue=queue_name, durable=True)
channel.basic_publish(exchange='', routing_key=queue_name, body=message_bytes, properties=pika.BasicProperties(
delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE
))
def get_all_datasets_with_language() -> List[str]:
to_return = []
for it in LANGUAGES:
for itt in DATASETS:
to_return.append(f'{it}_{itt}')
return to_return
def add_whisper(channel: BlockingChannel):
for whisper_variant in WHISPER_ASR_MODEL:
asr_name = f'whisper_{whisper_variant}'
for dataset_name in get_all_datasets_with_language():
for command in COMMANDS:
add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments')
def get_hf_facebook_wav2vec2_model_by_language_code(language_code: str) -> str:
return {
# 'nl': 'facebook_wav2vec2_large_xlsr_53_dutch',
'en': 'facebook_wav2vec2_large_960h_lv60_self',
# 'fr': 'facebook_wav2vec2_large_xlsr_53_french',
# 'de': 'facebook_wav2vec2_large_xlsr_53_german',
# 'it': 'facebook_wav2vec2_large_xlsr_53_italian',
# 'pl': 'facebook_wav2vec2_large_xlsr_53_polish',
# 'es': 'facebook_wav2vec2_large_xlsr_53_spanish'
}[language_code]
def add_facebook_hf_wav2vec2_asr(channel: BlockingChannel):
for dataset_name in get_all_datasets_with_language():
if dataset_name.startswith('en'):
add_to_queue(
dataset_name,
get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2]),
'hf_facebook_wav2vec2_asr',
channel,
'hf_facebook_wav2vec2_asr'
)
def main():
parameters = pika.URLParameters(
'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
connection = pika.BlockingConnection(parameters=parameters)
channel = connection.channel()
# add_whisper(channel)
add_facebook_hf_wav2vec2_asr(channel)
connection.close()
if __name__ == '__main__':
main()