Skip to content
Snippets Groups Projects
Commit 1ccb777f authored by Paweł Walkowiak's avatar Paweł Walkowiak
Browse files

Add sft samples creation

parent 7bbc37de
No related branches found
No related tags found
No related merge requests found
"""Script for running attacks on datasets."""
import importlib
import json
from collections import defaultdict
import click
import pandas as pd
import os
import torch
from tqdm import tqdm
from textfooler import Attack, TextFooler, Similarity, BaseLine, \
process, run_queue, filter_similarity_queue, spoil_queue, \
AttackMethod, get_xai_importance_diff
from time import sleep, time
from datetime import datetime
from multiprocessing import Process
from multiprocessing import Queue, Manager
from threading import Thread
from sklearn.metrics import classification_report, confusion_matrix
from string import punctuation
import fasttext.util
logging.set_verbosity_error()
TEXT = "text"
LEMMAS = "lemmas"
TAGS = "tags"
ORTHS = "orths"
ID = "id"
PRED = "pred_label"
NER = "ner"
ATTACK_SUMMARY = "attacks_summary"
ATTACK_SUCCEEDED = "attacks_succeeded"
SIMILARITY = "similarity"
CHANGED = "changed"
CHANGED_WORDS = "changed_words"
SUCCEEDED = "succeeded"
ALL = "all"
DIFF = "diff"
EXPECTED = "expected"
ACTUAL = "actual"
COSINE_SCORE = "cosine_score"
CLASS = "class"
QUEUE_SIZE = 60
FEATURES = "features"
IMPORTANCE = "importance"
SYNONYM = "synonym"
DISCARD = "discard"
GLOBAL = "global"
LOCAL = "local"
CHAR_DISCARD = "char_discard"
SYNONYMS_NBR = "synonyms_nbr"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
DEFAULT_RES = {
"spoiled": {
"attacks_summary": {"succeeded": 0, "all": 1},
"attacks_succeeded": [],
}
}
def join_punct(words):
punc = set(punctuation)
return "".join(w if set(w) <= punc else " " + w for w in words).lstrip()
def data_producer(queue_out, dataset_df, queue_recurse, queue_log, log_file):
try:
for i, cols in tqdm(
dataset_df[[TEXT, ID, LEMMAS, TAGS, ORTHS, PRED, NER]].iterrows(), total=len(dataset_df)
):
sentence, sent_id, lemmas, tags, orths, y_pred, ner = cols[0], cols[1], \
cols[2], cols[3], cols[4], \
cols[5], cols[6]
ners = []
for n in ner:
ners.extend(n[TEXT].split(" "))
queue_out.put([sentence, orths, ners, lemmas, tags, sent_id, y_pred, 1])
queue_log.put("Finished first Iteration")
item = 1
while item is not None:
item = queue_recurse.get()
if item is not None:
sent_id, sub = item[0], item[1]
queue_log.put(f"Recurse item id: {sent_id}, subs: {sub}")
row = dataset_df[dataset_df[ID] == sent_id].iloc[0]
sentence, orths, lemmas, tags, y_pred, ner = row[TEXT], row[ORTHS], \
row[LEMMAS], row[TAGS], row[PRED], row[NER]
ners = []
for n in ner:
ners.extend(n[TEXT].split(" "))
queue_out.put([sentence, orths, ners, lemmas, tags, sent_id, y_pred, sub])
except Exception as e:
queue_log.put(f"Error in data producer: {e}")
with open(log_file, "a") as f:
f.write(f"Producer failed with {e}\n")
queue_out.put(None)
def data_saver(queue_in, queue_log, queue_recurse, output_file,
output_dir, cases_nbr, queues_kill, to_kill_nbr, max_sub, log_file):
try:
item = True
while item is not None:
item = queue_in.get()
queue_log.put("Saving got from queue")
if item is not None:
results = []
sent_id, org_sentence, y_pred, changed, synonyms_nbr, sent_words, subs = item
sentences = []
for subst in changed:
subst = subst[0]
sent_words_copy = [*sent_words]
for idx, word_change in subst.items():
sent_words_copy[idx] = word_change['word']
sentences.append(join_punct(sent_words_copy))
for sent in sentences:
results.append({
ID: sent_id,
TEXT: sent,
"label": y_pred,
})
results = pd.DataFrame(results[:5]) # cut to 5 per sample
results.to_json(output_file, orient="records", lines=True, mode="a")
except Exception as e:
queue_log.put(f"Error in classifier: {e}")
with open(log_file, "a") as f:
f.write("Saver failed with {e}\n")
queue_in.put(None)
def log_queues(queues, log_file):
while True:
sizes = [q.qsize() for q in queues]
print(sizes, flush=True)
with open(log_file, "a") as f:
f.write(f"\t{sizes}\n")
sleep(10)
def log_info_queue(queue, log_file):
print("Logging queue")
while True:
item = queue.get()
if item is not None:
print(item)
with open(log_file, "a") as f:
f.write(f"\t{item}\n")
print("Logging queue finished")
def load_dir_files(dir_path):
result = {}
for filename in os.listdir(dir_path):
with open(os.path.join(dir_path, filename), 'r') as fin:
importance = json.load(fin)
result[filename.split("__")[0]] = {
word: importance[IMPORTANCE][idx]
for idx, word in importance[FEATURES].items()
if word
}
return result
def load_xai_importance(input_dir):
global_xai_dir = os.path.join(input_dir, "global")
local_xai_dir = os.path.join(input_dir, "local", "test")
local_dirs = os.listdir(local_xai_dir)
local_class_to_file = {dir_name: load_dir_files(os.path.join(local_xai_dir, dir_name))
for dir_name in local_dirs}
local_file_to_class = defaultdict(dict)
for c_name in local_dirs:
for f_name, value_df in local_class_to_file[c_name].items():
local_file_to_class[f_name][c_name] = value_df
return load_dir_files(global_xai_dir), dict(local_file_to_class)
@click.command()
@click.option(
"--dataset_name",
help="Dataset name",
type=str,
)
@click.option(
"--attack_type",
help="Attack type",
type=str,
)
def main(dataset_name: str, attack_type: str):
"""Downloads the dataset to the output directory."""
parameters = {}
with open('experiments/parameters.json', 'r') as fin:
parameters = json.load(fin)
lang = parameters.get("lang", {})[dataset_name]
xai_global, xai_local = {}, {}
xai_global_directed = {}
if "attack_xai" in attack_type:
importance = load_xai_importance(f"data/explanations/{dataset_name}")
xai_global, xai_local = importance[0], importance[1]
if "attack_xai_directed" in attack_type:
class_mapping = parameters.get("class_mapping", {})[dataset_name]
xai_global_directed = {source: get_xai_importance_diff(xai_global[source], xai_global[target])
for source, target in class_mapping.items()
}
max_sub = parameters.get("max_sub", 10)
word_change_size = parameters.get("word_change_size", 0.5)
similarity_bound = parameters.get("similarity_bound", 0.8)
word_synonym_threshold = parameters.get("word_synonym_threshold", 0.65)
sent_model = parameters.get("sent_model", "distiluse-base-multilingual-cased-v1")
menli_model = parameters.get("menli_model", "microsoft/deberta-large-mnli")
params = {
"attack_textfooler": [lang, AttackMethod.SYNONYM, word_synonym_threshold],
"attack_textfooler_discard": [lang, AttackMethod.DISCARD],
"attack_basic": [lang, 0.5, 0.4, 0.3], # prawopodobieństwa spacji > usunięcia znaku > usunięcia słowa
"attack_xai": [lang, xai_global, xai_local, GLOBAL, AttackMethod.SYNONYM, word_synonym_threshold],
"attack_xai_discard": [lang, xai_global, xai_local, GLOBAL, AttackMethod.DISCARD],
"attack_xai_local": [lang, xai_global, xai_local, LOCAL, AttackMethod.SYNONYM, word_synonym_threshold],
"attack_xai_discard_local": [lang, xai_global, xai_local, LOCAL, AttackMethod.DISCARD],
"attack_xai_char_discard": [lang, xai_global, xai_local, GLOBAL, AttackMethod.LETTER_DISCARD, word_synonym_threshold, word_change_size],
"attack_xai_char_swap": [lang, xai_global, xai_local, GLOBAL, AttackMethod.LETTER_SWAP, word_synonym_threshold, word_change_size],
"attack_xai_char_insert": [lang, xai_global, xai_local, GLOBAL, AttackMethod.LETTER_INSERT, word_synonym_threshold, word_change_size],
"attack_xai_char_substitute": [lang, xai_global, xai_local, GLOBAL, AttackMethod.LETTER_SUBSTITUTE, word_synonym_threshold, word_change_size],
"attack_xai_char_mixin": [lang, xai_global, xai_local, GLOBAL, AttackMethod.LETTER_MIX, word_synonym_threshold, word_change_size],
"attack_xai_directed_char_mixin": [lang, xai_global_directed, xai_local, GLOBAL, AttackMethod.LETTER_MIX, word_synonym_threshold, word_change_size],
"attack_xai_directed": [lang, xai_global_directed, xai_local, GLOBAL, AttackMethod.SYNONYM, word_synonym_threshold],
"attack_xai_directed_discard": [lang, xai_global_directed, xai_local, GLOBAL, AttackMethod.DISCARD],
"attack_xai_reverse": [lang, xai_global, xai_local, GLOBAL, AttackMethod.SYNONYM, word_synonym_threshold, word_change_size, False],
"attack_xai_reverse_discard": [lang, xai_global, xai_local, GLOBAL, AttackMethod.DISCARD, word_change_size, False],
"attack_xai_reverse_char_mixin": [lang, xai_global, xai_local, GLOBAL, AttackMethod.LETTER_MIX, word_synonym_threshold, word_change_size, False],
"attack_xai_insert": [lang, xai_global, xai_local, GLOBAL, AttackMethod.INSERT, word_synonym_threshold, word_change_size],
}[attack_type]
output_dir = f"data/results_sft/{attack_type}/{dataset_name}/"
input_file = f"data/classification/{dataset_name}/test.jsonl"
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "test.jsonl")
dataset_df = pd.read_json(input_file, lines=True)
log_file = f"logs/{attack_type}_{dataset_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M')}.log"
m = Manager()
queues = [m.Queue(maxsize=QUEUE_SIZE) for _ in range(4)]
queues.append(m.Queue(maxsize=max(int(1.5 * len(dataset_df)), QUEUE_SIZE)))
queues.append(m.Queue(maxsize=QUEUE_SIZE))
log_que = Thread(target=log_queues, args=(queues[:5], log_file))
log_que.daemon = True
log_que.start()
info_que = Thread(target=log_info_queue, args=(queues[5], log_file))
info_que.daemon = True
info_que.start()
# load fasttext model
if attack_type in [
"attack_xai",
"attack_xai_local",
"attack_textfooler",
"attack_xai_directed",
"attack_xai_directed_local",
"attack_xai_reverse"
]:
print("Downloading fasttext model")
if lang == "en":
fasttext.util.download_model('en', if_exists='ignore')
elif lang == "pl":
fasttext.util.download_model('pl', if_exists='ignore')
ft_model_name = f"cc.{lang}.300.bin"
print("Downloading fasttext model finished")
else:
ft_model_name = None
processes_nbr = 12
sim = Similarity(queues[5], similarity_bound, sent_model, menli_model, lang)
processes = [Process(target=data_producer, args=(queues[0], dataset_df, queues[4], queues[5], log_file))] # loading data file_in -> 0
processes.extend([Process(target=spoil_queue, args=(queues[0], queues[1], queues[5],
attack_type, params, ft_model_name, log_file))
for _ in range(processes_nbr)]) # spoiling 0 -> 1
processes.extend([Process(target=filter_similarity_queue, args=(queues[1], queues[2], queues[5], sim, log_file)),
Process(target=filter_similarity_queue, args=(queues[1], queues[2], queues[5], sim, log_file)), # cosim 1 -> 2
Process(target=data_saver, args=(queues[2], queues[5], queues[3], output_path,
output_dir, len(dataset_df), queues, processes_nbr+6, max_sub,
log_file)) # saving 3 -> file_out
])
[p.start() for p in processes]
# wait for all processes to finish
[p.join() for p in processes]
log_que.join(timeout=0.5)
info_que.join(timeout=0.5)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment