From 4c52f2a2670bfa751e3c7ecc3126d329452bde17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com> Date: Thu, 6 Apr 2023 17:11:04 +0200 Subject: [PATCH] Add chu_liu_edmonds.py --- combo/nn/__init__.py | 0 combo/nn/chu_liu_edmonds.py | 297 ++++++++++++++++++++++++++++++++++++ 2 files changed, 297 insertions(+) create mode 100644 combo/nn/__init__.py create mode 100644 combo/nn/chu_liu_edmonds.py diff --git a/combo/nn/__init__.py b/combo/nn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/combo/nn/chu_liu_edmonds.py b/combo/nn/chu_liu_edmonds.py new file mode 100644 index 0000000..be8f037 --- /dev/null +++ b/combo/nn/chu_liu_edmonds.py @@ -0,0 +1,297 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/nn/chu_liu_edmonds.py +""" +from typing import List, Set, Tuple, Dict +import numpy + +from combo.utils import ConfigurationError + + +def decode_mst( + energy: numpy.ndarray, length: int, has_labels: bool = True +) -> Tuple[numpy.ndarray, numpy.ndarray]: + """ + Note: Counter to typical intuition, this function decodes the _maximum_ + spanning tree. + + Decode the optimal MST tree with the Chu-Liu-Edmonds algorithm for + maximum spanning arborescences on graphs. + + # Parameters + + energy : `numpy.ndarray`, required. + A tensor with shape (num_labels, timesteps, timesteps) + containing the energy of each edge. If has_labels is `False`, + the tensor should have shape (timesteps, timesteps) instead. + length : `int`, required. + The length of this sequence, as the energy may have come + from a padded batch. + has_labels : `bool`, optional, (default = `True`) + Whether the graph has labels or not. + """ + if has_labels and energy.ndim != 3: + raise ConfigurationError("The dimension of the energy array is not equal to 3.") + elif not has_labels and energy.ndim != 2: + raise ConfigurationError("The dimension of the energy array is not equal to 2.") + input_shape = energy.shape + max_length = input_shape[-1] + + # Our energy matrix might have been batched - + # here we clip it to contain only non padded tokens. + if has_labels: + energy = energy[:, :length, :length] + # get best label for each edge. + label_id_matrix = energy.argmax(axis=0) + energy = energy.max(axis=0) + else: + energy = energy[:length, :length] + label_id_matrix = None + # get original score matrix + original_score_matrix = energy + # initialize score matrix to original score matrix + score_matrix = numpy.array(original_score_matrix, copy=True) + + old_input = numpy.zeros([length, length], dtype=numpy.int32) + old_output = numpy.zeros([length, length], dtype=numpy.int32) + current_nodes = [True for _ in range(length)] + representatives: List[Set[int]] = [] + + for node1 in range(length): + original_score_matrix[node1, node1] = 0.0 + score_matrix[node1, node1] = 0.0 + representatives.append({node1}) + + for node2 in range(node1 + 1, length): + old_input[node1, node2] = node1 + old_output[node1, node2] = node2 + + old_input[node2, node1] = node2 + old_output[node2, node1] = node1 + + final_edges: Dict[int, int] = {} + + # The main algorithm operates inplace. + chu_liu_edmonds( + length, score_matrix, current_nodes, final_edges, old_input, old_output, representatives + ) + + heads = numpy.zeros([max_length], numpy.int32) + if has_labels: + head_type = numpy.ones([max_length], numpy.int32) + else: + head_type = None + + for child, parent in final_edges.items(): + heads[child] = parent + if has_labels: + head_type[child] = label_id_matrix[parent, child] + + return heads, head_type + + +def chu_liu_edmonds( + length: int, + score_matrix: numpy.ndarray, + current_nodes: List[bool], + final_edges: Dict[int, int], + old_input: numpy.ndarray, + old_output: numpy.ndarray, + representatives: List[Set[int]], +): + """ + Applies the chu-liu-edmonds algorithm recursively + to a graph with edge weights defined by score_matrix. + + Note that this function operates in place, so variables + will be modified. + + # Parameters + + length : `int`, required. + The number of nodes. + score_matrix : `numpy.ndarray`, required. + The score matrix representing the scores for pairs + of nodes. + current_nodes : `List[bool]`, required. + The nodes which are representatives in the graph. + A representative at it's most basic represents a node, + but as the algorithm progresses, individual nodes will + represent collapsed cycles in the graph. + final_edges : `Dict[int, int]`, required. + An empty dictionary which will be populated with the + nodes which are connected in the maximum spanning tree. + old_input : `numpy.ndarray`, required. + old_output : `numpy.ndarray`, required. + representatives : `List[Set[int]]`, required. + A list containing the nodes that a particular node + is representing at this iteration in the graph. + + # Returns + + Nothing - all variables are modified in place. + + """ + # Set the initial graph to be the greedy best one. + parents = [-1] + for node1 in range(1, length): + parents.append(0) + if current_nodes[node1]: + max_score = score_matrix[0, node1] + for node2 in range(1, length): + if node2 == node1 or not current_nodes[node2]: + continue + + new_score = score_matrix[node2, node1] + if new_score > max_score: + max_score = new_score + parents[node1] = node2 + + # Check if this solution has a cycle. + has_cycle, cycle = _find_cycle(parents, length, current_nodes) + # If there are no cycles, find all edges and return. + if not has_cycle: + final_edges[0] = -1 + for node in range(1, length): + if not current_nodes[node]: + continue + + parent = old_input[parents[node], node] + child = old_output[parents[node], node] + final_edges[child] = parent + return + + # Otherwise, we have a cycle so we need to remove an edge. + # From here until the recursive call is the contraction stage of the algorithm. + cycle_weight = 0.0 + # Find the weight of the cycle. + index = 0 + for node in cycle: + index += 1 + cycle_weight += score_matrix[parents[node], node] + + # For each node in the graph, find the maximum weight incoming + # and outgoing edge into the cycle. + cycle_representative = cycle[0] + for node in range(length): + if not current_nodes[node] or node in cycle: + continue + + in_edge_weight = float("-inf") + in_edge = -1 + out_edge_weight = float("-inf") + out_edge = -1 + + for node_in_cycle in cycle: + if score_matrix[node_in_cycle, node] > in_edge_weight: + in_edge_weight = score_matrix[node_in_cycle, node] + in_edge = node_in_cycle + + # Add the new edge score to the cycle weight + # and subtract the edge we're considering removing. + score = ( + cycle_weight + + score_matrix[node, node_in_cycle] + - score_matrix[parents[node_in_cycle], node_in_cycle] + ) + + if score > out_edge_weight: + out_edge_weight = score + out_edge = node_in_cycle + + score_matrix[cycle_representative, node] = in_edge_weight + old_input[cycle_representative, node] = old_input[in_edge, node] + old_output[cycle_representative, node] = old_output[in_edge, node] + + score_matrix[node, cycle_representative] = out_edge_weight + old_output[node, cycle_representative] = old_output[node, out_edge] + old_input[node, cycle_representative] = old_input[node, out_edge] + + # For the next recursive iteration, we want to consider the cycle as a + # single node. Here we collapse the cycle into the first node in the + # cycle (first node is arbitrary), set all the other nodes not be + # considered in the next iteration. We also keep track of which + # representatives we are considering this iteration because we need + # them below to check if we're done. + considered_representatives: List[Set[int]] = [] + for i, node_in_cycle in enumerate(cycle): + considered_representatives.append(set()) + if i > 0: + # We need to consider at least one + # node in the cycle, arbitrarily choose + # the first. + current_nodes[node_in_cycle] = False + + for node in representatives[node_in_cycle]: + considered_representatives[i].add(node) + if i > 0: + representatives[cycle_representative].add(node) + + chu_liu_edmonds( + length, score_matrix, current_nodes, final_edges, old_input, old_output, representatives + ) + + # Expansion stage. + # check each node in cycle, if one of its representatives + # is a key in the final_edges, it is the one we need. + found = False + key_node = -1 + for i, node in enumerate(cycle): + for cycle_rep in considered_representatives[i]: + if cycle_rep in final_edges: + key_node = node + found = True + break + if found: + break + + previous = parents[key_node] + while previous != key_node: + child = old_output[parents[previous], previous] + parent = old_input[parents[previous], previous] + final_edges[child] = parent + previous = parents[previous] + + +def _find_cycle( + parents: List[int], length: int, current_nodes: List[bool] +) -> Tuple[bool, List[int]]: + + added = [False for _ in range(length)] + added[0] = True + cycle = set() + has_cycle = False + for i in range(1, length): + if has_cycle: + break + # don't redo nodes we've already + # visited or aren't considering. + if added[i] or not current_nodes[i]: + continue + # Initialize a new possible cycle. + this_cycle = set() + this_cycle.add(i) + added[i] = True + has_cycle = True + next_node = i + while parents[next_node] not in this_cycle: + next_node = parents[next_node] + # If we see a node we've already processed, + # we can stop, because the node we are + # processing would have been in that cycle. + if added[next_node]: + has_cycle = False + break + added[next_node] = True + this_cycle.add(next_node) + + if has_cycle: + original = next_node + cycle.add(original) + next_node = parents[original] + while next_node != original: + cycle.add(next_node) + next_node = parents[next_node] + break + + return has_cycle, list(cycle) -- GitLab