"""Script to train Enhanced Dependency Parsing models based on IWPT'21 Shared Task data.

For possible requirements, see train_eud.py comments.
"""

import os
import pathlib
from typing import List

from absl import app
from absl import flags

from scripts import utils

LANG2TREEBANK = {
    "ar": ["Arabic-PADT"],
    "bg": ["Bulgarian-BTB"],
    "cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"],
    "nl": ["Dutch-Alpino", "Dutch-LassySmall"],
    "en": ["English-EWT", "English-PUD", "English-GUM"],
    "et": ["Estonian-EDT", "Estonian-EWT"],
    "fi": ["Finnish-TDT", "Finnish-PUD"],
    "fr": ["French-Sequoia", "French-FQB"],
    "it": ["Italian-ISDT"],
    "lv": ["Latvian-LVTB"],
    "lt": ["Lithuanian-ALKSNIS"],
    "pl": ["Polish-LFG", "Polish-PDB", "Polish-PUD"],
    "ru": ["Russian-SynTagRus"],
    "sk": ["Slovak-SNK"],
    "sv": ["Swedish-Talbanken", "Swedish-PUD"],
    "ta": ["Tamil-TTB"],
    "uk": ["Ukrainian-IU"],
}

FLAGS = flags.FLAGS
flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()),
                  help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.")
flags.DEFINE_string(name="data_dir", default="",
                    help="Path to IWPT'21 data directory.")
flags.DEFINE_string(name="serialization_dir", default="/tmp/",
                    help="Model serialization dir.")
flags.DEFINE_integer(name="cuda_device", default=-1,
                     help="Cuda device id (-1 for cpu).")


def path_to_str(path: pathlib.Path) -> str:
    return str(path.resolve())


def merge_files(files: List[str], output: pathlib.Path):
    if not output.exists():
        os.system(f"cat {' '.join(files)} > {output}")


def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
    output_path = pathlib.Path(output)
    if not output_path.exists():
        utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
                              f"{path_to_str(treebank_file)}", output)


def run(_):
    languages = FLAGS.lang
    for lang in languages:
        assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}."
        assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict."
        data_dir = pathlib.Path(FLAGS.data_dir)
        assert data_dir.is_dir(), f"'{data_dir}' is not a directory!"

        treebanks = LANG2TREEBANK[lang]
        full_language = treebanks[0].split("-")[0]
        train_paths = []
        dev_paths = []
        train_raw_paths = []
        dev_raw_paths = []
        # TODO Uncomment when IWPT'21 Shared Task ends.
        # During shared task duration test data is not available.
        test_paths = []
        for treebank in treebanks:
            treebank_dir = data_dir / f"UD_{treebank}"
            assert treebank_dir.exists() and treebank_dir.is_dir(), f"'{treebank_dir}' directory doesn't exists."
            for treebank_file in treebank_dir.iterdir():
                name = treebank_file.name
                if "conllu" in name and "fixed" not in name:
                    output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
                    if "train" in name:
                        collapse_nodes(data_dir, treebank_file, output)
                        train_paths.append(output)
                    elif "dev" in name:
                        collapse_nodes(data_dir, treebank_file, output)
                        dev_paths.append(output)
                    # elif "test" in name:
                    #     collapse_nodes(data_dir, treebank_file, output)
                    #     test_paths.append(output)
                if ".txt" in name:
                    if "train" in name:
                        train_raw_paths.append(path_to_str(treebank_file))
                    elif "dev" in name:
                        dev_raw_paths.append(path_to_str(treebank_file))

        merged_dataset_name = "IWPT"
        lang_data_dir = pathlib.Path(data_dir / f"UD_{full_language}-{merged_dataset_name}")
        lang_data_dir.mkdir(exist_ok=True)

        suffix = f"{lang}_{merged_dataset_name}-ud".lower()
        train_path = lang_data_dir / f"{suffix}-train.conllu"
        dev_path = lang_data_dir / f"{suffix}-dev.conllu"
        test_path = lang_data_dir / f"{suffix}-test.conllu"
        train_raw_path = lang_data_dir / f"{suffix}-train.txt"
        dev_raw_path = lang_data_dir / f"{suffix}-dev.txt"
        test_raw_path = lang_data_dir / f"{suffix}-test.txt"

        merge_files(train_paths, output=train_path)
        merge_files(dev_paths, output=dev_path)
        # TODO Change to test_paths instead of dev_paths after IWPT'21
        merge_files(dev_paths, output=test_path)

        merge_files(train_raw_paths, output=train_raw_path)
        merge_files(dev_raw_paths, output=dev_raw_path)
        # TODO Change to test_raw_paths instead of dev_paths after IWPT'21
        merge_files(dev_raw_paths, output=test_raw_path)

        serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang
        serialization_dir.mkdir(exist_ok=True, parents=True)

        command = f"""combo --mode train
        --training_data {train_path}
        --validation_data {dev_path}
        --targets feats,upostag,xpostag,head,deprel,lemma,deps
        --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]}
        --serialization_dir {serialization_dir}
        --cuda_device {FLAGS.cuda_device}
        --word_batch_size 2500
        --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'}
        --notensorboard
        """

        # Datasets without XPOS
        if lang in {"fr"}:
            command = command + " --targets deprel,head,upostag,lemma,feats"

        utils.execute_command("".join(command.splitlines()))


def main():
    app.run(run)


if __name__ == "__main__":
    main()