-
Marcin Wątroba authoredUnverified4a03782a
classic_wer_metric_task.py 3.62 KiB
from typing import List, Dict
from sziszapangma.core.alignment.alignment_classic_calculator import \
AlignmentClassicCalculator
from sziszapangma.core.alignment.alignment_step import AlignmentStep
from sziszapangma.core.wer.wer_calculator import WerCalculator
from sziszapangma.integration.mapper.alignment_step_mapper import \
AlignmentStepMapper
from sziszapangma.integration.mapper.word_mapper import WordMapper
from sziszapangma.integration.repository.experiment_repository import \
ExperimentRepository
from sziszapangma.integration.task.processing_task import ProcessingTask
_CLASSIC_WER = 'classic_wer'
class ClassicWerMetricTask(ProcessingTask):
_metrics_property_name: str
_alignment_property_name: str
_gold_transcript_property_name: str
_alignment_classic_calculator: AlignmentClassicCalculator
_wer_calculator: WerCalculator
def __init__(
self,
task_name: str,
gold_transcript_property_name: str,
asr_property_name: str,
metrics_property_name: str,
alignment_property_name: str,
require_update: bool
):
super().__init__(task_name, require_update)
self._gold_transcript_property_name = gold_transcript_property_name
self._asr_property_name = asr_property_name
self._alignment_property_name = alignment_property_name
self._metrics_property_name = metrics_property_name
self._alignment_classic_calculator = AlignmentClassicCalculator()
self._wer_calculator = WerCalculator()
def skip_for_record(
self,
record_id: str,
experiment_repository: ExperimentRepository
) -> bool:
return experiment_repository \
.get_property_for_key(record_id, self._metrics_property_name)
def run_single_process(self, record_id: str,
experiment_repository: ExperimentRepository):
gold_transcript = experiment_repository \
.get_property_for_key(record_id,
self._gold_transcript_property_name)
asr_result = experiment_repository \
.get_property_for_key(record_id, self._asr_property_name)
if 'transcription' in asr_result:
alignment_steps = self._get_alignment(
gold_transcript, asr_result['transcription']
)
experiment_repository.update_property_for_key(
record_id,
self._alignment_property_name,
[AlignmentStepMapper.to_json_dict(it)
for it in alignment_steps]
)
experiment_repository.update_property_for_key(
record_id,
self._metrics_property_name,
self.calculate_metrics(alignment_steps)
)
def _get_alignment(
self,
gold_transcript: List[Dict[str, any]],
asr_result: List[Dict[str, any]]
) -> List[AlignmentStep]:
gold_transcript_words = [
WordMapper.from_json_dict(word_dict)
for word_dict in gold_transcript
]
asr_words = [
WordMapper.from_json_dict(word_dict)
for word_dict in asr_result
]
return self._alignment_classic_calculator \
.calculate_alignment(reference=gold_transcript_words,
hypothesis=asr_words)
def calculate_metrics(
self,
alignment_steps: List[AlignmentStep]
) -> Dict[str, any]:
"""Calculate all metrics for data sample."""
metrics = dict()
metrics[_CLASSIC_WER] = self._wer_calculator.calculate_wer(
alignment_steps)
return metrics