From c96708a501e902703e681ab3cf49105548023f3c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com>
Date: Thu, 12 Jan 2023 03:26:06 +0100
Subject: [PATCH] Add worker for pipeline

---
 new_experiment/add_to_queue_pipeline.py       | 66 +++++++++++++++++++
 .../pipeline_process_spacy_dep_tag_wer.py     |  2 +-
 .../pipeline_process_spacy_ner_wer.py         |  2 +-
 new_experiment/worker.py                      |  7 --
 4 files changed, 68 insertions(+), 9 deletions(-)
 create mode 100644 new_experiment/add_to_queue_pipeline.py

diff --git a/new_experiment/add_to_queue_pipeline.py b/new_experiment/add_to_queue_pipeline.py
new file mode 100644
index 0000000..2053526
--- /dev/null
+++ b/new_experiment/add_to_queue_pipeline.py
@@ -0,0 +1,66 @@
+import datetime
+import json
+from typing import List
+
+import pika
+from minio import Minio
+from pika.adapters.blocking_connection import BlockingChannel
+
+COMMANDS = ['run_word_wer_classic_pipeline', 'run_word_wer_embedding_pipeline', 'run_spacy_dep_tag_wer_pipeline',
+            'run_spacy_ner_wer_pipeline', 'run_spacy_pos_wer_pipeline']
+LANGUAGES = ['nl', 'fr', 'de', 'it', 'pl', 'es', 'en']
+WHISPER_ASR_MODEL = ['tiny', 'base', 'small', 'medium', 'large-v2']
+DATASETS = ['google_fleurs', 'minds14', 'voxpopuli']
+
+
+def get_all_datasets() -> List[str]:
+    return [it.object_name[:-1] for it in get_minio_client().list_objects('dataset-audio', '')]
+
+
+def get_dataset_items_id(dataset_name: str) -> List[str]:
+    return [it.object_name.split('/')[-1].split('.')[0] for it in
+            get_minio_client().list_objects('dataset-audio', f'{dataset_name}/')]
+
+
+def get_minio_client() -> Minio:
+    return Minio('minio-asr-benchmarks.theliver.pl', 'minio_user', 'eUxzEQbyYPdzrLxuvvethSbk18kB2s7G')
+
+
+def add_to_queue(dataset: str, asr_name: str, task: str, channel: BlockingChannel):
+    message_dict = {'dataset': dataset, 'asr_name': asr_name, 'task': task}
+    queue_name = 'asr_benchmark_experiments'
+    print(datetime.datetime.now().isoformat(), f'{queue_name}   {message_dict}')
+    message_bytes = json.dumps(message_dict).encode('utf-8')
+    channel.queue_declare(queue=queue_name, durable=True)
+    channel.basic_publish(exchange='', routing_key=queue_name, body=message_bytes, properties=pika.BasicProperties(
+        delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE
+    ))
+
+
+def get_all_datasets_with_language() -> List[str]:
+    to_return = []
+    for it in LANGUAGES:
+        for itt in DATASETS:
+            to_return.append(f'{it}_{itt}')
+    return to_return
+
+
+def add_whisper(channel: BlockingChannel):
+    for whisper_variant in WHISPER_ASR_MODEL:
+        asr_name = f'whisper_{whisper_variant}'
+        for dataset_name in get_all_datasets_with_language():
+            for command in COMMANDS:
+                add_to_queue(dataset_name, asr_name, command, channel)
+
+
+def main():
+    parameters = pika.URLParameters(
+        'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
+    connection = pika.BlockingConnection(parameters=parameters)
+    channel = connection.channel()
+    add_whisper(channel)
+    connection.close()
+
+
+if __name__ == '__main__':
+    main()
diff --git a/new_experiment/pipeline/pipeline_process_spacy_dep_tag_wer.py b/new_experiment/pipeline/pipeline_process_spacy_dep_tag_wer.py
index 03420db..8e2d568 100644
--- a/new_experiment/pipeline/pipeline_process_spacy_dep_tag_wer.py
+++ b/new_experiment/pipeline/pipeline_process_spacy_dep_tag_wer.py
@@ -25,7 +25,7 @@ def run_spacy_dep_tag_wer_pipeline(dataset_name: str, asr_name: str):
                 alignment_property_name=PropertyHelper.dep_tag_alignment(asr_name, model_name),
                 wer_property_name=PropertyHelper.dep_tag_metrics(asr_name, model_name),
                 task_name=f'SpacyDepTagSentenceWerProcessor___{dataset_name}___{asr_name}',
-                require_update=False
+                require_update=True
             )
         ],
         experiment_repository=repository,
diff --git a/new_experiment/pipeline/pipeline_process_spacy_ner_wer.py b/new_experiment/pipeline/pipeline_process_spacy_ner_wer.py
index 96cc9ef..c74e419 100644
--- a/new_experiment/pipeline/pipeline_process_spacy_ner_wer.py
+++ b/new_experiment/pipeline/pipeline_process_spacy_ner_wer.py
@@ -25,7 +25,7 @@ def run_spacy_ner_wer_pipeline(dataset_name: str, asr_name: str):
                 alignment_property_name=PropertyHelper.ner_alignment(asr_name, model_name),
                 wer_property_name=PropertyHelper.ner_metrics(asr_name, model_name),
                 task_name=f'SpacyNerSentenceWerProcessor___{dataset_name}___{asr_name}',
-                require_update=False
+                require_update=True
             )
         ],
         experiment_repository=repository,
diff --git a/new_experiment/worker.py b/new_experiment/worker.py
index 93a861b..7422ea9 100644
--- a/new_experiment/worker.py
+++ b/new_experiment/worker.py
@@ -1,19 +1,13 @@
 import json
 import os
-import uuid
 
 import pika
-from minio import Minio
-from pymongo import MongoClient
-from urllib3 import HTTPResponse
 
-from new_datasets.whisper_processor import WhisperAsrProcessor
 from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline
 from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline
 from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline
 from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline
 from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline
-from sziszapangma.integration.repository.mongo_experiment_repository import MongoExperimentRepository
 
 
 def get_param(name: str, default: str) -> str:
@@ -39,7 +33,6 @@ def main():
         task = message_dict['task']
         dataset = message_dict['dataset']
         asr_name = message_dict['asr_name']
-
         if task == 'run_word_wer_classic_pipeline':
             run_word_wer_classic_pipeline(dataset, asr_name)
         elif task == 'run_word_wer_embedding_pipeline':
-- 
GitLab