Skip to content
Snippets Groups Projects
Commit 46acd375 authored by Marcin Wątroba's avatar Marcin Wątroba
Browse files

Change facebook wav2vec2 model

parent 01293fe6
Branches
No related merge requests found
......@@ -54,26 +54,32 @@ def add_whisper(channel: BlockingChannel):
def get_hf_facebook_wav2vec2_model_by_language_code(language_code: str) -> str:
return {
# 'nl': 'facebook_wav2vec2_large_xlsr_53_dutch',
'nl': 'facebook_wav2vec2_large_xlsr_53_dutch',
'en': 'facebook_wav2vec2_large_960h_lv60_self',
# 'fr': 'facebook_wav2vec2_large_xlsr_53_french',
# 'de': 'facebook_wav2vec2_large_xlsr_53_german',
# 'it': 'facebook_wav2vec2_large_xlsr_53_italian',
# 'pl': 'facebook_wav2vec2_large_xlsr_53_polish',
# 'es': 'facebook_wav2vec2_large_xlsr_53_spanish'
'fr': 'facebook_wav2vec2_large_xlsr_53_french',
'de': 'facebook_wav2vec2_large_xlsr_53_german',
'it': 'facebook_wav2vec2_large_xlsr_53_italian',
'pl': 'facebook_wav2vec2_large_xlsr_53_polish',
'es': 'facebook_wav2vec2_large_xlsr_53_spanish'
}[language_code]
def add_facebook_hf_wav2vec2_asr(channel: BlockingChannel):
for dataset_name in get_all_datasets_with_language():
if dataset_name.startswith('en'):
add_to_queue(
dataset_name,
get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2]),
'hf_facebook_wav2vec2_asr',
channel,
'hf_facebook_wav2vec2_asr'
)
add_to_queue(
dataset_name,
get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2]),
'hf_facebook_wav2vec2_asr',
channel,
'hf_facebook_wav2vec2_asr'
)
def add_facebook_hf_wav2vec2_pipeline(channel: BlockingChannel):
for dataset_name in get_all_datasets_with_language():
asr_name = get_hf_facebook_wav2vec2_model_by_language_code(dataset_name[:2])
for command in COMMANDS:
add_to_queue(dataset_name, asr_name, command, channel, 'asr_benchmark_experiments')
def main():
......@@ -82,7 +88,8 @@ def main():
connection = pika.BlockingConnection(parameters=parameters)
channel = connection.channel()
# add_whisper(channel)
add_facebook_hf_wav2vec2_asr(channel)
# add_facebook_hf_wav2vec2_asr(channel)
add_facebook_hf_wav2vec2_pipeline(channel)
connection.close()
......
import os
def get_param(name: str, default: str) -> str:
return os.environ[name] if name in os.environ else default
import functools
import json
import os
import functools
import logging
import pika
import threading
import time
import pika
from pika.adapters.blocking_connection import BlockingChannel
from new_experiment.pipeline.pipeline_process_asr import run_hf_facebook_wav2vec2_asr_task
......@@ -15,61 +12,12 @@ from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner
from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline
from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline
from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline
# LOG_FORMAT = ('%(levelname) -10s %(asctime)s %(name) -30s %(funcName) '
# '-35s %(lineno) -5d: %(message)s')
# LOGGER = logging.getLogger(__name__)
# logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT)
def get_param(name: str, default: str) -> str:
return os.environ[name] if name in os.environ else default
from new_experiment.utils.param_util import get_param
_RABBIT_URL = get_param('RABBIT_URL',
'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
def main():
parameters = pika.URLParameters(_RABBIT_URL)
parameters._heartbeat = 0
# parameters._heartbeat = 65535
connection = pika.BlockingConnection(parameters=parameters)
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
queue_name = f'asr_benchmark_experiments'
for method_frame, properties, body in channel.consume(queue_name):
print(method_frame, properties, body)
message_dict = json.loads(body.decode('utf-8'))
print(message_dict)
task = message_dict['task']
dataset = message_dict['dataset']
asr_name = message_dict['asr_name']
if task == 'run_word_wer_classic_pipeline':
run_word_wer_classic_pipeline(dataset, asr_name)
elif task == 'run_word_wer_embedding_pipeline':
run_word_wer_embedding_pipeline(dataset, asr_name)
elif task == 'run_spacy_dep_tag_wer_pipeline':
run_spacy_dep_tag_wer_pipeline(dataset, asr_name)
elif task == 'run_spacy_ner_wer_pipeline':
run_spacy_ner_wer_pipeline(dataset, asr_name)
elif task == 'run_spacy_pos_wer_pipeline':
run_spacy_pos_wer_pipeline(dataset, asr_name)
else:
raise Exception(f"Bad message {message_dict}")
channel.basic_ack(method_frame.delivery_tag)
print('\n########################################################\n')
requeued_messages = channel.cancel()
print('Requeued %i messages' % requeued_messages)
connection.close()
def ack_message(channel, delivery_tag):
"""Note that `channel` must be the same pika channel instance via which
the message being ACKed was retrieved (AMQP protocol constraint).
......
......@@ -5,58 +5,88 @@ import functools
import logging
import pika
import threading
import time
from pika.adapters.blocking_connection import BlockingChannel
from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline
from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline
from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline
from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline
from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline
from new_experiment.utils.param_util import get_param
_RABBIT_URL = get_param('RABBIT_URL',
'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
def get_param(name: str, default: str) -> str:
return os.environ[name] if name in os.environ else default
def process_message(body: bytes):
print(body)
message_dict = json.loads(body.decode('utf-8'))
print(message_dict)
_RABBIT_URL = get_param('RABBIT_URL',
'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
task = message_dict['task']
dataset = message_dict['dataset']
asr_name = message_dict['asr_name']
if task == 'run_word_wer_classic_pipeline':
run_word_wer_classic_pipeline(dataset, asr_name)
elif task == 'run_word_wer_embedding_pipeline':
run_word_wer_embedding_pipeline(dataset, asr_name)
elif task == 'run_spacy_dep_tag_wer_pipeline':
run_spacy_dep_tag_wer_pipeline(dataset, asr_name)
elif task == 'run_spacy_ner_wer_pipeline':
run_spacy_ner_wer_pipeline(dataset, asr_name)
elif task == 'run_spacy_pos_wer_pipeline':
run_spacy_pos_wer_pipeline(dataset, asr_name)
else:
raise Exception(f"Bad message {message_dict}")
def ack_message(channel, delivery_tag):
"""Note that `channel` must be the same pika channel instance via which
the message being ACKed was retrieved (AMQP protocol constraint).
"""
if channel.is_open:
channel.basic_ack(delivery_tag)
else:
# Channel is already closed, so we can't ACK this message;
# log and/or do something that makes sense for your app in this case.
pass
def main():
def do_work(connection, channel, delivery_tag, body):
process_message(body)
cb = functools.partial(ack_message, channel, delivery_tag)
connection.add_callback_threadsafe(cb)
print('\n#########################\n')
def on_message(channel: BlockingChannel, method_frame, header_frame, body, args):
(connection, threads) = args
delivery_tag = method_frame.delivery_tag
t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body))
t.start()
threads.append(t)
def new_main():
parameters = pika.URLParameters(_RABBIT_URL)
parameters._heartbeat = 0
# parameters._heartbeat = 65535
connection = pika.BlockingConnection(parameters=parameters)
connection = pika.BlockingConnection(parameters)
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
queue_name = f'asr_benchmark_experiments'
for method_frame, properties, body in channel.consume(queue_name):
print(method_frame, properties, body)
message_dict = json.loads(body.decode('utf-8'))
print(message_dict)
task = message_dict['task']
dataset = message_dict['dataset']
asr_name = message_dict['asr_name']
if task == 'run_word_wer_classic_pipeline':
run_word_wer_classic_pipeline(dataset, asr_name)
elif task == 'run_word_wer_embedding_pipeline':
run_word_wer_embedding_pipeline(dataset, asr_name)
elif task == 'run_spacy_dep_tag_wer_pipeline':
run_spacy_dep_tag_wer_pipeline(dataset, asr_name)
elif task == 'run_spacy_ner_wer_pipeline':
run_spacy_ner_wer_pipeline(dataset, asr_name)
elif task == 'run_spacy_pos_wer_pipeline':
run_spacy_pos_wer_pipeline(dataset, asr_name)
else:
raise Exception(f"Bad message {message_dict}")
channel.basic_ack(method_frame.delivery_tag)
print('\n########################################################\n')
requeued_messages = channel.cancel()
print('Requeued %i messages' % requeued_messages)
threads = []
on_message_callback = functools.partial(on_message, args=(connection, threads))
channel.basic_consume('hf_facebook_wav2vec2_asr', on_message_callback)
try:
channel.start_consuming()
except KeyboardInterrupt:
channel.stop_consuming()
# Wait for all to complete
for thread in threads:
thread.join()
connection.close()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment