from typing import Dict, List

from sziszapangma.core.alignment.alignment_classic_calculator import AlignmentClassicCalculator
from sziszapangma.core.alignment.alignment_step import AlignmentStep
from sziszapangma.core.wer.wer_calculator import WerCalculator
from sziszapangma.integration.mapper.alignment_step_mapper import AlignmentStepMapper
from sziszapangma.integration.mapper.word_mapper import WordMapper
from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
from sziszapangma.integration.task.processing_task import ProcessingTask

_CLASSIC_WER = "classic_wer"


class ClassicWerMetricTask(ProcessingTask):
    _metrics_property_name: str
    _alignment_property_name: str
    _gold_transcript_property_name: str
    _alignment_classic_calculator: AlignmentClassicCalculator
    _wer_calculator: WerCalculator

    def __init__(
        self,
        task_name: str,
        gold_transcript_property_name: str,
        asr_property_name: str,
        metrics_property_name: str,
        alignment_property_name: str,
        require_update: bool,
    ):
        super().__init__(task_name, require_update)
        self._gold_transcript_property_name = gold_transcript_property_name
        self._asr_property_name = asr_property_name
        self._alignment_property_name = alignment_property_name
        self._metrics_property_name = metrics_property_name
        self._alignment_classic_calculator = AlignmentClassicCalculator()
        self._wer_calculator = WerCalculator()

    def skip_for_record(self, record_id: str, experiment_repository: ExperimentRepository) -> bool:
        return (
            experiment_repository.get_property_for_key(record_id, self._metrics_property_name)
            is not None
        )

    def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository):
        gold_transcript = experiment_repository.get_property_for_key(
            record_id, self._gold_transcript_property_name
        )
        asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name)
        if "transcription" in asr_result:
            alignment_steps = self._get_alignment(gold_transcript, asr_result["transcription"])
            experiment_repository.update_property_for_key(
                record_id,
                self._alignment_property_name,
                [AlignmentStepMapper.to_json_dict(it) for it in alignment_steps],
            )
            experiment_repository.update_property_for_key(
                record_id, self._metrics_property_name, self.calculate_metrics(alignment_steps)
            )

    def _get_alignment(
        self, gold_transcript: List[Dict[str, any]], asr_result: List[Dict[str, any]]
    ) -> List[AlignmentStep]:
        gold_transcript_words = [
            WordMapper.from_json_dict(word_dict) for word_dict in gold_transcript
        ]
        asr_words = [WordMapper.from_json_dict(word_dict) for word_dict in asr_result]
        return self._alignment_classic_calculator.calculate_alignment(
            reference=gold_transcript_words, hypothesis=asr_words
        )

    def calculate_metrics(self, alignment_steps: List[AlignmentStep]) -> Dict[str, any]:
        """Calculate all metrics for data sample."""
        metrics = dict()
        metrics[_CLASSIC_WER] = self._wer_calculator.calculate_wer(alignment_steps)
        return metrics
