"""XAI results."""
import pickle
from pathlib import Path

import click
import pandas as pd
import shap
import torch

from text_attacks.utils import get_model_and_tokenizer


def build_predict_fun(model, tokenizer):
    def f(x):
        encoded_inputs = torch.tensor(
            [tokenizer.encode(
                v, padding='max_length', max_length=512, truncation=True
            ) for v in x])
        logits = model(encoded_inputs).logits
        return logits

    return f


@click.command()
@click.option(
    "--dataset_name",
    help="Dataset name",
    type=str,
)
@click.option(
    "--output_dir",
    help="Path to output directory",
    type=click.Path(path_type=Path),
)
def main(
        dataset_name: str,
        output_dir: Path,
):
    """Downloads the dataset to the output directory."""
    output_dir.mkdir(parents=True, exist_ok=True)

    model, tokenizer = get_model_and_tokenizer(
        dataset_name=dataset_name,
    )
    test = pd.read_json(f"data/datasets/{dataset_name}/adversarial.jsonl", lines=True)
    test_x = test["text"].tolist()

    predict = build_predict_fun(model, tokenizer)
    explainer = shap.Explainer(
        predict,
        masker=tokenizer,
        output_names=list(model.config.id2label.values())
    )
    shap_values = explainer(test_x)
    with open(output_dir / "shap_values.pickle", mode="wb") as fd:
        pickle.dump(shap_values, fd)


if __name__ == "__main__":
    main()