import hashlib from collections import Counter from datetime import datetime from pathlib import Path import joblib from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.naive_bayes import MultinomialNB VALID_LABELS = {"spam", "ham"} class NaiveBayesSpamClassifier: def __init__(self, model_path: str): self.model_path = Path(model_path) self.vectorizer = None self.model = None self.metadata = {} @staticmethod def normalize_label(label: str) -> str: value = (label or "").strip().lower() return value if value in VALID_LABELS else "" @staticmethod def normalize_text(text: str) -> str: return " ".join((text or "").strip().split()) @staticmethod def _to_metadata(samples: list[dict], version_seed: str) -> dict: dist = Counter([item["label"] for item in samples]) digest = hashlib.md5(version_seed.encode("utf-8")).hexdigest()[:16] return { "trained_at": datetime.utcnow().isoformat(), "sample_count": len(samples), "label_distribution": dict(dist), "version": f"nb-{digest}", } def train(self, samples: list[dict]) -> dict: clean_samples = [] for row in samples: text = self.normalize_text(row.get("text")) label = self.normalize_label(row.get("label")) if not text or not label: continue clean_samples.append({"text": text, "label": label}) if len(clean_samples) < 10: raise ValueError("训练样本太少,至少需要10条有效样本") texts = [item["text"] for item in clean_samples] labels = [item["label"] for item in clean_samples] vectorizer = TfidfVectorizer(analyzer="char", ngram_range=(1, 2), min_df=1) x = vectorizer.fit_transform(texts) model = MultinomialNB(alpha=0.4) model.fit(x, labels) version_seed = "||".join([f"{item['label']}::{item['text']}" for item in clean_samples]) metadata = self._to_metadata(clean_samples, version_seed) self.model_path.parent.mkdir(parents=True, exist_ok=True) joblib.dump({"vectorizer": vectorizer, "model": model, "metadata": metadata}, self.model_path) self.vectorizer = vectorizer self.model = model self.metadata = metadata return metadata def load(self) -> bool: if not self.model_path.exists(): return False payload = joblib.load(self.model_path) self.vectorizer = payload.get("vectorizer") self.model = payload.get("model") self.metadata = payload.get("metadata", {}) return self.vectorizer is not None and self.model is not None def ensure_ready(self, samples: list[dict]) -> dict: if self.load(): return self.metadata return self.train(samples) def predict(self, text: str) -> dict: if self.vectorizer is None or self.model is None: raise RuntimeError("模型未加载,请先训练") cleaned = self.normalize_text(text) if len(cleaned) < 2: raise ValueError("待识别文本至少2个字符") x = self.vectorizer.transform([cleaned]) probs = self.model.predict_proba(x)[0] classes = list(self.model.classes_) spam_idx = classes.index("spam") if "spam" in classes else 0 ham_idx = classes.index("ham") if "ham" in classes else 0 spam_prob = float(probs[spam_idx]) ham_prob = float(probs[ham_idx]) prediction = "spam" if spam_prob >= ham_prob else "ham" reason_tokens = self._extract_reason_tokens(cleaned, classes, x) confidence = max(spam_prob, ham_prob) return { "text": cleaned, "prediction": prediction, "prediction_text": "垃圾信息" if prediction == "spam" else "正常信息", "spam_probability": round(spam_prob, 4), "ham_probability": round(ham_prob, 4), "confidence": round(confidence, 4), "reason_tokens": reason_tokens, "model_version": self.metadata.get("version", ""), "trained_at": self.metadata.get("trained_at"), } def _extract_reason_tokens(self, text: str, classes: list[str], x_row) -> list[str]: try: vocab = self.vectorizer.vocabulary_ feature_names = self.vectorizer.get_feature_names_out() class_log_prob = self.model.feature_log_prob_ spam_idx = classes.index("spam") if "spam" in classes else 0 ham_idx = classes.index("ham") if "ham" in classes else 0 token_counter = Counter() for idx in x_row.nonzero()[1]: token = feature_names[idx] token_counter[token] += 1 scored = [] for token in token_counter: idx = vocab.get(token) if idx is None: continue delta = class_log_prob[spam_idx][idx] - class_log_prob[ham_idx][idx] scored.append((token, delta)) scored.sort(key=lambda row: abs(row[1]), reverse=True) return [token for token, _ in scored[:5]] except Exception: return list(text[:5]) def model_info(self) -> dict: return { "ready": self.vectorizer is not None and self.model is not None, "model_path": str(self.model_path), "version": self.metadata.get("version", ""), "trained_at": self.metadata.get("trained_at"), "sample_count": int(self.metadata.get("sample_count", 0) or 0), "label_distribution": self.metadata.get("label_distribution", {}), }