import unittest
import combo.utils.graph as graph

import conllu
import numpy as np


class GraphTest(unittest.TestCase):

    def test_adding_empty_graph_with_the_same_labels(self):
        tree = conllu.TokenList(
            tokens=[
                {"head": 0, "deprel": "root", "form": "word1"},
                {"head": 3, "deprel": "yes", "form": "word2"},
                {"head": 1, "deprel": "yes", "form": "word3"},
            ]
        )
        vocab_index = {0: "root", 1: "yes", 2: "yes", 3: "yes"}
        empty_graph = np.zeros((4, 4))
        graph_labels = np.zeros((4, 4, 4))
        expected_deps = ["0:root", "3:yes", "1:yes"]

        # when
        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
        actual_deps = [t["deps"] for t in tree.tokens]

        # then
        self.assertEqual(expected_deps, actual_deps)

    def test_adding_empty_graph_with_different_labels(self):
        tree = conllu.TokenList(
            tokens=[
                {"head": 0, "deprel": "root", "form": "word1"},
                {"head": 3, "deprel": "tree_label", "form": "word2"},
                {"head": 1, "deprel": "tree_label", "form": "word3"},
            ]
        )
        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
        empty_graph = np.zeros((4, 4))
        graph_labels = np.zeros((4, 4, 3))
        graph_labels[2][3][2] = 10e10
        graph_labels[3][1][2] = 10e10
        expected_deps = ["0:root", "3:graph_label", "1:graph_label"]

        # when
        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
        actual_deps = [t["deps"] for t in tree.tokens]

        # then
        self.assertEqual(actual_deps, expected_deps)

    def test_extending_tree_with_graph(self):
        # given
        tree = conllu.TokenList(
            tokens=[
                {"head": 0, "deprel": "root", "form": "word1"},
                {"head": 1, "deprel": "tree_label", "form": "word2"},
                {"head": 2, "deprel": "tree_label", "form": "word3"},
            ]
        )
        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
        arc_scores = np.array([
            [0, 0, 0, 0],
            [1, 0, 0, 0],
            [0, 1, 0, 0],
            [0, 1, 1, 0],
        ])
        graph_labels = np.zeros((4, 4, 3))
        graph_labels[3][1][2] = 10e10
        expected_deps = ["0:root", "1:tree_label", "1:graph_label|2:tree_label"]

        # when
        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens,  root_idx=0, vocab_index=vocab_index)
        actual_deps = [t["deps"] for t in tree.tokens]

        # then
        self.assertEqual(actual_deps, expected_deps)

    def test_extending_tree_with_self_loop_edge_shouldnt_add_edge(self):
        # given
        tree = conllu.TokenList(
            tokens=[
                {"head": 0, "deprel": "root", "form": "word1"},
                {"head": 1, "deprel": "tree_label", "form": "word2"},
                {"head": 2, "deprel": "tree_label", "form": "word3"},
            ]
        )
        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
        arc_scores = np.array([
            [0, 0, 0, 0],
            [1, 0, 0, 0],
            [0, 1, 0, 0],
            [0, 0, 1, 1],
        ])
        graph_labels = np.zeros((4, 4, 3))
        graph_labels[3][3][2] = 10e10
        expected_deps = ["0:root", "1:tree_label", "2:tree_label"]
        # TODO current actual, adds self-loop
        # actual_deps = ["0:root", "1:tree_label", "2:tree_label|3:graph_label"]

        # when
        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens,  root_idx=0, vocab_index=vocab_index)
        actual_deps = [t["deps"] for t in tree.tokens]

        # then
        self.assertEqual(expected_deps, actual_deps)
