Skip to content
Snippets Groups Projects
Commit 9e937495 authored by Marcin Wątroba's avatar Marcin Wątroba
Browse files

Add new worker

parent f867305c
Branches
No related merge requests found
......@@ -26,9 +26,8 @@ def get_minio_client() -> Minio:
return Minio('minio-asr-benchmarks.theliver.pl', 'minio_user', 'eUxzEQbyYPdzrLxuvvethSbk18kB2s7G')
def add_to_queue(dataset: str, asr_name: str, task: str, channel: BlockingChannel):
def add_to_queue(dataset: str, asr_name: str, task: str, channel: BlockingChannel, queue_name: str):
message_dict = {'dataset': dataset, 'asr_name': asr_name, 'task': task}
queue_name = 'asr_benchmark_experiments'
print(datetime.datetime.now().isoformat(), f'{queue_name} {message_dict}')
message_bytes = json.dumps(message_dict).encode('utf-8')
channel.queue_declare(queue=queue_name, durable=True)
......@@ -53,12 +52,36 @@ def add_whisper(channel: BlockingChannel):
add_to_queue(dataset_name, asr_name, command, channel)
def get_hf_facebook_wav2vec2_model_by_language_code(language_code: str) -> str:
return {
'nl': 'facebook_wav2vec2_large_xlsr_53_dutch',
'en': 'facebook_wav2vec2_xls_r_300m',
'fr': 'facebook_wav2vec2_large_xlsr_53_french',
'de': 'facebook_wav2vec2_large_xlsr_53_german',
'it': 'facebook_wav2vec2_large_xlsr_53_italian',
'pl': 'facebook_wav2vec2_large_xlsr_53_polish',
'es': 'facebook_wav2vec2_large_xlsr_53_spanish'
}[language_code]
def add_facebook_hf_wav2vec2_asr(channel: BlockingChannel):
for dataset_name in get_all_datasets_with_language():
add_to_queue(
dataset_name,
get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2]),
'hf_facebook_wav2vec2_asr',
channel,
'hf_facebook_wav2vec2_asr'
)
def main():
parameters = pika.URLParameters(
'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
connection = pika.BlockingConnection(parameters=parameters)
channel = connection.channel()
add_whisper(channel)
# add_whisper(channel)
add_facebook_hf_wav2vec2_asr(channel)
connection.close()
......
......@@ -21,9 +21,7 @@ class Wav2Vec2AsrProcessor(AsrProcessor):
self._wav2vec2_processor = Wav2Vec2Processor.from_pretrained(model_name)
def call_recognise(self, file_path: str) -> Dict[str, Any]:
# samplerate, data = wavfile.read(file_path)
# data, samplerate = soundfile.read(file_path)
data, samplerate = librosa.load(file_path)
data, samplerate = librosa.load(file_path, sr=16000)
features = self._wav2vec2_processor(data, sampling_rate=samplerate, padding=True, return_tensors="pt")
input_values = features.input_values.to(self._device)
attention_mask = features.attention_mask.to(self._device)
......
import json
import os
import functools
import logging
import pika
import threading
import time
from pika.adapters.blocking_connection import BlockingChannel
from new_experiment.pipeline.pipeline_process_asr import run_hf_facebook_wav2vec2_asr_task
from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline
from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline
from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline
from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline
from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline
# LOG_FORMAT = ('%(levelname) -10s %(asctime)s %(name) -30s %(funcName) '
# '-35s %(lineno) -5d: %(message)s')
# LOGGER = logging.getLogger(__name__)
# logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT)
def get_param(name: str, default: str) -> str:
return os.environ[name] if name in os.environ else default
_RABBIT_URL = get_param('RABBIT_URL',
'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
def main():
parameters = pika.URLParameters(_RABBIT_URL)
parameters._heartbeat = 0
# parameters._heartbeat = 65535
connection = pika.BlockingConnection(parameters=parameters)
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
queue_name = f'asr_benchmark_experiments'
for method_frame, properties, body in channel.consume(queue_name):
print(method_frame, properties, body)
message_dict = json.loads(body.decode('utf-8'))
print(message_dict)
task = message_dict['task']
dataset = message_dict['dataset']
asr_name = message_dict['asr_name']
if task == 'run_word_wer_classic_pipeline':
run_word_wer_classic_pipeline(dataset, asr_name)
elif task == 'run_word_wer_embedding_pipeline':
run_word_wer_embedding_pipeline(dataset, asr_name)
elif task == 'run_spacy_dep_tag_wer_pipeline':
run_spacy_dep_tag_wer_pipeline(dataset, asr_name)
elif task == 'run_spacy_ner_wer_pipeline':
run_spacy_ner_wer_pipeline(dataset, asr_name)
elif task == 'run_spacy_pos_wer_pipeline':
run_spacy_pos_wer_pipeline(dataset, asr_name)
else:
raise Exception(f"Bad message {message_dict}")
channel.basic_ack(method_frame.delivery_tag)
print('\n########################################################\n')
requeued_messages = channel.cancel()
print('Requeued %i messages' % requeued_messages)
connection.close()
def ack_message(channel, delivery_tag):
"""Note that `channel` must be the same pika channel instance via which
the message being ACKed was retrieved (AMQP protocol constraint).
"""
if channel.is_open:
channel.basic_ack(delivery_tag)
else:
# Channel is already closed, so we can't ACK this message;
# log and/or do something that makes sense for your app in this case.
pass
def do_work(connection, channel, delivery_tag, body):
thread_id = threading.get_ident()
# fmt1 = 'Thread id: {} Delivery tag: {} Message body: {}'
# LOGGER.info(fmt1.format(thread_id, delivery_tag, body))
print(body)
message_dict = json.loads(body.decode('utf-8'))
print(message_dict)
task = message_dict['task']
dataset = message_dict['dataset']
asr_name = message_dict['asr_name']
if task == 'hf_facebook_wav2vec2_asr':
run_hf_facebook_wav2vec2_asr_task(dataset, asr_name)
cb = functools.partial(ack_message, channel, delivery_tag)
connection.add_callback_threadsafe(cb)
print('\n#########################\n')
def on_message(channel: BlockingChannel, method_frame, header_frame, body, args):
(connection, threads) = args
delivery_tag = method_frame.delivery_tag
t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body))
t.start()
threads.append(t)
def new_main():
parameters = pika.URLParameters(_RABBIT_URL)
connection = pika.BlockingConnection(parameters)
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
threads = []
on_message_callback = functools.partial(on_message, args=(connection, threads))
channel.basic_consume('hf_facebook_wav2vec2_asr', on_message_callback)
try:
channel.start_consuming()
except KeyboardInterrupt:
channel.stop_consuming()
# Wait for all to complete
for thread in threads:
thread.join()
connection.close()
if __name__ == '__main__':
new_main()
......@@ -10,10 +10,22 @@ from sziszapangma.integration.task.asr_task import AsrTask
def get_asr_processor(asr_name: str) -> AsrProcessor:
if asr_name == 'facebook_wav2vec2_large_xlsr_53_dutch':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-dutch')
if asr_name == 'facebook_wav2vec2_xls_r_300m':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-xls-r-300m')
if asr_name == 'facebook_wav2vec2_large_xlsr_53_french':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-french')
if asr_name == 'facebook_wav2vec2_large_xlsr_53_german':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-german')
if asr_name == 'facebook_wav2vec2_large_xlsr_53_italian':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-italian')
if asr_name == 'facebook_wav2vec2_large_xlsr_53_polish':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-polish')
if asr_name == 'facebook_wav2vec2_large_xlsr_53_spanish':
return Wav2Vec2AsrProcessor('facebook/wav2vec2-large-xlsr-53-spanish')
raise Exception(f'AsrProcessor not found for name: {asr_name}')
def run_spacy_dep_tag_wer_pipeline(dataset_name: str, asr_name: str):
def run_hf_facebook_wav2vec2_asr_task(dataset_name: str, asr_name: str):
repository = get_experiment_repository(dataset_name)
record_provider = LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name)
experiment_processor = ExperimentManager(
......@@ -21,8 +33,8 @@ def run_spacy_dep_tag_wer_pipeline(dataset_name: str, asr_name: str):
processing_tasks=[
AsrTask(
asr_property_name=PropertyHelper.asr_result(asr_name),
task_name=f'SpacyDepTagSentenceWerProcessor___{dataset_name}___{asr_name}',
require_update=False,
task_name=f'AsrTask___{dataset_name}___{asr_name}',
require_update=True,
asr_processor=get_asr_processor(asr_name),
record_path_provider=record_provider
)
......@@ -31,8 +43,7 @@ def run_spacy_dep_tag_wer_pipeline(dataset_name: str, asr_name: str):
)
experiment_processor.process()
if __name__ == '__main__':
run_spacy_dep_tag_wer_pipeline('nl_minds14', 'facebook_wav2vec2_large_xlsr_53_dutch')
run_spacy_dep_tag_wer_pipeline('nl_google_fleurs', 'facebook_wav2vec2_large_xlsr_53_dutch')
run_spacy_dep_tag_wer_pipeline('nl_voxpopuli', 'facebook_wav2vec2_large_xlsr_53_dutch')
# if __name__ == '__main__':
# run_spacy_dep_tag_wer_pipeline('nl_minds14', 'facebook_wav2vec2_large_xlsr_53_dutch')
# run_spacy_dep_tag_wer_pipeline('nl_google_fleurs', 'facebook_wav2vec2_large_xlsr_53_dutch')
# run_spacy_dep_tag_wer_pipeline('nl_voxpopuli', 'facebook_wav2vec2_large_xlsr_53_dutch')
......@@ -60,71 +60,5 @@ def main():
connection.close()
def new_main():
LOG_FORMAT = ('%(levelname) -10s %(asctime)s %(name) -30s %(funcName) '
'-35s %(lineno) -5d: %(message)s')
LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT)
def ack_message(channel, delivery_tag):
"""Note that `channel` must be the same pika channel instance via which
the message being ACKed was retrieved (AMQP protocol constraint).
"""
if channel.is_open:
channel.basic_ack(delivery_tag)
else:
# Channel is already closed, so we can't ACK this message;
# log and/or do something that makes sense for your app in this case.
pass
def do_work(connection, channel, delivery_tag, body):
thread_id = threading.get_ident()
fmt1 = 'Thread id: {} Delivery tag: {} Message body: {}'
LOGGER.info(fmt1.format(thread_id, delivery_tag, body))
# Sleeping to simulate 10 seconds of work
time.sleep(10)
cb = functools.partial(ack_message, channel, delivery_tag)
connection.add_callback_threadsafe(cb)
def on_message(channel, method_frame, header_frame, body, args):
(connection, threads) = args
delivery_tag = method_frame.delivery_tag
t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body))
t.start()
threads.append(t)
credentials = pika.PlainCredentials('guest', 'guest')
# Note: sending a short heartbeat to prove that heartbeats are still
# sent even though the worker simulates long-running work
parameters = pika.ConnectionParameters('localhost', credentials=credentials, heartbeat=5)
connection = pika.BlockingConnection(parameters)
channel = connection.channel()
channel.exchange_declare(exchange="test_exchange", exchange_type="direct", passive=False, durable=True,
auto_delete=False)
channel.queue_declare(queue="standard", auto_delete=True)
channel.queue_bind(queue="standard", exchange="test_exchange", routing_key="standard_key")
# Note: prefetch is set to 1 here as an example only and to keep the number of threads created
# to a reasonable amount. In production you will want to test with different prefetch values
# to find which one provides the best performance and usability for your solution
channel.basic_qos(prefetch_count=1)
threads = []
on_message_callback = functools.partial(on_message, args=(connection, threads))
channel.basic_consume('standard', on_message_callback)
try:
channel.start_consuming()
except KeyboardInterrupt:
channel.stop_consuming()
# Wait for all to complete
for thread in threads:
thread.join()
connection.close()
if __name__ == '__main__':
main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment