diff --git a/sziszapangma/integration/mapper/__init__.py b/sziszapangma/integration/mapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sziszapangma/integration/mapper/alignment_step_mapper.py b/sziszapangma/integration/mapper/alignment_step_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3bf9b32aaef3b1be84cc5c036208e3b31d4bc4 --- /dev/null +++ b/sziszapangma/integration/mapper/alignment_step_mapper.py @@ -0,0 +1,16 @@ +from typing import Dict + +from sziszapangma.core.alignment.alignment_step import AlignmentStep +from sziszapangma.integration.mapper.step_words_mapper import StepWordsMapper + + +class AlignmentStepMapper: + + @staticmethod + def to_json_dict(alignment_step: AlignmentStep) -> Dict[str, any]: + return { + 'step_type': alignment_step.step_type.name, + 'step_words': StepWordsMapper.to_json_dict( + alignment_step.step_words), + 'step_cost': alignment_step.step_cost + } diff --git a/sziszapangma/integration/mapper/step_words_mapper.py b/sziszapangma/integration/mapper/step_words_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..a28b532411317d7510ef92723eb0274583a18430 --- /dev/null +++ b/sziszapangma/integration/mapper/step_words_mapper.py @@ -0,0 +1,27 @@ +from typing import Dict + +from sziszapangma.core.alignment.step_words import StepWords +from sziszapangma.integration.mapper.word_mapper import WordMapper + + +class StepWordsMapper: + + @staticmethod + def to_json_dict(step_words: StepWords) -> Dict[str, any]: + to_return = dict() + if step_words.hypothesis_word is not None: + to_return['hypothesis_word'] = WordMapper.to_json_dict( + step_words.hypothesis_word) + if step_words.reference_word is not None: + to_return['reference_word'] = WordMapper.to_json_dict( + step_words.reference_word) + return to_return + + @staticmethod + def from_json_dict(input_json_dict: Dict[str, any]) -> StepWords: + return StepWords( + None if 'reference_word' not in input_json_dict + else WordMapper.from_json_dict(input_json_dict['reference_word']), + None if 'hypothesis_word' not in input_json_dict + else WordMapper.from_json_dict(input_json_dict['hypothesis_word']), + ) diff --git a/sziszapangma/integration/mapper/word_mapper.py b/sziszapangma/integration/mapper/word_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..04c30b10a4a024c583e766e216d5a5b8dcb5746f --- /dev/null +++ b/sziszapangma/integration/mapper/word_mapper.py @@ -0,0 +1,22 @@ +from typing import Dict + +from sziszapangma.core.alignment.alignment_step import AlignmentStep +from sziszapangma.core.alignment.step_words import StepWords +from sziszapangma.core.alignment.word import Word + +_ID = 'id' +_VALUE = 'value' + + +class WordMapper: + + @staticmethod + def to_json_dict(word: Word) -> Dict[str, str]: + return { + _ID: word.id, + _VALUE: word.value + } + + @staticmethod + def from_json_dict(input_json_dict: Dict[str, str]) -> Word: + return Word(input_json_dict[_ID], input_json_dict[_VALUE]) diff --git a/sziszapangma/integration/task/classic_wer_metric_task.py b/sziszapangma/integration/task/classic_wer_metric_task.py index 1f4ba703fe56f77d14b11763e36549db75fd57cf..1b6c17d6e9af34533cacc0debda8e54bb599041d 100644 --- a/sziszapangma/integration/task/classic_wer_metric_task.py +++ b/sziszapangma/integration/task/classic_wer_metric_task.py @@ -2,17 +2,22 @@ from typing import List, Dict from sziszapangma.core.alignment.alignment_classic_calculator import \ AlignmentClassicCalculator +from sziszapangma.core.alignment.alignment_step import AlignmentStep +from sziszapangma.core.alignment.word import Word from sziszapangma.core.wer.wer_calculator import WerCalculator +from sziszapangma.integration.mapper.alignment_step_mapper import \ + AlignmentStepMapper +from sziszapangma.integration.mapper.word_mapper import WordMapper from sziszapangma.integration.repository.experiment_repository import \ ExperimentRepository from sziszapangma.integration.task.processing_task import ProcessingTask _CLASSIC_WER = 'classic_wer' -_WORD = 'word' class ClassicWerMetricTask(ProcessingTask): _metrics_property_name: str + _alignment_property_name: str _gold_transcript_property_name: str _alignment_classic_calculator: AlignmentClassicCalculator _wer_calculator: WerCalculator @@ -23,17 +28,22 @@ class ClassicWerMetricTask(ProcessingTask): gold_transcript_property_name: str, asr_property_name: str, metrics_property_name: str, + alignment_property_name: str, require_update: bool ): super().__init__(task_name, require_update) self._gold_transcript_property_name = gold_transcript_property_name self._asr_property_name = asr_property_name + self._alignment_property_name = alignment_property_name self._metrics_property_name = metrics_property_name self._alignment_classic_calculator = AlignmentClassicCalculator() self._wer_calculator = WerCalculator() - def skip_for_record(self, record_id: str, - experiment_repository: ExperimentRepository) -> bool: + def skip_for_record( + self, + record_id: str, + experiment_repository: ExperimentRepository + ) -> bool: return experiment_repository \ .get_property_for_key(record_id, self._metrics_property_name) @@ -45,34 +55,44 @@ class ClassicWerMetricTask(ProcessingTask): asr_result = experiment_repository \ .get_property_for_key(record_id, self._asr_property_name) if 'transcription' in asr_result: + alignment_steps = self._get_alignment( + gold_transcript, asr_result['transcription'] + ) + experiment_repository.update_property_for_key( + record_id, + self._alignment_property_name, + [AlignmentStepMapper.to_json_dict(it) + for it in alignment_steps] + ) experiment_repository.update_property_for_key( record_id, self._metrics_property_name, - self.calculate_metrics( - gold_transcript=gold_transcript, - asr_result=asr_result['transcription'] - ) + self.calculate_metrics(alignment_steps) ) - def _run_wer_calculations( + def _get_alignment( self, gold_transcript: List[Dict[str, any]], - asr_result: List[str] - ) -> float: - return self._wer_calculator.calculate_wer( - self._alignment_classic_calculator.calculate_alignment( - reference=[it[_WORD] for it in gold_transcript], - hypothesis=[it for it in asr_result], - ) - ) + asr_result: List[Dict[str, any]] + ) -> List[AlignmentStep]: + gold_transcript_words = [ + WordMapper.from_json_dict(word_dict) + for word_dict in gold_transcript + ] + asr_words = [ + WordMapper.from_json_dict(word_dict) + for word_dict in asr_result + ] + return self._alignment_classic_calculator \ + .calculate_alignment(reference=gold_transcript_words, + hypothesis=asr_words) def calculate_metrics( self, - gold_transcript: List[Dict[str, any]], - asr_result: List[str] + alignment_steps: List[AlignmentStep] ) -> Dict[str, any]: """Calculate all metrics for data sample.""" metrics = dict() - metrics[_CLASSIC_WER] = self._run_wer_calculations( - gold_transcript, asr_result) + metrics[_CLASSIC_WER] = self._wer_calculator.calculate_wer( + alignment_steps) return metrics diff --git a/sziszapangma/integration/task/embedding_wer_metrics_task.py b/sziszapangma/integration/task/embedding_wer_metrics_task.py index 3145fbfdf177e1db528b47c3916cb0b2e6b66c09..3eb3476171ee99318a434c70ba06627300d645ac 100644 --- a/sziszapangma/integration/task/embedding_wer_metrics_task.py +++ b/sziszapangma/integration/task/embedding_wer_metrics_task.py @@ -4,11 +4,15 @@ from sziszapangma.core.alignment.alignment_embedding_calculator import \ AlignmentEmbeddingCalculator from sziszapangma.core.alignment.alignment_soft_calculator import \ AlignmentSoftCalculator +from sziszapangma.core.alignment.word import Word from sziszapangma.core.transformer.cached_embedding_transformer import \ CachedEmbeddingTransformer from sziszapangma.core.transformer.embedding_transformer import \ EmbeddingTransformer from sziszapangma.core.wer.wer_calculator import WerCalculator +from sziszapangma.integration.mapper.alignment_step_mapper import \ + AlignmentStepMapper +from sziszapangma.integration.mapper.word_mapper import WordMapper from sziszapangma.integration.repository.experiment_repository import \ ExperimentRepository from sziszapangma.integration.task.processing_task import ProcessingTask @@ -20,6 +24,7 @@ _WORD = 'word' class EmbeddingWerMetricsTask(ProcessingTask): _metrics_property_name: str + _alignment_property_name: str _gold_transcript_property_name: str _embedding_transformer: CachedEmbeddingTransformer _alignment_embedding_calculator: AlignmentEmbeddingCalculator @@ -32,6 +37,7 @@ class EmbeddingWerMetricsTask(ProcessingTask): gold_transcript_property_name: str, asr_property_name: str, metrics_property_name: str, + alignment_property_name: str, require_update: bool, embedding_transformer: EmbeddingTransformer ): @@ -46,6 +52,7 @@ class EmbeddingWerMetricsTask(ProcessingTask): self._alignment_soft_calculator = \ AlignmentSoftCalculator(self._embedding_transformer) self._wer_calculator = WerCalculator() + self._alignment_property_name = alignment_property_name def skip_for_record(self, record_id: str, experiment_repository: ExperimentRepository) -> bool: @@ -60,26 +67,39 @@ class EmbeddingWerMetricsTask(ProcessingTask): asr_result = experiment_repository \ .get_property_for_key(record_id, self._asr_property_name) if 'transcription' in asr_result: + gold_transcript_words = self._map_words_to_domain(gold_transcript) + asr_words = self._map_words_to_domain(asr_result['transcription']) + + soft_alignment = self._alignment_soft_calculator \ + .calculate_alignment(gold_transcript_words, asr_words) + embedding_alignment = self._alignment_embedding_calculator \ + .calculate_alignment(gold_transcript_words, asr_words) + + soft_wer = self._wer_calculator.calculate_wer(soft_alignment) + embedding_wer = self._wer_calculator \ + .calculate_wer(embedding_alignment) + + alignment_results = { + 'soft_alignment': [AlignmentStepMapper.to_json_dict(it) + for it in soft_alignment], + 'embedding_alignment': [AlignmentStepMapper.to_json_dict(it) + for it in embedding_alignment], + } + wer_results = {'soft_wer': soft_wer, + 'embedding_wer': embedding_wer} + experiment_repository.update_property_for_key( - record_id, - self._metrics_property_name, - self.calculate_metrics( - gold_transcript=gold_transcript, - asr_result=asr_result['transcription'] - ) - ) + record_id, self._alignment_property_name, alignment_results) + experiment_repository.update_property_for_key( + record_id, self._metrics_property_name, wer_results) + self._embedding_transformer.clear() - def calculate_metrics( - self, - gold_transcript: List[Dict[str, any]], - asr_result: List[str] - ) -> Dict[str, any]: - """Calculate all metrics for data sample.""" - metrics = dict() - reference = [it[_WORD] for it in gold_transcript] - metrics[_SOFT_WER] = self._alignment_soft_calculator\ - .calculate_alignment(reference, asr_result)[0] - metrics[_EMBEDDING_WER] = self._alignment_embedding_calculator\ - .calculate_wer(reference, asr_result)[0] - return metrics + @staticmethod + def _map_words_to_domain( + input_json_dicts: List[Dict[str, str]] + ) -> List[Word]: + return [ + WordMapper.from_json_dict(word_dict) + for word_dict in input_json_dicts + ]