Skip to content
Snippets Groups Projects
Commit f5b5ac00 authored by Marcin Wątroba's avatar Marcin Wątroba
Browse files

Add stats

parent dea0a8df
Branches
No related merge requests found
from new_experiment.new_dependency_provider import get_experiment_repository
def get_stats_for(dataset_name: str, property_name: str) -> float:
repo = get_experiment_repository(dataset_name)
vals = [repo.get_property_for_key(it, property_name) for it in repo.get_all_record_ids_for_property(property_name)]
print(vals)
vals = [it for it in vals if isinstance(it, float)]
ret = 0.0
if len(vals) == 0:
ret = -1
else:
ret = sum(vals) / len(vals)
print(dataset_name, property_name, ret)
def get_stats_for_classic_wer(dataset_name: str, property_name: str) -> float:
repo = get_experiment_repository(dataset_name)
vals = [repo.get_property_for_key(it, property_name) for it in repo.get_all_record_ids_for_property(property_name)]
vals = [it['classic_wer'] for it in vals if 'classic_wer' in it]
vals = [it for it in vals if isinstance(it, float)]
return sum(vals) / len(vals)
def get_stats_for_soft_wer(dataset_name: str, property_name: str) -> float:
repo = get_experiment_repository(dataset_name)
vals = [repo.get_property_for_key(it, property_name) for it in repo.get_all_record_ids_for_property(property_name)]
vals = [it['soft_wer'] for it in vals if 'soft_wer' in it]
vals = [it for it in vals if isinstance(it, float)]
return sum(vals) / len(vals)
def get_stats_for_embedding_wer(dataset_name: str, property_name: str) -> float:
repo = get_experiment_repository(dataset_name)
vals = [repo.get_property_for_key(it, property_name) for it in repo.get_all_record_ids_for_property(property_name)]
vals = [it['embedding_wer'] for it in vals if 'embedding_wer' in it]
vals = [it for it in vals if isinstance(it, float)]
return sum(vals) / len(vals)
import spacy
def get_spacy_model_name(language_code_2_letter: str) -> str:
if language_code_2_letter == 'en':
return 'en_core_web_lg'
return f'{language_code_2_letter}_core_news_lg'
if __name__ == '__main__':
spacy.load('en_core_web_lg')
This diff is collapsed.
......@@ -278,7 +278,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.9.15"
}
},
"nbformat": 4,
......
......@@ -33,6 +33,7 @@ scipy = "^1.10.0"
pika = "^1.3.1"
pyopenssl = "^23.0.0"
nltk = "^3.8.1"
jupyterlab = "^3.5.2"
[tool.poetry.group.dev.dependencies]
pytest = "^7.2.0"
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment