from new_experiment.hf_asr.wav2vec2_hf import Wav2Vec2AsrProcessor from new_experiment.new_dependency_provider import get_experiment_repository, get_minio_audio_record_repository from new_experiment.utils.loaded_remote_dataset_helper import LoadedRemoteDatasetHelper from new_experiment.utils.property_helper import PropertyHelper from sziszapangma.integration.asr_processor import AsrProcessor from sziszapangma.integration.experiment_manager import ExperimentManager from sziszapangma.integration.task.asr_task import AsrTask def get_asr_processor(asr_name: str) -> AsrProcessor: if asr_name == 'facebook_wav2vec2_large_xlsr_53_dutch': return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-dutch') if asr_name == 'facebook_wav2vec2_large_960h_lv60_self': return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-960h-lv60-self') if asr_name == 'facebook_wav2vec2_large_xlsr_53_french': return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-french') if asr_name == 'facebook_wav2vec2_large_xlsr_53_german': return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-german') if asr_name == 'facebook_wav2vec2_large_xlsr_53_italian': return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-italian') if asr_name == 'facebook_wav2vec2_large_xlsr_53_polish': return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-polish') if asr_name == 'facebook_wav2vec2_large_xlsr_53_spanish': return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-spanish') raise Exception(f'AsrProcessor not found for name: {asr_name}') def run_hf_facebook_wav2vec2_asr_task(dataset_name: str, asr_name: str): repository = get_experiment_repository(dataset_name) record_provider = LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name) experiment_processor = ExperimentManager( record_id_iterator=record_provider, processing_tasks=[ AsrTask( asr_property_name=PropertyHelper.asr_result(asr_name), task_name=f'AsrTask___{dataset_name}___{asr_name}', require_update=False, asr_processor=get_asr_processor(asr_name), record_path_provider=record_provider ) ], experiment_repository=repository, ) experiment_processor.process() # if __name__ == '__main__': # run_spacy_dep_tag_wer_pipeline('nl_minds14', 'facebook_wav2vec2_large_xlsr_53_dutch') # run_spacy_dep_tag_wer_pipeline('nl_google_fleurs', 'facebook_wav2vec2_large_xlsr_53_dutch') # run_spacy_dep_tag_wer_pipeline('nl_voxpopuli', 'facebook_wav2vec2_large_xlsr_53_dutch')