Skip to content
Snippets Groups Projects
Commit 2c4d1375 authored by pwalkow's avatar pwalkow
Browse files

Add queues

parent 2ac1bad2
Branches
No related merge requests found
"""Script for running attacks on datasets.""" """Script for running attacks on datasets."""
import json
import click import click
import pandas as pd import pandas as pd
import os import os
from tqdm import tqdm from tqdm import tqdm
from text_attacks.utils import get_classify_function from text_attacks.utils import get_classify_function
from textfooler import Attack, TextFooler, BaseLine, process from textfooler import Attack, TextFooler, BaseLine, process, run_queue, filter_similarity_queue
from queue import Full, Empty
from time import sleep
from multiprocessing import Process
from multiprocessing import Queue, Manager
from threading import Thread
TEXT = "text" TEXT = "text"
LEMMAS = "lemmas" LEMMAS = "lemmas"
TAGS = "tags" TAGS = "tags"
ORTHS = "orths" ORTHS = "orths"
ID = "id"
ATTACK_SUMMARY = "attacks_summary" ATTACK_SUMMARY = "attacks_summary"
ATTACK_SUCCEEDED = "attacks_succeeded" ATTACK_SUCCEEDED = "attacks_succeeded"
...@@ -24,6 +31,8 @@ EXPECTED = "expected" ...@@ -24,6 +31,8 @@ EXPECTED = "expected"
ACTUAL = "actual" ACTUAL = "actual"
COSINE_SCORE = "cosine_score" COSINE_SCORE = "cosine_score"
CLASS = "class" CLASS = "class"
SLEEP_TIME = 0.01
QUEUE_SIZE = 1000
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
...@@ -35,13 +44,75 @@ DEFAULT_RES = { ...@@ -35,13 +44,75 @@ DEFAULT_RES = {
} }
def data_producer(queue_out, input_file):
dataset_df = pd.read_json(input_file, lines=True)
for i, cols in tqdm(
dataset_df[[TEXT, ID, LEMMAS, TAGS, ORTHS]].iterrows(), total=len(dataset_df)
):
try:
sentence, sent_id, lemmas, tags, orths = cols[0], cols[1], \
cols[2], cols[3], cols[4]
queue_out.put([sentence, orths, [], lemmas, tags, sent_id])
except Full:
sleep(SLEEP_TIME)
try:
queue_out.put(None)
except Full:
sleep(SLEEP_TIME)
def data_saver(queue_in, output_file):
item = 1
while item is not None:
try:
item = queue_in.get(block=False)
except Empty:
sleep(SLEEP_TIME)
continue
if item is not None:
with open(output_file, 'a') as file_out:
json.dump(item, file_out, indent=2)
def classify_queue(queue_in, queue_out, classify_fun):
item = True
while item is not None:
try:
item = queue_in.get(block=False)
except Empty:
sleep(SLEEP_TIME)
continue
if item is not None:
try:
sent_id, org_sentence, changed_sents = item
sentences = [org_sentence].extend([sent[TEXT] for sent in changed_sents])
classified = classify_fun(sentences)
queue_out.put((sent_id, org_sentence, changed_sents, classified))
except Full:
sleep(SLEEP_TIME)
continue
queue_out.put(None)
def log_queues(queues):
while True:
sizes = [q.qsize() for q in queues]
print(sizes, flush=True)
sleep(2)
@click.command() @click.command()
@click.option( @click.option(
"--dataset_name", "--dataset_name",
help="Dataset name", help="Dataset name",
type=str, type=str,
) )
def main(dataset_name: str): @click.option(
"--attack_type",
help="Attack type",
type=str,
)
def main(dataset_name: str, attack_type: str):
"""Downloads the dataset to the output directory.""" """Downloads the dataset to the output directory."""
lang = { lang = {
"enron_spam": "en", "enron_spam": "en",
...@@ -49,32 +120,37 @@ def main(dataset_name: str): ...@@ -49,32 +120,37 @@ def main(dataset_name: str):
"20_news": "en", "20_news": "en",
"wiki_pl": "pl", "wiki_pl": "pl",
}[dataset_name] }[dataset_name]
output_dir = f"data/results/{dataset_name}"
attack = {
"attack_textfooler": TextFooler(lang),
"attack_basic": BaseLine(lang, 0.5, 0.4, 0.3)
}[attack_type]
# sim = Similarity(0.95, "distiluse-base-multilingual-cased-v1")
output_dir = f"data/results/{attack_type}/{dataset_name}/"
input_file = f"data/preprocessed/{dataset_name}/test.jsonl" input_file = f"data/preprocessed/{dataset_name}/test.jsonl"
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "test.jsonl") output_path = os.path.join(output_dir, "test.jsonl")
classify = get_classify_function(dataset_name=dataset_name) classify = get_classify_function(dataset_name=dataset_name)
dataset_df = pd.read_json(input_file, lines=True) max_sub = 1
spoiled, results = [], [] m = Manager()
similarity, max_sub = 0.95, 1 queues = [m.Queue(maxsize=QUEUE_SIZE) for _ in range(5)]
classes = classify(dataset_df[TEXT].tolist()) processes = [Process(target=data_producer, args=(queues[0], input_file,)), # loading data file_in -> 0
attack = TextFooler(lang) Process(target=attack.spoil_queue, args=(queues[0], queues[1], max_sub,)), # spoiling 0 -> 1
Process(target=filter_similarity_queue, args=(queues[1], queues[2], 0.95,
"distiluse-base-multilingual-cased-v1",)), # cosim 1 -> 2
Process(target=classify_queue, args=(queues[2], queues[3], classify, )), # classify changed 2 -> 3
Process(target=run_queue, args=(queues[3], queues[4], process,)), # process 3 -> 4
Process(target=data_saver, args=(queues[4], output_path,))] # saving 4 -> file_out
[p.start() for p in processes]
for i, cols in tqdm( log_que = Thread(target=log_queues, args=(queues, ))
dataset_df[[TEXT, LEMMAS, TAGS, ORTHS]].iterrows(), total=len(dataset_df) log_que.daemon = True
): log_que.start()
sentence, lemmas, tags, orths = cols[0], cols[1], cols[2], cols[3] # wait for all processes to finish
changed_sent = attack.spoil(sentence, [], lemmas, tags, orths, similarity, max_sub) [p.join() for p in processes]
if changed_sent: log_que.join(timeout=0.5)
spoiled.append(process(changed_sent, classes[i], classify))
with open(output_path, mode="wt") as fd:
fd.write(
pd.DataFrame({"spoiled": spoiled}).to_json(
orient="records", lines=True
)
)
if __name__ == "__main__": if __name__ == "__main__":
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment