from flask import Blueprint, current_app, request from flask_jwt_extended import jwt_required from app.extensions import db from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier from app.ml.spam_categorizer import categorize_spam, get_category_label from app.models import DetectionConfig, SpamPredictionLog, SpamTrainingSample from app.utils.auth import admin_required, current_user from app.utils.response import fail, ok spam_bp = Blueprint("spam", __name__) def _classifier() -> NaiveBayesSpamClassifier: return NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"]) def _active_samples() -> list[dict]: rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all() return [{"text": row.text, "label": row.label} for row in rows] def _ensure_ready() -> NaiveBayesSpamClassifier: clf = _classifier() samples = _active_samples() clf.ensure_ready(samples) return clf def _threshold() -> float: row = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first() return float(row.spam_threshold) if row else 0.75 def _adjusted_threshold(user_credit: int = 100) -> float: """根据用户信誉分调整阈值""" base_threshold = _threshold() # 系数范围:0.85 - 1.15 credit_factor = 1.0 + (user_credit - 100) * 0.0015 credit_factor = max(0.85, min(1.15, credit_factor)) return base_threshold * credit_factor @spam_bp.post("/predict") @jwt_required() def predict_one(): user = current_user() if not user: return fail("用户不存在", 404) payload = request.get_json(silent=True) or {} text = (payload.get("text") or "").strip() if len(text) < 2: return fail("请输入至少2个字符的待识别文本", 400) clf = _ensure_ready() result = clf.predict(text) threshold = _adjusted_threshold(user.credit_score or 100) blocked = float(result["spam_probability"]) >= threshold # 分类标签:仅在判定为垃圾时进行细分 category = "" category_label = "" if blocked: category, category_label = categorize_spam(result["text"]) row = SpamPredictionLog( user_id=user.id, text=result["text"], prediction=result["prediction"], category=category, spam_probability=result["spam_probability"], ham_probability=result["ham_probability"], confidence=result["confidence"], reason_tokens=result["reason_tokens"], model_version=result.get("model_version", ""), ) db.session.add(row) db.session.commit() return ok({ **result, "log_id": row.id, "threshold": threshold, "blocked_by_threshold": blocked, "category": category, "category_label": category_label, }, "识别成功") @spam_bp.post("/predict/batch") @jwt_required() def predict_batch(): user = current_user() if not user: return fail("用户不存在", 404) payload = request.get_json(silent=True) or {} items = payload.get("items") or [] if not isinstance(items, list) or not items: return fail("items 必须是非空数组", 400) if len(items) > 100: return fail("单次最多识别100条", 400) clf = _ensure_ready() rows = [] results = [] threshold = _adjusted_threshold(user.credit_score or 100) for text in items: content = (text or "").strip() if len(content) < 2: continue result = clf.predict(content) blocked = float(result["spam_probability"]) >= threshold result["blocked_by_threshold"] = blocked # 分类标签 category = "" category_label = "" if blocked: category, category_label = categorize_spam(result["text"]) result["category"] = category result["category_label"] = category_label rows.append( SpamPredictionLog( user_id=user.id, text=result["text"], prediction=result["prediction"], category=category, spam_probability=result["spam_probability"], ham_probability=result["ham_probability"], confidence=result["confidence"], reason_tokens=result["reason_tokens"], model_version=result.get("model_version", ""), ) ) results.append(result) if not rows: return fail("没有可识别的有效文本", 400) db.session.add_all(rows) db.session.commit() spam_count = len([item for item in results if item["prediction"] == "spam"]) blocked_count = len([item for item in results if item["blocked_by_threshold"]]) ham_count = len(results) - spam_count return ok( { "items": results, "summary": { "total": len(results), "spam_count": spam_count, "ham_count": ham_count, "blocked_count": blocked_count, "spam_ratio": round(spam_count / len(results), 4) if results else 0, "blocked_ratio": round(blocked_count / len(results), 4) if results else 0, "threshold": threshold, }, }, "批量识别完成", ) @spam_bp.get("/history") @jwt_required() def my_history(): user = current_user() if not user: return fail("用户不存在", 404) page = max(int(request.args.get("page", 1) or 1), 1) page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100) pagination = ( SpamPredictionLog.query.filter_by(user_id=user.id) .order_by(SpamPredictionLog.id.desc()) .paginate(page=page, per_page=page_size, error_out=False) ) return ok( { "items": [item.to_dict() for item in pagination.items], "total": pagination.total, "page": page, "page_size": page_size, } ) @spam_bp.delete("/history/") @jwt_required() def delete_history(log_id: int): user = current_user() if not user: return fail("用户不存在", 404) row = SpamPredictionLog.query.filter_by(id=log_id, user_id=user.id).first() if not row: return fail("记录不存在", 404) db.session.delete(row) db.session.commit() return ok({}, "记录已删除") @spam_bp.post("/feedback") @jwt_required() def save_feedback(): user = current_user() if not user: return fail("用户不存在", 404) payload = request.get_json(silent=True) or {} text = (payload.get("text") or "").strip() label = NaiveBayesSpamClassifier.normalize_label(payload.get("label")) if len(text) < 2: return fail("文本至少2个字符", 400) if not label: return fail("label 必须是 spam 或 ham", 400) row = SpamTrainingSample(text=text, label=label, source="feedback", created_by=user.id, is_active=True) db.session.add(row) db.session.commit() return ok(row.to_dict(), "反馈样本已记录") @spam_bp.get("/model/info") @jwt_required(optional=True) def model_info(): clf = _classifier() clf.load() info = clf.model_info() info["threshold"] = _threshold() return ok(info) @spam_bp.post("/train") @admin_required def train_model(): clf = _classifier() samples = _active_samples() metadata = clf.train(samples) return ok(metadata, "模型训练完成") @spam_bp.get("/samples") @admin_required def list_samples(): keyword = (request.args.get("keyword") or "").strip() label = (request.args.get("label") or "").strip().lower() page = max(int(request.args.get("page", 1) or 1), 1) page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100) query = SpamTrainingSample.query if keyword: query = query.filter(SpamTrainingSample.text.like(f"%{keyword}%")) if label in {"spam", "ham"}: query = query.filter(SpamTrainingSample.label == label) pagination = query.order_by(SpamTrainingSample.id.desc()).paginate(page=page, per_page=page_size, error_out=False) return ok( { "items": [item.to_dict() for item in pagination.items], "total": pagination.total, "page": page, "page_size": page_size, } ) @spam_bp.post("/samples") @admin_required def create_sample(): payload = request.get_json(silent=True) or {} text = (payload.get("text") or "").strip() label = NaiveBayesSpamClassifier.normalize_label(payload.get("label")) if len(text) < 2: return fail("文本至少2个字符", 400) if not label: return fail("label 必须是 spam 或 ham", 400) user = current_user() row = SpamTrainingSample(text=text, label=label, source="import", created_by=user.id if user else None, is_active=True) db.session.add(row) db.session.commit() return ok(row.to_dict(), "样本创建成功") @spam_bp.put("/samples/") @admin_required def update_sample(sample_id: int): row = SpamTrainingSample.query.get(sample_id) if not row: return fail("样本不存在", 404) payload = request.get_json(silent=True) or {} if "text" in payload: text = (payload.get("text") or "").strip() if len(text) < 2: return fail("文本至少2个字符", 400) row.text = text if "label" in payload: label = NaiveBayesSpamClassifier.normalize_label(payload.get("label")) if not label: return fail("label 必须是 spam 或 ham", 400) row.label = label if "is_active" in payload: row.is_active = bool(payload.get("is_active")) db.session.commit() return ok(row.to_dict(), "样本更新成功") @spam_bp.delete("/samples/") @admin_required def delete_sample(sample_id: int): row = SpamTrainingSample.query.get(sample_id) if not row: return fail("样本不存在", 404) db.session.delete(row) db.session.commit() return ok({}, "样本已删除") @spam_bp.post("/samples/import") @admin_required def import_samples(): payload = request.get_json(silent=True) or {} items = payload.get("items") or [] if not isinstance(items, list) or not items: return fail("items 必须是非空数组", 400) user = current_user() created = 0 updated = 0 for item in items: text = (item.get("text") or "").strip() label = NaiveBayesSpamClassifier.normalize_label(item.get("label")) if len(text) < 2 or not label: continue row = SpamTrainingSample.query.filter_by(text=text).first() if row: row.label = label row.is_active = bool(item.get("is_active", True)) row.source = item.get("source") or row.source updated += 1 else: row = SpamTrainingSample( text=text, label=label, source=item.get("source") or "import", created_by=user.id if user else None, is_active=bool(item.get("is_active", True)), ) db.session.add(row) created += 1 db.session.commit() return ok({"created": created, "updated": updated}, "样本导入完成")