Skip to content
Snippets Groups Projects
Commit 6139763f authored by Mateusz Gniewkowski's avatar Mateusz Gniewkowski
Browse files

Merge branch 'xai_test' into 'master'

Xai test

See merge request adversarial-attacks/text-attacks!4
parents 2ac1bad2 9bdaf9cb
Branches
No related merge requests found
......@@ -17,8 +17,8 @@ stages:
--output_dir data/models/enron_spam
deps:
- path: data/preprocessed/enron_spam
md5: b75efba1a62182dc8ac32acd1faf92ed.dir
size: 61709260
md5: 99d604f84516cee94948054a97ffec5e.dir
size: 71403809
nfiles: 3
- path: experiments/scripts/get_model.py
md5: 5050f51b4019bba97af47971f6c7cab4
......@@ -37,8 +37,8 @@ stages:
size: 18505614
nfiles: 6
- path: data/preprocessed/enron_spam/
md5: b75efba1a62182dc8ac32acd1faf92ed.dir
size: 61709260
md5: 99d604f84516cee94948054a97ffec5e.dir
size: 71403809
nfiles: 3
- path: experiments/scripts/classify.py
md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5
......@@ -57,17 +57,17 @@ stages:
size: 18505614
nfiles: 6
- path: data/preprocessed/enron_spam
md5: b75efba1a62182dc8ac32acd1faf92ed.dir
size: 61709260
md5: 99d604f84516cee94948054a97ffec5e.dir
size: 71403809
nfiles: 3
- path: experiments/scripts/explain.py
md5: afc02ef263a59c911098dea969faa932
size: 3234
md5: 4d82a557627f59c884f52fd7994ed80a
size: 4617
outs:
- path: data/explanations/enron_spam/
md5: 70c3f3d04e0b73fd56eecfda04914bd4.dir
size: 13589794
nfiles: 403
md5: 4a4630717bf6bf5027f1b70191c69a4c.dir
size: 33723647
nfiles: 4003
download_dataset@poleval:
cmd: PYTHONPATH=. python experiments/scripts/download_dataset.py --dataset_name
poleval --output_dir data/datasets/poleval
......@@ -88,12 +88,12 @@ stages:
size: 1688836
nfiles: 3
- path: experiments/scripts/tag_dataset.py
md5: f73e2203fdb988a00d4e8363a349c617
size: 3932
md5: ebadced7a031a31bdaf935d2b22e5e05
size: 4632
outs:
- path: data/preprocessed/poleval/
md5: 9d067db65ba6a27db19effce45b01876.dir
size: 2541105
md5: b0ea9f0ad1dba6d3b474c0a3cedf866e.dir
size: 2812175
nfiles: 3
preprocess_dataset@enron_spam:
cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name enron_spam
......@@ -103,12 +103,12 @@ stages:
size: 53096069
nfiles: 3
- path: experiments/scripts/tag_dataset.py
md5: f73e2203fdb988a00d4e8363a349c617
size: 3932
md5: ebadced7a031a31bdaf935d2b22e5e05
size: 4632
outs:
- path: data/preprocessed/enron_spam/
md5: 30c63efbc615347ddcb5f61e011113bd.dir
size: 65971374
md5: 99d604f84516cee94948054a97ffec5e.dir
size: 71403809
nfiles: 3
preprocess_dataset@wiki_pl:
cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name wiki_pl
......@@ -118,12 +118,12 @@ stages:
size: 29115538
nfiles: 3
- path: experiments/scripts/tag_dataset.py
md5: f73e2203fdb988a00d4e8363a349c617
size: 3932
md5: ebadced7a031a31bdaf935d2b22e5e05
size: 4632
outs:
- path: data/preprocessed/wiki_pl/
md5: 0014b9bb52913cbc9a568d237ea2207b.dir
size: 65553079
md5: 066634606f832b6c9d1db95293de7e04.dir
size: 77818549
nfiles: 3
classify@wiki_pl:
cmd: PYTHONPATH=. python experiments/scripts/classify.py --dataset_name wiki_pl
......@@ -134,8 +134,8 @@ stages:
size: 501711136
nfiles: 7
- path: data/preprocessed/wiki_pl/
md5: 3e9b2e1e0542777e0a751d9d7f7f4241.dir
size: 55380570
md5: 066634606f832b6c9d1db95293de7e04.dir
size: 77818549
nfiles: 3
- path: experiments/scripts/classify.py
md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5
......@@ -153,12 +153,12 @@ stages:
size: 23460358
nfiles: 3
- path: experiments/scripts/tag_dataset.py
md5: f73e2203fdb988a00d4e8363a349c617
size: 3932
md5: ebadced7a031a31bdaf935d2b22e5e05
size: 4632
outs:
- path: data/preprocessed/20_news/
md5: 20da0980e52df537e5b7ca5db0305879.dir
size: 58582060
md5: a3d2da9ac72423e555ae7ed051741b30.dir
size: 69405970
nfiles: 3
classify@20_news:
cmd: PYTHONPATH=. python experiments/scripts/classify.py --dataset_name 20_news
......@@ -169,8 +169,8 @@ stages:
size: 439008808
nfiles: 6
- path: data/preprocessed/20_news/
md5: 1ed5ef2dabe4bc05f7377175ed11137b.dir
size: 46845669
md5: a3d2da9ac72423e555ae7ed051741b30.dir
size: 69405970
nfiles: 3
- path: experiments/scripts/classify.py
md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5
......@@ -208,17 +208,17 @@ stages:
size: 501711136
nfiles: 7
- path: data/preprocessed/wiki_pl
md5: 3e9b2e1e0542777e0a751d9d7f7f4241.dir
size: 55380570
md5: 066634606f832b6c9d1db95293de7e04.dir
size: 77818549
nfiles: 3
- path: experiments/scripts/explain.py
md5: afc02ef263a59c911098dea969faa932
size: 3234
md5: 4d82a557627f59c884f52fd7994ed80a
size: 4617
outs:
- path: data/explanations/wiki_pl/
md5: 5a3b9b069024456412078143e3af15d7.dir
size: 331450794
nfiles: 10065
md5: a11dc3a6b329a1d6bacd50c158ff3e6c.dir
size: 1096043911
nfiles: 100403
explain@20_news:
cmd: PYTHONPATH=. python experiments/scripts/explain.py --dataset_name 20_news
--output_dir data/explanations/20_news
......@@ -228,14 +228,14 @@ stages:
size: 439008808
nfiles: 6
- path: data/preprocessed/20_news
md5: 1ed5ef2dabe4bc05f7377175ed11137b.dir
size: 46845669
md5: a3d2da9ac72423e555ae7ed051741b30.dir
size: 69405970
nfiles: 3
- path: experiments/scripts/explain.py
md5: afc02ef263a59c911098dea969faa932
size: 3234
md5: 4d82a557627f59c884f52fd7994ed80a
size: 4617
outs:
- path: data/explanations/20_news/
md5: c8ba90f9757a4e3cc4843d3791ef2446.dir
size: 232912969
nfiles: 14041
md5: ad1f9f0df287078edebed1e408df2c9f.dir
size: 869336544
nfiles: 140401
......@@ -8,6 +8,7 @@ import numpy as np
import pandas as pd
import shap
import torch
from tqdm import tqdm
from text_attacks.utils import get_model_and_tokenizer
......@@ -79,6 +80,14 @@ def main(
f"data/preprocessed/{dataset_name}/adversarial.jsonl", lines=True
)
test_x = test["text"].tolist()
test_x = [
tokenizer.decode(
tokenizer.encode(
t, padding="do_not_pad", max_length=512, truncation=True
),
skip_special_tokens=True
) for t in test_x
]
predict = build_predict_fun(model, tokenizer)
explainer = shap.Explainer(
......@@ -100,14 +109,47 @@ def main(
# LOCAL IMPORTANCE
for class_id, class_name in model.config.id2label.items():
sub_dir = output_dir / "local" / class_name
sub_dir = output_dir / "local" / "adversarial" /class_name
os.makedirs(sub_dir, exist_ok=True)
for shap_id, text_id in enumerate(test["id"]):
importance_df = get_importance(shap_values[shap_id, :, class_id])
importance_df.to_json(
sub_dir / f"{text_id}__importance.json",
)
# LOCAL IMPORTANCE (test set)
test = pd.read_json(
f"data/preprocessed/{dataset_name}/test.jsonl", lines=True
)
test_x = test["text"].tolist()
test_x = [
tokenizer.decode(
tokenizer.encode(
t, padding="do_not_pad", max_length=512, truncation=True
),
skip_special_tokens=True
) for t in test_x
]
predict = build_predict_fun(model, tokenizer)
explainer = shap.Explainer(
predict,
masker=tokenizer,
output_names=list(model.config.id2label.values())
)
for text_id, text in tqdm(
zip(test["id"], test_x),
total=len(test_x),
desc="Shap for test DS",
):
shap_values = explainer([text])
for class_id, class_name in model.config.id2label.items():
sub_dir = output_dir / "local" / "test" / class_name
os.makedirs(sub_dir, exist_ok=True)
importance_df = get_importance(shap_values[0, :, class_id])
importance_df.to_json(
sub_dir / f"{text_id}__importance.json",
)
if __name__ == "__main__":
main()
......@@ -6,6 +6,8 @@ import json
import os
from tqdm import tqdm
from multiprocessing import cpu_count, Pool
import spacy
TOKENS = "tokens"
ORTH = "orth"
......@@ -54,7 +56,7 @@ def tag_sentence(sentence: str, lang: str):
def process_file(dataset_df, lang, output_path):
test_with_tags = pd.DataFrame(dataset_df)
lemmas_col, tags_col, orth_col = [], [], []
cpus = 8
cpus = 2
with Pool(processes=cpus) as pool:
results = []
for idx in tqdm(range(0, len(dataset_df), cpus)):
......@@ -73,8 +75,28 @@ def process_file(dataset_df, lang, output_path):
test_with_tags[TAGS] = tags_col
test_with_tags[ORTHS] = orth_col
with open(output_path, mode="wt") as fd:
fd.write(test_with_tags.to_json(orient="records", lines=True))
return test_with_tags
def add_ner(dataset_df, language):
model = "en_core_web_trf" if language == "en" else "pl_core_news_lg"
nlp = spacy.load(model)
ner_data = list()
for text in tqdm(dataset_df["text"]):
doc = nlp(text)
doc_ner = list()
for ent in doc.ents:
doc_ner.append({
"text": ent.text,
"start_char": ent.start_char,
"end_char": ent.end_char,
"label": ent.label_,
})
ner_data.append(doc_ner)
dataset_df["ner"] = ner_data
return dataset_df
@click.command()
......@@ -83,7 +105,13 @@ def process_file(dataset_df, lang, output_path):
help="Dataset name",
type=str,
)
def main(dataset_name: str):
@click.option(
"--output",
help="Output directory",
type=str,
)
def main(dataset_name: str, output: str):
"""Downloads the dataset to the output directory."""
lang = {
"enron_spam": "en",
......@@ -91,18 +119,19 @@ def main(dataset_name: str):
"20_news": "en",
"wiki_pl": "pl",
}[dataset_name]
output_dir = f"data/preprocessed/{dataset_name}"
output_dir = f"{output}/{dataset_name}"
os.makedirs(output_dir, exist_ok=True)
input_dir = f"data/datasets/{dataset_name}"
for file in os.listdir(input_dir):
if os.path.isfile(os.path.join(input_dir, file)):
if file == "test.jsonl":
process_file(
if file in ["test.jsonl", "adversarial.jsonl"]:
test_with_tags = process_file(
pd.read_json(os.path.join(input_dir, file), lines=True),
lang,
os.path.join(output_dir, file),
)
test_with_tags = add_ner(test_with_tags, lang)
else:
test_with_tags = pd.DataFrame(
pd.read_json(os.path.join(input_dir, file), lines=True)
......@@ -111,10 +140,11 @@ def main(dataset_name: str):
test_with_tags[LEMMAS] = empty_list
test_with_tags[TAGS] = empty_list
test_with_tags[ORTHS] = empty_list
with open(os.path.join(output_dir, file), mode="wt") as fd:
fd.write(
test_with_tags.to_json(orient="records", lines=True)
)
test_with_tags["ner"] = empty_list
with open(os.path.join(output_dir, file), mode="wt") as fd:
fd.write(
test_with_tags.to_json(orient="records", lines=True)
)
if __name__ == "__main__":
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment