Skip to content
Snippets Groups Projects
pipeline_process_asr.py 2.66 KiB
from new_experiment.hf_asr.wav2vec2_hf import Wav2Vec2AsrProcessor
from new_experiment.new_dependency_provider import get_experiment_repository, get_minio_audio_record_repository
from new_experiment.utils.loaded_remote_dataset_helper import LoadedRemoteDatasetHelper
from new_experiment.utils.property_helper import PropertyHelper
from sziszapangma.integration.asr_processor import AsrProcessor
from sziszapangma.integration.experiment_manager import ExperimentManager
from sziszapangma.integration.task.asr_task import AsrTask


def get_asr_processor(asr_name: str) -> AsrProcessor:
    if asr_name == 'facebook_wav2vec2_large_xlsr_53_dutch':
        return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-dutch')
    if asr_name == 'facebook_wav2vec2_large_960h_lv60_self':
        return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-960h-lv60-self')
    if asr_name == 'facebook_wav2vec2_large_xlsr_53_french':
        return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-french')
    if asr_name == 'facebook_wav2vec2_large_xlsr_53_german':
        return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-german')
    if asr_name == 'facebook_wav2vec2_large_xlsr_53_italian':
        return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-italian')
    if asr_name == 'facebook_wav2vec2_large_xlsr_53_polish':
        return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-polish')
    if asr_name == 'facebook_wav2vec2_large_xlsr_53_spanish':
        return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-spanish')
    raise Exception(f'AsrProcessor not found for name: {asr_name}')


def run_hf_facebook_wav2vec2_asr_task(dataset_name: str, asr_name: str):
    repository = get_experiment_repository(dataset_name)
    record_provider = LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name)
    experiment_processor = ExperimentManager(
        record_id_iterator=record_provider,
        processing_tasks=[
            AsrTask(
                asr_property_name=PropertyHelper.asr_result(asr_name),
                task_name=f'AsrTask___{dataset_name}___{asr_name}',
                require_update=False,
                asr_processor=get_asr_processor(asr_name),
                record_path_provider=record_provider
            )
        ],
        experiment_repository=repository,
    )
    experiment_processor.process()

# if __name__ == '__main__':
#     run_spacy_dep_tag_wer_pipeline('nl_minds14', 'facebook_wav2vec2_large_xlsr_53_dutch')
#     run_spacy_dep_tag_wer_pipeline('nl_google_fleurs', 'facebook_wav2vec2_large_xlsr_53_dutch')
#     run_spacy_dep_tag_wer_pipeline('nl_voxpopuli', 'facebook_wav2vec2_large_xlsr_53_dutch')