diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 3995f4f36cc805167a22fd3861053a520fb53a40..ed2c1ff1c81cac086a59942168d153d1a23d24a7 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -74,8 +74,11 @@ def graph_and_tree_merge(tree_arc_scores, parse_graph[d] = sorted(parse_graph[d]) for i, g in enumerate(parse_graph): - heads = [x[0] for x in g] - rels = [x[1] for x in g] + heads = np.array([x[0] for x in g]) + rels = np.array([x[1] for x in g]) + indices = rels.argsort() + heads = heads[indices].tolist() + rels = rels[indices].tolist() deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels)) tokens[i - 1]["deps"] = deps return