From 08e1c133ce31f2fceeb9831effceadc9469500eb Mon Sep 17 00:00:00 2001 From: MGniew <m.f.gniewkowski@gmail.com> Date: Fri, 10 Mar 2023 14:28:08 +0100 Subject: [PATCH] Added 2 new datasets --- data/classification/.gitignore | 2 ++ data/datasets/.gitignore | 2 ++ data/datasets/20_news.dvc | 5 ++++ data/datasets/wiki_pl.dvc | 5 ++++ data/models/.gitignore | 2 ++ data/models/20_news.dvc | 5 ++++ data/models/wiki_pl.dvc | 5 ++++ experiments/scripts/classify.py | 2 ++ requirements.txt | 4 +-- text_attacks/models/20_news.py | 42 +++++++++++++++++++++++++++++++ text_attacks/models/enron_spam.py | 30 ++++++++++++++-------- text_attacks/models/wiki_pl.py | 42 +++++++++++++++++++++++++++++++ text_attacks/utils.py | 4 +-- 13 files changed, 136 insertions(+), 14 deletions(-) create mode 100644 data/datasets/20_news.dvc create mode 100644 data/datasets/wiki_pl.dvc create mode 100644 data/models/20_news.dvc create mode 100644 data/models/wiki_pl.dvc create mode 100644 text_attacks/models/20_news.py create mode 100644 text_attacks/models/wiki_pl.py diff --git a/data/classification/.gitignore b/data/classification/.gitignore index 60ba700..e695872 100644 --- a/data/classification/.gitignore +++ b/data/classification/.gitignore @@ -1 +1,3 @@ /enron_spam +/wiki_pl +/20_news diff --git a/data/datasets/.gitignore b/data/datasets/.gitignore index af871df..43bd163 100644 --- a/data/datasets/.gitignore +++ b/data/datasets/.gitignore @@ -1,2 +1,4 @@ /enron_spam +/20_news /poleval +/wiki_pl diff --git a/data/datasets/20_news.dvc b/data/datasets/20_news.dvc new file mode 100644 index 0000000..00b5cf4 --- /dev/null +++ b/data/datasets/20_news.dvc @@ -0,0 +1,5 @@ +outs: +- md5: 999207f1c2c123c9943397b47f2c3b3a.dir + size: 23460358 + nfiles: 3 + path: 20_news diff --git a/data/datasets/wiki_pl.dvc b/data/datasets/wiki_pl.dvc new file mode 100644 index 0000000..f0f2afe --- /dev/null +++ b/data/datasets/wiki_pl.dvc @@ -0,0 +1,5 @@ +outs: +- md5: abcbccb3e352ed623cace1b95078bd63.dir + size: 29115538 + nfiles: 3 + path: wiki_pl diff --git a/data/models/.gitignore b/data/models/.gitignore index 60ba700..ea22867 100644 --- a/data/models/.gitignore +++ b/data/models/.gitignore @@ -1 +1,3 @@ /enron_spam +/20_news +/wiki_pl diff --git a/data/models/20_news.dvc b/data/models/20_news.dvc new file mode 100644 index 0000000..d667d57 --- /dev/null +++ b/data/models/20_news.dvc @@ -0,0 +1,5 @@ +outs: +- md5: 43d68a67ecb8149bd6bf50db9767cb64.dir + size: 439008808 + nfiles: 6 + path: 20_news diff --git a/data/models/wiki_pl.dvc b/data/models/wiki_pl.dvc new file mode 100644 index 0000000..fdf58d5 --- /dev/null +++ b/data/models/wiki_pl.dvc @@ -0,0 +1,5 @@ +outs: +- md5: fd453042628fb09c080ef05d34a32cce.dir + size: 501711136 + nfiles: 7 + path: wiki_pl diff --git a/experiments/scripts/classify.py b/experiments/scripts/classify.py index 9639d29..ab34bd7 100644 --- a/experiments/scripts/classify.py +++ b/experiments/scripts/classify.py @@ -3,6 +3,7 @@ from pathlib import Path import click import pandas as pd +import torch from sklearn.metrics import classification_report from text_attacks.utils import get_classify_function @@ -27,6 +28,7 @@ def main( output_dir.mkdir(parents=True, exist_ok=True) classify = get_classify_function( dataset_name=dataset_name, + device="cuda" if torch.cuda.is_available() else "cpu" ) test = pd.read_json(f"data/preprocessed/{dataset_name}/test.jsonl", lines=True) test_x = test["text"].tolist() diff --git a/requirements.txt b/requirements.txt index fec55bd..66b509a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,8 @@ datasets transformers click scikit-learn -dvc[s3] -shap +dvc[s3]==2.46.0 +shap==0.41.0 lpmn_client_biz --find-links https://download.pytorch.org/whl/torch_stable.html diff --git a/text_attacks/models/20_news.py b/text_attacks/models/20_news.py new file mode 100644 index 0000000..53712fa --- /dev/null +++ b/text_attacks/models/20_news.py @@ -0,0 +1,42 @@ +"""Classification model for enron_spam""" +import os + +import torch +from tqdm import tqdm + +from transformers import AutoTokenizer, AutoModelForSequenceClassification + + +def get_model_and_tokenizer(): + model_path = "./data/models/20_news" + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = AutoModelForSequenceClassification.from_pretrained(model_path) + return model, tokenizer + + +def get_classify_function(device="cpu"): + model, tokenizer = get_model_and_tokenizer() + model.eval() + model = model.to(device) + + def fun(texts): + logits = list() + i = 0 + for chunk in tqdm( + [texts[pos:pos + 256] for pos in range(0, len(texts), 256)] + ): + encoded_inputs = tokenizer( + chunk, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512 + ).to(device) + with torch.no_grad(): + logits.append(model(**encoded_inputs).logits.cpu()) + logits = torch.cat(logits, dim=0) + pred_y = torch.argmax(logits, dim=1).tolist() + pred_y = [model.config.id2label[p] for p in pred_y] + return pred_y + + return fun diff --git a/text_attacks/models/enron_spam.py b/text_attacks/models/enron_spam.py index 063a52a..9a1946d 100644 --- a/text_attacks/models/enron_spam.py +++ b/text_attacks/models/enron_spam.py @@ -2,12 +2,13 @@ import os import torch +from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForSequenceClassification def get_model_and_tokenizer(): - model_path = "data/models/endron_spam" + model_path = "./data/models/endron_spam" if not os.path.exists(model_path): model_path = "mrm8488/bert-tiny-finetuned-enron-spam-detection" tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -16,18 +17,27 @@ def get_model_and_tokenizer(): return model, tokenizer -def get_classify_function(): +def get_classify_function(device="cpu"): model, tokenizer = get_model_and_tokenizer() + model.eval() + model = model.to(device) def fun(texts): - encoded_inputs = tokenizer( - texts, - return_tensors="pt", - padding=True, - truncation=True, - max_length=512 - ) - logits = model(**encoded_inputs).logits + logits = list() + i = 0 + for chunk in tqdm( + [texts[pos:pos + 256] for pos in range(0, len(texts), 256)] + ): + encoded_inputs = tokenizer( + chunk, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512 + ).to(device) + with torch.no_grad(): + logits.append(model(**encoded_inputs).logits.cpu()) + logits = torch.cat(logits, dim=0) pred_y = torch.argmax(logits, dim=1).tolist() pred_y = [model.config.id2label[p] for p in pred_y] return pred_y diff --git a/text_attacks/models/wiki_pl.py b/text_attacks/models/wiki_pl.py new file mode 100644 index 0000000..1ad1539 --- /dev/null +++ b/text_attacks/models/wiki_pl.py @@ -0,0 +1,42 @@ +"""Classification model for enron_spam""" +import os + +import torch +from tqdm import tqdm + +from transformers import AutoTokenizer, AutoModelForSequenceClassification + + +def get_model_and_tokenizer(): + model_path = "./data/models/wiki_pl" + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = AutoModelForSequenceClassification.from_pretrained(model_path) + return model, tokenizer + + +def get_classify_function(device="cpu"): + model, tokenizer = get_model_and_tokenizer() + model.eval() + model = model.to(device) + + def fun(texts): + logits = list() + i = 0 + for chunk in tqdm( + [texts[pos:pos + 256] for pos in range(0, len(texts), 256)] + ): + encoded_inputs = tokenizer( + chunk, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512 + ).to(device) + with torch.no_grad(): + logits.append(model(**encoded_inputs).logits.cpu()) + logits = torch.cat(logits, dim=0) + pred_y = torch.argmax(logits, dim=1).tolist() + pred_y = [model.config.id2label[p] for p in pred_y] + return pred_y + + return fun diff --git a/text_attacks/utils.py b/text_attacks/utils.py index e47d520..6a05882 100644 --- a/text_attacks/utils.py +++ b/text_attacks/utils.py @@ -11,13 +11,13 @@ def get_model_and_tokenizer(dataset_name): return fun() -def get_classify_function(dataset_name): +def get_classify_function(dataset_name, device="cpu"): """Return get_model_and_tokenizer for a specific dataset.""" fun = getattr( importlib.import_module(f"text_attacks.models.{dataset_name}"), "get_classify_function", ) - return fun() + return fun(device=device) def download_dataset(dataset_name): -- GitLab