From 9bdaf9cba085e2fc481b04c1fb20034bf7499ffb Mon Sep 17 00:00:00 2001 From: Mateusz Gniewkowski <mateusz.gniewkowski@pwr.edu.pl> Date: Thu, 23 Mar 2023 12:45:45 +0000 Subject: [PATCH] Xai test --- dvc.lock | 90 +++++++++++++++--------------- experiments/scripts/explain.py | 44 ++++++++++++++- experiments/scripts/tag_dataset.py | 52 +++++++++++++---- requirements.txt | 3 + 4 files changed, 132 insertions(+), 57 deletions(-) diff --git a/dvc.lock b/dvc.lock index 18c0a65..d1fc740 100644 --- a/dvc.lock +++ b/dvc.lock @@ -17,8 +17,8 @@ stages: --output_dir data/models/enron_spam deps: - path: data/preprocessed/enron_spam - md5: b75efba1a62182dc8ac32acd1faf92ed.dir - size: 61709260 + md5: 99d604f84516cee94948054a97ffec5e.dir + size: 71403809 nfiles: 3 - path: experiments/scripts/get_model.py md5: 5050f51b4019bba97af47971f6c7cab4 @@ -37,8 +37,8 @@ stages: size: 18505614 nfiles: 6 - path: data/preprocessed/enron_spam/ - md5: b75efba1a62182dc8ac32acd1faf92ed.dir - size: 61709260 + md5: 99d604f84516cee94948054a97ffec5e.dir + size: 71403809 nfiles: 3 - path: experiments/scripts/classify.py md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 @@ -57,17 +57,17 @@ stages: size: 18505614 nfiles: 6 - path: data/preprocessed/enron_spam - md5: b75efba1a62182dc8ac32acd1faf92ed.dir - size: 61709260 + md5: 99d604f84516cee94948054a97ffec5e.dir + size: 71403809 nfiles: 3 - path: experiments/scripts/explain.py - md5: afc02ef263a59c911098dea969faa932 - size: 3234 + md5: 4d82a557627f59c884f52fd7994ed80a + size: 4617 outs: - path: data/explanations/enron_spam/ - md5: 70c3f3d04e0b73fd56eecfda04914bd4.dir - size: 13589794 - nfiles: 403 + md5: 4a4630717bf6bf5027f1b70191c69a4c.dir + size: 33723647 + nfiles: 4003 download_dataset@poleval: cmd: PYTHONPATH=. python experiments/scripts/download_dataset.py --dataset_name poleval --output_dir data/datasets/poleval @@ -88,12 +88,12 @@ stages: size: 1688836 nfiles: 3 - path: experiments/scripts/tag_dataset.py - md5: f73e2203fdb988a00d4e8363a349c617 - size: 3932 + md5: ebadced7a031a31bdaf935d2b22e5e05 + size: 4632 outs: - path: data/preprocessed/poleval/ - md5: 9d067db65ba6a27db19effce45b01876.dir - size: 2541105 + md5: b0ea9f0ad1dba6d3b474c0a3cedf866e.dir + size: 2812175 nfiles: 3 preprocess_dataset@enron_spam: cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name enron_spam @@ -103,12 +103,12 @@ stages: size: 53096069 nfiles: 3 - path: experiments/scripts/tag_dataset.py - md5: f73e2203fdb988a00d4e8363a349c617 - size: 3932 + md5: ebadced7a031a31bdaf935d2b22e5e05 + size: 4632 outs: - path: data/preprocessed/enron_spam/ - md5: 30c63efbc615347ddcb5f61e011113bd.dir - size: 65971374 + md5: 99d604f84516cee94948054a97ffec5e.dir + size: 71403809 nfiles: 3 preprocess_dataset@wiki_pl: cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name wiki_pl @@ -118,12 +118,12 @@ stages: size: 29115538 nfiles: 3 - path: experiments/scripts/tag_dataset.py - md5: f73e2203fdb988a00d4e8363a349c617 - size: 3932 + md5: ebadced7a031a31bdaf935d2b22e5e05 + size: 4632 outs: - path: data/preprocessed/wiki_pl/ - md5: 0014b9bb52913cbc9a568d237ea2207b.dir - size: 65553079 + md5: 066634606f832b6c9d1db95293de7e04.dir + size: 77818549 nfiles: 3 classify@wiki_pl: cmd: PYTHONPATH=. python experiments/scripts/classify.py --dataset_name wiki_pl @@ -134,8 +134,8 @@ stages: size: 501711136 nfiles: 7 - path: data/preprocessed/wiki_pl/ - md5: 3e9b2e1e0542777e0a751d9d7f7f4241.dir - size: 55380570 + md5: 066634606f832b6c9d1db95293de7e04.dir + size: 77818549 nfiles: 3 - path: experiments/scripts/classify.py md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 @@ -153,12 +153,12 @@ stages: size: 23460358 nfiles: 3 - path: experiments/scripts/tag_dataset.py - md5: f73e2203fdb988a00d4e8363a349c617 - size: 3932 + md5: ebadced7a031a31bdaf935d2b22e5e05 + size: 4632 outs: - path: data/preprocessed/20_news/ - md5: 20da0980e52df537e5b7ca5db0305879.dir - size: 58582060 + md5: a3d2da9ac72423e555ae7ed051741b30.dir + size: 69405970 nfiles: 3 classify@20_news: cmd: PYTHONPATH=. python experiments/scripts/classify.py --dataset_name 20_news @@ -169,8 +169,8 @@ stages: size: 439008808 nfiles: 6 - path: data/preprocessed/20_news/ - md5: 1ed5ef2dabe4bc05f7377175ed11137b.dir - size: 46845669 + md5: a3d2da9ac72423e555ae7ed051741b30.dir + size: 69405970 nfiles: 3 - path: experiments/scripts/classify.py md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 @@ -208,17 +208,17 @@ stages: size: 501711136 nfiles: 7 - path: data/preprocessed/wiki_pl - md5: 3e9b2e1e0542777e0a751d9d7f7f4241.dir - size: 55380570 + md5: 066634606f832b6c9d1db95293de7e04.dir + size: 77818549 nfiles: 3 - path: experiments/scripts/explain.py - md5: afc02ef263a59c911098dea969faa932 - size: 3234 + md5: 4d82a557627f59c884f52fd7994ed80a + size: 4617 outs: - path: data/explanations/wiki_pl/ - md5: 5a3b9b069024456412078143e3af15d7.dir - size: 331450794 - nfiles: 10065 + md5: a11dc3a6b329a1d6bacd50c158ff3e6c.dir + size: 1096043911 + nfiles: 100403 explain@20_news: cmd: PYTHONPATH=. python experiments/scripts/explain.py --dataset_name 20_news --output_dir data/explanations/20_news @@ -228,14 +228,14 @@ stages: size: 439008808 nfiles: 6 - path: data/preprocessed/20_news - md5: 1ed5ef2dabe4bc05f7377175ed11137b.dir - size: 46845669 + md5: a3d2da9ac72423e555ae7ed051741b30.dir + size: 69405970 nfiles: 3 - path: experiments/scripts/explain.py - md5: afc02ef263a59c911098dea969faa932 - size: 3234 + md5: 4d82a557627f59c884f52fd7994ed80a + size: 4617 outs: - path: data/explanations/20_news/ - md5: c8ba90f9757a4e3cc4843d3791ef2446.dir - size: 232912969 - nfiles: 14041 + md5: ad1f9f0df287078edebed1e408df2c9f.dir + size: 869336544 + nfiles: 140401 diff --git a/experiments/scripts/explain.py b/experiments/scripts/explain.py index 92c85df..46120cd 100644 --- a/experiments/scripts/explain.py +++ b/experiments/scripts/explain.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd import shap import torch +from tqdm import tqdm from text_attacks.utils import get_model_and_tokenizer @@ -79,6 +80,14 @@ def main( f"data/preprocessed/{dataset_name}/adversarial.jsonl", lines=True ) test_x = test["text"].tolist() + test_x = [ + tokenizer.decode( + tokenizer.encode( + t, padding="do_not_pad", max_length=512, truncation=True + ), + skip_special_tokens=True + ) for t in test_x + ] predict = build_predict_fun(model, tokenizer) explainer = shap.Explainer( @@ -100,14 +109,47 @@ def main( # LOCAL IMPORTANCE for class_id, class_name in model.config.id2label.items(): - sub_dir = output_dir / "local" / class_name + sub_dir = output_dir / "local" / "adversarial" /class_name os.makedirs(sub_dir, exist_ok=True) for shap_id, text_id in enumerate(test["id"]): importance_df = get_importance(shap_values[shap_id, :, class_id]) importance_df.to_json( sub_dir / f"{text_id}__importance.json", ) + + # LOCAL IMPORTANCE (test set) + test = pd.read_json( + f"data/preprocessed/{dataset_name}/test.jsonl", lines=True + ) + test_x = test["text"].tolist() + test_x = [ + tokenizer.decode( + tokenizer.encode( + t, padding="do_not_pad", max_length=512, truncation=True + ), + skip_special_tokens=True + ) for t in test_x + ] + predict = build_predict_fun(model, tokenizer) + explainer = shap.Explainer( + predict, + masker=tokenizer, + output_names=list(model.config.id2label.values()) + ) + for text_id, text in tqdm( + zip(test["id"], test_x), + total=len(test_x), + desc="Shap for test DS", + ): + shap_values = explainer([text]) + for class_id, class_name in model.config.id2label.items(): + sub_dir = output_dir / "local" / "test" / class_name + os.makedirs(sub_dir, exist_ok=True) + importance_df = get_importance(shap_values[0, :, class_id]) + importance_df.to_json( + sub_dir / f"{text_id}__importance.json", + ) if __name__ == "__main__": main() diff --git a/experiments/scripts/tag_dataset.py b/experiments/scripts/tag_dataset.py index 911983a..2ecc511 100644 --- a/experiments/scripts/tag_dataset.py +++ b/experiments/scripts/tag_dataset.py @@ -6,6 +6,8 @@ import json import os from tqdm import tqdm from multiprocessing import cpu_count, Pool +import spacy + TOKENS = "tokens" ORTH = "orth" @@ -54,7 +56,7 @@ def tag_sentence(sentence: str, lang: str): def process_file(dataset_df, lang, output_path): test_with_tags = pd.DataFrame(dataset_df) lemmas_col, tags_col, orth_col = [], [], [] - cpus = 8 + cpus = 2 with Pool(processes=cpus) as pool: results = [] for idx in tqdm(range(0, len(dataset_df), cpus)): @@ -73,8 +75,28 @@ def process_file(dataset_df, lang, output_path): test_with_tags[TAGS] = tags_col test_with_tags[ORTHS] = orth_col - with open(output_path, mode="wt") as fd: - fd.write(test_with_tags.to_json(orient="records", lines=True)) + return test_with_tags + + +def add_ner(dataset_df, language): + model = "en_core_web_trf" if language == "en" else "pl_core_news_lg" + nlp = spacy.load(model) + ner_data = list() + + for text in tqdm(dataset_df["text"]): + doc = nlp(text) + doc_ner = list() + for ent in doc.ents: + doc_ner.append({ + "text": ent.text, + "start_char": ent.start_char, + "end_char": ent.end_char, + "label": ent.label_, + }) + ner_data.append(doc_ner) + + dataset_df["ner"] = ner_data + return dataset_df @click.command() @@ -83,7 +105,13 @@ def process_file(dataset_df, lang, output_path): help="Dataset name", type=str, ) -def main(dataset_name: str): +@click.option( + "--output", + help="Output directory", + type=str, + +) +def main(dataset_name: str, output: str): """Downloads the dataset to the output directory.""" lang = { "enron_spam": "en", @@ -91,18 +119,19 @@ def main(dataset_name: str): "20_news": "en", "wiki_pl": "pl", }[dataset_name] - output_dir = f"data/preprocessed/{dataset_name}" + output_dir = f"{output}/{dataset_name}" os.makedirs(output_dir, exist_ok=True) input_dir = f"data/datasets/{dataset_name}" for file in os.listdir(input_dir): if os.path.isfile(os.path.join(input_dir, file)): - if file == "test.jsonl": - process_file( + if file in ["test.jsonl", "adversarial.jsonl"]: + test_with_tags = process_file( pd.read_json(os.path.join(input_dir, file), lines=True), lang, os.path.join(output_dir, file), ) + test_with_tags = add_ner(test_with_tags, lang) else: test_with_tags = pd.DataFrame( pd.read_json(os.path.join(input_dir, file), lines=True) @@ -111,10 +140,11 @@ def main(dataset_name: str): test_with_tags[LEMMAS] = empty_list test_with_tags[TAGS] = empty_list test_with_tags[ORTHS] = empty_list - with open(os.path.join(output_dir, file), mode="wt") as fd: - fd.write( - test_with_tags.to_json(orient="records", lines=True) - ) + test_with_tags["ner"] = empty_list + with open(os.path.join(output_dir, file), mode="wt") as fd: + fd.write( + test_with_tags.to_json(orient="records", lines=True) + ) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index 8e069b3..bc59ed1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,9 @@ tokenizers==0.13.2 sentence-transformers==2.2.2 jupyter matplotlib +spacy +https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.5.0/en_core_web_trf-3.5.0-py3-none-any.whl +https://github.com/explosion/spacy-models/releases/download/pl_core_news_lg-3.5.0/pl_core_news_lg-3.5.0-py3-none-any.whl --find-links https://download.pytorch.org/whl/torch_stable.html torch==1.12.0+cu116 -- GitLab