Skip to content
Snippets Groups Projects
call_experiment_stats.py 1.72 KiB
Newer Older
Marcin Wątroba's avatar
Marcin Wątroba committed
from new_experiment.new_dependency_provider import get_experiment_repository


def get_stats_for(dataset_name: str, property_name: str) -> float:
    repo = get_experiment_repository(dataset_name)
    vals = [repo.get_property_for_key(it, property_name) for it in repo.get_all_record_ids_for_property(property_name)]
    print(vals)
    vals = [it for it in vals if isinstance(it, float)]
    ret = 0.0
    if len(vals) == 0:
        ret = -1
    else:
        ret = sum(vals) / len(vals)
    print(dataset_name, property_name, ret)


def get_stats_for_classic_wer(dataset_name: str, property_name: str) -> float:
    repo = get_experiment_repository(dataset_name)
    vals = [repo.get_property_for_key(it, property_name) for it in repo.get_all_record_ids_for_property(property_name)]
    vals = [it['classic_wer'] for it in vals if 'classic_wer' in it]
    vals = [it for it in vals if isinstance(it, float)]
    return sum(vals) / len(vals)


def get_stats_for_soft_wer(dataset_name: str, property_name: str) -> float:
    repo = get_experiment_repository(dataset_name)
    vals = [repo.get_property_for_key(it, property_name) for it in repo.get_all_record_ids_for_property(property_name)]
    vals = [it['soft_wer'] for it in vals if 'soft_wer' in it]
    vals = [it for it in vals if isinstance(it, float)]
    return sum(vals) / len(vals)


def get_stats_for_embedding_wer(dataset_name: str, property_name: str) -> float:
    repo = get_experiment_repository(dataset_name)
    vals = [repo.get_property_for_key(it, property_name) for it in repo.get_all_record_ids_for_property(property_name)]
    vals = [it['embedding_wer'] for it in vals if 'embedding_wer' in it]
    vals = [it for it in vals if isinstance(it, float)]
    return sum(vals) / len(vals)