From 1936693a2226bd1e0cf235cd02f7caf06ce90be1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com>
Date: Sun, 15 Jan 2023 01:26:25 +0100
Subject: [PATCH] download_dataset command

---
 new_datasets/download_hf_dataset.py     | 18 ++++++++++++++++++
 new_experiment/add_to_queue_pipeline.py | 20 +++++++++++++++-----
 2 files changed, 33 insertions(+), 5 deletions(-)
 create mode 100644 new_datasets/download_hf_dataset.py

diff --git a/new_datasets/download_hf_dataset.py b/new_datasets/download_hf_dataset.py
new file mode 100644
index 0000000..1f6effc
--- /dev/null
+++ b/new_datasets/download_hf_dataset.py
@@ -0,0 +1,18 @@
+import argparse
+from typing import Optional
+
+import datasets
+
+
+def download_dataset(dataset_path: str, dataset_name: Optional[str], cache_dir: str):
+    dataset = datasets.load_dataset(dataset_path, dataset_name, cache_dir=cache_dir)
+    print(dataset)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--dataset_path")
+    parser.add_argument("--dataset_name")
+    parser.add_argument("--cache_dir")
+    args = parser.parse_args()
+    download_dataset(args.dataset, args.asr)
diff --git a/new_experiment/add_to_queue_pipeline.py b/new_experiment/add_to_queue_pipeline.py
index b87ca62..c6e2f36 100644
--- a/new_experiment/add_to_queue_pipeline.py
+++ b/new_experiment/add_to_queue_pipeline.py
@@ -9,14 +9,14 @@ from pika.adapters.blocking_connection import BlockingChannel
 COMMANDS = ['run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', 'run_spacy_dep_tag_wer_pipeline',
             'run_spacy_ner_wer_pipeline', 'run_spacy_pos_wer_pipeline']
 LANGUAGES = [
-    # 'nl', 'fr', 'de',
+    'nl', 'fr', 'de',
     'it',
-    # 'pl', 'es', 'en'
+    'pl', 'es', 'en'
 ]
 WHISPER_ASR_MODEL = ['tiny', 'base', 'small', 'medium', 'large-v2']
 DATASETS = [
-    # 'google_fleurs',
-    # 'minds14',
+    'google_fleurs',
+    'minds14',
     'voxpopuli'
 ]
 
@@ -90,6 +90,16 @@ def add_facebook_hf_wav2vec2_pipeline(channel: BlockingChannel):
             add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments')
 
 
+def add_nvidia(channel: BlockingChannel):
+    languages = ['de', 'en', 'es', 'fr', 'it']
+    for language_code in languages:
+        asr_name = f'nvidia_stt_{language_code}_conformer_transducer_large'
+        for datasets in DATASETS:
+            dataset_name = f'{language_code}_{datasets}'
+            for command in COMMANDS:
+                add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments')
+
+
 def main():
     parameters = pika.URLParameters(
         'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
@@ -98,8 +108,8 @@ def main():
     # add_whisper(channel)
     # add_facebook_hf_wav2vec2_asr(channel)
     # add_facebook_hf_wav2vec2_pipeline(channel)
+    add_nvidia(channel)
     connection.close()
-    # ['de', 'en', 'es', 'fr', 'it']
 
 
 if __name__ == '__main__':
-- 
GitLab