1
This commit is contained in:
334
backend/app/routes/spam_routes.py
Normal file
334
backend/app/routes/spam_routes.py
Normal file
@@ -0,0 +1,334 @@
|
||||
from flask import Blueprint, current_app, request
|
||||
from flask_jwt_extended import jwt_required
|
||||
|
||||
from app.extensions import db
|
||||
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
|
||||
from app.models import DetectionConfig, SpamPredictionLog, SpamTrainingSample
|
||||
from app.utils.auth import admin_required, current_user
|
||||
from app.utils.response import fail, ok
|
||||
|
||||
|
||||
spam_bp = Blueprint("spam", __name__)
|
||||
|
||||
|
||||
def _classifier() -> NaiveBayesSpamClassifier:
|
||||
return NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
|
||||
|
||||
|
||||
def _active_samples() -> list[dict]:
|
||||
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
|
||||
return [{"text": row.text, "label": row.label} for row in rows]
|
||||
|
||||
|
||||
def _ensure_ready() -> NaiveBayesSpamClassifier:
|
||||
clf = _classifier()
|
||||
samples = _active_samples()
|
||||
clf.ensure_ready(samples)
|
||||
return clf
|
||||
|
||||
|
||||
def _threshold() -> float:
|
||||
row = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
|
||||
return float(row.spam_threshold) if row else 0.75
|
||||
|
||||
|
||||
@spam_bp.post("/predict")
|
||||
@jwt_required()
|
||||
def predict_one():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
text = (payload.get("text") or "").strip()
|
||||
if len(text) < 2:
|
||||
return fail("请输入至少2个字符的待识别文本", 400)
|
||||
|
||||
clf = _ensure_ready()
|
||||
result = clf.predict(text)
|
||||
threshold = _threshold()
|
||||
blocked = float(result["spam_probability"]) >= threshold
|
||||
|
||||
row = SpamPredictionLog(
|
||||
user_id=user.id,
|
||||
text=result["text"],
|
||||
prediction=result["prediction"],
|
||||
spam_probability=result["spam_probability"],
|
||||
ham_probability=result["ham_probability"],
|
||||
confidence=result["confidence"],
|
||||
reason_tokens=result["reason_tokens"],
|
||||
model_version=result.get("model_version", ""),
|
||||
)
|
||||
db.session.add(row)
|
||||
db.session.commit()
|
||||
|
||||
return ok({**result, "log_id": row.id, "threshold": threshold, "blocked_by_threshold": blocked}, "识别成功")
|
||||
|
||||
|
||||
@spam_bp.post("/predict/batch")
|
||||
@jwt_required()
|
||||
def predict_batch():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
items = payload.get("items") or []
|
||||
if not isinstance(items, list) or not items:
|
||||
return fail("items 必须是非空数组", 400)
|
||||
if len(items) > 100:
|
||||
return fail("单次最多识别100条", 400)
|
||||
|
||||
clf = _ensure_ready()
|
||||
rows = []
|
||||
results = []
|
||||
threshold = _threshold()
|
||||
|
||||
for text in items:
|
||||
content = (text or "").strip()
|
||||
if len(content) < 2:
|
||||
continue
|
||||
result = clf.predict(content)
|
||||
result["blocked_by_threshold"] = float(result["spam_probability"]) >= threshold
|
||||
rows.append(
|
||||
SpamPredictionLog(
|
||||
user_id=user.id,
|
||||
text=result["text"],
|
||||
prediction=result["prediction"],
|
||||
spam_probability=result["spam_probability"],
|
||||
ham_probability=result["ham_probability"],
|
||||
confidence=result["confidence"],
|
||||
reason_tokens=result["reason_tokens"],
|
||||
model_version=result.get("model_version", ""),
|
||||
)
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
if not rows:
|
||||
return fail("没有可识别的有效文本", 400)
|
||||
|
||||
db.session.add_all(rows)
|
||||
db.session.commit()
|
||||
|
||||
spam_count = len([item for item in results if item["prediction"] == "spam"])
|
||||
blocked_count = len([item for item in results if item["blocked_by_threshold"]])
|
||||
ham_count = len(results) - spam_count
|
||||
|
||||
return ok(
|
||||
{
|
||||
"items": results,
|
||||
"summary": {
|
||||
"total": len(results),
|
||||
"spam_count": spam_count,
|
||||
"ham_count": ham_count,
|
||||
"blocked_count": blocked_count,
|
||||
"spam_ratio": round(spam_count / len(results), 4) if results else 0,
|
||||
"blocked_ratio": round(blocked_count / len(results), 4) if results else 0,
|
||||
"threshold": threshold,
|
||||
},
|
||||
},
|
||||
"批量识别完成",
|
||||
)
|
||||
|
||||
|
||||
@spam_bp.get("/history")
|
||||
@jwt_required()
|
||||
def my_history():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||||
|
||||
pagination = (
|
||||
SpamPredictionLog.query.filter_by(user_id=user.id)
|
||||
.order_by(SpamPredictionLog.id.desc())
|
||||
.paginate(page=page, per_page=page_size, error_out=False)
|
||||
)
|
||||
|
||||
return ok(
|
||||
{
|
||||
"items": [item.to_dict() for item in pagination.items],
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@spam_bp.delete("/history/<int:log_id>")
|
||||
@jwt_required()
|
||||
def delete_history(log_id: int):
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
row = SpamPredictionLog.query.filter_by(id=log_id, user_id=user.id).first()
|
||||
if not row:
|
||||
return fail("记录不存在", 404)
|
||||
|
||||
db.session.delete(row)
|
||||
db.session.commit()
|
||||
return ok({}, "记录已删除")
|
||||
|
||||
|
||||
@spam_bp.post("/feedback")
|
||||
@jwt_required()
|
||||
def save_feedback():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
text = (payload.get("text") or "").strip()
|
||||
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
|
||||
|
||||
if len(text) < 2:
|
||||
return fail("文本至少2个字符", 400)
|
||||
if not label:
|
||||
return fail("label 必须是 spam 或 ham", 400)
|
||||
|
||||
row = SpamTrainingSample(text=text, label=label, source="feedback", created_by=user.id, is_active=True)
|
||||
db.session.add(row)
|
||||
db.session.commit()
|
||||
return ok(row.to_dict(), "反馈样本已记录")
|
||||
|
||||
|
||||
@spam_bp.get("/model/info")
|
||||
@jwt_required(optional=True)
|
||||
def model_info():
|
||||
clf = _classifier()
|
||||
clf.load()
|
||||
info = clf.model_info()
|
||||
info["threshold"] = _threshold()
|
||||
return ok(info)
|
||||
|
||||
|
||||
@spam_bp.post("/train")
|
||||
@admin_required
|
||||
def train_model():
|
||||
clf = _classifier()
|
||||
samples = _active_samples()
|
||||
metadata = clf.train(samples)
|
||||
return ok(metadata, "模型训练完成")
|
||||
|
||||
|
||||
@spam_bp.get("/samples")
|
||||
@admin_required
|
||||
def list_samples():
|
||||
keyword = (request.args.get("keyword") or "").strip()
|
||||
label = (request.args.get("label") or "").strip().lower()
|
||||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||||
|
||||
query = SpamTrainingSample.query
|
||||
if keyword:
|
||||
query = query.filter(SpamTrainingSample.text.like(f"%{keyword}%"))
|
||||
if label in {"spam", "ham"}:
|
||||
query = query.filter(SpamTrainingSample.label == label)
|
||||
|
||||
pagination = query.order_by(SpamTrainingSample.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
|
||||
return ok(
|
||||
{
|
||||
"items": [item.to_dict() for item in pagination.items],
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@spam_bp.post("/samples")
|
||||
@admin_required
|
||||
def create_sample():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
text = (payload.get("text") or "").strip()
|
||||
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
|
||||
|
||||
if len(text) < 2:
|
||||
return fail("文本至少2个字符", 400)
|
||||
if not label:
|
||||
return fail("label 必须是 spam 或 ham", 400)
|
||||
|
||||
user = current_user()
|
||||
row = SpamTrainingSample(text=text, label=label, source="import", created_by=user.id if user else None, is_active=True)
|
||||
db.session.add(row)
|
||||
db.session.commit()
|
||||
return ok(row.to_dict(), "样本创建成功")
|
||||
|
||||
|
||||
@spam_bp.put("/samples/<int:sample_id>")
|
||||
@admin_required
|
||||
def update_sample(sample_id: int):
|
||||
row = SpamTrainingSample.query.get(sample_id)
|
||||
if not row:
|
||||
return fail("样本不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
if "text" in payload:
|
||||
text = (payload.get("text") or "").strip()
|
||||
if len(text) < 2:
|
||||
return fail("文本至少2个字符", 400)
|
||||
row.text = text
|
||||
if "label" in payload:
|
||||
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
|
||||
if not label:
|
||||
return fail("label 必须是 spam 或 ham", 400)
|
||||
row.label = label
|
||||
if "is_active" in payload:
|
||||
row.is_active = bool(payload.get("is_active"))
|
||||
|
||||
db.session.commit()
|
||||
return ok(row.to_dict(), "样本更新成功")
|
||||
|
||||
|
||||
@spam_bp.delete("/samples/<int:sample_id>")
|
||||
@admin_required
|
||||
def delete_sample(sample_id: int):
|
||||
row = SpamTrainingSample.query.get(sample_id)
|
||||
if not row:
|
||||
return fail("样本不存在", 404)
|
||||
|
||||
db.session.delete(row)
|
||||
db.session.commit()
|
||||
return ok({}, "样本已删除")
|
||||
|
||||
|
||||
@spam_bp.post("/samples/import")
|
||||
@admin_required
|
||||
def import_samples():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
items = payload.get("items") or []
|
||||
if not isinstance(items, list) or not items:
|
||||
return fail("items 必须是非空数组", 400)
|
||||
|
||||
user = current_user()
|
||||
created = 0
|
||||
updated = 0
|
||||
|
||||
for item in items:
|
||||
text = (item.get("text") or "").strip()
|
||||
label = NaiveBayesSpamClassifier.normalize_label(item.get("label"))
|
||||
if len(text) < 2 or not label:
|
||||
continue
|
||||
|
||||
row = SpamTrainingSample.query.filter_by(text=text).first()
|
||||
if row:
|
||||
row.label = label
|
||||
row.is_active = bool(item.get("is_active", True))
|
||||
row.source = item.get("source") or row.source
|
||||
updated += 1
|
||||
else:
|
||||
row = SpamTrainingSample(
|
||||
text=text,
|
||||
label=label,
|
||||
source=item.get("source") or "import",
|
||||
created_by=user.id if user else None,
|
||||
is_active=bool(item.get("is_active", True)),
|
||||
)
|
||||
db.session.add(row)
|
||||
created += 1
|
||||
|
||||
db.session.commit()
|
||||
return ok({"created": created, "updated": updated}, "样本导入完成")
|
||||
Reference in New Issue
Block a user