Skip to content
Snippets Groups Projects
pipeline_process_asr_missing.py 2.24 KiB
Newer Older
from typing import Dict, Any

from whisper import Whisper
Marcin Wątroba's avatar
Marcin Wątroba committed

from experiment.experiment_dependency_provider import get_repository, get_record_provider
from sziszapangma.integration.asr_processor import AsrWebClient, AsrProcessor
Marcin Wątroba's avatar
Marcin Wątroba committed
from sziszapangma.integration.experiment_manager import ExperimentManager
from sziszapangma.integration.task.asr_task import AsrTask
import whisper


class WhisperAsrProcessor(AsrProcessor):
    _whisper: Whisper

    def __init__(self):
        self._whisper = whisper.load_model("tiny", in_memory=True)
        print(self._whisper.device)

    def call_recognise(self, file_path: str) -> Dict[str, Any]:
        result = self._whisper.transcribe(file_path)
        return {
            "transcription": result['text'].split(),
            "full_text": result['text'],
            "words_time_alignment": None,
            "language": result['language'],
            "segments": result['segments']
        }
def get_asr_client(asr_name: str) -> AsrProcessor:
Marcin Wątroba's avatar
Marcin Wątroba committed
    if asr_name == 'ajn':
        return AsrWebClient('http://localhost:5431/process_asr', '__example_token__')
    elif asr_name == 'wav2vec2':
        return AsrWebClient('http://localhost:5437/process_asr', '__example_token__')
    elif asr_name == 'whisper':
        return WhisperAsrProcessor()
Marcin Wątroba's avatar
Marcin Wątroba committed
    else:
        raise Exception


def run_asr_pipeline(dataset_name: str, asr_name: str):
    record_provider = get_record_provider(dataset_name)
    experiment_processor = ExperimentManager(
        record_id_iterator=record_provider,
        processing_tasks=[
            AsrTask(
                task_name=f'AsrTask___{dataset_name}___{asr_name}',
                asr_processor=get_asr_client(asr_name),
                asr_property_name=f'{asr_name}__result',
                require_update=False,
                record_path_provider=record_provider
            )
        ],
        experiment_repository=get_repository(dataset_name),
        relation_manager_provider=record_provider
    )
    experiment_processor.process()


if __name__ == '__main__':
    # run_asr_pipeline('pl_common_voice', 'whisper')
    run_asr_pipeline('pl_google_fleurs', 'whisper')
    run_asr_pipeline('pl_luna', 'whisper')
    run_asr_pipeline('pl_minds14', 'whisper')
    run_asr_pipeline('pl_voicelab_cbiz', 'whisper')