Skip to content
Snippets Groups Projects
classic_wer_metric_task.py 3.62 KiB
from typing import List, Dict

from sziszapangma.core.alignment.alignment_classic_calculator import \
    AlignmentClassicCalculator
from sziszapangma.core.alignment.alignment_step import AlignmentStep
from sziszapangma.core.wer.wer_calculator import WerCalculator
from sziszapangma.integration.mapper.alignment_step_mapper import \
    AlignmentStepMapper
from sziszapangma.integration.mapper.word_mapper import WordMapper
from sziszapangma.integration.repository.experiment_repository import \
    ExperimentRepository
from sziszapangma.integration.task.processing_task import ProcessingTask

_CLASSIC_WER = 'classic_wer'


class ClassicWerMetricTask(ProcessingTask):
    _metrics_property_name: str
    _alignment_property_name: str
    _gold_transcript_property_name: str
    _alignment_classic_calculator: AlignmentClassicCalculator
    _wer_calculator: WerCalculator

    def __init__(
        self,
        task_name: str,
        gold_transcript_property_name: str,
        asr_property_name: str,
        metrics_property_name: str,
        alignment_property_name: str,
        require_update: bool
    ):
        super().__init__(task_name, require_update)
        self._gold_transcript_property_name = gold_transcript_property_name
        self._asr_property_name = asr_property_name
        self._alignment_property_name = alignment_property_name
        self._metrics_property_name = metrics_property_name
        self._alignment_classic_calculator = AlignmentClassicCalculator()
        self._wer_calculator = WerCalculator()

    def skip_for_record(
        self,
        record_id: str,
        experiment_repository: ExperimentRepository
    ) -> bool:
        return experiment_repository \
            .get_property_for_key(record_id, self._metrics_property_name)

    def run_single_process(self, record_id: str,
                           experiment_repository: ExperimentRepository):
        gold_transcript = experiment_repository \
            .get_property_for_key(record_id,
                                  self._gold_transcript_property_name)
        asr_result = experiment_repository \
            .get_property_for_key(record_id, self._asr_property_name)
        if 'transcription' in asr_result:
            alignment_steps = self._get_alignment(
                gold_transcript, asr_result['transcription']
            )
            experiment_repository.update_property_for_key(
                record_id,
                self._alignment_property_name,
                [AlignmentStepMapper.to_json_dict(it)
                 for it in alignment_steps]
            )
            experiment_repository.update_property_for_key(
                record_id,
                self._metrics_property_name,
                self.calculate_metrics(alignment_steps)
            )

    def _get_alignment(
        self,
        gold_transcript: List[Dict[str, any]],
        asr_result: List[Dict[str, any]]
    ) -> List[AlignmentStep]:
        gold_transcript_words = [
            WordMapper.from_json_dict(word_dict)
            for word_dict in gold_transcript
        ]
        asr_words = [
            WordMapper.from_json_dict(word_dict)
            for word_dict in asr_result
        ]
        return self._alignment_classic_calculator \
            .calculate_alignment(reference=gold_transcript_words,
                                 hypothesis=asr_words)

    def calculate_metrics(
        self,
        alignment_steps: List[AlignmentStep]
    ) -> Dict[str, any]:
        """Calculate all metrics for data sample."""
        metrics = dict()
        metrics[_CLASSIC_WER] = self._wer_calculator.calculate_wer(
            alignment_steps)
        return metrics