1
This commit is contained in:
156
backend/app/ml/naive_bayes_classifier.py
Normal file
156
backend/app/ml/naive_bayes_classifier.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import hashlib
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import joblib
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.naive_bayes import MultinomialNB
|
||||
|
||||
|
||||
VALID_LABELS = {"spam", "ham"}
|
||||
|
||||
|
||||
class NaiveBayesSpamClassifier:
|
||||
def __init__(self, model_path: str):
|
||||
self.model_path = Path(model_path)
|
||||
self.vectorizer = None
|
||||
self.model = None
|
||||
self.metadata = {}
|
||||
|
||||
@staticmethod
|
||||
def normalize_label(label: str) -> str:
|
||||
value = (label or "").strip().lower()
|
||||
return value if value in VALID_LABELS else ""
|
||||
|
||||
@staticmethod
|
||||
def normalize_text(text: str) -> str:
|
||||
return " ".join((text or "").strip().split())
|
||||
|
||||
@staticmethod
|
||||
def _to_metadata(samples: list[dict], version_seed: str) -> dict:
|
||||
dist = Counter([item["label"] for item in samples])
|
||||
digest = hashlib.md5(version_seed.encode("utf-8")).hexdigest()[:16]
|
||||
return {
|
||||
"trained_at": datetime.utcnow().isoformat(),
|
||||
"sample_count": len(samples),
|
||||
"label_distribution": dict(dist),
|
||||
"version": f"nb-{digest}",
|
||||
}
|
||||
|
||||
def train(self, samples: list[dict]) -> dict:
|
||||
clean_samples = []
|
||||
for row in samples:
|
||||
text = self.normalize_text(row.get("text"))
|
||||
label = self.normalize_label(row.get("label"))
|
||||
if not text or not label:
|
||||
continue
|
||||
clean_samples.append({"text": text, "label": label})
|
||||
|
||||
if len(clean_samples) < 10:
|
||||
raise ValueError("训练样本太少,至少需要10条有效样本")
|
||||
|
||||
texts = [item["text"] for item in clean_samples]
|
||||
labels = [item["label"] for item in clean_samples]
|
||||
|
||||
vectorizer = TfidfVectorizer(analyzer="char", ngram_range=(1, 2), min_df=1)
|
||||
x = vectorizer.fit_transform(texts)
|
||||
|
||||
model = MultinomialNB(alpha=0.4)
|
||||
model.fit(x, labels)
|
||||
|
||||
version_seed = "||".join([f"{item['label']}::{item['text']}" for item in clean_samples])
|
||||
metadata = self._to_metadata(clean_samples, version_seed)
|
||||
|
||||
self.model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
joblib.dump({"vectorizer": vectorizer, "model": model, "metadata": metadata}, self.model_path)
|
||||
|
||||
self.vectorizer = vectorizer
|
||||
self.model = model
|
||||
self.metadata = metadata
|
||||
return metadata
|
||||
|
||||
def load(self) -> bool:
|
||||
if not self.model_path.exists():
|
||||
return False
|
||||
|
||||
payload = joblib.load(self.model_path)
|
||||
self.vectorizer = payload.get("vectorizer")
|
||||
self.model = payload.get("model")
|
||||
self.metadata = payload.get("metadata", {})
|
||||
return self.vectorizer is not None and self.model is not None
|
||||
|
||||
def ensure_ready(self, samples: list[dict]) -> dict:
|
||||
if self.load():
|
||||
return self.metadata
|
||||
return self.train(samples)
|
||||
|
||||
def predict(self, text: str) -> dict:
|
||||
if self.vectorizer is None or self.model is None:
|
||||
raise RuntimeError("模型未加载,请先训练")
|
||||
|
||||
cleaned = self.normalize_text(text)
|
||||
if len(cleaned) < 2:
|
||||
raise ValueError("待识别文本至少2个字符")
|
||||
|
||||
x = self.vectorizer.transform([cleaned])
|
||||
probs = self.model.predict_proba(x)[0]
|
||||
classes = list(self.model.classes_)
|
||||
|
||||
spam_idx = classes.index("spam") if "spam" in classes else 0
|
||||
ham_idx = classes.index("ham") if "ham" in classes else 0
|
||||
|
||||
spam_prob = float(probs[spam_idx])
|
||||
ham_prob = float(probs[ham_idx])
|
||||
prediction = "spam" if spam_prob >= ham_prob else "ham"
|
||||
|
||||
reason_tokens = self._extract_reason_tokens(cleaned, classes, x)
|
||||
confidence = max(spam_prob, ham_prob)
|
||||
|
||||
return {
|
||||
"text": cleaned,
|
||||
"prediction": prediction,
|
||||
"prediction_text": "垃圾信息" if prediction == "spam" else "正常信息",
|
||||
"spam_probability": round(spam_prob, 4),
|
||||
"ham_probability": round(ham_prob, 4),
|
||||
"confidence": round(confidence, 4),
|
||||
"reason_tokens": reason_tokens,
|
||||
"model_version": self.metadata.get("version", ""),
|
||||
"trained_at": self.metadata.get("trained_at"),
|
||||
}
|
||||
|
||||
def _extract_reason_tokens(self, text: str, classes: list[str], x_row) -> list[str]:
|
||||
try:
|
||||
vocab = self.vectorizer.vocabulary_
|
||||
feature_names = self.vectorizer.get_feature_names_out()
|
||||
class_log_prob = self.model.feature_log_prob_
|
||||
spam_idx = classes.index("spam") if "spam" in classes else 0
|
||||
ham_idx = classes.index("ham") if "ham" in classes else 0
|
||||
|
||||
token_counter = Counter()
|
||||
for idx in x_row.nonzero()[1]:
|
||||
token = feature_names[idx]
|
||||
token_counter[token] += 1
|
||||
|
||||
scored = []
|
||||
for token in token_counter:
|
||||
idx = vocab.get(token)
|
||||
if idx is None:
|
||||
continue
|
||||
delta = class_log_prob[spam_idx][idx] - class_log_prob[ham_idx][idx]
|
||||
scored.append((token, delta))
|
||||
|
||||
scored.sort(key=lambda row: abs(row[1]), reverse=True)
|
||||
return [token for token, _ in scored[:5]]
|
||||
except Exception:
|
||||
return list(text[:5])
|
||||
|
||||
def model_info(self) -> dict:
|
||||
return {
|
||||
"ready": self.vectorizer is not None and self.model is not None,
|
||||
"model_path": str(self.model_path),
|
||||
"version": self.metadata.get("version", ""),
|
||||
"trained_at": self.metadata.get("trained_at"),
|
||||
"sample_count": int(self.metadata.get("sample_count", 0) or 0),
|
||||
"label_distribution": self.metadata.get("label_distribution", {}),
|
||||
}
|
||||
Reference in New Issue
Block a user