-
Marcin Wątroba authored01293fe6
pipeline_process_asr.py 2.66 KiB
from new_experiment.hf_asr.wav2vec2_hf import Wav2Vec2AsrProcessor
from new_experiment.new_dependency_provider import get_experiment_repository, get_minio_audio_record_repository
from new_experiment.utils.loaded_remote_dataset_helper import LoadedRemoteDatasetHelper
from new_experiment.utils.property_helper import PropertyHelper
from sziszapangma.integration.asr_processor import AsrProcessor
from sziszapangma.integration.experiment_manager import ExperimentManager
from sziszapangma.integration.task.asr_task import AsrTask
def get_asr_processor(asr_name: str) -> AsrProcessor:
if asr_name == 'facebook_wav2vec2_large_xlsr_53_dutch':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-dutch')
if asr_name == 'facebook_wav2vec2_large_960h_lv60_self':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-960h-lv60-self')
if asr_name == 'facebook_wav2vec2_large_xlsr_53_french':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-french')
if asr_name == 'facebook_wav2vec2_large_xlsr_53_german':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-german')
if asr_name == 'facebook_wav2vec2_large_xlsr_53_italian':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-italian')
if asr_name == 'facebook_wav2vec2_large_xlsr_53_polish':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-polish')
if asr_name == 'facebook_wav2vec2_large_xlsr_53_spanish':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-spanish')
raise Exception(f'AsrProcessor not found for name: {asr_name}')
def run_hf_facebook_wav2vec2_asr_task(dataset_name: str, asr_name: str):
repository = get_experiment_repository(dataset_name)
record_provider = LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name)
experiment_processor = ExperimentManager(
record_id_iterator=record_provider,
processing_tasks=[
AsrTask(
asr_property_name=PropertyHelper.asr_result(asr_name),
task_name=f'AsrTask___{dataset_name}___{asr_name}',
require_update=False,
asr_processor=get_asr_processor(asr_name),
record_path_provider=record_provider
)
],
experiment_repository=repository,
)
experiment_processor.process()
# if __name__ == '__main__':
# run_spacy_dep_tag_wer_pipeline('nl_minds14', 'facebook_wav2vec2_large_xlsr_53_dutch')
# run_spacy_dep_tag_wer_pipeline('nl_google_fleurs', 'facebook_wav2vec2_large_xlsr_53_dutch')
# run_spacy_dep_tag_wer_pipeline('nl_voxpopuli', 'facebook_wav2vec2_large_xlsr_53_dutch')