157 lines
5.7 KiB
Python
157 lines
5.7 KiB
Python
import hashlib
|
||
from collections import Counter
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
import joblib
|
||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||
from sklearn.naive_bayes import MultinomialNB
|
||
|
||
|
||
VALID_LABELS = {"spam", "ham"}
|
||
|
||
|
||
class NaiveBayesSpamClassifier:
|
||
def __init__(self, model_path: str):
|
||
self.model_path = Path(model_path)
|
||
self.vectorizer = None
|
||
self.model = None
|
||
self.metadata = {}
|
||
|
||
@staticmethod
|
||
def normalize_label(label: str) -> str:
|
||
value = (label or "").strip().lower()
|
||
return value if value in VALID_LABELS else ""
|
||
|
||
@staticmethod
|
||
def normalize_text(text: str) -> str:
|
||
return " ".join((text or "").strip().split())
|
||
|
||
@staticmethod
|
||
def _to_metadata(samples: list[dict], version_seed: str) -> dict:
|
||
dist = Counter([item["label"] for item in samples])
|
||
digest = hashlib.md5(version_seed.encode("utf-8")).hexdigest()[:16]
|
||
return {
|
||
"trained_at": datetime.utcnow().isoformat(),
|
||
"sample_count": len(samples),
|
||
"label_distribution": dict(dist),
|
||
"version": f"nb-{digest}",
|
||
}
|
||
|
||
def train(self, samples: list[dict]) -> dict:
|
||
clean_samples = []
|
||
for row in samples:
|
||
text = self.normalize_text(row.get("text"))
|
||
label = self.normalize_label(row.get("label"))
|
||
if not text or not label:
|
||
continue
|
||
clean_samples.append({"text": text, "label": label})
|
||
|
||
if len(clean_samples) < 10:
|
||
raise ValueError("训练样本太少,至少需要10条有效样本")
|
||
|
||
texts = [item["text"] for item in clean_samples]
|
||
labels = [item["label"] for item in clean_samples]
|
||
|
||
vectorizer = TfidfVectorizer(analyzer="char", ngram_range=(1, 2), min_df=1)
|
||
x = vectorizer.fit_transform(texts)
|
||
|
||
model = MultinomialNB(alpha=0.4)
|
||
model.fit(x, labels)
|
||
|
||
version_seed = "||".join([f"{item['label']}::{item['text']}" for item in clean_samples])
|
||
metadata = self._to_metadata(clean_samples, version_seed)
|
||
|
||
self.model_path.parent.mkdir(parents=True, exist_ok=True)
|
||
joblib.dump({"vectorizer": vectorizer, "model": model, "metadata": metadata}, self.model_path)
|
||
|
||
self.vectorizer = vectorizer
|
||
self.model = model
|
||
self.metadata = metadata
|
||
return metadata
|
||
|
||
def load(self) -> bool:
|
||
if not self.model_path.exists():
|
||
return False
|
||
|
||
payload = joblib.load(self.model_path)
|
||
self.vectorizer = payload.get("vectorizer")
|
||
self.model = payload.get("model")
|
||
self.metadata = payload.get("metadata", {})
|
||
return self.vectorizer is not None and self.model is not None
|
||
|
||
def ensure_ready(self, samples: list[dict]) -> dict:
|
||
if self.load():
|
||
return self.metadata
|
||
return self.train(samples)
|
||
|
||
def predict(self, text: str) -> dict:
|
||
if self.vectorizer is None or self.model is None:
|
||
raise RuntimeError("模型未加载,请先训练")
|
||
|
||
cleaned = self.normalize_text(text)
|
||
if len(cleaned) < 2:
|
||
raise ValueError("待识别文本至少2个字符")
|
||
|
||
x = self.vectorizer.transform([cleaned])
|
||
probs = self.model.predict_proba(x)[0]
|
||
classes = list(self.model.classes_)
|
||
|
||
spam_idx = classes.index("spam") if "spam" in classes else 0
|
||
ham_idx = classes.index("ham") if "ham" in classes else 0
|
||
|
||
spam_prob = float(probs[spam_idx])
|
||
ham_prob = float(probs[ham_idx])
|
||
prediction = "spam" if spam_prob >= ham_prob else "ham"
|
||
|
||
reason_tokens = self._extract_reason_tokens(cleaned, classes, x)
|
||
confidence = max(spam_prob, ham_prob)
|
||
|
||
return {
|
||
"text": cleaned,
|
||
"prediction": prediction,
|
||
"prediction_text": "垃圾信息" if prediction == "spam" else "正常信息",
|
||
"spam_probability": round(spam_prob, 4),
|
||
"ham_probability": round(ham_prob, 4),
|
||
"confidence": round(confidence, 4),
|
||
"reason_tokens": reason_tokens,
|
||
"model_version": self.metadata.get("version", ""),
|
||
"trained_at": self.metadata.get("trained_at"),
|
||
}
|
||
|
||
def _extract_reason_tokens(self, text: str, classes: list[str], x_row) -> list[str]:
|
||
try:
|
||
vocab = self.vectorizer.vocabulary_
|
||
feature_names = self.vectorizer.get_feature_names_out()
|
||
class_log_prob = self.model.feature_log_prob_
|
||
spam_idx = classes.index("spam") if "spam" in classes else 0
|
||
ham_idx = classes.index("ham") if "ham" in classes else 0
|
||
|
||
token_counter = Counter()
|
||
for idx in x_row.nonzero()[1]:
|
||
token = feature_names[idx]
|
||
token_counter[token] += 1
|
||
|
||
scored = []
|
||
for token in token_counter:
|
||
idx = vocab.get(token)
|
||
if idx is None:
|
||
continue
|
||
delta = class_log_prob[spam_idx][idx] - class_log_prob[ham_idx][idx]
|
||
scored.append((token, delta))
|
||
|
||
scored.sort(key=lambda row: abs(row[1]), reverse=True)
|
||
return [token for token, _ in scored[:5]]
|
||
except Exception:
|
||
return list(text[:5])
|
||
|
||
def model_info(self) -> dict:
|
||
return {
|
||
"ready": self.vectorizer is not None and self.model is not None,
|
||
"model_path": str(self.model_path),
|
||
"version": self.metadata.get("version", ""),
|
||
"trained_at": self.metadata.get("trained_at"),
|
||
"sample_count": int(self.metadata.get("sample_count", 0) or 0),
|
||
"label_distribution": self.metadata.get("label_distribution", {}),
|
||
}
|