from typing import Dict, List

from sziszapangma.core.alignment.alignment_embedding_calculator import AlignmentEmbeddingCalculator
from sziszapangma.core.alignment.alignment_soft_calculator import AlignmentSoftCalculator
from sziszapangma.core.alignment.word import Word
from sziszapangma.core.transformer.cached_embedding_transformer import CachedEmbeddingTransformer
from sziszapangma.core.transformer.embedding_transformer import EmbeddingTransformer
from sziszapangma.core.wer.wer_calculator import WerCalculator
from sziszapangma.integration.mapper.alignment_step_mapper import AlignmentStepMapper
from sziszapangma.integration.mapper.word_mapper import WordMapper
from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
from sziszapangma.integration.task.processing_task import ProcessingTask

_SOFT_WER = "soft_wer"
_EMBEDDING_WER = "embedding_wer"
_WORD = "word"


class EmbeddingWerMetricsTask(ProcessingTask):
    _metrics_property_name: str
    _alignment_property_name: str
    _gold_transcript_property_name: str
    _embedding_transformer: CachedEmbeddingTransformer
    _alignment_embedding_calculator: AlignmentEmbeddingCalculator
    _alignment_soft_calculator: AlignmentSoftCalculator
    _wer_calculator: WerCalculator

    def __init__(
        self,
        task_name: str,
        gold_transcript_property_name: str,
        asr_property_name: str,
        metrics_property_name: str,
        alignment_property_name: str,
        require_update: bool,
        embedding_transformer: EmbeddingTransformer,
    ):
        super().__init__(task_name, require_update)
        self._gold_transcript_property_name = gold_transcript_property_name
        self._asr_property_name = asr_property_name
        self._metrics_property_name = metrics_property_name
        self._embedding_transformer = CachedEmbeddingTransformer(embedding_transformer)
        self._alignment_embedding_calculator = AlignmentEmbeddingCalculator(
            self._embedding_transformer
        )
        self._alignment_soft_calculator = AlignmentSoftCalculator(self._embedding_transformer)
        self._wer_calculator = WerCalculator()
        self._alignment_property_name = alignment_property_name

    def skip_for_record(self, record_id: str, experiment_repository: ExperimentRepository) -> bool:
        return experiment_repository.get_property_for_key(record_id, self._metrics_property_name)

    def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository):
        gold_transcript = experiment_repository.get_property_for_key(
            record_id, self._gold_transcript_property_name
        )
        asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name)
        if "transcription" in asr_result:
            gold_transcript_words = self._map_words_to_domain(gold_transcript)
            asr_words = self._map_words_to_domain(asr_result["transcription"])

            soft_alignment = self._alignment_soft_calculator.calculate_alignment(
                gold_transcript_words, asr_words
            )
            embedding_alignment = self._alignment_embedding_calculator.calculate_alignment(
                gold_transcript_words, asr_words
            )

            soft_wer = self._wer_calculator.calculate_wer(soft_alignment)
            embedding_wer = self._wer_calculator.calculate_wer(embedding_alignment)

            alignment_results = {
                "soft_alignment": [AlignmentStepMapper.to_json_dict(it) for it in soft_alignment],
                "embedding_alignment": [
                    AlignmentStepMapper.to_json_dict(it) for it in embedding_alignment
                ],
            }
            wer_results = {"soft_wer": soft_wer, "embedding_wer": embedding_wer}

            experiment_repository.update_property_for_key(
                record_id, self._alignment_property_name, alignment_results
            )
            experiment_repository.update_property_for_key(
                record_id, self._metrics_property_name, wer_results
            )

        self._embedding_transformer.clear()

    @staticmethod
    def _map_words_to_domain(input_json_dicts: List[Dict[str, str]]) -> List[Word]:
        return [WordMapper.from_json_dict(word_dict) for word_dict in input_json_dicts]
