Skip to content
Snippets Groups Projects
test_classic_wer.py 2.19 KiB
from typing import List, Tuple

import pytest

from sziszapangma.core.alignment.alignment_classic_calculator import AlignmentClassicCalculator
from sziszapangma.core.alignment.step_type import StepType
from sziszapangma.core.alignment.step_words import StepWords
from sziszapangma.core.wer.wer_calculator import WerCalculator
from sziszapangma.model.model import Word
from sziszapangma.model.model_creators import create_new_word


def string_list_to_words(strings: List[str]) -> List[Word]:
    return [create_new_word(it) for it in strings]


def get_sample_data() -> Tuple[List[Word], List[Word]]:
    reference = ["This", "great", "machine", "can", "recognize", "speech"]
    hypothesis = ["This", "machine", "can", "wreck", "a", "nice", "beach"]
    return string_list_to_words(reference), string_list_to_words(hypothesis)


def test_classic_calculate_wer_value():
    """Sample test for core calculate."""
    reference, hypothesis = get_sample_data()
    alignment = AlignmentClassicCalculator().calculate_alignment(reference, hypothesis)
    wer_result = WerCalculator().calculate_wer(alignment)
    assert pytest.approx(wer_result) == 0.8333333


def test_classic_calculate_wer_steps():
    """Sample test for core calculate."""
    reference, hypothesis = get_sample_data()
    alignment = AlignmentClassicCalculator().calculate_alignment(reference, hypothesis)

    reference_words = [
        StepWords(reference[0], hypothesis[0]),
        StepWords(reference[1], None),
        StepWords(reference[2], hypothesis[1]),
        StepWords(reference[3], hypothesis[2]),
        StepWords(None, hypothesis[3]),
        StepWords(None, hypothesis[4]),
        StepWords(reference[4], hypothesis[5]),
        StepWords(reference[5], hypothesis[6]),
    ]
    step_types = [
        StepType.CORRECT,
        StepType.DELETION,
        StepType.CORRECT,
        StepType.CORRECT,
        StepType.INSERTION,
        StepType.INSERTION,
        StepType.SUBSTITUTION,
        StepType.SUBSTITUTION,
    ]

    assert len(alignment) == 8
    assert [it.step_type for it in alignment] == step_types
    assert [it.step_cost for it in alignment] == [0, 1, 0, 0, 1, 1, 1, 1]
    assert [it.step_words for it in alignment] == reference_words