From 1936693a2226bd1e0cf235cd02f7caf06ce90be1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com> Date: Sun, 15 Jan 2023 01:26:25 +0100 Subject: [PATCH] download_dataset command --- new_datasets/download_hf_dataset.py | 18 ++++++++++++++++++ new_experiment/add_to_queue_pipeline.py | 20 +++++++++++++++----- 2 files changed, 33 insertions(+), 5 deletions(-) create mode 100644 new_datasets/download_hf_dataset.py diff --git a/new_datasets/download_hf_dataset.py b/new_datasets/download_hf_dataset.py new file mode 100644 index 0000000..1f6effc --- /dev/null +++ b/new_datasets/download_hf_dataset.py @@ -0,0 +1,18 @@ +import argparse +from typing import Optional + +import datasets + + +def download_dataset(dataset_path: str, dataset_name: Optional[str], cache_dir: str): + dataset = datasets.load_dataset(dataset_path, dataset_name, cache_dir=cache_dir) + print(dataset) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_path") + parser.add_argument("--dataset_name") + parser.add_argument("--cache_dir") + args = parser.parse_args() + download_dataset(args.dataset, args.asr) diff --git a/new_experiment/add_to_queue_pipeline.py b/new_experiment/add_to_queue_pipeline.py index b87ca62..c6e2f36 100644 --- a/new_experiment/add_to_queue_pipeline.py +++ b/new_experiment/add_to_queue_pipeline.py @@ -9,14 +9,14 @@ from pika.adapters.blocking_connection import BlockingChannel COMMANDS = ['run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', 'run_spacy_dep_tag_wer_pipeline', 'run_spacy_ner_wer_pipeline', 'run_spacy_pos_wer_pipeline'] LANGUAGES = [ - # 'nl', 'fr', 'de', + 'nl', 'fr', 'de', 'it', - # 'pl', 'es', 'en' + 'pl', 'es', 'en' ] WHISPER_ASR_MODEL = ['tiny', 'base', 'small', 'medium', 'large-v2'] DATASETS = [ - # 'google_fleurs', - # 'minds14', + 'google_fleurs', + 'minds14', 'voxpopuli' ] @@ -90,6 +90,16 @@ def add_facebook_hf_wav2vec2_pipeline(channel: BlockingChannel): add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments') +def add_nvidia(channel: BlockingChannel): + languages = ['de', 'en', 'es', 'fr', 'it'] + for language_code in languages: + asr_name = f'nvidia_stt_{language_code}_conformer_transducer_large' + for datasets in DATASETS: + dataset_name = f'{language_code}_{datasets}' + for command in COMMANDS: + add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments') + + def main(): parameters = pika.URLParameters( 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') @@ -98,8 +108,8 @@ def main(): # add_whisper(channel) # add_facebook_hf_wav2vec2_asr(channel) # add_facebook_hf_wav2vec2_pipeline(channel) + add_nvidia(channel) connection.close() - # ['de', 'en', 'es', 'fr', 'it'] if __name__ == '__main__': -- GitLab