335 lines
10 KiB
Python
335 lines
10 KiB
Python
from flask import Blueprint, current_app, request
|
|
from flask_jwt_extended import jwt_required
|
|
|
|
from app.extensions import db
|
|
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
|
|
from app.models import DetectionConfig, SpamPredictionLog, SpamTrainingSample
|
|
from app.utils.auth import admin_required, current_user
|
|
from app.utils.response import fail, ok
|
|
|
|
|
|
spam_bp = Blueprint("spam", __name__)
|
|
|
|
|
|
def _classifier() -> NaiveBayesSpamClassifier:
|
|
return NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
|
|
|
|
|
|
def _active_samples() -> list[dict]:
|
|
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
|
|
return [{"text": row.text, "label": row.label} for row in rows]
|
|
|
|
|
|
def _ensure_ready() -> NaiveBayesSpamClassifier:
|
|
clf = _classifier()
|
|
samples = _active_samples()
|
|
clf.ensure_ready(samples)
|
|
return clf
|
|
|
|
|
|
def _threshold() -> float:
|
|
row = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
|
|
return float(row.spam_threshold) if row else 0.75
|
|
|
|
|
|
@spam_bp.post("/predict")
|
|
@jwt_required()
|
|
def predict_one():
|
|
user = current_user()
|
|
if not user:
|
|
return fail("用户不存在", 404)
|
|
|
|
payload = request.get_json(silent=True) or {}
|
|
text = (payload.get("text") or "").strip()
|
|
if len(text) < 2:
|
|
return fail("请输入至少2个字符的待识别文本", 400)
|
|
|
|
clf = _ensure_ready()
|
|
result = clf.predict(text)
|
|
threshold = _threshold()
|
|
blocked = float(result["spam_probability"]) >= threshold
|
|
|
|
row = SpamPredictionLog(
|
|
user_id=user.id,
|
|
text=result["text"],
|
|
prediction=result["prediction"],
|
|
spam_probability=result["spam_probability"],
|
|
ham_probability=result["ham_probability"],
|
|
confidence=result["confidence"],
|
|
reason_tokens=result["reason_tokens"],
|
|
model_version=result.get("model_version", ""),
|
|
)
|
|
db.session.add(row)
|
|
db.session.commit()
|
|
|
|
return ok({**result, "log_id": row.id, "threshold": threshold, "blocked_by_threshold": blocked}, "识别成功")
|
|
|
|
|
|
@spam_bp.post("/predict/batch")
|
|
@jwt_required()
|
|
def predict_batch():
|
|
user = current_user()
|
|
if not user:
|
|
return fail("用户不存在", 404)
|
|
|
|
payload = request.get_json(silent=True) or {}
|
|
items = payload.get("items") or []
|
|
if not isinstance(items, list) or not items:
|
|
return fail("items 必须是非空数组", 400)
|
|
if len(items) > 100:
|
|
return fail("单次最多识别100条", 400)
|
|
|
|
clf = _ensure_ready()
|
|
rows = []
|
|
results = []
|
|
threshold = _threshold()
|
|
|
|
for text in items:
|
|
content = (text or "").strip()
|
|
if len(content) < 2:
|
|
continue
|
|
result = clf.predict(content)
|
|
result["blocked_by_threshold"] = float(result["spam_probability"]) >= threshold
|
|
rows.append(
|
|
SpamPredictionLog(
|
|
user_id=user.id,
|
|
text=result["text"],
|
|
prediction=result["prediction"],
|
|
spam_probability=result["spam_probability"],
|
|
ham_probability=result["ham_probability"],
|
|
confidence=result["confidence"],
|
|
reason_tokens=result["reason_tokens"],
|
|
model_version=result.get("model_version", ""),
|
|
)
|
|
)
|
|
results.append(result)
|
|
|
|
if not rows:
|
|
return fail("没有可识别的有效文本", 400)
|
|
|
|
db.session.add_all(rows)
|
|
db.session.commit()
|
|
|
|
spam_count = len([item for item in results if item["prediction"] == "spam"])
|
|
blocked_count = len([item for item in results if item["blocked_by_threshold"]])
|
|
ham_count = len(results) - spam_count
|
|
|
|
return ok(
|
|
{
|
|
"items": results,
|
|
"summary": {
|
|
"total": len(results),
|
|
"spam_count": spam_count,
|
|
"ham_count": ham_count,
|
|
"blocked_count": blocked_count,
|
|
"spam_ratio": round(spam_count / len(results), 4) if results else 0,
|
|
"blocked_ratio": round(blocked_count / len(results), 4) if results else 0,
|
|
"threshold": threshold,
|
|
},
|
|
},
|
|
"批量识别完成",
|
|
)
|
|
|
|
|
|
@spam_bp.get("/history")
|
|
@jwt_required()
|
|
def my_history():
|
|
user = current_user()
|
|
if not user:
|
|
return fail("用户不存在", 404)
|
|
|
|
page = max(int(request.args.get("page", 1) or 1), 1)
|
|
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
|
|
|
pagination = (
|
|
SpamPredictionLog.query.filter_by(user_id=user.id)
|
|
.order_by(SpamPredictionLog.id.desc())
|
|
.paginate(page=page, per_page=page_size, error_out=False)
|
|
)
|
|
|
|
return ok(
|
|
{
|
|
"items": [item.to_dict() for item in pagination.items],
|
|
"total": pagination.total,
|
|
"page": page,
|
|
"page_size": page_size,
|
|
}
|
|
)
|
|
|
|
|
|
@spam_bp.delete("/history/<int:log_id>")
|
|
@jwt_required()
|
|
def delete_history(log_id: int):
|
|
user = current_user()
|
|
if not user:
|
|
return fail("用户不存在", 404)
|
|
|
|
row = SpamPredictionLog.query.filter_by(id=log_id, user_id=user.id).first()
|
|
if not row:
|
|
return fail("记录不存在", 404)
|
|
|
|
db.session.delete(row)
|
|
db.session.commit()
|
|
return ok({}, "记录已删除")
|
|
|
|
|
|
@spam_bp.post("/feedback")
|
|
@jwt_required()
|
|
def save_feedback():
|
|
user = current_user()
|
|
if not user:
|
|
return fail("用户不存在", 404)
|
|
|
|
payload = request.get_json(silent=True) or {}
|
|
text = (payload.get("text") or "").strip()
|
|
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
|
|
|
|
if len(text) < 2:
|
|
return fail("文本至少2个字符", 400)
|
|
if not label:
|
|
return fail("label 必须是 spam 或 ham", 400)
|
|
|
|
row = SpamTrainingSample(text=text, label=label, source="feedback", created_by=user.id, is_active=True)
|
|
db.session.add(row)
|
|
db.session.commit()
|
|
return ok(row.to_dict(), "反馈样本已记录")
|
|
|
|
|
|
@spam_bp.get("/model/info")
|
|
@jwt_required(optional=True)
|
|
def model_info():
|
|
clf = _classifier()
|
|
clf.load()
|
|
info = clf.model_info()
|
|
info["threshold"] = _threshold()
|
|
return ok(info)
|
|
|
|
|
|
@spam_bp.post("/train")
|
|
@admin_required
|
|
def train_model():
|
|
clf = _classifier()
|
|
samples = _active_samples()
|
|
metadata = clf.train(samples)
|
|
return ok(metadata, "模型训练完成")
|
|
|
|
|
|
@spam_bp.get("/samples")
|
|
@admin_required
|
|
def list_samples():
|
|
keyword = (request.args.get("keyword") or "").strip()
|
|
label = (request.args.get("label") or "").strip().lower()
|
|
page = max(int(request.args.get("page", 1) or 1), 1)
|
|
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
|
|
|
query = SpamTrainingSample.query
|
|
if keyword:
|
|
query = query.filter(SpamTrainingSample.text.like(f"%{keyword}%"))
|
|
if label in {"spam", "ham"}:
|
|
query = query.filter(SpamTrainingSample.label == label)
|
|
|
|
pagination = query.order_by(SpamTrainingSample.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
|
|
return ok(
|
|
{
|
|
"items": [item.to_dict() for item in pagination.items],
|
|
"total": pagination.total,
|
|
"page": page,
|
|
"page_size": page_size,
|
|
}
|
|
)
|
|
|
|
|
|
@spam_bp.post("/samples")
|
|
@admin_required
|
|
def create_sample():
|
|
payload = request.get_json(silent=True) or {}
|
|
text = (payload.get("text") or "").strip()
|
|
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
|
|
|
|
if len(text) < 2:
|
|
return fail("文本至少2个字符", 400)
|
|
if not label:
|
|
return fail("label 必须是 spam 或 ham", 400)
|
|
|
|
user = current_user()
|
|
row = SpamTrainingSample(text=text, label=label, source="import", created_by=user.id if user else None, is_active=True)
|
|
db.session.add(row)
|
|
db.session.commit()
|
|
return ok(row.to_dict(), "样本创建成功")
|
|
|
|
|
|
@spam_bp.put("/samples/<int:sample_id>")
|
|
@admin_required
|
|
def update_sample(sample_id: int):
|
|
row = SpamTrainingSample.query.get(sample_id)
|
|
if not row:
|
|
return fail("样本不存在", 404)
|
|
|
|
payload = request.get_json(silent=True) or {}
|
|
if "text" in payload:
|
|
text = (payload.get("text") or "").strip()
|
|
if len(text) < 2:
|
|
return fail("文本至少2个字符", 400)
|
|
row.text = text
|
|
if "label" in payload:
|
|
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
|
|
if not label:
|
|
return fail("label 必须是 spam 或 ham", 400)
|
|
row.label = label
|
|
if "is_active" in payload:
|
|
row.is_active = bool(payload.get("is_active"))
|
|
|
|
db.session.commit()
|
|
return ok(row.to_dict(), "样本更新成功")
|
|
|
|
|
|
@spam_bp.delete("/samples/<int:sample_id>")
|
|
@admin_required
|
|
def delete_sample(sample_id: int):
|
|
row = SpamTrainingSample.query.get(sample_id)
|
|
if not row:
|
|
return fail("样本不存在", 404)
|
|
|
|
db.session.delete(row)
|
|
db.session.commit()
|
|
return ok({}, "样本已删除")
|
|
|
|
|
|
@spam_bp.post("/samples/import")
|
|
@admin_required
|
|
def import_samples():
|
|
payload = request.get_json(silent=True) or {}
|
|
items = payload.get("items") or []
|
|
if not isinstance(items, list) or not items:
|
|
return fail("items 必须是非空数组", 400)
|
|
|
|
user = current_user()
|
|
created = 0
|
|
updated = 0
|
|
|
|
for item in items:
|
|
text = (item.get("text") or "").strip()
|
|
label = NaiveBayesSpamClassifier.normalize_label(item.get("label"))
|
|
if len(text) < 2 or not label:
|
|
continue
|
|
|
|
row = SpamTrainingSample.query.filter_by(text=text).first()
|
|
if row:
|
|
row.label = label
|
|
row.is_active = bool(item.get("is_active", True))
|
|
row.source = item.get("source") or row.source
|
|
updated += 1
|
|
else:
|
|
row = SpamTrainingSample(
|
|
text=text,
|
|
label=label,
|
|
source=item.get("source") or "import",
|
|
created_by=user.id if user else None,
|
|
is_active=bool(item.get("is_active", True)),
|
|
)
|
|
db.session.add(row)
|
|
created += 1
|
|
|
|
db.session.commit()
|
|
return ok({"created": created, "updated": updated}, "样本导入完成")
|