Skip to content
Snippets Groups Projects
Commit 6139763f authored by Mateusz Gniewkowski's avatar Mateusz Gniewkowski
Browse files

Merge branch 'xai_test' into 'master'

Xai test

See merge request adversarial-attacks/text-attacks!4
parents 2ac1bad2 9bdaf9cb
No related branches found
No related tags found
No related merge requests found
...@@ -17,8 +17,8 @@ stages: ...@@ -17,8 +17,8 @@ stages:
--output_dir data/models/enron_spam --output_dir data/models/enron_spam
deps: deps:
- path: data/preprocessed/enron_spam - path: data/preprocessed/enron_spam
md5: b75efba1a62182dc8ac32acd1faf92ed.dir md5: 99d604f84516cee94948054a97ffec5e.dir
size: 61709260 size: 71403809
nfiles: 3 nfiles: 3
- path: experiments/scripts/get_model.py - path: experiments/scripts/get_model.py
md5: 5050f51b4019bba97af47971f6c7cab4 md5: 5050f51b4019bba97af47971f6c7cab4
...@@ -37,8 +37,8 @@ stages: ...@@ -37,8 +37,8 @@ stages:
size: 18505614 size: 18505614
nfiles: 6 nfiles: 6
- path: data/preprocessed/enron_spam/ - path: data/preprocessed/enron_spam/
md5: b75efba1a62182dc8ac32acd1faf92ed.dir md5: 99d604f84516cee94948054a97ffec5e.dir
size: 61709260 size: 71403809
nfiles: 3 nfiles: 3
- path: experiments/scripts/classify.py - path: experiments/scripts/classify.py
md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5
...@@ -57,17 +57,17 @@ stages: ...@@ -57,17 +57,17 @@ stages:
size: 18505614 size: 18505614
nfiles: 6 nfiles: 6
- path: data/preprocessed/enron_spam - path: data/preprocessed/enron_spam
md5: b75efba1a62182dc8ac32acd1faf92ed.dir md5: 99d604f84516cee94948054a97ffec5e.dir
size: 61709260 size: 71403809
nfiles: 3 nfiles: 3
- path: experiments/scripts/explain.py - path: experiments/scripts/explain.py
md5: afc02ef263a59c911098dea969faa932 md5: 4d82a557627f59c884f52fd7994ed80a
size: 3234 size: 4617
outs: outs:
- path: data/explanations/enron_spam/ - path: data/explanations/enron_spam/
md5: 70c3f3d04e0b73fd56eecfda04914bd4.dir md5: 4a4630717bf6bf5027f1b70191c69a4c.dir
size: 13589794 size: 33723647
nfiles: 403 nfiles: 4003
download_dataset@poleval: download_dataset@poleval:
cmd: PYTHONPATH=. python experiments/scripts/download_dataset.py --dataset_name cmd: PYTHONPATH=. python experiments/scripts/download_dataset.py --dataset_name
poleval --output_dir data/datasets/poleval poleval --output_dir data/datasets/poleval
...@@ -88,12 +88,12 @@ stages: ...@@ -88,12 +88,12 @@ stages:
size: 1688836 size: 1688836
nfiles: 3 nfiles: 3
- path: experiments/scripts/tag_dataset.py - path: experiments/scripts/tag_dataset.py
md5: f73e2203fdb988a00d4e8363a349c617 md5: ebadced7a031a31bdaf935d2b22e5e05
size: 3932 size: 4632
outs: outs:
- path: data/preprocessed/poleval/ - path: data/preprocessed/poleval/
md5: 9d067db65ba6a27db19effce45b01876.dir md5: b0ea9f0ad1dba6d3b474c0a3cedf866e.dir
size: 2541105 size: 2812175
nfiles: 3 nfiles: 3
preprocess_dataset@enron_spam: preprocess_dataset@enron_spam:
cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name enron_spam cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name enron_spam
...@@ -103,12 +103,12 @@ stages: ...@@ -103,12 +103,12 @@ stages:
size: 53096069 size: 53096069
nfiles: 3 nfiles: 3
- path: experiments/scripts/tag_dataset.py - path: experiments/scripts/tag_dataset.py
md5: f73e2203fdb988a00d4e8363a349c617 md5: ebadced7a031a31bdaf935d2b22e5e05
size: 3932 size: 4632
outs: outs:
- path: data/preprocessed/enron_spam/ - path: data/preprocessed/enron_spam/
md5: 30c63efbc615347ddcb5f61e011113bd.dir md5: 99d604f84516cee94948054a97ffec5e.dir
size: 65971374 size: 71403809
nfiles: 3 nfiles: 3
preprocess_dataset@wiki_pl: preprocess_dataset@wiki_pl:
cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name wiki_pl cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name wiki_pl
...@@ -118,12 +118,12 @@ stages: ...@@ -118,12 +118,12 @@ stages:
size: 29115538 size: 29115538
nfiles: 3 nfiles: 3
- path: experiments/scripts/tag_dataset.py - path: experiments/scripts/tag_dataset.py
md5: f73e2203fdb988a00d4e8363a349c617 md5: ebadced7a031a31bdaf935d2b22e5e05
size: 3932 size: 4632
outs: outs:
- path: data/preprocessed/wiki_pl/ - path: data/preprocessed/wiki_pl/
md5: 0014b9bb52913cbc9a568d237ea2207b.dir md5: 066634606f832b6c9d1db95293de7e04.dir
size: 65553079 size: 77818549
nfiles: 3 nfiles: 3
classify@wiki_pl: classify@wiki_pl:
cmd: PYTHONPATH=. python experiments/scripts/classify.py --dataset_name wiki_pl cmd: PYTHONPATH=. python experiments/scripts/classify.py --dataset_name wiki_pl
...@@ -134,8 +134,8 @@ stages: ...@@ -134,8 +134,8 @@ stages:
size: 501711136 size: 501711136
nfiles: 7 nfiles: 7
- path: data/preprocessed/wiki_pl/ - path: data/preprocessed/wiki_pl/
md5: 3e9b2e1e0542777e0a751d9d7f7f4241.dir md5: 066634606f832b6c9d1db95293de7e04.dir
size: 55380570 size: 77818549
nfiles: 3 nfiles: 3
- path: experiments/scripts/classify.py - path: experiments/scripts/classify.py
md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5
...@@ -153,12 +153,12 @@ stages: ...@@ -153,12 +153,12 @@ stages:
size: 23460358 size: 23460358
nfiles: 3 nfiles: 3
- path: experiments/scripts/tag_dataset.py - path: experiments/scripts/tag_dataset.py
md5: f73e2203fdb988a00d4e8363a349c617 md5: ebadced7a031a31bdaf935d2b22e5e05
size: 3932 size: 4632
outs: outs:
- path: data/preprocessed/20_news/ - path: data/preprocessed/20_news/
md5: 20da0980e52df537e5b7ca5db0305879.dir md5: a3d2da9ac72423e555ae7ed051741b30.dir
size: 58582060 size: 69405970
nfiles: 3 nfiles: 3
classify@20_news: classify@20_news:
cmd: PYTHONPATH=. python experiments/scripts/classify.py --dataset_name 20_news cmd: PYTHONPATH=. python experiments/scripts/classify.py --dataset_name 20_news
...@@ -169,8 +169,8 @@ stages: ...@@ -169,8 +169,8 @@ stages:
size: 439008808 size: 439008808
nfiles: 6 nfiles: 6
- path: data/preprocessed/20_news/ - path: data/preprocessed/20_news/
md5: 1ed5ef2dabe4bc05f7377175ed11137b.dir md5: a3d2da9ac72423e555ae7ed051741b30.dir
size: 46845669 size: 69405970
nfiles: 3 nfiles: 3
- path: experiments/scripts/classify.py - path: experiments/scripts/classify.py
md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5
...@@ -208,17 +208,17 @@ stages: ...@@ -208,17 +208,17 @@ stages:
size: 501711136 size: 501711136
nfiles: 7 nfiles: 7
- path: data/preprocessed/wiki_pl - path: data/preprocessed/wiki_pl
md5: 3e9b2e1e0542777e0a751d9d7f7f4241.dir md5: 066634606f832b6c9d1db95293de7e04.dir
size: 55380570 size: 77818549
nfiles: 3 nfiles: 3
- path: experiments/scripts/explain.py - path: experiments/scripts/explain.py
md5: afc02ef263a59c911098dea969faa932 md5: 4d82a557627f59c884f52fd7994ed80a
size: 3234 size: 4617
outs: outs:
- path: data/explanations/wiki_pl/ - path: data/explanations/wiki_pl/
md5: 5a3b9b069024456412078143e3af15d7.dir md5: a11dc3a6b329a1d6bacd50c158ff3e6c.dir
size: 331450794 size: 1096043911
nfiles: 10065 nfiles: 100403
explain@20_news: explain@20_news:
cmd: PYTHONPATH=. python experiments/scripts/explain.py --dataset_name 20_news cmd: PYTHONPATH=. python experiments/scripts/explain.py --dataset_name 20_news
--output_dir data/explanations/20_news --output_dir data/explanations/20_news
...@@ -228,14 +228,14 @@ stages: ...@@ -228,14 +228,14 @@ stages:
size: 439008808 size: 439008808
nfiles: 6 nfiles: 6
- path: data/preprocessed/20_news - path: data/preprocessed/20_news
md5: 1ed5ef2dabe4bc05f7377175ed11137b.dir md5: a3d2da9ac72423e555ae7ed051741b30.dir
size: 46845669 size: 69405970
nfiles: 3 nfiles: 3
- path: experiments/scripts/explain.py - path: experiments/scripts/explain.py
md5: afc02ef263a59c911098dea969faa932 md5: 4d82a557627f59c884f52fd7994ed80a
size: 3234 size: 4617
outs: outs:
- path: data/explanations/20_news/ - path: data/explanations/20_news/
md5: c8ba90f9757a4e3cc4843d3791ef2446.dir md5: ad1f9f0df287078edebed1e408df2c9f.dir
size: 232912969 size: 869336544
nfiles: 14041 nfiles: 140401
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import pandas as pd import pandas as pd
import shap import shap
import torch import torch
from tqdm import tqdm
from text_attacks.utils import get_model_and_tokenizer from text_attacks.utils import get_model_and_tokenizer
...@@ -79,6 +80,14 @@ def main( ...@@ -79,6 +80,14 @@ def main(
f"data/preprocessed/{dataset_name}/adversarial.jsonl", lines=True f"data/preprocessed/{dataset_name}/adversarial.jsonl", lines=True
) )
test_x = test["text"].tolist() test_x = test["text"].tolist()
test_x = [
tokenizer.decode(
tokenizer.encode(
t, padding="do_not_pad", max_length=512, truncation=True
),
skip_special_tokens=True
) for t in test_x
]
predict = build_predict_fun(model, tokenizer) predict = build_predict_fun(model, tokenizer)
explainer = shap.Explainer( explainer = shap.Explainer(
...@@ -100,7 +109,7 @@ def main( ...@@ -100,7 +109,7 @@ def main(
# LOCAL IMPORTANCE # LOCAL IMPORTANCE
for class_id, class_name in model.config.id2label.items(): for class_id, class_name in model.config.id2label.items():
sub_dir = output_dir / "local" / class_name sub_dir = output_dir / "local" / "adversarial" /class_name
os.makedirs(sub_dir, exist_ok=True) os.makedirs(sub_dir, exist_ok=True)
for shap_id, text_id in enumerate(test["id"]): for shap_id, text_id in enumerate(test["id"]):
importance_df = get_importance(shap_values[shap_id, :, class_id]) importance_df = get_importance(shap_values[shap_id, :, class_id])
...@@ -108,6 +117,39 @@ def main( ...@@ -108,6 +117,39 @@ def main(
sub_dir / f"{text_id}__importance.json", sub_dir / f"{text_id}__importance.json",
) )
# LOCAL IMPORTANCE (test set)
test = pd.read_json(
f"data/preprocessed/{dataset_name}/test.jsonl", lines=True
)
test_x = test["text"].tolist()
test_x = [
tokenizer.decode(
tokenizer.encode(
t, padding="do_not_pad", max_length=512, truncation=True
),
skip_special_tokens=True
) for t in test_x
]
predict = build_predict_fun(model, tokenizer)
explainer = shap.Explainer(
predict,
masker=tokenizer,
output_names=list(model.config.id2label.values())
)
for text_id, text in tqdm(
zip(test["id"], test_x),
total=len(test_x),
desc="Shap for test DS",
):
shap_values = explainer([text])
for class_id, class_name in model.config.id2label.items():
sub_dir = output_dir / "local" / "test" / class_name
os.makedirs(sub_dir, exist_ok=True)
importance_df = get_importance(shap_values[0, :, class_id])
importance_df.to_json(
sub_dir / f"{text_id}__importance.json",
)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -6,6 +6,8 @@ import json ...@@ -6,6 +6,8 @@ import json
import os import os
from tqdm import tqdm from tqdm import tqdm
from multiprocessing import cpu_count, Pool from multiprocessing import cpu_count, Pool
import spacy
TOKENS = "tokens" TOKENS = "tokens"
ORTH = "orth" ORTH = "orth"
...@@ -54,7 +56,7 @@ def tag_sentence(sentence: str, lang: str): ...@@ -54,7 +56,7 @@ def tag_sentence(sentence: str, lang: str):
def process_file(dataset_df, lang, output_path): def process_file(dataset_df, lang, output_path):
test_with_tags = pd.DataFrame(dataset_df) test_with_tags = pd.DataFrame(dataset_df)
lemmas_col, tags_col, orth_col = [], [], [] lemmas_col, tags_col, orth_col = [], [], []
cpus = 8 cpus = 2
with Pool(processes=cpus) as pool: with Pool(processes=cpus) as pool:
results = [] results = []
for idx in tqdm(range(0, len(dataset_df), cpus)): for idx in tqdm(range(0, len(dataset_df), cpus)):
...@@ -73,8 +75,28 @@ def process_file(dataset_df, lang, output_path): ...@@ -73,8 +75,28 @@ def process_file(dataset_df, lang, output_path):
test_with_tags[TAGS] = tags_col test_with_tags[TAGS] = tags_col
test_with_tags[ORTHS] = orth_col test_with_tags[ORTHS] = orth_col
with open(output_path, mode="wt") as fd: return test_with_tags
fd.write(test_with_tags.to_json(orient="records", lines=True))
def add_ner(dataset_df, language):
model = "en_core_web_trf" if language == "en" else "pl_core_news_lg"
nlp = spacy.load(model)
ner_data = list()
for text in tqdm(dataset_df["text"]):
doc = nlp(text)
doc_ner = list()
for ent in doc.ents:
doc_ner.append({
"text": ent.text,
"start_char": ent.start_char,
"end_char": ent.end_char,
"label": ent.label_,
})
ner_data.append(doc_ner)
dataset_df["ner"] = ner_data
return dataset_df
@click.command() @click.command()
...@@ -83,7 +105,13 @@ def process_file(dataset_df, lang, output_path): ...@@ -83,7 +105,13 @@ def process_file(dataset_df, lang, output_path):
help="Dataset name", help="Dataset name",
type=str, type=str,
) )
def main(dataset_name: str): @click.option(
"--output",
help="Output directory",
type=str,
)
def main(dataset_name: str, output: str):
"""Downloads the dataset to the output directory.""" """Downloads the dataset to the output directory."""
lang = { lang = {
"enron_spam": "en", "enron_spam": "en",
...@@ -91,18 +119,19 @@ def main(dataset_name: str): ...@@ -91,18 +119,19 @@ def main(dataset_name: str):
"20_news": "en", "20_news": "en",
"wiki_pl": "pl", "wiki_pl": "pl",
}[dataset_name] }[dataset_name]
output_dir = f"data/preprocessed/{dataset_name}" output_dir = f"{output}/{dataset_name}"
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
input_dir = f"data/datasets/{dataset_name}" input_dir = f"data/datasets/{dataset_name}"
for file in os.listdir(input_dir): for file in os.listdir(input_dir):
if os.path.isfile(os.path.join(input_dir, file)): if os.path.isfile(os.path.join(input_dir, file)):
if file == "test.jsonl": if file in ["test.jsonl", "adversarial.jsonl"]:
process_file( test_with_tags = process_file(
pd.read_json(os.path.join(input_dir, file), lines=True), pd.read_json(os.path.join(input_dir, file), lines=True),
lang, lang,
os.path.join(output_dir, file), os.path.join(output_dir, file),
) )
test_with_tags = add_ner(test_with_tags, lang)
else: else:
test_with_tags = pd.DataFrame( test_with_tags = pd.DataFrame(
pd.read_json(os.path.join(input_dir, file), lines=True) pd.read_json(os.path.join(input_dir, file), lines=True)
...@@ -111,6 +140,7 @@ def main(dataset_name: str): ...@@ -111,6 +140,7 @@ def main(dataset_name: str):
test_with_tags[LEMMAS] = empty_list test_with_tags[LEMMAS] = empty_list
test_with_tags[TAGS] = empty_list test_with_tags[TAGS] = empty_list
test_with_tags[ORTHS] = empty_list test_with_tags[ORTHS] = empty_list
test_with_tags["ner"] = empty_list
with open(os.path.join(output_dir, file), mode="wt") as fd: with open(os.path.join(output_dir, file), mode="wt") as fd:
fd.write( fd.write(
test_with_tags.to_json(orient="records", lines=True) test_with_tags.to_json(orient="records", lines=True)
......
...@@ -9,6 +9,9 @@ tokenizers==0.13.2 ...@@ -9,6 +9,9 @@ tokenizers==0.13.2
sentence-transformers==2.2.2 sentence-transformers==2.2.2
jupyter jupyter
matplotlib matplotlib
spacy
https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.5.0/en_core_web_trf-3.5.0-py3-none-any.whl
https://github.com/explosion/spacy-models/releases/download/pl_core_news_lg-3.5.0/pl_core_news_lg-3.5.0-py3-none-any.whl
--find-links https://download.pytorch.org/whl/torch_stable.html --find-links https://download.pytorch.org/whl/torch_stable.html
torch==1.12.0+cu116 torch==1.12.0+cu116
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment