"""Utility functions for replacing substrings in a string."""

from typing import List, Tuple, TypeVar


def replace(original_string: str, replacements: List[Tuple[int, int, str]]) -> str:
    """Replaces substrings in a string.

    !!! Important: This function assumes that there are no overlapping
    annotations.

    Args:
        original_string (str): The original string.
        replacements (List[Tuple[int, int, str]]): A list of tuples containing
            (start, end, replacement).

    Returns:
        str: The string with replacements applied.

    """
    replacements = sorted(replacements, key=lambda x: x[0])

    delta = 0
    for replacement in replacements:
        original_string = (
            original_string[: replacement[0] + delta]
            + replacement[2]
            + original_string[replacement[1] + delta :]
        )
        delta += len(replacement[2]) - (replacement[1] - replacement[0])

    return original_string


_T = TypeVar("_T")


def replace_and_update(
    original_string: str,
    replacements: List[Tuple[int, int, str]],
    other_annotations: List[Tuple[int, int, _T]],
) -> Tuple[str, List[Tuple[int, int, _T]]]:
    """Replaces parts of a string and updates annotations to match new string.

    !!! Important: This function assumes that there are no overlapping
    annotations.

    Args:
        original_string (str): The original string.
        replacements (List[Tuple[int, int, str]]): A list of tuples containing
            (start, end, replacement).
        other_annotations (List[Tuple[int, int, _T]]): A list of other
            annotations.

    Returns:
        Tuple[str, List[Tuple[int, int, _T]]]: The string with replacements
            applied and other annotations with new positions.

    """
    joined_list = []
    for replacement in replacements:
        joined_list.append((replacement[0], replacement[1], replacement[2], True))
    for other_annotation in other_annotations:
        joined_list.append(
            (other_annotation[0], other_annotation[1], other_annotation[2], False)
        )

    annotations = sorted(joined_list, key=lambda x: x[0])

    new_other_annotations = []

    delta = 0
    for annotation in annotations:
        is_replacement = annotation[3]

        if is_replacement:
            original_string = (
                original_string[: annotation[0] + delta]
                + annotation[2]
                + original_string[annotation[1] + delta :]
            )
            delta += len(annotation[2]) - (annotation[1] - annotation[0])
        else:
            new_other_annotations.append(
                (annotation[0] + delta, annotation[1] + delta, annotation[2])
            )

    return original_string, new_other_annotations