From 9bdaf9cba085e2fc481b04c1fb20034bf7499ffb Mon Sep 17 00:00:00 2001
From: Mateusz Gniewkowski <mateusz.gniewkowski@pwr.edu.pl>
Date: Thu, 23 Mar 2023 12:45:45 +0000
Subject: [PATCH] Xai test

---
 dvc.lock                           | 90 +++++++++++++++---------------
 experiments/scripts/explain.py     | 44 ++++++++++++++-
 experiments/scripts/tag_dataset.py | 52 +++++++++++++----
 requirements.txt                   |  3 +
 4 files changed, 132 insertions(+), 57 deletions(-)

diff --git a/dvc.lock b/dvc.lock
index 18c0a65..d1fc740 100644
--- a/dvc.lock
+++ b/dvc.lock
@@ -17,8 +17,8 @@ stages:
       --output_dir data/models/enron_spam
     deps:
     - path: data/preprocessed/enron_spam
-      md5: b75efba1a62182dc8ac32acd1faf92ed.dir
-      size: 61709260
+      md5: 99d604f84516cee94948054a97ffec5e.dir
+      size: 71403809
       nfiles: 3
     - path: experiments/scripts/get_model.py
       md5: 5050f51b4019bba97af47971f6c7cab4
@@ -37,8 +37,8 @@ stages:
       size: 18505614
       nfiles: 6
     - path: data/preprocessed/enron_spam/
-      md5: b75efba1a62182dc8ac32acd1faf92ed.dir
-      size: 61709260
+      md5: 99d604f84516cee94948054a97ffec5e.dir
+      size: 71403809
       nfiles: 3
     - path: experiments/scripts/classify.py
       md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5
@@ -57,17 +57,17 @@ stages:
       size: 18505614
       nfiles: 6
     - path: data/preprocessed/enron_spam
-      md5: b75efba1a62182dc8ac32acd1faf92ed.dir
-      size: 61709260
+      md5: 99d604f84516cee94948054a97ffec5e.dir
+      size: 71403809
       nfiles: 3
     - path: experiments/scripts/explain.py
-      md5: afc02ef263a59c911098dea969faa932
-      size: 3234
+      md5: 4d82a557627f59c884f52fd7994ed80a
+      size: 4617
     outs:
     - path: data/explanations/enron_spam/
-      md5: 70c3f3d04e0b73fd56eecfda04914bd4.dir
-      size: 13589794
-      nfiles: 403
+      md5: 4a4630717bf6bf5027f1b70191c69a4c.dir
+      size: 33723647
+      nfiles: 4003
   download_dataset@poleval:
     cmd: PYTHONPATH=. python experiments/scripts/download_dataset.py --dataset_name
       poleval --output_dir data/datasets/poleval
@@ -88,12 +88,12 @@ stages:
       size: 1688836
       nfiles: 3
     - path: experiments/scripts/tag_dataset.py
-      md5: f73e2203fdb988a00d4e8363a349c617
-      size: 3932
+      md5: ebadced7a031a31bdaf935d2b22e5e05
+      size: 4632
     outs:
     - path: data/preprocessed/poleval/
-      md5: 9d067db65ba6a27db19effce45b01876.dir
-      size: 2541105
+      md5: b0ea9f0ad1dba6d3b474c0a3cedf866e.dir
+      size: 2812175
       nfiles: 3
   preprocess_dataset@enron_spam:
     cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name enron_spam
@@ -103,12 +103,12 @@ stages:
       size: 53096069
       nfiles: 3
     - path: experiments/scripts/tag_dataset.py
-      md5: f73e2203fdb988a00d4e8363a349c617
-      size: 3932
+      md5: ebadced7a031a31bdaf935d2b22e5e05
+      size: 4632
     outs:
     - path: data/preprocessed/enron_spam/
-      md5: 30c63efbc615347ddcb5f61e011113bd.dir
-      size: 65971374
+      md5: 99d604f84516cee94948054a97ffec5e.dir
+      size: 71403809
       nfiles: 3
   preprocess_dataset@wiki_pl:
     cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name wiki_pl
@@ -118,12 +118,12 @@ stages:
       size: 29115538
       nfiles: 3
     - path: experiments/scripts/tag_dataset.py
-      md5: f73e2203fdb988a00d4e8363a349c617
-      size: 3932
+      md5: ebadced7a031a31bdaf935d2b22e5e05
+      size: 4632
     outs:
     - path: data/preprocessed/wiki_pl/
-      md5: 0014b9bb52913cbc9a568d237ea2207b.dir
-      size: 65553079
+      md5: 066634606f832b6c9d1db95293de7e04.dir
+      size: 77818549
       nfiles: 3
   classify@wiki_pl:
     cmd: PYTHONPATH=. python experiments/scripts/classify.py --dataset_name wiki_pl
@@ -134,8 +134,8 @@ stages:
       size: 501711136
       nfiles: 7
     - path: data/preprocessed/wiki_pl/
-      md5: 3e9b2e1e0542777e0a751d9d7f7f4241.dir
-      size: 55380570
+      md5: 066634606f832b6c9d1db95293de7e04.dir
+      size: 77818549
       nfiles: 3
     - path: experiments/scripts/classify.py
       md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5
@@ -153,12 +153,12 @@ stages:
       size: 23460358
       nfiles: 3
     - path: experiments/scripts/tag_dataset.py
-      md5: f73e2203fdb988a00d4e8363a349c617
-      size: 3932
+      md5: ebadced7a031a31bdaf935d2b22e5e05
+      size: 4632
     outs:
     - path: data/preprocessed/20_news/
-      md5: 20da0980e52df537e5b7ca5db0305879.dir
-      size: 58582060
+      md5: a3d2da9ac72423e555ae7ed051741b30.dir
+      size: 69405970
       nfiles: 3
   classify@20_news:
     cmd: PYTHONPATH=. python experiments/scripts/classify.py --dataset_name 20_news
@@ -169,8 +169,8 @@ stages:
       size: 439008808
       nfiles: 6
     - path: data/preprocessed/20_news/
-      md5: 1ed5ef2dabe4bc05f7377175ed11137b.dir
-      size: 46845669
+      md5: a3d2da9ac72423e555ae7ed051741b30.dir
+      size: 69405970
       nfiles: 3
     - path: experiments/scripts/classify.py
       md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5
@@ -208,17 +208,17 @@ stages:
       size: 501711136
       nfiles: 7
     - path: data/preprocessed/wiki_pl
-      md5: 3e9b2e1e0542777e0a751d9d7f7f4241.dir
-      size: 55380570
+      md5: 066634606f832b6c9d1db95293de7e04.dir
+      size: 77818549
       nfiles: 3
     - path: experiments/scripts/explain.py
-      md5: afc02ef263a59c911098dea969faa932
-      size: 3234
+      md5: 4d82a557627f59c884f52fd7994ed80a
+      size: 4617
     outs:
     - path: data/explanations/wiki_pl/
-      md5: 5a3b9b069024456412078143e3af15d7.dir
-      size: 331450794
-      nfiles: 10065
+      md5: a11dc3a6b329a1d6bacd50c158ff3e6c.dir
+      size: 1096043911
+      nfiles: 100403
   explain@20_news:
     cmd: PYTHONPATH=. python experiments/scripts/explain.py --dataset_name 20_news
       --output_dir data/explanations/20_news
@@ -228,14 +228,14 @@ stages:
       size: 439008808
       nfiles: 6
     - path: data/preprocessed/20_news
-      md5: 1ed5ef2dabe4bc05f7377175ed11137b.dir
-      size: 46845669
+      md5: a3d2da9ac72423e555ae7ed051741b30.dir
+      size: 69405970
       nfiles: 3
     - path: experiments/scripts/explain.py
-      md5: afc02ef263a59c911098dea969faa932
-      size: 3234
+      md5: 4d82a557627f59c884f52fd7994ed80a
+      size: 4617
     outs:
     - path: data/explanations/20_news/
-      md5: c8ba90f9757a4e3cc4843d3791ef2446.dir
-      size: 232912969
-      nfiles: 14041
+      md5: ad1f9f0df287078edebed1e408df2c9f.dir
+      size: 869336544
+      nfiles: 140401
diff --git a/experiments/scripts/explain.py b/experiments/scripts/explain.py
index 92c85df..46120cd 100644
--- a/experiments/scripts/explain.py
+++ b/experiments/scripts/explain.py
@@ -8,6 +8,7 @@ import numpy as np
 import pandas as pd
 import shap
 import torch
+from tqdm import tqdm
 
 from text_attacks.utils import get_model_and_tokenizer
 
@@ -79,6 +80,14 @@ def main(
         f"data/preprocessed/{dataset_name}/adversarial.jsonl", lines=True
     )
     test_x = test["text"].tolist()
+    test_x = [
+        tokenizer.decode(
+            tokenizer.encode(
+                t, padding="do_not_pad", max_length=512, truncation=True
+            ),
+            skip_special_tokens=True
+        ) for t in test_x
+    ]
 
     predict = build_predict_fun(model, tokenizer)
     explainer = shap.Explainer(
@@ -100,14 +109,47 @@ def main(
 
     # LOCAL IMPORTANCE
     for class_id, class_name in model.config.id2label.items():
-        sub_dir = output_dir / "local" / class_name
+        sub_dir = output_dir / "local" / "adversarial" /class_name
         os.makedirs(sub_dir, exist_ok=True)
         for shap_id, text_id in enumerate(test["id"]):
             importance_df = get_importance(shap_values[shap_id, :, class_id])
             importance_df.to_json(
                 sub_dir / f"{text_id}__importance.json",
             )
+    
+    # LOCAL IMPORTANCE (test set)
+    test = pd.read_json(
+        f"data/preprocessed/{dataset_name}/test.jsonl", lines=True
+    )
+    test_x = test["text"].tolist()
+    test_x = [
+        tokenizer.decode(
+            tokenizer.encode(
+                t, padding="do_not_pad", max_length=512, truncation=True
+            ),
+            skip_special_tokens=True
+        ) for t in test_x
+    ]
 
+    predict = build_predict_fun(model, tokenizer)
+    explainer = shap.Explainer(
+        predict,
+        masker=tokenizer,
+        output_names=list(model.config.id2label.values())
+    )
+    for text_id, text in tqdm(
+        zip(test["id"], test_x),
+        total=len(test_x),
+        desc="Shap for test DS",
+    ):
+        shap_values = explainer([text])
+        for class_id, class_name in model.config.id2label.items():
+            sub_dir = output_dir / "local" / "test" / class_name
+            os.makedirs(sub_dir, exist_ok=True)
+            importance_df = get_importance(shap_values[0, :, class_id])
+            importance_df.to_json(
+                sub_dir / f"{text_id}__importance.json",
+            )
 
 if __name__ == "__main__":
     main()
diff --git a/experiments/scripts/tag_dataset.py b/experiments/scripts/tag_dataset.py
index 911983a..2ecc511 100644
--- a/experiments/scripts/tag_dataset.py
+++ b/experiments/scripts/tag_dataset.py
@@ -6,6 +6,8 @@ import json
 import os
 from tqdm import tqdm
 from multiprocessing import cpu_count, Pool
+import spacy
+
 
 TOKENS = "tokens"
 ORTH = "orth"
@@ -54,7 +56,7 @@ def tag_sentence(sentence: str, lang: str):
 def process_file(dataset_df, lang, output_path):
     test_with_tags = pd.DataFrame(dataset_df)
     lemmas_col, tags_col, orth_col = [], [], []
-    cpus = 8
+    cpus = 2
     with Pool(processes=cpus) as pool:
         results = []
         for idx in tqdm(range(0, len(dataset_df), cpus)):
@@ -73,8 +75,28 @@ def process_file(dataset_df, lang, output_path):
     test_with_tags[TAGS] = tags_col
     test_with_tags[ORTHS] = orth_col
 
-    with open(output_path, mode="wt") as fd:
-        fd.write(test_with_tags.to_json(orient="records", lines=True))
+    return test_with_tags
+
+
+def add_ner(dataset_df, language):
+    model = "en_core_web_trf" if language == "en" else "pl_core_news_lg"
+    nlp = spacy.load(model)
+    ner_data = list()
+
+    for text in tqdm(dataset_df["text"]):
+        doc = nlp(text)
+        doc_ner = list()
+        for ent in doc.ents:
+            doc_ner.append({
+                "text": ent.text,
+                "start_char": ent.start_char,
+                "end_char": ent.end_char,
+                "label": ent.label_,
+            })
+        ner_data.append(doc_ner)
+
+    dataset_df["ner"] = ner_data
+    return dataset_df
 
 
 @click.command()
@@ -83,7 +105,13 @@ def process_file(dataset_df, lang, output_path):
     help="Dataset name",
     type=str,
 )
-def main(dataset_name: str):
+@click.option(
+    "--output",
+    help="Output directory",
+    type=str,
+
+)
+def main(dataset_name: str, output: str):
     """Downloads the dataset to the output directory."""
     lang = {
         "enron_spam": "en",
@@ -91,18 +119,19 @@ def main(dataset_name: str):
         "20_news": "en",
         "wiki_pl": "pl",
     }[dataset_name]
-    output_dir = f"data/preprocessed/{dataset_name}"
+    output_dir = f"{output}/{dataset_name}"
     os.makedirs(output_dir, exist_ok=True)
 
     input_dir = f"data/datasets/{dataset_name}"
     for file in os.listdir(input_dir):
         if os.path.isfile(os.path.join(input_dir, file)):
-            if file == "test.jsonl":
-                process_file(
+            if file in ["test.jsonl", "adversarial.jsonl"]:
+                test_with_tags = process_file(
                     pd.read_json(os.path.join(input_dir, file), lines=True),
                     lang,
                     os.path.join(output_dir, file),
                 )
+                test_with_tags = add_ner(test_with_tags, lang)
             else:
                 test_with_tags = pd.DataFrame(
                     pd.read_json(os.path.join(input_dir, file), lines=True)
@@ -111,10 +140,11 @@ def main(dataset_name: str):
                 test_with_tags[LEMMAS] = empty_list
                 test_with_tags[TAGS] = empty_list
                 test_with_tags[ORTHS] = empty_list
-                with open(os.path.join(output_dir, file), mode="wt") as fd:
-                    fd.write(
-                        test_with_tags.to_json(orient="records", lines=True)
-                    )
+                test_with_tags["ner"] = empty_list
+        with open(os.path.join(output_dir, file), mode="wt") as fd:
+            fd.write(
+                test_with_tags.to_json(orient="records", lines=True)
+            )
 
 
 if __name__ == "__main__":
diff --git a/requirements.txt b/requirements.txt
index 8e069b3..bc59ed1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,6 +9,9 @@ tokenizers==0.13.2
 sentence-transformers==2.2.2
 jupyter
 matplotlib
+spacy
+https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.5.0/en_core_web_trf-3.5.0-py3-none-any.whl
+https://github.com/explosion/spacy-models/releases/download/pl_core_news_lg-3.5.0/pl_core_news_lg-3.5.0-py3-none-any.whl
 
 --find-links https://download.pytorch.org/whl/torch_stable.html
 torch==1.12.0+cu116
-- 
GitLab