from typing import Dict, Any from whisper import Whisper from experiment.experiment_dependency_provider import get_repository, get_record_provider from sziszapangma.integration.asr_processor import AsrWebClient, AsrProcessor from sziszapangma.integration.experiment_manager import ExperimentManager from sziszapangma.integration.task.asr_task import AsrTask import whisper class WhisperAsrProcessor(AsrProcessor): _whisper: Whisper def __init__(self): self._whisper = whisper.load_model("tiny", in_memory=True) print(self._whisper.device) def call_recognise(self, file_path: str) -> Dict[str, Any]: result = self._whisper.transcribe(file_path) return { "transcription": result['text'].split(), "full_text": result['text'], "words_time_alignment": None, "language": result['language'], "segments": result['segments'] } def get_asr_client(asr_name: str) -> AsrProcessor: if asr_name == 'ajn': return AsrWebClient('http://localhost:5431/process_asr', '__example_token__') elif asr_name == 'wav2vec2': return AsrWebClient('http://localhost:5437/process_asr', '__example_token__') elif asr_name == 'whisper': return WhisperAsrProcessor() else: raise Exception def run_asr_pipeline(dataset_name: str, asr_name: str): record_provider = get_record_provider(dataset_name) experiment_processor = ExperimentManager( record_id_iterator=record_provider, processing_tasks=[ AsrTask( task_name=f'AsrTask___{dataset_name}___{asr_name}', asr_processor=get_asr_client(asr_name), asr_property_name=f'{asr_name}__result', require_update=False, record_path_provider=record_provider ) ], experiment_repository=get_repository(dataset_name), relation_manager_provider=record_provider ) experiment_processor.process() if __name__ == '__main__': # run_asr_pipeline('pl_common_voice', 'whisper') run_asr_pipeline('pl_google_fleurs', 'whisper') run_asr_pipeline('pl_luna', 'whisper') run_asr_pipeline('pl_minds14', 'whisper') run_asr_pipeline('pl_voicelab_cbiz', 'whisper')