Skip to content
Snippets Groups Projects
enron_spam.py 822 B
Newer Older
MGniew's avatar
MGniew committed
"""Classification model for enron_spam"""
import pandas as pd
from datasets import load_dataset
from sklearn.model_selection import train_test_split


def download_dataset():
    dataset = load_dataset("SetFit/enron_spam")
    train = pd.DataFrame(dataset["train"].to_dict())
    test = pd.DataFrame(dataset["test"].to_dict())

    train["label"] = train["label_text"]
    train = train.rename(columns={"message_id": "id"})
    train = train.drop(columns=["label_text", "subject", "message", "date"])

    test["label"] = test["label_text"]
    test = test.rename(columns={"message_id": "id"})
    test = test.drop(columns=["label_text", "subject", "message", "date"])
    adversarial, test = train_test_split(
        test,
        test_size=0.9,
        stratify=test["label"]
    )

    return train, test, adversarial