Select Git revision
embedding_wer_metrics_task.py
embedding_wer_metrics_task.py 4.16 KiB
from sziszapangma.core.alignment.alignment_embedding_calculator import AlignmentEmbeddingCalculator
from sziszapangma.core.alignment.alignment_soft_calculator import AlignmentSoftCalculator
from sziszapangma.core.transformer.cached_embedding_transformer import CachedEmbeddingTransformer
from sziszapangma.core.transformer.embedding_transformer import EmbeddingTransformer
from sziszapangma.core.wer.wer_calculator import WerCalculator
from sziszapangma.integration.mapper.alignment_step_mapper import AlignmentStepMapper
from sziszapangma.integration.repository.experiment_repository import ExperimentRepository
from sziszapangma.integration.task.processing_task import ProcessingTask
from sziszapangma.integration.task.task_util import TaskUtil
from sziszapangma.model.relation_manager import RelationManager
_SOFT_WER = "soft_wer"
_EMBEDDING_WER = "embedding_wer"
_WORD = "word"
class EmbeddingWerMetricsTask(ProcessingTask):
_metrics_property_name: str
_alignment_property_name: str
_gold_transcript_property_name: str
_embedding_transformer: CachedEmbeddingTransformer
_alignment_embedding_calculator: AlignmentEmbeddingCalculator
_alignment_soft_calculator: AlignmentSoftCalculator
_wer_calculator: WerCalculator
def __init__(
self,
task_name: str,
gold_transcript_property_name: str,
asr_property_name: str,
metrics_property_name: str,
alignment_property_name: str,
require_update: bool,
embedding_transformer: EmbeddingTransformer,
):
super().__init__(task_name, require_update)
self._gold_transcript_property_name = gold_transcript_property_name
self._asr_property_name = asr_property_name
self._metrics_property_name = metrics_property_name
self._embedding_transformer = CachedEmbeddingTransformer(embedding_transformer)
self._alignment_embedding_calculator = AlignmentEmbeddingCalculator(
self._embedding_transformer
)
self._alignment_soft_calculator = AlignmentSoftCalculator(self._embedding_transformer)
self._wer_calculator = WerCalculator()
self._alignment_property_name = alignment_property_name
def skip_for_record(self, record_id: str, experiment_repository: ExperimentRepository) -> bool:
return (
experiment_repository.get_property_for_key(record_id, self._metrics_property_name)
is not None
)
def run_single_process(
self,
record_id: str,
experiment_repository: ExperimentRepository,
relation_manager: RelationManager,
):
gold_transcript = TaskUtil.get_words_from_record(relation_manager)
asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name)
if gold_transcript is not None and asr_result is not None and "transcription" in asr_result:
gold_transcript_lower = TaskUtil.words_to_lower(gold_transcript)
asr_transcript_lower = TaskUtil.words_to_lower(asr_result["transcription"])
soft_alignment = self._alignment_soft_calculator.calculate_alignment(
gold_transcript_lower, asr_transcript_lower
)
embedding_alignment = self._alignment_embedding_calculator.calculate_alignment(
gold_transcript_lower, asr_transcript_lower
)
soft_wer = self._wer_calculator.calculate_wer(soft_alignment)
embedding_wer = self._wer_calculator.calculate_wer(embedding_alignment)
alignment_results = {
"soft_alignment": [AlignmentStepMapper.to_json_dict(it) for it in soft_alignment],
"embedding_alignment": [
AlignmentStepMapper.to_json_dict(it) for it in embedding_alignment
],
}
wer_results = {"soft_wer": soft_wer, "embedding_wer": embedding_wer}
experiment_repository.update_property_for_key(
record_id, self._alignment_property_name, alignment_results
)
experiment_repository.update_property_for_key(
record_id, self._metrics_property_name, wer_results
)
self._embedding_transformer.clear()