Newer
Older
"""Classification model for enron_spam"""
import pandas as pd
from datasets import load_dataset
from sklearn.model_selection import train_test_split
def download_dataset():
dataset = load_dataset("SetFit/enron_spam")
train = pd.DataFrame(dataset["train"].to_dict())
test = pd.DataFrame(dataset["test"].to_dict())
train["label"] = train["label_text"]
train = train.rename(columns={"message_id": "id"})
train = train.drop(columns=["label_text", "subject", "message", "date"])
test["label"] = test["label_text"]
test = test.rename(columns={"message_id": "id"})
test = test.drop(columns=["label_text", "subject", "message", "date"])
adversarial, test = train_test_split(
test,
test_size=0.9,
stratify=test["label"]
)
return train, test, adversarial