diff --git a/dvc.yaml b/dvc.yaml index 4c9fc0b99643dc31ec32621d8d10e48723813ffe..43616eec06a1af84b9443e3b216c2042ed74da14 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -154,9 +154,9 @@ stages: - dataset: pl_minds14 asr: whisper_tiny do: - cmd: PYTHONPATH=. python experiment/pipeline_process_word_wer.py --dataset=${item.dataset} --asr=${item.asr} + cmd: PYTHONPATH=. python experiment/pipeline_process_word_classic_wer.py --dataset=${item.dataset} --asr=${item.asr} deps: - - experiment/pipeline_process_word_wer.py + - experiment/pipeline_process_word_classic_wer.py - experiment_data/dataset/${item.dataset} - experiment_data/pipeline/${item.dataset}/gold_transcript - experiment_data/pipeline/${item.dataset}/${item.asr}__result diff --git a/new_experiment/pipeline/dataset_importer/fleurs_dataset_importer.py b/new_experiment/pipeline/dataset_importer/fleurs_dataset_importer.py index 9b88c89112b4eef3fb620dbda4d34979a79ef133..11eaca6b778f72f5584de97c24317f3e42b87eee 100644 --- a/new_experiment/pipeline/dataset_importer/fleurs_dataset_importer.py +++ b/new_experiment/pipeline/dataset_importer/fleurs_dataset_importer.py @@ -27,4 +27,4 @@ class FleursDatasetImporter(HfDatasetImporter): return record['path'] def get_record_id(self, record: Dict[str, Any]) -> str: - return record["id"] + return str(record["id"]) diff --git a/new_experiment/pipeline/import_datasets.py b/new_experiment/pipeline/dataset_importer/import_datasets.py similarity index 100% rename from new_experiment/pipeline/import_datasets.py rename to new_experiment/pipeline/dataset_importer/import_datasets.py diff --git a/new_experiment/pipeline/import_fleurs.py b/new_experiment/pipeline/dataset_importer/import_fleurs.py similarity index 81% rename from new_experiment/pipeline/import_fleurs.py rename to new_experiment/pipeline/dataset_importer/import_fleurs.py index f08197d50132213a8e23b313b04893bb0dae0b81..a81c5f3044b25aebfaefbc7c77c6584de3856ece 100644 --- a/new_experiment/pipeline/import_fleurs.py +++ b/new_experiment/pipeline/dataset_importer/import_fleurs.py @@ -1,4 +1,4 @@ -from new_experiment.pipeline.import_datasets import import_fleurs_dataset +from new_experiment.pipeline.dataset_importer.import_datasets import import_fleurs_dataset if __name__ == '__main__': import_fleurs_dataset('nl_nl', 'nl_google_fleurs') diff --git a/new_experiment/pipeline/import_minds14.py b/new_experiment/pipeline/dataset_importer/import_minds14.py similarity index 80% rename from new_experiment/pipeline/import_minds14.py rename to new_experiment/pipeline/dataset_importer/import_minds14.py index 3501909ce2433d42d7932dd929453e333880007d..01871ad0d73168178139bd642c2c38eebace1bde 100644 --- a/new_experiment/pipeline/import_minds14.py +++ b/new_experiment/pipeline/dataset_importer/import_minds14.py @@ -1,4 +1,4 @@ -from new_experiment.pipeline.import_datasets import import_minds14_dataset +from new_experiment.pipeline.dataset_importer.import_datasets import import_minds14_dataset if __name__ == '__main__': import_minds14_dataset('nl-NL', 'nl_minds14') diff --git a/new_experiment/pipeline/dataset_importer/import_voxpopuli.py b/new_experiment/pipeline/dataset_importer/import_voxpopuli.py new file mode 100644 index 0000000000000000000000000000000000000000..7e4ba81b1fe85e7ef165daae1a5d17e408797aac --- /dev/null +++ b/new_experiment/pipeline/dataset_importer/import_voxpopuli.py @@ -0,0 +1,57 @@ +import json +from pathlib import Path +from typing import Any, List + +from nltk import RegexpTokenizer + +from new_experiment.new_dependency_provider import get_experiment_repository +from new_experiment.utils.property_helper import PropertyHelper +from sziszapangma.model.model_creators import create_new_word + + +# de_voxpopuli + +def get_words(raw: str) -> List[str]: + tokenizer = RegexpTokenizer(r'\w+') + return tokenizer.tokenize(raw) + + +def import_from_file(lang: str): + path = Path(f'/Users/marcinwatroba/Desktop/MY_PROJECTS/playground/librispeech/cache_items_{lang}_voxpopuli.jsonl') + with open(path, 'r') as reader: + dataset_name = f'{lang}_voxpopuli' + repo = get_experiment_repository(dataset_name) + for line in reader.read().splitlines(keepends=False): + it_dict = json.loads(line) + print(it_dict) + record_id = str(it_dict['audio_unique_id']) + raw_text = it_dict['raw_text'] + normalized_text_words = [create_new_word(it) for it in get_words(it_dict['normalized_text'])] + repo.update_property_for_key( + record_id=record_id, + property_name=PropertyHelper.get_gold_transcript_words(), + property_value=normalized_text_words + ) + repo.update_property_for_key( + record_id=record_id, + property_name=PropertyHelper.get_gold_transcript_raw(), + property_value={'gold_transcript_raw': raw_text} + ) + + +if __name__ == '__main__': + # import_voxpopuli_dataset('nl', 'nl_voxpopuli') + # import_voxpopuli_dataset('fr', 'fr_voxpopuli') + # import_voxpopuli_dataset('de', 'de_voxpopuli') + # import_voxpopuli_dataset('it', 'it_voxpopuli') + # import_voxpopuli_dataset('pl', 'pl_voxpopuli') + # import_voxpopuli_dataset('es', 'es_voxpopuli') + # import_voxpopuli_dataset('en', 'en_voxpopuli') + + import_from_file('nl') + import_from_file('fr') + import_from_file('de') + # import_from_file('it') + import_from_file('pl') + import_from_file('es') + import_from_file('en') diff --git a/new_experiment/pipeline/import_voxpopuli.py b/new_experiment/pipeline/import_voxpopuli.py deleted file mode 100644 index 1ecfb6d2b077e1315ae5c0f3fc2661de44d99dc9..0000000000000000000000000000000000000000 --- a/new_experiment/pipeline/import_voxpopuli.py +++ /dev/null @@ -1,10 +0,0 @@ -from new_experiment.pipeline.import_datasets import import_voxpopuli_dataset - -if __name__ == '__main__': - import_voxpopuli_dataset('nl', 'nl_voxpopuli') - import_voxpopuli_dataset('fr', 'fr_voxpopuli') - import_voxpopuli_dataset('de', 'de_voxpopuli') - import_voxpopuli_dataset('it', 'it_voxpopuli') - import_voxpopuli_dataset('pl', 'pl_voxpopuli') - import_voxpopuli_dataset('es', 'es_voxpopuli') - import_voxpopuli_dataset('en', 'en_voxpopuli') diff --git a/new_experiment/pipeline/pipeline_process_word_classic_wer.py b/new_experiment/pipeline/pipeline_process_word_classic_wer.py new file mode 100644 index 0000000000000000000000000000000000000000..9489f64a4244d48aa05f34b81876c5634e02c958 --- /dev/null +++ b/new_experiment/pipeline/pipeline_process_word_classic_wer.py @@ -0,0 +1,43 @@ +import argparse + +from experiment.const_pipeline_names import GOLD_TRANSCRIPT +from experiment.experiment_dependency_provider import get_record_provider, get_repository +from new_experiment.new_dependency_provider import get_experiment_repository, get_minio_audio_record_repository +from new_experiment.utils.loaded_remote_dataset_helper import LoadedRemoteDatasetHelper +from new_experiment.utils.property_helper import PropertyHelper +from sziszapangma.core.transformer.web_embedding_transformer import WebEmbeddingTransformer +from sziszapangma.integration.experiment_manager import ExperimentManager +from sziszapangma.integration.task.classic_wer_metric_task import ClassicWerMetricTask +from sziszapangma.integration.task.embedding_wer_metrics_task import EmbeddingWerMetricsTask + + +def run_word_wer_classic_pipeline(dataset_name: str, asr_name: str): + repository = get_experiment_repository(dataset_name) + experiment_processor = ExperimentManager( + record_id_iterator=LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name), + processing_tasks=[ + ClassicWerMetricTask( + task_name=f'ClassicWerMetricTask___{dataset_name}___{asr_name}', + asr_property_name=PropertyHelper.asr_result(asr_name), + gold_transcript_property_name=PropertyHelper.get_gold_transcript_words(), + metrics_property_name=PropertyHelper.word_wer_classic_metrics(asr_name), + require_update=True, + alignment_property_name=PropertyHelper.word_wer_classic_alignment(asr_name) + ), + # EmbeddingWerMetricsTask( + # task_name='EmbeddingWerMetricsTask', + # asr_property_name=f'{asr_name}__result', + # gold_transcript_property_name=GOLD_TRANSCRIPT, + # metrics_property_name=f'{asr_name}__word_wer_embeddings_metrics', + # require_update=False, + # embedding_transformer=WebEmbeddingTransformer('pl', 'http://localhost:5003', 'fjsd-mkwe-oius-m9h2'), + # alignment_property_name=f'{asr_name}__word_wer_embeddings_alignment' + # ) + ], + experiment_repository=repository + ) + experiment_processor.process() + + +if __name__ == '__main__': + run_word_wer_classic_pipeline('de_google_fleurs', 'whisper_tiny') diff --git a/new_experiment/pipeline/pipeline_process_word_embedding_wer.py b/new_experiment/pipeline/pipeline_process_word_embedding_wer.py new file mode 100644 index 0000000000000000000000000000000000000000..071be0ee8ac566ab36ea6d10e757f00f292dc84b --- /dev/null +++ b/new_experiment/pipeline/pipeline_process_word_embedding_wer.py @@ -0,0 +1,36 @@ +import argparse + +from experiment.const_pipeline_names import GOLD_TRANSCRIPT +from experiment.experiment_dependency_provider import get_record_provider, get_repository +from new_experiment.new_dependency_provider import get_experiment_repository, get_minio_audio_record_repository +from new_experiment.utils.loaded_remote_dataset_helper import LoadedRemoteDatasetHelper +from new_experiment.utils.property_helper import PropertyHelper +from sziszapangma.core.transformer.fasttext_embedding_transformer import FasttextEmbeddingTransformer +from sziszapangma.core.transformer.web_embedding_transformer import WebEmbeddingTransformer +from sziszapangma.integration.experiment_manager import ExperimentManager +from sziszapangma.integration.task.classic_wer_metric_task import ClassicWerMetricTask +from sziszapangma.integration.task.embedding_wer_metrics_task import EmbeddingWerMetricsTask + + +def run_word_wer_classic_pipeline(dataset_name: str, asr_name: str): + repository = get_experiment_repository(dataset_name) + experiment_processor = ExperimentManager( + record_id_iterator=LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name), + processing_tasks=[ + EmbeddingWerMetricsTask( + task_name='EmbeddingWerMetricsTask', + asr_property_name=PropertyHelper.asr_result(asr_name), + gold_transcript_property_name=PropertyHelper.get_gold_transcript_words(), + metrics_property_name=PropertyHelper.word_wer_embeddings_metrics(asr_name), + require_update=False, + embedding_transformer=FasttextEmbeddingTransformer(dataset_name[:2]), + alignment_property_name=PropertyHelper.word_wer_embeddings_alignment(asr_name) + ) + ], + experiment_repository=repository + ) + experiment_processor.process() + + +if __name__ == '__main__': + run_word_wer_classic_pipeline('de_google_fleurs', 'whisper_tiny') diff --git a/new_experiment/utils/loaded_remote_dataset_helper.py b/new_experiment/utils/loaded_remote_dataset_helper.py index 78dc8bf3a6d7a8ef0d2a5c3abd4317e000d950de..2ead6e3ce754e57bfbef2ed6d017fb65df5fc388 100644 --- a/new_experiment/utils/loaded_remote_dataset_helper.py +++ b/new_experiment/utils/loaded_remote_dataset_helper.py @@ -1,24 +1,22 @@ from pathlib import Path from typing import Set -from minio import Minio -from urllib3 import HTTPResponse - from experiment.dataset_helper import DatasetHelper -from new_experiment.utils.minio_audio_record_repository import MinioRecordRepository +from new_experiment.utils.minio_audio_record_repository import MinioAudioRecordRepository from new_experiment.utils.property_helper import PropertyHelper from sziszapangma.integration.repository.experiment_repository import ExperimentRepository class LoadedRemoteDatasetHelper(DatasetHelper): _experiment_repository: ExperimentRepository - _minio_record_repository: MinioRecordRepository + _minio_audio_record_repository: MinioAudioRecordRepository _dataset_name: str - def __init__(self, experiment_repository: ExperimentRepository, minio_record_repository: MinioRecordRepository, + def __init__(self, experiment_repository: ExperimentRepository, + minio_audio_record_repository: MinioAudioRecordRepository, dataset_name: str): self._experiment_repository = experiment_repository - self._minio_record_repository = minio_record_repository + self._minio_audio_record_repository = minio_audio_record_repository self._dataset_name = dataset_name def get_all_records(self) -> Set[str]: @@ -28,5 +26,5 @@ class LoadedRemoteDatasetHelper(DatasetHelper): record_path = Path.home() / f'.cache/asr_benchmark/{self._dataset_name}/{record_id}.wav' if record_path.exists(): return record_path.as_posix() - self._minio_record_repository.save_file(record_path, self._dataset_name, record_id) + self._minio_audio_record_repository.save_file(record_path, self._dataset_name, record_id) return record_path.as_posix() diff --git a/sziszapangma/core/transformer/fasttext_embedding_transformer.py b/sziszapangma/core/transformer/fasttext_embedding_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b95b93cf6fb8d09767cb3a3cc78c86a1314a2c1e --- /dev/null +++ b/sziszapangma/core/transformer/fasttext_embedding_transformer.py @@ -0,0 +1,25 @@ +import json +from typing import Dict, List, Optional + +import numpy as np +import requests +from fasttext.FastText import _FastText + +from sziszapangma.core.transformer.embedding_transformer import EmbeddingTransformer +import fasttext.util + + +class FasttextEmbeddingTransformer(EmbeddingTransformer): + _lang_id: str + _model: _FastText + + def __init__(self, lang_id: str): + self._lang_id = lang_id + fasttext.util.download_model(lang_id, if_exists='ignore') + ft = fasttext.load_model(f'cc.{lang_id}.300.bin') + + def get_embedding(self, word: str) -> np.ndarray: + return self._model.get_word_vector(word) + + def get_embeddings(self, words: List[str]) -> Dict[str, np.ndarray]: + return {it: self.get_embedding(it) for it in words} diff --git a/sziszapangma/integration/experiment_manager.py b/sziszapangma/integration/experiment_manager.py index 5ecd58f4f11a9965958e91a4135532f393f9a9ed..14c2ba3b5a46d8c04eb2909b3a05161d2da0e73a 100644 --- a/sziszapangma/integration/experiment_manager.py +++ b/sziszapangma/integration/experiment_manager.py @@ -11,25 +11,21 @@ class ExperimentManager: _experiment_repository: ExperimentRepository _record_id_iterator: RecordIdIterator _processing_tasks: List[ProcessingTask] - _relation_manager_provider: RelationManagerProvider def __init__( self, experiment_repository: ExperimentRepository, record_id_iterator: RecordIdIterator, processing_tasks: List[ProcessingTask], - relation_manager_provider: RelationManagerProvider, ): self._experiment_repository = experiment_repository self._record_id_iterator = record_id_iterator self._processing_tasks = processing_tasks - self._relation_manager_provider = relation_manager_provider def process(self): self._experiment_repository.initialise() for processing_task in self._processing_tasks: processing_task.process( self._record_id_iterator, - self._experiment_repository, - self._relation_manager_provider, + self._experiment_repository ) diff --git a/sziszapangma/integration/repository/mongo_experiment_repository.py b/sziszapangma/integration/repository/mongo_experiment_repository.py index 6c87a1dc87b4e83c636b8006bebb5ac8ef3202c5..0e9fb114ca2f09115579281fbb81ed1a999284f5 100644 --- a/sziszapangma/integration/repository/mongo_experiment_repository.py +++ b/sziszapangma/integration/repository/mongo_experiment_repository.py @@ -24,9 +24,16 @@ class MongoExperimentRepository(ExperimentRepository): def property_exists(self, record_id: str, property_name: str) -> bool: database = self._get_database() all_collections = database.list_collection_names() + print(property_name, all_collections) if property_name not in all_collections: + print('collection not found') return False else: + print('self.get_all_record_ids_for_property(property_name)', record_id, record_id.__class__, + # record_id in self.get_all_record_ids_for_property(property_name), + # len(self.get_all_record_ids_for_property(property_name)), + list(self.get_all_record_ids_for_property(property_name))[0], + list(self.get_all_record_ids_for_property(property_name))[0].__class__) return database[property_name].find_one({ID: record_id}) is not None def update_property_for_key(self, record_id: str, property_name: str, property_value: Any): diff --git a/sziszapangma/integration/task/classic_wer_metric_task.py b/sziszapangma/integration/task/classic_wer_metric_task.py index 27ee08aeee9b6eeb7179d3c05947da47f49b1d91..af7a46bc233d90e47e5a0a492286aed452cf3792 100644 --- a/sziszapangma/integration/task/classic_wer_metric_task.py +++ b/sziszapangma/integration/task/classic_wer_metric_task.py @@ -1,3 +1,4 @@ +from pprint import pprint from typing import Any, Dict, List from sziszapangma.core.alignment.alignment_classic_calculator import AlignmentClassicCalculator @@ -47,9 +48,10 @@ class ClassicWerMetricTask(ProcessingTask): self, record_id: str, experiment_repository: ExperimentRepository, - relation_manager: RelationManager, ): - gold_transcript = TaskUtil.get_words_from_record(relation_manager) + print('#############') + gold_transcript = experiment_repository.get_property_for_key(record_id, self._gold_transcript_property_name) + print('$$$$$$$$$$$$$', gold_transcript) asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name) if gold_transcript is not None and asr_result is not None and "transcription" in asr_result: alignment_steps = self._get_alignment(gold_transcript, asr_result["transcription"])