From 51ae81aab9ddcfa59a7929b195875c7de6484bba Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com>
Date: Tue, 29 Jun 2021 15:42:36 +0200
Subject: [PATCH] Fix integration

---
 sziszapangma/integration/mapper/__init__.py   |  0
 .../mapper/alignment_step_mapper.py           | 16 +++++
 .../integration/mapper/step_words_mapper.py   | 27 +++++++++
 .../integration/mapper/word_mapper.py         | 22 +++++++
 .../task/classic_wer_metric_task.py           | 60 ++++++++++++-------
 .../task/embedding_wer_metrics_task.py        | 60 ++++++++++++-------
 6 files changed, 145 insertions(+), 40 deletions(-)
 create mode 100644 sziszapangma/integration/mapper/__init__.py
 create mode 100644 sziszapangma/integration/mapper/alignment_step_mapper.py
 create mode 100644 sziszapangma/integration/mapper/step_words_mapper.py
 create mode 100644 sziszapangma/integration/mapper/word_mapper.py

diff --git a/sziszapangma/integration/mapper/__init__.py b/sziszapangma/integration/mapper/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/sziszapangma/integration/mapper/alignment_step_mapper.py b/sziszapangma/integration/mapper/alignment_step_mapper.py
new file mode 100644
index 0000000..8b3bf9b
--- /dev/null
+++ b/sziszapangma/integration/mapper/alignment_step_mapper.py
@@ -0,0 +1,16 @@
+from typing import Dict
+
+from sziszapangma.core.alignment.alignment_step import AlignmentStep
+from sziszapangma.integration.mapper.step_words_mapper import StepWordsMapper
+
+
+class AlignmentStepMapper:
+
+    @staticmethod
+    def to_json_dict(alignment_step: AlignmentStep) -> Dict[str, any]:
+        return {
+            'step_type': alignment_step.step_type.name,
+            'step_words': StepWordsMapper.to_json_dict(
+                alignment_step.step_words),
+            'step_cost': alignment_step.step_cost
+        }
diff --git a/sziszapangma/integration/mapper/step_words_mapper.py b/sziszapangma/integration/mapper/step_words_mapper.py
new file mode 100644
index 0000000..a28b532
--- /dev/null
+++ b/sziszapangma/integration/mapper/step_words_mapper.py
@@ -0,0 +1,27 @@
+from typing import Dict
+
+from sziszapangma.core.alignment.step_words import StepWords
+from sziszapangma.integration.mapper.word_mapper import WordMapper
+
+
+class StepWordsMapper:
+
+    @staticmethod
+    def to_json_dict(step_words: StepWords) -> Dict[str, any]:
+        to_return = dict()
+        if step_words.hypothesis_word is not None:
+            to_return['hypothesis_word'] = WordMapper.to_json_dict(
+                step_words.hypothesis_word)
+        if step_words.reference_word is not None:
+            to_return['reference_word'] = WordMapper.to_json_dict(
+                step_words.reference_word)
+        return to_return
+
+    @staticmethod
+    def from_json_dict(input_json_dict: Dict[str, any]) -> StepWords:
+        return StepWords(
+            None if 'reference_word' not in input_json_dict
+            else WordMapper.from_json_dict(input_json_dict['reference_word']),
+            None if 'hypothesis_word' not in input_json_dict
+            else WordMapper.from_json_dict(input_json_dict['hypothesis_word']),
+        )
diff --git a/sziszapangma/integration/mapper/word_mapper.py b/sziszapangma/integration/mapper/word_mapper.py
new file mode 100644
index 0000000..04c30b1
--- /dev/null
+++ b/sziszapangma/integration/mapper/word_mapper.py
@@ -0,0 +1,22 @@
+from typing import Dict
+
+from sziszapangma.core.alignment.alignment_step import AlignmentStep
+from sziszapangma.core.alignment.step_words import StepWords
+from sziszapangma.core.alignment.word import Word
+
+_ID = 'id'
+_VALUE = 'value'
+
+
+class WordMapper:
+
+    @staticmethod
+    def to_json_dict(word: Word) -> Dict[str, str]:
+        return {
+            _ID: word.id,
+            _VALUE: word.value
+        }
+
+    @staticmethod
+    def from_json_dict(input_json_dict: Dict[str, str]) -> Word:
+        return Word(input_json_dict[_ID], input_json_dict[_VALUE])
diff --git a/sziszapangma/integration/task/classic_wer_metric_task.py b/sziszapangma/integration/task/classic_wer_metric_task.py
index 1f4ba70..1b6c17d 100644
--- a/sziszapangma/integration/task/classic_wer_metric_task.py
+++ b/sziszapangma/integration/task/classic_wer_metric_task.py
@@ -2,17 +2,22 @@ from typing import List, Dict
 
 from sziszapangma.core.alignment.alignment_classic_calculator import \
     AlignmentClassicCalculator
+from sziszapangma.core.alignment.alignment_step import AlignmentStep
+from sziszapangma.core.alignment.word import Word
 from sziszapangma.core.wer.wer_calculator import WerCalculator
+from sziszapangma.integration.mapper.alignment_step_mapper import \
+    AlignmentStepMapper
+from sziszapangma.integration.mapper.word_mapper import WordMapper
 from sziszapangma.integration.repository.experiment_repository import \
     ExperimentRepository
 from sziszapangma.integration.task.processing_task import ProcessingTask
 
 _CLASSIC_WER = 'classic_wer'
-_WORD = 'word'
 
 
 class ClassicWerMetricTask(ProcessingTask):
     _metrics_property_name: str
+    _alignment_property_name: str
     _gold_transcript_property_name: str
     _alignment_classic_calculator: AlignmentClassicCalculator
     _wer_calculator: WerCalculator
@@ -23,17 +28,22 @@ class ClassicWerMetricTask(ProcessingTask):
         gold_transcript_property_name: str,
         asr_property_name: str,
         metrics_property_name: str,
+        alignment_property_name: str,
         require_update: bool
     ):
         super().__init__(task_name, require_update)
         self._gold_transcript_property_name = gold_transcript_property_name
         self._asr_property_name = asr_property_name
+        self._alignment_property_name = alignment_property_name
         self._metrics_property_name = metrics_property_name
         self._alignment_classic_calculator = AlignmentClassicCalculator()
         self._wer_calculator = WerCalculator()
 
-    def skip_for_record(self, record_id: str,
-                        experiment_repository: ExperimentRepository) -> bool:
+    def skip_for_record(
+        self,
+        record_id: str,
+        experiment_repository: ExperimentRepository
+    ) -> bool:
         return experiment_repository \
             .get_property_for_key(record_id, self._metrics_property_name)
 
@@ -45,34 +55,44 @@ class ClassicWerMetricTask(ProcessingTask):
         asr_result = experiment_repository \
             .get_property_for_key(record_id, self._asr_property_name)
         if 'transcription' in asr_result:
+            alignment_steps = self._get_alignment(
+                gold_transcript, asr_result['transcription']
+            )
+            experiment_repository.update_property_for_key(
+                record_id,
+                self._alignment_property_name,
+                [AlignmentStepMapper.to_json_dict(it)
+                 for it in alignment_steps]
+            )
             experiment_repository.update_property_for_key(
                 record_id,
                 self._metrics_property_name,
-                self.calculate_metrics(
-                    gold_transcript=gold_transcript,
-                    asr_result=asr_result['transcription']
-                )
+                self.calculate_metrics(alignment_steps)
             )
 
-    def _run_wer_calculations(
+    def _get_alignment(
         self,
         gold_transcript: List[Dict[str, any]],
-        asr_result: List[str]
-    ) -> float:
-        return self._wer_calculator.calculate_wer(
-            self._alignment_classic_calculator.calculate_alignment(
-                reference=[it[_WORD] for it in gold_transcript],
-                hypothesis=[it for it in asr_result],
-            )
-        )
+        asr_result: List[Dict[str, any]]
+    ) -> List[AlignmentStep]:
+        gold_transcript_words = [
+            WordMapper.from_json_dict(word_dict)
+            for word_dict in gold_transcript
+        ]
+        asr_words = [
+            WordMapper.from_json_dict(word_dict)
+            for word_dict in asr_result
+        ]
+        return self._alignment_classic_calculator \
+            .calculate_alignment(reference=gold_transcript_words,
+                                 hypothesis=asr_words)
 
     def calculate_metrics(
         self,
-        gold_transcript: List[Dict[str, any]],
-        asr_result: List[str]
+        alignment_steps: List[AlignmentStep]
     ) -> Dict[str, any]:
         """Calculate all metrics for data sample."""
         metrics = dict()
-        metrics[_CLASSIC_WER] = self._run_wer_calculations(
-            gold_transcript, asr_result)
+        metrics[_CLASSIC_WER] = self._wer_calculator.calculate_wer(
+            alignment_steps)
         return metrics
diff --git a/sziszapangma/integration/task/embedding_wer_metrics_task.py b/sziszapangma/integration/task/embedding_wer_metrics_task.py
index 3145fbf..3eb3476 100644
--- a/sziszapangma/integration/task/embedding_wer_metrics_task.py
+++ b/sziszapangma/integration/task/embedding_wer_metrics_task.py
@@ -4,11 +4,15 @@ from sziszapangma.core.alignment.alignment_embedding_calculator import \
     AlignmentEmbeddingCalculator
 from sziszapangma.core.alignment.alignment_soft_calculator import \
     AlignmentSoftCalculator
+from sziszapangma.core.alignment.word import Word
 from sziszapangma.core.transformer.cached_embedding_transformer import \
     CachedEmbeddingTransformer
 from sziszapangma.core.transformer.embedding_transformer import \
     EmbeddingTransformer
 from sziszapangma.core.wer.wer_calculator import WerCalculator
+from sziszapangma.integration.mapper.alignment_step_mapper import \
+    AlignmentStepMapper
+from sziszapangma.integration.mapper.word_mapper import WordMapper
 from sziszapangma.integration.repository.experiment_repository import \
     ExperimentRepository
 from sziszapangma.integration.task.processing_task import ProcessingTask
@@ -20,6 +24,7 @@ _WORD = 'word'
 
 class EmbeddingWerMetricsTask(ProcessingTask):
     _metrics_property_name: str
+    _alignment_property_name: str
     _gold_transcript_property_name: str
     _embedding_transformer: CachedEmbeddingTransformer
     _alignment_embedding_calculator: AlignmentEmbeddingCalculator
@@ -32,6 +37,7 @@ class EmbeddingWerMetricsTask(ProcessingTask):
         gold_transcript_property_name: str,
         asr_property_name: str,
         metrics_property_name: str,
+        alignment_property_name: str,
         require_update: bool,
         embedding_transformer: EmbeddingTransformer
     ):
@@ -46,6 +52,7 @@ class EmbeddingWerMetricsTask(ProcessingTask):
         self._alignment_soft_calculator = \
             AlignmentSoftCalculator(self._embedding_transformer)
         self._wer_calculator = WerCalculator()
+        self._alignment_property_name = alignment_property_name
 
     def skip_for_record(self, record_id: str,
                         experiment_repository: ExperimentRepository) -> bool:
@@ -60,26 +67,39 @@ class EmbeddingWerMetricsTask(ProcessingTask):
         asr_result = experiment_repository \
             .get_property_for_key(record_id, self._asr_property_name)
         if 'transcription' in asr_result:
+            gold_transcript_words = self._map_words_to_domain(gold_transcript)
+            asr_words = self._map_words_to_domain(asr_result['transcription'])
+
+            soft_alignment = self._alignment_soft_calculator \
+                .calculate_alignment(gold_transcript_words, asr_words)
+            embedding_alignment = self._alignment_embedding_calculator \
+                .calculate_alignment(gold_transcript_words, asr_words)
+
+            soft_wer = self._wer_calculator.calculate_wer(soft_alignment)
+            embedding_wer = self._wer_calculator \
+                .calculate_wer(embedding_alignment)
+
+            alignment_results = {
+                'soft_alignment': [AlignmentStepMapper.to_json_dict(it)
+                                   for it in soft_alignment],
+                'embedding_alignment': [AlignmentStepMapper.to_json_dict(it)
+                                        for it in embedding_alignment],
+            }
+            wer_results = {'soft_wer': soft_wer,
+                           'embedding_wer': embedding_wer}
+
             experiment_repository.update_property_for_key(
-                record_id,
-                self._metrics_property_name,
-                self.calculate_metrics(
-                    gold_transcript=gold_transcript,
-                    asr_result=asr_result['transcription']
-                )
-            )
+                record_id, self._alignment_property_name, alignment_results)
+            experiment_repository.update_property_for_key(
+                record_id, self._metrics_property_name, wer_results)
+
         self._embedding_transformer.clear()
 
-    def calculate_metrics(
-        self,
-        gold_transcript: List[Dict[str, any]],
-        asr_result: List[str]
-    ) -> Dict[str, any]:
-        """Calculate all metrics for data sample."""
-        metrics = dict()
-        reference = [it[_WORD] for it in gold_transcript]
-        metrics[_SOFT_WER] = self._alignment_soft_calculator\
-            .calculate_alignment(reference, asr_result)[0]
-        metrics[_EMBEDDING_WER] = self._alignment_embedding_calculator\
-            .calculate_wer(reference, asr_result)[0]
-        return metrics
+    @staticmethod
+    def _map_words_to_domain(
+        input_json_dicts: List[Dict[str, str]]
+    ) -> List[Word]:
+        return [
+            WordMapper.from_json_dict(word_dict)
+            for word_dict in input_json_dicts
+        ]
-- 
GitLab