diff --git a/new_experiment/add_to_queue_pipeline.py b/new_experiment/add_to_queue_pipeline.py index c8d5253f9d877124c9ed22ff63d1d89396f49382..edb3444d465604c3699fce0a53bf0ee3b7dbe758 100644 --- a/new_experiment/add_to_queue_pipeline.py +++ b/new_experiment/add_to_queue_pipeline.py @@ -1,11 +1,13 @@ import datetime import json -from typing import List +from typing import List, Dict import pika from minio import Minio from pika.adapters.blocking_connection import BlockingChannel +from new_experiment.new_dependency_provider import get_experiment_repository + COMMANDS = [ 'run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', @@ -45,6 +47,10 @@ def get_minio_client() -> Minio: def add_to_queue(dataset: str, asr_name: str, task: str, channel: BlockingChannel, queue_name: str): message_dict = {'dataset': dataset, 'asr_name': asr_name, 'task': task} + add_message_to_queue(message_dict, channel, queue_name) + + +def add_message_to_queue(message_dict: Dict[str, str], channel: BlockingChannel, queue_name: str): print(datetime.datetime.now().isoformat(), f'{queue_name} {message_dict}') message_bytes = json.dumps(message_dict).encode('utf-8') channel.queue_declare(queue=queue_name, durable=True) @@ -116,6 +122,15 @@ def add_all(channel: BlockingChannel): add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments') +def add_calculate_metrics(channel: BlockingChannel): + for dataset_name in DATASETS: + repo = get_experiment_repository(dataset_name) + metric_properties = [it for it in repo.get_all_properties() if it.endswith('_metrics')] + for property_name in metric_properties: + add_message_to_queue({'dataset': dataset_name, 'metric_property': property_name}, channel, + 'metric_stats_calculate') + + def main(): parameters = pika.URLParameters( 'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/') @@ -125,7 +140,7 @@ def main(): # add_facebook_hf_wav2vec2_asr(channel) # add_facebook_hf_wav2vec2_pipeline(channel) # add_nvidia(channel) - add_all(channel) + add_calculate_metrics(channel) connection.close() diff --git a/new_experiment/worker_metric_stats.py b/new_experiment/worker_metric_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..e94cddd2c059ee3b9ef150a070612130abfbc16a --- /dev/null +++ b/new_experiment/worker_metric_stats.py @@ -0,0 +1,35 @@ +import argparse +import json + +from call_experiment_stats import get_stats_for +from new_experiment.new_dependency_provider import get_experiment_repository +from new_experiment.pipeline.pipeline_process_flair_upos import run_flair_upos_multi_pipeline +from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline +from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline +from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline +from new_experiment.pipeline.pipeline_process_wikineural_ner_wer import run_wikineural_ner_pipeline +from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline +from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline +from new_experiment.queue_base import process_queue + + +def process_message(body: bytes): + print(body) + message_dict = json.loads(body.decode('utf-8')) + print(message_dict) + + metric_property = message_dict['metric_property'] + dataset = message_dict['dataset'] + + metric = get_stats_for(dataset, metric_property) + + get_experiment_repository('metric_stats').update_property_for_key(record_id=metric_property, property_name=dataset, + property_value=metric) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--prefetch_count") + args = parser.parse_args() + process_queue('asr_benchmark_experiments', process_message, + 1 if args.prefetch_count in [None, ''] else int(args.prefetch_count))