Skip to content
Snippets Groups Projects
experiment_dependency_provider.py 3.2 KiB
Newer Older
Marcin Wątroba's avatar
Marcin Wątroba committed
from datasets import load_dataset

from experiment.dataset_helper import DatasetHelper
from experiment.dataset_specific.pl_luna.luna_record_provider import LunaRecordProvider
from experiment.dataset_specific.pl_voicelab_cbiz.voicelab_gold_transcript_processor import \
    VoicelabGoldTranscriptProcessor
from experiment.dataset_specific.pl_voicelab_cbiz.voicelab_telco_record_provider import VoicelabTelcoRecordProvider
from experiment.hf_dataset_helper.hf_gold_transcript_processor import HfGoldTranscriptProcessor
from experiment.hf_dataset_helper.hf_record_provider import HfRecordProvider
from experiment.dataset_specific.pl_luna.pipeline.task.luna_gold_transcript_processor import LunaGoldTranscriptProcessor
from sziszapangma.integration.gold_transcript_processor import GoldTranscriptProcessor
from sziszapangma.integration.path_filter import ExtensionPathFilter
from sziszapangma.integration.repository.multi_files_experiment_repository import MultiFilesExperimentRepository

PIPELINE_DATA_DIRECTORY = 'experiment_data/pipeline'
HF_ACCESS_TOKEN = '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'


def get_repository(dataset_name: str) -> MultiFilesExperimentRepository:
    return MultiFilesExperimentRepository(PIPELINE_DATA_DIRECTORY, dataset_name)


def get_record_provider(dataset_name: str) -> DatasetHelper:
    if dataset_name == 'pl_common_voice':
        dataset = load_dataset("mozilla-foundation/common_voice_9_0", "pl", use_auth_token=HF_ACCESS_TOKEN)['test']
        return HfRecordProvider(dataset, 'experiment_data/dataset_relation_manager_data/pl_common_voice')
    elif dataset_name == 'pl_minds14':
        dataset = load_dataset("PolyAI/minds14", "pl-PL")['train']
        return HfRecordProvider(dataset, 'experiment_data/dataset_relation_manager_data/pl_minds14', 2)
    elif dataset_name == 'pl_google_fleurs':
        return HfRecordProvider(
            load_dataset("google/fleurs", "pl_pl")['test'],
            'experiment_data/dataset_relation_manager_data/pl_google_fleurs', 1, ''
Marcin Wątroba's avatar
Marcin Wątroba committed
        )
    elif dataset_name == 'pl_luna':
        return LunaRecordProvider(
            ExtensionPathFilter(
                root_directory=f'experiment_data/dataset/pl_luna/LUNA.PL',
                extension='wav'
            ),
            relation_manager_root_path='experiment_data/dataset_relation_manager_data/pl_luna'
        )
    elif dataset_name == 'pl_voicelab_cbiz':
        return VoicelabTelcoRecordProvider(
            ExtensionPathFilter(
                root_directory='experiment_data/dataset/pl_voicelab_cbiz',
                extension='wav'
            ),
            relation_manager_root_path='experiment_data/dataset_relation_manager_data/pl_voicelab_cbiz'
        )

    else:
        raise Exception('dataset not found')


def get_gold_transcript_processor(dataset_name: str, dataset_helper: DatasetHelper) -> GoldTranscriptProcessor:
    if dataset_name in ['pl_common_voice', 'pl_minds14', 'pl_google_fleurs']:
        return HfGoldTranscriptProcessor(dataset_helper)
    elif dataset_name == 'pl_luna':
        return LunaGoldTranscriptProcessor(dataset_helper)
    elif dataset_name == 'pl_voicelab_cbiz':
        return VoicelabGoldTranscriptProcessor(dataset_helper)
    else:
        raise Exception('dataset not found')