Skip to content
Snippets Groups Projects
Commit 3868f577 authored by Paweł Walkowiak's avatar Paweł Walkowiak
Browse files

Merge branch 'xai_fooler' into 'master'

Xai fooler

See merge request adversarial-attacks/text-attacks!5
parents 6139763f 59c763a0
No related merge requests found
/wiki_pl
/enron_spam
/wiki_pl
/20_news
/wiki_pl
/enron_spam
/20_news
/enron_spam
/20_news
/wiki_pl
......@@ -45,8 +45,8 @@ stages:
size: 1181
outs:
- path: data/classification/enron_spam
md5: 0450c0b672bc4a5db3cc7be2dac786bd.dir
size: 10674882
md5: 5de1a2fcbae0de94f5fbfd2bb747d919.dir
size: 14585920
nfiles: 2
explain@enron_spam:
cmd: PYTHONPATH=. python experiments/scripts/explain.py --dataset_name enron_spam
......@@ -142,8 +142,8 @@ stages:
size: 1181
outs:
- path: data/classification/wiki_pl
md5: 515330772505f489b55686545bcf23a0.dir
size: 34103198
md5: 88c3cea96b2cb3ddda1a82037bf6130a.dir
size: 44196727
nfiles: 2
preprocess_dataset@20_news:
cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name 20_news
......@@ -177,8 +177,8 @@ stages:
size: 1181
outs:
- path: data/classification/20_news
md5: 6831f104f7c20541548fe72250c45706.dir
size: 31286120
md5: b73611443c4189af91b827c083f37e0b.dir
size: 42897496
nfiles: 2
attack_basic@enron_spam:
cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name enron_spam
......@@ -239,3 +239,203 @@ stages:
md5: ad1f9f0df287078edebed1e408df2c9f.dir
size: 869336544
nfiles: 140401
attack_basic@wiki_pl:
cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name wiki_pl
--attack_type attack_basic
deps:
- path: data/models/wiki_pl
md5: fd453042628fb09c080ef05d34a32cce.dir
size: 501711136
nfiles: 7
- path: data/preprocessed/wiki_pl
md5: 0014b9bb52913cbc9a568d237ea2207b.dir
size: 65553079
nfiles: 3
- path: experiments/scripts/attack.py
md5: 702997933e5af85d09d8286a14e2cc05
size: 2486
outs:
- path: data/results/attack_basic/wiki_pl/
md5: f118a41e391b5f713f77611140f2f2cc.dir
size: 1
nfiles: 1
attack_textfooler@enron_spam:
cmd: 'PYTHONPATH=. python experiments/scripts/attack.py --dataset_name enron_spam
--attack_type attack_textfooler '
deps:
- path: data/models/enron_spam
md5: 3e16b22f59532c66beeadea958e0579a.dir
size: 18505614
nfiles: 6
- path: data/preprocessed/enron_spam
md5: 30c63efbc615347ddcb5f61e011113bd.dir
size: 65971374
nfiles: 3
- path: experiments/scripts/attack.py
md5: b9d9a4d9fcba1cb4dfbb554ecc3e26fb
size: 10083
outs:
- path: data/results/attack_textfooler/enron_spam/
md5: 10ecd4c940e8df1058465048ffbe78d4.dir
size: 3291044
nfiles: 2
attack_textfooler@20_news:
cmd: 'PYTHONPATH=. python experiments/scripts/attack.py --dataset_name 20_news
--attack_type attack_textfooler '
deps:
- path: data/models/20_news
md5: 43d68a67ecb8149bd6bf50db9767cb64.dir
size: 439008808
nfiles: 6
- path: data/preprocessed/20_news
md5: 20da0980e52df537e5b7ca5db0305879.dir
size: 58582060
nfiles: 3
- path: experiments/scripts/attack.py
md5: 4fe9c6210ce0f3be66b54c2565ad2daa
size: 10132
outs:
- path: data/results/attack_textfooler/20_news/
md5: 007aba16e343ca283180c7bc7b9a0190.dir
size: 93666157
nfiles: 2
attack_textfooler@wiki_pl:
cmd: 'PYTHONPATH=. python experiments/scripts/attack.py --dataset_name wiki_pl
--attack_type attack_textfooler '
deps:
- path: data/classification/wiki_pl
md5: 88c3cea96b2cb3ddda1a82037bf6130a.dir
size: 44196727
nfiles: 2
- path: data/models/wiki_pl
md5: fd453042628fb09c080ef05d34a32cce.dir
size: 501711136
nfiles: 7
- path: experiments/scripts/attack.py
md5: 2977363ba8806c393498f98d5733c013
size: 11497
outs:
- path: data/results/attack_textfooler/wiki_pl/
md5: eccc12b9a5ae383ea02067cd1955753e.dir
size: 20293404
nfiles: 2
attack_textfooler_discard@wiki_pl:
cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name wiki_pl
--attack_type attack_textfooler_discard
deps:
- path: data/classification/wiki_pl
md5: 88c3cea96b2cb3ddda1a82037bf6130a.dir
size: 44196727
nfiles: 2
- path: data/models/wiki_pl
md5: fd453042628fb09c080ef05d34a32cce.dir
size: 501711136
nfiles: 7
- path: experiments/scripts/attack.py
md5: 2b9ddc1ff1f56855ff667171ba04ed78
size: 11606
outs:
- path: data/results/attack_textfooler_discard/wiki_pl/
md5: e41122c3cdf76ad1b163aba49acce0f0.dir
size: 14396685
nfiles: 2
attack_textfooler_discard@enron_spam:
cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name enron_spam
--attack_type attack_textfooler_discard
deps:
- path: data/classification/enron_spam
md5: 5de1a2fcbae0de94f5fbfd2bb747d919.dir
size: 14585920
nfiles: 2
- path: data/models/enron_spam
md5: 3e16b22f59532c66beeadea958e0579a.dir
size: 18505614
nfiles: 6
- path: experiments/scripts/attack.py
md5: 2b9ddc1ff1f56855ff667171ba04ed78
size: 11606
outs:
- path: data/results/attack_textfooler_discard/enron_spam/
md5: 8a78484bd77916f82021a72338342a44.dir
size: 2816160
nfiles: 2
attack_textfooler_discard@20_news:
cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name 20_news
--attack_type attack_textfooler_discard
deps:
- path: data/classification/20_news
md5: b73611443c4189af91b827c083f37e0b.dir
size: 42897496
nfiles: 2
- path: data/models/20_news
md5: 43d68a67ecb8149bd6bf50db9767cb64.dir
size: 439008808
nfiles: 6
- path: experiments/scripts/attack.py
md5: 9e913b341cb0993625a41c401d64a30b
size: 12017
outs:
- path: data/results/attack_textfooler_discard/20_news/
md5: 82d89b00a710e9de0a2157357fed5894.dir
size: 24977923
nfiles: 2
attack_xai@enron_spam:
cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name enron_spam
--attack_type attack_xai
deps:
- path: data/classification/enron_spam
md5: 5de1a2fcbae0de94f5fbfd2bb747d919.dir
size: 14585920
nfiles: 2
- path: data/models/enron_spam
md5: 3e16b22f59532c66beeadea958e0579a.dir
size: 18505614
nfiles: 6
- path: experiments/scripts/attack.py
md5: 87f54ee4e2a08f1259d9d8b2d01fe1b9
size: 12061
outs:
- path: data/results/attack_xai/enron_spam/
md5: ad19831866da140de113e64862da0bce.dir
size: 2860109
nfiles: 2
attack_xai@20_news:
cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name 20_news
--attack_type attack_xai
deps:
- path: data/classification/20_news
md5: b73611443c4189af91b827c083f37e0b.dir
size: 42897496
nfiles: 2
- path: data/models/20_news
md5: 43d68a67ecb8149bd6bf50db9767cb64.dir
size: 439008808
nfiles: 6
- path: experiments/scripts/attack.py
md5: 87f54ee4e2a08f1259d9d8b2d01fe1b9
size: 12061
outs:
- path: data/results/attack_xai/20_news/
md5: af00c730d4d73a0a8e2a047882c0d9aa.dir
size: 8739816
nfiles: 2
attack_xai@wiki_pl:
cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name wiki_pl
--attack_type attack_xai
deps:
- path: data/classification/wiki_pl
md5: 88c3cea96b2cb3ddda1a82037bf6130a.dir
size: 44196727
nfiles: 2
- path: data/models/wiki_pl
md5: fd453042628fb09c080ef05d34a32cce.dir
size: 501711136
nfiles: 7
- path: experiments/scripts/attack.py
md5: 87f54ee4e2a08f1259d9d8b2d01fe1b9
size: 12061
outs:
- path: data/results/attack_xai/wiki_pl/
md5: e24c456f63d8e13b92fcab51e0726141.dir
size: 8287334
nfiles: 2
......@@ -77,6 +77,102 @@ stages:
- data/preprocessed/${item}
outs:
- data/explanations/${item}/
attack_textfooler:
foreach:
- enron_spam
- 20_news
- wiki_pl
do:
wdir: .
cmd: >-
PYTHONPATH=. python experiments/scripts/attack.py
--dataset_name ${item} --attack_type attack_textfooler
deps:
- experiments/scripts/attack.py
- data/models/${item}
- data/classification/${item}
outs:
- data/results/attack_textfooler/${item}/
attack_textfooler_discard:
foreach:
- enron_spam
- 20_news
- wiki_pl
do:
wdir: .
cmd: >-
PYTHONPATH=. python experiments/scripts/attack.py
--dataset_name ${item} --attack_type attack_textfooler_discard
deps:
- experiments/scripts/attack.py
- data/models/${item}
- data/classification/${item}
outs:
- data/results/attack_textfooler_discard/${item}/
attack_xai:
foreach:
- enron_spam
- 20_news
- wiki_pl
do:
wdir: .
cmd: >-
PYTHONPATH=. python experiments/scripts/attack.py
--dataset_name ${item} --attack_type attack_xai
deps:
- experiments/scripts/attack.py
- data/models/${item}
- data/classification/${item}
outs:
- data/results/attack_xai/${item}/
attack_xai_discard:
foreach:
- enron_spam
- 20_news
- wiki_pl
do:
wdir: .
cmd: >-
PYTHONPATH=. python experiments/scripts/attack.py
--dataset_name ${item} --attack_type attack_xai_discard
deps:
- experiments/scripts/attack.py
- data/models/${item}
- data/classification/${item}
outs:
- data/results/attack_xai_discard/${item}/
attack_xai_local:
foreach:
- enron_spam
- 20_news
- wiki_pl
do:
wdir: .
cmd: >-
PYTHONPATH=. python experiments/scripts/attack.py
--dataset_name ${item} --attack_type attack_xai_local
deps:
- experiments/scripts/attack.py
- data/models/${item}
- data/classification/${item}
outs:
- data/results/attack_xai_local/${item}/
attack_xai_discard_local:
foreach:
- enron_spam
- 20_news
- wiki_pl
do:
wdir: .
cmd: >-
PYTHONPATH=. python experiments/scripts/attack.py
--dataset_name ${item} --attack_type attack_xai_discard_local
deps:
- experiments/scripts/attack.py
- data/models/${item}
- data/classification/${item}
outs:
- data/results/attack_xai_discard_local/${item}/
attack_basic:
foreach:
- enron_spam
......@@ -86,11 +182,11 @@ stages:
wdir: .
cmd: >-
PYTHONPATH=. python experiments/scripts/attack.py
--dataset_name ${item}
--dataset_name ${item} --attack_type attack_basic
deps:
- experiments/scripts/attack.py
- data/models/${item}
- data/preprocessed/${item}
- data/classification/${item}
outs:
- data/results/attack_basic/${item}/
"""Script for running attacks on datasets."""
import importlib
import json
from collections import defaultdict
import click
import pandas as pd
import os
from tqdm import tqdm
from text_attacks.utils import get_classify_function
from textfooler import Attack, TextFooler, BaseLine, process
import torch
from tqdm import tqdm
from textfooler import Attack, TextFooler, Similarity, BaseLine, \
process, run_queue, filter_similarity_queue, spoil_queue
from time import sleep, time
from multiprocessing import Process
from multiprocessing import Queue, Manager
from threading import Thread
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
TEXT = "text"
LEMMAS = "lemmas"
TAGS = "tags"
ORTHS = "orths"
ID = "id"
PRED = "pred_label"
ATTACK_SUMMARY = "attacks_summary"
ATTACK_SUCCEEDED = "attacks_succeeded"
......@@ -24,6 +36,14 @@ EXPECTED = "expected"
ACTUAL = "actual"
COSINE_SCORE = "cosine_score"
CLASS = "class"
QUEUE_SIZE = 1000
FEATURES = "features"
IMPORTANCE = "importance"
SYNONYM = "synonym"
DISCARD = "discard"
GLOBAL = "global"
LOCAL = "local"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
......@@ -35,13 +55,126 @@ DEFAULT_RES = {
}
def data_producer(queue_out, dataset_df):
for i, cols in tqdm(
dataset_df[[TEXT, ID, LEMMAS, TAGS, ORTHS, PRED]].iterrows(), total=len(dataset_df)
):
sentence, sent_id, lemmas, tags, orths, y_pred = cols[0], cols[1], \
cols[2], cols[3], cols[4], cols[5]
queue_out.put([sentence, orths, [], lemmas, tags, sent_id, y_pred])
def data_saver(queue_in, queue_log, output_file, output_dir, cases_nbr, queues_kill, to_kill_nbr):
processed_nbr, start = 0, time()
item = 1
test_y, pred_y = [], []
spoiled_sents = []
ch_suc, ch_all = 0, 0
end_time = time()
while item is not None:
item = queue_in.get()
if item is not None:
processed_nbr += 1
spoiled, class_test, class_pred = item
test_y.append(class_test)
pred_y.append(class_pred)
queue_log.put(f"Processed and saved {processed_nbr} in {time() - start} s")
ch_suc += spoiled[ATTACK_SUMMARY][SUCCEEDED]
ch_all += spoiled[ATTACK_SUMMARY][ALL]
spoiled_sents.append(spoiled)
if processed_nbr == cases_nbr:
for que_kill in queues_kill:
[que_kill.put(None) for _ in range(to_kill_nbr)]
if processed_nbr == cases_nbr - 10:
end_time = time()
if processed_nbr >= cases_nbr - 10:
if sum([q.qsize() for q in queues_kill]) == 0 and (time() - end_time) > 3600:
for que_kill in queues_kill:
[que_kill.put(None) for _ in range(to_kill_nbr)]
with open(output_file, 'wt') as fd:
fd.write(pd.DataFrame(spoiled_sents).to_json(
orient="records", lines=True))
np.savetxt(f"{output_dir}/metrics.txt", confusion_matrix(test_y, pred_y))
with open(f"{output_dir}/metrics.txt", mode="at") as fd:
fd.write('\n')
fd.write(classification_report(test_y, pred_y))
fd.write('\n')
fd.write(f"succeeded {ch_suc} all {ch_all}")
def classify_queue(queue_in, queue_out, queue_log, dataset_name, cuda_device):
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device
fun = getattr(
importlib.import_module(f"text_attacks.models.{dataset_name}"),
"get_classify_function",
)
classify_fun = fun(device="cuda" if torch.cuda.is_available() else "cpu")
queue_log.put(f"Classify device {'cuda' if torch.cuda.is_available() else 'cpu'}")
item = True
while item is not None:
item = queue_in.get()
queue_log.put("Classify got from queue")
if item is not None:
sent_id, org_sentence, y_pred, changed_sents = item
sentences = [sent[TEXT] for sent in changed_sents]
queue_log.put(f"Classifying sentences {len(sentences)}, id {sent_id}")
classified = classify_fun(sentences) if sentences else []
queue_out.put((sent_id, org_sentence, changed_sents, y_pred, classified))
queue_log.put(f"Classified sentences {sent_id}")
def log_queues(queues):
while True:
sizes = [q.qsize() for q in queues]
print(sizes, flush=True)
sleep(10)
def log_info_queue(queue):
print("Logging queue")
while True:
item = queue.get()
print(item)
def load_dir_files(dir_path):
result = {}
for filename in os.listdir(dir_path):
with open(os.path.join(dir_path, filename), 'r') as fin:
importance = json.load(fin)
result[filename.split("__")[0]] = {
word: importance[IMPORTANCE][idx]
for idx, word in importance[FEATURES].items()
if word
}
return result
def load_xai_importance(input_dir):
global_xai_dir = os.path.join(input_dir, "global")
local_xai_dir = os.path.join(input_dir, "local", "test")
local_dirs = os.listdir(local_xai_dir)
local_class_to_file = {dir_name: load_dir_files(os.path.join(local_xai_dir, dir_name))
for dir_name in local_dirs}
local_file_to_class = defaultdict(dict)
for c_name in local_dirs:
for f_name, value_df in local_class_to_file[c_name].items():
local_file_to_class[f_name][c_name] = value_df
return load_dir_files(global_xai_dir), local_file_to_class
@click.command()
@click.option(
"--dataset_name",
help="Dataset name",
type=str,
)
def main(dataset_name: str):
@click.option(
"--attack_type",
help="Attack type",
type=str,
)
def main(dataset_name: str, attack_type: str):
"""Downloads the dataset to the output directory."""
lang = {
"enron_spam": "en",
......@@ -49,32 +182,89 @@ def main(dataset_name: str):
"20_news": "en",
"wiki_pl": "pl",
}[dataset_name]
output_dir = f"data/results/{dataset_name}"
input_file = f"data/preprocessed/{dataset_name}/test.jsonl"
xai_global, xai_local = {}, {}
if "attack_xai" in attack_type:
importance = load_xai_importance(f"data/explanations/{dataset_name}")
xai_global, xai_local = importance[0], importance[1]
xai_sub = 5
params = {
"attack_textfooler": [lang, SYNONYM],
"attack_textfooler_discard": [lang, DISCARD],
"attack_basic": [lang, 0.5, 0.4, 0.3], # prawopodobieństwa spacji > usunięcia znaku > usunięcia słowa
"attack_xai": [lang, xai_global, xai_local, GLOBAL, SYNONYM, xai_sub],
"attack_xai_discard": [lang, xai_global, xai_local, GLOBAL, DISCARD, xai_sub],
"attack_xai_local": [lang, xai_global, xai_local, LOCAL, SYNONYM, xai_sub],
"attack_xai_discard_local": [lang, xai_global, xai_local, LOCAL, DISCARD, xai_sub]
}[attack_type]
output_dir = f"data/results/{attack_type}/{dataset_name}/"
input_file = f"data/classification/{dataset_name}/test.jsonl"
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "test.jsonl")
classify = get_classify_function(dataset_name=dataset_name)
dataset_df = pd.read_json(input_file, lines=True)
spoiled, results = [], []
similarity, max_sub = 0.95, 1
classes = classify(dataset_df[TEXT].tolist())
attack = TextFooler(lang)
max_sub = 1
for i, cols in tqdm(
dataset_df[[TEXT, LEMMAS, TAGS, ORTHS]].iterrows(), total=len(dataset_df)
):
sentence, lemmas, tags, orths = cols[0], cols[1], cols[2], cols[3]
changed_sent = attack.spoil(sentence, [], lemmas, tags, orths, similarity, max_sub)
if changed_sent:
spoiled.append(process(changed_sent, classes[i], classify))
with open(output_path, mode="wt") as fd:
fd.write(
pd.DataFrame({"spoiled": spoiled}).to_json(
orient="records", lines=True
)
)
m = Manager()
queues = [m.Queue(maxsize=QUEUE_SIZE) for _ in range(6)]
sim = Similarity(queues[5], 0.95, "distiluse-base-multilingual-cased-v1")
processes = [
Process(target=data_producer, args=(queues[0], dataset_df,)), # loading data file_in -> 0
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
# Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
# Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
# Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
# Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
# Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)),
# spoiling 0 -> 1
Process(target=filter_similarity_queue, args=(queues[1], queues[2], queues[5], sim)),
Process(target=filter_similarity_queue, args=(queues[1], queues[2], queues[5], sim)), # cosim 1 -> 2
Process(target=classify_queue, args=(queues[2], queues[3], queues[5], dataset_name, "6")),
Process(target=classify_queue, args=(queues[2], queues[3], queues[5], dataset_name, "4")),
# classify changed 2 -> 3
Process(target=run_queue, args=(queues[3], queues[4], queues[5], process,)), # process 3 -> 4
Process(target=data_saver, args=(queues[4], queues[5], output_path, output_dir, len(dataset_df), queues, 30))
# saving 4 -> file_out
]
[p.start() for p in processes]
log_que = Thread(target=log_queues, args=(queues[:5],))
log_que.daemon = True
log_que.start()
info_que = Thread(target=log_info_queue, args=(queues[5],))
info_que.daemon = True
info_que.start()
# wait for all processes to finish
[p.join() for p in processes]
log_que.join(timeout=0.5)
info_que.join(timeout=0.5)
if __name__ == "__main__":
......
......@@ -23,7 +23,7 @@ def get_classify_function(device="cpu"):
logits = list()
i = 0
for chunk in tqdm(
[texts[pos:pos + 256] for pos in range(0, len(texts), 256)]
[texts[pos:pos + 128] for pos in range(0, len(texts), 128)]
):
encoded_inputs = tokenizer(
chunk,
......
......@@ -23,7 +23,7 @@ def get_classify_function(device="cpu"):
logits = list()
i = 0
for chunk in tqdm(
[texts[pos:pos + 256] for pos in range(0, len(texts), 256)]
[texts[pos:pos + 128] for pos in range(0, len(texts), 128)]
):
encoded_inputs = tokenizer(
chunk,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment