Skip to content
Snippets Groups Projects
Select Git revision
  • master
  • develop
  • python2.7
3 results

FindMorfeusz2.cmake

Blame
  • google_cloud_processor.py 3.34 KiB
    import datetime
    import os.path
    
    from datasets import load_dataset
    from google.cloud.speech_v1p1beta1 import RecognitionConfig
    from joblib import Parallel, delayed
    
    import json
    import uuid
    from pathlib import Path
    from typing import Any
    from google.protobuf.json_format import MessageToJson
    
    from google.cloud.storage import Blob
    import os
    from google.cloud import speech
    from google.cloud import storage
    from scipy.io.wavfile import write
    from experiment.luna.pipeline.dependency_provider import get_record_provider
    
    ACCESS_TOKEN = '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'
    
    
    class GoogleCloudAsrProcessor:
        bucket_name: str
        storage_client: storage.Client
        speech_client: speech.SpeechClient
        language_code: str
    
        def __init__(
            self,
            bucket_name: str,
            storage_client: storage.Client,
            speech_client: speech.SpeechClient,
            language_code: str
        ):
            self.speech_client = speech_client
            self.storage_client = storage_client
            self.bucket_name = bucket_name
            self.language_code = language_code
    
        def save_file_in_gcs(self, file_path: str) -> Blob:
            bucket = self.storage_client.get_bucket(self.bucket_name)
            blob = bucket.blob(f'{uuid.uuid4()}__{Path(file_path).name}')
            blob.upload_from_filename(file_path)
            return blob
    
        def process_file(self, file_path: str) -> str:
            blob = self.save_file_in_gcs(file_path)
            uri = f'gs://{self.bucket_name}/{blob.name}'
            long_audi_wav = speech.RecognitionAudio(uri=uri)
            config_wav_enhanced = speech.RecognitionConfig(
                enable_automatic_punctuation=True,
                language_code=self.language_code,
                use_enhanced=True,
                enable_word_time_offsets=True,
                encoding=RecognitionConfig.AudioEncoding.MP3,
                sample_rate_hertz=48000
            )
    
            operation = self.speech_client.long_running_recognize(
                config=config_wav_enhanced,
                audio=long_audi_wav
            )
            response = operation.result()
            blob.delete()
            return MessageToJson(response._pb)
    
    
    def process_single(record: Any):
        print(record)
        os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/home/marcinwatroba/PWR_ASR/asr-benchmarks/secret_key.json'
        processor = GoogleCloudAsrProcessor('asr-benchmarks-data', storage.Client(), speech.SpeechClient(), 'pl-PL')
        # file_path_temp = f'wav_tmp/{uuid.uuid4()}.wav'
        # write(file_path_temp, record['audio']['sampling_rate'], record['audio']['array'])
        # record_path = file_path_temp
        record_path = record['path']
        record_id = record_path.split('/')[-1]
        result_path = f'{root_result_path}/{record_id}.json'
        if not os.path.exists(result_path):
            result = processor.process_file(record_path)
            with open(result_path, 'w') as f:
                f.write(result)
                print(result)
    
    
    if __name__ == '__main__':
        os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/home/marcinwatroba/PWR_ASR/asr-benchmarks/secret_key.json'
        root_result_path = f'common_voice_9_0_google_asr'
        if not os.path.exists(root_result_path):
            os.mkdir(root_result_path)
        dataset = load_dataset("mozilla-foundation/common_voice_9_0", "pl", use_auth_token=ACCESS_TOKEN)['test']
        # dataset = load_dataset("google/fleurs", "pl_pl")
        print(dataset)
        Parallel(n_jobs=32)(delayed(process_single)(i) for i in dataset)