This commit is contained in:
刘正航
2026-04-21 22:45:19 +08:00
commit b5237f9038
159 changed files with 7769 additions and 0 deletions

View File

@@ -0,0 +1,156 @@
import hashlib
from collections import Counter
from datetime import datetime
from pathlib import Path
import joblib
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
VALID_LABELS = {"spam", "ham"}
class NaiveBayesSpamClassifier:
def __init__(self, model_path: str):
self.model_path = Path(model_path)
self.vectorizer = None
self.model = None
self.metadata = {}
@staticmethod
def normalize_label(label: str) -> str:
value = (label or "").strip().lower()
return value if value in VALID_LABELS else ""
@staticmethod
def normalize_text(text: str) -> str:
return " ".join((text or "").strip().split())
@staticmethod
def _to_metadata(samples: list[dict], version_seed: str) -> dict:
dist = Counter([item["label"] for item in samples])
digest = hashlib.md5(version_seed.encode("utf-8")).hexdigest()[:16]
return {
"trained_at": datetime.utcnow().isoformat(),
"sample_count": len(samples),
"label_distribution": dict(dist),
"version": f"nb-{digest}",
}
def train(self, samples: list[dict]) -> dict:
clean_samples = []
for row in samples:
text = self.normalize_text(row.get("text"))
label = self.normalize_label(row.get("label"))
if not text or not label:
continue
clean_samples.append({"text": text, "label": label})
if len(clean_samples) < 10:
raise ValueError("训练样本太少至少需要10条有效样本")
texts = [item["text"] for item in clean_samples]
labels = [item["label"] for item in clean_samples]
vectorizer = TfidfVectorizer(analyzer="char", ngram_range=(1, 2), min_df=1)
x = vectorizer.fit_transform(texts)
model = MultinomialNB(alpha=0.4)
model.fit(x, labels)
version_seed = "||".join([f"{item['label']}::{item['text']}" for item in clean_samples])
metadata = self._to_metadata(clean_samples, version_seed)
self.model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump({"vectorizer": vectorizer, "model": model, "metadata": metadata}, self.model_path)
self.vectorizer = vectorizer
self.model = model
self.metadata = metadata
return metadata
def load(self) -> bool:
if not self.model_path.exists():
return False
payload = joblib.load(self.model_path)
self.vectorizer = payload.get("vectorizer")
self.model = payload.get("model")
self.metadata = payload.get("metadata", {})
return self.vectorizer is not None and self.model is not None
def ensure_ready(self, samples: list[dict]) -> dict:
if self.load():
return self.metadata
return self.train(samples)
def predict(self, text: str) -> dict:
if self.vectorizer is None or self.model is None:
raise RuntimeError("模型未加载,请先训练")
cleaned = self.normalize_text(text)
if len(cleaned) < 2:
raise ValueError("待识别文本至少2个字符")
x = self.vectorizer.transform([cleaned])
probs = self.model.predict_proba(x)[0]
classes = list(self.model.classes_)
spam_idx = classes.index("spam") if "spam" in classes else 0
ham_idx = classes.index("ham") if "ham" in classes else 0
spam_prob = float(probs[spam_idx])
ham_prob = float(probs[ham_idx])
prediction = "spam" if spam_prob >= ham_prob else "ham"
reason_tokens = self._extract_reason_tokens(cleaned, classes, x)
confidence = max(spam_prob, ham_prob)
return {
"text": cleaned,
"prediction": prediction,
"prediction_text": "垃圾信息" if prediction == "spam" else "正常信息",
"spam_probability": round(spam_prob, 4),
"ham_probability": round(ham_prob, 4),
"confidence": round(confidence, 4),
"reason_tokens": reason_tokens,
"model_version": self.metadata.get("version", ""),
"trained_at": self.metadata.get("trained_at"),
}
def _extract_reason_tokens(self, text: str, classes: list[str], x_row) -> list[str]:
try:
vocab = self.vectorizer.vocabulary_
feature_names = self.vectorizer.get_feature_names_out()
class_log_prob = self.model.feature_log_prob_
spam_idx = classes.index("spam") if "spam" in classes else 0
ham_idx = classes.index("ham") if "ham" in classes else 0
token_counter = Counter()
for idx in x_row.nonzero()[1]:
token = feature_names[idx]
token_counter[token] += 1
scored = []
for token in token_counter:
idx = vocab.get(token)
if idx is None:
continue
delta = class_log_prob[spam_idx][idx] - class_log_prob[ham_idx][idx]
scored.append((token, delta))
scored.sort(key=lambda row: abs(row[1]), reverse=True)
return [token for token, _ in scored[:5]]
except Exception:
return list(text[:5])
def model_info(self) -> dict:
return {
"ready": self.vectorizer is not None and self.model is not None,
"model_path": str(self.model_path),
"version": self.metadata.get("version", ""),
"trained_at": self.metadata.get("trained_at"),
"sample_count": int(self.metadata.get("sample_count", 0) or 0),
"label_distribution": self.metadata.get("label_distribution", {}),
}