From cc33cb84ac9d352d1844d5b4d21870db06744b22 Mon Sep 17 00:00:00 2001 From: pwalkow <pwalkow@gpu-server.ws.clarin> Date: Fri, 10 Mar 2023 13:19:43 +0100 Subject: [PATCH] Add textfooler req --- dvc.lock | 14 ++++---- dvc.yaml | 8 ++--- experiments/scripts/attack.py | 61 +++++++++++++++++++++++++++++++++++ requirements.txt | 4 +++ 4 files changed, 76 insertions(+), 11 deletions(-) create mode 100644 experiments/scripts/attack.py diff --git a/dvc.lock b/dvc.lock index b3d2e13..346ede5 100644 --- a/dvc.lock +++ b/dvc.lock @@ -52,20 +52,20 @@ stages: cmd: PYTHONPATH=. python experiments/scripts/explain.py --dataset_name enron_spam --output_dir data/explanations/enron_spam deps: - - path: data/datasets/enron_spam - md5: 66d44efedf37990b1989c81bbee085e0.dir - size: 53096069 - nfiles: 3 - path: data/models/enron_spam md5: 3e16b22f59532c66beeadea958e0579a.dir size: 18505614 nfiles: 6 + - path: data/preprocessed/enron_spam + md5: b75efba1a62182dc8ac32acd1faf92ed.dir + size: 61709260 + nfiles: 3 - path: experiments/scripts/explain.py - md5: c85cbb774f2682ee39948e701fa0b0ca - size: 1445 + md5: 4e40a6415038ec6eb4140b54ff65c9c0 + size: 1449 outs: - path: data/explanations/enron_spam/ - md5: 376bd1619c08b4989564788e74de8e06.dir + md5: 345282e7c4e774d55aba55ed56ec464f.dir size: 7870394 nfiles: 1 download_dataset@poleval: diff --git a/dvc.yaml b/dvc.yaml index a110e67..533298e 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -30,7 +30,7 @@ stages: get_model: foreach: - enron_spam - - poleval + # - poleval do: wdir: . cmd: >- @@ -45,7 +45,7 @@ stages: classify: foreach: - enron_spam - - poleval + #- poleval do: wdir: . cmd: >- @@ -61,7 +61,7 @@ stages: explain: foreach: - enron_spam - - poleval + #- poleval do: wdir: . cmd: >- @@ -71,6 +71,6 @@ stages: deps: - experiments/scripts/explain.py - data/models/${item} - - data/proprocessed/${item} + - data/preprocessed/${item} outs: - data/explanations/${item}/ diff --git a/experiments/scripts/attack.py b/experiments/scripts/attack.py new file mode 100644 index 0000000..0980b60 --- /dev/null +++ b/experiments/scripts/attack.py @@ -0,0 +1,61 @@ +"""Script for running attacks on datasets.""" +import click +import pandas as pd +import json +import os +from tqdm import tqdm +from multiprocessing import cpu_count, Pool +from text_attacks.utils import get_classify_function +from textfooler import Attack, TextFooler + + +TEXT = 'text' +LEMMAS = 'lemmas' +TAGS = 'tags' + + +def spoil_sentence(sentence, lemmas, tags, lang, classify_fun, similarity): + attack = TextFooler(lang) + return attack.process(sentence, lemmas, tags, classify_fun, similarity) + + +@click.command() +@click.option( + "--dataset_name", + help="Dataset name", + type=str, +) +def main(dataset_name: str): + """Downloads the dataset to the output directory.""" + lang = 'en' if dataset_name == 'enron_spam' else 'pl' + output_dir = f"data/results/{dataset_name}" + input_file = f"data/preprocessed/{dataset_name}/test.jsonl" + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, 'test.jsonl') + classify = get_classify_function( + dataset_name=dataset_name, + ) + dataset_df = pd.read_json(input_file, lines=True) + spoiled = [] + similarity = 0.95 + cpus = cpu_count() + with Pool(processes=cpus) as pool: + results = [] + for idx in tqdm(range(0, len(dataset_df), cpus)): + end = min(idx+cpus, len(dataset_df) + 1) + for sentence, lemmas, tags in dataset_df[[TEXT, LEMMAS, TAGS], idx:end]: + results.append(pool.apply_async(spoil_sentence, args=[sentence, lemmas, + tags, lang, classify, similarity])) + for res in results: + spoiled_sent = res.get() + spoiled.append(spoiled_sent) + results = [] + + with open(output_path, mode="wt") as fd: + fd.write(pd.DataFrame( + {"spoiled": spoiled}).to_json( + orient='records', lines=True)) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 62256fe..fec55bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,7 @@ lpmn_client_biz --find-links https://download.pytorch.org/whl/torch_stable.html torch==1.12.0+cu116 + +--index-url https://pypi.clarin-pl.eu/simple/ +plwn-api +git+ssh://git@gitlab.clarin-pl.eu/adversarial-attacks/textfooling.git@develop -- GitLab