From 95002a468e758efe1c20451ce1b767fda15bf3da Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com>
Date: Sat, 14 Jan 2023 00:54:51 +0100
Subject: [PATCH] Change facebook wav2vec2 model

---
 new_experiment/queue_base.py      | 58 ++++++++++++++++++++++++++++
 new_experiment/worker_pipeline.py | 64 +------------------------------
 2 files changed, 60 insertions(+), 62 deletions(-)
 create mode 100644 new_experiment/queue_base.py

diff --git a/new_experiment/queue_base.py b/new_experiment/queue_base.py
new file mode 100644
index 0000000..6307f56
--- /dev/null
+++ b/new_experiment/queue_base.py
@@ -0,0 +1,58 @@
+import functools
+import threading
+from typing import Callable
+
+import pika
+from pika.adapters.blocking_connection import BlockingChannel
+
+from new_experiment.utils.param_util import get_param
+
+_RABBIT_URL = get_param('RABBIT_URL',
+                        'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
+
+
+def ack_message(channel, delivery_tag):
+    """Note that `channel` must be the same pika channel instance via which
+    the message being ACKed was retrieved (AMQP protocol constraint).
+    """
+    if channel.is_open:
+        channel.basic_ack(delivery_tag)
+    else:
+        # Channel is already closed, so we can't ACK this message;
+        # log and/or do something that makes sense for your app in this case.
+        pass
+
+
+def process_queue(queue_name: str, process_message: Callable[[bytes], None]):
+    def do_work(connection, channel, delivery_tag, body):
+        process_message(body)
+        cb = functools.partial(ack_message, channel, delivery_tag)
+        connection.add_callback_threadsafe(cb)
+        print('\n#########################\n')
+
+    def on_message(channel: BlockingChannel, method_frame, header_frame, body, args):
+        (connection, threads) = args
+        delivery_tag = method_frame.delivery_tag
+        t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body))
+        t.start()
+        threads.append(t)
+
+    parameters = pika.URLParameters(_RABBIT_URL)
+    connection = pika.BlockingConnection(parameters)
+    channel = connection.channel()
+    channel.basic_qos(prefetch_count=1)
+
+    threads = []
+    on_message_callback = functools.partial(on_message, args=(connection, threads))
+    channel.basic_consume(queue_name, on_message_callback)
+
+    try:
+        channel.start_consuming()
+    except KeyboardInterrupt:
+        channel.stop_consuming()
+
+    # Wait for all to complete
+    for thread in threads:
+        thread.join()
+
+    connection.close()
diff --git a/new_experiment/worker_pipeline.py b/new_experiment/worker_pipeline.py
index 9acbe7c..27dbc61 100644
--- a/new_experiment/worker_pipeline.py
+++ b/new_experiment/worker_pipeline.py
@@ -1,22 +1,11 @@
 import json
-import os
-
-import functools
-import logging
-import pika
-import threading
-
-from pika.adapters.blocking_connection import BlockingChannel
 
 from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline
 from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline
 from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline
 from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline
 from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline
-from new_experiment.utils.param_util import get_param
-
-_RABBIT_URL = get_param('RABBIT_URL',
-                        'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
+from new_experiment.queue_base import process_queue
 
 
 def process_message(body: bytes):
@@ -41,54 +30,5 @@ def process_message(body: bytes):
         raise Exception(f"Bad message {message_dict}")
 
 
-def ack_message(channel, delivery_tag):
-    """Note that `channel` must be the same pika channel instance via which
-    the message being ACKed was retrieved (AMQP protocol constraint).
-    """
-    if channel.is_open:
-        channel.basic_ack(delivery_tag)
-    else:
-        # Channel is already closed, so we can't ACK this message;
-        # log and/or do something that makes sense for your app in this case.
-        pass
-
-
-def do_work(connection, channel, delivery_tag, body):
-    process_message(body)
-    cb = functools.partial(ack_message, channel, delivery_tag)
-    connection.add_callback_threadsafe(cb)
-    print('\n#########################\n')
-
-
-def on_message(channel: BlockingChannel, method_frame, header_frame, body, args):
-    (connection, threads) = args
-    delivery_tag = method_frame.delivery_tag
-    t = threading.Thread(target=do_work, args=(connection, channel, delivery_tag, body))
-    t.start()
-    threads.append(t)
-
-
-def new_main():
-    parameters = pika.URLParameters(_RABBIT_URL)
-    connection = pika.BlockingConnection(parameters)
-    channel = connection.channel()
-    channel.basic_qos(prefetch_count=1)
-
-    threads = []
-    on_message_callback = functools.partial(on_message, args=(connection, threads))
-    channel.basic_consume('hf_facebook_wav2vec2_asr', on_message_callback)
-
-    try:
-        channel.start_consuming()
-    except KeyboardInterrupt:
-        channel.stop_consuming()
-
-    # Wait for all to complete
-    for thread in threads:
-        thread.join()
-
-    connection.close()
-
-
 if __name__ == '__main__':
-    new_main()
+    process_queue('asr_benchmark_experiments', process_message)
-- 
GitLab