diff --git a/combo/predict.py b/combo/predict.py index e52b42ef94f92239410489082530f564729ff18c..c58db250c792ddd1562ab2d7aa4908af52c54191 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -207,6 +207,8 @@ class SemanticMultitaskPredictor(predictor.Predictor): tree_tokens=tree_tokens, root_idx=self.vocab.get_token_index("root", "deprel_labels"), vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) + empty_tokens = graph.restore_collapse_edges(tree_tokens) + tree.tokens.extend(empty_tokens) return tree, predictions["sentence_embedding"] diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 1785b4bcd5ca7e5e407f32651f23b395ea78a7a5..32e7dd944999c9c2f7c8e800289eb2f35d14a4bc 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -85,3 +85,30 @@ def _dfs(graph, start, end): if next_state in path: continue fringe.append((next_state, path + [next_state])) + + +def restore_collapse_edges(tree_tokens): + empty_tokens = [] + for token in tree_tokens: + deps = token["deps"].split("|") + for i, d in enumerate(deps): + if ">" in d: + # {head}:{empty_node_relation}>{current_node_relation} + # should map to + # For new, empty node: + # {head}:{empty_node_relation} + # For current node: + # {new_empty_node_id}:{current_node_relation} + # TODO consider where to put new_empty_node_id (currently at the end) + head, relation = d.split(':', 1) + ehead = f"{len(tree_tokens)}.{len(empty_tokens) + 1}" + empty_node_relation, current_node_relation = relation.split(">", 1) + deps[i] = f"{ehead}:{current_node_relation}" + empty_tokens.append( + { + "id": ehead, + "deps": f"{head}:{empty_node_relation}" + } + ) + token["deps"] = "|".join(deps) + return empty_tokens