From e407a441e9775b6e8bd29165c0f886b092540f51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com> Date: Thu, 12 Jan 2023 01:48:09 +0100 Subject: [PATCH] Add pipe --- dvc.yaml | 4 +- .../fleurs_dataset_importer.py | 2 +- .../{ => dataset_importer}/import_datasets.py | 0 .../{ => dataset_importer}/import_fleurs.py | 2 +- .../{ => dataset_importer}/import_minds14.py | 2 +- .../dataset_importer/import_voxpopuli.py | 57 +++++++++++++++++++ new_experiment/pipeline/import_voxpopuli.py | 10 ---- .../pipeline_process_word_classic_wer.py | 43 ++++++++++++++ .../pipeline_process_word_embedding_wer.py | 36 ++++++++++++ .../utils/loaded_remote_dataset_helper.py | 14 ++--- .../fasttext_embedding_transformer.py | 25 ++++++++ .../integration/experiment_manager.py | 6 +- .../repository/mongo_experiment_repository.py | 7 +++ .../task/classic_wer_metric_task.py | 6 +- 14 files changed, 184 insertions(+), 30 deletions(-) rename new_experiment/pipeline/{ => dataset_importer}/import_datasets.py (100%) rename new_experiment/pipeline/{ => dataset_importer}/import_fleurs.py (81%) rename new_experiment/pipeline/{ => dataset_importer}/import_minds14.py (80%) create mode 100644 new_experiment/pipeline/dataset_importer/import_voxpopuli.py delete mode 100644 new_experiment/pipeline/import_voxpopuli.py create mode 100644 new_experiment/pipeline/pipeline_process_word_classic_wer.py create mode 100644 new_experiment/pipeline/pipeline_process_word_embedding_wer.py create mode 100644 sziszapangma/core/transformer/fasttext_embedding_transformer.py diff --git a/dvc.yaml b/dvc.yaml index 4c9fc0b..43616ee 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -154,9 +154,9 @@ stages: - dataset: pl_minds14 asr: whisper_tiny do: - cmd: PYTHONPATH=. python experiment/pipeline_process_word_wer.py --dataset=${item.dataset} --asr=${item.asr} + cmd: PYTHONPATH=. python experiment/pipeline_process_word_classic_wer.py --dataset=${item.dataset} --asr=${item.asr} deps: - - experiment/pipeline_process_word_wer.py + - experiment/pipeline_process_word_classic_wer.py - experiment_data/dataset/${item.dataset} - experiment_data/pipeline/${item.dataset}/gold_transcript - experiment_data/pipeline/${item.dataset}/${item.asr}__result diff --git a/new_experiment/pipeline/dataset_importer/fleurs_dataset_importer.py b/new_experiment/pipeline/dataset_importer/fleurs_dataset_importer.py index 9b88c89..11eaca6 100644 --- a/new_experiment/pipeline/dataset_importer/fleurs_dataset_importer.py +++ b/new_experiment/pipeline/dataset_importer/fleurs_dataset_importer.py @@ -27,4 +27,4 @@ class FleursDatasetImporter(HfDatasetImporter): return record['path'] def get_record_id(self, record: Dict[str, Any]) -> str: - return record["id"] + return str(record["id"]) diff --git a/new_experiment/pipeline/import_datasets.py b/new_experiment/pipeline/dataset_importer/import_datasets.py similarity index 100% rename from new_experiment/pipeline/import_datasets.py rename to new_experiment/pipeline/dataset_importer/import_datasets.py diff --git a/new_experiment/pipeline/import_fleurs.py b/new_experiment/pipeline/dataset_importer/import_fleurs.py similarity index 81% rename from new_experiment/pipeline/import_fleurs.py rename to new_experiment/pipeline/dataset_importer/import_fleurs.py index f08197d..a81c5f3 100644 --- a/new_experiment/pipeline/import_fleurs.py +++ b/new_experiment/pipeline/dataset_importer/import_fleurs.py @@ -1,4 +1,4 @@ -from new_experiment.pipeline.import_datasets import import_fleurs_dataset +from new_experiment.pipeline.dataset_importer.import_datasets import import_fleurs_dataset if __name__ == '__main__': import_fleurs_dataset('nl_nl', 'nl_google_fleurs') diff --git a/new_experiment/pipeline/import_minds14.py b/new_experiment/pipeline/dataset_importer/import_minds14.py similarity index 80% rename from new_experiment/pipeline/import_minds14.py rename to new_experiment/pipeline/dataset_importer/import_minds14.py index 3501909..01871ad 100644 --- a/new_experiment/pipeline/import_minds14.py +++ b/new_experiment/pipeline/dataset_importer/import_minds14.py @@ -1,4 +1,4 @@ -from new_experiment.pipeline.import_datasets import import_minds14_dataset +from new_experiment.pipeline.dataset_importer.import_datasets import import_minds14_dataset if __name__ == '__main__': import_minds14_dataset('nl-NL', 'nl_minds14') diff --git a/new_experiment/pipeline/dataset_importer/import_voxpopuli.py b/new_experiment/pipeline/dataset_importer/import_voxpopuli.py new file mode 100644 index 0000000..7e4ba81 --- /dev/null +++ b/new_experiment/pipeline/dataset_importer/import_voxpopuli.py @@ -0,0 +1,57 @@ +import json +from pathlib import Path +from typing import Any, List + +from nltk import RegexpTokenizer + +from new_experiment.new_dependency_provider import get_experiment_repository +from new_experiment.utils.property_helper import PropertyHelper +from sziszapangma.model.model_creators import create_new_word + + +# de_voxpopuli + +def get_words(raw: str) -> List[str]: + tokenizer = RegexpTokenizer(r'\w+') + return tokenizer.tokenize(raw) + + +def import_from_file(lang: str): + path = Path(f'/Users/marcinwatroba/Desktop/MY_PROJECTS/playground/librispeech/cache_items_{lang}_voxpopuli.jsonl') + with open(path, 'r') as reader: + dataset_name = f'{lang}_voxpopuli' + repo = get_experiment_repository(dataset_name) + for line in reader.read().splitlines(keepends=False): + it_dict = json.loads(line) + print(it_dict) + record_id = str(it_dict['audio_unique_id']) + raw_text = it_dict['raw_text'] + normalized_text_words = [create_new_word(it) for it in get_words(it_dict['normalized_text'])] + repo.update_property_for_key( + record_id=record_id, + property_name=PropertyHelper.get_gold_transcript_words(), + property_value=normalized_text_words + ) + repo.update_property_for_key( + record_id=record_id, + property_name=PropertyHelper.get_gold_transcript_raw(), + property_value={'gold_transcript_raw': raw_text} + ) + + +if __name__ == '__main__': + # import_voxpopuli_dataset('nl', 'nl_voxpopuli') + # import_voxpopuli_dataset('fr', 'fr_voxpopuli') + # import_voxpopuli_dataset('de', 'de_voxpopuli') + # import_voxpopuli_dataset('it', 'it_voxpopuli') + # import_voxpopuli_dataset('pl', 'pl_voxpopuli') + # import_voxpopuli_dataset('es', 'es_voxpopuli') + # import_voxpopuli_dataset('en', 'en_voxpopuli') + + import_from_file('nl') + import_from_file('fr') + import_from_file('de') + # import_from_file('it') + import_from_file('pl') + import_from_file('es') + import_from_file('en') diff --git a/new_experiment/pipeline/import_voxpopuli.py b/new_experiment/pipeline/import_voxpopuli.py deleted file mode 100644 index 1ecfb6d..0000000 --- a/new_experiment/pipeline/import_voxpopuli.py +++ /dev/null @@ -1,10 +0,0 @@ -from new_experiment.pipeline.import_datasets import import_voxpopuli_dataset - -if __name__ == '__main__': - import_voxpopuli_dataset('nl', 'nl_voxpopuli') - import_voxpopuli_dataset('fr', 'fr_voxpopuli') - import_voxpopuli_dataset('de', 'de_voxpopuli') - import_voxpopuli_dataset('it', 'it_voxpopuli') - import_voxpopuli_dataset('pl', 'pl_voxpopuli') - import_voxpopuli_dataset('es', 'es_voxpopuli') - import_voxpopuli_dataset('en', 'en_voxpopuli') diff --git a/new_experiment/pipeline/pipeline_process_word_classic_wer.py b/new_experiment/pipeline/pipeline_process_word_classic_wer.py new file mode 100644 index 0000000..9489f64 --- /dev/null +++ b/new_experiment/pipeline/pipeline_process_word_classic_wer.py @@ -0,0 +1,43 @@ +import argparse + +from experiment.const_pipeline_names import GOLD_TRANSCRIPT +from experiment.experiment_dependency_provider import get_record_provider, get_repository +from new_experiment.new_dependency_provider import get_experiment_repository, get_minio_audio_record_repository +from new_experiment.utils.loaded_remote_dataset_helper import LoadedRemoteDatasetHelper +from new_experiment.utils.property_helper import PropertyHelper +from sziszapangma.core.transformer.web_embedding_transformer import WebEmbeddingTransformer +from sziszapangma.integration.experiment_manager import ExperimentManager +from sziszapangma.integration.task.classic_wer_metric_task import ClassicWerMetricTask +from sziszapangma.integration.task.embedding_wer_metrics_task import EmbeddingWerMetricsTask + + +def run_word_wer_classic_pipeline(dataset_name: str, asr_name: str): + repository = get_experiment_repository(dataset_name) + experiment_processor = ExperimentManager( + record_id_iterator=LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name), + processing_tasks=[ + ClassicWerMetricTask( + task_name=f'ClassicWerMetricTask___{dataset_name}___{asr_name}', + asr_property_name=PropertyHelper.asr_result(asr_name), + gold_transcript_property_name=PropertyHelper.get_gold_transcript_words(), + metrics_property_name=PropertyHelper.word_wer_classic_metrics(asr_name), + require_update=True, + alignment_property_name=PropertyHelper.word_wer_classic_alignment(asr_name) + ), + # EmbeddingWerMetricsTask( + # task_name='EmbeddingWerMetricsTask', + # asr_property_name=f'{asr_name}__result', + # gold_transcript_property_name=GOLD_TRANSCRIPT, + # metrics_property_name=f'{asr_name}__word_wer_embeddings_metrics', + # require_update=False, + # embedding_transformer=WebEmbeddingTransformer('pl', 'http://localhost:5003', 'fjsd-mkwe-oius-m9h2'), + # alignment_property_name=f'{asr_name}__word_wer_embeddings_alignment' + # ) + ], + experiment_repository=repository + ) + experiment_processor.process() + + +if __name__ == '__main__': + run_word_wer_classic_pipeline('de_google_fleurs', 'whisper_tiny') diff --git a/new_experiment/pipeline/pipeline_process_word_embedding_wer.py b/new_experiment/pipeline/pipeline_process_word_embedding_wer.py new file mode 100644 index 0000000..071be0e --- /dev/null +++ b/new_experiment/pipeline/pipeline_process_word_embedding_wer.py @@ -0,0 +1,36 @@ +import argparse + +from experiment.const_pipeline_names import GOLD_TRANSCRIPT +from experiment.experiment_dependency_provider import get_record_provider, get_repository +from new_experiment.new_dependency_provider import get_experiment_repository, get_minio_audio_record_repository +from new_experiment.utils.loaded_remote_dataset_helper import LoadedRemoteDatasetHelper +from new_experiment.utils.property_helper import PropertyHelper +from sziszapangma.core.transformer.fasttext_embedding_transformer import FasttextEmbeddingTransformer +from sziszapangma.core.transformer.web_embedding_transformer import WebEmbeddingTransformer +from sziszapangma.integration.experiment_manager import ExperimentManager +from sziszapangma.integration.task.classic_wer_metric_task import ClassicWerMetricTask +from sziszapangma.integration.task.embedding_wer_metrics_task import EmbeddingWerMetricsTask + + +def run_word_wer_classic_pipeline(dataset_name: str, asr_name: str): + repository = get_experiment_repository(dataset_name) + experiment_processor = ExperimentManager( + record_id_iterator=LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name), + processing_tasks=[ + EmbeddingWerMetricsTask( + task_name='EmbeddingWerMetricsTask', + asr_property_name=PropertyHelper.asr_result(asr_name), + gold_transcript_property_name=PropertyHelper.get_gold_transcript_words(), + metrics_property_name=PropertyHelper.word_wer_embeddings_metrics(asr_name), + require_update=False, + embedding_transformer=FasttextEmbeddingTransformer(dataset_name[:2]), + alignment_property_name=PropertyHelper.word_wer_embeddings_alignment(asr_name) + ) + ], + experiment_repository=repository + ) + experiment_processor.process() + + +if __name__ == '__main__': + run_word_wer_classic_pipeline('de_google_fleurs', 'whisper_tiny') diff --git a/new_experiment/utils/loaded_remote_dataset_helper.py b/new_experiment/utils/loaded_remote_dataset_helper.py index 78dc8bf..2ead6e3 100644 --- a/new_experiment/utils/loaded_remote_dataset_helper.py +++ b/new_experiment/utils/loaded_remote_dataset_helper.py @@ -1,24 +1,22 @@ from pathlib import Path from typing import Set -from minio import Minio -from urllib3 import HTTPResponse - from experiment.dataset_helper import DatasetHelper -from new_experiment.utils.minio_audio_record_repository import MinioRecordRepository +from new_experiment.utils.minio_audio_record_repository import MinioAudioRecordRepository from new_experiment.utils.property_helper import PropertyHelper from sziszapangma.integration.repository.experiment_repository import ExperimentRepository class LoadedRemoteDatasetHelper(DatasetHelper): _experiment_repository: ExperimentRepository - _minio_record_repository: MinioRecordRepository + _minio_audio_record_repository: MinioAudioRecordRepository _dataset_name: str - def __init__(self, experiment_repository: ExperimentRepository, minio_record_repository: MinioRecordRepository, + def __init__(self, experiment_repository: ExperimentRepository, + minio_audio_record_repository: MinioAudioRecordRepository, dataset_name: str): self._experiment_repository = experiment_repository - self._minio_record_repository = minio_record_repository + self._minio_audio_record_repository = minio_audio_record_repository self._dataset_name = dataset_name def get_all_records(self) -> Set[str]: @@ -28,5 +26,5 @@ class LoadedRemoteDatasetHelper(DatasetHelper): record_path = Path.home() / f'.cache/asr_benchmark/{self._dataset_name}/{record_id}.wav' if record_path.exists(): return record_path.as_posix() - self._minio_record_repository.save_file(record_path, self._dataset_name, record_id) + self._minio_audio_record_repository.save_file(record_path, self._dataset_name, record_id) return record_path.as_posix() diff --git a/sziszapangma/core/transformer/fasttext_embedding_transformer.py b/sziszapangma/core/transformer/fasttext_embedding_transformer.py new file mode 100644 index 0000000..b95b93c --- /dev/null +++ b/sziszapangma/core/transformer/fasttext_embedding_transformer.py @@ -0,0 +1,25 @@ +import json +from typing import Dict, List, Optional + +import numpy as np +import requests +from fasttext.FastText import _FastText + +from sziszapangma.core.transformer.embedding_transformer import EmbeddingTransformer +import fasttext.util + + +class FasttextEmbeddingTransformer(EmbeddingTransformer): + _lang_id: str + _model: _FastText + + def __init__(self, lang_id: str): + self._lang_id = lang_id + fasttext.util.download_model(lang_id, if_exists='ignore') + ft = fasttext.load_model(f'cc.{lang_id}.300.bin') + + def get_embedding(self, word: str) -> np.ndarray: + return self._model.get_word_vector(word) + + def get_embeddings(self, words: List[str]) -> Dict[str, np.ndarray]: + return {it: self.get_embedding(it) for it in words} diff --git a/sziszapangma/integration/experiment_manager.py b/sziszapangma/integration/experiment_manager.py index 5ecd58f..14c2ba3 100644 --- a/sziszapangma/integration/experiment_manager.py +++ b/sziszapangma/integration/experiment_manager.py @@ -11,25 +11,21 @@ class ExperimentManager: _experiment_repository: ExperimentRepository _record_id_iterator: RecordIdIterator _processing_tasks: List[ProcessingTask] - _relation_manager_provider: RelationManagerProvider def __init__( self, experiment_repository: ExperimentRepository, record_id_iterator: RecordIdIterator, processing_tasks: List[ProcessingTask], - relation_manager_provider: RelationManagerProvider, ): self._experiment_repository = experiment_repository self._record_id_iterator = record_id_iterator self._processing_tasks = processing_tasks - self._relation_manager_provider = relation_manager_provider def process(self): self._experiment_repository.initialise() for processing_task in self._processing_tasks: processing_task.process( self._record_id_iterator, - self._experiment_repository, - self._relation_manager_provider, + self._experiment_repository ) diff --git a/sziszapangma/integration/repository/mongo_experiment_repository.py b/sziszapangma/integration/repository/mongo_experiment_repository.py index 6c87a1d..0e9fb11 100644 --- a/sziszapangma/integration/repository/mongo_experiment_repository.py +++ b/sziszapangma/integration/repository/mongo_experiment_repository.py @@ -24,9 +24,16 @@ class MongoExperimentRepository(ExperimentRepository): def property_exists(self, record_id: str, property_name: str) -> bool: database = self._get_database() all_collections = database.list_collection_names() + print(property_name, all_collections) if property_name not in all_collections: + print('collection not found') return False else: + print('self.get_all_record_ids_for_property(property_name)', record_id, record_id.__class__, + # record_id in self.get_all_record_ids_for_property(property_name), + # len(self.get_all_record_ids_for_property(property_name)), + list(self.get_all_record_ids_for_property(property_name))[0], + list(self.get_all_record_ids_for_property(property_name))[0].__class__) return database[property_name].find_one({ID: record_id}) is not None def update_property_for_key(self, record_id: str, property_name: str, property_value: Any): diff --git a/sziszapangma/integration/task/classic_wer_metric_task.py b/sziszapangma/integration/task/classic_wer_metric_task.py index 27ee08a..af7a46b 100644 --- a/sziszapangma/integration/task/classic_wer_metric_task.py +++ b/sziszapangma/integration/task/classic_wer_metric_task.py @@ -1,3 +1,4 @@ +from pprint import pprint from typing import Any, Dict, List from sziszapangma.core.alignment.alignment_classic_calculator import AlignmentClassicCalculator @@ -47,9 +48,10 @@ class ClassicWerMetricTask(ProcessingTask): self, record_id: str, experiment_repository: ExperimentRepository, - relation_manager: RelationManager, ): - gold_transcript = TaskUtil.get_words_from_record(relation_manager) + print('#############') + gold_transcript = experiment_repository.get_property_for_key(record_id, self._gold_transcript_property_name) + print('$$$$$$$$$$$$$', gold_transcript) asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name) if gold_transcript is not None and asr_result is not None and "transcription" in asr_result: alignment_steps = self._get_alignment(gold_transcript, asr_result["transcription"]) -- GitLab