Newer
Older
"id": "955a0385-29fb-47dc-b012-729e49570594",
"metadata": {},
"outputs": [],
"source": [
"from new_experiment.utils.get_spacy_model_name import *\n",
"\n",
"from call_experiment_stats import *\n",
"\n",
"from new_experiment.utils.property_helper import PropertyHelper\n",
"from new_experiment.utils.get_spacy_model_name import get_spacy_model_name\n",
"from new_experiment.new_dependency_provider import get_experiment_repository\n",
"from new_experiment.add_to_queue_pipeline import get_hf_facebook_wav2vec2_model_by_language_code\n",
"execution_count": null,
"id": "3f1221d3-5f70-4441-af07-58fa176e31e9",
"metadata": {},
"outputs": [],
"source": [
"METRICS_FILE = 'metrics.txt'"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "eda46e65-8079-40b9-9c4e-37fe74caec45",
"metadata": {},
"outputs": [],
"source": [
"metric_repository = get_experiment_repository('metric_stats')\n",
"with open(METRICS_FILE, 'w') as writer:\n",
" for dataset_property in metric_repository.get_all_properties():\n",
" values_dict = metric_repository.get_all_values_from_property(dataset_property)\n",
" for value_key in values_dict.keys():\n",
" line = f'{dataset_property} {value_key} {values_dict[value_key]}'\n",
" writer.write(f'{line}\\n')"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "9f5e44a6-f211-4b61-8cb4-5636c7672c6a",
"metadata": {},
"outputs": [],
"source": [
"COMMANDS = ['run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', 'run_spacy_dep_tag_wer_pipeline',\n",
" 'run_spacy_ner_wer_pipeline', 'run_spacy_pos_wer_pipeline']\n",
"LANGUAGES = ['nl', 'fr', 'de', 'it', 'pl', 'es', 'en']\n",
"WHISPER_ASR_MODEL = ['tiny', 'base', 'small', 'medium', 'large-v2']\n",
"DATASETS = ['google_fleurs', 'minds14', 'voxpopuli']\n",
"FULL_DATASET_NAMES = []\n",
"for itt in LANGUAGES:\n",
" for it in DATASETS:\n",
" FULL_DATASET_NAMES.append(f'{itt}_{it}')\n",
"FULL_LANGUAGE_MODELS = [f'whisper_{it}' for it in WHISPER_ASR_MODEL] + ['facebook_wav2vec2', 'nvidia_stt']"
"id": "d2465ceb-7439-4fa5-adf8-e95d7e6106b9",
"metadata": {},
"outputs": [],
"source": [
"vals = dict()\n",
"with open(METRICS_FILE, 'r') as reader:\n",
" lines = reader.read().splitlines(keepends=False)\n",
" for line in lines:\n",
" # print(line)\n",
" words = line.split()\n",
" key = f'{words[0]}_{words[1]}'\n",
"execution_count": 12,
"id": "e41b19d0-37cb-4810-896a-fa0f73dd86e0",
"metadata": {},
"outputs": [],
"source": [
"def get_model_for_dataset_name(dataset: str, model: str):\n",
" language_code = dataset[:2]\n",
" if model.startswith('whisper'):\n",
" return model\n",
" elif model.startswith('facebook_wav2vec2'):\n",
" return get_hf_facebook_wav2vec2_model_by_language_code(language_code)\n",
" elif model.startswith('nvidia_stt'):\n",
" return f'nvidia_stt_{language_code}_conformer_transducer_large'\n",
" else:\n",
" raise Exception('asr name not found')"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "22d84451-b7e3-4dba-9758-068dae23ace4",
"metadata": {},
"outputs": [],
"source": [
"spacy_ner = [\n",
" [vals.get(f'{dataset}_{PropertyHelper.ner_metrics(get_model_for_dataset_name(dataset, model), get_spacy_model_name(dataset[:2]))}', -1.0) \n",
" for model in FULL_LANGUAGE_MODELS]\n",
" for dataset in FULL_DATASET_NAMES\n",
"]\n",
"spacy_pos = [\n",
" [vals.get(f'{dataset}_{PropertyHelper.pos_metrics(get_model_for_dataset_name(dataset, model), get_spacy_model_name(dataset[:2]))}', -1.0) \n",
" for model in FULL_LANGUAGE_MODELS]\n",
" for dataset in FULL_DATASET_NAMES\n",
"]\n",
"spacy_dep = [\n",
" [vals.get(f'{dataset}_{PropertyHelper.pos_metrics(get_model_for_dataset_name(dataset, model), get_spacy_model_name(dataset[:2]))}', -1.0) \n",
" for model in FULL_LANGUAGE_MODELS]\n",
" for dataset in FULL_DATASET_NAMES\n",
"]\n",
"word_wer_classic_metrics = [\n",
" [vals.get(f'{dataset}_{PropertyHelper.word_wer_classic_metrics(get_model_for_dataset_name(dataset, model))}', -1.0) for model in FULL_LANGUAGE_MODELS]\n",
" for dataset in FULL_DATASET_NAMES\n",
"]\n",
"word_wer_soft_metrics = [\n",
" [vals.get(f'{dataset}_{PropertyHelper.word_wer_soft_metrics(get_model_for_dataset_name(dataset, model))}', -1.0) for model in FULL_LANGUAGE_MODELS]\n",
" for dataset in FULL_DATASET_NAMES\n",
"]\n",
"word_wer_embedding_metrics = [\n",
" [vals.get(f'{dataset}_{PropertyHelper.word_wer_embeddings_metrics(get_model_for_dataset_name(dataset, model))}', -1.0) for model in FULL_LANGUAGE_MODELS]\n",
" for dataset in FULL_DATASET_NAMES\n",
"]\n",
"flair_pos = [\n",
" [vals.get(f'{dataset}_{PropertyHelper.word_wer_embeddings_metrics(get_model_for_dataset_name(dataset, model))}', -1.0) for model in FULL_LANGUAGE_MODELS]\n",
" for dataset in FULL_DATASET_NAMES\n",
"]"
"id": "45fd851c-644f-48e6-b711-5bd312404b8b",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>whisper_tiny</th>\n",
" <th>whisper_base</th>\n",
" <th>whisper_small</th>\n",
" <th>whisper_medium</th>\n",
" <th>whisper_large-v2</th>\n",
" <th>facebook_wav2vec2</th>\n",
" <th>nvidia_stt</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>nl_google_fleurs</th>\n",
" <td>0.316124</td>\n",
" <td>0.230845</td>\n",
" <td>0.186936</td>\n",
" <td>0.170150</td>\n",
" <td>0.165057</td>\n",
" <td>0.082781</td>\n",
" <td>-1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>nl_minds14</th>\n",
" <td>0.463084</td>\n",
" <td>0.409993</td>\n",
" <td>0.360934</td>\n",
" <td>0.331613</td>\n",
" <td>0.324172</td>\n",
" <td>0.142155</td>\n",
" <td>-1.000000</td>\n",
Loading full blame...