423 lines
13 KiB
Python
423 lines
13 KiB
Python
from flask import Blueprint, current_app, request, send_file, after_this_request
|
||
from flask_jwt_extended import jwt_required
|
||
|
||
from app.extensions import db
|
||
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
|
||
from app.ml.spam_categorizer import categorize_spam, get_category_label
|
||
from app.models import DetectionConfig, SpamPredictionLog, SpamTrainingSample
|
||
from app.utils.auth import admin_required, current_user
|
||
from app.utils.response import fail, ok
|
||
|
||
|
||
spam_bp = Blueprint("spam", __name__)
|
||
|
||
|
||
def _classifier() -> NaiveBayesSpamClassifier:
|
||
return NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
|
||
|
||
|
||
def _active_samples() -> list[dict]:
|
||
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
|
||
return [{"text": row.text, "label": row.label} for row in rows]
|
||
|
||
|
||
def _ensure_ready() -> NaiveBayesSpamClassifier:
|
||
clf = _classifier()
|
||
samples = _active_samples()
|
||
clf.ensure_ready(samples)
|
||
return clf
|
||
|
||
|
||
def _threshold() -> float:
|
||
row = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
|
||
return float(row.spam_threshold) if row else 0.75
|
||
|
||
|
||
def _adjusted_threshold(user_credit: int = 100) -> float:
|
||
"""根据用户信誉分调整阈值"""
|
||
base_threshold = _threshold()
|
||
# 系数范围:0.85 - 1.15
|
||
credit_factor = 1.0 + (user_credit - 100) * 0.0015
|
||
credit_factor = max(0.85, min(1.15, credit_factor))
|
||
return base_threshold * credit_factor
|
||
|
||
|
||
@spam_bp.post("/predict")
|
||
@jwt_required()
|
||
def predict_one():
|
||
user = current_user()
|
||
if not user:
|
||
return fail("用户不存在", 404)
|
||
|
||
payload = request.get_json(silent=True) or {}
|
||
text = (payload.get("text") or "").strip()
|
||
if len(text) < 2:
|
||
return fail("请输入至少2个字符的待识别文本", 400)
|
||
|
||
clf = _ensure_ready()
|
||
result = clf.predict(text)
|
||
threshold = _adjusted_threshold(user.credit_score or 100)
|
||
blocked = float(result["spam_probability"]) >= threshold
|
||
|
||
# 分类标签:仅在判定为垃圾时进行细分
|
||
category = ""
|
||
category_label = ""
|
||
if blocked:
|
||
category, category_label = categorize_spam(result["text"])
|
||
|
||
row = SpamPredictionLog(
|
||
user_id=user.id,
|
||
text=result["text"],
|
||
prediction=result["prediction"],
|
||
category=category,
|
||
spam_probability=result["spam_probability"],
|
||
ham_probability=result["ham_probability"],
|
||
confidence=result["confidence"],
|
||
reason_tokens=result["reason_tokens"],
|
||
model_version=result.get("model_version", ""),
|
||
)
|
||
db.session.add(row)
|
||
db.session.commit()
|
||
|
||
return ok({
|
||
**result,
|
||
"log_id": row.id,
|
||
"threshold": threshold,
|
||
"blocked_by_threshold": blocked,
|
||
"category": category,
|
||
"category_label": category_label,
|
||
}, "识别成功")
|
||
|
||
|
||
@spam_bp.post("/predict/batch")
|
||
@jwt_required()
|
||
def predict_batch():
|
||
user = current_user()
|
||
if not user:
|
||
return fail("用户不存在", 404)
|
||
|
||
payload = request.get_json(silent=True) or {}
|
||
items = payload.get("items") or []
|
||
if not isinstance(items, list) or not items:
|
||
return fail("items 必须是非空数组", 400)
|
||
if len(items) > 100:
|
||
return fail("单次最多识别100条", 400)
|
||
|
||
clf = _ensure_ready()
|
||
rows = []
|
||
results = []
|
||
threshold = _adjusted_threshold(user.credit_score or 100)
|
||
|
||
for text in items:
|
||
content = (text or "").strip()
|
||
if len(content) < 2:
|
||
continue
|
||
result = clf.predict(content)
|
||
blocked = float(result["spam_probability"]) >= threshold
|
||
result["blocked_by_threshold"] = blocked
|
||
|
||
# 分类标签
|
||
category = ""
|
||
category_label = ""
|
||
if blocked:
|
||
category, category_label = categorize_spam(result["text"])
|
||
result["category"] = category
|
||
result["category_label"] = category_label
|
||
|
||
rows.append(
|
||
SpamPredictionLog(
|
||
user_id=user.id,
|
||
text=result["text"],
|
||
prediction=result["prediction"],
|
||
category=category,
|
||
spam_probability=result["spam_probability"],
|
||
ham_probability=result["ham_probability"],
|
||
confidence=result["confidence"],
|
||
reason_tokens=result["reason_tokens"],
|
||
model_version=result.get("model_version", ""),
|
||
)
|
||
)
|
||
results.append(result)
|
||
|
||
if not rows:
|
||
return fail("没有可识别的有效文本", 400)
|
||
|
||
db.session.add_all(rows)
|
||
db.session.commit()
|
||
|
||
spam_count = len([item for item in results if item["prediction"] == "spam"])
|
||
blocked_count = len([item for item in results if item["blocked_by_threshold"]])
|
||
ham_count = len(results) - spam_count
|
||
|
||
return ok(
|
||
{
|
||
"items": results,
|
||
"summary": {
|
||
"total": len(results),
|
||
"spam_count": spam_count,
|
||
"ham_count": ham_count,
|
||
"blocked_count": blocked_count,
|
||
"spam_ratio": round(spam_count / len(results), 4) if results else 0,
|
||
"blocked_ratio": round(blocked_count / len(results), 4) if results else 0,
|
||
"threshold": threshold,
|
||
},
|
||
},
|
||
"批量识别完成",
|
||
)
|
||
|
||
|
||
@spam_bp.get("/history")
|
||
@jwt_required()
|
||
def my_history():
|
||
user = current_user()
|
||
if not user:
|
||
return fail("用户不存在", 404)
|
||
|
||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||
|
||
pagination = (
|
||
SpamPredictionLog.query.filter_by(user_id=user.id)
|
||
.order_by(SpamPredictionLog.id.desc())
|
||
.paginate(page=page, per_page=page_size, error_out=False)
|
||
)
|
||
|
||
return ok(
|
||
{
|
||
"items": [item.to_dict() for item in pagination.items],
|
||
"total": pagination.total,
|
||
"page": page,
|
||
"page_size": page_size,
|
||
}
|
||
)
|
||
|
||
|
||
@spam_bp.delete("/history/<int:log_id>")
|
||
@jwt_required()
|
||
def delete_history(log_id: int):
|
||
user = current_user()
|
||
if not user:
|
||
return fail("用户不存在", 404)
|
||
|
||
row = SpamPredictionLog.query.filter_by(id=log_id, user_id=user.id).first()
|
||
if not row:
|
||
return fail("记录不存在", 404)
|
||
|
||
db.session.delete(row)
|
||
db.session.commit()
|
||
return ok({}, "记录已删除")
|
||
|
||
|
||
@spam_bp.post("/feedback")
|
||
@jwt_required()
|
||
def save_feedback():
|
||
user = current_user()
|
||
if not user:
|
||
return fail("用户不存在", 404)
|
||
|
||
payload = request.get_json(silent=True) or {}
|
||
text = (payload.get("text") or "").strip()
|
||
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
|
||
|
||
if len(text) < 2:
|
||
return fail("文本至少2个字符", 400)
|
||
if not label:
|
||
return fail("label 必须是 spam 或 ham", 400)
|
||
|
||
row = SpamTrainingSample(text=text, label=label, source="feedback", created_by=user.id, is_active=True)
|
||
db.session.add(row)
|
||
db.session.commit()
|
||
return ok(row.to_dict(), "反馈样本已记录")
|
||
|
||
|
||
@spam_bp.get("/model/info")
|
||
@jwt_required(optional=True)
|
||
def model_info():
|
||
clf = _classifier()
|
||
clf.load()
|
||
info = clf.model_info()
|
||
info["threshold"] = _threshold()
|
||
return ok(info)
|
||
|
||
|
||
@spam_bp.post("/train")
|
||
@admin_required
|
||
def train_model():
|
||
clf = _classifier()
|
||
samples = _active_samples()
|
||
metadata = clf.train(samples)
|
||
return ok(metadata, "模型训练完成")
|
||
|
||
|
||
@spam_bp.get("/samples")
|
||
@admin_required
|
||
def list_samples():
|
||
keyword = (request.args.get("keyword") or "").strip()
|
||
label = (request.args.get("label") or "").strip().lower()
|
||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||
|
||
query = SpamTrainingSample.query
|
||
if keyword:
|
||
query = query.filter(SpamTrainingSample.text.like(f"%{keyword}%"))
|
||
if label in {"spam", "ham"}:
|
||
query = query.filter(SpamTrainingSample.label == label)
|
||
|
||
pagination = query.order_by(SpamTrainingSample.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
|
||
return ok(
|
||
{
|
||
"items": [item.to_dict() for item in pagination.items],
|
||
"total": pagination.total,
|
||
"page": page,
|
||
"page_size": page_size,
|
||
}
|
||
)
|
||
|
||
|
||
@spam_bp.post("/samples")
|
||
@admin_required
|
||
def create_sample():
|
||
payload = request.get_json(silent=True) or {}
|
||
text = (payload.get("text") or "").strip()
|
||
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
|
||
|
||
if len(text) < 2:
|
||
return fail("文本至少2个字符", 400)
|
||
if not label:
|
||
return fail("label 必须是 spam 或 ham", 400)
|
||
|
||
user = current_user()
|
||
row = SpamTrainingSample(text=text, label=label, source="import", created_by=user.id if user else None, is_active=True)
|
||
db.session.add(row)
|
||
db.session.commit()
|
||
return ok(row.to_dict(), "样本创建成功")
|
||
|
||
|
||
@spam_bp.put("/samples/<int:sample_id>")
|
||
@admin_required
|
||
def update_sample(sample_id: int):
|
||
row = SpamTrainingSample.query.get(sample_id)
|
||
if not row:
|
||
return fail("样本不存在", 404)
|
||
|
||
payload = request.get_json(silent=True) or {}
|
||
if "text" in payload:
|
||
text = (payload.get("text") or "").strip()
|
||
if len(text) < 2:
|
||
return fail("文本至少2个字符", 400)
|
||
row.text = text
|
||
if "label" in payload:
|
||
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
|
||
if not label:
|
||
return fail("label 必须是 spam 或 ham", 400)
|
||
row.label = label
|
||
if "is_active" in payload:
|
||
row.is_active = bool(payload.get("is_active"))
|
||
|
||
db.session.commit()
|
||
return ok(row.to_dict(), "样本更新成功")
|
||
|
||
|
||
@spam_bp.delete("/samples/<int:sample_id>")
|
||
@admin_required
|
||
def delete_sample(sample_id: int):
|
||
row = SpamTrainingSample.query.get(sample_id)
|
||
if not row:
|
||
return fail("样本不存在", 404)
|
||
|
||
db.session.delete(row)
|
||
db.session.commit()
|
||
return ok({}, "样本已删除")
|
||
|
||
|
||
@spam_bp.post("/samples/import")
|
||
@admin_required
|
||
def import_samples():
|
||
payload = request.get_json(silent=True) or {}
|
||
items = payload.get("items") or []
|
||
if not isinstance(items, list) or not items:
|
||
return fail("items 必须是非空数组", 400)
|
||
|
||
user = current_user()
|
||
created = 0
|
||
updated = 0
|
||
|
||
for item in items:
|
||
text = (item.get("text") or "").strip()
|
||
label = NaiveBayesSpamClassifier.normalize_label(item.get("label"))
|
||
if len(text) < 2 or not label:
|
||
continue
|
||
|
||
row = SpamTrainingSample.query.filter_by(text=text).first()
|
||
if row:
|
||
row.label = label
|
||
row.is_active = bool(item.get("is_active", True))
|
||
row.source = item.get("source") or row.source
|
||
updated += 1
|
||
else:
|
||
row = SpamTrainingSample(
|
||
text=text,
|
||
label=label,
|
||
source=item.get("source") or "import",
|
||
created_by=user.id if user else None,
|
||
is_active=bool(item.get("is_active", True)),
|
||
)
|
||
db.session.add(row)
|
||
created += 1
|
||
|
||
db.session.commit()
|
||
return ok({"created": created, "updated": updated}, "样本导入完成")
|
||
|
||
|
||
@spam_bp.post("/export/xlsx")
|
||
@jwt_required()
|
||
def export_xlsx():
|
||
user = current_user()
|
||
if not user:
|
||
return fail("用户不存在", 404)
|
||
|
||
payload = request.get_json(silent=True) or {}
|
||
items = payload.get("items") or []
|
||
if not isinstance(items, list) or not items:
|
||
return fail("items 必须是非空数组", 400)
|
||
|
||
import os
|
||
import tempfile
|
||
import pandas as pd
|
||
|
||
rows = []
|
||
for item in items:
|
||
tokens = item.get("reason_tokens") or []
|
||
token_str = "; ".join(t.get("token", "") for t in tokens) if isinstance(tokens, list) else ""
|
||
prediction_text = "垃圾信息" if item.get("prediction") == "spam" else "正常信息"
|
||
|
||
rows.append({
|
||
"文本": item.get("text", ""),
|
||
"判定结果": prediction_text,
|
||
"分类标签": item.get("category_label", ""),
|
||
"置信度": f"{float(item.get("confidence", 0) or 0) * 100:.2f}%",
|
||
"垃圾概率": f"{float(item.get("spam_probability", 0) or 0) * 100:.2f}%",
|
||
"正常概率": f"{float(item.get("ham_probability", 0) or 0) * 100:.2f}%",
|
||
"风险关键词": token_str,
|
||
})
|
||
|
||
df = pd.DataFrame(rows)
|
||
tmp = tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False)
|
||
tmp.close()
|
||
df.to_excel(tmp.name, index=False, engine="openpyxl")
|
||
|
||
@after_this_request
|
||
def cleanup(response):
|
||
try:
|
||
os.unlink(tmp.name)
|
||
except Exception:
|
||
pass
|
||
return response
|
||
|
||
return send_file(
|
||
tmp.name,
|
||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||
as_attachment=True,
|
||
download_name="batch_detect.xlsx",
|
||
)
|