Newer
Older
from typing import Dict, Any
from whisper import Whisper
from experiment.experiment_dependency_provider import get_repository, get_record_provider
from sziszapangma.integration.asr_processor import AsrWebClient, AsrProcessor
from sziszapangma.integration.experiment_manager import ExperimentManager
from sziszapangma.integration.task.asr_task import AsrTask
import whisper
class WhisperAsrProcessor(AsrProcessor):
_whisper: Whisper
def __init__(self):
self._whisper = whisper.load_model("tiny", in_memory=True)
print(self._whisper.device)
def call_recognise(self, file_path: str) -> Dict[str, Any]:
result = self._whisper.transcribe(file_path)
return {
"transcription": result['text'].split(),
"full_text": result['text'],
"words_time_alignment": None,
"language": result['language'],
"segments": result['segments']
}
def get_asr_client(asr_name: str) -> AsrProcessor:
if asr_name == 'ajn':
return AsrWebClient('http://localhost:5431/process_asr', '__example_token__')
elif asr_name == 'wav2vec2':
return AsrWebClient('http://localhost:5437/process_asr', '__example_token__')
elif asr_name == 'whisper':
return WhisperAsrProcessor()
else:
raise Exception
def run_asr_pipeline(dataset_name: str, asr_name: str):
record_provider = get_record_provider(dataset_name)
experiment_processor = ExperimentManager(
record_id_iterator=record_provider,
processing_tasks=[
AsrTask(
task_name=f'AsrTask___{dataset_name}___{asr_name}',
asr_processor=get_asr_client(asr_name),
asr_property_name=f'{asr_name}__result',
require_update=False,
record_path_provider=record_provider
)
],
experiment_repository=get_repository(dataset_name),
relation_manager_provider=record_provider
)
experiment_processor.process()
if __name__ == '__main__':
# run_asr_pipeline('pl_common_voice', 'whisper')
run_asr_pipeline('pl_google_fleurs', 'whisper')
run_asr_pipeline('pl_luna', 'whisper')
run_asr_pipeline('pl_minds14', 'whisper')
run_asr_pipeline('pl_voicelab_cbiz', 'whisper')