Skip to content
Snippets Groups Projects
google_cloud_processor_minds.py 3.35 KiB
import datetime
import os.path

from datasets import load_dataset
from google.cloud.speech_v1p1beta1 import RecognitionConfig
from joblib import Parallel, delayed

import json
import uuid
from pathlib import Path
from typing import Any
from google.protobuf.json_format import MessageToJson

from google.cloud.storage import Blob
import os
from google.cloud import speech
from google.cloud import storage
from scipy.io.wavfile import write
from experiment.luna.pipeline.dependency_provider import get_record_provider

ACCESS_TOKEN = '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'


class GoogleCloudAsrProcessor:
    bucket_name: str
    storage_client: storage.Client
    speech_client: speech.SpeechClient
    language_code: str

    def __init__(
        self,
        bucket_name: str,
        storage_client: storage.Client,
        speech_client: speech.SpeechClient,
        language_code: str
    ):
        self.speech_client = speech_client
        self.storage_client = storage_client
        self.bucket_name = bucket_name
        self.language_code = language_code

    def save_file_in_gcs(self, file_path: str) -> Blob:
        bucket = self.storage_client.get_bucket(self.bucket_name)
        blob = bucket.blob(f'{uuid.uuid4()}__{Path(file_path).name}')
        blob.upload_from_filename(file_path)
        return blob

    def process_file(self, file_path: str) -> str:
        blob = self.save_file_in_gcs(file_path)
        uri = f'gs://{self.bucket_name}/{blob.name}'
        long_audi_wav = speech.RecognitionAudio(uri=uri)
        config_wav_enhanced = speech.RecognitionConfig(
            enable_automatic_punctuation=True,
            language_code=self.language_code,
            use_enhanced=True,
            enable_word_time_offsets=True,
            # encoding=RecognitionConfig.AudioEncoding.,
            # sample_rate_hertz=48000
        )

        operation = self.speech_client.long_running_recognize(
            config=config_wav_enhanced,
            audio=long_audi_wav
        )
        response = operation.result()
        blob.delete()
        return MessageToJson(response._pb)


def process_single(record: Any):
    print(record)
    os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/home/marcinwatroba/PWR_ASR/asr-benchmarks/secret_key.json'
    processor = GoogleCloudAsrProcessor('asr-benchmarks-data', storage.Client(), speech.SpeechClient(), 'pl-PL')
    # file_path_temp = f'wav_tmp/{uuid.uuid4()}.wav'
    # write(file_path_temp, record['audio']['sampling_rate'], record['audio']['array'])
    # record_path = file_path_temp
    record_path = record['path']
    record_id = '__'.join(record_path.split('/')[-2:])
    result_path = f'{root_result_path}/{record_id}.json'
    if not os.path.exists(result_path):
        result = processor.process_file(record_path)
        with open(result_path, 'w') as f:
            f.write(result)
            print(result)


if __name__ == '__main__':
    os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/home/marcinwatroba/PWR_ASR/asr-benchmarks/secret_key.json'
    root_result_path = f'minds_google_asr'
    if not os.path.exists(root_result_path):
        os.mkdir(root_result_path)
    # dataset = load_dataset("mozilla-foundation/common_voice_9_0", "pl", use_auth_token=ACCESS_TOKEN)['test']
    dataset = load_dataset("PolyAI/minds14", "pl-PL")['train']
    print(dataset)
    Parallel(n_jobs=32)(delayed(process_single)(i) for i in dataset)