Skip to content
Snippets Groups Projects
Commit 1936693a authored by Marcin Wątroba's avatar Marcin Wątroba
Browse files

download_dataset command

parent aa7b7f1e
No related merge requests found
import argparse
from typing import Optional
import datasets
def download_dataset(dataset_path: str, dataset_name: Optional[str], cache_dir: str):
dataset = datasets.load_dataset(dataset_path, dataset_name, cache_dir=cache_dir)
print(dataset)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path")
parser.add_argument("--dataset_name")
parser.add_argument("--cache_dir")
args = parser.parse_args()
download_dataset(args.dataset, args.asr)
...@@ -9,14 +9,14 @@ from pika.adapters.blocking_connection import BlockingChannel ...@@ -9,14 +9,14 @@ from pika.adapters.blocking_connection import BlockingChannel
COMMANDS = ['run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', 'run_spacy_dep_tag_wer_pipeline', COMMANDS = ['run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', 'run_spacy_dep_tag_wer_pipeline',
'run_spacy_ner_wer_pipeline', 'run_spacy_pos_wer_pipeline'] 'run_spacy_ner_wer_pipeline', 'run_spacy_pos_wer_pipeline']
LANGUAGES = [ LANGUAGES = [
# 'nl', 'fr', 'de', 'nl', 'fr', 'de',
'it', 'it',
# 'pl', 'es', 'en' 'pl', 'es', 'en'
] ]
WHISPER_ASR_MODEL = ['tiny', 'base', 'small', 'medium', 'large-v2'] WHISPER_ASR_MODEL = ['tiny', 'base', 'small', 'medium', 'large-v2']
DATASETS = [ DATASETS = [
# 'google_fleurs', 'google_fleurs',
# 'minds14', 'minds14',
'voxpopuli' 'voxpopuli'
] ]
...@@ -90,6 +90,16 @@ def add_facebook_hf_wav2vec2_pipeline(channel: BlockingChannel): ...@@ -90,6 +90,16 @@ def add_facebook_hf_wav2vec2_pipeline(channel: BlockingChannel):
add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments') add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments')
def add_nvidia(channel: BlockingChannel):
languages = ['de', 'en', 'es', 'fr', 'it']
for language_code in languages:
asr_name = f'nvidia_stt_{language_code}_conformer_transducer_large'
for datasets in DATASETS:
dataset_name = f'{language_code}_{datasets}'
for command in COMMANDS:
add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments')
def main(): def main():
parameters = pika.URLParameters( parameters = pika.URLParameters(
'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
...@@ -98,8 +108,8 @@ def main(): ...@@ -98,8 +108,8 @@ def main():
# add_whisper(channel) # add_whisper(channel)
# add_facebook_hf_wav2vec2_asr(channel) # add_facebook_hf_wav2vec2_asr(channel)
# add_facebook_hf_wav2vec2_pipeline(channel) # add_facebook_hf_wav2vec2_pipeline(channel)
add_nvidia(channel)
connection.close() connection.close()
# ['de', 'en', 'es', 'fr', 'it']
if __name__ == '__main__': if __name__ == '__main__':
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment