Skip to content
Snippets Groups Projects
Commit bbf4b8e8 authored by pwalkow's avatar pwalkow
Browse files

Add queues

parent 2c4d1375
No related branches found
No related tags found
No related merge requests found
......@@ -4,13 +4,16 @@ import json
import click
import pandas as pd
import os
import torch
from tqdm import tqdm
from text_attacks.utils import get_classify_function
from textfooler import Attack, TextFooler, BaseLine, process, run_queue, filter_similarity_queue
from queue import Full, Empty
from textfooler import Attack, TextFooler, Similarity, BaseLine, \
process, run_queue, filter_similarity_queue, spoil_queue
from time import sleep
from multiprocessing import Process
from multiprocessing import Queue, Manager
import multiprocess
from threading import Thread
TEXT = "text"
......@@ -31,7 +34,6 @@ EXPECTED = "expected"
ACTUAL = "actual"
COSINE_SCORE = "cosine_score"
CLASS = "class"
SLEEP_TIME = 0.01
QUEUE_SIZE = 1000
os.environ["TOKENIZERS_PARALLELISM"] = "false"
......@@ -49,48 +51,33 @@ def data_producer(queue_out, input_file):
for i, cols in tqdm(
dataset_df[[TEXT, ID, LEMMAS, TAGS, ORTHS]].iterrows(), total=len(dataset_df)
):
try:
sentence, sent_id, lemmas, tags, orths = cols[0], cols[1], \
cols[2], cols[3], cols[4]
queue_out.put([sentence, orths, [], lemmas, tags, sent_id])
except Full:
sleep(SLEEP_TIME)
try:
queue_out.put(None)
except Full:
sleep(SLEEP_TIME)
def data_saver(queue_in, output_file):
item = 1
while item is not None:
try:
item = queue_in.get(block=False)
except Empty:
sleep(SLEEP_TIME)
continue
item = queue_in.get()
if item is not None:
with open(output_file, 'a') as file_out:
json.dump(item, file_out, indent=2)
def classify_queue(queue_in, queue_out, classify_fun):
def classify_queue(queue_in, queue_out, queue_log, classify_fun):
item = True
while item is not None:
try:
item = queue_in.get(block=False)
except Empty:
sleep(SLEEP_TIME)
continue
item = queue_in.get()
queue_log.put("Classify got from queue")
if item is not None:
try:
sent_id, org_sentence, changed_sents = item
sentences = [org_sentence].extend([sent[TEXT] for sent in changed_sents])
sentences = [org_sentence]
sentences.extend([sent[TEXT] for sent in changed_sents])
queue_log.put(f"Classifying sentences {sentences[:100]}")
classified = classify_fun(sentences)
queue_out.put((sent_id, org_sentence, changed_sents, classified))
except Full:
sleep(SLEEP_TIME)
continue
queue_out.put(None)
......@@ -101,6 +88,13 @@ def log_queues(queues):
sleep(2)
def log_info_queue(queue):
print("Logging queue")
while True:
item = queue.get()
print(item)
@click.command()
@click.option(
"--dataset_name",
......@@ -120,27 +114,25 @@ def main(dataset_name: str, attack_type: str):
"20_news": "en",
"wiki_pl": "pl",
}[dataset_name]
attack = {
"attack_textfooler": TextFooler(lang),
"attack_basic": BaseLine(lang, 0.5, 0.4, 0.3)
params = {
"attack_textfooler": [lang],
"attack_basic": [lang, 0.5, 0.4, 0.3]
}[attack_type]
# sim = Similarity(0.95, "distiluse-base-multilingual-cased-v1")
output_dir = f"data/results/{attack_type}/{dataset_name}/"
input_file = f"data/preprocessed/{dataset_name}/test.jsonl"
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "test.jsonl")
classify = get_classify_function(dataset_name=dataset_name)
classify = get_classify_function(dataset_name=dataset_name, device="cpu")
max_sub = 1
m = Manager()
queues = [m.Queue(maxsize=QUEUE_SIZE) for _ in range(5)]
queues = [m.Queue(maxsize=QUEUE_SIZE) for _ in range(6)]
sim = Similarity(queues[5], 0.95, "distiluse-base-multilingual-cased-v1")
processes = [Process(target=data_producer, args=(queues[0], input_file,)), # loading data file_in -> 0
Process(target=attack.spoil_queue, args=(queues[0], queues[1], max_sub,)), # spoiling 0 -> 1
Process(target=filter_similarity_queue, args=(queues[1], queues[2], 0.95,
"distiluse-base-multilingual-cased-v1",)), # cosim 1 -> 2
Process(target=classify_queue, args=(queues[2], queues[3], classify, )), # classify changed 2 -> 3
Process(target=spoil_queue, args=(queues[0], queues[1], max_sub, attack_type, params)), # spoiling 0 -> 1
Process(target=filter_similarity_queue, args=(queues[1], queues[2], queues[5], sim)), # cosim 1 -> 2
multiprocess.Process(target=classify_queue, args=(queues[2], queues[3], queues[5], classify, )), # classify changed 2 -> 3
Process(target=run_queue, args=(queues[3], queues[4], process,)), # process 3 -> 4
Process(target=data_saver, args=(queues[4], output_path,))] # saving 4 -> file_out
[p.start() for p in processes]
......@@ -148,9 +140,13 @@ def main(dataset_name: str, attack_type: str):
log_que = Thread(target=log_queues, args=(queues, ))
log_que.daemon = True
log_que.start()
info_que = Thread(target=log_info_queue, args=(queues[5], ))
info_que.daemon = True
info_que.start()
# wait for all processes to finish
[p.join() for p in processes]
log_que.join(timeout=0.5)
info_que.join(timeout=0.5)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment