Files
c/backend/app/routes/spam_routes.py
刘正航 6d62120443 feat: 用户行为信誉分系统
- User 新增 credit_score 字段(0-200,默认100)
- 信誉分影响检测阈值系数:高分降低敏感度,低分提高敏感度
- 发布成功+1分,被拦截-2分;申诉通过+10分,驳回-5分
- 新增手动调整和批量重算信誉分接口
- admin-users 页面显示信誉分进度条,支持编辑调整

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-21 23:52:47 +08:00

344 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
from app.models import DetectionConfig, SpamPredictionLog, SpamTrainingSample
from app.utils.auth import admin_required, current_user
from app.utils.response import fail, ok
spam_bp = Blueprint("spam", __name__)
def _classifier() -> NaiveBayesSpamClassifier:
return NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
def _active_samples() -> list[dict]:
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
return [{"text": row.text, "label": row.label} for row in rows]
def _ensure_ready() -> NaiveBayesSpamClassifier:
clf = _classifier()
samples = _active_samples()
clf.ensure_ready(samples)
return clf
def _threshold() -> float:
row = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
return float(row.spam_threshold) if row else 0.75
def _adjusted_threshold(user_credit: int = 100) -> float:
"""根据用户信誉分调整阈值"""
base_threshold = _threshold()
# 系数范围0.85 - 1.15
credit_factor = 1.0 + (user_credit - 100) * 0.0015
credit_factor = max(0.85, min(1.15, credit_factor))
return base_threshold * credit_factor
@spam_bp.post("/predict")
@jwt_required()
def predict_one():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
if len(text) < 2:
return fail("请输入至少2个字符的待识别文本", 400)
clf = _ensure_ready()
result = clf.predict(text)
threshold = _adjusted_threshold(user.credit_score or 100)
blocked = float(result["spam_probability"]) >= threshold
row = SpamPredictionLog(
user_id=user.id,
text=result["text"],
prediction=result["prediction"],
spam_probability=result["spam_probability"],
ham_probability=result["ham_probability"],
confidence=result["confidence"],
reason_tokens=result["reason_tokens"],
model_version=result.get("model_version", ""),
)
db.session.add(row)
db.session.commit()
return ok({**result, "log_id": row.id, "threshold": threshold, "blocked_by_threshold": blocked}, "识别成功")
@spam_bp.post("/predict/batch")
@jwt_required()
def predict_batch():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
items = payload.get("items") or []
if not isinstance(items, list) or not items:
return fail("items 必须是非空数组", 400)
if len(items) > 100:
return fail("单次最多识别100条", 400)
clf = _ensure_ready()
rows = []
results = []
threshold = _adjusted_threshold(user.credit_score or 100)
for text in items:
content = (text or "").strip()
if len(content) < 2:
continue
result = clf.predict(content)
result["blocked_by_threshold"] = float(result["spam_probability"]) >= threshold
rows.append(
SpamPredictionLog(
user_id=user.id,
text=result["text"],
prediction=result["prediction"],
spam_probability=result["spam_probability"],
ham_probability=result["ham_probability"],
confidence=result["confidence"],
reason_tokens=result["reason_tokens"],
model_version=result.get("model_version", ""),
)
)
results.append(result)
if not rows:
return fail("没有可识别的有效文本", 400)
db.session.add_all(rows)
db.session.commit()
spam_count = len([item for item in results if item["prediction"] == "spam"])
blocked_count = len([item for item in results if item["blocked_by_threshold"]])
ham_count = len(results) - spam_count
return ok(
{
"items": results,
"summary": {
"total": len(results),
"spam_count": spam_count,
"ham_count": ham_count,
"blocked_count": blocked_count,
"spam_ratio": round(spam_count / len(results), 4) if results else 0,
"blocked_ratio": round(blocked_count / len(results), 4) if results else 0,
"threshold": threshold,
},
},
"批量识别完成",
)
@spam_bp.get("/history")
@jwt_required()
def my_history():
user = current_user()
if not user:
return fail("用户不存在", 404)
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
pagination = (
SpamPredictionLog.query.filter_by(user_id=user.id)
.order_by(SpamPredictionLog.id.desc())
.paginate(page=page, per_page=page_size, error_out=False)
)
return ok(
{
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@spam_bp.delete("/history/<int:log_id>")
@jwt_required()
def delete_history(log_id: int):
user = current_user()
if not user:
return fail("用户不存在", 404)
row = SpamPredictionLog.query.filter_by(id=log_id, user_id=user.id).first()
if not row:
return fail("记录不存在", 404)
db.session.delete(row)
db.session.commit()
return ok({}, "记录已删除")
@spam_bp.post("/feedback")
@jwt_required()
def save_feedback():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
if len(text) < 2:
return fail("文本至少2个字符", 400)
if not label:
return fail("label 必须是 spam 或 ham", 400)
row = SpamTrainingSample(text=text, label=label, source="feedback", created_by=user.id, is_active=True)
db.session.add(row)
db.session.commit()
return ok(row.to_dict(), "反馈样本已记录")
@spam_bp.get("/model/info")
@jwt_required(optional=True)
def model_info():
clf = _classifier()
clf.load()
info = clf.model_info()
info["threshold"] = _threshold()
return ok(info)
@spam_bp.post("/train")
@admin_required
def train_model():
clf = _classifier()
samples = _active_samples()
metadata = clf.train(samples)
return ok(metadata, "模型训练完成")
@spam_bp.get("/samples")
@admin_required
def list_samples():
keyword = (request.args.get("keyword") or "").strip()
label = (request.args.get("label") or "").strip().lower()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
query = SpamTrainingSample.query
if keyword:
query = query.filter(SpamTrainingSample.text.like(f"%{keyword}%"))
if label in {"spam", "ham"}:
query = query.filter(SpamTrainingSample.label == label)
pagination = query.order_by(SpamTrainingSample.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
return ok(
{
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@spam_bp.post("/samples")
@admin_required
def create_sample():
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
if len(text) < 2:
return fail("文本至少2个字符", 400)
if not label:
return fail("label 必须是 spam 或 ham", 400)
user = current_user()
row = SpamTrainingSample(text=text, label=label, source="import", created_by=user.id if user else None, is_active=True)
db.session.add(row)
db.session.commit()
return ok(row.to_dict(), "样本创建成功")
@spam_bp.put("/samples/<int:sample_id>")
@admin_required
def update_sample(sample_id: int):
row = SpamTrainingSample.query.get(sample_id)
if not row:
return fail("样本不存在", 404)
payload = request.get_json(silent=True) or {}
if "text" in payload:
text = (payload.get("text") or "").strip()
if len(text) < 2:
return fail("文本至少2个字符", 400)
row.text = text
if "label" in payload:
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
if not label:
return fail("label 必须是 spam 或 ham", 400)
row.label = label
if "is_active" in payload:
row.is_active = bool(payload.get("is_active"))
db.session.commit()
return ok(row.to_dict(), "样本更新成功")
@spam_bp.delete("/samples/<int:sample_id>")
@admin_required
def delete_sample(sample_id: int):
row = SpamTrainingSample.query.get(sample_id)
if not row:
return fail("样本不存在", 404)
db.session.delete(row)
db.session.commit()
return ok({}, "样本已删除")
@spam_bp.post("/samples/import")
@admin_required
def import_samples():
payload = request.get_json(silent=True) or {}
items = payload.get("items") or []
if not isinstance(items, list) or not items:
return fail("items 必须是非空数组", 400)
user = current_user()
created = 0
updated = 0
for item in items:
text = (item.get("text") or "").strip()
label = NaiveBayesSpamClassifier.normalize_label(item.get("label"))
if len(text) < 2 or not label:
continue
row = SpamTrainingSample.query.filter_by(text=text).first()
if row:
row.label = label
row.is_active = bool(item.get("is_active", True))
row.source = item.get("source") or row.source
updated += 1
else:
row = SpamTrainingSample(
text=text,
label=label,
source=item.get("source") or "import",
created_by=user.id if user else None,
is_active=bool(item.get("is_active", True)),
)
db.session.add(row)
created += 1
db.session.commit()
return ok({"created": created, "updated": updated}, "样本导入完成")