From 4c52f2a2670bfa751e3c7ecc3126d329452bde17 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com>
Date: Thu, 6 Apr 2023 17:11:04 +0200
Subject: [PATCH] Add chu_liu_edmonds.py

---
 combo/nn/__init__.py        |   0
 combo/nn/chu_liu_edmonds.py | 297 ++++++++++++++++++++++++++++++++++++
 2 files changed, 297 insertions(+)
 create mode 100644 combo/nn/__init__.py
 create mode 100644 combo/nn/chu_liu_edmonds.py

diff --git a/combo/nn/__init__.py b/combo/nn/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/combo/nn/chu_liu_edmonds.py b/combo/nn/chu_liu_edmonds.py
new file mode 100644
index 0000000..be8f037
--- /dev/null
+++ b/combo/nn/chu_liu_edmonds.py
@@ -0,0 +1,297 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/nn/chu_liu_edmonds.py
+"""
+from typing import List, Set, Tuple, Dict
+import numpy
+
+from combo.utils import ConfigurationError
+
+
+def decode_mst(
+    energy: numpy.ndarray, length: int, has_labels: bool = True
+) -> Tuple[numpy.ndarray, numpy.ndarray]:
+    """
+    Note: Counter to typical intuition, this function decodes the _maximum_
+    spanning tree.
+
+    Decode the optimal MST tree with the Chu-Liu-Edmonds algorithm for
+    maximum spanning arborescences on graphs.
+
+    # Parameters
+
+    energy : `numpy.ndarray`, required.
+        A tensor with shape (num_labels, timesteps, timesteps)
+        containing the energy of each edge. If has_labels is `False`,
+        the tensor should have shape (timesteps, timesteps) instead.
+    length : `int`, required.
+        The length of this sequence, as the energy may have come
+        from a padded batch.
+    has_labels : `bool`, optional, (default = `True`)
+        Whether the graph has labels or not.
+    """
+    if has_labels and energy.ndim != 3:
+        raise ConfigurationError("The dimension of the energy array is not equal to 3.")
+    elif not has_labels and energy.ndim != 2:
+        raise ConfigurationError("The dimension of the energy array is not equal to 2.")
+    input_shape = energy.shape
+    max_length = input_shape[-1]
+
+    # Our energy matrix might have been batched -
+    # here we clip it to contain only non padded tokens.
+    if has_labels:
+        energy = energy[:, :length, :length]
+        # get best label for each edge.
+        label_id_matrix = energy.argmax(axis=0)
+        energy = energy.max(axis=0)
+    else:
+        energy = energy[:length, :length]
+        label_id_matrix = None
+    # get original score matrix
+    original_score_matrix = energy
+    # initialize score matrix to original score matrix
+    score_matrix = numpy.array(original_score_matrix, copy=True)
+
+    old_input = numpy.zeros([length, length], dtype=numpy.int32)
+    old_output = numpy.zeros([length, length], dtype=numpy.int32)
+    current_nodes = [True for _ in range(length)]
+    representatives: List[Set[int]] = []
+
+    for node1 in range(length):
+        original_score_matrix[node1, node1] = 0.0
+        score_matrix[node1, node1] = 0.0
+        representatives.append({node1})
+
+        for node2 in range(node1 + 1, length):
+            old_input[node1, node2] = node1
+            old_output[node1, node2] = node2
+
+            old_input[node2, node1] = node2
+            old_output[node2, node1] = node1
+
+    final_edges: Dict[int, int] = {}
+
+    # The main algorithm operates inplace.
+    chu_liu_edmonds(
+        length, score_matrix, current_nodes, final_edges, old_input, old_output, representatives
+    )
+
+    heads = numpy.zeros([max_length], numpy.int32)
+    if has_labels:
+        head_type = numpy.ones([max_length], numpy.int32)
+    else:
+        head_type = None
+
+    for child, parent in final_edges.items():
+        heads[child] = parent
+        if has_labels:
+            head_type[child] = label_id_matrix[parent, child]
+
+    return heads, head_type
+
+
+def chu_liu_edmonds(
+    length: int,
+    score_matrix: numpy.ndarray,
+    current_nodes: List[bool],
+    final_edges: Dict[int, int],
+    old_input: numpy.ndarray,
+    old_output: numpy.ndarray,
+    representatives: List[Set[int]],
+):
+    """
+    Applies the chu-liu-edmonds algorithm recursively
+    to a graph with edge weights defined by score_matrix.
+
+    Note that this function operates in place, so variables
+    will be modified.
+
+    # Parameters
+
+    length : `int`, required.
+        The number of nodes.
+    score_matrix : `numpy.ndarray`, required.
+        The score matrix representing the scores for pairs
+        of nodes.
+    current_nodes : `List[bool]`, required.
+        The nodes which are representatives in the graph.
+        A representative at it's most basic represents a node,
+        but as the algorithm progresses, individual nodes will
+        represent collapsed cycles in the graph.
+    final_edges : `Dict[int, int]`, required.
+        An empty dictionary which will be populated with the
+        nodes which are connected in the maximum spanning tree.
+    old_input : `numpy.ndarray`, required.
+    old_output : `numpy.ndarray`, required.
+    representatives : `List[Set[int]]`, required.
+        A list containing the nodes that a particular node
+        is representing at this iteration in the graph.
+
+    # Returns
+
+    Nothing - all variables are modified in place.
+
+    """
+    # Set the initial graph to be the greedy best one.
+    parents = [-1]
+    for node1 in range(1, length):
+        parents.append(0)
+        if current_nodes[node1]:
+            max_score = score_matrix[0, node1]
+            for node2 in range(1, length):
+                if node2 == node1 or not current_nodes[node2]:
+                    continue
+
+                new_score = score_matrix[node2, node1]
+                if new_score > max_score:
+                    max_score = new_score
+                    parents[node1] = node2
+
+    # Check if this solution has a cycle.
+    has_cycle, cycle = _find_cycle(parents, length, current_nodes)
+    # If there are no cycles, find all edges and return.
+    if not has_cycle:
+        final_edges[0] = -1
+        for node in range(1, length):
+            if not current_nodes[node]:
+                continue
+
+            parent = old_input[parents[node], node]
+            child = old_output[parents[node], node]
+            final_edges[child] = parent
+        return
+
+    # Otherwise, we have a cycle so we need to remove an edge.
+    # From here until the recursive call is the contraction stage of the algorithm.
+    cycle_weight = 0.0
+    # Find the weight of the cycle.
+    index = 0
+    for node in cycle:
+        index += 1
+        cycle_weight += score_matrix[parents[node], node]
+
+    # For each node in the graph, find the maximum weight incoming
+    # and outgoing edge into the cycle.
+    cycle_representative = cycle[0]
+    for node in range(length):
+        if not current_nodes[node] or node in cycle:
+            continue
+
+        in_edge_weight = float("-inf")
+        in_edge = -1
+        out_edge_weight = float("-inf")
+        out_edge = -1
+
+        for node_in_cycle in cycle:
+            if score_matrix[node_in_cycle, node] > in_edge_weight:
+                in_edge_weight = score_matrix[node_in_cycle, node]
+                in_edge = node_in_cycle
+
+            # Add the new edge score to the cycle weight
+            # and subtract the edge we're considering removing.
+            score = (
+                cycle_weight
+                + score_matrix[node, node_in_cycle]
+                - score_matrix[parents[node_in_cycle], node_in_cycle]
+            )
+
+            if score > out_edge_weight:
+                out_edge_weight = score
+                out_edge = node_in_cycle
+
+        score_matrix[cycle_representative, node] = in_edge_weight
+        old_input[cycle_representative, node] = old_input[in_edge, node]
+        old_output[cycle_representative, node] = old_output[in_edge, node]
+
+        score_matrix[node, cycle_representative] = out_edge_weight
+        old_output[node, cycle_representative] = old_output[node, out_edge]
+        old_input[node, cycle_representative] = old_input[node, out_edge]
+
+    # For the next recursive iteration, we want to consider the cycle as a
+    # single node. Here we collapse the cycle into the first node in the
+    # cycle (first node is arbitrary), set all the other nodes not be
+    # considered in the next iteration. We also keep track of which
+    # representatives we are considering this iteration because we need
+    # them below to check if we're done.
+    considered_representatives: List[Set[int]] = []
+    for i, node_in_cycle in enumerate(cycle):
+        considered_representatives.append(set())
+        if i > 0:
+            # We need to consider at least one
+            # node in the cycle, arbitrarily choose
+            # the first.
+            current_nodes[node_in_cycle] = False
+
+        for node in representatives[node_in_cycle]:
+            considered_representatives[i].add(node)
+            if i > 0:
+                representatives[cycle_representative].add(node)
+
+    chu_liu_edmonds(
+        length, score_matrix, current_nodes, final_edges, old_input, old_output, representatives
+    )
+
+    # Expansion stage.
+    # check each node in cycle, if one of its representatives
+    # is a key in the final_edges, it is the one we need.
+    found = False
+    key_node = -1
+    for i, node in enumerate(cycle):
+        for cycle_rep in considered_representatives[i]:
+            if cycle_rep in final_edges:
+                key_node = node
+                found = True
+                break
+        if found:
+            break
+
+    previous = parents[key_node]
+    while previous != key_node:
+        child = old_output[parents[previous], previous]
+        parent = old_input[parents[previous], previous]
+        final_edges[child] = parent
+        previous = parents[previous]
+
+
+def _find_cycle(
+    parents: List[int], length: int, current_nodes: List[bool]
+) -> Tuple[bool, List[int]]:
+
+    added = [False for _ in range(length)]
+    added[0] = True
+    cycle = set()
+    has_cycle = False
+    for i in range(1, length):
+        if has_cycle:
+            break
+        # don't redo nodes we've already
+        # visited or aren't considering.
+        if added[i] or not current_nodes[i]:
+            continue
+        # Initialize a new possible cycle.
+        this_cycle = set()
+        this_cycle.add(i)
+        added[i] = True
+        has_cycle = True
+        next_node = i
+        while parents[next_node] not in this_cycle:
+            next_node = parents[next_node]
+            # If we see a node we've already processed,
+            # we can stop, because the node we are
+            # processing would have been in that cycle.
+            if added[next_node]:
+                has_cycle = False
+                break
+            added[next_node] = True
+            this_cycle.add(next_node)
+
+        if has_cycle:
+            original = next_node
+            cycle.add(original)
+            next_node = parents[original]
+            while next_node != original:
+                cycle.add(next_node)
+                next_node = parents[next_node]
+            break
+
+    return has_cycle, list(cycle)
-- 
GitLab