Skip to content
Snippets Groups Projects
Commit e407a441 authored by Marcin Wątroba's avatar Marcin Wątroba
Browse files

Add pipe

parent 96f3d25c
Branches
No related merge requests found
Showing
with 184 additions and 30 deletions
......@@ -154,9 +154,9 @@ stages:
- dataset: pl_minds14
asr: whisper_tiny
do:
cmd: PYTHONPATH=. python experiment/pipeline_process_word_wer.py --dataset=${item.dataset} --asr=${item.asr}
cmd: PYTHONPATH=. python experiment/pipeline_process_word_classic_wer.py --dataset=${item.dataset} --asr=${item.asr}
deps:
- experiment/pipeline_process_word_wer.py
- experiment/pipeline_process_word_classic_wer.py
- experiment_data/dataset/${item.dataset}
- experiment_data/pipeline/${item.dataset}/gold_transcript
- experiment_data/pipeline/${item.dataset}/${item.asr}__result
......
......@@ -27,4 +27,4 @@ class FleursDatasetImporter(HfDatasetImporter):
return record['path']
def get_record_id(self, record: Dict[str, Any]) -> str:
return record["id"]
return str(record["id"])
from new_experiment.pipeline.import_datasets import import_fleurs_dataset
from new_experiment.pipeline.dataset_importer.import_datasets import import_fleurs_dataset
if __name__ == '__main__':
import_fleurs_dataset('nl_nl', 'nl_google_fleurs')
......
from new_experiment.pipeline.import_datasets import import_minds14_dataset
from new_experiment.pipeline.dataset_importer.import_datasets import import_minds14_dataset
if __name__ == '__main__':
import_minds14_dataset('nl-NL', 'nl_minds14')
......
import json
from pathlib import Path
from typing import Any, List
from nltk import RegexpTokenizer
from new_experiment.new_dependency_provider import get_experiment_repository
from new_experiment.utils.property_helper import PropertyHelper
from sziszapangma.model.model_creators import create_new_word
# de_voxpopuli
def get_words(raw: str) -> List[str]:
tokenizer = RegexpTokenizer(r'\w+')
return tokenizer.tokenize(raw)
def import_from_file(lang: str):
path = Path(f'/Users/marcinwatroba/Desktop/MY_PROJECTS/playground/librispeech/cache_items_{lang}_voxpopuli.jsonl')
with open(path, 'r') as reader:
dataset_name = f'{lang}_voxpopuli'
repo = get_experiment_repository(dataset_name)
for line in reader.read().splitlines(keepends=False):
it_dict = json.loads(line)
print(it_dict)
record_id = str(it_dict['audio_unique_id'])
raw_text = it_dict['raw_text']
normalized_text_words = [create_new_word(it) for it in get_words(it_dict['normalized_text'])]
repo.update_property_for_key(
record_id=record_id,
property_name=PropertyHelper.get_gold_transcript_words(),
property_value=normalized_text_words
)
repo.update_property_for_key(
record_id=record_id,
property_name=PropertyHelper.get_gold_transcript_raw(),
property_value={'gold_transcript_raw': raw_text}
)
if __name__ == '__main__':
# import_voxpopuli_dataset('nl', 'nl_voxpopuli')
# import_voxpopuli_dataset('fr', 'fr_voxpopuli')
# import_voxpopuli_dataset('de', 'de_voxpopuli')
# import_voxpopuli_dataset('it', 'it_voxpopuli')
# import_voxpopuli_dataset('pl', 'pl_voxpopuli')
# import_voxpopuli_dataset('es', 'es_voxpopuli')
# import_voxpopuli_dataset('en', 'en_voxpopuli')
import_from_file('nl')
import_from_file('fr')
import_from_file('de')
# import_from_file('it')
import_from_file('pl')
import_from_file('es')
import_from_file('en')
from new_experiment.pipeline.import_datasets import import_voxpopuli_dataset
if __name__ == '__main__':
import_voxpopuli_dataset('nl', 'nl_voxpopuli')
import_voxpopuli_dataset('fr', 'fr_voxpopuli')
import_voxpopuli_dataset('de', 'de_voxpopuli')
import_voxpopuli_dataset('it', 'it_voxpopuli')
import_voxpopuli_dataset('pl', 'pl_voxpopuli')
import_voxpopuli_dataset('es', 'es_voxpopuli')
import_voxpopuli_dataset('en', 'en_voxpopuli')
import argparse
from experiment.const_pipeline_names import GOLD_TRANSCRIPT
from experiment.experiment_dependency_provider import get_record_provider, get_repository
from new_experiment.new_dependency_provider import get_experiment_repository, get_minio_audio_record_repository
from new_experiment.utils.loaded_remote_dataset_helper import LoadedRemoteDatasetHelper
from new_experiment.utils.property_helper import PropertyHelper
from sziszapangma.core.transformer.web_embedding_transformer import WebEmbeddingTransformer
from sziszapangma.integration.experiment_manager import ExperimentManager
from sziszapangma.integration.task.classic_wer_metric_task import ClassicWerMetricTask
from sziszapangma.integration.task.embedding_wer_metrics_task import EmbeddingWerMetricsTask
def run_word_wer_classic_pipeline(dataset_name: str, asr_name: str):
repository = get_experiment_repository(dataset_name)
experiment_processor = ExperimentManager(
record_id_iterator=LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name),
processing_tasks=[
ClassicWerMetricTask(
task_name=f'ClassicWerMetricTask___{dataset_name}___{asr_name}',
asr_property_name=PropertyHelper.asr_result(asr_name),
gold_transcript_property_name=PropertyHelper.get_gold_transcript_words(),
metrics_property_name=PropertyHelper.word_wer_classic_metrics(asr_name),
require_update=True,
alignment_property_name=PropertyHelper.word_wer_classic_alignment(asr_name)
),
# EmbeddingWerMetricsTask(
# task_name='EmbeddingWerMetricsTask',
# asr_property_name=f'{asr_name}__result',
# gold_transcript_property_name=GOLD_TRANSCRIPT,
# metrics_property_name=f'{asr_name}__word_wer_embeddings_metrics',
# require_update=False,
# embedding_transformer=WebEmbeddingTransformer('pl', 'http://localhost:5003', 'fjsd-mkwe-oius-m9h2'),
# alignment_property_name=f'{asr_name}__word_wer_embeddings_alignment'
# )
],
experiment_repository=repository
)
experiment_processor.process()
if __name__ == '__main__':
run_word_wer_classic_pipeline('de_google_fleurs', 'whisper_tiny')
import argparse
from experiment.const_pipeline_names import GOLD_TRANSCRIPT
from experiment.experiment_dependency_provider import get_record_provider, get_repository
from new_experiment.new_dependency_provider import get_experiment_repository, get_minio_audio_record_repository
from new_experiment.utils.loaded_remote_dataset_helper import LoadedRemoteDatasetHelper
from new_experiment.utils.property_helper import PropertyHelper
from sziszapangma.core.transformer.fasttext_embedding_transformer import FasttextEmbeddingTransformer
from sziszapangma.core.transformer.web_embedding_transformer import WebEmbeddingTransformer
from sziszapangma.integration.experiment_manager import ExperimentManager
from sziszapangma.integration.task.classic_wer_metric_task import ClassicWerMetricTask
from sziszapangma.integration.task.embedding_wer_metrics_task import EmbeddingWerMetricsTask
def run_word_wer_classic_pipeline(dataset_name: str, asr_name: str):
repository = get_experiment_repository(dataset_name)
experiment_processor = ExperimentManager(
record_id_iterator=LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name),
processing_tasks=[
EmbeddingWerMetricsTask(
task_name='EmbeddingWerMetricsTask',
asr_property_name=PropertyHelper.asr_result(asr_name),
gold_transcript_property_name=PropertyHelper.get_gold_transcript_words(),
metrics_property_name=PropertyHelper.word_wer_embeddings_metrics(asr_name),
require_update=False,
embedding_transformer=FasttextEmbeddingTransformer(dataset_name[:2]),
alignment_property_name=PropertyHelper.word_wer_embeddings_alignment(asr_name)
)
],
experiment_repository=repository
)
experiment_processor.process()
if __name__ == '__main__':
run_word_wer_classic_pipeline('de_google_fleurs', 'whisper_tiny')
from pathlib import Path
from typing import Set
from minio import Minio
from urllib3 import HTTPResponse
from experiment.dataset_helper import DatasetHelper
from new_experiment.utils.minio_audio_record_repository import MinioRecordRepository
from new_experiment.utils.minio_audio_record_repository import MinioAudioRecordRepository
from new_experiment.utils.property_helper import PropertyHelper
from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
class LoadedRemoteDatasetHelper(DatasetHelper):
_experiment_repository: ExperimentRepository
_minio_record_repository: MinioRecordRepository
_minio_audio_record_repository: MinioAudioRecordRepository
_dataset_name: str
def __init__(self, experiment_repository: ExperimentRepository, minio_record_repository: MinioRecordRepository,
def __init__(self, experiment_repository: ExperimentRepository,
minio_audio_record_repository: MinioAudioRecordRepository,
dataset_name: str):
self._experiment_repository = experiment_repository
self._minio_record_repository = minio_record_repository
self._minio_audio_record_repository = minio_audio_record_repository
self._dataset_name = dataset_name
def get_all_records(self) -> Set[str]:
......@@ -28,5 +26,5 @@ class LoadedRemoteDatasetHelper(DatasetHelper):
record_path = Path.home() / f'.cache/asr_benchmark/{self._dataset_name}/{record_id}.wav'
if record_path.exists():
return record_path.as_posix()
self._minio_record_repository.save_file(record_path, self._dataset_name, record_id)
self._minio_audio_record_repository.save_file(record_path, self._dataset_name, record_id)
return record_path.as_posix()
import json
from typing import Dict, List, Optional
import numpy as np
import requests
from fasttext.FastText import _FastText
from sziszapangma.core.transformer.embedding_transformer import EmbeddingTransformer
import fasttext.util
class FasttextEmbeddingTransformer(EmbeddingTransformer):
_lang_id: str
_model: _FastText
def __init__(self, lang_id: str):
self._lang_id = lang_id
fasttext.util.download_model(lang_id, if_exists='ignore')
ft = fasttext.load_model(f'cc.{lang_id}.300.bin')
def get_embedding(self, word: str) -> np.ndarray:
return self._model.get_word_vector(word)
def get_embeddings(self, words: List[str]) -> Dict[str, np.ndarray]:
return {it: self.get_embedding(it) for it in words}
......@@ -11,25 +11,21 @@ class ExperimentManager:
_experiment_repository: ExperimentRepository
_record_id_iterator: RecordIdIterator
_processing_tasks: List[ProcessingTask]
_relation_manager_provider: RelationManagerProvider
def __init__(
self,
experiment_repository: ExperimentRepository,
record_id_iterator: RecordIdIterator,
processing_tasks: List[ProcessingTask],
relation_manager_provider: RelationManagerProvider,
):
self._experiment_repository = experiment_repository
self._record_id_iterator = record_id_iterator
self._processing_tasks = processing_tasks
self._relation_manager_provider = relation_manager_provider
def process(self):
self._experiment_repository.initialise()
for processing_task in self._processing_tasks:
processing_task.process(
self._record_id_iterator,
self._experiment_repository,
self._relation_manager_provider,
self._experiment_repository
)
......@@ -24,9 +24,16 @@ class MongoExperimentRepository(ExperimentRepository):
def property_exists(self, record_id: str, property_name: str) -> bool:
database = self._get_database()
all_collections = database.list_collection_names()
print(property_name, all_collections)
if property_name not in all_collections:
print('collection not found')
return False
else:
print('self.get_all_record_ids_for_property(property_name)', record_id, record_id.__class__,
# record_id in self.get_all_record_ids_for_property(property_name),
# len(self.get_all_record_ids_for_property(property_name)),
list(self.get_all_record_ids_for_property(property_name))[0],
list(self.get_all_record_ids_for_property(property_name))[0].__class__)
return database[property_name].find_one({ID: record_id}) is not None
def update_property_for_key(self, record_id: str, property_name: str, property_value: Any):
......
from pprint import pprint
from typing import Any, Dict, List
from sziszapangma.core.alignment.alignment_classic_calculator import AlignmentClassicCalculator
......@@ -47,9 +48,10 @@ class ClassicWerMetricTask(ProcessingTask):
self,
record_id: str,
experiment_repository: ExperimentRepository,
relation_manager: RelationManager,
):
gold_transcript = TaskUtil.get_words_from_record(relation_manager)
print('#############')
gold_transcript = experiment_repository.get_property_for_key(record_id, self._gold_transcript_property_name)
print('$$$$$$$$$$$$$', gold_transcript)
asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name)
if gold_transcript is not None and asr_result is not None and "transcription" in asr_result:
alignment_steps = self._get_alignment(gold_transcript, asr_result["transcription"])
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment