Skip to content
Snippets Groups Projects
Commit c96708a5 authored by Marcin Wątroba's avatar Marcin Wątroba
Browse files

Add worker for pipeline

parent d1783291
Branches
No related merge requests found
import datetime
import json
from typing import List
import pika
from minio import Minio
from pika.adapters.blocking_connection import BlockingChannel
COMMANDS = ['run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', 'run_spacy_dep_tag_wer_pipeline',
'run_spacy_ner_wer_pipeline', 'run_spacy_pos_wer_pipeline']
LANGUAGES = ['nl', 'fr', 'de', 'it', 'pl', 'es', 'en']
WHISPER_ASR_MODEL = ['tiny', 'base', 'small', 'medium', 'large-v2']
DATASETS = ['google_fleurs', 'minds14', 'voxpopuli']
def get_all_datasets() -> List[str]:
return [it.object_name[:-1] for it in get_minio_client().list_objects('dataset-audio', '')]
def get_dataset_items_id(dataset_name: str) -> List[str]:
return [it.object_name.split('/')[-1].split('.')[0] for it in
get_minio_client().list_objects('dataset-audio', f'{dataset_name}/')]
def get_minio_client() -> Minio:
return Minio('minio-asr-benchmarks.theliver.pl', 'minio_user', 'eUxzEQbyYPdzrLxuvvethSbk18kB2s7G')
def add_to_queue(dataset: str, asr_name: str, task: str, channel: BlockingChannel):
message_dict = {'dataset': dataset, 'asr_name': asr_name, 'task': task}
queue_name = 'asr_benchmark_experiments'
print(datetime.datetime.now().isoformat(), f'{queue_name} {message_dict}')
message_bytes = json.dumps(message_dict).encode('utf-8')
channel.queue_declare(queue=queue_name, durable=True)
channel.basic_publish(exchange='', routing_key=queue_name, body=message_bytes, properties=pika.BasicProperties(
delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE
))
def get_all_datasets_with_language() -> List[str]:
to_return = []
for it in LANGUAGES:
for itt in DATASETS:
to_return.append(f'{it}_{itt}')
return to_return
def add_whisper(channel: BlockingChannel):
for whisper_variant in WHISPER_ASR_MODEL:
asr_name = f'whisper_{whisper_variant}'
for dataset_name in get_all_datasets_with_language():
for command in COMMANDS:
add_to_queue(dataset_name, asr_name, command, channel)
def main():
parameters = pika.URLParameters(
'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
connection = pika.BlockingConnection(parameters=parameters)
channel = connection.channel()
add_whisper(channel)
connection.close()
if __name__ == '__main__':
main()
......@@ -25,7 +25,7 @@ def run_spacy_dep_tag_wer_pipeline(dataset_name: str, asr_name: str):
alignment_property_name=PropertyHelper.dep_tag_alignment(asr_name, model_name),
wer_property_name=PropertyHelper.dep_tag_metrics(asr_name, model_name),
task_name=f'SpacyDepTagSentenceWerProcessor___{dataset_name}___{asr_name}',
require_update=False
require_update=True
)
],
experiment_repository=repository,
......
......@@ -25,7 +25,7 @@ def run_spacy_ner_wer_pipeline(dataset_name: str, asr_name: str):
alignment_property_name=PropertyHelper.ner_alignment(asr_name, model_name),
wer_property_name=PropertyHelper.ner_metrics(asr_name, model_name),
task_name=f'SpacyNerSentenceWerProcessor___{dataset_name}___{asr_name}',
require_update=False
require_update=True
)
],
experiment_repository=repository,
......
import json
import os
import uuid
import pika
from minio import Minio
from pymongo import MongoClient
from urllib3 import HTTPResponse
from new_datasets.whisper_processor import WhisperAsrProcessor
from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline
from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline
from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline
from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline
from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline
from sziszapangma.integration.repository.mongo_experiment_repository import MongoExperimentRepository
def get_param(name: str, default: str) -> str:
......@@ -39,7 +33,6 @@ def main():
task = message_dict['task']
dataset = message_dict['dataset']
asr_name = message_dict['asr_name']
if task == 'run_word_wer_classic_pipeline':
run_word_wer_classic_pipeline(dataset, asr_name)
elif task == 'run_word_wer_embedding_pipeline':
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment