Skip to content
Snippets Groups Projects
Select Git revision
  • 000a7941dca90b341b479843c60974d8d1ba954b
  • master default protected
  • vertical_relations
  • lu_without_semantic_frames
  • hierarchy
  • additional-unification-filters
  • v0.1.1
  • v0.1.0
  • v0.0.9
  • v0.0.8
  • v0.0.7
  • v0.0.6
  • v0.0.5
  • v0.0.4
  • v0.0.3
  • v0.0.2
  • v0.0.1
17 results

run-docker

Blame
  • call_experiment_stats.py 3.99 KiB
    from new_experiment.new_dependency_provider import get_experiment_repository
    from new_experiment.utils.get_spacy_model_name import get_spacy_model_name
    from new_experiment.utils.property_helper import PropertyHelper
    
    
    def get_stats_for(dataset_name: str, property_name: str) -> float:
        repo = get_experiment_repository(dataset_name)
        all_vals = repo.get_all_values_from_property(property_name)
        vals = [all_vals[record_id] for record_id in all_vals.keys()]
        vals = [ittt for ittt in vals if isinstance(ittt, float) and 10 > ittt > -2]
        if len(vals) == 0:
            ret = -1
        else:
            ret = sum(vals) / len(vals)
        print(dataset_name, property_name, ret)
        return ret
    
    
    def get_stats_for_classic_wer(dataset_name: str, property_name: str) -> float:
        repo = get_experiment_repository(dataset_name)
        all_vals = repo.get_all_values_from_property(property_name)
        vals = [all_vals[record_id] for record_id in all_vals.keys()]
        vals = [ittt['classic_wer'] for ittt in vals if 'classic_wer' in ittt]
        vals = [ittt for ittt in vals if isinstance(ittt, float) and 10 > ittt > -2]
        if len(vals) == 0:
            ret = -1
        else:
            ret = sum(vals) / len(vals)
        print(dataset_name, property_name, ret)
        return ret
    
    
    def get_stats_for_soft_wer(dataset_name: str, property_name: str) -> float:
        repo = get_experiment_repository(dataset_name)
        all_vals = repo.get_all_values_from_property(property_name)
        vals = [all_vals[record_id] for record_id in all_vals.keys()]
        vals = [ittt['soft_wer'] for ittt in vals if 'soft_wer' in ittt]
        vals = [ittt for ittt in vals if isinstance(ittt, float) and 10 > ittt > -2]
        if len(vals) == 0:
            ret = -1
        else:
            ret = sum(vals) / len(vals)
        print(dataset_name, property_name + '_soft', ret)
        return ret
    
    
    def get_stats_for_embedding_wer(dataset_name: str, property_name: str) -> float:
        repo = get_experiment_repository(dataset_name)
        vals = [repo.get_property_for_key(it, property_name) for it in repo.get_all_record_ids_for_property(property_name)]
        vals = [it['embedding_wer'] for it in vals if 'embedding_wer' in it]
        vals = [ittt for ittt in vals if isinstance(ittt, float)]
        if len(vals) == 0:
            ret = -1
        else:
            ret = sum(vals) / len(vals)
        print(dataset_name, property_name + '_emb', ret)
        return ret
    
    
    if __name__ == '__main__':
        COMMANDS = ['run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', 'run_spacy_dep_tag_wer_pipeline',
                    'run_spacy_ner_wer_pipeline', 'run_spacy_pos_wer_pipeline']
        LANGUAGES = ['nl', 'fr', 'de', 'it', 'pl', 'es', 'en']
        WHISPER_ASR_MODEL = ['tiny', 'base', 'small', 'medium', 'large-v2']
        DATASETS = ['google_fleurs', 'minds14', 'voxpopuli']
        FULL_DATASET_NAMES = []
        for itt in LANGUAGES:
            for it in DATASETS:
                FULL_DATASET_NAMES.append(f'{itt}_{it}')
    
        FULL_LANGUAGE_MODELS = [f'whisper_{it}' for it in WHISPER_ASR_MODEL]
    
        for dataset in FULL_DATASET_NAMES:
            for model in FULL_LANGUAGE_MODELS:
                get_stats_for(dataset, PropertyHelper.ner_metrics(model, get_spacy_model_name(dataset[:2])))
    
        for dataset in FULL_DATASET_NAMES:
            for model in FULL_LANGUAGE_MODELS:
                get_stats_for(dataset, PropertyHelper.pos_metrics(model, get_spacy_model_name(dataset[:2])))
    
        for dataset in FULL_DATASET_NAMES:
            for model in FULL_LANGUAGE_MODELS:
                get_stats_for(dataset, PropertyHelper.dep_tag_metrics(model, get_spacy_model_name(dataset[:2])))
    
        for dataset in FULL_DATASET_NAMES:
            for model in FULL_LANGUAGE_MODELS:
                get_stats_for_classic_wer(dataset, PropertyHelper.word_wer_classic_metrics(model))
    
        for dataset in FULL_DATASET_NAMES:
            for model in FULL_LANGUAGE_MODELS:
                get_stats_for_soft_wer(dataset, PropertyHelper.word_wer_embeddings_metrics(model))
    
        for dataset in FULL_DATASET_NAMES:
            for model in FULL_LANGUAGE_MODELS:
                get_stats_for_embedding_wer(dataset, PropertyHelper.word_wer_embeddings_metrics(model))