From 38826e3e90e59284a2d1cbeb1fd8733a6b3b4567 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com>
Date: Thu, 12 Jan 2023 02:50:02 +0100
Subject: [PATCH] Add worker for pipeline

---
 Dockerfile                                    | 17 +++++++--------
 .../pipeline_process_spacy_dep_tag_wer.py     |  4 ++--
 .../pipeline_process_word_embedding_wer.py    |  4 ++--
 new_experiment/worker.py                      | 21 +++++++++++++++++++
 4 files changed, 33 insertions(+), 13 deletions(-)

diff --git a/Dockerfile b/Dockerfile
index 84d3df4..795b6c1 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -6,15 +6,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y tzdata
 RUN add-apt-repository ppa:deadsnakes/ppa && apt-get update && apt-get install -y python3.8 python3-pip ffmpeg
 RUN alias python='python3' && alias pip='pip3' && pip install poetry
 
-RUN pip install spacy
-RUN python -m spacy download de_core_news_lg
-RUN python -m spacy download pl_core_news_lg
-RUN python -m spacy download en_core_news_lg
-RUN python -m spacy download it_core_news_lg
-RUN python -m spacy download nl_core_news_lg
-RUN python -m spacy download sp_core_news_lg
-RUN python -m spacy download pt_core_news_lg
-
 ADD poetry.lock ./
 ADD pyproject.toml ./
 ADD README.rst ./
@@ -27,3 +18,11 @@ RUN ls -l
 RUN poetry config virtualenvs.create false --local
 
 RUN poetry install
+
+RUN poetry run python -m spacy download de_core_news_lg
+RUN poetry run python -m spacy download pl_core_news_lg
+RUN poetry run python -m spacy download en_core_news_lg
+RUN poetry run python -m spacy download it_core_news_lg
+RUN poetry run python -m spacy download nl_core_news_lg
+RUN poetry run python -m spacy download sp_core_news_lg
+RUN poetry run python -m spacy download pt_core_news_lg
diff --git a/new_experiment/pipeline/pipeline_process_spacy_dep_tag_wer.py b/new_experiment/pipeline/pipeline_process_spacy_dep_tag_wer.py
index f0727d1..03420db 100644
--- a/new_experiment/pipeline/pipeline_process_spacy_dep_tag_wer.py
+++ b/new_experiment/pipeline/pipeline_process_spacy_dep_tag_wer.py
@@ -10,7 +10,7 @@ from new_experiment.utils.property_helper import PropertyHelper
 from sziszapangma.integration.experiment_manager import ExperimentManager
 
 
-def run_spacy_pos_wer_pipeline(dataset_name: str, asr_name: str):
+def run_spacy_dep_tag_wer_pipeline(dataset_name: str, asr_name: str):
     repository = get_experiment_repository(dataset_name)
     record_provider = LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name)
     language_code = dataset_name[:2]
@@ -34,4 +34,4 @@ def run_spacy_pos_wer_pipeline(dataset_name: str, asr_name: str):
 
 
 if __name__ == '__main__':
-    run_spacy_pos_wer_pipeline('de_minds14', 'whisper_tiny')
+    run_spacy_dep_tag_wer_pipeline('de_minds14', 'whisper_tiny')
diff --git a/new_experiment/pipeline/pipeline_process_word_embedding_wer.py b/new_experiment/pipeline/pipeline_process_word_embedding_wer.py
index dbdecb6..8f942f0 100644
--- a/new_experiment/pipeline/pipeline_process_word_embedding_wer.py
+++ b/new_experiment/pipeline/pipeline_process_word_embedding_wer.py
@@ -12,7 +12,7 @@ from sziszapangma.integration.task.classic_wer_metric_task import ClassicWerMetr
 from sziszapangma.integration.task.embedding_wer_metrics_task import EmbeddingWerMetricsTask
 
 
-def run_word_wer_classic_pipeline(dataset_name: str, asr_name: str):
+def run_word_wer_embedding_pipeline(dataset_name: str, asr_name: str):
     repository = get_experiment_repository(dataset_name)
     experiment_processor = ExperimentManager(
         record_id_iterator=LoadedRemoteDatasetHelper(repository, get_minio_audio_record_repository(), dataset_name),
@@ -33,4 +33,4 @@ def run_word_wer_classic_pipeline(dataset_name: str, asr_name: str):
 
 
 if __name__ == '__main__':
-    run_word_wer_classic_pipeline('de_minds14', 'whisper_tiny')
+    run_word_wer_embedding_pipeline('de_minds14', 'whisper_tiny')
diff --git a/new_experiment/worker.py b/new_experiment/worker.py
index 3a86742..93a861b 100644
--- a/new_experiment/worker.py
+++ b/new_experiment/worker.py
@@ -8,6 +8,11 @@ from pymongo import MongoClient
 from urllib3 import HTTPResponse
 
 from new_datasets.whisper_processor import WhisperAsrProcessor
+from new_experiment.pipeline.pipeline_process_spacy_dep_tag_wer import run_spacy_dep_tag_wer_pipeline
+from new_experiment.pipeline.pipeline_process_spacy_ner_wer import run_spacy_ner_wer_pipeline
+from new_experiment.pipeline.pipeline_process_spacy_pos_wer import run_spacy_pos_wer_pipeline
+from new_experiment.pipeline.pipeline_process_word_classic_wer import run_word_wer_classic_pipeline
+from new_experiment.pipeline.pipeline_process_word_embedding_wer import run_word_wer_embedding_pipeline
 from sziszapangma.integration.repository.mongo_experiment_repository import MongoExperimentRepository
 
 
@@ -17,6 +22,8 @@ def get_param(name: str, default: str) -> str:
 
 _RABBIT_URL = get_param('RABBIT_URL',
                         'amqps://rabbit_user:kz6m4972OUHFmtUcPOHx4kF3Lj6yw7lo@rabbit-asr-benchmarks.theliver.pl:5671/')
+
+
 def main():
     parameters = pika.URLParameters(_RABBIT_URL)
     connection = pika.BlockingConnection(parameters=parameters)
@@ -30,7 +37,21 @@ def main():
         print(message_dict)
 
         task = message_dict['task']
+        dataset = message_dict['dataset']
+        asr_name = message_dict['asr_name']
 
+        if task == 'run_word_wer_classic_pipeline':
+            run_word_wer_classic_pipeline(dataset, asr_name)
+        elif task == 'run_word_wer_embedding_pipeline':
+            run_word_wer_embedding_pipeline(dataset, asr_name)
+        elif task == 'run_spacy_dep_tag_wer_pipeline':
+            run_spacy_dep_tag_wer_pipeline(dataset, asr_name)
+        elif task == 'run_spacy_ner_wer_pipeline':
+            run_spacy_ner_wer_pipeline(dataset, asr_name)
+        elif task == 'run_spacy_pos_wer_pipeline':
+            run_spacy_pos_wer_pipeline(dataset, asr_name)
+        else:
+            raise Exception(f"Bad message {message_dict}")
 
         channel.basic_ack(method_frame.delivery_tag)
         print('\n########################################################\n')
-- 
GitLab