diff --git a/data/results/attack_basic/.gitignore b/data/results/attack_basic/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..bef990db996c16e172365e026cfc73df61bafa60 --- /dev/null +++ b/data/results/attack_basic/.gitignore @@ -0,0 +1 @@ +/wiki_pl diff --git a/data/results/attack_textfooler/.gitignore b/data/results/attack_textfooler/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e69587204a79358e2809c210d36b6519d6be3b81 --- /dev/null +++ b/data/results/attack_textfooler/.gitignore @@ -0,0 +1,3 @@ +/enron_spam +/wiki_pl +/20_news diff --git a/data/results/attack_textfooler_discard/.gitignore b/data/results/attack_textfooler_discard/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..70a2cd627e76cfd8acff2853e1e4706c56373b8c --- /dev/null +++ b/data/results/attack_textfooler_discard/.gitignore @@ -0,0 +1,2 @@ +/wiki_pl +/enron_spam diff --git a/dvc.lock b/dvc.lock index 18c0a65a0f2917d3a74fb3b5ea84d4e9abde9187..93146df62808b9fffe01589b24079b85e22f3e7a 100644 --- a/dvc.lock +++ b/dvc.lock @@ -17,8 +17,8 @@ stages: --output_dir data/models/enron_spam deps: - path: data/preprocessed/enron_spam - md5: b75efba1a62182dc8ac32acd1faf92ed.dir - size: 61709260 + md5: 30c63efbc615347ddcb5f61e011113bd.dir + size: 65971374 nfiles: 3 - path: experiments/scripts/get_model.py md5: 5050f51b4019bba97af47971f6c7cab4 @@ -37,16 +37,16 @@ stages: size: 18505614 nfiles: 6 - path: data/preprocessed/enron_spam/ - md5: b75efba1a62182dc8ac32acd1faf92ed.dir - size: 61709260 + md5: 30c63efbc615347ddcb5f61e011113bd.dir + size: 65971374 nfiles: 3 - path: experiments/scripts/classify.py md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 size: 1181 outs: - path: data/classification/enron_spam - md5: 0450c0b672bc4a5db3cc7be2dac786bd.dir - size: 10674882 + md5: 5de1a2fcbae0de94f5fbfd2bb747d919.dir + size: 14585920 nfiles: 2 explain@enron_spam: cmd: PYTHONPATH=. python experiments/scripts/explain.py --dataset_name enron_spam @@ -134,16 +134,16 @@ stages: size: 501711136 nfiles: 7 - path: data/preprocessed/wiki_pl/ - md5: 3e9b2e1e0542777e0a751d9d7f7f4241.dir - size: 55380570 + md5: 0014b9bb52913cbc9a568d237ea2207b.dir + size: 65553079 nfiles: 3 - path: experiments/scripts/classify.py md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 size: 1181 outs: - path: data/classification/wiki_pl - md5: 515330772505f489b55686545bcf23a0.dir - size: 34103198 + md5: 88c3cea96b2cb3ddda1a82037bf6130a.dir + size: 44196727 nfiles: 2 preprocess_dataset@20_news: cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name 20_news @@ -169,16 +169,16 @@ stages: size: 439008808 nfiles: 6 - path: data/preprocessed/20_news/ - md5: 1ed5ef2dabe4bc05f7377175ed11137b.dir - size: 46845669 + md5: 20da0980e52df537e5b7ca5db0305879.dir + size: 58582060 nfiles: 3 - path: experiments/scripts/classify.py md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 size: 1181 outs: - path: data/classification/20_news - md5: 6831f104f7c20541548fe72250c45706.dir - size: 31286120 + md5: b73611443c4189af91b827c083f37e0b.dir + size: 42897496 nfiles: 2 attack_basic@enron_spam: cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name enron_spam @@ -239,3 +239,123 @@ stages: md5: c8ba90f9757a4e3cc4843d3791ef2446.dir size: 232912969 nfiles: 14041 + attack_basic@wiki_pl: + cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name wiki_pl + --attack_type attack_basic + deps: + - path: data/models/wiki_pl + md5: fd453042628fb09c080ef05d34a32cce.dir + size: 501711136 + nfiles: 7 + - path: data/preprocessed/wiki_pl + md5: 0014b9bb52913cbc9a568d237ea2207b.dir + size: 65553079 + nfiles: 3 + - path: experiments/scripts/attack.py + md5: 702997933e5af85d09d8286a14e2cc05 + size: 2486 + outs: + - path: data/results/attack_basic/wiki_pl/ + md5: f118a41e391b5f713f77611140f2f2cc.dir + size: 1 + nfiles: 1 + attack_textfooler@enron_spam: + cmd: 'PYTHONPATH=. python experiments/scripts/attack.py --dataset_name enron_spam + --attack_type attack_textfooler ' + deps: + - path: data/models/enron_spam + md5: 3e16b22f59532c66beeadea958e0579a.dir + size: 18505614 + nfiles: 6 + - path: data/preprocessed/enron_spam + md5: 30c63efbc615347ddcb5f61e011113bd.dir + size: 65971374 + nfiles: 3 + - path: experiments/scripts/attack.py + md5: b9d9a4d9fcba1cb4dfbb554ecc3e26fb + size: 10083 + outs: + - path: data/results/attack_textfooler/enron_spam/ + md5: 10ecd4c940e8df1058465048ffbe78d4.dir + size: 3291044 + nfiles: 2 + attack_textfooler@20_news: + cmd: 'PYTHONPATH=. python experiments/scripts/attack.py --dataset_name 20_news + --attack_type attack_textfooler ' + deps: + - path: data/models/20_news + md5: 43d68a67ecb8149bd6bf50db9767cb64.dir + size: 439008808 + nfiles: 6 + - path: data/preprocessed/20_news + md5: 20da0980e52df537e5b7ca5db0305879.dir + size: 58582060 + nfiles: 3 + - path: experiments/scripts/attack.py + md5: 4fe9c6210ce0f3be66b54c2565ad2daa + size: 10132 + outs: + - path: data/results/attack_textfooler/20_news/ + md5: 007aba16e343ca283180c7bc7b9a0190.dir + size: 93666157 + nfiles: 2 + attack_textfooler@wiki_pl: + cmd: 'PYTHONPATH=. python experiments/scripts/attack.py --dataset_name wiki_pl + --attack_type attack_textfooler ' + deps: + - path: data/classification/wiki_pl + md5: 88c3cea96b2cb3ddda1a82037bf6130a.dir + size: 44196727 + nfiles: 2 + - path: data/models/wiki_pl + md5: fd453042628fb09c080ef05d34a32cce.dir + size: 501711136 + nfiles: 7 + - path: experiments/scripts/attack.py + md5: 2977363ba8806c393498f98d5733c013 + size: 11497 + outs: + - path: data/results/attack_textfooler/wiki_pl/ + md5: eccc12b9a5ae383ea02067cd1955753e.dir + size: 20293404 + nfiles: 2 + attack_textfooler_discard@wiki_pl: + cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name wiki_pl + --attack_type attack_textfooler_discard + deps: + - path: data/classification/wiki_pl + md5: 88c3cea96b2cb3ddda1a82037bf6130a.dir + size: 44196727 + nfiles: 2 + - path: data/models/wiki_pl + md5: fd453042628fb09c080ef05d34a32cce.dir + size: 501711136 + nfiles: 7 + - path: experiments/scripts/attack.py + md5: 2b9ddc1ff1f56855ff667171ba04ed78 + size: 11606 + outs: + - path: data/results/attack_textfooler_discard/wiki_pl/ + md5: e41122c3cdf76ad1b163aba49acce0f0.dir + size: 14396685 + nfiles: 2 + attack_textfooler_discard@enron_spam: + cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name enron_spam + --attack_type attack_textfooler_discard + deps: + - path: data/classification/enron_spam + md5: 5de1a2fcbae0de94f5fbfd2bb747d919.dir + size: 14585920 + nfiles: 2 + - path: data/models/enron_spam + md5: 3e16b22f59532c66beeadea958e0579a.dir + size: 18505614 + nfiles: 6 + - path: experiments/scripts/attack.py + md5: 2b9ddc1ff1f56855ff667171ba04ed78 + size: 11606 + outs: + - path: data/results/attack_textfooler_discard/enron_spam/ + md5: 8a78484bd77916f82021a72338342a44.dir + size: 2816160 + nfiles: 2 diff --git a/dvc.yaml b/dvc.yaml index 6862a88561c9e3c9841ae258749e018d430215ad..03c045f1148e592ae0fa0bb3a5da13f85af1d9ff 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -77,6 +77,38 @@ stages: - data/preprocessed/${item} outs: - data/explanations/${item}/ + attack_textfooler: + foreach: + - enron_spam + - 20_news + - wiki_pl + do: + wdir: . + cmd: >- + PYTHONPATH=. python experiments/scripts/attack.py + --dataset_name ${item} --attack_type attack_textfooler + deps: + - experiments/scripts/attack.py + - data/models/${item} + - data/classification/${item} + outs: + - data/results/attack_textfooler/${item}/ + attack_textfooler_discard: + foreach: + - enron_spam + - 20_news + - wiki_pl + do: + wdir: . + cmd: >- + PYTHONPATH=. python experiments/scripts/attack.py + --dataset_name ${item} --attack_type attack_textfooler_discard + deps: + - experiments/scripts/attack.py + - data/models/${item} + - data/classification/${item} + outs: + - data/results/attack_textfooler_discard/${item}/ attack_basic: foreach: - enron_spam @@ -86,11 +118,11 @@ stages: wdir: . cmd: >- PYTHONPATH=. python experiments/scripts/attack.py - --dataset_name ${item} + --dataset_name ${item} --attack_type attack_basic deps: - experiments/scripts/attack.py - data/models/${item} - - data/preprocessed/${item} + - data/classification/${item} outs: - data/results/attack_basic/${item}/ diff --git a/experiments/scripts/attack.py b/experiments/scripts/attack.py index 721d46ffebbfa51d4fd9289e73729b54f6d8c397..78f97ca5604985d42b61d1b8246299a7a0798fa3 100644 --- a/experiments/scripts/attack.py +++ b/experiments/scripts/attack.py @@ -1,7 +1,7 @@ """Script for running attacks on datasets.""" import importlib import json - +from collections import defaultdict import click import pandas as pd import os @@ -22,6 +22,7 @@ LEMMAS = "lemmas" TAGS = "tags" ORTHS = "orths" ID = "id" +PRED = "pred_label" ATTACK_SUMMARY = "attacks_summary" ATTACK_SUCCEEDED = "attacks_succeeded" @@ -36,6 +37,10 @@ ACTUAL = "actual" COSINE_SCORE = "cosine_score" CLASS = "class" QUEUE_SIZE = 1000 +FEATURES = "features" +IMPORTANCE = "importance" +SYNONYM = "synonym" +DISCARD = "discard" os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -49,19 +54,20 @@ DEFAULT_RES = { def data_producer(queue_out, dataset_df): for i, cols in tqdm( - dataset_df[[TEXT, ID, LEMMAS, TAGS, ORTHS]].iterrows(), total=len(dataset_df) + dataset_df[[TEXT, ID, LEMMAS, TAGS, ORTHS, PRED]].iterrows(), total=len(dataset_df) ): - sentence, sent_id, lemmas, tags, orths = cols[0], cols[1], \ - cols[2], cols[3], cols[4] - queue_out.put([sentence, orths, [], lemmas, tags, sent_id]) + sentence, sent_id, lemmas, tags, orths, y_pred = cols[0], cols[1], \ + cols[2], cols[3], cols[4], cols[5] + queue_out.put([sentence, orths, [], lemmas, tags, sent_id, y_pred]) -def data_saver(queue_in, queue_log, output_file, output_dir, cases_nbr, que_kill, to_kill_nbr): +def data_saver(queue_in, queue_log, output_file, output_dir, cases_nbr, queues_kill, to_kill_nbr): processed_nbr, start = 0, time() item = 1 test_y, pred_y = [], [] spoiled_sents = [] ch_suc, ch_all = 0, 0 + end_time = time() while item is not None: item = queue_in.get() if item is not None: @@ -74,13 +80,17 @@ def data_saver(queue_in, queue_log, output_file, output_dir, cases_nbr, que_kill ch_all += spoiled[ATTACK_SUMMARY][ALL] spoiled_sents.append(spoiled) if processed_nbr == cases_nbr: - [que_kill.put(None) for _ in range(to_kill_nbr)] - with open(output_file, 'a') as fd: - fd.write( - pd.DataFrame(spoiled_sents).to_json( - orient="records", lines=True - ) - ) + for que_kill in queues_kill: + [que_kill.put(None) for _ in range(to_kill_nbr)] + if processed_nbr == cases_nbr - 10: + end_time = time() + if processed_nbr >= cases_nbr - 10: + if sum([q.qsize() for q in queues_kill]) == 0 and (time() - end_time) > 3600: + for que_kill in queues_kill: + [que_kill.put(None) for _ in range(to_kill_nbr)] + with open(output_file, 'wt') as fd: + fd.write(pd.DataFrame(spoiled_sents).to_json( + orient="records", lines=True)) np.savetxt(f"{output_dir}/metrics.txt", confusion_matrix(test_y, pred_y)) with open(f"{output_dir}/metrics.txt", mode="at") as fd: fd.write('\n') @@ -89,7 +99,8 @@ def data_saver(queue_in, queue_log, output_file, output_dir, cases_nbr, que_kill fd.write(f"succeeded {ch_suc} all {ch_all}") -def classify_queue(queue_in, queue_out, queue_log, dataset_name): +def classify_queue(queue_in, queue_out, queue_log, dataset_name, cuda_device): + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device fun = getattr( importlib.import_module(f"text_attacks.models.{dataset_name}"), "get_classify_function", @@ -101,21 +112,19 @@ def classify_queue(queue_in, queue_out, queue_log, dataset_name): item = queue_in.get() queue_log.put("Classify got from queue") if item is not None: - sent_id, org_sentence, changed_sents = item - sentences = [org_sentence] - sentences.extend([sent[TEXT] for sent in changed_sents]) + sent_id, org_sentence, y_pred, changed_sents = item + sentences = [sent[TEXT] for sent in changed_sents] queue_log.put(f"Classifying sentences {len(sentences)}, id {sent_id}") - classified = classify_fun(sentences) - queue_out.put((sent_id, org_sentence, changed_sents, classified)) + classified = classify_fun(sentences) if sentences else [] + queue_out.put((sent_id, org_sentence, changed_sents, y_pred, classified)) queue_log.put(f"Classified sentences {sent_id}") - queue_out.put(None) def log_queues(queues): while True: sizes = [q.qsize() for q in queues] print(sizes, flush=True) - sleep(2) + sleep(10) def log_info_queue(queue): @@ -125,6 +134,31 @@ def log_info_queue(queue): print(item) +def load_dir_files(dir_path): + result = {} + for filename in os.listdir(dir_path): + with open(os.path.join(dir_path, filename), 'r') as fin: + importance = json.load(fin) + result[filename.split("__")[0]] = { + word: importance[IMPORTANCE][idx] + for idx, word in importance[FEATURES].items() + if word + } + return result + +def load_xai_importance(input_dir): + global_xai_dir = os.path.join(input_dir, "global") + local_xai_dir = os.path.join(input_dir, "local") + local_dirs = os.listdir(local_xai_dir) + local_class_to_file = {dir_name: load_dir_files(os.path.join(local_xai_dir, dir_name)) + for dir_name in local_dirs} + local_file_to_class = defaultdict(dict) + for c_name in local_dirs: + for f_name, value_df in local_class_to_file[c_name].items(): + local_file_to_class[f_name][c_name] = value_df + return load_dir_files(global_xai_dir), local_file_to_class + + @click.command() @click.option( "--dataset_name", @@ -145,42 +179,77 @@ def main(dataset_name: str, attack_type: str): "wiki_pl": "pl", }[dataset_name] params = { - "attack_textfooler": [lang], - "attack_basic": [lang, 0.5, 0.4, 0.3] + "attack_textfooler": [lang, SYNONYM], + "attack_textfooler_discard": [lang, DISCARD], + "attack_basic": [lang, 0.5, 0.4, 0.3] # prawopodobieństwa spacji > usunięcia znaku > usunięcia słowa }[attack_type] output_dir = f"data/results/{attack_type}/{dataset_name}/" - input_file = f"data/preprocessed/{dataset_name}/test.jsonl" + input_file = f"data/classification/{dataset_name}/test.jsonl" os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, "test.jsonl") dataset_df = pd.read_json(input_file, lines=True) + + # xai_global, xai_local = load_xai_importance( + # f"data/explanations/{dataset_name}" + # ) if attack_type == "attack_xai" else {}, {} + max_sub = 1 m = Manager() queues = [m.Queue(maxsize=QUEUE_SIZE) for _ in range(6)] sim = Similarity(queues[5], 0.95, "distiluse-base-multilingual-cased-v1") - processes = [Process(target=data_producer, args=(queues[0], dataset_df,)), # loading data file_in -> 0 - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), # spoiling 0 -> 1 - Process(target=filter_similarity_queue, args=(queues[1], queues[2], queues[5], sim)), # cosim 1 -> 2 - Process(target=classify_queue, args=(queues[2], queues[3], queues[5], dataset_name, )), # classify changed 2 -> 3 - Process(target=run_queue, args=(queues[3], queues[4], queues[5], process,)), # process 3 -> 4 - Process(target=data_saver, args=(queues[4], queues[5], output_path, output_dir, len(dataset_df), - queues[0], 11))] # saving 4 -> file_out + processes = [ + Process(target=data_producer, args=(queues[0], dataset_df,)), # loading data file_in -> 0 + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + # Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + # Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + # Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + # Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + # Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), + # spoiling 0 -> 1 + Process(target=filter_similarity_queue, args=(queues[1], queues[2], queues[5], sim)), + Process(target=filter_similarity_queue, args=(queues[1], queues[2], queues[5], sim)), # cosim 1 -> 2 + Process(target=classify_queue, args=(queues[2], queues[3], queues[5], dataset_name, "6")), + Process(target=classify_queue, args=(queues[2], queues[3], queues[5], dataset_name, "4")), + # classify changed 2 -> 3 + Process(target=run_queue, args=(queues[3], queues[4], queues[5], process,)), # process 3 -> 4 + Process(target=data_saver, args=(queues[4], queues[5], output_path, output_dir, len(dataset_df), queues, 30)) + # saving 4 -> file_out + ] [p.start() for p in processes] - log_que = Thread(target=log_queues, args=(queues[:5], )) + log_que = Thread(target=log_queues, args=(queues[:5],)) log_que.daemon = True log_que.start() - info_que = Thread(target=log_info_queue, args=(queues[5], )) + info_que = Thread(target=log_info_queue, args=(queues[5],)) info_que.daemon = True info_que.start() # wait for all processes to finish diff --git a/text_attacks/models/20_news.py b/text_attacks/models/20_news.py index 53712fa403a55b03a12aaf6962cdfcd6f4c503ce..8bebdac020574212b917f3a39e55c5b1f1703cb1 100644 --- a/text_attacks/models/20_news.py +++ b/text_attacks/models/20_news.py @@ -23,7 +23,7 @@ def get_classify_function(device="cpu"): logits = list() i = 0 for chunk in tqdm( - [texts[pos:pos + 256] for pos in range(0, len(texts), 256)] + [texts[pos:pos + 128] for pos in range(0, len(texts), 128)] ): encoded_inputs = tokenizer( chunk, diff --git a/text_attacks/models/wiki_pl.py b/text_attacks/models/wiki_pl.py index 1ad153955d6e5acfc89f4f922465fb624c1ecf5d..18e9f9f5ee263dfaa6a12183d8e179cf5f50adf5 100644 --- a/text_attacks/models/wiki_pl.py +++ b/text_attacks/models/wiki_pl.py @@ -23,7 +23,7 @@ def get_classify_function(device="cpu"): logits = list() i = 0 for chunk in tqdm( - [texts[pos:pos + 256] for pos in range(0, len(texts), 256)] + [texts[pos:pos + 128] for pos in range(0, len(texts), 128)] ): encoded_inputs = tokenizer( chunk,