Skip to content
Snippets Groups Projects
hf_dataset_importer.py 2.53 KiB
Newer Older
import datetime
from abc import ABC, abstractmethod
from hashlib import sha1
from pathlib import Path
from typing import List, Any, Dict

import numpy as np
from datasets import Dataset
from minio import Minio

from new_experiment.utils.minio_audio_record_repository import MinioAudioRecordRepository
from new_experiment.utils.property_helper import PropertyHelper
from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
from sziszapangma.model.model_creators import create_new_word


class HfDatasetImporter(ABC):
    _experiment_repository: ExperimentRepository
    _minio_audio_record_repository: MinioAudioRecordRepository
    _experiment_dataset_name: str

    def __init__(self, experiment_repository: ExperimentRepository,
                 minio_audio_record_repository: MinioAudioRecordRepository, experiment_dataset_name: str):
        self._experiment_repository = experiment_repository
        self._minio_audio_record_repository = minio_audio_record_repository
        self._experiment_dataset_name = experiment_dataset_name

    @abstractmethod
    def get_words(self, record: Dict[str, Any]) -> List[str]:
        pass

    @abstractmethod
    def get_raw_transcription(self, record: Dict[str, Any]) -> str:
        pass

    @abstractmethod
    def get_audio_file(self, record: Dict[str, Any]) -> Path:
        pass

    @abstractmethod
    def get_record_id(self, record: Dict[str, Any]) -> str:
        pass

    def process_dataset(self, dataset: Dataset):
        counter = 1
        for it in dataset:
            print(datetime.datetime.now().isoformat(), f'process_dataset item {counter} {it}')
            self.process_record(it)
            counter += 1

    def process_record(self, record: Dict[str, Any]):
        record_id = self.get_record_id(record)
        words = [create_new_word(it) for it in self.get_words(record)]
        raw_transcription = self.get_raw_transcription(record)
        audio_file_path = self.get_audio_file(record)
        self._experiment_repository.update_property_for_key(
            record_id=record_id,
            property_name=PropertyHelper.get_gold_transcript_words(),
            property_value=words
        )
        self._experiment_repository.update_property_for_key(
            record_id=record_id,
            property_name=PropertyHelper.get_gold_transcript_raw(),
            property_value={'gold_transcript_raw': raw_transcription}
        )
        # TODO uncomment
        # self._minio_audio_record_repository.save_file(audio_file_path, self._experiment_dataset_name, record_id)