1
This commit is contained in:
1
backend/app/routes/__init__.py
Normal file
1
backend/app/routes/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
BIN
backend/app/routes/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
backend/app/routes/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
backend/app/routes/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/admin_routes.cpython-310.pyc
Normal file
BIN
backend/app/routes/__pycache__/admin_routes.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/admin_routes.cpython-311.pyc
Normal file
BIN
backend/app/routes/__pycache__/admin_routes.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/auth_routes.cpython-310.pyc
Normal file
BIN
backend/app/routes/__pycache__/auth_routes.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/auth_routes.cpython-311.pyc
Normal file
BIN
backend/app/routes/__pycache__/auth_routes.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/content_routes.cpython-310.pyc
Normal file
BIN
backend/app/routes/__pycache__/content_routes.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/content_routes.cpython-311.pyc
Normal file
BIN
backend/app/routes/__pycache__/content_routes.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/diet_routes.cpython-311.pyc
Normal file
BIN
backend/app/routes/__pycache__/diet_routes.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/qa_routes.cpython-311.pyc
Normal file
BIN
backend/app/routes/__pycache__/qa_routes.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/recipe_routes.cpython-311.pyc
Normal file
BIN
backend/app/routes/__pycache__/recipe_routes.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/recommend_routes.cpython-311.pyc
Normal file
BIN
backend/app/routes/__pycache__/recommend_routes.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/spam_routes.cpython-310.pyc
Normal file
BIN
backend/app/routes/__pycache__/spam_routes.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/spam_routes.cpython-311.pyc
Normal file
BIN
backend/app/routes/__pycache__/spam_routes.cpython-311.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/user_routes.cpython-310.pyc
Normal file
BIN
backend/app/routes/__pycache__/user_routes.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/app/routes/__pycache__/user_routes.cpython-311.pyc
Normal file
BIN
backend/app/routes/__pycache__/user_routes.cpython-311.pyc
Normal file
Binary file not shown.
429
backend/app/routes/admin_routes.py
Normal file
429
backend/app/routes/admin_routes.py
Normal file
@@ -0,0 +1,429 @@
|
||||
from collections import Counter
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from flask import Blueprint, current_app, request
|
||||
from sqlalchemy import func, or_
|
||||
|
||||
from app.extensions import db
|
||||
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
|
||||
from app.models import ContentPost, DetectionConfig, SpamPredictionLog, SpamTrainingSample, User
|
||||
from app.utils.auth import admin_required, current_user
|
||||
from app.utils.response import fail, ok
|
||||
|
||||
|
||||
admin_bp = Blueprint("admin", __name__)
|
||||
|
||||
|
||||
def _day_key(day_value) -> str:
|
||||
if hasattr(day_value, "isoformat"):
|
||||
return day_value.isoformat()
|
||||
return str(day_value)
|
||||
|
||||
|
||||
def _tokenize(text: str) -> list[str]:
|
||||
content = (text or "").strip()
|
||||
if len(content) <= 1:
|
||||
return []
|
||||
tokens = []
|
||||
for i in range(len(content) - 1):
|
||||
token = content[i : i + 2]
|
||||
if token.strip():
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
|
||||
def _get_or_create_config() -> DetectionConfig:
|
||||
cfg = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
|
||||
if cfg:
|
||||
return cfg
|
||||
|
||||
cfg = DetectionConfig(spam_threshold=0.75)
|
||||
db.session.add(cfg)
|
||||
db.session.commit()
|
||||
return cfg
|
||||
|
||||
|
||||
def _serialize_post(item: ContentPost) -> dict:
|
||||
row = item.to_dict()
|
||||
row["username"] = item.author.username if item.author else ""
|
||||
row["nickname"] = item.author.nickname if item.author else ""
|
||||
row["recipient_username"] = item.recipient.username if item.recipient else ""
|
||||
row["recipient_nickname"] = item.recipient.nickname if item.recipient else ""
|
||||
row["reviewer_username"] = item.reviewer.username if item.reviewer else ""
|
||||
return row
|
||||
|
||||
|
||||
def _upsert_manual_sample(text: str, label: str, admin_id: int | None) -> None:
|
||||
existed = SpamTrainingSample.query.filter_by(text=text, label=label).first()
|
||||
if existed:
|
||||
existed.is_active = True
|
||||
existed.source = existed.source or "manual_review"
|
||||
return
|
||||
|
||||
row = SpamTrainingSample(
|
||||
text=text,
|
||||
label=label,
|
||||
source="manual_review",
|
||||
created_by=admin_id,
|
||||
is_active=True,
|
||||
)
|
||||
db.session.add(row)
|
||||
|
||||
|
||||
@admin_bp.get("/stats")
|
||||
@admin_required
|
||||
def stats():
|
||||
user_count = User.query.count()
|
||||
sample_count = SpamTrainingSample.query.count()
|
||||
predict_count = SpamPredictionLog.query.count()
|
||||
post_count = ContentPost.query.count()
|
||||
blocked_count = ContentPost.query.filter_by(status="blocked").count()
|
||||
published_count = ContentPost.query.filter_by(status="published").count()
|
||||
pending_appeal_count = ContentPost.query.filter_by(appeal_status="pending").count()
|
||||
|
||||
now = datetime.utcnow()
|
||||
week_ago = now - timedelta(days=6)
|
||||
|
||||
trend_rows = (
|
||||
db.session.query(func.date(ContentPost.created_at), func.count(ContentPost.id))
|
||||
.filter(ContentPost.created_at >= week_ago)
|
||||
.group_by(func.date(ContentPost.created_at))
|
||||
.all()
|
||||
)
|
||||
|
||||
blocked_7d_count = ContentPost.query.filter(ContentPost.created_at >= week_ago, ContentPost.status == "blocked").count() or 0
|
||||
total_7d_count = ContentPost.query.filter(ContentPost.created_at >= week_ago).count() or 0
|
||||
|
||||
day_map = {_day_key(day): int(count or 0) for day, count in trend_rows}
|
||||
trend = []
|
||||
today = now.date()
|
||||
for offset in range(6, -1, -1):
|
||||
day = today - timedelta(days=offset)
|
||||
key = day.isoformat()
|
||||
trend.append({"date": key, "label": day.strftime("%m-%d"), "post_count": day_map.get(key, 0)})
|
||||
|
||||
source_rows = (
|
||||
db.session.query(SpamTrainingSample.source, func.count(SpamTrainingSample.id))
|
||||
.group_by(SpamTrainingSample.source)
|
||||
.order_by(func.count(SpamTrainingSample.id).desc())
|
||||
.all()
|
||||
)
|
||||
source_dist = [{"name": (name or "unknown"), "value": int(value or 0)} for name, value in source_rows]
|
||||
|
||||
blocked_logs = (
|
||||
ContentPost.query.filter(ContentPost.created_at >= week_ago, ContentPost.status == "blocked")
|
||||
.order_by(ContentPost.id.desc())
|
||||
.limit(1000)
|
||||
.all()
|
||||
)
|
||||
|
||||
token_counter = Counter()
|
||||
for row in blocked_logs:
|
||||
token_counter.update(_tokenize(row.text))
|
||||
top_keywords = [{"token": token, "count": count} for token, count in token_counter.most_common(12)]
|
||||
|
||||
cfg = _get_or_create_config()
|
||||
|
||||
clf = NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
|
||||
clf.load()
|
||||
|
||||
return ok(
|
||||
{
|
||||
"user_count": user_count,
|
||||
"sample_count": sample_count,
|
||||
"predict_count": predict_count,
|
||||
"post_count": post_count,
|
||||
"blocked_count": blocked_count,
|
||||
"published_count": published_count,
|
||||
"pending_appeal_count": pending_appeal_count,
|
||||
"blocked_ratio_7d": round(blocked_7d_count / total_7d_count, 4) if total_7d_count else 0,
|
||||
"total_7d": total_7d_count,
|
||||
"trend_7d": trend,
|
||||
"source_distribution": source_dist,
|
||||
"top_keywords": top_keywords,
|
||||
"model_info": clf.model_info(),
|
||||
"threshold": cfg.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@admin_bp.get("/detection/threshold")
|
||||
@admin_required
|
||||
def get_threshold():
|
||||
return ok(_get_or_create_config().to_dict())
|
||||
|
||||
|
||||
@admin_bp.put("/detection/threshold")
|
||||
@admin_required
|
||||
def set_threshold():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
try:
|
||||
threshold = float(payload.get("spam_threshold"))
|
||||
except Exception:
|
||||
return fail("spam_threshold 必须是数字", 400)
|
||||
|
||||
if threshold < 0.01 or threshold > 0.99:
|
||||
return fail("spam_threshold 必须在 0.01 到 0.99 之间", 400)
|
||||
|
||||
cfg = _get_or_create_config()
|
||||
admin = current_user()
|
||||
cfg.spam_threshold = threshold
|
||||
cfg.updated_by = admin.id if admin else None
|
||||
db.session.commit()
|
||||
return ok(cfg.to_dict(), "阈值更新成功")
|
||||
|
||||
|
||||
@admin_bp.get("/intercepts")
|
||||
@admin_required
|
||||
def list_intercepts():
|
||||
keyword = (request.args.get("keyword") or "").strip()
|
||||
status = (request.args.get("status") or "blocked").strip().lower()
|
||||
review_status = (request.args.get("review_status") or "").strip().lower()
|
||||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||||
|
||||
query = ContentPost.query
|
||||
if keyword:
|
||||
query = query.filter(ContentPost.text.like(f"%{keyword}%"))
|
||||
if status in {"blocked", "published"}:
|
||||
query = query.filter(ContentPost.status == status)
|
||||
if review_status in {"none", "pending", "confirmed_spam", "approved_ham"}:
|
||||
query = query.filter(ContentPost.manual_review_status == review_status)
|
||||
|
||||
pagination = query.order_by(ContentPost.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
|
||||
return ok(
|
||||
{
|
||||
"items": [_serialize_post(item) for item in pagination.items],
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@admin_bp.put("/intercepts/<int:post_id>/review")
|
||||
@admin_required
|
||||
def review_intercept(post_id: int):
|
||||
row = ContentPost.query.get(post_id)
|
||||
if not row:
|
||||
return fail("记录不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
decision = (payload.get("decision") or "").strip().lower()
|
||||
note = (payload.get("note") or "").strip()
|
||||
if decision not in {"spam", "ham"}:
|
||||
return fail("decision 必须是 spam 或 ham", 400)
|
||||
|
||||
admin = current_user()
|
||||
now = datetime.utcnow()
|
||||
|
||||
row.manual_review_by = admin.id if admin else None
|
||||
row.manual_review_note = note
|
||||
row.manual_review_at = now
|
||||
|
||||
if decision == "spam":
|
||||
row.status = "blocked"
|
||||
row.prediction = "spam"
|
||||
row.manual_review_status = "confirmed_spam"
|
||||
if row.appeal_status == "pending":
|
||||
row.appeal_status = "rejected"
|
||||
row.appeal_admin_note = note or "人工复核确认为垃圾信息"
|
||||
row.appeal_processed_at = now
|
||||
row.appeal_processed_by = admin.id if admin else None
|
||||
_upsert_manual_sample(row.text, "spam", admin.id if admin else None)
|
||||
else:
|
||||
row.status = "published"
|
||||
row.prediction = "ham"
|
||||
row.manual_review_status = "approved_ham"
|
||||
if row.appeal_status == "pending":
|
||||
row.appeal_status = "approved"
|
||||
row.appeal_admin_note = note or "人工复核后解除拦截"
|
||||
row.appeal_processed_at = now
|
||||
row.appeal_processed_by = admin.id if admin else None
|
||||
_upsert_manual_sample(row.text, "ham", admin.id if admin else None)
|
||||
|
||||
db.session.commit()
|
||||
return ok(_serialize_post(row), "人工复核完成")
|
||||
|
||||
|
||||
@admin_bp.get("/appeals")
|
||||
@admin_required
|
||||
def list_appeals():
|
||||
keyword = (request.args.get("keyword") or "").strip()
|
||||
status = (request.args.get("status") or "pending").strip().lower()
|
||||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||||
|
||||
query = ContentPost.query.filter(ContentPost.appeal_status != "none")
|
||||
if keyword:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ContentPost.text.like(f"%{keyword}%"),
|
||||
ContentPost.appeal_reason.like(f"%{keyword}%"),
|
||||
ContentPost.appeal_admin_note.like(f"%{keyword}%"),
|
||||
)
|
||||
)
|
||||
if status in {"pending", "approved", "rejected"}:
|
||||
query = query.filter(ContentPost.appeal_status == status)
|
||||
|
||||
pagination = query.order_by(ContentPost.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
|
||||
return ok(
|
||||
{
|
||||
"items": [_serialize_post(item) for item in pagination.items],
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@admin_bp.put("/appeals/<int:post_id>/process")
|
||||
@admin_required
|
||||
def process_appeal(post_id: int):
|
||||
row = ContentPost.query.get(post_id)
|
||||
if not row:
|
||||
return fail("记录不存在", 404)
|
||||
if row.appeal_status != "pending":
|
||||
return fail("该申诉不在待处理状态", 400)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
action = (payload.get("action") or "").strip().lower()
|
||||
note = (payload.get("note") or "").strip()
|
||||
if action not in {"approve", "reject"}:
|
||||
return fail("action 必须是 approve 或 reject", 400)
|
||||
|
||||
admin = current_user()
|
||||
now = datetime.utcnow()
|
||||
|
||||
row.appeal_status = "approved" if action == "approve" else "rejected"
|
||||
row.appeal_admin_note = note
|
||||
row.appeal_processed_at = now
|
||||
row.appeal_processed_by = admin.id if admin else None
|
||||
|
||||
row.manual_review_by = admin.id if admin else None
|
||||
row.manual_review_note = note
|
||||
row.manual_review_at = now
|
||||
|
||||
if action == "approve":
|
||||
row.status = "published"
|
||||
row.prediction = "ham"
|
||||
row.manual_review_status = "approved_ham"
|
||||
_upsert_manual_sample(row.text, "ham", admin.id if admin else None)
|
||||
else:
|
||||
row.status = "blocked"
|
||||
row.prediction = "spam"
|
||||
row.manual_review_status = "confirmed_spam"
|
||||
_upsert_manual_sample(row.text, "spam", admin.id if admin else None)
|
||||
|
||||
db.session.commit()
|
||||
return ok(_serialize_post(row), "申诉处理完成")
|
||||
|
||||
|
||||
@admin_bp.get("/users")
|
||||
@admin_required
|
||||
def list_users():
|
||||
keyword = (request.args.get("keyword") or "").strip()
|
||||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||||
|
||||
query = User.query
|
||||
if keyword:
|
||||
query = query.filter(User.username.like(f"%{keyword}%") | User.nickname.like(f"%{keyword}%"))
|
||||
|
||||
pagination = query.order_by(User.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
|
||||
return ok(
|
||||
{
|
||||
"items": [item.to_dict() for item in pagination.items],
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@admin_bp.post("/users/import")
|
||||
@admin_required
|
||||
def import_users():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
items = payload.get("items") or []
|
||||
if not isinstance(items, list) or not items:
|
||||
return fail("items 必须是非空数组", 400)
|
||||
|
||||
created = 0
|
||||
updated = 0
|
||||
|
||||
for row in items:
|
||||
username = (row.get("username") or "").strip()
|
||||
if len(username) < 3:
|
||||
continue
|
||||
|
||||
user = User.query.filter_by(username=username).first()
|
||||
if not user:
|
||||
user = User(
|
||||
username=username,
|
||||
nickname=(row.get("nickname") or username).strip(),
|
||||
company=(row.get("company") or "").strip(),
|
||||
title=(row.get("title") or "").strip(),
|
||||
phone=(row.get("phone") or "").strip(),
|
||||
is_admin=bool(row.get("is_admin", False)),
|
||||
)
|
||||
user.set_password(row.get("password") or "123456")
|
||||
db.session.add(user)
|
||||
created += 1
|
||||
continue
|
||||
|
||||
user.nickname = (row.get("nickname") or user.nickname).strip()
|
||||
user.company = (row.get("company") or user.company).strip()
|
||||
user.title = (row.get("title") or user.title).strip()
|
||||
user.phone = (row.get("phone") or user.phone).strip()
|
||||
if "is_admin" in row:
|
||||
user.is_admin = bool(row.get("is_admin"))
|
||||
if row.get("password"):
|
||||
user.set_password(row["password"])
|
||||
updated += 1
|
||||
|
||||
db.session.commit()
|
||||
return ok({"created": created, "updated": updated}, "用户导入完成")
|
||||
|
||||
|
||||
@admin_bp.put("/users/<int:user_id>")
|
||||
@admin_required
|
||||
def update_user(user_id: int):
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
if "nickname" in payload:
|
||||
user.nickname = (payload.get("nickname") or user.nickname).strip()
|
||||
if "company" in payload:
|
||||
user.company = (payload.get("company") or "").strip()
|
||||
if "title" in payload:
|
||||
user.title = (payload.get("title") or "").strip()
|
||||
if "phone" in payload:
|
||||
user.phone = (payload.get("phone") or "").strip()
|
||||
if "is_admin" in payload:
|
||||
user.is_admin = bool(payload.get("is_admin"))
|
||||
if payload.get("password"):
|
||||
if len(payload["password"]) < 6:
|
||||
return fail("密码至少6位", 400)
|
||||
user.set_password(payload["password"])
|
||||
|
||||
db.session.commit()
|
||||
return ok(user.to_dict(), "用户更新成功")
|
||||
|
||||
|
||||
@admin_bp.delete("/users/<int:user_id>")
|
||||
@admin_required
|
||||
def delete_user(user_id: int):
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
if user.is_admin and User.query.filter_by(is_admin=True).count() <= 1:
|
||||
return fail("至少保留一个管理员账号", 400)
|
||||
|
||||
db.session.delete(user)
|
||||
db.session.commit()
|
||||
return ok({}, "用户已删除")
|
||||
|
||||
67
backend/app/routes/auth_routes.py
Normal file
67
backend/app/routes/auth_routes.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from flask import Blueprint, request
|
||||
from flask_jwt_extended import create_access_token, jwt_required
|
||||
|
||||
from app.extensions import db
|
||||
from app.models import User
|
||||
from app.utils.auth import current_user
|
||||
from app.utils.response import fail, ok
|
||||
|
||||
|
||||
auth_bp = Blueprint("auth", __name__)
|
||||
|
||||
|
||||
@auth_bp.post("/register")
|
||||
def register():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
username = (payload.get("username") or "").strip()
|
||||
password = payload.get("password") or ""
|
||||
nickname = (payload.get("nickname") or username).strip()
|
||||
|
||||
if len(username) < 3:
|
||||
return fail("用户名至少3位", 400)
|
||||
if len(password) < 6:
|
||||
return fail("密码至少6位", 400)
|
||||
if User.query.filter_by(username=username).first():
|
||||
return fail("用户名已存在", 409)
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
nickname=nickname or username,
|
||||
company=(payload.get("company") or "").strip(),
|
||||
title=(payload.get("title") or "").strip(),
|
||||
phone=(payload.get("phone") or "").strip(),
|
||||
is_admin=bool(payload.get("is_admin", False)),
|
||||
)
|
||||
user.set_password(password)
|
||||
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
return ok(user.to_dict(), "注册成功")
|
||||
|
||||
|
||||
@auth_bp.post("/login")
|
||||
def login():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
username = (payload.get("username") or "").strip()
|
||||
password = payload.get("password") or ""
|
||||
|
||||
user = User.query.filter_by(username=username).first()
|
||||
if not user or not user.check_password(password):
|
||||
return fail("用户名或密码错误", 401)
|
||||
|
||||
access_token = create_access_token(
|
||||
identity=str(user.id),
|
||||
additional_claims={"is_admin": bool(user.is_admin), "username": user.username},
|
||||
)
|
||||
|
||||
return ok({"token": access_token, "user": user.to_dict()}, "登录成功")
|
||||
|
||||
|
||||
@auth_bp.get("/me")
|
||||
@jwt_required()
|
||||
def me():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
return ok(user.to_dict())
|
||||
360
backend/app/routes/content_routes.py
Normal file
360
backend/app/routes/content_routes.py
Normal file
@@ -0,0 +1,360 @@
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Blueprint, current_app, request
|
||||
from flask_jwt_extended import jwt_required
|
||||
|
||||
from app.extensions import db
|
||||
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
|
||||
from app.models import ContentPost, DetectionConfig, SpamPredictionLog, SpamTrainingSample, User
|
||||
from app.utils.auth import current_user
|
||||
from app.utils.response import fail, ok
|
||||
|
||||
|
||||
content_bp = Blueprint("content", __name__)
|
||||
|
||||
|
||||
def _classifier() -> NaiveBayesSpamClassifier:
|
||||
return NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
|
||||
|
||||
|
||||
def _active_samples() -> list[dict]:
|
||||
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
|
||||
return [{"text": row.text, "label": row.label} for row in rows]
|
||||
|
||||
|
||||
def _ensure_ready() -> NaiveBayesSpamClassifier:
|
||||
clf = _classifier()
|
||||
clf.ensure_ready(_active_samples())
|
||||
return clf
|
||||
|
||||
|
||||
def _get_config() -> DetectionConfig:
|
||||
cfg = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
|
||||
if cfg:
|
||||
return cfg
|
||||
|
||||
cfg = DetectionConfig(spam_threshold=0.75)
|
||||
db.session.add(cfg)
|
||||
db.session.commit()
|
||||
return cfg
|
||||
|
||||
|
||||
def _serialize_post(row: ContentPost) -> dict:
|
||||
payload = row.to_dict()
|
||||
payload["username"] = row.author.username if row.author else ""
|
||||
payload["nickname"] = row.author.nickname if row.author else ""
|
||||
payload["recipient_username"] = row.recipient.username if row.recipient else ""
|
||||
payload["recipient_nickname"] = row.recipient.nickname if row.recipient else ""
|
||||
payload["reviewer_username"] = row.reviewer.username if row.reviewer else ""
|
||||
return payload
|
||||
|
||||
|
||||
def _resolve_visibility(value: str) -> str:
|
||||
key = (value or "public").strip().lower()
|
||||
return key if key in {"public", "private", "direct"} else "public"
|
||||
|
||||
|
||||
def _resolve_recipient(payload: dict, visibility: str, current_user_id: int):
|
||||
if visibility != "direct":
|
||||
return None, None
|
||||
|
||||
recipient = None
|
||||
raw_id = payload.get("recipient_user_id")
|
||||
username = (payload.get("recipient_username") or "").strip()
|
||||
|
||||
if raw_id is not None and str(raw_id).strip() != "":
|
||||
try:
|
||||
recipient = User.query.get(int(raw_id))
|
||||
except Exception:
|
||||
return None, "recipient_user_id 无效"
|
||||
elif username:
|
||||
recipient = User.query.filter_by(username=username).first()
|
||||
|
||||
if not recipient:
|
||||
return None, "私信发布必须指定有效接收人"
|
||||
if recipient.id == current_user_id:
|
||||
return None, "不能给自己发送私信"
|
||||
return recipient, None
|
||||
|
||||
|
||||
def _predict_and_decide(text: str) -> tuple[dict, float, bool]:
|
||||
clf = _ensure_ready()
|
||||
result = clf.predict(text)
|
||||
threshold = float(_get_config().spam_threshold)
|
||||
blocked = float(result["spam_probability"]) >= threshold
|
||||
return result, threshold, blocked
|
||||
|
||||
|
||||
@content_bp.post("/publish")
|
||||
@jwt_required()
|
||||
def publish_text():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
text = (payload.get("text") or "").strip()
|
||||
visibility = _resolve_visibility(payload.get("visibility"))
|
||||
|
||||
if len(text) < 2:
|
||||
return fail("发布文本至少2个字符", 400)
|
||||
|
||||
recipient, err = _resolve_recipient(payload, visibility, user.id)
|
||||
if err:
|
||||
return fail(err, 400)
|
||||
|
||||
result, threshold, blocked = _predict_and_decide(text)
|
||||
|
||||
post = ContentPost(
|
||||
user_id=user.id,
|
||||
recipient_user_id=recipient.id if recipient else None,
|
||||
text=result["text"],
|
||||
visibility=visibility,
|
||||
status="blocked" if blocked else "published",
|
||||
prediction=result["prediction"],
|
||||
spam_probability=result["spam_probability"],
|
||||
ham_probability=result["ham_probability"],
|
||||
confidence=result["confidence"],
|
||||
threshold=threshold,
|
||||
reason_tokens=result["reason_tokens"],
|
||||
model_version=result.get("model_version", ""),
|
||||
manual_review_status="pending" if blocked else "none",
|
||||
)
|
||||
|
||||
detect_log = SpamPredictionLog(
|
||||
user_id=user.id,
|
||||
text=result["text"],
|
||||
prediction=result["prediction"],
|
||||
spam_probability=result["spam_probability"],
|
||||
ham_probability=result["ham_probability"],
|
||||
confidence=result["confidence"],
|
||||
reason_tokens=result["reason_tokens"],
|
||||
model_version=result.get("model_version", ""),
|
||||
)
|
||||
|
||||
db.session.add(post)
|
||||
db.session.add(detect_log)
|
||||
db.session.commit()
|
||||
|
||||
feedback = "发布成功" if not blocked else "疑似垃圾信息,系统已拦截,可提交申诉"
|
||||
return ok(
|
||||
{
|
||||
"publish_allowed": not blocked,
|
||||
"action": "published" if not blocked else "blocked",
|
||||
"feedback": feedback,
|
||||
"post": _serialize_post(post),
|
||||
"detect": result,
|
||||
},
|
||||
feedback,
|
||||
)
|
||||
|
||||
|
||||
@content_bp.put("/posts/<int:post_id>")
|
||||
@jwt_required()
|
||||
def edit_post(post_id: int):
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
post = ContentPost.query.filter_by(id=post_id, user_id=user.id).first()
|
||||
if not post:
|
||||
return fail("发布记录不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
text = (payload.get("text") or post.text).strip()
|
||||
visibility = _resolve_visibility(payload.get("visibility") or post.visibility)
|
||||
|
||||
if len(text) < 2:
|
||||
return fail("发布文本至少2个字符", 400)
|
||||
|
||||
recipient, err = _resolve_recipient(payload, visibility, user.id)
|
||||
if err:
|
||||
return fail(err, 400)
|
||||
|
||||
result, threshold, blocked = _predict_and_decide(text)
|
||||
|
||||
post.text = result["text"]
|
||||
post.visibility = visibility
|
||||
post.recipient_user_id = recipient.id if recipient else None
|
||||
post.status = "blocked" if blocked else "published"
|
||||
post.prediction = result["prediction"]
|
||||
post.spam_probability = result["spam_probability"]
|
||||
post.ham_probability = result["ham_probability"]
|
||||
post.confidence = result["confidence"]
|
||||
post.threshold = threshold
|
||||
post.reason_tokens = result["reason_tokens"]
|
||||
post.model_version = result.get("model_version", "")
|
||||
|
||||
post.manual_review_status = "pending" if blocked else "none"
|
||||
post.manual_review_by = None
|
||||
post.manual_review_note = ""
|
||||
post.manual_review_at = None
|
||||
post.appeal_status = "none"
|
||||
post.appeal_reason = ""
|
||||
post.appeal_admin_note = ""
|
||||
post.appeal_submitted_at = None
|
||||
post.appeal_processed_at = None
|
||||
post.appeal_processed_by = None
|
||||
|
||||
db.session.commit()
|
||||
|
||||
feedback = "更新并重新发布成功" if not blocked else "更新后触发拦截,可提交申诉"
|
||||
return ok(
|
||||
{
|
||||
"publish_allowed": not blocked,
|
||||
"action": "published" if not blocked else "blocked",
|
||||
"feedback": feedback,
|
||||
"post": _serialize_post(post),
|
||||
"detect": result,
|
||||
},
|
||||
feedback,
|
||||
)
|
||||
|
||||
|
||||
@content_bp.get("/posts/history")
|
||||
@jwt_required()
|
||||
def my_posts():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
status = (request.args.get("status") or "").strip().lower()
|
||||
visibility = (request.args.get("visibility") or "").strip().lower()
|
||||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||||
|
||||
query = ContentPost.query.filter_by(user_id=user.id)
|
||||
if status in {"published", "blocked"}:
|
||||
query = query.filter(ContentPost.status == status)
|
||||
if visibility in {"public", "private", "direct"}:
|
||||
query = query.filter(ContentPost.visibility == visibility)
|
||||
|
||||
pagination = query.order_by(ContentPost.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
|
||||
return ok(
|
||||
{
|
||||
"items": [_serialize_post(item) for item in pagination.items],
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@content_bp.get("/posts/inbox")
|
||||
@jwt_required()
|
||||
def my_inbox():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||||
|
||||
pagination = (
|
||||
ContentPost.query.filter_by(recipient_user_id=user.id, visibility="direct", status="published")
|
||||
.order_by(ContentPost.id.desc())
|
||||
.paginate(page=page, per_page=page_size, error_out=False)
|
||||
)
|
||||
return ok(
|
||||
{
|
||||
"items": [_serialize_post(item) for item in pagination.items],
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@content_bp.delete("/posts/<int:post_id>")
|
||||
@jwt_required()
|
||||
def delete_post(post_id: int):
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
row = ContentPost.query.filter_by(id=post_id, user_id=user.id).first()
|
||||
if not row:
|
||||
return fail("记录不存在", 404)
|
||||
|
||||
db.session.delete(row)
|
||||
db.session.commit()
|
||||
return ok({}, "记录已删除")
|
||||
|
||||
|
||||
@content_bp.post("/posts/<int:post_id>/appeal")
|
||||
@jwt_required()
|
||||
def submit_appeal(post_id: int):
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
post = ContentPost.query.filter_by(id=post_id, user_id=user.id).first()
|
||||
if not post:
|
||||
return fail("发布记录不存在", 404)
|
||||
if post.status != "blocked":
|
||||
return fail("仅被拦截的信息可申诉", 400)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
reason = (payload.get("reason") or "").strip()
|
||||
if len(reason) < 2:
|
||||
return fail("申诉理由至少2个字符", 400)
|
||||
if post.appeal_status == "pending":
|
||||
return fail("该记录已在申诉处理中", 400)
|
||||
|
||||
post.appeal_status = "pending"
|
||||
post.appeal_reason = reason
|
||||
post.appeal_submitted_at = datetime.utcnow()
|
||||
post.appeal_admin_note = ""
|
||||
post.appeal_processed_at = None
|
||||
post.appeal_processed_by = None
|
||||
post.manual_review_status = "pending"
|
||||
db.session.commit()
|
||||
|
||||
return ok(_serialize_post(post), "申诉提交成功")
|
||||
|
||||
|
||||
@content_bp.get("/appeals/my")
|
||||
@jwt_required()
|
||||
def my_appeals():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||||
|
||||
pagination = (
|
||||
ContentPost.query.filter(ContentPost.user_id == user.id, ContentPost.appeal_status != "none")
|
||||
.order_by(ContentPost.id.desc())
|
||||
.paginate(page=page, per_page=page_size, error_out=False)
|
||||
)
|
||||
|
||||
return ok(
|
||||
{
|
||||
"items": [_serialize_post(item) for item in pagination.items],
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@content_bp.get("/posts/public")
|
||||
@jwt_required(optional=True)
|
||||
def public_feed():
|
||||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||||
|
||||
pagination = (
|
||||
ContentPost.query.filter_by(visibility="public", status="published")
|
||||
.order_by(ContentPost.id.desc())
|
||||
.paginate(page=page, per_page=page_size, error_out=False)
|
||||
)
|
||||
return ok(
|
||||
{
|
||||
"items": [_serialize_post(item) for item in pagination.items],
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
121
backend/app/routes/diet_routes.py
Normal file
121
backend/app/routes/diet_routes.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Blueprint, request
|
||||
from flask_jwt_extended import jwt_required
|
||||
|
||||
from app.extensions import db
|
||||
from app.models import DietStatus
|
||||
from app.utils.auth import current_user
|
||||
from app.utils.response import fail, ok
|
||||
|
||||
|
||||
diet_bp = Blueprint("diet", __name__)
|
||||
|
||||
|
||||
def _parse_status_payload(payload: dict) -> dict:
|
||||
return {
|
||||
"weight": float(payload.get("weight", 0)),
|
||||
"body_fat": float(payload.get("body_fat", 0)),
|
||||
"exercise_kcal": float(payload.get("exercise_kcal", 0)),
|
||||
"intake_kcal": float(payload.get("intake_kcal", 0)),
|
||||
"sleep_hours": float(payload.get("sleep_hours", 7)),
|
||||
"note": (payload.get("note") or "").strip(),
|
||||
}
|
||||
|
||||
|
||||
@diet_bp.post("/status")
|
||||
@jwt_required()
|
||||
def create_status():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
data = _parse_status_payload(payload)
|
||||
|
||||
if data["weight"] <= 0 or data["body_fat"] <= 0:
|
||||
return fail("体重和体脂率必须大于0", 400)
|
||||
|
||||
status = DietStatus(user_id=user.id, **data)
|
||||
if payload.get("recorded_at"):
|
||||
status.recorded_at = datetime.fromisoformat(payload["recorded_at"])
|
||||
|
||||
db.session.add(status)
|
||||
db.session.commit()
|
||||
return ok(status.to_dict(), "饮食状态记录成功")
|
||||
|
||||
|
||||
@diet_bp.put("/status/latest")
|
||||
@jwt_required()
|
||||
def update_latest_status():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
latest = (
|
||||
DietStatus.query.filter_by(user_id=user.id)
|
||||
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
|
||||
.first()
|
||||
)
|
||||
if not latest:
|
||||
return fail("暂无可编辑的饮食状态", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
data = _parse_status_payload(payload)
|
||||
for key, value in data.items():
|
||||
setattr(latest, key, value)
|
||||
|
||||
db.session.commit()
|
||||
return ok(latest.to_dict(), "最新饮食状态已更新")
|
||||
|
||||
|
||||
@diet_bp.get("/status/latest")
|
||||
@jwt_required()
|
||||
def get_latest_status():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
latest = (
|
||||
DietStatus.query.filter_by(user_id=user.id)
|
||||
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
|
||||
.first()
|
||||
)
|
||||
if not latest:
|
||||
return ok({})
|
||||
|
||||
return ok(latest.to_dict())
|
||||
|
||||
|
||||
@diet_bp.get("/history")
|
||||
@jwt_required()
|
||||
def history():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
limit = min(max(int(request.args.get("limit", 30) or 30), 1), 200)
|
||||
rows = (
|
||||
DietStatus.query.filter_by(user_id=user.id)
|
||||
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return ok([item.to_dict() for item in rows])
|
||||
|
||||
|
||||
@diet_bp.delete("/history/<int:status_id>")
|
||||
@jwt_required()
|
||||
def delete_history(status_id: int):
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
row = DietStatus.query.filter_by(id=status_id, user_id=user.id).first()
|
||||
if not row:
|
||||
return fail("记录不存在", 404)
|
||||
|
||||
db.session.delete(row)
|
||||
db.session.commit()
|
||||
return ok({}, "记录已删除")
|
||||
177
backend/app/routes/qa_routes.py
Normal file
177
backend/app/routes/qa_routes.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from flask import Blueprint, current_app, request
|
||||
from flask_jwt_extended import jwt_required
|
||||
|
||||
from app.rag.llm_client import LLMConfigManager, ask_dify, ask_fastgpt
|
||||
from app.rag.local_retriever import LocalKnowledgeRetriever
|
||||
from app.utils.auth import admin_required, current_user
|
||||
from app.utils.response import ok
|
||||
|
||||
|
||||
qa_bp = Blueprint("qa", __name__)
|
||||
|
||||
|
||||
def _retriever() -> LocalKnowledgeRetriever:
|
||||
return LocalKnowledgeRetriever(current_app.config["LOCAL_KB_PATH"])
|
||||
|
||||
|
||||
def _llm_config() -> dict:
|
||||
manager = LLMConfigManager(current_app.config["LLM_CONFIG_PATH"])
|
||||
return manager.load()
|
||||
|
||||
|
||||
def _default_provider(config: dict) -> str:
|
||||
provider = (config.get("active_provider") or "fastgpt").strip().lower() if isinstance(config, dict) else "fastgpt"
|
||||
return provider if provider in {"fastgpt", "dify"} else "fastgpt"
|
||||
|
||||
|
||||
def _context_text(hits: list) -> str:
|
||||
chunks = []
|
||||
for idx, item in enumerate(hits, start=1):
|
||||
chunks.append(f"[{idx}] 问: {item['question']}\n答: {item['answer']}")
|
||||
return "\n\n".join(chunks)
|
||||
|
||||
|
||||
def _local_answer(hits: list) -> str:
|
||||
if not hits:
|
||||
return "本地知识库暂未检索到高相关内容,建议补充问题细节。"
|
||||
top = hits[0]
|
||||
return f"根据本地知识库:{top['answer']}"
|
||||
|
||||
|
||||
def _advice_template(advice_type: str) -> str:
|
||||
templates = {
|
||||
"gain_muscle": "增肌建议:每天蛋白质建议 1.6-2.2g/kg,训练后 30 分钟内补充蛋白+碳水,优先鸡胸肉/牛肉/鸡蛋/燕麦。",
|
||||
"lose_fat": "减脂建议:保持 10%-20% 轻热量缺口,优先高蛋白高纤维食物,减少高糖饮料和深夜加餐。",
|
||||
"keto": "生酮建议:控制碳水通常小于 50g/天,优先健康脂肪和适量蛋白,并关注电解质补充。",
|
||||
"nutritionist": "营养师建议:结合体重、体脂、活动量和目标制定周计划,每 7 天复盘体重和围度变化。",
|
||||
"general": "通用建议:保证规律作息、三餐结构稳定、适量运动和充足饮水。",
|
||||
}
|
||||
return templates.get(advice_type, templates["general"])
|
||||
|
||||
|
||||
def _ask_by_provider(provider: str, config: dict, question: str, context: str, user_id: str, username: str):
|
||||
if provider == "dify":
|
||||
return ask_dify(config, question, context=context, user=username)
|
||||
return ask_fastgpt(config, question, context=context, custom_uid=user_id)
|
||||
|
||||
|
||||
@qa_bp.get("/kb/search")
|
||||
@jwt_required(optional=True)
|
||||
def search_kb():
|
||||
query = (request.args.get("query") or "").strip()
|
||||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 10)
|
||||
|
||||
retriever = _retriever()
|
||||
hits = retriever.search(query, top_k=top_k)
|
||||
return ok({"items": hits})
|
||||
|
||||
|
||||
@qa_bp.post("/ask")
|
||||
@jwt_required()
|
||||
def ask_nutrition_qa():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
question = (payload.get("question") or "").strip()
|
||||
if not question:
|
||||
return ok({"answer": "请输入问题", "items": []}, "empty_question")
|
||||
|
||||
mode = (payload.get("mode") or "auto").strip().lower()
|
||||
|
||||
llm_cfg = _llm_config()
|
||||
provider = (payload.get("provider") or _default_provider(llm_cfg)).strip().lower()
|
||||
if provider not in {"fastgpt", "dify"}:
|
||||
provider = _default_provider(llm_cfg)
|
||||
|
||||
retriever = _retriever()
|
||||
hits = retriever.search(question, top_k=3)
|
||||
context = _context_text(hits)
|
||||
|
||||
local_answer = _local_answer(hits)
|
||||
if mode == "local":
|
||||
return ok({"answer": local_answer, "provider": "local_kb", "items": hits})
|
||||
|
||||
user = current_user()
|
||||
llm_result = _ask_by_provider(
|
||||
provider,
|
||||
llm_cfg,
|
||||
question,
|
||||
context,
|
||||
str(user.id) if user else "",
|
||||
user.username if user else "anonymous",
|
||||
)
|
||||
|
||||
if mode == "llm":
|
||||
if llm_result.get("ok"):
|
||||
return ok({"answer": llm_result.get("answer"), "provider": provider, "items": hits})
|
||||
return ok(
|
||||
{
|
||||
"answer": f"LLM 调用失败,已回退本地知识库。{local_answer}",
|
||||
"provider": "local_kb_fallback",
|
||||
"items": hits,
|
||||
"llm_error": llm_result,
|
||||
}
|
||||
)
|
||||
|
||||
if llm_result.get("ok"):
|
||||
return ok({"answer": llm_result.get("answer"), "provider": provider, "items": hits})
|
||||
|
||||
return ok(
|
||||
{
|
||||
"answer": local_answer,
|
||||
"provider": "local_kb",
|
||||
"items": hits,
|
||||
"llm_error": llm_result,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@qa_bp.post("/advice")
|
||||
@jwt_required()
|
||||
def advice():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
advice_type = (payload.get("advice_type") or "general").strip().lower()
|
||||
question = (payload.get("question") or "请根据我的状态给建议").strip()
|
||||
|
||||
llm_cfg = _llm_config()
|
||||
provider = (payload.get("provider") or _default_provider(llm_cfg)).strip().lower()
|
||||
if provider not in {"fastgpt", "dify"}:
|
||||
provider = _default_provider(llm_cfg)
|
||||
|
||||
retriever = _retriever()
|
||||
hits = retriever.search(f"{advice_type} {question}", top_k=3)
|
||||
context = _context_text(hits)
|
||||
base_text = _advice_template(advice_type)
|
||||
|
||||
user = current_user()
|
||||
llm_result = _ask_by_provider(
|
||||
provider,
|
||||
llm_cfg,
|
||||
f"请给出{advice_type}方向的个性化饮食建议。用户补充: {question}",
|
||||
f"{base_text}\n\n{context}",
|
||||
str(user.id) if user else "",
|
||||
user.username if user else "anonymous",
|
||||
)
|
||||
|
||||
if llm_result.get("ok"):
|
||||
answer = llm_result.get("answer")
|
||||
source = provider
|
||||
else:
|
||||
answer = f"{base_text}\n\n(提示:LLM 暂不可用,当前为本地模板建议)"
|
||||
source = "local_template"
|
||||
|
||||
return ok(
|
||||
{
|
||||
"advice_type": advice_type,
|
||||
"answer": answer,
|
||||
"provider": source,
|
||||
"knowledge_hits": hits,
|
||||
"llm_status": llm_result,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@qa_bp.post("/kb/reload")
|
||||
@admin_required
|
||||
def reload_kb():
|
||||
retriever = _retriever()
|
||||
retriever.reload()
|
||||
return ok({"count": len(retriever.documents)}, "知识库已重载")
|
||||
135
backend/app/routes/recipe_routes.py
Normal file
135
backend/app/routes/recipe_routes.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from flask import Blueprint, current_app, request
|
||||
from flask_jwt_extended import jwt_required
|
||||
|
||||
from app.extensions import db
|
||||
from app.models import Recipe
|
||||
from app.utils.auth import admin_required
|
||||
from app.utils.response import fail, ok
|
||||
|
||||
|
||||
recipe_bp = Blueprint("recipe", __name__)
|
||||
|
||||
|
||||
def _recipe_payload(recipe: Recipe, payload: dict) -> None:
|
||||
recipe.name = payload.get("name", recipe.name)
|
||||
recipe.category = payload.get("category", recipe.category or "轻食")
|
||||
recipe.description = payload.get("description", recipe.description or "")
|
||||
recipe.calories = float(payload.get("calories", recipe.calories or 0))
|
||||
recipe.protein = float(payload.get("protein", recipe.protein or 0))
|
||||
recipe.fat = float(payload.get("fat", recipe.fat or 0))
|
||||
recipe.carbs = float(payload.get("carbs", recipe.carbs or 0))
|
||||
recipe.fiber = float(payload.get("fiber", recipe.fiber or 0))
|
||||
recipe.tags = payload.get("tags", recipe.tags or [])
|
||||
recipe.difficulty = payload.get("difficulty", recipe.difficulty or "easy")
|
||||
|
||||
|
||||
@recipe_bp.get("")
|
||||
@jwt_required(optional=True)
|
||||
def list_recipes():
|
||||
keyword = (request.args.get("keyword") or "").strip()
|
||||
category = (request.args.get("category") or "").strip()
|
||||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||||
page_size = min(max(int(request.args.get("page_size", 10) or 10), 1), 50)
|
||||
|
||||
query = Recipe.query
|
||||
if keyword:
|
||||
query = query.filter(Recipe.name.like(f"%{keyword}%"))
|
||||
if category:
|
||||
query = query.filter(Recipe.category == category)
|
||||
|
||||
pagination = query.order_by(Recipe.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
|
||||
data = {
|
||||
"items": [item.to_dict() for item in pagination.items],
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
return ok(data)
|
||||
|
||||
|
||||
@recipe_bp.get("/<int:recipe_id>")
|
||||
@jwt_required(optional=True)
|
||||
def get_recipe(recipe_id: int):
|
||||
recipe = Recipe.query.get(recipe_id)
|
||||
if not recipe:
|
||||
return fail("食谱不存在", 404)
|
||||
return ok(recipe.to_dict())
|
||||
|
||||
|
||||
@recipe_bp.post("")
|
||||
@admin_required
|
||||
def create_recipe():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
if not payload.get("name"):
|
||||
return fail("缺少食谱名称", 400)
|
||||
|
||||
recipe = Recipe()
|
||||
_recipe_payload(recipe, payload)
|
||||
|
||||
db.session.add(recipe)
|
||||
db.session.commit()
|
||||
return ok(recipe.to_dict(), "食谱创建成功")
|
||||
|
||||
|
||||
@recipe_bp.put("/<int:recipe_id>")
|
||||
@admin_required
|
||||
def update_recipe(recipe_id: int):
|
||||
recipe = Recipe.query.get(recipe_id)
|
||||
if not recipe:
|
||||
return fail("食谱不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
_recipe_payload(recipe, payload)
|
||||
db.session.commit()
|
||||
return ok(recipe.to_dict(), "食谱更新成功")
|
||||
|
||||
|
||||
@recipe_bp.delete("/<int:recipe_id>")
|
||||
@admin_required
|
||||
def delete_recipe(recipe_id: int):
|
||||
recipe = Recipe.query.get(recipe_id)
|
||||
if not recipe:
|
||||
return fail("食谱不存在", 404)
|
||||
|
||||
db.session.delete(recipe)
|
||||
db.session.commit()
|
||||
return ok({}, "食谱删除成功")
|
||||
|
||||
|
||||
@recipe_bp.post("/import")
|
||||
@admin_required
|
||||
def import_recipes():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
rows = payload.get("items")
|
||||
|
||||
if not isinstance(rows, list):
|
||||
seed_file = Path(current_app.root_path).parents[0] / "seed" / "recipes_seed.json"
|
||||
if seed_file.exists():
|
||||
with seed_file.open("r", encoding="utf-8-sig") as file:
|
||||
rows = json.load(file)
|
||||
else:
|
||||
return fail("导入数据为空,且未找到默认种子文件", 400)
|
||||
|
||||
created_count = 0
|
||||
updated_count = 0
|
||||
|
||||
for item in rows:
|
||||
name = (item.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
recipe = Recipe.query.filter_by(name=name).first()
|
||||
if recipe:
|
||||
_recipe_payload(recipe, item)
|
||||
updated_count += 1
|
||||
else:
|
||||
recipe = Recipe()
|
||||
_recipe_payload(recipe, item)
|
||||
db.session.add(recipe)
|
||||
created_count += 1
|
||||
|
||||
db.session.commit()
|
||||
return ok({"created": created_count, "updated": updated_count}, "食谱导入完成")
|
||||
314
backend/app/routes/recommend_routes.py
Normal file
314
backend/app/routes/recommend_routes.py
Normal file
@@ -0,0 +1,314 @@
|
||||
from statistics import mean
|
||||
|
||||
from flask import Blueprint, current_app, request
|
||||
from flask_jwt_extended import jwt_required
|
||||
|
||||
from app.extensions import db
|
||||
from app.ml.rf_recommender import RandomForestDietRecommender, merge_profile_with_history
|
||||
from app.models import DietStatus, RecommendationLog, Recipe
|
||||
from app.utils.auth import current_user
|
||||
from app.utils.response import fail, ok
|
||||
|
||||
|
||||
recommend_bp = Blueprint("recommend", __name__)
|
||||
|
||||
|
||||
def _build_profile(user, latest_status):
|
||||
profile = {
|
||||
"age": user.age,
|
||||
"height_cm": user.height_cm,
|
||||
"goal": user.goal,
|
||||
"occupation": user.occupation,
|
||||
"weight": 65,
|
||||
"body_fat": 20,
|
||||
"exercise_kcal": 300,
|
||||
"intake_kcal": 1800,
|
||||
}
|
||||
if latest_status:
|
||||
profile.update(
|
||||
{
|
||||
"weight": latest_status.weight,
|
||||
"body_fat": latest_status.body_fat,
|
||||
"exercise_kcal": latest_status.exercise_kcal,
|
||||
"intake_kcal": latest_status.intake_kcal,
|
||||
}
|
||||
)
|
||||
return profile
|
||||
|
||||
|
||||
def _goal_adjust(goal: str, recipe: Recipe) -> float:
|
||||
if goal == "lose_fat":
|
||||
return recipe.protein * 0.8 - recipe.fat * 0.4 - max(recipe.calories - 500, 0) * 0.04
|
||||
if goal == "gain_muscle":
|
||||
return recipe.protein * 1.0 + recipe.carbs * 0.25
|
||||
if goal == "keto":
|
||||
return recipe.fat * 0.5 - recipe.carbs * 0.8
|
||||
return recipe.protein * 0.3 + recipe.fiber * 1.2
|
||||
|
||||
|
||||
def _occupation_adjust(occupation: str, recipe: Recipe) -> float:
|
||||
occupation = (occupation or "").lower()
|
||||
if occupation in {"developer", "office"}:
|
||||
return recipe.fiber * 0.9 + recipe.protein * 0.2
|
||||
if occupation in {"fitness", "manual"}:
|
||||
return recipe.protein * 0.6 + recipe.carbs * 0.2
|
||||
if occupation in {"teacher", "student"}:
|
||||
return recipe.carbs * 0.15 + recipe.protein * 0.3
|
||||
return recipe.protein * 0.25
|
||||
|
||||
|
||||
def _health_adjust(profile: dict, recipe: Recipe) -> float:
|
||||
score = 70.0
|
||||
body_fat = float(profile.get("body_fat", 20))
|
||||
intake_kcal = float(profile.get("intake_kcal", 1800))
|
||||
exercise_kcal = float(profile.get("exercise_kcal", 300))
|
||||
|
||||
if body_fat > 28:
|
||||
score += recipe.protein * 0.5
|
||||
score -= max(recipe.calories - 480, 0) * 0.05
|
||||
elif body_fat < 14:
|
||||
score += recipe.carbs * 0.15 + recipe.fat * 0.1
|
||||
|
||||
if intake_kcal > exercise_kcal + 1700:
|
||||
score -= max(recipe.calories - 450, 0) * 0.04
|
||||
else:
|
||||
score += recipe.protein * 0.3
|
||||
|
||||
score += min(recipe.fiber * 1.0, 10)
|
||||
return round(score, 2)
|
||||
|
||||
|
||||
def _serialize_with_score(recipes):
|
||||
return [
|
||||
{
|
||||
**item[0].to_dict(),
|
||||
"score": round(float(item[1]), 2),
|
||||
"reason": item[2],
|
||||
}
|
||||
for item in recipes
|
||||
]
|
||||
|
||||
|
||||
def _save_log(user_id: int, rec_type: str, payload: dict) -> None:
|
||||
log = RecommendationLog(user_id=user_id, rec_type=rec_type, payload=payload)
|
||||
db.session.add(log)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@recommend_bp.get("/current")
|
||||
@jwt_required()
|
||||
def recommend_current_status():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
|
||||
|
||||
latest = (
|
||||
DietStatus.query.filter_by(user_id=user.id)
|
||||
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
|
||||
.first()
|
||||
)
|
||||
history = (
|
||||
DietStatus.query.filter_by(user_id=user.id)
|
||||
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
|
||||
.limit(14)
|
||||
.all()
|
||||
)
|
||||
|
||||
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
|
||||
if not recipes:
|
||||
return fail("当前没有可推荐食谱,请先导入食谱", 400)
|
||||
|
||||
profile = _build_profile(user, latest)
|
||||
profile = merge_profile_with_history(profile, history)
|
||||
|
||||
recommender = RandomForestDietRecommender(current_app.config["MODEL_PATH"])
|
||||
items = recommender.recommend(profile, recipes, top_k=top_k)
|
||||
|
||||
response_data = {
|
||||
"type": "current_status_rf",
|
||||
"profile": profile,
|
||||
"items": items,
|
||||
}
|
||||
_save_log(user.id, "current_status_rf", response_data)
|
||||
return ok(response_data)
|
||||
|
||||
|
||||
@recommend_bp.get("/health")
|
||||
@jwt_required()
|
||||
def recommend_by_health():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
|
||||
|
||||
latest = (
|
||||
DietStatus.query.filter_by(user_id=user.id)
|
||||
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
|
||||
.first()
|
||||
)
|
||||
profile = _build_profile(user, latest)
|
||||
|
||||
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
|
||||
if not recipes:
|
||||
return fail("当前没有可推荐食谱,请先导入食谱", 400)
|
||||
|
||||
scored = []
|
||||
for recipe in recipes:
|
||||
score = _health_adjust(profile, recipe)
|
||||
reason = "结合当前体脂和摄入消耗状态,推荐更合理的营养结构。"
|
||||
scored.append((recipe, score, reason))
|
||||
|
||||
scored.sort(key=lambda row: row[1], reverse=True)
|
||||
items = _serialize_with_score(scored[:top_k])
|
||||
|
||||
response_data = {
|
||||
"type": "health_state",
|
||||
"profile": profile,
|
||||
"items": items,
|
||||
}
|
||||
_save_log(user.id, "health_state", response_data)
|
||||
return ok(response_data)
|
||||
|
||||
|
||||
@recommend_bp.get("/plan/history")
|
||||
@jwt_required()
|
||||
def recommend_plan_by_history():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
|
||||
history = (
|
||||
DietStatus.query.filter_by(user_id=user.id)
|
||||
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
|
||||
.limit(30)
|
||||
.all()
|
||||
)
|
||||
if len(history) < 3:
|
||||
return fail("历史数据太少,至少记录3天后可用该功能", 400)
|
||||
|
||||
avg_weight = mean([h.weight for h in history])
|
||||
early_weight = history[-1].weight
|
||||
latest_weight = history[0].weight
|
||||
weight_trend = latest_weight - early_weight
|
||||
|
||||
avg_intake = mean([h.intake_kcal for h in history])
|
||||
avg_exercise = mean([h.exercise_kcal for h in history])
|
||||
|
||||
cal_limit = 520
|
||||
strategy = "维持计划"
|
||||
if weight_trend > 0.8 or avg_intake > avg_exercise + 1700:
|
||||
cal_limit = 440
|
||||
strategy = "减脂优先计划"
|
||||
elif weight_trend < -0.8:
|
||||
cal_limit = 620
|
||||
strategy = "增肌恢复计划"
|
||||
|
||||
recipes = Recipe.query.filter(Recipe.calories <= cal_limit).order_by(Recipe.protein.desc()).limit(top_k).all()
|
||||
items = []
|
||||
for recipe in recipes:
|
||||
items.append(
|
||||
{
|
||||
**recipe.to_dict(),
|
||||
"score": round(recipe.protein * 2 - recipe.fat * 0.3, 2),
|
||||
"reason": f"基于历史趋势({strategy}),优先推荐该热量区间食谱。",
|
||||
}
|
||||
)
|
||||
|
||||
response_data = {
|
||||
"type": "history_plan",
|
||||
"strategy": strategy,
|
||||
"metrics": {
|
||||
"avg_weight": round(avg_weight, 2),
|
||||
"weight_trend": round(weight_trend, 2),
|
||||
"avg_intake": round(avg_intake, 1),
|
||||
"avg_exercise": round(avg_exercise, 1),
|
||||
},
|
||||
"items": items,
|
||||
}
|
||||
_save_log(user.id, "history_plan", response_data)
|
||||
return ok(response_data)
|
||||
|
||||
|
||||
@recommend_bp.get("/plan/goal")
|
||||
@jwt_required()
|
||||
def recommend_plan_by_goal():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
|
||||
goal = request.args.get("goal") or user.goal
|
||||
|
||||
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
|
||||
if not recipes:
|
||||
return fail("当前没有可推荐食谱,请先导入食谱", 400)
|
||||
|
||||
scored = []
|
||||
for recipe in recipes:
|
||||
score = 60 + _goal_adjust(goal, recipe)
|
||||
reason = f"根据目标({goal})筛选营养比例更匹配的食谱。"
|
||||
scored.append((recipe, score, reason))
|
||||
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
items = _serialize_with_score(scored[:top_k])
|
||||
|
||||
response_data = {
|
||||
"type": "future_goal_plan",
|
||||
"goal": goal,
|
||||
"items": items,
|
||||
}
|
||||
_save_log(user.id, "future_goal_plan", response_data)
|
||||
return ok(response_data)
|
||||
|
||||
|
||||
@recommend_bp.get("/plan/occupation")
|
||||
@jwt_required()
|
||||
def recommend_plan_by_occupation():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
|
||||
occupation = request.args.get("occupation") or user.occupation
|
||||
|
||||
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
|
||||
if not recipes:
|
||||
return fail("当前没有可推荐食谱,请先导入食谱", 400)
|
||||
|
||||
scored = []
|
||||
for recipe in recipes:
|
||||
score = 65 + _occupation_adjust(occupation, recipe)
|
||||
reason = f"结合职业({occupation})的能量消耗特点进行推荐。"
|
||||
scored.append((recipe, score, reason))
|
||||
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
items = _serialize_with_score(scored[:top_k])
|
||||
|
||||
response_data = {
|
||||
"type": "occupation_plan",
|
||||
"occupation": occupation,
|
||||
"items": items,
|
||||
}
|
||||
_save_log(user.id, "occupation_plan", response_data)
|
||||
return ok(response_data)
|
||||
|
||||
|
||||
@recommend_bp.get("/logs")
|
||||
@jwt_required()
|
||||
def my_recommend_logs():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
limit = min(max(int(request.args.get("limit", 20) or 20), 1), 100)
|
||||
rows = (
|
||||
RecommendationLog.query.filter_by(user_id=user.id)
|
||||
.order_by(RecommendationLog.id.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return ok([item.to_dict() for item in rows])
|
||||
334
backend/app/routes/spam_routes.py
Normal file
334
backend/app/routes/spam_routes.py
Normal file
@@ -0,0 +1,334 @@
|
||||
from flask import Blueprint, current_app, request
|
||||
from flask_jwt_extended import jwt_required
|
||||
|
||||
from app.extensions import db
|
||||
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
|
||||
from app.models import DetectionConfig, SpamPredictionLog, SpamTrainingSample
|
||||
from app.utils.auth import admin_required, current_user
|
||||
from app.utils.response import fail, ok
|
||||
|
||||
|
||||
spam_bp = Blueprint("spam", __name__)
|
||||
|
||||
|
||||
def _classifier() -> NaiveBayesSpamClassifier:
|
||||
return NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
|
||||
|
||||
|
||||
def _active_samples() -> list[dict]:
|
||||
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
|
||||
return [{"text": row.text, "label": row.label} for row in rows]
|
||||
|
||||
|
||||
def _ensure_ready() -> NaiveBayesSpamClassifier:
|
||||
clf = _classifier()
|
||||
samples = _active_samples()
|
||||
clf.ensure_ready(samples)
|
||||
return clf
|
||||
|
||||
|
||||
def _threshold() -> float:
|
||||
row = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
|
||||
return float(row.spam_threshold) if row else 0.75
|
||||
|
||||
|
||||
@spam_bp.post("/predict")
|
||||
@jwt_required()
|
||||
def predict_one():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
text = (payload.get("text") or "").strip()
|
||||
if len(text) < 2:
|
||||
return fail("请输入至少2个字符的待识别文本", 400)
|
||||
|
||||
clf = _ensure_ready()
|
||||
result = clf.predict(text)
|
||||
threshold = _threshold()
|
||||
blocked = float(result["spam_probability"]) >= threshold
|
||||
|
||||
row = SpamPredictionLog(
|
||||
user_id=user.id,
|
||||
text=result["text"],
|
||||
prediction=result["prediction"],
|
||||
spam_probability=result["spam_probability"],
|
||||
ham_probability=result["ham_probability"],
|
||||
confidence=result["confidence"],
|
||||
reason_tokens=result["reason_tokens"],
|
||||
model_version=result.get("model_version", ""),
|
||||
)
|
||||
db.session.add(row)
|
||||
db.session.commit()
|
||||
|
||||
return ok({**result, "log_id": row.id, "threshold": threshold, "blocked_by_threshold": blocked}, "识别成功")
|
||||
|
||||
|
||||
@spam_bp.post("/predict/batch")
|
||||
@jwt_required()
|
||||
def predict_batch():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
items = payload.get("items") or []
|
||||
if not isinstance(items, list) or not items:
|
||||
return fail("items 必须是非空数组", 400)
|
||||
if len(items) > 100:
|
||||
return fail("单次最多识别100条", 400)
|
||||
|
||||
clf = _ensure_ready()
|
||||
rows = []
|
||||
results = []
|
||||
threshold = _threshold()
|
||||
|
||||
for text in items:
|
||||
content = (text or "").strip()
|
||||
if len(content) < 2:
|
||||
continue
|
||||
result = clf.predict(content)
|
||||
result["blocked_by_threshold"] = float(result["spam_probability"]) >= threshold
|
||||
rows.append(
|
||||
SpamPredictionLog(
|
||||
user_id=user.id,
|
||||
text=result["text"],
|
||||
prediction=result["prediction"],
|
||||
spam_probability=result["spam_probability"],
|
||||
ham_probability=result["ham_probability"],
|
||||
confidence=result["confidence"],
|
||||
reason_tokens=result["reason_tokens"],
|
||||
model_version=result.get("model_version", ""),
|
||||
)
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
if not rows:
|
||||
return fail("没有可识别的有效文本", 400)
|
||||
|
||||
db.session.add_all(rows)
|
||||
db.session.commit()
|
||||
|
||||
spam_count = len([item for item in results if item["prediction"] == "spam"])
|
||||
blocked_count = len([item for item in results if item["blocked_by_threshold"]])
|
||||
ham_count = len(results) - spam_count
|
||||
|
||||
return ok(
|
||||
{
|
||||
"items": results,
|
||||
"summary": {
|
||||
"total": len(results),
|
||||
"spam_count": spam_count,
|
||||
"ham_count": ham_count,
|
||||
"blocked_count": blocked_count,
|
||||
"spam_ratio": round(spam_count / len(results), 4) if results else 0,
|
||||
"blocked_ratio": round(blocked_count / len(results), 4) if results else 0,
|
||||
"threshold": threshold,
|
||||
},
|
||||
},
|
||||
"批量识别完成",
|
||||
)
|
||||
|
||||
|
||||
@spam_bp.get("/history")
|
||||
@jwt_required()
|
||||
def my_history():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||||
|
||||
pagination = (
|
||||
SpamPredictionLog.query.filter_by(user_id=user.id)
|
||||
.order_by(SpamPredictionLog.id.desc())
|
||||
.paginate(page=page, per_page=page_size, error_out=False)
|
||||
)
|
||||
|
||||
return ok(
|
||||
{
|
||||
"items": [item.to_dict() for item in pagination.items],
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@spam_bp.delete("/history/<int:log_id>")
|
||||
@jwt_required()
|
||||
def delete_history(log_id: int):
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
row = SpamPredictionLog.query.filter_by(id=log_id, user_id=user.id).first()
|
||||
if not row:
|
||||
return fail("记录不存在", 404)
|
||||
|
||||
db.session.delete(row)
|
||||
db.session.commit()
|
||||
return ok({}, "记录已删除")
|
||||
|
||||
|
||||
@spam_bp.post("/feedback")
|
||||
@jwt_required()
|
||||
def save_feedback():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
text = (payload.get("text") or "").strip()
|
||||
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
|
||||
|
||||
if len(text) < 2:
|
||||
return fail("文本至少2个字符", 400)
|
||||
if not label:
|
||||
return fail("label 必须是 spam 或 ham", 400)
|
||||
|
||||
row = SpamTrainingSample(text=text, label=label, source="feedback", created_by=user.id, is_active=True)
|
||||
db.session.add(row)
|
||||
db.session.commit()
|
||||
return ok(row.to_dict(), "反馈样本已记录")
|
||||
|
||||
|
||||
@spam_bp.get("/model/info")
|
||||
@jwt_required(optional=True)
|
||||
def model_info():
|
||||
clf = _classifier()
|
||||
clf.load()
|
||||
info = clf.model_info()
|
||||
info["threshold"] = _threshold()
|
||||
return ok(info)
|
||||
|
||||
|
||||
@spam_bp.post("/train")
|
||||
@admin_required
|
||||
def train_model():
|
||||
clf = _classifier()
|
||||
samples = _active_samples()
|
||||
metadata = clf.train(samples)
|
||||
return ok(metadata, "模型训练完成")
|
||||
|
||||
|
||||
@spam_bp.get("/samples")
|
||||
@admin_required
|
||||
def list_samples():
|
||||
keyword = (request.args.get("keyword") or "").strip()
|
||||
label = (request.args.get("label") or "").strip().lower()
|
||||
page = max(int(request.args.get("page", 1) or 1), 1)
|
||||
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
|
||||
|
||||
query = SpamTrainingSample.query
|
||||
if keyword:
|
||||
query = query.filter(SpamTrainingSample.text.like(f"%{keyword}%"))
|
||||
if label in {"spam", "ham"}:
|
||||
query = query.filter(SpamTrainingSample.label == label)
|
||||
|
||||
pagination = query.order_by(SpamTrainingSample.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
|
||||
return ok(
|
||||
{
|
||||
"items": [item.to_dict() for item in pagination.items],
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@spam_bp.post("/samples")
|
||||
@admin_required
|
||||
def create_sample():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
text = (payload.get("text") or "").strip()
|
||||
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
|
||||
|
||||
if len(text) < 2:
|
||||
return fail("文本至少2个字符", 400)
|
||||
if not label:
|
||||
return fail("label 必须是 spam 或 ham", 400)
|
||||
|
||||
user = current_user()
|
||||
row = SpamTrainingSample(text=text, label=label, source="import", created_by=user.id if user else None, is_active=True)
|
||||
db.session.add(row)
|
||||
db.session.commit()
|
||||
return ok(row.to_dict(), "样本创建成功")
|
||||
|
||||
|
||||
@spam_bp.put("/samples/<int:sample_id>")
|
||||
@admin_required
|
||||
def update_sample(sample_id: int):
|
||||
row = SpamTrainingSample.query.get(sample_id)
|
||||
if not row:
|
||||
return fail("样本不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
if "text" in payload:
|
||||
text = (payload.get("text") or "").strip()
|
||||
if len(text) < 2:
|
||||
return fail("文本至少2个字符", 400)
|
||||
row.text = text
|
||||
if "label" in payload:
|
||||
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
|
||||
if not label:
|
||||
return fail("label 必须是 spam 或 ham", 400)
|
||||
row.label = label
|
||||
if "is_active" in payload:
|
||||
row.is_active = bool(payload.get("is_active"))
|
||||
|
||||
db.session.commit()
|
||||
return ok(row.to_dict(), "样本更新成功")
|
||||
|
||||
|
||||
@spam_bp.delete("/samples/<int:sample_id>")
|
||||
@admin_required
|
||||
def delete_sample(sample_id: int):
|
||||
row = SpamTrainingSample.query.get(sample_id)
|
||||
if not row:
|
||||
return fail("样本不存在", 404)
|
||||
|
||||
db.session.delete(row)
|
||||
db.session.commit()
|
||||
return ok({}, "样本已删除")
|
||||
|
||||
|
||||
@spam_bp.post("/samples/import")
|
||||
@admin_required
|
||||
def import_samples():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
items = payload.get("items") or []
|
||||
if not isinstance(items, list) or not items:
|
||||
return fail("items 必须是非空数组", 400)
|
||||
|
||||
user = current_user()
|
||||
created = 0
|
||||
updated = 0
|
||||
|
||||
for item in items:
|
||||
text = (item.get("text") or "").strip()
|
||||
label = NaiveBayesSpamClassifier.normalize_label(item.get("label"))
|
||||
if len(text) < 2 or not label:
|
||||
continue
|
||||
|
||||
row = SpamTrainingSample.query.filter_by(text=text).first()
|
||||
if row:
|
||||
row.label = label
|
||||
row.is_active = bool(item.get("is_active", True))
|
||||
row.source = item.get("source") or row.source
|
||||
updated += 1
|
||||
else:
|
||||
row = SpamTrainingSample(
|
||||
text=text,
|
||||
label=label,
|
||||
source=item.get("source") or "import",
|
||||
created_by=user.id if user else None,
|
||||
is_active=bool(item.get("is_active", True)),
|
||||
)
|
||||
db.session.add(row)
|
||||
created += 1
|
||||
|
||||
db.session.commit()
|
||||
return ok({"created": created, "updated": updated}, "样本导入完成")
|
||||
62
backend/app/routes/user_routes.py
Normal file
62
backend/app/routes/user_routes.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from flask import Blueprint, request
|
||||
from flask_jwt_extended import jwt_required
|
||||
|
||||
from app.extensions import db
|
||||
from app.models import User
|
||||
from app.utils.auth import current_user
|
||||
from app.utils.response import fail, ok
|
||||
|
||||
|
||||
user_bp = Blueprint("user", __name__)
|
||||
|
||||
|
||||
@user_bp.get("/profile")
|
||||
@jwt_required()
|
||||
def get_profile():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
return ok(user.to_dict())
|
||||
|
||||
|
||||
@user_bp.put("/profile")
|
||||
@jwt_required()
|
||||
def update_profile():
|
||||
user = current_user()
|
||||
if not user:
|
||||
return fail("用户不存在", 404)
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
if "nickname" in payload:
|
||||
user.nickname = (payload.get("nickname") or user.nickname).strip() or user.nickname
|
||||
if "company" in payload:
|
||||
user.company = (payload.get("company") or "").strip()
|
||||
if "title" in payload:
|
||||
user.title = (payload.get("title") or "").strip()
|
||||
if "phone" in payload:
|
||||
user.phone = (payload.get("phone") or "").strip()
|
||||
|
||||
new_password = payload.get("password")
|
||||
if new_password:
|
||||
if len(new_password) < 6:
|
||||
return fail("新密码至少6位", 400)
|
||||
user.set_password(new_password)
|
||||
|
||||
db.session.commit()
|
||||
return ok(user.to_dict(), "个人信息更新成功")
|
||||
|
||||
|
||||
@user_bp.get("/search")
|
||||
@jwt_required()
|
||||
def search_users():
|
||||
keyword = (request.args.get("keyword") or "").strip()
|
||||
if not keyword:
|
||||
return ok([])
|
||||
|
||||
users = (
|
||||
User.query.filter(User.username.like(f"%{keyword}%") | User.nickname.like(f"%{keyword}%"))
|
||||
.order_by(User.id.desc())
|
||||
.limit(20)
|
||||
.all()
|
||||
)
|
||||
return ok([item.to_dict() for item in users])
|
||||
Reference in New Issue
Block a user