from typing import Tuple, List
from xml.etree import ElementTree

from experiment.dataset_specific.pl_luna.luna_record_provider import LunaRecordProvider
from sziszapangma.integration.path_filter import ExtensionPathFilter
from sziszapangma.model.model import Word, SingleAnnotation
from sziszapangma.model.model_creators import create_new_word, create_new_single_annotation, \
    create_new_span_annotation, create_new_document
from sziszapangma.model.relation_manager import RelationManager


class LunaAdapter:
    _record_provider: LunaRecordProvider

    def __init__(self, record_provider: LunaRecordProvider):
        self._record_provider = record_provider

    @staticmethod
    def save_words(
        words_path: str,
        relation_manager: RelationManager
    ) -> Tuple[List[Word], List[SingleAnnotation]]:
        xml_tree = ElementTree.parse(words_path)
        root_element = xml_tree.getroot()
        words = []
        annotations = []
        for word_element in root_element:
            if word_element.tag == "word":
                word = create_new_word(text=word_element.attrib["word"])
                relation_manager.save_item(word)
                word_annotations = [
                    create_new_single_annotation(
                        annotation_type="lemma", value=word_element.attrib["lemma"],
                        reference_id=word['id']
                    ),
                    create_new_single_annotation(
                        annotation_type="pos", value=word_element.attrib["POS"],
                        reference_id=word['id']
                    ),
                    create_new_single_annotation(
                        annotation_type="morph", value=word_element.attrib["morph"],
                        reference_id=word['id']
                    ),
                ]
                for word_annotation in word_annotations:
                    relation_manager.save_item(word_annotation)
                    relation_manager.save_relation(word, word_annotation)
                words.append(word)
                annotations.extend(word_annotations)
        document = create_new_document([word['id'] for word in words])
        relation_manager.save_item(document)
        for word in words:
            relation_manager.save_relation(word, document)
        return words, annotations

    @staticmethod
    def parse_id_expression(word_id: str) -> int:
        return int(word_id[5:])

    def get_word_ids_list(self, words: List[Word], word_ids: str) -> List[str]:
        if word_ids == "empty":
            return []
        splitted = word_ids.split("..")
        if len(splitted) == 1:
            return [words[self.parse_id_expression(word_ids) - 1]['id']]
        else:
            index_form = self.parse_id_expression(splitted[0]) - 1
            index_to = self.parse_id_expression(splitted[1])
            return [word['id'] for word in words[index_form:index_to]]

    def read_concepts(
        self,
        words: List[Word], concept_path: str, relation_manager: RelationManager
    ) -> None:
        xml_tree = ElementTree.parse(concept_path)
        root_element = xml_tree.getroot()
        for word_element in root_element:
            if word_element.tag == "concept":
                word_ids_to_relation = self.get_word_ids_list(words, word_element.attrib["span"])
                relation = create_new_span_annotation('concept', word_ids_to_relation)
                relation_manager.save_item(relation)
                concept_value = {
                    "attribute": word_element.attrib["attribute"],
                    "value": word_element.attrib["value"],
                }
                annotation = create_new_single_annotation(
                    annotation_type="concept", value=concept_value, reference_id=relation['id']
                )
                relation_manager.save_item(annotation)
                relation_manager.save_relation(annotation, relation)
                for word in words:
                    if word['id'] in word_ids_to_relation:
                        relation_manager.save_relation(word, relation)

    def read_chunks(self,
                    words: List[Word], chunks_path: str,
                    relation_manager: RelationManager
                    ) -> None:
        xml_tree = ElementTree.parse(chunks_path)
        root_element = xml_tree.getroot()
        for word_element in root_element:
            if word_element.tag == "chunk":
                word_ids_to_relation = self.get_word_ids_list(words, word_element.attrib["span"])
                relation = create_new_span_annotation('chunk', elements=word_ids_to_relation)
                concept_value = {
                    "cat": word_element.attrib["span"],
                    # zmienić na id słowa w zbiorze
                    "main": word_element.attrib["main"]
                    if "main" in word_element.attrib
                    else None,
                }
                annotation = create_new_single_annotation(
                    annotation_type="chunk", value=concept_value, reference_id=relation['id']
                )
                relation_manager.save_item(relation)
                relation_manager.save_item(annotation)
                relation_manager.save_relation(relation, annotation)
                for word in words:
                    if word['id'] in word_ids_to_relation:
                        relation_manager.save_relation(word, relation)

    def read_turns(
        self,
        words: List[Word], turns_path: str,
        relation_manager: RelationManager
    ) -> None:
        xml_tree = ElementTree.parse(turns_path)
        root_element = xml_tree.getroot()
        for word_element in root_element:
            if word_element.tag == "Turn":
                word_ids_to_relation = self.get_word_ids_list(words, word_element.attrib["words"])
                relation = create_new_span_annotation(name='turn', elements=word_ids_to_relation)
                turn_metadata = {
                    "speaker": word_element.attrib["speaker"],
                    "startTime": word_element.attrib["startTime"],
                    "endTime": word_element.attrib["endTime"],
                }
                annotation = create_new_single_annotation(
                    annotation_type="turn", value=turn_metadata, reference_id=relation['id']
                )
                relation_manager.save_item(relation)
                relation_manager.save_item(annotation)
                relation_manager.save_relation(relation, annotation)
                for word in words:
                    if word['id'] in word_ids_to_relation:
                        relation_manager.save_relation(word, relation)

    def import_record(self, record_id: str):
        print(f'record {record_id}')
        relation_manager = self._record_provider.get_relation_manager(record_id)
        relation_manager.clear_all()
        basic_path = self._record_provider.get_path(record_id)[:-4]

        words_path = f"{basic_path}_words.xml"
        concept_path = f"{basic_path}_attvalue.xml"
        chunks_path = f"{basic_path}_chunks.xml"
        turn_path = f"{basic_path}_turns.xml"

        words, single_annotations = self.save_words(words_path, relation_manager)
        # print('save_words')
        self.read_concepts(words, concept_path, relation_manager)
        # print('read_concepts')
        self.read_chunks(words, chunks_path, relation_manager)
        # print('read_chunks')
        self.read_turns(words, turn_path, relation_manager)
        # print('read_turns')
        relation_manager.commit()


def main():
    luna_directory = 'experiment_data/dataset/pl_luna'
    luna_record_provider = LunaRecordProvider(
        ExtensionPathFilter(
            root_directory=f'{luna_directory}/LUNA.PL',
            extension='wav'
        ),
        relation_manager_root_path='experiment_data/dataset_relation_manager_data/pl_luna'
    )
    luna_adapter = LunaAdapter(luna_record_provider)
    record_ids = list(luna_record_provider.get_all_records())
    index = 0
    for record_id in record_ids:
        index += 1
        print(f'{index}/{len(record_ids)}')
        luna_adapter.import_record(record_id)


if __name__ == "__main__":
    main()
