# coding: utf8

# Copyright (C) 2017 Michał Kaliński
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import absolute_import, division


import collections as coll
import functools as funct
import json
import xml.etree.cElementTree as et

import six

from .. import enums as en


__all__ = (
    'GraphMLWordNet',
    'GraphMLBuilder',
    'GRAPH_TYPE_SYNSET',
    'GRAPH_TYPE_UNIT',
    'GRAPH_TYPE_MIXED',
    'UNS_HAS_LU',
    'UNS_IN_SYN',
)


# Constants for graphml exporting (library user should just use the string
# values).
# They also double as prefixes for IDs.
GRAPH_TYPE_SYNSET = 'synset'
GRAPH_TYPE_UNIT = 'lexical_unit'
GRAPH_TYPE_MIXED = 'mixed'
UNS_HAS_LU = u'has_unit'
UNS_IN_SYN = u'in_synset'


class GraphMLWordNet(object):
    """Stores plWordNet data as a GraphML tree.

    This is an auxiliary class which usually shouldn't be constructed directly.
    Use an appropriate method from :class:`plwn.bases.PLWordNet`.
    """

    #: Defines a possible type of a GraphML graph attribute. ``typename`` is
    #: the name of the type (value of ``attr.type`` attribute in GraphML), and
    #: ``convert`` is a function that takes a single argument and converts it
    #: to string which will be the content of a ``data`` tag.
    _DataType = coll.namedtuple('_DataType', ('typename', 'convert'))

    DATA_TYPE_INT = _DataType(u'long', lambda val: six.text_type(int(val)))
    DATA_TYPE_STR = _DataType(u'string', six.text_type)
    DATA_TYPE_OPTSTR = _DataType(
        u'string',
        lambda val: u'' if val is None else six.text_type(val),
    )
    DATA_TYPE_BOOL = _DataType(
        u'boolean',
        lambda val: u'true' if val else u'false',
    )
    DATA_TYPE_JSON = _DataType(u'string', json.dumps)
    DATA_TYPE_ENUMVAL = _DataType(
        u'string',
        lambda val: six.text_type(val.value),
    )
    # Data type for enum that can also be None.
    DATA_TYPE_OPTENUMVAL = _DataType(
        u'string',
        lambda val: u'' if val is None else six.text_type(val.value),
    )
    DATA_TYPE_ENUMSEQ = _DataType(
        u'string',
        lambda val: json.dumps(en.make_values_tuple(val)),
    )

    def __init__(self):
        self._root = et.Element(
            u'graphml',
            # The commented out xmlns declaration is correct, but inserting
            # it causes the namespace machinery of ElementTree complicate
            # everything. Only uncomment if more namespaces need to be
            # embedded in the output.
            # {'xmlns': 'http://graphml.graphdrawing.org/xmlns'},
        )
        self._graph = et.SubElement(
            self._root,
            u'graph',
            {u'id': u'plWordNet', u'edgedefault': u'directed'},
        )
        self._tree = et.ElementTree(self._root)
        self._attr_types = {}

    def add_attribute_type(self, id_, name, type_, for_=u'node'):
        """Adds attribute which can be then assigned to node or edge instances.

        :param str id_: Unique (in the whole XML) identifier of the attribute
            type.
        :param str name: Name of the attribute.
        :param _DataType type_: One of the ``DATA_TYPE_*`` constants, defining
            the type of the attribute.
        :param str for_: Should be either "node" or "edge", depending if it's a
            node attribute or an edge attribute.

        :raises ValueError: If ``type_`` or ``for_`` were passed an illegal
            value.
        """
        if not isinstance(type_, self._DataType):
            raise ValueError('type_={!r}'.format(type_))
        if for_ != u'node' and for_ != u'edge':
            raise ValueError('for={!r}'.format(for_))

        self._attr_types[id_] = type_

        et.SubElement(
            self._root,
            u'key',
            {
                u'id': id_,
                u'for': for_,
                u'attr.name': name,
                u'attr.type': type_.typename,
            },
        )

    def add_node(self, id_, attributes={}):
        """Add a node to the GraphML graph.

        This would be either a synset node or a lexical unit node, this method
        doesn't distinguish between them. The caller should include some way to
        tell them apart.

        :param str id_: Unique (in the whole XML) identifier of the node.
        :param Mapping[str,Union[int,bool,float,str]] attributes: Maps
            attribute IDs to their values. The IDs should have been previously
            defined by :meth:`.add_attribute_type`.

        :raises KeyError: If any of the names in ``attributes`` was not
            previously defined.
        """
        node = et.SubElement(
            self._graph,
            u'node',
            {u'id': id_},
        )
        self._add_attributes_to(node, attributes)

    def add_edge(self, id_, source, target, attributes={}):
        """Add an edge to the GraphML graph.

        An edge would normally designate a relation, but this method doesn't
        assume that. The caller should set an appropriate attribute for that.

        Source and target nodes do not need to have been previously defined,
        but should be eventually or the graph will be invalid.

        :param str id_: Unique (in the whole XML) identifier of the node.
        :param str source: Identifier of the source node.
        :param str target: Identifier of the target node.
        :param Mapping[str,Union[int,bool,float,str]] attributes: Maps
            attribute IDs to their values. The IDs should have been previously
            defined by :meth:`.add_attribute_type`.

        :raises KeyError: If any of the names in ``attributes`` was not
            previously defined.
        """
        edge = et.SubElement(
            self._graph,
            u'edge',
            {u'id': id_, u'source': source, u'target': target},
        )
        self._add_attributes_to(edge, attributes)

    def write(self, file_):
        """Saves the GraphML representation to a file.

        :param Union[str,TextIO] file_: Stream or name of the file to which the
        graph should be written.
        """
        self._tree.write(file_, 'utf-8')

    def _add_attributes_to(self, element, attributes):
        for attr_id, attr_val in six.iteritems(attributes):
            attr_type = self._attr_types[attr_id]
            attr = et.SubElement(
                element,
                u'data',
                {u'key': attr_id},
            )
            attr.text = attr_type.convert(attr_val)


class GraphMLBuilder(object):
    """Class that bridges.

    :class:`plwn.bases.PLWordNetBase` and :class:`GraphMLWordNet`,
     extracting data from the former and putting it
    into the latter in the appropriate format.

    This is an auxiliary class which usually shouldn't be constructed directly.
    Use an appropriate method from :class:`plwn.bases.PLWordNet`.
    """

    _EDGE_LEX_TEMPLATE = u'lu--{}--{}--{}'
    _EDGE_SYN_TEMPLATE = u'syn--{}--{}--{}'
    _EDGE_UNS_TEMPLATE = u'uns--{}--{}--{}'

    def __init__(self, plwn, gmlwn):
        """.

        :param plwn: The plWordNet instance from which the data will be
            extracted.
        :type plwn: plwn.bases.PLWordNetBase
        :param gmlwn: The GraphML storage which will receive data from
            ``plwn``.
        :type gmlwn: GraphMLWordNet
        """
        self._plwn = plwn
        self._graphout = gmlwn

        # Add attributes for relation edges. Edges are present for all graphs,
        # so they will be needed anyway.
        self._graphout.add_attribute_type(
            'edge-type',
            'type',
            GraphMLWordNet.DATA_TYPE_STR,
            'edge',
        )
        self._graphout.add_attribute_type(
            'edge-name',
            'name',
            GraphMLWordNet.DATA_TYPE_STR,
            'edge',
        )

    def synset_graph(self,
                     prefix_ids,
                     include_attributes,
                     included_attributes,
                     excluded_attributes,
                     included_nodes,
                     excluded_nodes,
                     included_relations,
                     excluded_relations,
                     skip_artificial_synsets=True):
        """See :meth:`plwn.bases.PLWordNetBase.to_graphml` for description."""
        added_attributes = (
            self._add_synset_attrs(included_attributes, excluded_attributes)
            if (include_attributes or
                included_attributes is not None or
                excluded_attributes is not None)
            else ()
        )
        visited_nodes = set()

        for edge in self._plwn.synset_relation_edges(included_relations,
                                                     excluded_relations,
                                                     skip_artificial_synsets):
            prefixed_source = self._prefix_synset_id(
                edge.source.id,
                prefix_ids,
            )
            prefixed_target = self._prefix_synset_id(
                edge.target.id,
                prefix_ids,
            )

            # Add an edge if both its endpoints are not excluded. Nodes are
            # added along edges, but it's not a problem if a valid node is not
            # included, because it will eventually be included by another edge,
            # if it's not completely secluded (and if it is, we don't want it).
            if self._check_include_exclude_2(edge.source.id,
                                             edge.target.id,
                                             included_nodes,
                                             excluded_nodes):
                if edge.source.id not in visited_nodes:
                    visited_nodes.add(edge.source.id)
                    self._graphout.add_node(
                        prefixed_source,
                        self._make_attr_dict(
                            edge.source,
                            added_attributes,
                        ),
                    )
                if edge.target.id not in visited_nodes:
                    visited_nodes.add(edge.target.id)
                    self._graphout.add_node(
                        prefixed_target,
                        self._make_attr_dict(
                            edge.target,
                            added_attributes,
                        ),
                    )

                # Now, add the edge itself
                self._graphout.add_edge(
                    self._EDGE_SYN_TEMPLATE.format(
                        prefixed_source,
                        prefixed_target,
                        edge.relation,
                    ),
                    prefixed_source,
                    prefixed_target,
                    {u'edge-type': u'relation', u'edge-name': edge.relation},
                )

    def lexical_unit_graph(self,
                           prefix_ids,
                           include_attributes,
                           included_attributes,
                           excluded_attributes,
                           included_nodes,
                           excluded_nodes,
                           included_relations,
                           excluded_relations):
        added_attributes = (
            self._add_lexunit_attrs(included_attributes, excluded_attributes)
            if (include_attributes or
                included_attributes is not None or
                excluded_attributes is not None)
            else ()
        )
        visited_nodes = set()

        for edge in self._plwn.lexical_relation_edges(included_relations,
                                                      excluded_relations):
            prefixed_source = self._prefix_lexunit_id(
                edge.source.id,
                prefix_ids,
            )
            prefixed_target = self._prefix_lexunit_id(
                edge.target.id,
                prefix_ids,
            )

            if self._check_include_exclude_2(edge.source.id,
                                             edge.target.id,
                                             included_nodes,
                                             excluded_nodes):
                if edge.source.id not in visited_nodes:
                    visited_nodes.add(edge.source.id)
                    self._graphout.add_node(
                        prefixed_source,
                        self._make_attr_dict(
                            edge.source,
                            added_attributes,
                        ),
                    )
                if edge.target.id not in visited_nodes:
                    visited_nodes.add(edge.target.id)
                    self._graphout.add_node(
                        prefixed_target,
                        self._make_attr_dict(
                            edge.target,
                            added_attributes,
                        ),
                    )

                self._graphout.add_edge(
                    self._EDGE_LEX_TEMPLATE.format(
                        prefixed_source,
                        prefixed_target,
                        edge.relation,
                    ),
                    prefixed_source,
                    prefixed_target,
                    {u'edge-type': u'relation', u'edge-name': edge.relation},
                )

    def mixed_graph(self,
                    include_attributes,
                    included_synset_attributes,
                    excluded_synset_attributes,
                    included_lexical_unit_attributes,
                    excluded_lexical_unit_attributes,
                    included_synset_relations,
                    excluded_synset_relations,
                    included_lexical_unit_relations,
                    excluded_lexical_unit_relations,
                    included_synset_nodes,
                    excluded_synset_nodes,
                    included_lexical_unit_nodes,
                    excluded_lexical_unit_nodes,
                    skip_artificial_synsets=True):

        synset_attributes = (
            self._add_synset_attrs(
                included_synset_attributes,
                excluded_synset_attributes,
            )
            if (include_attributes or
                included_synset_attributes is not None or
                excluded_synset_attributes is not None)
            else ()
        )

        lexunit_attributes = (
            self._add_lexunit_attrs(
                included_lexical_unit_attributes,
                excluded_lexical_unit_attributes,
            )

            if (include_attributes or
                included_lexical_unit_attributes is not None or
                excluded_lexical_unit_attributes is not None)
            else ()
        )

        added_synsets = set()
        empty_synsets = set()

        # Add synset edges, then add their lexunit nodes (which were not
        # excluded). Do not include lexical units from synsets that were
        # excluded.
        for syn_edge in self._plwn.synset_relation_edges(
                included_synset_relations,
                excluded_synset_relations,
                skip_artificial_synsets,
        ):

            if self._check_include_exclude_2(syn_edge.source.id,
                                             syn_edge.target.id,
                                             included_synset_nodes,
                                             excluded_synset_nodes):
                self._add_mixed_synset_edge(
                    syn_edge,
                    synset_attributes,
                    lexunit_attributes,
                    added_synsets,
                    empty_synsets,
                    included_lexical_unit_nodes,
                    excluded_lexical_unit_nodes,
                )

        for lex_edge in self._plwn.lexical_relation_edges(
                included_lexical_unit_relations,
                excluded_lexical_unit_relations,
        ):

            if self._check_include_exclude_2(lex_edge.source.id,
                                             lex_edge.target.id,
                                             included_lexical_unit_nodes,
                                             excluded_lexical_unit_nodes):
                self._add_mixed_lexunit_edge(
                    lex_edge,
                    synset_attributes,
                    lexunit_attributes,
                    added_synsets,
                    empty_synsets,
                    included_synset_nodes,
                    excluded_synset_nodes,
                )

    def _add_mixed_synset_edge(self,
                               syn_edge,
                               syn_attrs,
                               lex_attrs,
                               added_syns,
                               empty_syns,
                               included_lexs,
                               excluded_lexs):

        source_units = None
        target_units = None

        # If the synsets have not yet been yet added, get lexical units
        # that belong to them. If an empty synset is encountered,
        # remember it.
        if (syn_edge.source.id not in added_syns and
                syn_edge.source.id not in empty_syns):

            source_units = self._make_units_of_synset(
                syn_edge.source,
                included_lexs,
                excluded_lexs,
            )

            if not source_units:
                empty_syns.add(syn_edge.source.id)

        if (syn_edge.target.id not in added_syns and
                syn_edge.target.id not in empty_syns):

            target_units = self._make_units_of_synset(
                syn_edge.target,
                included_lexs,
                excluded_lexs,
            )

            if not target_units:
                empty_syns.add(syn_edge.target.id)

        prefixed_syn_source = self._prefix_synset_id(
            syn_edge.source.id,
            True,
        )
        prefixed_syn_target = self._prefix_synset_id(
            syn_edge.target.id,
            True,
        )

        # Only add the edge if both endpoints are not empty (don't
        # check from *_units, because the endpoint wasn't necessarily
        # added in this step.
        if (syn_edge.source.id not in empty_syns and
                syn_edge.target.id not in empty_syns):

            # If the source or target was not yet added, it will have
            # the units set with true value. If it's false, then it was
            # already added earlier (if it was really empty, it would
            # have been added to empty_synsets.
            if source_units:
                self._graphout.add_node(
                    prefixed_syn_source,
                    self._make_attr_dict(
                        syn_edge.source,
                        syn_attrs,
                    ),
                )
                self._add_units_of_synset(
                    prefixed_syn_source,
                    source_units,
                    lex_attrs,
                )
                added_syns.add(syn_edge.source.id)

            if target_units:
                self._graphout.add_node(
                    prefixed_syn_target,
                    self._make_attr_dict(
                        syn_edge.target,
                        syn_attrs,
                    ),
                )
                self._add_units_of_synset(
                    prefixed_syn_target,
                    target_units,
                    lex_attrs,
                )
                added_syns.add(syn_edge.target.id)

            self._graphout.add_edge(
                self._EDGE_SYN_TEMPLATE.format(
                    prefixed_syn_source,
                    prefixed_syn_target,
                    syn_edge.relation,
                ),
                prefixed_syn_source,
                prefixed_syn_target,
                {u'edge-type': u'relation', u'edge-name': syn_edge.relation},
            )

    def _add_mixed_lexunit_edge(self,
                                lex_edge,
                                syn_attrs,
                                lex_attrs,
                                added_syns,
                                empty_syns,
                                included_syns,
                                excluded_syns):

        source_synset = lex_edge.source.synset
        target_synset = lex_edge.target.synset

        # Check if one of the lexunits' synset is empty or otherwise
        # excluded.
        if (self._check_include_exclude_2(source_synset.id,
                                          target_synset.id,
                                          included_syns,
                                          excluded_syns) and
                source_synset.id not in empty_syns and
                target_synset.id not in empty_syns):

            # At this point, both lexunits and their synsets are eligible for
            # being added. And if the synset is added, then all of its
            # not-excluded lexical units get added too. This takes care of
            # adding the currently processed lexical unit nodes to the graph.
            if source_synset.id not in added_syns:
                source_units = self._make_units_of_synset(
                    source_synset,
                    included_syns,
                    excluded_syns,
                )
                prefixed_syn_source = self._prefix_synset_id(
                    source_synset.id,
                    True,
                )
                self._graphout.add_node(
                    prefixed_syn_source,
                    self._make_attr_dict(source_synset, syn_attrs),
                )
                self._add_units_of_synset(
                    prefixed_syn_source,
                    source_units,
                    lex_attrs,
                )
                added_syns.add(source_synset.id)

            if target_synset.id not in added_syns:
                target_units = self._make_units_of_synset(
                    target_synset,
                    included_syns,
                    excluded_syns,
                )
                prefixed_syn_target = self._prefix_synset_id(
                    target_synset.id,
                    True,
                )
                self._graphout.add_node(
                    prefixed_syn_target,
                    self._make_attr_dict(
                        target_synset,
                        syn_attrs,
                    ),
                )
                self._add_units_of_synset(
                    prefixed_syn_target,
                    target_units,
                    lex_attrs,
                )
                added_syns.add(target_synset.id)

            prefixed_lex_source = self._prefix_lexunit_id(
                lex_edge.source.id,
                True,
            )
            prefixed_lex_target = self._prefix_lexunit_id(
                lex_edge.target.id,
                True,
            )
            self._graphout.add_edge(
                self._EDGE_LEX_TEMPLATE.format(
                    prefixed_lex_source,
                    prefixed_lex_target,
                    lex_edge.relation,
                ),
                prefixed_lex_source,
                prefixed_lex_target,
                {u'edge-type': u'relation', u'edge-name': lex_edge.relation},
            )

    def _add_units_of_synset(self,
                             prefixed_synset_id,
                             units_of_synset,
                             attributes):

        for lu in units_of_synset:
            prefixed_lex = self._prefix_lexunit_id(lu.id, True)
            self._graphout.add_node(
                prefixed_lex,
                self._make_attr_dict(lu, attributes),
            )
            self._graphout.add_edge(
                self._EDGE_UNS_TEMPLATE.format(
                    prefixed_synset_id,
                    prefixed_lex,
                    UNS_HAS_LU,
                ),
                prefixed_synset_id,
                prefixed_lex,
                {u'edge-type': u'unit_and_synset', u'edge-name': UNS_HAS_LU},
            )
            self._graphout.add_edge(
                self._EDGE_UNS_TEMPLATE.format(
                    prefixed_lex,
                    prefixed_synset_id,
                    UNS_IN_SYN,
                ),
                prefixed_lex,
                prefixed_synset_id,
                {u'edge-type': u'unit_and_synset', u'edge-name': UNS_IN_SYN},
            )

    def _add_synset_attrs(self, included_attrs, excluded_attrs):
        includer = _AttrIncluder(
            self._graphout,
            u'syn_data',
            funct.partial(
                self._check_include_exclude,
                include_set=included_attrs,
                exclude_set=excluded_attrs,
            ),
        )

        includer(u'definition', GraphMLWordNet.DATA_TYPE_OPTSTR)
        includer(u'is_artificial', GraphMLWordNet.DATA_TYPE_BOOL)

        return includer.included_attrs

    def _add_lexunit_attrs(self, included_attrs, excluded_attrs):
        includer = _AttrIncluder(
            self._graphout,
            u'lu_data',
            funct.partial(
                self._check_include_exclude,
                include_set=included_attrs,
                exclude_set=excluded_attrs,
            ),
        )

        includer(u'lemma', GraphMLWordNet.DATA_TYPE_STR)
        includer(u'pos', GraphMLWordNet.DATA_TYPE_ENUMVAL)
        includer(u'variant', GraphMLWordNet.DATA_TYPE_INT)
        includer(u'definition', GraphMLWordNet.DATA_TYPE_OPTSTR)
        includer(u'sense_examples', GraphMLWordNet.DATA_TYPE_JSON)
        includer(u'sense_examples_sources', GraphMLWordNet.DATA_TYPE_JSON)
        includer(u'external_links', GraphMLWordNet.DATA_TYPE_JSON)
        includer(u'usage_notes', GraphMLWordNet.DATA_TYPE_JSON)
        includer(u'domain', GraphMLWordNet.DATA_TYPE_ENUMVAL)
        includer(u'verb_aspect', GraphMLWordNet.DATA_TYPE_OPTENUMVAL)
        includer(u'is_emotional', GraphMLWordNet.DATA_TYPE_BOOL)
        includer(u'emotion_markedness', GraphMLWordNet.DATA_TYPE_OPTENUMVAL)
        includer(u'emotion_names', GraphMLWordNet.DATA_TYPE_ENUMSEQ)
        includer(u'emotion_valuations', GraphMLWordNet.DATA_TYPE_ENUMSEQ)
        includer(u'emotion_example', GraphMLWordNet.DATA_TYPE_STR)
        includer(u'emotion_example_secondary', GraphMLWordNet.DATA_TYPE_STR)

        return includer.included_attrs

    @classmethod
    def _make_units_of_synset(cls, synset, included_nodes, excluded_nodes):
        return frozenset(lu
                         for lu in synset.lexical_units
                         if cls._check_include_exclude(lu.id,
                                                       included_nodes,
                                                       excluded_nodes))

    @classmethod
    def _prefix_synset_id(cls, id_, do_prefix):
        return (u'{}-{}'.format(GRAPH_TYPE_SYNSET, id_)
                if do_prefix
                else six.text_type(id_))

    @classmethod
    def _prefix_lexunit_id(cls, id_, do_prefix):
        return (u'{}-{}'.format(GRAPH_TYPE_UNIT, id_)
                if do_prefix
                else six.text_type(id_))

    @staticmethod
    def _check_include_exclude(item, include_set, exclude_set):
        """``True`` if item is in include and not in exclude.

        If the set is ``None``, the check for the set is ``True``.
        """
        return ((include_set is None or item in include_set) and
                (exclude_set is None or item not in exclude_set))

    @staticmethod
    def _check_include_exclude_2(item1, item2, include_set, exclude_set):
        """Check for two items in include/exclude (ex. for edges)."""
        return ((include_set is None or
                 (item1 in include_set and item2 in include_set)) and
                (exclude_set is None or
                 (item1 not in exclude_set and item2 not in exclude_set)))

    @staticmethod
    def _make_attr_dict(item, added_attrs):
        # It's assumed that by the time this private method gets called,
        # something hase made sure that added_attrs contains only legal values.
        # added_attrs should be a set of pairs, first one the GraphML attribute
        # key, and the second the name of the attribute of the lexical unit /
        # synset object.
        return {attrkey: getattr(item, attrname)
                for attrkey, attrname in added_attrs}


class _AttrIncluder(object):
    """Aux class.

    For the repetitive "check if attribute should be included" ->
    "store it in all required places" cycle.
    """

    def __init__(self, graphout, type_prefix, checkfunc):
        """:param GraphMLWordNet graphout: The output graph instance.

        :param str type_prefix: Unique names of attributes will be prefixed
            with this.

        :param checkfunc: Callable that should take a name of an attribute and
            return ``True`` if it should be included and ``False`` otherwise.
        :type checkfunc: Callable[[str], bool]
        """
        self._graphout = graphout
        self._prefix = type_prefix
        self._check = checkfunc
        self._added = set()

    @property
    def included_attrs(self):
        return self._added

    def __call__(self, attr_name, attr_type):
        if self._check(attr_name):
            idpair = u'{}-{}'.format(self._prefix, attr_name), attr_name
            self._added.add(idpair)
            self._graphout.add_attribute_type(*idpair, type_=attr_type)
