Skip to content
Snippets Groups Projects
annotation_mapping.py 1.82 KiB
Newer Older
Michał Pogoda's avatar
Michał Pogoda committed
from typing import Dict, List, Tuple, TypeVar

T1 = TypeVar("T1")
T2 = TypeVar("T2")

Michał Pogoda's avatar
Michał Pogoda committed
def map_annotatios(
    ref_annotations: List[Tuple[int, int, T1]],
    all_annotations: Dict[str, List[Tuple[int, int, T2]]],
    target_columns: List[str],
) -> Dict[Tuple[int, int, T1], Dict[str, Tuple[int, int, T2]]]:
    """Map annotations from target columns to reference annotations.
Michał Pogoda's avatar
Michał Pogoda committed
    Example:
        >> ref_annotations = [(0, 3, "Andrzej"), (7, 11, "psa")]
        >> all_annotations = {
        >>     "A": [(0, 3, "Andrzej"), (7, 11, "psa")],
        >>     "B": [(0, 3, "AndrzejB"), (7, 11, "psaA")],
        >>     "C": [(0, 3, "AndrzejC"), (8, 9, "psaC")],
        >> }
        >> target_columns = ["B", "C"]
        >> map_annotatios(ref_annotations, all_annotations, target_columns)
        {
            (0, 3, "Andrzej"): {"B": (0, 3, "AndrzejB"), "C": (0, 3, "AndrzejC")},
            (7, 11, "psa"): {
               "B": (7, 11, "psaA"),
            },
        }

    Args:
        ref_annotations (List[Tuple[int, int, T1]]): Reference annotations.
        all_annotations (Dict[str, List[Tuple[int, int, T2]]]): All annotations.
        target_columns (List[str]): Target columns.

    Returns:
        Dict[Tuple[int, int, T1], Dict[str, Tuple[int, int, T2]]]: Mapped annotations.
    """
Michał Pogoda's avatar
Michał Pogoda committed
    result = dict()
    index_map = dict()

    for s_start, s_end, s_anno in ref_annotations:
        result[(s_start, s_end, s_anno)] = dict()
        index_map[(s_start, s_end)] = (s_start, s_end, s_anno)

    for target_column in target_columns:
        for t_start, t_end, t_anno in all_annotations[target_column]:
            if (t_start, t_end) in index_map:
                result[index_map[(t_start, t_end)]][target_column] = (
                    t_start,
                    t_end,
                    t_anno,
                )

    return result