Newer
Older
from datasets import load_dataset
from experiment.dataset_helper import DatasetHelper
from experiment.dataset_specific.pl_luna.luna_record_provider import LunaRecordProvider
from experiment.dataset_specific.pl_voicelab_cbiz.voicelab_gold_transcript_processor import \
VoicelabGoldTranscriptProcessor
from experiment.dataset_specific.pl_voicelab_cbiz.voicelab_telco_record_provider import VoicelabTelcoRecordProvider
from experiment.hf_dataset_helper.hf_gold_transcript_processor import HfGoldTranscriptProcessor
from experiment.hf_dataset_helper.hf_record_provider import HfRecordProvider
from experiment.dataset_specific.pl_luna.pipeline.task.luna_gold_transcript_processor import LunaGoldTranscriptProcessor
from sziszapangma.integration.gold_transcript_processor import GoldTranscriptProcessor
from sziszapangma.integration.path_filter import ExtensionPathFilter
from sziszapangma.integration.repository.multi_files_experiment_repository import MultiFilesExperimentRepository
PIPELINE_DATA_DIRECTORY = 'experiment_data/pipeline'
HF_ACCESS_TOKEN = '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'
def get_repository(dataset_name: str) -> MultiFilesExperimentRepository:
return MultiFilesExperimentRepository(PIPELINE_DATA_DIRECTORY, dataset_name)
def get_record_provider(dataset_name: str) -> DatasetHelper:
if dataset_name == 'pl_common_voice':
dataset = load_dataset("mozilla-foundation/common_voice_9_0", "pl", use_auth_token=HF_ACCESS_TOKEN)['test']
return HfRecordProvider(dataset, 'experiment_data/dataset_relation_manager_data/pl_common_voice')
elif dataset_name == 'pl_minds14':
dataset = load_dataset("PolyAI/minds14", "pl-PL")['train']
return HfRecordProvider(dataset, 'experiment_data/dataset_relation_manager_data/pl_minds14', 2)
elif dataset_name == 'pl_google_fleurs':
return HfRecordProvider(
load_dataset("google/fleurs", "pl_pl")['test'],
'experiment_data/dataset_relation_manager_data/pl_google_fleurs', 1, ''
)
elif dataset_name == 'pl_luna':
return LunaRecordProvider(
ExtensionPathFilter(
root_directory=f'experiment_data/dataset/pl_luna/LUNA.PL',
extension='wav'
),
relation_manager_root_path='experiment_data/dataset_relation_manager_data/pl_luna'
)
elif dataset_name == 'pl_voicelab_cbiz':
return VoicelabTelcoRecordProvider(
ExtensionPathFilter(
root_directory='experiment_data/dataset/pl_voicelab_cbiz',
extension='wav'
),
relation_manager_root_path='experiment_data/dataset_relation_manager_data/pl_voicelab_cbiz'
)
else:
raise Exception('dataset not found')
def get_gold_transcript_processor(dataset_name: str, dataset_helper: DatasetHelper) -> GoldTranscriptProcessor:
if dataset_name in ['pl_common_voice', 'pl_minds14', 'pl_google_fleurs']:
return HfGoldTranscriptProcessor(dataset_helper)
elif dataset_name == 'pl_luna':
return LunaGoldTranscriptProcessor(dataset_helper)
elif dataset_name == 'pl_voicelab_cbiz':
return VoicelabGoldTranscriptProcessor(dataset_helper)
else:
raise Exception('dataset not found')