Files
c/backend/app/routes/spam_routes.py
刘正航 b5237f9038 1
2026-04-21 22:45:19 +08:00

335 lines
10 KiB
Python

from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
from app.models import DetectionConfig, SpamPredictionLog, SpamTrainingSample
from app.utils.auth import admin_required, current_user
from app.utils.response import fail, ok
spam_bp = Blueprint("spam", __name__)
def _classifier() -> NaiveBayesSpamClassifier:
return NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
def _active_samples() -> list[dict]:
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
return [{"text": row.text, "label": row.label} for row in rows]
def _ensure_ready() -> NaiveBayesSpamClassifier:
clf = _classifier()
samples = _active_samples()
clf.ensure_ready(samples)
return clf
def _threshold() -> float:
row = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
return float(row.spam_threshold) if row else 0.75
@spam_bp.post("/predict")
@jwt_required()
def predict_one():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
if len(text) < 2:
return fail("请输入至少2个字符的待识别文本", 400)
clf = _ensure_ready()
result = clf.predict(text)
threshold = _threshold()
blocked = float(result["spam_probability"]) >= threshold
row = SpamPredictionLog(
user_id=user.id,
text=result["text"],
prediction=result["prediction"],
spam_probability=result["spam_probability"],
ham_probability=result["ham_probability"],
confidence=result["confidence"],
reason_tokens=result["reason_tokens"],
model_version=result.get("model_version", ""),
)
db.session.add(row)
db.session.commit()
return ok({**result, "log_id": row.id, "threshold": threshold, "blocked_by_threshold": blocked}, "识别成功")
@spam_bp.post("/predict/batch")
@jwt_required()
def predict_batch():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
items = payload.get("items") or []
if not isinstance(items, list) or not items:
return fail("items 必须是非空数组", 400)
if len(items) > 100:
return fail("单次最多识别100条", 400)
clf = _ensure_ready()
rows = []
results = []
threshold = _threshold()
for text in items:
content = (text or "").strip()
if len(content) < 2:
continue
result = clf.predict(content)
result["blocked_by_threshold"] = float(result["spam_probability"]) >= threshold
rows.append(
SpamPredictionLog(
user_id=user.id,
text=result["text"],
prediction=result["prediction"],
spam_probability=result["spam_probability"],
ham_probability=result["ham_probability"],
confidence=result["confidence"],
reason_tokens=result["reason_tokens"],
model_version=result.get("model_version", ""),
)
)
results.append(result)
if not rows:
return fail("没有可识别的有效文本", 400)
db.session.add_all(rows)
db.session.commit()
spam_count = len([item for item in results if item["prediction"] == "spam"])
blocked_count = len([item for item in results if item["blocked_by_threshold"]])
ham_count = len(results) - spam_count
return ok(
{
"items": results,
"summary": {
"total": len(results),
"spam_count": spam_count,
"ham_count": ham_count,
"blocked_count": blocked_count,
"spam_ratio": round(spam_count / len(results), 4) if results else 0,
"blocked_ratio": round(blocked_count / len(results), 4) if results else 0,
"threshold": threshold,
},
},
"批量识别完成",
)
@spam_bp.get("/history")
@jwt_required()
def my_history():
user = current_user()
if not user:
return fail("用户不存在", 404)
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
pagination = (
SpamPredictionLog.query.filter_by(user_id=user.id)
.order_by(SpamPredictionLog.id.desc())
.paginate(page=page, per_page=page_size, error_out=False)
)
return ok(
{
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@spam_bp.delete("/history/<int:log_id>")
@jwt_required()
def delete_history(log_id: int):
user = current_user()
if not user:
return fail("用户不存在", 404)
row = SpamPredictionLog.query.filter_by(id=log_id, user_id=user.id).first()
if not row:
return fail("记录不存在", 404)
db.session.delete(row)
db.session.commit()
return ok({}, "记录已删除")
@spam_bp.post("/feedback")
@jwt_required()
def save_feedback():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
if len(text) < 2:
return fail("文本至少2个字符", 400)
if not label:
return fail("label 必须是 spam 或 ham", 400)
row = SpamTrainingSample(text=text, label=label, source="feedback", created_by=user.id, is_active=True)
db.session.add(row)
db.session.commit()
return ok(row.to_dict(), "反馈样本已记录")
@spam_bp.get("/model/info")
@jwt_required(optional=True)
def model_info():
clf = _classifier()
clf.load()
info = clf.model_info()
info["threshold"] = _threshold()
return ok(info)
@spam_bp.post("/train")
@admin_required
def train_model():
clf = _classifier()
samples = _active_samples()
metadata = clf.train(samples)
return ok(metadata, "模型训练完成")
@spam_bp.get("/samples")
@admin_required
def list_samples():
keyword = (request.args.get("keyword") or "").strip()
label = (request.args.get("label") or "").strip().lower()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
query = SpamTrainingSample.query
if keyword:
query = query.filter(SpamTrainingSample.text.like(f"%{keyword}%"))
if label in {"spam", "ham"}:
query = query.filter(SpamTrainingSample.label == label)
pagination = query.order_by(SpamTrainingSample.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
return ok(
{
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@spam_bp.post("/samples")
@admin_required
def create_sample():
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
if len(text) < 2:
return fail("文本至少2个字符", 400)
if not label:
return fail("label 必须是 spam 或 ham", 400)
user = current_user()
row = SpamTrainingSample(text=text, label=label, source="import", created_by=user.id if user else None, is_active=True)
db.session.add(row)
db.session.commit()
return ok(row.to_dict(), "样本创建成功")
@spam_bp.put("/samples/<int:sample_id>")
@admin_required
def update_sample(sample_id: int):
row = SpamTrainingSample.query.get(sample_id)
if not row:
return fail("样本不存在", 404)
payload = request.get_json(silent=True) or {}
if "text" in payload:
text = (payload.get("text") or "").strip()
if len(text) < 2:
return fail("文本至少2个字符", 400)
row.text = text
if "label" in payload:
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
if not label:
return fail("label 必须是 spam 或 ham", 400)
row.label = label
if "is_active" in payload:
row.is_active = bool(payload.get("is_active"))
db.session.commit()
return ok(row.to_dict(), "样本更新成功")
@spam_bp.delete("/samples/<int:sample_id>")
@admin_required
def delete_sample(sample_id: int):
row = SpamTrainingSample.query.get(sample_id)
if not row:
return fail("样本不存在", 404)
db.session.delete(row)
db.session.commit()
return ok({}, "样本已删除")
@spam_bp.post("/samples/import")
@admin_required
def import_samples():
payload = request.get_json(silent=True) or {}
items = payload.get("items") or []
if not isinstance(items, list) or not items:
return fail("items 必须是非空数组", 400)
user = current_user()
created = 0
updated = 0
for item in items:
text = (item.get("text") or "").strip()
label = NaiveBayesSpamClassifier.normalize_label(item.get("label"))
if len(text) < 2 or not label:
continue
row = SpamTrainingSample.query.filter_by(text=text).first()
if row:
row.label = label
row.is_active = bool(item.get("is_active", True))
row.source = item.get("source") or row.source
updated += 1
else:
row = SpamTrainingSample(
text=text,
label=label,
source=item.get("source") or "import",
created_by=user.id if user else None,
is_active=bool(item.get("is_active", True)),
)
db.session.add(row)
created += 1
db.session.commit()
return ok({"created": created, "updated": updated}, "样本导入完成")