From e407a441e9775b6e8bd29165c0f886b092540f51 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com>
Date: Thu, 12 Jan 2023 01:48:09 +0100
Subject: [PATCH] Add pipe

---
 dvc.yaml                                      |  4 +-
 .../fleurs_dataset_importer.py                |  2 +-
 .../{ => dataset_importer}/import_datasets.py |  0
 .../{ => dataset_importer}/import_fleurs.py   |  2 +-
 .../{ => dataset_importer}/import_minds14.py  |  2 +-
 .../dataset_importer/import_voxpopuli.py      | 57 +++++++++++++++++++
 new_experiment/pipeline/import_voxpopuli.py   | 10 ----
 .../pipeline_process_word_classic_wer.py      | 43 ++++++++++++++
 .../pipeline_process_word_embedding_wer.py    | 36 ++++++++++++
 .../utils/loaded_remote_dataset_helper.py     | 14 ++---
 .../fasttext_embedding_transformer.py         | 25 ++++++++
 .../integration/experiment_manager.py         |  6 +-
 .../repository/mongo_experiment_repository.py |  7 +++
 .../task/classic_wer_metric_task.py           |  6 +-
 14 files changed, 184 insertions(+), 30 deletions(-)
 rename new_experiment/pipeline/{ => dataset_importer}/import_datasets.py (100%)
 rename new_experiment/pipeline/{ => dataset_importer}/import_fleurs.py (81%)
 rename new_experiment/pipeline/{ => dataset_importer}/import_minds14.py (80%)
 create mode 100644 new_experiment/pipeline/dataset_importer/import_voxpopuli.py
 delete mode 100644 new_experiment/pipeline/import_voxpopuli.py
 create mode 100644 new_experiment/pipeline/pipeline_process_word_classic_wer.py
 create mode 100644 new_experiment/pipeline/pipeline_process_word_embedding_wer.py
 create mode 100644 sziszapangma/core/transformer/fasttext_embedding_transformer.py

diff --git a/dvc.yaml b/dvc.yaml
index 4c9fc0b..43616ee 100644
--- a/dvc.yaml
+++ b/dvc.yaml
@@ -154,9 +154,9 @@ stages:
             -   dataset: pl_minds14
                 asr: whisper_tiny
         do:
-            cmd: PYTHONPATH=. python experiment/pipeline_process_word_wer.py --dataset=${item.dataset} --asr=${item.asr}
+            cmd: PYTHONPATH=. python experiment/pipeline_process_word_classic_wer.py --dataset=${item.dataset} --asr=${item.asr}
             deps:
-                - experiment/pipeline_process_word_wer.py
+                - experiment/pipeline_process_word_classic_wer.py
                 - experiment_data/dataset/${item.dataset}
                 - experiment_data/pipeline/${item.dataset}/gold_transcript
                 - experiment_data/pipeline/${item.dataset}/${item.asr}__result
diff --git a/new_experiment/pipeline/dataset_importer/fleurs_dataset_importer.py b/new_experiment/pipeline/dataset_importer/fleurs_dataset_importer.py
index 9b88c89..11eaca6 100644
--- a/new_experiment/pipeline/dataset_importer/fleurs_dataset_importer.py
+++ b/new_experiment/pipeline/dataset_importer/fleurs_dataset_importer.py
@@ -27,4 +27,4 @@ class FleursDatasetImporter(HfDatasetImporter):
         return record['path']
 
     def get_record_id(self, record: Dict[str, Any]) -> str:
-        return record["id"]
+        return str(record["id"])
diff --git a/new_experiment/pipeline/import_datasets.py b/new_experiment/pipeline/dataset_importer/import_datasets.py
similarity index 100%
rename from new_experiment/pipeline/import_datasets.py
rename to new_experiment/pipeline/dataset_importer/import_datasets.py
diff --git a/new_experiment/pipeline/import_fleurs.py b/new_experiment/pipeline/dataset_importer/import_fleurs.py
similarity index 81%
rename from new_experiment/pipeline/import_fleurs.py
rename to new_experiment/pipeline/dataset_importer/import_fleurs.py
index f08197d..a81c5f3 100644
--- a/new_experiment/pipeline/import_fleurs.py
+++ b/new_experiment/pipeline/dataset_importer/import_fleurs.py
@@ -1,4 +1,4 @@
-from new_experiment.pipeline.import_datasets import import_fleurs_dataset
+from new_experiment.pipeline.dataset_importer.import_datasets import import_fleurs_dataset
 
 if __name__ == '__main__':
     import_fleurs_dataset('nl_nl', 'nl_google_fleurs')
diff --git a/new_experiment/pipeline/import_minds14.py b/new_experiment/pipeline/dataset_importer/import_minds14.py
similarity index 80%
rename from new_experiment/pipeline/import_minds14.py
rename to new_experiment/pipeline/dataset_importer/import_minds14.py
index 3501909..01871ad 100644
--- a/new_experiment/pipeline/import_minds14.py
+++ b/new_experiment/pipeline/dataset_importer/import_minds14.py
@@ -1,4 +1,4 @@
-from new_experiment.pipeline.import_datasets import import_minds14_dataset
+from new_experiment.pipeline.dataset_importer.import_datasets import import_minds14_dataset
 
 if __name__ == '__main__':
     import_minds14_dataset('nl-NL', 'nl_minds14')
diff --git a/new_experiment/pipeline/dataset_importer/import_voxpopuli.py b/new_experiment/pipeline/dataset_importer/import_voxpopuli.py
new file mode 100644
index 0000000..7e4ba81
--- /dev/null
+++ b/new_experiment/pipeline/dataset_importer/import_voxpopuli.py
@@ -0,0 +1,57 @@
+import json
+from pathlib import Path
+from typing import Any, List
+
+from nltk import RegexpTokenizer
+
+from new_experiment.new_dependency_provider import get_experiment_repository
+from new_experiment.utils.property_helper import PropertyHelper
+from sziszapangma.model.model_creators import create_new_word
+
+
+# de_voxpopuli
+
+def get_words(raw: str) -> List[str]:
+    tokenizer = RegexpTokenizer(r'\w+')
+    return tokenizer.tokenize(raw)
+
+
+def import_from_file(lang: str):
+    path = Path(f'/Users/marcinwatroba/Desktop/MY_PROJECTS/playground/librispeech/cache_items_{lang}_voxpopuli.jsonl')
+    with open(path, 'r') as reader:
+        dataset_name = f'{lang}_voxpopuli'
+        repo = get_experiment_repository(dataset_name)
+        for line in reader.read().splitlines(keepends=False):
+            it_dict = json.loads(line)
+            print(it_dict)
+            record_id = str(it_dict['audio_unique_id'])
+            raw_text = it_dict['raw_text']
+            normalized_text_words = [create_new_word(it) for it in get_words(it_dict['normalized_text'])]
+            repo.update_property_for_key(
+                record_id=record_id,
+                property_name=PropertyHelper.get_gold_transcript_words(),
+                property_value=normalized_text_words
+            )
+            repo.update_property_for_key(
+                record_id=record_id,
+                property_name=PropertyHelper.get_gold_transcript_raw(),
+                property_value={'gold_transcript_raw': raw_text}
+            )
+
+
+if __name__ == '__main__':
+    # import_voxpopuli_dataset('nl', 'nl_voxpopuli')
+    # import_voxpopuli_dataset('fr', 'fr_voxpopuli')
+    # import_voxpopuli_dataset('de', 'de_voxpopuli')
+    # import_voxpopuli_dataset('it', 'it_voxpopuli')
+    # import_voxpopuli_dataset('pl', 'pl_voxpopuli')
+    # import_voxpopuli_dataset('es', 'es_voxpopuli')
+    # import_voxpopuli_dataset('en', 'en_voxpopuli')
+
+    import_from_file('nl')
+    import_from_file('fr')
+    import_from_file('de')
+    # import_from_file('it')
+    import_from_file('pl')
+    import_from_file('es')
+    import_from_file('en')
diff --git a/new_experiment/pipeline/import_voxpopuli.py b/new_experiment/pipeline/import_voxpopuli.py
deleted file mode 100644
index 1ecfb6d..0000000
--- a/new_experiment/pipeline/import_voxpopuli.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from new_experiment.pipeline.import_datasets import import_voxpopuli_dataset
-
-if __name__ == '__main__':
-    import_voxpopuli_dataset('nl', 'nl_voxpopuli')
-    import_voxpopuli_dataset('fr', 'fr_voxpopuli')
-    import_voxpopuli_dataset('de', 'de_voxpopuli')
-    import_voxpopuli_dataset('it', 'it_voxpopuli')
-    import_voxpopuli_dataset('pl', 'pl_voxpopuli')
-    import_voxpopuli_dataset('es', 'es_voxpopuli')
-    import_voxpopuli_dataset('en', 'en_voxpopuli')
diff --git a/new_experiment/pipeline/pipeline_process_word_classic_wer.py b/new_experiment/pipeline/pipeline_process_word_classic_wer.py
new file mode 100644
index 0000000..9489f64
--- /dev/null
+++ b/new_experiment/pipeline/pipeline_process_word_classic_wer.py
@@ -0,0 +1,43 @@
+import argparse
+
+from experiment.const_pipeline_names import GOLD_TRANSCRIPT
+from experiment.experiment_dependency_provider import get_record_provider, get_repository
+from new_experiment.new_dependency_provider import get_experiment_repository, get_minio_audio_record_repository
+from new_experiment.utils.loaded_remote_dataset_helper import LoadedRemoteDatasetHelper
+from new_experiment.utils.property_helper import PropertyHelper
+from sziszapangma.core.transformer.web_embedding_transformer import WebEmbeddingTransformer
+from sziszapangma.integration.experiment_manager import ExperimentManager
+from sziszapangma.integration.task.classic_wer_metric_task import ClassicWerMetricTask
+from sziszapangma.integration.task.embedding_wer_metrics_task import EmbeddingWerMetricsTask
+
+
+def run_word_wer_classic_pipeline(dataset_name: str, asr_name: str):
+    repository = get_experiment_repository(dataset_name)
+    experiment_processor = ExperimentManager(
+        record_id_iterator=LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name),
+        processing_tasks=[
+            ClassicWerMetricTask(
+                task_name=f'ClassicWerMetricTask___{dataset_name}___{asr_name}',
+                asr_property_name=PropertyHelper.asr_result(asr_name),
+                gold_transcript_property_name=PropertyHelper.get_gold_transcript_words(),
+                metrics_property_name=PropertyHelper.word_wer_classic_metrics(asr_name),
+                require_update=True,
+                alignment_property_name=PropertyHelper.word_wer_classic_alignment(asr_name)
+            ),
+            # EmbeddingWerMetricsTask(
+            #     task_name='EmbeddingWerMetricsTask',
+            #     asr_property_name=f'{asr_name}__result',
+            #     gold_transcript_property_name=GOLD_TRANSCRIPT,
+            #     metrics_property_name=f'{asr_name}__word_wer_embeddings_metrics',
+            #     require_update=False,
+            #     embedding_transformer=WebEmbeddingTransformer('pl', 'http://localhost:5003', 'fjsd-mkwe-oius-m9h2'),
+            #     alignment_property_name=f'{asr_name}__word_wer_embeddings_alignment'
+            # )
+        ],
+        experiment_repository=repository
+    )
+    experiment_processor.process()
+
+
+if __name__ == '__main__':
+    run_word_wer_classic_pipeline('de_google_fleurs', 'whisper_tiny')
diff --git a/new_experiment/pipeline/pipeline_process_word_embedding_wer.py b/new_experiment/pipeline/pipeline_process_word_embedding_wer.py
new file mode 100644
index 0000000..071be0e
--- /dev/null
+++ b/new_experiment/pipeline/pipeline_process_word_embedding_wer.py
@@ -0,0 +1,36 @@
+import argparse
+
+from experiment.const_pipeline_names import GOLD_TRANSCRIPT
+from experiment.experiment_dependency_provider import get_record_provider, get_repository
+from new_experiment.new_dependency_provider import get_experiment_repository, get_minio_audio_record_repository
+from new_experiment.utils.loaded_remote_dataset_helper import LoadedRemoteDatasetHelper
+from new_experiment.utils.property_helper import PropertyHelper
+from sziszapangma.core.transformer.fasttext_embedding_transformer import FasttextEmbeddingTransformer
+from sziszapangma.core.transformer.web_embedding_transformer import WebEmbeddingTransformer
+from sziszapangma.integration.experiment_manager import ExperimentManager
+from sziszapangma.integration.task.classic_wer_metric_task import ClassicWerMetricTask
+from sziszapangma.integration.task.embedding_wer_metrics_task import EmbeddingWerMetricsTask
+
+
+def run_word_wer_classic_pipeline(dataset_name: str, asr_name: str):
+    repository = get_experiment_repository(dataset_name)
+    experiment_processor = ExperimentManager(
+        record_id_iterator=LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name),
+        processing_tasks=[
+            EmbeddingWerMetricsTask(
+                task_name='EmbeddingWerMetricsTask',
+                asr_property_name=PropertyHelper.asr_result(asr_name),
+                gold_transcript_property_name=PropertyHelper.get_gold_transcript_words(),
+                metrics_property_name=PropertyHelper.word_wer_embeddings_metrics(asr_name),
+                require_update=False,
+                embedding_transformer=FasttextEmbeddingTransformer(dataset_name[:2]),
+                alignment_property_name=PropertyHelper.word_wer_embeddings_alignment(asr_name)
+            )
+        ],
+        experiment_repository=repository
+    )
+    experiment_processor.process()
+
+
+if __name__ == '__main__':
+    run_word_wer_classic_pipeline('de_google_fleurs', 'whisper_tiny')
diff --git a/new_experiment/utils/loaded_remote_dataset_helper.py b/new_experiment/utils/loaded_remote_dataset_helper.py
index 78dc8bf..2ead6e3 100644
--- a/new_experiment/utils/loaded_remote_dataset_helper.py
+++ b/new_experiment/utils/loaded_remote_dataset_helper.py
@@ -1,24 +1,22 @@
 from pathlib import Path
 from typing import Set
 
-from minio import Minio
-from urllib3 import HTTPResponse
-
 from experiment.dataset_helper import DatasetHelper
-from new_experiment.utils.minio_audio_record_repository import MinioRecordRepository
+from new_experiment.utils.minio_audio_record_repository import MinioAudioRecordRepository
 from new_experiment.utils.property_helper import PropertyHelper
 from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
 
 
 class LoadedRemoteDatasetHelper(DatasetHelper):
     _experiment_repository: ExperimentRepository
-    _minio_record_repository: MinioRecordRepository
+    _minio_audio_record_repository: MinioAudioRecordRepository
     _dataset_name: str
 
-    def __init__(self, experiment_repository: ExperimentRepository, minio_record_repository: MinioRecordRepository,
+    def __init__(self, experiment_repository: ExperimentRepository,
+                 minio_audio_record_repository: MinioAudioRecordRepository,
                  dataset_name: str):
         self._experiment_repository = experiment_repository
-        self._minio_record_repository = minio_record_repository
+        self._minio_audio_record_repository = minio_audio_record_repository
         self._dataset_name = dataset_name
 
     def get_all_records(self) -> Set[str]:
@@ -28,5 +26,5 @@ class LoadedRemoteDatasetHelper(DatasetHelper):
         record_path = Path.home() / f'.cache/asr_benchmark/{self._dataset_name}/{record_id}.wav'
         if record_path.exists():
             return record_path.as_posix()
-        self._minio_record_repository.save_file(record_path, self._dataset_name, record_id)
+        self._minio_audio_record_repository.save_file(record_path, self._dataset_name, record_id)
         return record_path.as_posix()
diff --git a/sziszapangma/core/transformer/fasttext_embedding_transformer.py b/sziszapangma/core/transformer/fasttext_embedding_transformer.py
new file mode 100644
index 0000000..b95b93c
--- /dev/null
+++ b/sziszapangma/core/transformer/fasttext_embedding_transformer.py
@@ -0,0 +1,25 @@
+import json
+from typing import Dict, List, Optional
+
+import numpy as np
+import requests
+from fasttext.FastText import _FastText
+
+from sziszapangma.core.transformer.embedding_transformer import EmbeddingTransformer
+import fasttext.util
+
+
+class FasttextEmbeddingTransformer(EmbeddingTransformer):
+    _lang_id: str
+    _model: _FastText
+
+    def __init__(self, lang_id: str):
+        self._lang_id = lang_id
+        fasttext.util.download_model(lang_id, if_exists='ignore')
+        ft = fasttext.load_model(f'cc.{lang_id}.300.bin')
+
+    def get_embedding(self, word: str) -> np.ndarray:
+        return self._model.get_word_vector(word)
+
+    def get_embeddings(self, words: List[str]) -> Dict[str, np.ndarray]:
+        return {it: self.get_embedding(it) for it in words}
diff --git a/sziszapangma/integration/experiment_manager.py b/sziszapangma/integration/experiment_manager.py
index 5ecd58f..14c2ba3 100644
--- a/sziszapangma/integration/experiment_manager.py
+++ b/sziszapangma/integration/experiment_manager.py
@@ -11,25 +11,21 @@ class ExperimentManager:
     _experiment_repository: ExperimentRepository
     _record_id_iterator: RecordIdIterator
     _processing_tasks: List[ProcessingTask]
-    _relation_manager_provider: RelationManagerProvider
 
     def __init__(
         self,
         experiment_repository: ExperimentRepository,
         record_id_iterator: RecordIdIterator,
         processing_tasks: List[ProcessingTask],
-        relation_manager_provider: RelationManagerProvider,
     ):
         self._experiment_repository = experiment_repository
         self._record_id_iterator = record_id_iterator
         self._processing_tasks = processing_tasks
-        self._relation_manager_provider = relation_manager_provider
 
     def process(self):
         self._experiment_repository.initialise()
         for processing_task in self._processing_tasks:
             processing_task.process(
                 self._record_id_iterator,
-                self._experiment_repository,
-                self._relation_manager_provider,
+                self._experiment_repository
             )
diff --git a/sziszapangma/integration/repository/mongo_experiment_repository.py b/sziszapangma/integration/repository/mongo_experiment_repository.py
index 6c87a1d..0e9fb11 100644
--- a/sziszapangma/integration/repository/mongo_experiment_repository.py
+++ b/sziszapangma/integration/repository/mongo_experiment_repository.py
@@ -24,9 +24,16 @@ class MongoExperimentRepository(ExperimentRepository):
     def property_exists(self, record_id: str, property_name: str) -> bool:
         database = self._get_database()
         all_collections = database.list_collection_names()
+        print(property_name, all_collections)
         if property_name not in all_collections:
+            print('collection not found')
             return False
         else:
+            print('self.get_all_record_ids_for_property(property_name)', record_id, record_id.__class__,
+                  # record_id in self.get_all_record_ids_for_property(property_name),
+                  # len(self.get_all_record_ids_for_property(property_name)),
+                  list(self.get_all_record_ids_for_property(property_name))[0],
+                  list(self.get_all_record_ids_for_property(property_name))[0].__class__)
             return database[property_name].find_one({ID: record_id}) is not None
 
     def update_property_for_key(self, record_id: str, property_name: str, property_value: Any):
diff --git a/sziszapangma/integration/task/classic_wer_metric_task.py b/sziszapangma/integration/task/classic_wer_metric_task.py
index 27ee08a..af7a46b 100644
--- a/sziszapangma/integration/task/classic_wer_metric_task.py
+++ b/sziszapangma/integration/task/classic_wer_metric_task.py
@@ -1,3 +1,4 @@
+from pprint import pprint
 from typing import Any, Dict, List
 
 from sziszapangma.core.alignment.alignment_classic_calculator import AlignmentClassicCalculator
@@ -47,9 +48,10 @@ class ClassicWerMetricTask(ProcessingTask):
         self,
         record_id: str,
         experiment_repository: ExperimentRepository,
-        relation_manager: RelationManager,
     ):
-        gold_transcript = TaskUtil.get_words_from_record(relation_manager)
+        print('#############')
+        gold_transcript = experiment_repository.get_property_for_key(record_id, self._gold_transcript_property_name)
+        print('$$$$$$$$$$$$$', gold_transcript)
         asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name)
         if gold_transcript is not None and asr_result is not None and "transcription" in asr_result:
             alignment_steps = self._get_alignment(gold_transcript, asr_result["transcription"])
-- 
GitLab