Files
c/backend/app/ml/naive_bayes_classifier.py
刘正航 b5237f9038 1
2026-04-21 22:45:19 +08:00

157 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import hashlib
from collections import Counter
from datetime import datetime
from pathlib import Path
import joblib
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
VALID_LABELS = {"spam", "ham"}
class NaiveBayesSpamClassifier:
def __init__(self, model_path: str):
self.model_path = Path(model_path)
self.vectorizer = None
self.model = None
self.metadata = {}
@staticmethod
def normalize_label(label: str) -> str:
value = (label or "").strip().lower()
return value if value in VALID_LABELS else ""
@staticmethod
def normalize_text(text: str) -> str:
return " ".join((text or "").strip().split())
@staticmethod
def _to_metadata(samples: list[dict], version_seed: str) -> dict:
dist = Counter([item["label"] for item in samples])
digest = hashlib.md5(version_seed.encode("utf-8")).hexdigest()[:16]
return {
"trained_at": datetime.utcnow().isoformat(),
"sample_count": len(samples),
"label_distribution": dict(dist),
"version": f"nb-{digest}",
}
def train(self, samples: list[dict]) -> dict:
clean_samples = []
for row in samples:
text = self.normalize_text(row.get("text"))
label = self.normalize_label(row.get("label"))
if not text or not label:
continue
clean_samples.append({"text": text, "label": label})
if len(clean_samples) < 10:
raise ValueError("训练样本太少至少需要10条有效样本")
texts = [item["text"] for item in clean_samples]
labels = [item["label"] for item in clean_samples]
vectorizer = TfidfVectorizer(analyzer="char", ngram_range=(1, 2), min_df=1)
x = vectorizer.fit_transform(texts)
model = MultinomialNB(alpha=0.4)
model.fit(x, labels)
version_seed = "||".join([f"{item['label']}::{item['text']}" for item in clean_samples])
metadata = self._to_metadata(clean_samples, version_seed)
self.model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump({"vectorizer": vectorizer, "model": model, "metadata": metadata}, self.model_path)
self.vectorizer = vectorizer
self.model = model
self.metadata = metadata
return metadata
def load(self) -> bool:
if not self.model_path.exists():
return False
payload = joblib.load(self.model_path)
self.vectorizer = payload.get("vectorizer")
self.model = payload.get("model")
self.metadata = payload.get("metadata", {})
return self.vectorizer is not None and self.model is not None
def ensure_ready(self, samples: list[dict]) -> dict:
if self.load():
return self.metadata
return self.train(samples)
def predict(self, text: str) -> dict:
if self.vectorizer is None or self.model is None:
raise RuntimeError("模型未加载,请先训练")
cleaned = self.normalize_text(text)
if len(cleaned) < 2:
raise ValueError("待识别文本至少2个字符")
x = self.vectorizer.transform([cleaned])
probs = self.model.predict_proba(x)[0]
classes = list(self.model.classes_)
spam_idx = classes.index("spam") if "spam" in classes else 0
ham_idx = classes.index("ham") if "ham" in classes else 0
spam_prob = float(probs[spam_idx])
ham_prob = float(probs[ham_idx])
prediction = "spam" if spam_prob >= ham_prob else "ham"
reason_tokens = self._extract_reason_tokens(cleaned, classes, x)
confidence = max(spam_prob, ham_prob)
return {
"text": cleaned,
"prediction": prediction,
"prediction_text": "垃圾信息" if prediction == "spam" else "正常信息",
"spam_probability": round(spam_prob, 4),
"ham_probability": round(ham_prob, 4),
"confidence": round(confidence, 4),
"reason_tokens": reason_tokens,
"model_version": self.metadata.get("version", ""),
"trained_at": self.metadata.get("trained_at"),
}
def _extract_reason_tokens(self, text: str, classes: list[str], x_row) -> list[str]:
try:
vocab = self.vectorizer.vocabulary_
feature_names = self.vectorizer.get_feature_names_out()
class_log_prob = self.model.feature_log_prob_
spam_idx = classes.index("spam") if "spam" in classes else 0
ham_idx = classes.index("ham") if "ham" in classes else 0
token_counter = Counter()
for idx in x_row.nonzero()[1]:
token = feature_names[idx]
token_counter[token] += 1
scored = []
for token in token_counter:
idx = vocab.get(token)
if idx is None:
continue
delta = class_log_prob[spam_idx][idx] - class_log_prob[ham_idx][idx]
scored.append((token, delta))
scored.sort(key=lambda row: abs(row[1]), reverse=True)
return [token for token, _ in scored[:5]]
except Exception:
return list(text[:5])
def model_info(self) -> dict:
return {
"ready": self.vectorizer is not None and self.model is not None,
"model_path": str(self.model_path),
"version": self.metadata.get("version", ""),
"trained_at": self.metadata.get("trained_at"),
"sample_count": int(self.metadata.get("sample_count", 0) or 0),
"label_distribution": self.metadata.get("label_distribution", {}),
}