diff --git a/data/classification/.gitignore b/data/classification/.gitignore index 27e43f9f65497c9f7467263b0282d95ee264628b..24679632493487acc23ca16bcf24175dee4a9cb1 100644 --- a/data/classification/.gitignore +++ b/data/classification/.gitignore @@ -2,3 +2,4 @@ /wiki_pl /20_news /poleval +/ag_news diff --git a/data/datasets/.gitignore b/data/datasets/.gitignore index 43bd163e72f8f4c3216a67dd67aaa506107fbda5..0cb29715b77c64d56542b9c3c917f05a5066c7ad 100644 --- a/data/datasets/.gitignore +++ b/data/datasets/.gitignore @@ -2,3 +2,4 @@ /20_news /poleval /wiki_pl +/ag_news diff --git a/data/explanations/.gitignore b/data/explanations/.gitignore index 27e43f9f65497c9f7467263b0282d95ee264628b..24679632493487acc23ca16bcf24175dee4a9cb1 100644 --- a/data/explanations/.gitignore +++ b/data/explanations/.gitignore @@ -2,3 +2,4 @@ /wiki_pl /20_news /poleval +/ag_news diff --git a/data/models/.gitignore b/data/models/.gitignore index f37e1f64e859cabcf44bfc8c94822fd01f9fa0c9..3919f3faee5ad2083c721d146e04f59598fdfd86 100644 --- a/data/models/.gitignore +++ b/data/models/.gitignore @@ -2,3 +2,4 @@ /20_news /wiki_pl /poleval +/ag_news diff --git a/data/preprocessed/.gitignore b/data/preprocessed/.gitignore index 7fdd029740deec7ff19796ce6b503b9ca1a4c89b..e682bdb32d14353df93b156e4a08563cd15583e3 100644 --- a/data/preprocessed/.gitignore +++ b/data/preprocessed/.gitignore @@ -2,3 +2,4 @@ /enron_spam /wiki_pl /20_news +/ag_news diff --git a/data/reduced/.gitignore b/data/reduced/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4b5836b852092e0470f6b49f5e4f5ff316896fcf --- /dev/null +++ b/data/reduced/.gitignore @@ -0,0 +1,4 @@ +/wiki_pl +/enron_spam +/poleval +/ag_news diff --git a/data/results/attack_xai_char_discard/.gitignore b/data/results/attack_xai_char_discard/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d1220947ea19b5959f7c9b5ef897466f82de0956 --- /dev/null +++ b/data/results/attack_xai_char_discard/.gitignore @@ -0,0 +1,4 @@ +/poleval +/enron_spam +/20_news +/wiki_pl diff --git a/dvc.lock b/dvc.lock index bbd98c8f6e63f175cabfd778145467fe945bf8bd..9d048a30ca12a97d6bd72cecfed9fc0e2e0ff75f 100644 --- a/dvc.lock +++ b/dvc.lock @@ -36,17 +36,17 @@ stages: md5: 3e16b22f59532c66beeadea958e0579a.dir size: 18505614 nfiles: 6 - - path: data/preprocessed/enron_spam/ - md5: 99d604f84516cee94948054a97ffec5e.dir - size: 71403809 - nfiles: 3 + - path: data/reduced/enron_spam/ + md5: ee6f2c141cd68b86e620f022f0ca0b5a.dir + size: 12933383 + nfiles: 1 - path: experiments/scripts/classify.py - md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 - size: 1181 + md5: 8c4dc8293bc7d7f8b87b4788cea1b81e + size: 1176 outs: - path: data/classification/enron_spam - md5: 5de1a2fcbae0de94f5fbfd2bb747d919.dir - size: 14585920 + md5: 7e0017fe7f10a3a8bbd2c3dcf355cb34.dir + size: 12968818 nfiles: 2 explain@enron_spam: cmd: PYTHONPATH=. python experiments/scripts/explain.py --dataset_name enron_spam @@ -88,8 +88,8 @@ stages: size: 1688836 nfiles: 3 - path: experiments/scripts/tag_dataset.py - md5: ebadced7a031a31bdaf935d2b22e5e05 - size: 4632 + md5: 19015179757440a5639ee263dcabfde3 + size: 5010 outs: - path: data/preprocessed/poleval/ md5: b0ea9f0ad1dba6d3b474c0a3cedf866e.dir @@ -103,8 +103,8 @@ stages: size: 53096069 nfiles: 3 - path: experiments/scripts/tag_dataset.py - md5: ebadced7a031a31bdaf935d2b22e5e05 - size: 4632 + md5: 19015179757440a5639ee263dcabfde3 + size: 5010 outs: - path: data/preprocessed/enron_spam/ md5: 99d604f84516cee94948054a97ffec5e.dir @@ -118,8 +118,8 @@ stages: size: 29115538 nfiles: 3 - path: experiments/scripts/tag_dataset.py - md5: ebadced7a031a31bdaf935d2b22e5e05 - size: 4632 + md5: 19015179757440a5639ee263dcabfde3 + size: 5010 outs: - path: data/preprocessed/wiki_pl/ md5: 066634606f832b6c9d1db95293de7e04.dir @@ -133,17 +133,17 @@ stages: md5: fd453042628fb09c080ef05d34a32cce.dir size: 501711136 nfiles: 7 - - path: data/preprocessed/wiki_pl/ - md5: 066634606f832b6c9d1db95293de7e04.dir - size: 77818549 - nfiles: 3 + - path: data/reduced/wiki_pl/ + md5: 30359a1d253a3c1cee7affa7ae365ef3.dir + size: 31644651 + nfiles: 1 - path: experiments/scripts/classify.py - md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 - size: 1181 + md5: 8c4dc8293bc7d7f8b87b4788cea1b81e + size: 1176 outs: - path: data/classification/wiki_pl - md5: 88c3cea96b2cb3ddda1a82037bf6130a.dir - size: 44196727 + md5: 8455064b5b3e39ffc35d3ac712b41c2d.dir + size: 31721772 nfiles: 2 preprocess_dataset@20_news: cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name 20_news @@ -153,8 +153,8 @@ stages: size: 23460358 nfiles: 3 - path: experiments/scripts/tag_dataset.py - md5: ebadced7a031a31bdaf935d2b22e5e05 - size: 4632 + md5: 4282e97ce099dddd615af62c45eb655a + size: 5074 outs: - path: data/preprocessed/20_news/ md5: a3d2da9ac72423e555ae7ed051741b30.dir @@ -168,13 +168,13 @@ stages: md5: 43d68a67ecb8149bd6bf50db9767cb64.dir size: 439008808 nfiles: 6 - - path: data/preprocessed/20_news/ - md5: a3d2da9ac72423e555ae7ed051741b30.dir - size: 69405970 - nfiles: 3 + - path: data/reduced/20_news/ + md5: d751713988987e9331980363e24189ce.dir + size: 0 + nfiles: 0 - path: experiments/scripts/classify.py - md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 - size: 1181 + md5: 8c4dc8293bc7d7f8b87b4788cea1b81e + size: 1176 outs: - path: data/classification/20_news md5: b73611443c4189af91b827c083f37e0b.dir @@ -433,12 +433,12 @@ stages: size: 501711136 nfiles: 7 - path: experiments/scripts/attack.py - md5: 87f54ee4e2a08f1259d9d8b2d01fe1b9 - size: 12061 + md5: fa754531f756242413103dd4a039ecbb + size: 10650 outs: - path: data/results/attack_xai/wiki_pl/ - md5: e24c456f63d8e13b92fcab51e0726141.dir - size: 8287334 + md5: ff52c5a1f070d3b935437f149ba0ef1f.dir + size: 387376283 nfiles: 2 attack_xai_local@wiki_pl: cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name wiki_pl @@ -684,17 +684,17 @@ stages: md5: 8f806cb1b2eb0dd097811d42e4bf9c2d.dir size: 501609312 nfiles: 7 - - path: data/preprocessed/poleval/ - md5: b0ea9f0ad1dba6d3b474c0a3cedf866e.dir - size: 2812175 - nfiles: 3 + - path: data/reduced/poleval/ + md5: 3586ee4d363638e7627efce92e55e6b0.dir + size: 755230 + nfiles: 1 - path: experiments/scripts/classify.py - md5: 6fc1a6a0a11ba6cd99a8b6625a96d9f5 - size: 1181 + md5: 8c4dc8293bc7d7f8b87b4788cea1b81e + size: 1176 outs: - path: data/classification/poleval - md5: f207458f9365a74672c31b5ffb2a83af.dir - size: 787456 + md5: ad7f82ab04c69cd9e1d18b17e6d94d82.dir + size: 775288 nfiles: 2 attack_xai_local@poleval: cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name poleval @@ -709,12 +709,12 @@ stages: size: 501609312 nfiles: 7 - path: experiments/scripts/attack.py - md5: 9518ec9af275d6a12fede47dff6767e1 - size: 11530 + md5: c464fe658004e1d0b2f45bf0dbdbfb42 + size: 11947 outs: - path: data/results/attack_xai_local/poleval/ - md5: 7597e90d1ddfa82615e79f6821d90e1b.dir - size: 188754 + md5: 3f47355fe91d8df7cb5b598da22bccdc.dir + size: 275308 nfiles: 2 attack_xai_discard_local@poleval: cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name poleval @@ -749,12 +749,12 @@ stages: size: 501609312 nfiles: 7 - path: experiments/scripts/attack.py - md5: 9518ec9af275d6a12fede47dff6767e1 - size: 11530 + md5: 5c37737865b0e3524be76396330e683f + size: 9556 outs: - path: data/results/attack_xai/poleval/ - md5: d368af0f7069a5f43b9cf6f3a0422522.dir - size: 189001 + md5: 1ed13c64a5ae2ed24598be64b36ad26e.dir + size: 84472229 nfiles: 2 attack_xai_discard@poleval: cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name poleval @@ -769,12 +769,12 @@ stages: size: 501609312 nfiles: 7 - path: experiments/scripts/attack.py - md5: 9518ec9af275d6a12fede47dff6767e1 - size: 11530 + md5: e46bca05ac076c87522c7318257026ba + size: 10247 outs: - path: data/results/attack_xai_discard/poleval/ - md5: 83bbb41d4e1303329330c981cf50ece6.dir - size: 188316 + md5: 70dbf6dacdfb2be5e41e462bd9b6ad8d.dir + size: 8606443 nfiles: 2 attack_basic@poleval: cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name poleval @@ -789,12 +789,12 @@ stages: size: 501609312 nfiles: 7 - path: experiments/scripts/attack.py - md5: 9518ec9af275d6a12fede47dff6767e1 - size: 11530 + md5: d2a15a7f4c3d065c67db7caf4aaa0dae + size: 9556 outs: - path: data/results/attack_basic/poleval/ - md5: 2ba20316e75c6401e764a42c8c9ba02d.dir - size: 220962 + md5: 70849358008ba01eebce555a4a1e1482.dir + size: 238609 nfiles: 2 attack_textfooler_discard@poleval: cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name poleval @@ -809,12 +809,12 @@ stages: size: 501609312 nfiles: 7 - path: experiments/scripts/attack.py - md5: 9518ec9af275d6a12fede47dff6767e1 - size: 11530 + md5: d2a15a7f4c3d065c67db7caf4aaa0dae + size: 9556 outs: - path: data/results/attack_textfooler_discard/poleval/ - md5: 71e521e256665812795d75e545ee4e9a.dir - size: 205890 + md5: 884f622d1b9a8e531583ff6dcabe3f95.dir + size: 1536088 nfiles: 2 attack_textfooler@poleval: cmd: 'PYTHONPATH=. python experiments/scripts/attack.py --dataset_name poleval @@ -829,10 +829,185 @@ stages: size: 501609312 nfiles: 7 - path: experiments/scripts/attack.py - md5: 9518ec9af275d6a12fede47dff6767e1 - size: 11530 + md5: 5c37737865b0e3524be76396330e683f + size: 9556 outs: - path: data/results/attack_textfooler/poleval/ - md5: 1cc24839e9c653182cb312f5c66a6a88.dir - size: 275318 + md5: 37e5f7b554c1959fbf1fabb84bf32ed8.dir + size: 89141787 + nfiles: 2 + attack_xai_char_discard@poleval: + cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name poleval + --attack_type attack_xai_char_discard + deps: + - path: data/classification/poleval + md5: f207458f9365a74672c31b5ffb2a83af.dir + size: 787456 nfiles: 2 + - path: data/models/poleval + md5: 8f806cb1b2eb0dd097811d42e4bf9c2d.dir + size: 501609312 + nfiles: 7 + - path: experiments/scripts/attack.py + md5: aa42fd50ddee64a0002e210270376b88 + size: 10247 + outs: + - path: data/results/attack_xai_char_discard/poleval/ + md5: 134ee8022b841597f6a14796bdbbcf30.dir + size: 142837300 + nfiles: 2 + attack_xai_char_discard@wiki_pl: + cmd: PYTHONPATH=. python experiments/scripts/attack.py --dataset_name wiki_pl + --attack_type attack_xai_char_discard + deps: + - path: data/classification/wiki_pl + md5: 88c3cea96b2cb3ddda1a82037bf6130a.dir + size: 44196727 + nfiles: 2 + - path: data/models/wiki_pl + md5: fd453042628fb09c080ef05d34a32cce.dir + size: 501711136 + nfiles: 7 + - path: experiments/scripts/attack.py + md5: 6a16ddc830a8ba50d01412600a19a4ea + size: 11037 + outs: + - path: data/results/attack_xai_char_discard/wiki_pl/ + md5: db1b512415d278115f76a74112f31c53.dir + size: 57649801 + nfiles: 2 + reduce@wiki_pl: + cmd: PYTHONPATH=. python experiments/scripts/token_reduce.py --dataset_name wiki_pl + --output_dir data/reduced/wiki_pl + deps: + - path: data/models/wiki_pl/ + md5: fd453042628fb09c080ef05d34a32cce.dir + size: 501711136 + nfiles: 7 + - path: data/preprocessed/wiki_pl/ + md5: 066634606f832b6c9d1db95293de7e04.dir + size: 77818549 + nfiles: 3 + - path: experiments/scripts/token_reduce.py + md5: 091c622fce2349978e61b8f6e4d22f27 + size: 4945 + outs: + - path: data/reduced/wiki_pl + md5: 30359a1d253a3c1cee7affa7ae365ef3.dir + size: 31644651 + nfiles: 1 + reduce@enron_spam: + cmd: PYTHONPATH=. python experiments/scripts/token_reduce.py --dataset_name enron_spam + --output_dir data/reduced/enron_spam + deps: + - path: data/models/enron_spam/ + md5: 3e16b22f59532c66beeadea958e0579a.dir + size: 18505614 + nfiles: 6 + - path: data/preprocessed/enron_spam/ + md5: 99d604f84516cee94948054a97ffec5e.dir + size: 71403809 + nfiles: 3 + - path: experiments/scripts/token_reduce.py + md5: 091c622fce2349978e61b8f6e4d22f27 + size: 4945 + outs: + - path: data/reduced/enron_spam + md5: ee6f2c141cd68b86e620f022f0ca0b5a.dir + size: 12933383 + nfiles: 1 + reduce@poleval: + cmd: PYTHONPATH=. python experiments/scripts/token_reduce.py --dataset_name poleval + --output_dir data/reduced/poleval + deps: + - path: data/models/poleval/ + md5: 8f806cb1b2eb0dd097811d42e4bf9c2d.dir + size: 501609312 + nfiles: 7 + - path: data/preprocessed/poleval/ + md5: b0ea9f0ad1dba6d3b474c0a3cedf866e.dir + size: 2812175 + nfiles: 3 + - path: experiments/scripts/token_reduce.py + md5: 091c622fce2349978e61b8f6e4d22f27 + size: 4945 + outs: + - path: data/reduced/poleval + md5: 3586ee4d363638e7627efce92e55e6b0.dir + size: 755230 + nfiles: 1 + preprocess_dataset@ag_news: + cmd: PYTHONPATH=. python experiments/scripts/tag_dataset.py --dataset_name ag_news + deps: + - path: data/datasets/ag_news/ + md5: 98d7192236c2f868e52f772e9a3aa01b.dir + size: 34020031 + nfiles: 3 + - path: experiments/scripts/tag_dataset.py + md5: 19015179757440a5639ee263dcabfde3 + size: 5010 + outs: + - path: data/preprocessed/ag_news/ + md5: 225096ec90dfb5085b6a8e3835681155.dir + size: 49998839 + nfiles: 3 + reduce@ag_news: + cmd: PYTHONPATH=. python experiments/scripts/token_reduce.py --dataset_name ag_news + --output_dir data/reduced/ag_news + deps: + - path: data/models/ag_news/ + md5: af9183a3115fd5f07ec2ba8e9086200b.dir + size: 438958577 + nfiles: 6 + - path: data/preprocessed/ag_news/ + md5: 225096ec90dfb5085b6a8e3835681155.dir + size: 49998839 + nfiles: 3 + - path: experiments/scripts/token_reduce.py + md5: 091c622fce2349978e61b8f6e4d22f27 + size: 4945 + outs: + - path: data/reduced/ag_news + md5: adac9c05622da5a0895de488a2dc80fa.dir + size: 11715622 + nfiles: 1 + classify@ag_news: + cmd: PYTHONPATH=. python experiments/scripts/classify.py --dataset_name ag_news + --output_dir data/classification/ag_news + deps: + - path: data/models/ag_news/ + md5: af9183a3115fd5f07ec2ba8e9086200b.dir + size: 438958577 + nfiles: 6 + - path: data/reduced/ag_news/ + md5: 5ab65584e3b005a19cb527fd69dfaf2c.dir + size: 11648332 + nfiles: 1 + - path: experiments/scripts/classify.py + md5: 8c4dc8293bc7d7f8b87b4788cea1b81e + size: 1176 + outs: + - path: data/classification/ag_news + md5: d9cabd3c3554172bf43b296079eec947.dir + size: 11806224 + nfiles: 2 + explain@ag_news: + cmd: PYTHONPATH=. python experiments/scripts/explain.py --dataset_name ag_news + --output_dir data/explanations/ag_news + deps: + - path: data/models/ag_news + md5: af9183a3115fd5f07ec2ba8e9086200b.dir + size: 438958577 + nfiles: 6 + - path: data/preprocessed/ag_news + md5: 225096ec90dfb5085b6a8e3835681155.dir + size: 49998839 + nfiles: 3 + - path: experiments/scripts/explain.py + md5: 2c1eca09f7cbdc5d93278b2cd27b126c + size: 4717 + outs: + - path: data/explanations/ag_news/ + md5: ae37f6468f0b44e40cf421e2544d2646.dir + size: 54861964 + nfiles: 30405 diff --git a/dvc.yaml b/dvc.yaml index 7baf31cbd02848793317d029589cb56bbbe4dc2b..81902f26c1e2cfe79908497c87957191acf60228 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -17,7 +17,7 @@ stages: foreach: - enron_spam - poleval - - 20_news + - ag_news - wiki_pl do: wdir: . @@ -44,11 +44,29 @@ stages: - data/preprocessed/${item} outs: - data/models/${item}/ + reduce: + foreach: + - enron_spam + - poleval + - ag_news + - wiki_pl + do: + wdir: . + cmd: >- + PYTHONPATH=. python experiments/scripts/token_reduce.py + --dataset_name ${item} + --output_dir data/reduced/${item} + deps: + - experiments/scripts/token_reduce.py + - data/models/${item}/ + - data/preprocessed/${item}/ + outs: + - data/reduced/${item} classify: foreach: - enron_spam - poleval - - 20_news + - ag_news - wiki_pl do: wdir: . @@ -59,14 +77,14 @@ stages: deps: - experiments/scripts/classify.py - data/models/${item}/ - - data/preprocessed/${item}/ + - data/reduced/${item}/ outs: - data/classification/${item} explain: foreach: - enron_spam - poleval - - 20_news + - ag_news - wiki_pl do: wdir: . @@ -199,4 +217,37 @@ stages: - data/classification/${item} outs: - data/results/attack_basic/${item}/ - + attack_xai_char_discard: + foreach: + - enron_spam + - poleval + - 20_news + - wiki_pl + do: + wdir: . + cmd: >- + PYTHONPATH=. python experiments/scripts/attack.py + --dataset_name ${item} --attack_type attack_xai_char_discard + deps: + - experiments/scripts/attack.py + - data/models/${item} + - data/classification/${item} + outs: + - data/results/attack_xai_char_discard/${item}/ + attack_xai_char_discard_local: + foreach: + - enron_spam + - poleval + - 20_news + - wiki_pl + do: + wdir: . + cmd: >- + PYTHONPATH=. python experiments/scripts/attack.py + --dataset_name ${item} --attack_type attack_xai_char_discard_local + deps: + - experiments/scripts/attack.py + - data/models/${item} + - data/classification/${item} + outs: + - data/results/attack_xai_char_discard_local/${item}/ diff --git a/experiments/scripts/attack.py b/experiments/scripts/attack.py index b660adf9a5954d27e1e51dfd020bb2529171cfaa..a1450fd8bd444a09e89eefc3f6881cec1bca5ce0 100644 --- a/experiments/scripts/attack.py +++ b/experiments/scripts/attack.py @@ -15,7 +15,8 @@ from multiprocessing import Process from multiprocessing import Queue, Manager from threading import Thread from sklearn.metrics import classification_report, confusion_matrix -import numpy as np +from string import punctuation + TEXT = "text" LEMMAS = "lemmas" @@ -36,13 +37,14 @@ EXPECTED = "expected" ACTUAL = "actual" COSINE_SCORE = "cosine_score" CLASS = "class" -QUEUE_SIZE = 1000 +QUEUE_SIZE = 60 FEATURES = "features" IMPORTANCE = "importance" SYNONYM = "synonym" DISCARD = "discard" GLOBAL = "global" LOCAL = "local" +CHAR_DISCARD = "char_discard" os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -55,6 +57,11 @@ DEFAULT_RES = { } +def join_punct(words): + punc = set(punctuation) + return "".join(w if set(w) <= punc else " " + w for w in words).lstrip() + + def data_producer(queue_out, dataset_df): for i, cols in tqdm( dataset_df[[TEXT, ID, LEMMAS, TAGS, ORTHS, PRED]].iterrows(), total=len(dataset_df) @@ -68,20 +75,31 @@ def data_saver(queue_in, queue_log, output_file, output_dir, cases_nbr, queues_k processed_nbr, start = 0, time() item = 1 test_y, pred_y = [], [] - spoiled_sents = [] - ch_suc, ch_all = 0, 0 + ch_suc, ch_all, synonyms_nbr = 0, 0, 0 + samples, samples_succ = 0, 0 + count_tokens, sum_tokens = 0, 0 + end_time = time() while item is not None: item = queue_in.get() if item is not None: processed_nbr += 1 - spoiled, class_test, class_pred = item + spoiled, class_test, class_pred, synonym_nbr = process(*item) test_y.append(class_test) pred_y.append(class_pred) queue_log.put(f"Processed and saved {processed_nbr} in {time() - start} s") + samples_succ = samples_succ + 1 if spoiled[ATTACK_SUMMARY][SUCCEEDED] > 0 else samples_succ + samples += 1 + for success in spoiled[ATTACK_SUCCEEDED]: + if CHANGED_WORDS in success: + count_tokens += len(success[CHANGED_WORDS]) + sum_tokens += 1 ch_suc += spoiled[ATTACK_SUMMARY][SUCCEEDED] ch_all += spoiled[ATTACK_SUMMARY][ALL] - spoiled_sents.append(spoiled) + synonyms_nbr += synonym_nbr + with open(output_file, 'at') as fd: + fd.write(pd.DataFrame([spoiled]).to_json(orient="records", lines=True)) + spoiled = None if processed_nbr == cases_nbr: for que_kill in queues_kill: [que_kill.put(None) for _ in range(to_kill_nbr)] @@ -91,15 +109,27 @@ def data_saver(queue_in, queue_log, output_file, output_dir, cases_nbr, queues_k if sum([q.qsize() for q in queues_kill]) == 0 and (time() - end_time) > 3600: for que_kill in queues_kill: [que_kill.put(None) for _ in range(to_kill_nbr)] - with open(output_file, 'wt') as fd: - fd.write(pd.DataFrame(spoiled_sents).to_json( - orient="records", lines=True)) - np.savetxt(f"{output_dir}/metrics.txt", confusion_matrix(test_y, pred_y)) - with open(f"{output_dir}/metrics.txt", mode="at") as fd: - fd.write('\n') - fd.write(classification_report(test_y, pred_y)) - fd.write('\n') - fd.write(f"succeeded {ch_suc} all {ch_all}") + + metrics = { + "confusion_matrix": confusion_matrix(test_y, pred_y).tolist(), + "classification_report": classification_report(test_y, pred_y, output_dict=True), + "attacks_succeeded": ch_suc, + "attacks_all": ch_all, + "synonyms_nbr": synonyms_nbr, + "success_rate": ch_suc / ch_all, + "success_rate_per_synonym": ch_suc / synonyms_nbr, + "time": time() - start, + "samples": samples, + "samples_succ": samples_succ, + "count_tokens": count_tokens, + "sum_tokens": sum_tokens, + "%F": (samples - samples_succ) / samples if samples > 0 else 0, + "%C": count_tokens / sum_tokens if sum_tokens > 0 else 0, + "BLEU": 0, + "P": 0 + } + with open(f"{output_dir}/metrics.json", mode="w") as fd: + json.dump(metrics, fd) def classify_queue(queue_in, queue_out, queue_log, dataset_name, cuda_device): @@ -115,11 +145,17 @@ def classify_queue(queue_in, queue_out, queue_log, dataset_name, cuda_device): item = queue_in.get() queue_log.put("Classify got from queue") if item is not None: - sent_id, org_sentence, y_pred, changed_sents = item - sentences = [sent[TEXT] for sent in changed_sents] - queue_log.put(f"Classifying sentences {len(sentences)}, id {sent_id}") + sent_id, org_sentence, y_pred, changed, synonyms_nbr, sent_words = item + sentences = [] + for subst, _ in changed: + sent_words_copy = [*sent_words] + for idx, word_change in subst.items(): + sent_words_copy[idx] = word_change['word'] + sentences.append(join_punct(sent_words_copy)) + + queue_log.put(f"Classifying sentences {synonyms_nbr}, id {sent_id}") classified = classify_fun(sentences) if sentences else [] - queue_out.put((sent_id, org_sentence, changed_sents, y_pred, classified)) + queue_out.put((sent_id, org_sentence, changed, y_pred, classified, synonyms_nbr, sent_words)) queue_log.put(f"Classified sentences {sent_id}") @@ -134,7 +170,9 @@ def log_info_queue(queue): print("Logging queue") while True: item = queue.get() - print(item) + if item is not None: + print(item) + print("Logging queue finished") def load_dir_files(dir_path): @@ -181,12 +219,17 @@ def main(dataset_name: str, attack_type: str): "poleval": "pl", "20_news": "en", "wiki_pl": "pl", + "ag_news": "en", }[dataset_name] xai_global, xai_local = {}, {} if "attack_xai" in attack_type: importance = load_xai_importance(f"data/explanations/{dataset_name}") xai_global, xai_local = importance[0], importance[1] - xai_sub = 5 + xai_sub = 0.15 + max_sub = 3 + char_delete_size = 0.4 + similarity_bound = 0.3 + params = { "attack_textfooler": [lang, SYNONYM], "attack_textfooler_discard": [lang, DISCARD], @@ -194,68 +237,55 @@ def main(dataset_name: str, attack_type: str): "attack_xai": [lang, xai_global, xai_local, GLOBAL, SYNONYM, xai_sub], "attack_xai_discard": [lang, xai_global, xai_local, GLOBAL, DISCARD, xai_sub], "attack_xai_local": [lang, xai_global, xai_local, LOCAL, SYNONYM, xai_sub], - "attack_xai_discard_local": [lang, xai_global, xai_local, LOCAL, DISCARD, xai_sub] + "attack_xai_discard_local": [lang, xai_global, xai_local, LOCAL, DISCARD, xai_sub], + "attack_xai_char_discard": [lang, xai_global, xai_local, GLOBAL, CHAR_DISCARD, xai_sub, char_delete_size], + "attack_xai_char_discard_local": [lang, xai_global, xai_local, LOCAL, CHAR_DISCARD, xai_sub, char_delete_size] }[attack_type] - output_dir = f"data/results/{attack_type}/{dataset_name}/" input_file = f"data/classification/{dataset_name}/test.jsonl" + os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, "test.jsonl") dataset_df = pd.read_json(input_file, lines=True) - max_sub = 1 + + test_sent_ids = ["Komputery_199721.txt", "Zydzi_976178.txt", "Kotowate_2015873.txt", "Zydzi_1602490.txt", + "Pilka-nozna_2899267.txt", "Optyka_1926807.txt", "Zydzi_929483.txt", + "Niemieccy-wojskowi_2410107.txt"] + + # dataset_df = dataset_df[dataset_df['id'].isin(test_sent_ids)] + # dataset_df = dataset_df.reset_index(drop=True) + + dataset_df = dataset_df[:20] m = Manager() - queues = [m.Queue(maxsize=QUEUE_SIZE) for _ in range(6)] - sim = Similarity(queues[5], 0.95, "distiluse-base-multilingual-cased-v1") - processes = [ - Process(target=data_producer, args=(queues[0], dataset_df,)), # loading data file_in -> 0 - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - Process(target=spoil_queue, args=(queues[0], queues[1], queues[5], max_sub, attack_type, params)), - # spoiling 0 -> 1 - Process(target=filter_similarity_queue, args=(queues[1], queues[2], queues[5], sim)), - Process(target=filter_similarity_queue, args=(queues[1], queues[2], queues[5], sim)), # cosim 1 -> 2 - Process(target=classify_queue, args=(queues[2], queues[3], queues[5], dataset_name, "6")), - Process(target=classify_queue, args=(queues[2], queues[3], queues[5], dataset_name, "4")), - # classify changed 2 -> 3 - Process(target=run_queue, args=(queues[3], queues[4], queues[5], process,)), # process 3 -> 4 - Process(target=data_saver, args=(queues[4], queues[5], output_path, output_dir, len(dataset_df), queues, 30)) - # saving 4 -> file_out - ] - [p.start() for p in processes] + queues = [m.Queue(maxsize=QUEUE_SIZE) for _ in range(5)] - log_que = Thread(target=log_queues, args=(queues[:5],)) + log_que = Thread(target=log_queues, args=(queues[:4],)) log_que.daemon = True log_que.start() - info_que = Thread(target=log_info_queue, args=(queues[5],)) + info_que = Thread(target=log_info_queue, args=(queues[4],)) info_que.daemon = True info_que.start() + + processes_nbr = 15 + sim = Similarity(queues[4], similarity_bound, "distiluse-base-multilingual-cased-v1") + processes = [Process(target=data_producer, args=(queues[0], dataset_df,))] # loading data file_in -> 0 + + processes.extend([Process(target=spoil_queue, args=(queues[0], queues[1], queues[4], max_sub, attack_type, params)) + for _ in range(processes_nbr)]) # spoiling 0 -> 1 + + processes.extend([Process(target=filter_similarity_queue, args=(queues[1], queues[2], queues[4], sim)), + Process(target=filter_similarity_queue, args=(queues[1], queues[2], queues[4], sim)), # cosim 1 -> 2 + Process(target=classify_queue, args=(queues[2], queues[3], queues[4], dataset_name, "3")), + Process(target=classify_queue, args=(queues[2], queues[3], queues[4], dataset_name, "3")), + # classify changed 2 -> 3 + # Process(target=run_queue, args=(queues[3], queues[4], queues[5], process,)), # process 3 -> 4 + Process(target=data_saver, args=(queues[3], queues[4], output_path, output_dir, len(dataset_df), queues, processes_nbr+6)) # saving 3 -> file_out + ]) + + [p.start() for p in processes] + # wait for all processes to finish [p.join() for p in processes] log_que.join(timeout=0.5) diff --git a/experiments/scripts/classify.py b/experiments/scripts/classify.py index ab34bd70e815f74effd4044efa041f3ebb5249d6..055e0b1ea45cf7079d48f4b27c6da2db7f3dd6cb 100644 --- a/experiments/scripts/classify.py +++ b/experiments/scripts/classify.py @@ -30,7 +30,7 @@ def main( dataset_name=dataset_name, device="cuda" if torch.cuda.is_available() else "cpu" ) - test = pd.read_json(f"data/preprocessed/{dataset_name}/test.jsonl", lines=True) + test = pd.read_json(f"data/reduced/{dataset_name}/test.jsonl", lines=True) test_x = test["text"].tolist() test_y = test["label"] pred_y = classify(test_x) diff --git a/experiments/scripts/explain.py b/experiments/scripts/explain.py index 46120cd8003f86615a8894dd8258e281f4f92540..e9d058164ee4173fcee99bf830f4dec8f9eac37b 100644 --- a/experiments/scripts/explain.py +++ b/experiments/scripts/explain.py @@ -103,12 +103,14 @@ def main( os.makedirs(output_dir / "global", exist_ok=True) for class_id, class_name in model.config.id2label.items(): importance_df = get_importance(shap_values[:, :, class_id].mean(0)) + class_name = class_name.replace("/", "_") importance_df.to_json( output_dir / "global" / f"{class_name}__importance.json", ) # LOCAL IMPORTANCE for class_id, class_name in model.config.id2label.items(): + class_name = class_name.replace("/", "_") sub_dir = output_dir / "local" / "adversarial" /class_name os.makedirs(sub_dir, exist_ok=True) for shap_id, text_id in enumerate(test["id"]): diff --git a/experiments/scripts/tag_dataset.py b/experiments/scripts/tag_dataset.py index 2ecc51197e33e9c87a1f2d4053e3b1b748447b10..f283bed960613ba395209d5c5ea76dc187f710dd 100644 --- a/experiments/scripts/tag_dataset.py +++ b/experiments/scripts/tag_dataset.py @@ -1,13 +1,13 @@ """Script for running tagger on datasets.""" import click import pandas as pd -from lpmn_client_biz import Connection, IOType, Task, download +from lpmn_client_biz import Connection, IOType, Task, download, upload import json import os from tqdm import tqdm -from multiprocessing import cpu_count, Pool import spacy - +import shutil +import uuid TOKENS = "tokens" ORTH = "orth" @@ -18,63 +18,74 @@ TEXT = "text" LEMMAS = "lemmas" TAGS = "tags" ORTHS = "orths" +NER = "ner" -def tag_sentence(sentence: str, lang: str): +def tag_sentences(sentences, lang: str): + results = {} connection = Connection(config_file="experiments/configs/config.yml") - lpmn = [{"spacy": {"lang": "en"}}] - if lang == "pl": - lpmn = [ - "morphodita", - {"posconverter": {"input_format": "ccl", "output_format": "json"}}, - ] - - task = Task(lpmn, connection=connection) - output_file_id = task.run(str(sentence), IOType.TEXT) - tokens = [] - try: - clarin_json = json.loads( - download(connection, output_file_id, IOType.TEXT).decode("utf-8") - ) - tokens = clarin_json[TOKENS] - except json.decoder.JSONDecodeError: - downloaded = download(connection, output_file_id, IOType.FILE) - with open(downloaded, "r") as file: + lpmn = [[{"postagger": {"lang": lang}}], 'makezip'] + input_dir = str(uuid.uuid4()) + os.makedirs(input_dir) + for idx, sentence in enumerate(sentences): + with open(f'{input_dir}/file_{idx}', + 'w', encoding='utf8') as fout: + fout.write(sentence) + + uploaded = upload(connection, input_dir) + task = Task(lpmn, connection) + result = task.run(uploaded, IOType.FILE, verbose=True) + archive_path = download( + connection, + result, + IOType.FILE, + filename=f'{uuid.uuid4()}.zip' + ) + output_path = archive_path.replace('.zip', '') + shutil.unpack_archive(archive_path, output_path) + files = sorted(os.listdir(output_path), key=lambda x: int(x.split('_')[1])) + for j, filename in enumerate(files): + with open(f'{output_path}/{filename}', 'r') as file: lines = [json.loads(line) for line in file.readlines()] - for line in lines: - tokens.extend(line[TOKENS]) - os.remove(downloaded) - lemmas, tags, orths = [], [], [] - for token in tokens: - lexeme = token[LEXEMES][0] - lemmas.append(lexeme[LEMMA]) - tags.append(lexeme[MSTAG]) - orths.append(token[ORTH]) - return lemmas, tags, orths - - -def process_file(dataset_df, lang, output_path): + lemmas, tags, orths = [], [], [] + if len(lines) > 0: + for idx, line in enumerate(lines): + tokens = line[TOKENS] + for token in tokens: + lexeme = token[LEXEMES][0] + lemmas.append(lexeme[LEMMA]) + tags.append(lexeme[MSTAG]) + orths.append(token[ORTH]) + else: + tokens = lines[0][TOKENS] + for token in tokens: + lexeme = token[LEXEMES][0] + lemmas.append(lexeme[LEMMA]) + tags.append(lexeme[MSTAG]) + orths.append(token[ORTH]) + results[int(filename.split('_')[1])] = { + LEMMAS: lemmas, + TAGS: tags, + ORTHS: orths + } + shutil.rmtree(input_dir) + os.remove(archive_path) + shutil.rmtree(output_path) + return results + + +def process_file(dataset_df, lang): test_with_tags = pd.DataFrame(dataset_df) lemmas_col, tags_col, orth_col = [], [], [] - cpus = 2 - with Pool(processes=cpus) as pool: - results = [] - for idx in tqdm(range(0, len(dataset_df), cpus)): - end = min(idx + cpus, len(dataset_df) + 1) - for sentence in dataset_df[TEXT][idx:end]: - results.append( - pool.apply_async(tag_sentence, args=[sentence, lang]) - ) - for res in results: - lemmas, tags, orths = res.get() - lemmas_col.append(lemmas) - tags_col.append(tags) - orth_col.append(orths) - results = [] + + tagged_sentences = tag_sentences(dataset_df[TEXT].tolist(), lang) + for idx, tokens in tagged_sentences.items(): + lemmas_col.append(tokens[LEMMAS]) + tags_col.append(tokens[TAGS]) + orth_col.append(tokens[ORTHS]) test_with_tags[LEMMAS] = lemmas_col test_with_tags[TAGS] = tags_col test_with_tags[ORTHS] = orth_col - return test_with_tags @@ -105,21 +116,16 @@ def add_ner(dataset_df, language): help="Dataset name", type=str, ) -@click.option( - "--output", - help="Output directory", - type=str, - -) -def main(dataset_name: str, output: str): +def main(dataset_name: str): """Downloads the dataset to the output directory.""" lang = { "enron_spam": "en", "poleval": "pl", "20_news": "en", "wiki_pl": "pl", + "ag_news": "en", }[dataset_name] - output_dir = f"{output}/{dataset_name}" + output_dir = f"data/preprocessed/{dataset_name}/" os.makedirs(output_dir, exist_ok=True) input_dir = f"data/datasets/{dataset_name}" @@ -128,8 +134,7 @@ def main(dataset_name: str, output: str): if file in ["test.jsonl", "adversarial.jsonl"]: test_with_tags = process_file( pd.read_json(os.path.join(input_dir, file), lines=True), - lang, - os.path.join(output_dir, file), + lang ) test_with_tags = add_ner(test_with_tags, lang) else: diff --git a/experiments/scripts/token_reduce.py b/experiments/scripts/token_reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..1aeb000d42c2aa28edcdd5faa367c9ba23d9235c --- /dev/null +++ b/experiments/scripts/token_reduce.py @@ -0,0 +1,154 @@ +"""Reduce sample size to 512 tokens""" + +from pathlib import Path +import click +import pandas as pd +import spacy +import uuid +import shutil +from tqdm import tqdm +import os +import json +from text_attacks.utils import get_model_and_tokenizer +from lpmn_client_biz import Connection, IOType, Task, download, upload + +TOKENS = "tokens" +ORTH = "orth" +LEXEMES = "lexemes" +LEMMA = "lemma" +MSTAG = "mstag" +TEXT = "text" +LEMMAS = "lemmas" +TAGS = "tags" +ORTHS = "orths" +NER = "ner" + + +def tag_sentences(sentences, lang: str): + results = {} + connection = Connection(config_file="experiments/configs/config.yml") + lpmn = [[{"postagger": {"lang": lang}}], 'makezip'] + input_dir = str(uuid.uuid4()) + os.makedirs(input_dir) + for idx, sentence in sentences.items(): + with open(f'{input_dir}/file_{idx}', + 'w', encoding='utf8') as fout: + fout.write(sentence) + + uploaded = upload(connection, input_dir) + task = Task(lpmn, connection) + result = task.run(uploaded, IOType.FILE, verbose=True) + archive_path = download( + connection, + result, + IOType.FILE, + filename=f'{uuid.uuid4()}.zip' + ) + output_path = archive_path.replace('.zip', '') + shutil.unpack_archive(archive_path, output_path) + files = sorted(os.listdir(output_path), key=lambda x: int(x.split('_')[1])) + for j, filename in enumerate(files): + with open(f'{output_path}/{filename}', 'r') as file: + lines = [json.loads(line) for line in file.readlines()] + lemmas, tags, orths = [], [], [] + if len(lines) > 0: + for idx, line in enumerate(lines): + tokens = line[TOKENS] + for token in tokens: + lexeme = token[LEXEMES][0] + lemmas.append(lexeme[LEMMA]) + tags.append(lexeme[MSTAG]) + orths.append(token[ORTH]) + else: + tokens = lines[0][TOKENS] + for token in tokens: + lexeme = token[LEXEMES][0] + lemmas.append(lexeme[LEMMA]) + tags.append(lexeme[MSTAG]) + orths.append(token[ORTH]) + results[int(filename.split('_')[1])] = { + LEMMAS: lemmas, + TAGS: tags, + ORTHS: orths + } + shutil.rmtree(input_dir) + os.remove(archive_path) + shutil.rmtree(output_path) + return results + + +def add_ner(sentences, language): + model = "en_core_web_trf" if language == "en" else "pl_core_news_lg" + nlp = spacy.load(model) + ner_data = {} + + for idx, text in tqdm(sentences.items()): + doc = nlp(text) + doc_ner = list() + for ent in doc.ents: + doc_ner.append({ + "text": ent.text, + "start_char": ent.start_char, + "end_char": ent.end_char, + "label": ent.label_, + }) + ner_data[idx] = doc_ner + return ner_data + + +@click.command() +@click.option( + "--dataset_name", + help="Dataset name", + type=str, +) +@click.option( + "--output_dir", + help="Path to output directory", + type=click.Path(path_type=Path), +) +def main( + dataset_name: str, + output_dir: Path, +): + lang = { + "enron_spam": "en", + "poleval": "pl", + "20_news": "en", + "wiki_pl": "pl", + "ag_news": "en", + }[dataset_name] + output_dir.mkdir(parents=True, exist_ok=True) + model, tokenizer = get_model_and_tokenizer( + dataset_name=dataset_name + ) + model.to("cpu") + model.eval() + test = pd.read_json(f"data/preprocessed/{dataset_name}/test.jsonl", lines=True) + + texts = test["text"].tolist() + texts_reduced = {} + for i, sentence in test["text"].items(): + encoded = tokenizer.encode(sentence, add_special_tokens=True, max_length=512, truncation=True) + decod_res = tokenizer.decode(encoded, skip_special_tokens=True, clean_up_tokenization_spaces=True) + last_word = decod_res.split(" ")[-1] + max_len = len(" ".join(sentence.split(" ")[:512])) + idx = sentence.rfind(last_word, 0, max_len) + if idx + len(last_word) < len(sentence) and idx > 0: + texts_reduced[i] = sentence[:idx + len(last_word)] + print("To reduce ", len(texts_reduced), " of ", len(texts)) + + if len(texts_reduced) > 0: + tagged_reduced = tag_sentences(texts_reduced, lang) + ner_reduced = add_ner(texts_reduced, lang) + for idx, sentence in texts_reduced.items(): + test.loc[idx, TEXT] = sentence + test.at[idx, LEMMAS] = tagged_reduced[idx][LEMMAS] + test.at[idx, TAGS] = tagged_reduced[idx][TAGS] + test.at[idx, ORTHS] = tagged_reduced[idx][ORTHS] + test.at[idx, NER] = ner_reduced[idx] + test.to_json(output_dir / "test.jsonl", orient="records", lines=True) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/text_attacks/models/ag_news.py b/text_attacks/models/ag_news.py new file mode 100644 index 0000000000000000000000000000000000000000..9811bb6d1755dc6cacb741224f63ac2081d4c19d --- /dev/null +++ b/text_attacks/models/ag_news.py @@ -0,0 +1,40 @@ +"""Classification model for ag_news""" + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForSequenceClassification + + +def get_model_and_tokenizer(): + model_path = "./data/models/ag_news" + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = AutoModelForSequenceClassification.from_pretrained(model_path) + return model, tokenizer + + +def get_classify_function(device="cpu"): + model, tokenizer = get_model_and_tokenizer() + model.eval() + model = model.to(device) + + def fun(texts): + logits = list() + i = 0 + for chunk in tqdm( + [texts[pos:pos + 128] for pos in range(0, len(texts), 128)] + ): + encoded_inputs = tokenizer( + chunk, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512 + ).to(device) + with torch.no_grad(): + logits.append(model(**encoded_inputs).logits.cpu()) + logits = torch.cat(logits, dim=0) + pred_y = torch.argmax(logits, dim=1).tolist() + pred_y = [model.config.id2label[p] for p in pred_y] + return pred_y + + return fun