From 08e1c133ce31f2fceeb9831effceadc9469500eb Mon Sep 17 00:00:00 2001
From: MGniew <m.f.gniewkowski@gmail.com>
Date: Fri, 10 Mar 2023 14:28:08 +0100
Subject: [PATCH] Added 2 new datasets

---
 data/classification/.gitignore    |  2 ++
 data/datasets/.gitignore          |  2 ++
 data/datasets/20_news.dvc         |  5 ++++
 data/datasets/wiki_pl.dvc         |  5 ++++
 data/models/.gitignore            |  2 ++
 data/models/20_news.dvc           |  5 ++++
 data/models/wiki_pl.dvc           |  5 ++++
 experiments/scripts/classify.py   |  2 ++
 requirements.txt                  |  4 +--
 text_attacks/models/20_news.py    | 42 +++++++++++++++++++++++++++++++
 text_attacks/models/enron_spam.py | 30 ++++++++++++++--------
 text_attacks/models/wiki_pl.py    | 42 +++++++++++++++++++++++++++++++
 text_attacks/utils.py             |  4 +--
 13 files changed, 136 insertions(+), 14 deletions(-)
 create mode 100644 data/datasets/20_news.dvc
 create mode 100644 data/datasets/wiki_pl.dvc
 create mode 100644 data/models/20_news.dvc
 create mode 100644 data/models/wiki_pl.dvc
 create mode 100644 text_attacks/models/20_news.py
 create mode 100644 text_attacks/models/wiki_pl.py

diff --git a/data/classification/.gitignore b/data/classification/.gitignore
index 60ba700..e695872 100644
--- a/data/classification/.gitignore
+++ b/data/classification/.gitignore
@@ -1 +1,3 @@
 /enron_spam
+/wiki_pl
+/20_news
diff --git a/data/datasets/.gitignore b/data/datasets/.gitignore
index af871df..43bd163 100644
--- a/data/datasets/.gitignore
+++ b/data/datasets/.gitignore
@@ -1,2 +1,4 @@
 /enron_spam
+/20_news
 /poleval
+/wiki_pl
diff --git a/data/datasets/20_news.dvc b/data/datasets/20_news.dvc
new file mode 100644
index 0000000..00b5cf4
--- /dev/null
+++ b/data/datasets/20_news.dvc
@@ -0,0 +1,5 @@
+outs:
+- md5: 999207f1c2c123c9943397b47f2c3b3a.dir
+  size: 23460358
+  nfiles: 3
+  path: 20_news
diff --git a/data/datasets/wiki_pl.dvc b/data/datasets/wiki_pl.dvc
new file mode 100644
index 0000000..f0f2afe
--- /dev/null
+++ b/data/datasets/wiki_pl.dvc
@@ -0,0 +1,5 @@
+outs:
+- md5: abcbccb3e352ed623cace1b95078bd63.dir
+  size: 29115538
+  nfiles: 3
+  path: wiki_pl
diff --git a/data/models/.gitignore b/data/models/.gitignore
index 60ba700..ea22867 100644
--- a/data/models/.gitignore
+++ b/data/models/.gitignore
@@ -1 +1,3 @@
 /enron_spam
+/20_news
+/wiki_pl
diff --git a/data/models/20_news.dvc b/data/models/20_news.dvc
new file mode 100644
index 0000000..d667d57
--- /dev/null
+++ b/data/models/20_news.dvc
@@ -0,0 +1,5 @@
+outs:
+- md5: 43d68a67ecb8149bd6bf50db9767cb64.dir
+  size: 439008808
+  nfiles: 6
+  path: 20_news
diff --git a/data/models/wiki_pl.dvc b/data/models/wiki_pl.dvc
new file mode 100644
index 0000000..fdf58d5
--- /dev/null
+++ b/data/models/wiki_pl.dvc
@@ -0,0 +1,5 @@
+outs:
+- md5: fd453042628fb09c080ef05d34a32cce.dir
+  size: 501711136
+  nfiles: 7
+  path: wiki_pl
diff --git a/experiments/scripts/classify.py b/experiments/scripts/classify.py
index 9639d29..ab34bd7 100644
--- a/experiments/scripts/classify.py
+++ b/experiments/scripts/classify.py
@@ -3,6 +3,7 @@ from pathlib import Path
 
 import click
 import pandas as pd
+import torch
 from sklearn.metrics import classification_report
 
 from text_attacks.utils import get_classify_function
@@ -27,6 +28,7 @@ def main(
     output_dir.mkdir(parents=True, exist_ok=True)
     classify = get_classify_function(
         dataset_name=dataset_name,
+        device="cuda" if torch.cuda.is_available() else "cpu"
     )
     test = pd.read_json(f"data/preprocessed/{dataset_name}/test.jsonl", lines=True)
     test_x = test["text"].tolist()
diff --git a/requirements.txt b/requirements.txt
index fec55bd..66b509a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,8 +2,8 @@ datasets
 transformers
 click
 scikit-learn
-dvc[s3]
-shap
+dvc[s3]==2.46.0
+shap==0.41.0
 lpmn_client_biz
 
 --find-links https://download.pytorch.org/whl/torch_stable.html
diff --git a/text_attacks/models/20_news.py b/text_attacks/models/20_news.py
new file mode 100644
index 0000000..53712fa
--- /dev/null
+++ b/text_attacks/models/20_news.py
@@ -0,0 +1,42 @@
+"""Classification model for enron_spam"""
+import os
+
+import torch
+from tqdm import tqdm
+
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
+
+
+def get_model_and_tokenizer():
+    model_path = "./data/models/20_news"
+    tokenizer = AutoTokenizer.from_pretrained(model_path)
+    model = AutoModelForSequenceClassification.from_pretrained(model_path)
+    return model, tokenizer
+
+
+def get_classify_function(device="cpu"):
+    model, tokenizer = get_model_and_tokenizer()
+    model.eval()
+    model = model.to(device)
+
+    def fun(texts):
+        logits = list()
+        i = 0
+        for chunk in tqdm(
+            [texts[pos:pos + 256] for pos in range(0, len(texts), 256)]
+        ):
+            encoded_inputs = tokenizer(
+                chunk,
+                return_tensors="pt",
+                padding=True,
+                truncation=True,
+                max_length=512
+            ).to(device)
+            with torch.no_grad():
+                logits.append(model(**encoded_inputs).logits.cpu())
+        logits = torch.cat(logits, dim=0)
+        pred_y = torch.argmax(logits, dim=1).tolist()
+        pred_y = [model.config.id2label[p] for p in pred_y]
+        return pred_y
+
+    return fun
diff --git a/text_attacks/models/enron_spam.py b/text_attacks/models/enron_spam.py
index 063a52a..9a1946d 100644
--- a/text_attacks/models/enron_spam.py
+++ b/text_attacks/models/enron_spam.py
@@ -2,12 +2,13 @@
 import os
 
 import torch
+from tqdm import tqdm
 
 from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 def get_model_and_tokenizer():
-    model_path = "data/models/endron_spam"
+    model_path = "./data/models/endron_spam"
     if not os.path.exists(model_path):
         model_path = "mrm8488/bert-tiny-finetuned-enron-spam-detection"
     tokenizer = AutoTokenizer.from_pretrained(model_path)
@@ -16,18 +17,27 @@ def get_model_and_tokenizer():
     return model, tokenizer
 
 
-def get_classify_function():
+def get_classify_function(device="cpu"):
     model, tokenizer = get_model_and_tokenizer()
+    model.eval()
+    model = model.to(device)
 
     def fun(texts):
-        encoded_inputs = tokenizer(
-            texts,
-            return_tensors="pt",
-            padding=True,
-            truncation=True,
-            max_length=512
-        )
-        logits = model(**encoded_inputs).logits
+        logits = list()
+        i = 0
+        for chunk in tqdm(
+            [texts[pos:pos + 256] for pos in range(0, len(texts), 256)]
+        ):
+            encoded_inputs = tokenizer(
+                chunk,
+                return_tensors="pt",
+                padding=True,
+                truncation=True,
+                max_length=512
+            ).to(device)
+            with torch.no_grad():
+                logits.append(model(**encoded_inputs).logits.cpu())
+        logits = torch.cat(logits, dim=0)
         pred_y = torch.argmax(logits, dim=1).tolist()
         pred_y = [model.config.id2label[p] for p in pred_y]
         return pred_y
diff --git a/text_attacks/models/wiki_pl.py b/text_attacks/models/wiki_pl.py
new file mode 100644
index 0000000..1ad1539
--- /dev/null
+++ b/text_attacks/models/wiki_pl.py
@@ -0,0 +1,42 @@
+"""Classification model for enron_spam"""
+import os
+
+import torch
+from tqdm import tqdm
+
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
+
+
+def get_model_and_tokenizer():
+    model_path = "./data/models/wiki_pl"
+    tokenizer = AutoTokenizer.from_pretrained(model_path)
+    model = AutoModelForSequenceClassification.from_pretrained(model_path)
+    return model, tokenizer
+
+
+def get_classify_function(device="cpu"):
+    model, tokenizer = get_model_and_tokenizer()
+    model.eval()
+    model = model.to(device)
+
+    def fun(texts):
+        logits = list()
+        i = 0
+        for chunk in tqdm(
+            [texts[pos:pos + 256] for pos in range(0, len(texts), 256)]
+        ):
+            encoded_inputs = tokenizer(
+                chunk,
+                return_tensors="pt",
+                padding=True,
+                truncation=True,
+                max_length=512
+            ).to(device)
+            with torch.no_grad():
+                logits.append(model(**encoded_inputs).logits.cpu())
+        logits = torch.cat(logits, dim=0)
+        pred_y = torch.argmax(logits, dim=1).tolist()
+        pred_y = [model.config.id2label[p] for p in pred_y]
+        return pred_y
+
+    return fun
diff --git a/text_attacks/utils.py b/text_attacks/utils.py
index e47d520..6a05882 100644
--- a/text_attacks/utils.py
+++ b/text_attacks/utils.py
@@ -11,13 +11,13 @@ def get_model_and_tokenizer(dataset_name):
     return fun()
 
 
-def get_classify_function(dataset_name):
+def get_classify_function(dataset_name, device="cpu"):
     """Return get_model_and_tokenizer for a specific dataset."""
     fun = getattr(
         importlib.import_module(f"text_attacks.models.{dataset_name}"),
         "get_classify_function",
     )
-    return fun()
+    return fun(device=device)
 
 
 def download_dataset(dataset_name):
-- 
GitLab