Files
c/backend/app/ml/naive_bayes_classifier.py
刘正航 50440e84fb feat: 风险关键词红色标记 + 点击显示权重贡献
- 后端: _extract_reason_tokens 返回 [{token, weight}] 格式
- 前端: detect/batch 页面风险关键词使用红色标签样式
- 点击关键词弹窗显示权重值及判定倾向

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-21 22:55:42 +08:00

157 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import hashlib
from collections import Counter
from datetime import datetime
from pathlib import Path
import joblib
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
VALID_LABELS = {"spam", "ham"}
class NaiveBayesSpamClassifier:
def __init__(self, model_path: str):
self.model_path = Path(model_path)
self.vectorizer = None
self.model = None
self.metadata = {}
@staticmethod
def normalize_label(label: str) -> str:
value = (label or "").strip().lower()
return value if value in VALID_LABELS else ""
@staticmethod
def normalize_text(text: str) -> str:
return " ".join((text or "").strip().split())
@staticmethod
def _to_metadata(samples: list[dict], version_seed: str) -> dict:
dist = Counter([item["label"] for item in samples])
digest = hashlib.md5(version_seed.encode("utf-8")).hexdigest()[:16]
return {
"trained_at": datetime.utcnow().isoformat(),
"sample_count": len(samples),
"label_distribution": dict(dist),
"version": f"nb-{digest}",
}
def train(self, samples: list[dict]) -> dict:
clean_samples = []
for row in samples:
text = self.normalize_text(row.get("text"))
label = self.normalize_label(row.get("label"))
if not text or not label:
continue
clean_samples.append({"text": text, "label": label})
if len(clean_samples) < 10:
raise ValueError("训练样本太少至少需要10条有效样本")
texts = [item["text"] for item in clean_samples]
labels = [item["label"] for item in clean_samples]
vectorizer = TfidfVectorizer(analyzer="char", ngram_range=(1, 2), min_df=1)
x = vectorizer.fit_transform(texts)
model = MultinomialNB(alpha=0.4)
model.fit(x, labels)
version_seed = "||".join([f"{item['label']}::{item['text']}" for item in clean_samples])
metadata = self._to_metadata(clean_samples, version_seed)
self.model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump({"vectorizer": vectorizer, "model": model, "metadata": metadata}, self.model_path)
self.vectorizer = vectorizer
self.model = model
self.metadata = metadata
return metadata
def load(self) -> bool:
if not self.model_path.exists():
return False
payload = joblib.load(self.model_path)
self.vectorizer = payload.get("vectorizer")
self.model = payload.get("model")
self.metadata = payload.get("metadata", {})
return self.vectorizer is not None and self.model is not None
def ensure_ready(self, samples: list[dict]) -> dict:
if self.load():
return self.metadata
return self.train(samples)
def predict(self, text: str) -> dict:
if self.vectorizer is None or self.model is None:
raise RuntimeError("模型未加载,请先训练")
cleaned = self.normalize_text(text)
if len(cleaned) < 2:
raise ValueError("待识别文本至少2个字符")
x = self.vectorizer.transform([cleaned])
probs = self.model.predict_proba(x)[0]
classes = list(self.model.classes_)
spam_idx = classes.index("spam") if "spam" in classes else 0
ham_idx = classes.index("ham") if "ham" in classes else 0
spam_prob = float(probs[spam_idx])
ham_prob = float(probs[ham_idx])
prediction = "spam" if spam_prob >= ham_prob else "ham"
reason_tokens = self._extract_reason_tokens(cleaned, classes, x)
confidence = max(spam_prob, ham_prob)
return {
"text": cleaned,
"prediction": prediction,
"prediction_text": "垃圾信息" if prediction == "spam" else "正常信息",
"spam_probability": round(spam_prob, 4),
"ham_probability": round(ham_prob, 4),
"confidence": round(confidence, 4),
"reason_tokens": reason_tokens,
"model_version": self.metadata.get("version", ""),
"trained_at": self.metadata.get("trained_at"),
}
def _extract_reason_tokens(self, text: str, classes: list[str], x_row) -> list[dict]:
try:
vocab = self.vectorizer.vocabulary_
feature_names = self.vectorizer.get_feature_names_out()
class_log_prob = self.model.feature_log_prob_
spam_idx = classes.index("spam") if "spam" in classes else 0
ham_idx = classes.index("ham") if "ham" in classes else 0
token_counter = Counter()
for idx in x_row.nonzero()[1]:
token = feature_names[idx]
token_counter[token] += 1
scored = []
for token in token_counter:
idx = vocab.get(token)
if idx is None:
continue
delta = class_log_prob[spam_idx][idx] - class_log_prob[ham_idx][idx]
scored.append({"token": token, "weight": round(delta, 4)})
scored.sort(key=lambda row: abs(row["weight"]), reverse=True)
return scored[:5]
except Exception:
return [{"token": ch, "weight": 0.0} for ch in list(text[:5])]
def model_info(self) -> dict:
return {
"ready": self.vectorizer is not None and self.model is not None,
"model_path": str(self.model_path),
"version": self.metadata.get("version", ""),
"trained_at": self.metadata.get("trained_at"),
"sample_count": int(self.metadata.get("sample_count", 0) or 0),
"label_distribution": self.metadata.get("label_distribution", {}),
}