from typing import List, Tuple

import pytest

from sziszapangma.core.wer.classic_wer_calculator import ClassicWerCalculator
from sziszapangma.core.wer.step_type import StepType
from sziszapangma.core.wer.step_words import StepWords


def get_sample_data() -> Tuple[List[str], List[str]]:
    reference = ['This', 'great', 'machine', 'can', 'recognize', 'speech']
    hypothesis = ['This', 'machine', 'can', 'wreck', 'a', 'nice', 'beach']
    return reference, hypothesis


def test_classic_calculate_wer_value():
    """Sample test for core calculate."""
    reference, hypothesis = get_sample_data()
    wer_result = ClassicWerCalculator().calculate_wer(reference, hypothesis)
    assert pytest.approx(wer_result[0]) == 0.8333333


def test_classic_calculate_wer_steps():
    """Sample test for core calculate."""
    reference, hypothesis = get_sample_data()
    wer_result = ClassicWerCalculator().calculate_wer(reference, hypothesis)

    reference_words = [
        StepWords('This', 'This'), StepWords('great', None),
        StepWords('machine', 'machine'), StepWords('can', 'can'),
        StepWords(None, 'wreck'), StepWords(None, 'a'),
        StepWords('recognize', 'nice'),
        StepWords('speech', 'beach')]
    step_types = [
        StepType.CORRECT, StepType.DELETION, StepType.CORRECT,
        StepType.CORRECT, StepType.INSERTION, StepType.INSERTION,
        StepType.SUBSTITUTION, StepType.SUBSTITUTION]

    assert len(wer_result[1]) == 8
    assert [it.step_type for it in wer_result[1]] == step_types
    assert [it.step_cost for it in wer_result[1]] == [0, 1, 0, 0, 1, 1, 1, 1]
    assert [it.step_words for it in wer_result[1]] == reference_words