from datasets import load_dataset from experiment.dataset_helper import DatasetHelper from experiment.dataset_specific.pl_luna.luna_record_provider import LunaRecordProvider from experiment.dataset_specific.pl_voicelab_cbiz.voicelab_gold_transcript_processor import \ VoicelabGoldTranscriptProcessor from experiment.dataset_specific.pl_voicelab_cbiz.voicelab_telco_record_provider import VoicelabTelcoRecordProvider from experiment.hf_dataset_helper.hf_gold_transcript_processor import HfGoldTranscriptProcessor from experiment.hf_dataset_helper.hf_record_provider import HfRecordProvider from experiment.dataset_specific.pl_luna.pipeline.task.luna_gold_transcript_processor import LunaGoldTranscriptProcessor from sziszapangma.integration.gold_transcript_processor import GoldTranscriptProcessor from sziszapangma.integration.path_filter import ExtensionPathFilter from sziszapangma.integration.repository.multi_files_experiment_repository import MultiFilesExperimentRepository PIPELINE_DATA_DIRECTORY = 'experiment_data/pipeline' HF_ACCESS_TOKEN = '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' def get_repository(dataset_name: str) -> MultiFilesExperimentRepository: return MultiFilesExperimentRepository(PIPELINE_DATA_DIRECTORY, dataset_name) def get_record_provider(dataset_name: str) -> DatasetHelper: if dataset_name == 'pl_common_voice': dataset = load_dataset("mozilla-foundation/common_voice_9_0", "pl", use_auth_token=HF_ACCESS_TOKEN)['test'] return HfRecordProvider(dataset, 'experiment_data/dataset_relation_manager_data/pl_common_voice') elif dataset_name == 'pl_minds14': dataset = load_dataset("PolyAI/minds14", "pl-PL")['train'] return HfRecordProvider(dataset, 'experiment_data/dataset_relation_manager_data/pl_minds14', 2) elif dataset_name == 'pl_google_fleurs': return HfRecordProvider( load_dataset("google/fleurs", "pl_pl")['test'], 'experiment_data/dataset_relation_manager_data/pl_google_fleurs', 1, '' ) elif dataset_name == 'pl_luna': return LunaRecordProvider( ExtensionPathFilter( root_directory=f'experiment_data/dataset/pl_luna/LUNA.PL', extension='wav' ), relation_manager_root_path='experiment_data/dataset_relation_manager_data/pl_luna' ) elif dataset_name == 'pl_voicelab_cbiz': return VoicelabTelcoRecordProvider( ExtensionPathFilter( root_directory='experiment_data/dataset/pl_voicelab_cbiz', extension='wav' ), relation_manager_root_path='experiment_data/dataset_relation_manager_data/pl_voicelab_cbiz' ) else: raise Exception('dataset not found') def get_gold_transcript_processor(dataset_name: str, dataset_helper: DatasetHelper) -> GoldTranscriptProcessor: if dataset_name in ['pl_common_voice', 'pl_minds14', 'pl_google_fleurs']: return HfGoldTranscriptProcessor(dataset_helper) elif dataset_name == 'pl_luna': return LunaGoldTranscriptProcessor(dataset_helper) elif dataset_name == 'pl_voicelab_cbiz': return VoicelabGoldTranscriptProcessor(dataset_helper) else: raise Exception('dataset not found')