Skip to content
Snippets Groups Projects
Select Git revision
  • 9a94c87e64a1083d8964b927853fd029da09c823
  • main default protected
  • change_data_model
  • feature/add_auth_asr_service
  • fix/incorrect_import
  • feature/change_registry_clarin
  • feature/add_base_asr_service
  • feature/add_poetry
  • feature/add_word_ids
  • feature/add_sziszapangma
10 results

embedding_wer_metrics_task.py

Blame
  • user avatar
    Marcin Wątroba authored
    6736343c
    History
    embedding_wer_metrics_task.py 4.16 KiB
    from sziszapangma.core.alignment.alignment_embedding_calculator import AlignmentEmbeddingCalculator
    from sziszapangma.core.alignment.alignment_soft_calculator import AlignmentSoftCalculator
    from sziszapangma.core.transformer.cached_embedding_transformer import CachedEmbeddingTransformer
    from sziszapangma.core.transformer.embedding_transformer import EmbeddingTransformer
    from sziszapangma.core.wer.wer_calculator import WerCalculator
    from sziszapangma.integration.mapper.alignment_step_mapper import AlignmentStepMapper
    from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
    from sziszapangma.integration.task.processing_task import ProcessingTask
    from sziszapangma.integration.task.task_util import TaskUtil
    from sziszapangma.model.relation_manager import RelationManager
    
    _SOFT_WER = "soft_wer"
    _EMBEDDING_WER = "embedding_wer"
    _WORD = "word"
    
    
    class EmbeddingWerMetricsTask(ProcessingTask):
        _metrics_property_name: str
        _alignment_property_name: str
        _gold_transcript_property_name: str
        _embedding_transformer: CachedEmbeddingTransformer
        _alignment_embedding_calculator: AlignmentEmbeddingCalculator
        _alignment_soft_calculator: AlignmentSoftCalculator
        _wer_calculator: WerCalculator
    
        def __init__(
            self,
            task_name: str,
            gold_transcript_property_name: str,
            asr_property_name: str,
            metrics_property_name: str,
            alignment_property_name: str,
            require_update: bool,
            embedding_transformer: EmbeddingTransformer,
        ):
            super().__init__(task_name, require_update)
            self._gold_transcript_property_name = gold_transcript_property_name
            self._asr_property_name = asr_property_name
            self._metrics_property_name = metrics_property_name
            self._embedding_transformer = CachedEmbeddingTransformer(embedding_transformer)
            self._alignment_embedding_calculator = AlignmentEmbeddingCalculator(
                self._embedding_transformer
            )
            self._alignment_soft_calculator = AlignmentSoftCalculator(self._embedding_transformer)
            self._wer_calculator = WerCalculator()
            self._alignment_property_name = alignment_property_name
    
        def skip_for_record(self, record_id: str, experiment_repository: ExperimentRepository) -> bool:
            return (
                experiment_repository.get_property_for_key(record_id, self._metrics_property_name)
                is not None
            )
    
        def run_single_process(
            self,
            record_id: str,
            experiment_repository: ExperimentRepository,
            relation_manager: RelationManager,
        ):
            gold_transcript = TaskUtil.get_words_from_record(relation_manager)
            asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name)
            if gold_transcript is not None and asr_result is not None and "transcription" in asr_result:
                gold_transcript_lower = TaskUtil.words_to_lower(gold_transcript)
                asr_transcript_lower = TaskUtil.words_to_lower(asr_result["transcription"])
    
                soft_alignment = self._alignment_soft_calculator.calculate_alignment(
                    gold_transcript_lower, asr_transcript_lower
                )
                embedding_alignment = self._alignment_embedding_calculator.calculate_alignment(
                    gold_transcript_lower, asr_transcript_lower
                )
    
                soft_wer = self._wer_calculator.calculate_wer(soft_alignment)
                embedding_wer = self._wer_calculator.calculate_wer(embedding_alignment)
    
                alignment_results = {
                    "soft_alignment": [AlignmentStepMapper.to_json_dict(it) for it in soft_alignment],
                    "embedding_alignment": [
                        AlignmentStepMapper.to_json_dict(it) for it in embedding_alignment
                    ],
                }
                wer_results = {"soft_wer": soft_wer, "embedding_wer": embedding_wer}
    
                experiment_repository.update_property_for_key(
                    record_id, self._alignment_property_name, alignment_results
                )
                experiment_repository.update_property_for_key(
                    record_id, self._metrics_property_name, wer_results
                )
    
            self._embedding_transformer.clear()