Skip to content
Snippets Groups Projects
Commit f8e0ba07 authored by Marcin Wątroba's avatar Marcin Wątroba
Browse files

Add metric calculator

parent 6619ed5a
No related merge requests found
import datetime
import json
from typing import List
from typing import List, Dict
import pika
from minio import Minio
from pika.adapters.blocking_connection import BlockingChannel
from new_experiment.new_dependency_provider import get_experiment_repository
COMMANDS = [
'run_word_wer_classic_pipeline',
'run_word_wer_embedding_pipeline',
......@@ -45,6 +47,10 @@ def get_minio_client() -> Minio:
def add_to_queue(dataset: str, asr_name: str, task: str, channel: BlockingChannel, queue_name: str):
message_dict = {'dataset': dataset, 'asr_name': asr_name, 'task': task}
add_message_to_queue(message_dict, channel, queue_name)
def add_message_to_queue(message_dict: Dict[str, str], channel: BlockingChannel, queue_name: str):
print(datetime.datetime.now().isoformat(), f'{queue_name} {message_dict}')
message_bytes = json.dumps(message_dict).encode('utf-8')
channel.queue_declare(queue=queue_name, durable=True)
......@@ -116,6 +122,15 @@ def add_all(channel: BlockingChannel):
add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments')
def add_calculate_metrics(channel: BlockingChannel):
for dataset_name in DATASETS:
repo = get_experiment_repository(dataset_name)
metric_properties = [it for it in repo.get_all_properties() if it.endswith('_metrics')]
for property_name in metric_properties:
add_message_to_queue({'dataset': dataset_name, 'metric_property': property_name}, channel,
'metric_stats_calculate')
def main():
parameters = pika.URLParameters(
'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
......@@ -125,7 +140,7 @@ def main():
# add_facebook_hf_wav2vec2_asr(channel)
# add_facebook_hf_wav2vec2_pipeline(channel)
# add_nvidia(channel)
add_all(channel)
add_calculate_metrics(channel)
connection.close()
......
import argparse
import json
from call_experiment_stats import get_stats_for
from new_experiment.new_dependency_provider import get_experiment_repository
from new_experiment.pipeline.pipeline_process_flair_upos import run_flair_upos_multi_pipeline
from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline
from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline
from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline
from new_experiment.pipeline.pipeline_process_wikineural_ner_wer import run_wikineural_ner_pipeline
from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline
from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline
from new_experiment.queue_base import process_queue
def process_message(body: bytes):
print(body)
message_dict = json.loads(body.decode('utf-8'))
print(message_dict)
metric_property = message_dict['metric_property']
dataset = message_dict['dataset']
metric = get_stats_for(dataset, metric_property)
get_experiment_repository('metric_stats').update_property_for_key(record_id=metric_property, property_name=dataset,
property_value=metric)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--prefetch_count")
args = parser.parse_args()
process_queue('asr_benchmark_experiments', process_message,
1 if args.prefetch_count in [None, ''] else int(args.prefetch_count))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment