Skip to content
Snippets Groups Projects
enron_spam.py 1021 B
Newer Older
MGniew's avatar
MGniew committed
"""Classification model for enron_spam"""
MGniew's avatar
MGniew committed
import os

import torch
MGniew's avatar
MGniew committed

from transformers import AutoTokenizer, AutoModelForSequenceClassification


def get_model_and_tokenizer():
MGniew's avatar
MGniew committed
    model_path = "data/models/endron_spam"
    if not os.path.exists(model_path):
        model_path = "mrm8488/bert-tiny-finetuned-enron-spam-detection"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
MGniew's avatar
MGniew committed
    model.config.id2label = {0: "ham", 1: "spam"}
    return model, tokenizer
MGniew's avatar
MGniew committed


def get_classify_function():
    model, tokenizer = get_model_and_tokenizer()

    def fun(texts):
        encoded_inputs = tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )
        logits = model(**encoded_inputs).logits
        pred_y = torch.argmax(logits, dim=1).tolist()
        pred_y = [model.config.id2label[p] for p in pred_y]
        return pred_y

    return fun