This commit is contained in:
刘正航
2026-04-21 22:45:19 +08:00
commit b5237f9038
159 changed files with 7769 additions and 0 deletions

37
backend/app/__init__.py Normal file
View File

@@ -0,0 +1,37 @@
from flask import Flask
from app.config import Config
from app.extensions import cors, db, jwt
from app.routes.admin_routes import admin_bp
from app.routes.auth_routes import auth_bp
from app.routes.content_routes import content_bp
from app.routes.spam_routes import spam_bp
from app.routes.user_routes import user_bp
def create_app() -> Flask:
app = Flask(__name__)
app.config.from_object(Config)
db.init_app(app)
jwt.init_app(app)
cors.init_app(app, supports_credentials=True)
app.register_blueprint(auth_bp, url_prefix="/api/auth")
app.register_blueprint(user_bp, url_prefix="/api/user")
app.register_blueprint(spam_bp, url_prefix="/api/spam")
app.register_blueprint(content_bp, url_prefix="/api/content")
app.register_blueprint(admin_bp, url_prefix="/api/admin")
@app.get("/api/health")
def health_check():
return {
"code": 0,
"message": "ok",
"data": {
"service": "spam-detect-backend",
"version": "2.1.0",
},
}
return app

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

37
backend/app/config.py Normal file
View File

@@ -0,0 +1,37 @@
import json
import os
from pathlib import Path
from urllib.parse import quote_plus
BASE_DIR = Path(__file__).resolve().parents[1]
MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json"
def load_mysql_config() -> dict:
if not MYSQL_CONFIG_PATH.exists():
raise FileNotFoundError(f"未找到 MySQL 配置文件: {MYSQL_CONFIG_PATH}")
with MYSQL_CONFIG_PATH.open("r", encoding="utf-8-sig") as file:
return json.load(file)
def build_mysql_uri(mysql_cfg: dict) -> str:
user = mysql_cfg["user"]
password = quote_plus(mysql_cfg["password"])
host = mysql_cfg.get("host", "127.0.0.1")
port = mysql_cfg.get("port", 3306)
database = mysql_cfg["database"]
charset = mysql_cfg.get("charset", "utf8mb4")
return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}?charset={charset}"
class Config:
MYSQL_CONFIG = load_mysql_config()
SQLALCHEMY_DATABASE_URI = build_mysql_uri(MYSQL_CONFIG)
SQLALCHEMY_TRACK_MODIFICATIONS = False
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "replace-this-jwt-secret-key-in-production")
SECRET_KEY = os.getenv("FLASK_SECRET_KEY", "replace-this-flask-secret-key-in-production")
SPAM_DATASET_PATH = str(BASE_DIR / "seed" / "spam_samples_seed.json")
NB_MODEL_PATH = str(BASE_DIR / "models" / "spam_nb_model.joblib")

View File

@@ -0,0 +1,8 @@
from flask_cors import CORS
from flask_jwt_extended import JWTManager
from flask_sqlalchemy import SQLAlchemy
db = SQLAlchemy()
jwt = JWTManager()
cors = CORS()

View File

@@ -0,0 +1 @@

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,156 @@
import hashlib
from collections import Counter
from datetime import datetime
from pathlib import Path
import joblib
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
VALID_LABELS = {"spam", "ham"}
class NaiveBayesSpamClassifier:
def __init__(self, model_path: str):
self.model_path = Path(model_path)
self.vectorizer = None
self.model = None
self.metadata = {}
@staticmethod
def normalize_label(label: str) -> str:
value = (label or "").strip().lower()
return value if value in VALID_LABELS else ""
@staticmethod
def normalize_text(text: str) -> str:
return " ".join((text or "").strip().split())
@staticmethod
def _to_metadata(samples: list[dict], version_seed: str) -> dict:
dist = Counter([item["label"] for item in samples])
digest = hashlib.md5(version_seed.encode("utf-8")).hexdigest()[:16]
return {
"trained_at": datetime.utcnow().isoformat(),
"sample_count": len(samples),
"label_distribution": dict(dist),
"version": f"nb-{digest}",
}
def train(self, samples: list[dict]) -> dict:
clean_samples = []
for row in samples:
text = self.normalize_text(row.get("text"))
label = self.normalize_label(row.get("label"))
if not text or not label:
continue
clean_samples.append({"text": text, "label": label})
if len(clean_samples) < 10:
raise ValueError("训练样本太少至少需要10条有效样本")
texts = [item["text"] for item in clean_samples]
labels = [item["label"] for item in clean_samples]
vectorizer = TfidfVectorizer(analyzer="char", ngram_range=(1, 2), min_df=1)
x = vectorizer.fit_transform(texts)
model = MultinomialNB(alpha=0.4)
model.fit(x, labels)
version_seed = "||".join([f"{item['label']}::{item['text']}" for item in clean_samples])
metadata = self._to_metadata(clean_samples, version_seed)
self.model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump({"vectorizer": vectorizer, "model": model, "metadata": metadata}, self.model_path)
self.vectorizer = vectorizer
self.model = model
self.metadata = metadata
return metadata
def load(self) -> bool:
if not self.model_path.exists():
return False
payload = joblib.load(self.model_path)
self.vectorizer = payload.get("vectorizer")
self.model = payload.get("model")
self.metadata = payload.get("metadata", {})
return self.vectorizer is not None and self.model is not None
def ensure_ready(self, samples: list[dict]) -> dict:
if self.load():
return self.metadata
return self.train(samples)
def predict(self, text: str) -> dict:
if self.vectorizer is None or self.model is None:
raise RuntimeError("模型未加载,请先训练")
cleaned = self.normalize_text(text)
if len(cleaned) < 2:
raise ValueError("待识别文本至少2个字符")
x = self.vectorizer.transform([cleaned])
probs = self.model.predict_proba(x)[0]
classes = list(self.model.classes_)
spam_idx = classes.index("spam") if "spam" in classes else 0
ham_idx = classes.index("ham") if "ham" in classes else 0
spam_prob = float(probs[spam_idx])
ham_prob = float(probs[ham_idx])
prediction = "spam" if spam_prob >= ham_prob else "ham"
reason_tokens = self._extract_reason_tokens(cleaned, classes, x)
confidence = max(spam_prob, ham_prob)
return {
"text": cleaned,
"prediction": prediction,
"prediction_text": "垃圾信息" if prediction == "spam" else "正常信息",
"spam_probability": round(spam_prob, 4),
"ham_probability": round(ham_prob, 4),
"confidence": round(confidence, 4),
"reason_tokens": reason_tokens,
"model_version": self.metadata.get("version", ""),
"trained_at": self.metadata.get("trained_at"),
}
def _extract_reason_tokens(self, text: str, classes: list[str], x_row) -> list[str]:
try:
vocab = self.vectorizer.vocabulary_
feature_names = self.vectorizer.get_feature_names_out()
class_log_prob = self.model.feature_log_prob_
spam_idx = classes.index("spam") if "spam" in classes else 0
ham_idx = classes.index("ham") if "ham" in classes else 0
token_counter = Counter()
for idx in x_row.nonzero()[1]:
token = feature_names[idx]
token_counter[token] += 1
scored = []
for token in token_counter:
idx = vocab.get(token)
if idx is None:
continue
delta = class_log_prob[spam_idx][idx] - class_log_prob[ham_idx][idx]
scored.append((token, delta))
scored.sort(key=lambda row: abs(row[1]), reverse=True)
return [token for token, _ in scored[:5]]
except Exception:
return list(text[:5])
def model_info(self) -> dict:
return {
"ready": self.vectorizer is not None and self.model is not None,
"model_path": str(self.model_path),
"version": self.metadata.get("version", ""),
"trained_at": self.metadata.get("trained_at"),
"sample_count": int(self.metadata.get("sample_count", 0) or 0),
"label_distribution": self.metadata.get("label_distribution", {}),
}

View File

@@ -0,0 +1,252 @@
import hashlib
import json
from datetime import datetime
from pathlib import Path
import joblib
import numpy as np
from sklearn.ensemble import RandomForestRegressor
GOAL_MAP = {
"maintain": 0,
"lose_fat": 1,
"gain_muscle": 2,
"keto": 3,
}
OCCUPATION_MAP = {
"通用": 0,
"student": 1,
"office": 2,
"teacher": 3,
"developer": 4,
"healthcare": 5,
"fitness": 6,
"manual": 7,
}
class RandomForestDietRecommender:
def __init__(self, model_path: str):
self.model_path = Path(model_path)
self.model = None
self.recipe_signature = None
@staticmethod
def _encode_goal(goal: str) -> int:
return GOAL_MAP.get(goal or "maintain", 0)
@staticmethod
def _encode_occupation(occupation: str) -> int:
occupation = occupation or "通用"
if occupation in OCCUPATION_MAP:
return OCCUPATION_MAP[occupation]
return OCCUPATION_MAP["通用"]
def _signature(self, recipes: list) -> str:
raw = [
{
"id": item.id,
"name": item.name,
"calories": item.calories,
"protein": item.protein,
"fat": item.fat,
"carbs": item.carbs,
"fiber": item.fiber,
"updated": item.updated_at.isoformat() if item.updated_at else "",
}
for item in recipes
]
raw_json = json.dumps(raw, ensure_ascii=False, sort_keys=True)
return hashlib.md5(raw_json.encode("utf-8")).hexdigest()
@staticmethod
def _daily_target_kcal(profile: dict) -> float:
goal = profile.get("goal", "maintain")
baseline = 1800 + float(profile.get("exercise_kcal", 0)) * 0.4
if goal == "lose_fat":
baseline *= 0.82
elif goal == "gain_muscle":
baseline *= 1.12
elif goal == "keto":
baseline *= 0.9
return max(baseline, 1200)
@staticmethod
def _heuristic_score(profile: dict, recipe) -> float:
goal = profile.get("goal", "maintain")
daily_target = RandomForestDietRecommender._daily_target_kcal(profile)
target_per_meal = daily_target / 3
cal_gap_ratio = abs(recipe.calories - target_per_meal) / max(target_per_meal, 1)
protein_ratio = recipe.protein / max(recipe.calories, 1)
carbs_ratio = recipe.carbs / max(recipe.calories, 1)
fat_ratio = recipe.fat / max(recipe.calories, 1)
score = 100.0
score -= min(cal_gap_ratio * 55, 50)
if goal == "lose_fat":
score += min(recipe.protein * 0.6, 18)
score -= max((recipe.fat - 20) * 0.7, 0)
score -= max((recipe.carbs - 55) * 0.3, 0)
elif goal == "gain_muscle":
score += min(recipe.protein * 0.8, 26)
score += min(recipe.carbs * 0.2, 10)
elif goal == "keto":
score += min(recipe.fat * 0.4, 18)
score -= max(recipe.carbs - 30, 0) * 0.8
else:
score += min(recipe.fiber * 1.2, 8)
body_fat = float(profile.get("body_fat", 20))
if body_fat > 28:
score -= max(recipe.calories - 520, 0) * 0.03
intake_kcal = float(profile.get("intake_kcal", 1800))
if intake_kcal > daily_target:
score -= max(recipe.calories - 460, 0) * 0.02
score += np.clip((protein_ratio - 0.12) * 100, -8, 8)
score += np.clip((0.08 - carbs_ratio) * 80 if goal == "keto" else 0, -6, 6)
score += np.clip((0.25 - fat_ratio) * 30 if goal == "lose_fat" else 0, -5, 5)
return float(np.clip(score, 1, 100))
def _build_feature(self, profile: dict, recipe) -> list:
return [
float(profile.get("weight", 65)),
float(profile.get("body_fat", 20)),
float(profile.get("exercise_kcal", 300)),
float(profile.get("intake_kcal", 1800)),
float(profile.get("age", 25)),
float(profile.get("height_cm", 170)),
float(self._encode_goal(profile.get("goal", "maintain"))),
float(self._encode_occupation(profile.get("occupation", "通用"))),
float(recipe.calories),
float(recipe.protein),
float(recipe.fat),
float(recipe.carbs),
float(recipe.fiber or 0),
]
def _sample_profiles(self, n: int = 600) -> list:
rng = np.random.default_rng(2026)
goals = list(GOAL_MAP.keys())
occupations = list(OCCUPATION_MAP.keys())
profiles = []
for _ in range(n):
goal = goals[int(rng.integers(0, len(goals)))]
occupation = occupations[int(rng.integers(0, len(occupations)))]
profiles.append(
{
"weight": float(rng.uniform(45, 100)),
"body_fat": float(rng.uniform(10, 38)),
"exercise_kcal": float(rng.uniform(50, 850)),
"intake_kcal": float(rng.uniform(1200, 3200)),
"age": float(rng.uniform(18, 55)),
"height_cm": float(rng.uniform(150, 190)),
"goal": goal,
"occupation": occupation,
}
)
return profiles
def train(self, recipes: list) -> None:
if not recipes:
raise ValueError("训练随机森林前至少需要 1 条食谱数据")
x_rows = []
y_rows = []
sampled_profiles = self._sample_profiles()
for profile in sampled_profiles:
for recipe in recipes:
x_rows.append(self._build_feature(profile, recipe))
y_rows.append(self._heuristic_score(profile, recipe))
x = np.array(x_rows)
y = np.array(y_rows)
model = RandomForestRegressor(
n_estimators=240,
random_state=2026,
max_depth=12,
min_samples_leaf=2,
n_jobs=-1,
)
model.fit(x, y)
self.model = model
self.recipe_signature = self._signature(recipes)
self.model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(
{
"model": model,
"recipe_signature": self.recipe_signature,
"trained_at": datetime.utcnow().isoformat(),
},
self.model_path,
)
def load_or_train(self, recipes: list) -> None:
current_signature = self._signature(recipes)
if self.model_path.exists():
payload = joblib.load(self.model_path)
if payload.get("recipe_signature") == current_signature:
self.model = payload["model"]
self.recipe_signature = current_signature
return
self.train(recipes)
@staticmethod
def _build_reason(profile: dict, recipe, score: float) -> str:
goal = profile.get("goal", "maintain")
if goal == "lose_fat":
return f"热量适中,蛋白质 {recipe.protein}g适合减脂期控热量和保肌。"
if goal == "gain_muscle":
return f"蛋白质与碳水配置较高,适合增肌训练后的恢复。"
if goal == "keto":
return f"碳水 {recipe.carbs}g偏低碳结构适合生酮期参考。"
if score > 80:
return "营养均衡度高,适合作为日常轻食搭配。"
return "综合营养结构较均衡,可作为个性化备选方案。"
def recommend(self, profile: dict, recipes: list, top_k: int = 5) -> list:
if not recipes:
return []
self.load_or_train(recipes)
x = np.array([self._build_feature(profile, recipe) for recipe in recipes])
pred_scores = self.model.predict(x)
result = []
for recipe, score in zip(recipes, pred_scores):
row = recipe.to_dict()
row["rf_score"] = round(float(score), 2)
row["reason"] = self._build_reason(profile, recipe, float(score))
result.append(row)
result.sort(key=lambda item: item["rf_score"], reverse=True)
return result[:top_k]
def merge_profile_with_history(base_profile: dict, history: list) -> dict:
if not history:
return base_profile
weights = [item.weight for item in history]
body_fats = [item.body_fat for item in history]
exercise = [item.exercise_kcal for item in history]
intake = [item.intake_kcal for item in history]
merged = dict(base_profile)
merged["weight"] = float(np.mean(weights))
merged["body_fat"] = float(np.mean(body_fats))
merged["exercise_kcal"] = float(np.mean(exercise))
merged["intake_kcal"] = float(np.mean(intake))
return merged

180
backend/app/models.py Normal file
View File

@@ -0,0 +1,180 @@
from datetime import datetime
from werkzeug.security import check_password_hash, generate_password_hash
from app.extensions import db
class User(db.Model):
__tablename__ = "users"
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(64), unique=True, nullable=False, index=True)
password_hash = db.Column(db.String(255), nullable=False)
nickname = db.Column(db.String(64), nullable=False)
company = db.Column(db.String(128), default="")
title = db.Column(db.String(64), default="")
phone = db.Column(db.String(32), default="")
is_admin = db.Column(db.Boolean, default=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
prediction_logs = db.relationship("SpamPredictionLog", backref="user", lazy=True, cascade="all, delete-orphan")
training_samples = db.relationship("SpamTrainingSample", backref="creator", lazy=True, foreign_keys="SpamTrainingSample.created_by")
sent_posts = db.relationship("ContentPost", backref="author", lazy=True, foreign_keys="ContentPost.user_id")
received_posts = db.relationship("ContentPost", backref="recipient", lazy=True, foreign_keys="ContentPost.recipient_user_id")
reviewed_posts = db.relationship("ContentPost", backref="reviewer", lazy=True, foreign_keys="ContentPost.manual_review_by")
def set_password(self, password: str) -> None:
self.password_hash = generate_password_hash(password)
def check_password(self, password: str) -> bool:
return check_password_hash(self.password_hash, password)
def to_dict(self) -> dict:
return {
"id": self.id,
"username": self.username,
"nickname": self.nickname,
"company": self.company,
"title": self.title,
"phone": self.phone,
"is_admin": self.is_admin,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class SpamTrainingSample(db.Model):
__tablename__ = "spam_training_samples"
id = db.Column(db.Integer, primary_key=True)
text = db.Column(db.Text, nullable=False)
label = db.Column(db.String(16), nullable=False, index=True) # spam | ham
source = db.Column(db.String(32), default="seed") # seed | import | feedback | manual_review
created_by = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True, index=True)
is_active = db.Column(db.Boolean, default=True)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def to_dict(self) -> dict:
return {
"id": self.id,
"text": self.text,
"label": self.label,
"source": self.source,
"created_by": self.created_by,
"is_active": self.is_active,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class SpamPredictionLog(db.Model):
__tablename__ = "spam_prediction_logs"
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False, index=True)
text = db.Column(db.Text, nullable=False)
prediction = db.Column(db.String(16), nullable=False) # spam | ham
spam_probability = db.Column(db.Float, nullable=False)
ham_probability = db.Column(db.Float, nullable=False)
confidence = db.Column(db.Float, nullable=False)
reason_tokens = db.Column(db.JSON, default=list)
model_version = db.Column(db.String(64), default="")
created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
def to_dict(self) -> dict:
return {
"id": self.id,
"user_id": self.user_id,
"text": self.text,
"prediction": self.prediction,
"spam_probability": round(float(self.spam_probability), 4),
"ham_probability": round(float(self.ham_probability), 4),
"confidence": round(float(self.confidence), 4),
"reason_tokens": self.reason_tokens or [],
"model_version": self.model_version,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
class DetectionConfig(db.Model):
__tablename__ = "detection_configs"
id = db.Column(db.Integer, primary_key=True)
spam_threshold = db.Column(db.Float, nullable=False, default=0.75)
updated_by = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def to_dict(self) -> dict:
return {
"id": self.id,
"spam_threshold": round(float(self.spam_threshold), 4),
"updated_by": self.updated_by,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class ContentPost(db.Model):
__tablename__ = "content_posts"
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False, index=True)
recipient_user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True, index=True)
text = db.Column(db.Text, nullable=False)
visibility = db.Column(db.String(16), nullable=False, default="public") # public | private | direct
status = db.Column(db.String(16), nullable=False, default="published") # published | blocked
prediction = db.Column(db.String(16), nullable=False, default="ham")
spam_probability = db.Column(db.Float, nullable=False, default=0)
ham_probability = db.Column(db.Float, nullable=False, default=0)
confidence = db.Column(db.Float, nullable=False, default=0)
threshold = db.Column(db.Float, nullable=False, default=0.75)
reason_tokens = db.Column(db.JSON, default=list)
model_version = db.Column(db.String(64), default="")
manual_review_status = db.Column(db.String(32), nullable=False, default="none") # none | pending | confirmed_spam | approved_ham
manual_review_by = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True)
manual_review_note = db.Column(db.String(255), default="")
manual_review_at = db.Column(db.DateTime, nullable=True)
appeal_status = db.Column(db.String(16), nullable=False, default="none") # none | pending | approved | rejected
appeal_reason = db.Column(db.String(255), default="")
appeal_admin_note = db.Column(db.String(255), default="")
appeal_submitted_at = db.Column(db.DateTime, nullable=True)
appeal_processed_at = db.Column(db.DateTime, nullable=True)
appeal_processed_by = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True)
created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def to_dict(self) -> dict:
return {
"id": self.id,
"user_id": self.user_id,
"recipient_user_id": self.recipient_user_id,
"text": self.text,
"visibility": self.visibility,
"status": self.status,
"prediction": self.prediction,
"spam_probability": round(float(self.spam_probability), 4),
"ham_probability": round(float(self.ham_probability), 4),
"confidence": round(float(self.confidence), 4),
"threshold": round(float(self.threshold), 4),
"reason_tokens": self.reason_tokens or [],
"model_version": self.model_version,
"manual_review_status": self.manual_review_status,
"manual_review_by": self.manual_review_by,
"manual_review_note": self.manual_review_note,
"manual_review_at": self.manual_review_at.isoformat() if self.manual_review_at else None,
"appeal_status": self.appeal_status,
"appeal_reason": self.appeal_reason,
"appeal_admin_note": self.appeal_admin_note,
"appeal_submitted_at": self.appeal_submitted_at.isoformat() if self.appeal_submitted_at else None,
"appeal_processed_at": self.appeal_processed_at.isoformat() if self.appeal_processed_at else None,
"appeal_processed_by": self.appeal_processed_by,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}

View File

@@ -0,0 +1 @@

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,151 @@
import json
from pathlib import Path
from typing import Any
import requests
class LLMConfigManager:
def __init__(self, config_path: str):
self.config_path = Path(config_path)
def load(self) -> dict:
if not self.config_path.exists():
return {}
with self.config_path.open("r", encoding="utf-8-sig") as file:
return json.load(file)
def _join_url(base_url: str, path: str) -> str:
return f"{base_url.rstrip('/')}/{path.lstrip('/')}"
def _timeout(config: dict) -> int:
section = config.get("request", {}) if isinstance(config, dict) else {}
try:
return int(section.get("timeout_seconds", 45))
except Exception:
return 45
def _extract_openai_text(payload: Any) -> str:
if isinstance(payload, dict):
choices = payload.get("choices")
if isinstance(choices, list) and choices:
msg = choices[0].get("message") if isinstance(choices[0], dict) else {}
if isinstance(msg, dict) and msg.get("content"):
return str(msg["content"])
if payload.get("answer"):
return str(payload["answer"])
if payload.get("data"):
return json.dumps(payload["data"], ensure_ascii=False)
return json.dumps(payload, ensure_ascii=False) if isinstance(payload, (dict, list)) else str(payload)
def ask_fastgpt(config: dict, prompt: str, context: str = "", custom_uid: str = "") -> dict:
section = config.get("fastgpt", {})
base_url = section.get("base_url", "").strip()
api_key = section.get("api_key", "").strip()
chat_id = str(section.get("chat_id", "111"))
model = section.get("model", "")
timeout = _timeout(config)
if not base_url or not api_key:
return {"ok": False, "message": "FastGPT 未配置 base_url 或 api_key"}
url = _join_url(base_url, "/v1/chat/completions")
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
content = prompt if not context else f"请基于以下本地知识库内容回答。\n\n知识库:\n{context}\n\n问题: {prompt}"
messages = []
system_prompt = section.get("system_prompt")
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": content})
use_custom_uid = custom_uid or section.get("custom_uid", "")
payload = {
"chatId": chat_id,
"stream": False,
"detail": False,
"messages": messages,
}
if model:
payload["model"] = model
if use_custom_uid:
payload["customUid"] = use_custom_uid
try:
response = requests.post(url, headers=headers, json=payload, timeout=timeout, allow_redirects=True)
if response.status_code >= 400:
return {
"ok": False,
"message": f"FastGPT 调用失败: HTTP {response.status_code}",
"detail": response.text,
}
data = response.json()
return {
"ok": True,
"provider": "fastgpt",
"raw": data,
"answer": _extract_openai_text(data),
}
except Exception as exc:
return {
"ok": False,
"message": "FastGPT 请求异常",
"detail": str(exc),
}
def ask_dify(config: dict, prompt: str, context: str = "", user: str = "anonymous") -> dict:
section = config.get("dify", {})
base_url = section.get("base_url", "").strip()
api_key = section.get("api_key", "").strip()
timeout = _timeout(config)
if not base_url or not api_key:
return {"ok": False, "message": "Dify 未配置 base_url 或 api_key"}
url = _join_url(base_url, "/v1/chat-messages")
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
query = prompt if not context else f"知识库上下文:\n{context}\n\n用户问题:\n{prompt}"
payload = {
"inputs": {},
"query": query,
"response_mode": "blocking",
"conversation_id": "",
"user": user,
}
try:
response = requests.post(url, headers=headers, json=payload, timeout=timeout, allow_redirects=True)
if response.status_code >= 400:
return {
"ok": False,
"message": f"Dify 调用失败: HTTP {response.status_code}",
"detail": response.text,
}
data = response.json()
answer = data.get("answer") or _extract_openai_text(data)
return {
"ok": True,
"provider": "dify",
"raw": data,
"answer": answer,
}
except Exception as exc:
return {
"ok": False,
"message": "Dify 请求异常",
"detail": str(exc),
}

View File

@@ -0,0 +1,62 @@
import json
import re
from pathlib import Path
class LocalKnowledgeRetriever:
def __init__(self, kb_path: str):
self.kb_path = Path(kb_path)
self.documents = []
self._load()
@staticmethod
def _tokenize(text: str) -> set:
text = text or ""
words = set(re.findall(r"[A-Za-z0-9_]+", text.lower()))
cjk_chunks = re.findall(r"[\u4e00-\u9fff]+", text)
for chunk in cjk_chunks:
if len(chunk) <= 2:
words.add(chunk)
else:
for idx in range(len(chunk) - 1):
words.add(chunk[idx : idx + 2])
return words
def _load(self):
if not self.kb_path.exists():
self.documents = []
return
with self.kb_path.open("r", encoding="utf-8-sig") as file:
rows = json.load(file)
self.documents = rows if isinstance(rows, list) else []
def reload(self):
self._load()
def search(self, query: str, top_k: int = 3) -> list:
query_tokens = self._tokenize(query)
if not query_tokens:
return []
scored = []
for item in self.documents:
content = f"{item.get('question', '')} {item.get('answer', '')} {' '.join(item.get('tags', []))}"
doc_tokens = self._tokenize(content)
if not doc_tokens:
continue
overlap = len(query_tokens & doc_tokens)
if overlap == 0:
continue
score = overlap / max(len(query_tokens), 1)
scored.append(
{
"score": round(score, 4),
"question": item.get("question", ""),
"answer": item.get("answer", ""),
"tags": item.get("tags", []),
"source": item.get("source", "本地知识库"),
}
)
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:top_k]

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,429 @@
from collections import Counter
from datetime import datetime, timedelta
from flask import Blueprint, current_app, request
from sqlalchemy import func, or_
from app.extensions import db
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
from app.models import ContentPost, DetectionConfig, SpamPredictionLog, SpamTrainingSample, User
from app.utils.auth import admin_required, current_user
from app.utils.response import fail, ok
admin_bp = Blueprint("admin", __name__)
def _day_key(day_value) -> str:
if hasattr(day_value, "isoformat"):
return day_value.isoformat()
return str(day_value)
def _tokenize(text: str) -> list[str]:
content = (text or "").strip()
if len(content) <= 1:
return []
tokens = []
for i in range(len(content) - 1):
token = content[i : i + 2]
if token.strip():
tokens.append(token)
return tokens
def _get_or_create_config() -> DetectionConfig:
cfg = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
if cfg:
return cfg
cfg = DetectionConfig(spam_threshold=0.75)
db.session.add(cfg)
db.session.commit()
return cfg
def _serialize_post(item: ContentPost) -> dict:
row = item.to_dict()
row["username"] = item.author.username if item.author else ""
row["nickname"] = item.author.nickname if item.author else ""
row["recipient_username"] = item.recipient.username if item.recipient else ""
row["recipient_nickname"] = item.recipient.nickname if item.recipient else ""
row["reviewer_username"] = item.reviewer.username if item.reviewer else ""
return row
def _upsert_manual_sample(text: str, label: str, admin_id: int | None) -> None:
existed = SpamTrainingSample.query.filter_by(text=text, label=label).first()
if existed:
existed.is_active = True
existed.source = existed.source or "manual_review"
return
row = SpamTrainingSample(
text=text,
label=label,
source="manual_review",
created_by=admin_id,
is_active=True,
)
db.session.add(row)
@admin_bp.get("/stats")
@admin_required
def stats():
user_count = User.query.count()
sample_count = SpamTrainingSample.query.count()
predict_count = SpamPredictionLog.query.count()
post_count = ContentPost.query.count()
blocked_count = ContentPost.query.filter_by(status="blocked").count()
published_count = ContentPost.query.filter_by(status="published").count()
pending_appeal_count = ContentPost.query.filter_by(appeal_status="pending").count()
now = datetime.utcnow()
week_ago = now - timedelta(days=6)
trend_rows = (
db.session.query(func.date(ContentPost.created_at), func.count(ContentPost.id))
.filter(ContentPost.created_at >= week_ago)
.group_by(func.date(ContentPost.created_at))
.all()
)
blocked_7d_count = ContentPost.query.filter(ContentPost.created_at >= week_ago, ContentPost.status == "blocked").count() or 0
total_7d_count = ContentPost.query.filter(ContentPost.created_at >= week_ago).count() or 0
day_map = {_day_key(day): int(count or 0) for day, count in trend_rows}
trend = []
today = now.date()
for offset in range(6, -1, -1):
day = today - timedelta(days=offset)
key = day.isoformat()
trend.append({"date": key, "label": day.strftime("%m-%d"), "post_count": day_map.get(key, 0)})
source_rows = (
db.session.query(SpamTrainingSample.source, func.count(SpamTrainingSample.id))
.group_by(SpamTrainingSample.source)
.order_by(func.count(SpamTrainingSample.id).desc())
.all()
)
source_dist = [{"name": (name or "unknown"), "value": int(value or 0)} for name, value in source_rows]
blocked_logs = (
ContentPost.query.filter(ContentPost.created_at >= week_ago, ContentPost.status == "blocked")
.order_by(ContentPost.id.desc())
.limit(1000)
.all()
)
token_counter = Counter()
for row in blocked_logs:
token_counter.update(_tokenize(row.text))
top_keywords = [{"token": token, "count": count} for token, count in token_counter.most_common(12)]
cfg = _get_or_create_config()
clf = NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
clf.load()
return ok(
{
"user_count": user_count,
"sample_count": sample_count,
"predict_count": predict_count,
"post_count": post_count,
"blocked_count": blocked_count,
"published_count": published_count,
"pending_appeal_count": pending_appeal_count,
"blocked_ratio_7d": round(blocked_7d_count / total_7d_count, 4) if total_7d_count else 0,
"total_7d": total_7d_count,
"trend_7d": trend,
"source_distribution": source_dist,
"top_keywords": top_keywords,
"model_info": clf.model_info(),
"threshold": cfg.to_dict(),
}
)
@admin_bp.get("/detection/threshold")
@admin_required
def get_threshold():
return ok(_get_or_create_config().to_dict())
@admin_bp.put("/detection/threshold")
@admin_required
def set_threshold():
payload = request.get_json(silent=True) or {}
try:
threshold = float(payload.get("spam_threshold"))
except Exception:
return fail("spam_threshold 必须是数字", 400)
if threshold < 0.01 or threshold > 0.99:
return fail("spam_threshold 必须在 0.01 到 0.99 之间", 400)
cfg = _get_or_create_config()
admin = current_user()
cfg.spam_threshold = threshold
cfg.updated_by = admin.id if admin else None
db.session.commit()
return ok(cfg.to_dict(), "阈值更新成功")
@admin_bp.get("/intercepts")
@admin_required
def list_intercepts():
keyword = (request.args.get("keyword") or "").strip()
status = (request.args.get("status") or "blocked").strip().lower()
review_status = (request.args.get("review_status") or "").strip().lower()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
query = ContentPost.query
if keyword:
query = query.filter(ContentPost.text.like(f"%{keyword}%"))
if status in {"blocked", "published"}:
query = query.filter(ContentPost.status == status)
if review_status in {"none", "pending", "confirmed_spam", "approved_ham"}:
query = query.filter(ContentPost.manual_review_status == review_status)
pagination = query.order_by(ContentPost.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
return ok(
{
"items": [_serialize_post(item) for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@admin_bp.put("/intercepts/<int:post_id>/review")
@admin_required
def review_intercept(post_id: int):
row = ContentPost.query.get(post_id)
if not row:
return fail("记录不存在", 404)
payload = request.get_json(silent=True) or {}
decision = (payload.get("decision") or "").strip().lower()
note = (payload.get("note") or "").strip()
if decision not in {"spam", "ham"}:
return fail("decision 必须是 spam 或 ham", 400)
admin = current_user()
now = datetime.utcnow()
row.manual_review_by = admin.id if admin else None
row.manual_review_note = note
row.manual_review_at = now
if decision == "spam":
row.status = "blocked"
row.prediction = "spam"
row.manual_review_status = "confirmed_spam"
if row.appeal_status == "pending":
row.appeal_status = "rejected"
row.appeal_admin_note = note or "人工复核确认为垃圾信息"
row.appeal_processed_at = now
row.appeal_processed_by = admin.id if admin else None
_upsert_manual_sample(row.text, "spam", admin.id if admin else None)
else:
row.status = "published"
row.prediction = "ham"
row.manual_review_status = "approved_ham"
if row.appeal_status == "pending":
row.appeal_status = "approved"
row.appeal_admin_note = note or "人工复核后解除拦截"
row.appeal_processed_at = now
row.appeal_processed_by = admin.id if admin else None
_upsert_manual_sample(row.text, "ham", admin.id if admin else None)
db.session.commit()
return ok(_serialize_post(row), "人工复核完成")
@admin_bp.get("/appeals")
@admin_required
def list_appeals():
keyword = (request.args.get("keyword") or "").strip()
status = (request.args.get("status") or "pending").strip().lower()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
query = ContentPost.query.filter(ContentPost.appeal_status != "none")
if keyword:
query = query.filter(
or_(
ContentPost.text.like(f"%{keyword}%"),
ContentPost.appeal_reason.like(f"%{keyword}%"),
ContentPost.appeal_admin_note.like(f"%{keyword}%"),
)
)
if status in {"pending", "approved", "rejected"}:
query = query.filter(ContentPost.appeal_status == status)
pagination = query.order_by(ContentPost.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
return ok(
{
"items": [_serialize_post(item) for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@admin_bp.put("/appeals/<int:post_id>/process")
@admin_required
def process_appeal(post_id: int):
row = ContentPost.query.get(post_id)
if not row:
return fail("记录不存在", 404)
if row.appeal_status != "pending":
return fail("该申诉不在待处理状态", 400)
payload = request.get_json(silent=True) or {}
action = (payload.get("action") or "").strip().lower()
note = (payload.get("note") or "").strip()
if action not in {"approve", "reject"}:
return fail("action 必须是 approve 或 reject", 400)
admin = current_user()
now = datetime.utcnow()
row.appeal_status = "approved" if action == "approve" else "rejected"
row.appeal_admin_note = note
row.appeal_processed_at = now
row.appeal_processed_by = admin.id if admin else None
row.manual_review_by = admin.id if admin else None
row.manual_review_note = note
row.manual_review_at = now
if action == "approve":
row.status = "published"
row.prediction = "ham"
row.manual_review_status = "approved_ham"
_upsert_manual_sample(row.text, "ham", admin.id if admin else None)
else:
row.status = "blocked"
row.prediction = "spam"
row.manual_review_status = "confirmed_spam"
_upsert_manual_sample(row.text, "spam", admin.id if admin else None)
db.session.commit()
return ok(_serialize_post(row), "申诉处理完成")
@admin_bp.get("/users")
@admin_required
def list_users():
keyword = (request.args.get("keyword") or "").strip()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
query = User.query
if keyword:
query = query.filter(User.username.like(f"%{keyword}%") | User.nickname.like(f"%{keyword}%"))
pagination = query.order_by(User.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
return ok(
{
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@admin_bp.post("/users/import")
@admin_required
def import_users():
payload = request.get_json(silent=True) or {}
items = payload.get("items") or []
if not isinstance(items, list) or not items:
return fail("items 必须是非空数组", 400)
created = 0
updated = 0
for row in items:
username = (row.get("username") or "").strip()
if len(username) < 3:
continue
user = User.query.filter_by(username=username).first()
if not user:
user = User(
username=username,
nickname=(row.get("nickname") or username).strip(),
company=(row.get("company") or "").strip(),
title=(row.get("title") or "").strip(),
phone=(row.get("phone") or "").strip(),
is_admin=bool(row.get("is_admin", False)),
)
user.set_password(row.get("password") or "123456")
db.session.add(user)
created += 1
continue
user.nickname = (row.get("nickname") or user.nickname).strip()
user.company = (row.get("company") or user.company).strip()
user.title = (row.get("title") or user.title).strip()
user.phone = (row.get("phone") or user.phone).strip()
if "is_admin" in row:
user.is_admin = bool(row.get("is_admin"))
if row.get("password"):
user.set_password(row["password"])
updated += 1
db.session.commit()
return ok({"created": created, "updated": updated}, "用户导入完成")
@admin_bp.put("/users/<int:user_id>")
@admin_required
def update_user(user_id: int):
user = User.query.get(user_id)
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
if "nickname" in payload:
user.nickname = (payload.get("nickname") or user.nickname).strip()
if "company" in payload:
user.company = (payload.get("company") or "").strip()
if "title" in payload:
user.title = (payload.get("title") or "").strip()
if "phone" in payload:
user.phone = (payload.get("phone") or "").strip()
if "is_admin" in payload:
user.is_admin = bool(payload.get("is_admin"))
if payload.get("password"):
if len(payload["password"]) < 6:
return fail("密码至少6位", 400)
user.set_password(payload["password"])
db.session.commit()
return ok(user.to_dict(), "用户更新成功")
@admin_bp.delete("/users/<int:user_id>")
@admin_required
def delete_user(user_id: int):
user = User.query.get(user_id)
if not user:
return fail("用户不存在", 404)
if user.is_admin and User.query.filter_by(is_admin=True).count() <= 1:
return fail("至少保留一个管理员账号", 400)
db.session.delete(user)
db.session.commit()
return ok({}, "用户已删除")

View File

@@ -0,0 +1,67 @@
from flask import Blueprint, request
from flask_jwt_extended import create_access_token, jwt_required
from app.extensions import db
from app.models import User
from app.utils.auth import current_user
from app.utils.response import fail, ok
auth_bp = Blueprint("auth", __name__)
@auth_bp.post("/register")
def register():
payload = request.get_json(silent=True) or {}
username = (payload.get("username") or "").strip()
password = payload.get("password") or ""
nickname = (payload.get("nickname") or username).strip()
if len(username) < 3:
return fail("用户名至少3位", 400)
if len(password) < 6:
return fail("密码至少6位", 400)
if User.query.filter_by(username=username).first():
return fail("用户名已存在", 409)
user = User(
username=username,
nickname=nickname or username,
company=(payload.get("company") or "").strip(),
title=(payload.get("title") or "").strip(),
phone=(payload.get("phone") or "").strip(),
is_admin=bool(payload.get("is_admin", False)),
)
user.set_password(password)
db.session.add(user)
db.session.commit()
return ok(user.to_dict(), "注册成功")
@auth_bp.post("/login")
def login():
payload = request.get_json(silent=True) or {}
username = (payload.get("username") or "").strip()
password = payload.get("password") or ""
user = User.query.filter_by(username=username).first()
if not user or not user.check_password(password):
return fail("用户名或密码错误", 401)
access_token = create_access_token(
identity=str(user.id),
additional_claims={"is_admin": bool(user.is_admin), "username": user.username},
)
return ok({"token": access_token, "user": user.to_dict()}, "登录成功")
@auth_bp.get("/me")
@jwt_required()
def me():
user = current_user()
if not user:
return fail("用户不存在", 404)
return ok(user.to_dict())

View File

@@ -0,0 +1,360 @@
from datetime import datetime
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
from app.models import ContentPost, DetectionConfig, SpamPredictionLog, SpamTrainingSample, User
from app.utils.auth import current_user
from app.utils.response import fail, ok
content_bp = Blueprint("content", __name__)
def _classifier() -> NaiveBayesSpamClassifier:
return NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
def _active_samples() -> list[dict]:
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
return [{"text": row.text, "label": row.label} for row in rows]
def _ensure_ready() -> NaiveBayesSpamClassifier:
clf = _classifier()
clf.ensure_ready(_active_samples())
return clf
def _get_config() -> DetectionConfig:
cfg = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
if cfg:
return cfg
cfg = DetectionConfig(spam_threshold=0.75)
db.session.add(cfg)
db.session.commit()
return cfg
def _serialize_post(row: ContentPost) -> dict:
payload = row.to_dict()
payload["username"] = row.author.username if row.author else ""
payload["nickname"] = row.author.nickname if row.author else ""
payload["recipient_username"] = row.recipient.username if row.recipient else ""
payload["recipient_nickname"] = row.recipient.nickname if row.recipient else ""
payload["reviewer_username"] = row.reviewer.username if row.reviewer else ""
return payload
def _resolve_visibility(value: str) -> str:
key = (value or "public").strip().lower()
return key if key in {"public", "private", "direct"} else "public"
def _resolve_recipient(payload: dict, visibility: str, current_user_id: int):
if visibility != "direct":
return None, None
recipient = None
raw_id = payload.get("recipient_user_id")
username = (payload.get("recipient_username") or "").strip()
if raw_id is not None and str(raw_id).strip() != "":
try:
recipient = User.query.get(int(raw_id))
except Exception:
return None, "recipient_user_id 无效"
elif username:
recipient = User.query.filter_by(username=username).first()
if not recipient:
return None, "私信发布必须指定有效接收人"
if recipient.id == current_user_id:
return None, "不能给自己发送私信"
return recipient, None
def _predict_and_decide(text: str) -> tuple[dict, float, bool]:
clf = _ensure_ready()
result = clf.predict(text)
threshold = float(_get_config().spam_threshold)
blocked = float(result["spam_probability"]) >= threshold
return result, threshold, blocked
@content_bp.post("/publish")
@jwt_required()
def publish_text():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
visibility = _resolve_visibility(payload.get("visibility"))
if len(text) < 2:
return fail("发布文本至少2个字符", 400)
recipient, err = _resolve_recipient(payload, visibility, user.id)
if err:
return fail(err, 400)
result, threshold, blocked = _predict_and_decide(text)
post = ContentPost(
user_id=user.id,
recipient_user_id=recipient.id if recipient else None,
text=result["text"],
visibility=visibility,
status="blocked" if blocked else "published",
prediction=result["prediction"],
spam_probability=result["spam_probability"],
ham_probability=result["ham_probability"],
confidence=result["confidence"],
threshold=threshold,
reason_tokens=result["reason_tokens"],
model_version=result.get("model_version", ""),
manual_review_status="pending" if blocked else "none",
)
detect_log = SpamPredictionLog(
user_id=user.id,
text=result["text"],
prediction=result["prediction"],
spam_probability=result["spam_probability"],
ham_probability=result["ham_probability"],
confidence=result["confidence"],
reason_tokens=result["reason_tokens"],
model_version=result.get("model_version", ""),
)
db.session.add(post)
db.session.add(detect_log)
db.session.commit()
feedback = "发布成功" if not blocked else "疑似垃圾信息,系统已拦截,可提交申诉"
return ok(
{
"publish_allowed": not blocked,
"action": "published" if not blocked else "blocked",
"feedback": feedback,
"post": _serialize_post(post),
"detect": result,
},
feedback,
)
@content_bp.put("/posts/<int:post_id>")
@jwt_required()
def edit_post(post_id: int):
user = current_user()
if not user:
return fail("用户不存在", 404)
post = ContentPost.query.filter_by(id=post_id, user_id=user.id).first()
if not post:
return fail("发布记录不存在", 404)
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or post.text).strip()
visibility = _resolve_visibility(payload.get("visibility") or post.visibility)
if len(text) < 2:
return fail("发布文本至少2个字符", 400)
recipient, err = _resolve_recipient(payload, visibility, user.id)
if err:
return fail(err, 400)
result, threshold, blocked = _predict_and_decide(text)
post.text = result["text"]
post.visibility = visibility
post.recipient_user_id = recipient.id if recipient else None
post.status = "blocked" if blocked else "published"
post.prediction = result["prediction"]
post.spam_probability = result["spam_probability"]
post.ham_probability = result["ham_probability"]
post.confidence = result["confidence"]
post.threshold = threshold
post.reason_tokens = result["reason_tokens"]
post.model_version = result.get("model_version", "")
post.manual_review_status = "pending" if blocked else "none"
post.manual_review_by = None
post.manual_review_note = ""
post.manual_review_at = None
post.appeal_status = "none"
post.appeal_reason = ""
post.appeal_admin_note = ""
post.appeal_submitted_at = None
post.appeal_processed_at = None
post.appeal_processed_by = None
db.session.commit()
feedback = "更新并重新发布成功" if not blocked else "更新后触发拦截,可提交申诉"
return ok(
{
"publish_allowed": not blocked,
"action": "published" if not blocked else "blocked",
"feedback": feedback,
"post": _serialize_post(post),
"detect": result,
},
feedback,
)
@content_bp.get("/posts/history")
@jwt_required()
def my_posts():
user = current_user()
if not user:
return fail("用户不存在", 404)
status = (request.args.get("status") or "").strip().lower()
visibility = (request.args.get("visibility") or "").strip().lower()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
query = ContentPost.query.filter_by(user_id=user.id)
if status in {"published", "blocked"}:
query = query.filter(ContentPost.status == status)
if visibility in {"public", "private", "direct"}:
query = query.filter(ContentPost.visibility == visibility)
pagination = query.order_by(ContentPost.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
return ok(
{
"items": [_serialize_post(item) for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@content_bp.get("/posts/inbox")
@jwt_required()
def my_inbox():
user = current_user()
if not user:
return fail("用户不存在", 404)
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
pagination = (
ContentPost.query.filter_by(recipient_user_id=user.id, visibility="direct", status="published")
.order_by(ContentPost.id.desc())
.paginate(page=page, per_page=page_size, error_out=False)
)
return ok(
{
"items": [_serialize_post(item) for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@content_bp.delete("/posts/<int:post_id>")
@jwt_required()
def delete_post(post_id: int):
user = current_user()
if not user:
return fail("用户不存在", 404)
row = ContentPost.query.filter_by(id=post_id, user_id=user.id).first()
if not row:
return fail("记录不存在", 404)
db.session.delete(row)
db.session.commit()
return ok({}, "记录已删除")
@content_bp.post("/posts/<int:post_id>/appeal")
@jwt_required()
def submit_appeal(post_id: int):
user = current_user()
if not user:
return fail("用户不存在", 404)
post = ContentPost.query.filter_by(id=post_id, user_id=user.id).first()
if not post:
return fail("发布记录不存在", 404)
if post.status != "blocked":
return fail("仅被拦截的信息可申诉", 400)
payload = request.get_json(silent=True) or {}
reason = (payload.get("reason") or "").strip()
if len(reason) < 2:
return fail("申诉理由至少2个字符", 400)
if post.appeal_status == "pending":
return fail("该记录已在申诉处理中", 400)
post.appeal_status = "pending"
post.appeal_reason = reason
post.appeal_submitted_at = datetime.utcnow()
post.appeal_admin_note = ""
post.appeal_processed_at = None
post.appeal_processed_by = None
post.manual_review_status = "pending"
db.session.commit()
return ok(_serialize_post(post), "申诉提交成功")
@content_bp.get("/appeals/my")
@jwt_required()
def my_appeals():
user = current_user()
if not user:
return fail("用户不存在", 404)
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
pagination = (
ContentPost.query.filter(ContentPost.user_id == user.id, ContentPost.appeal_status != "none")
.order_by(ContentPost.id.desc())
.paginate(page=page, per_page=page_size, error_out=False)
)
return ok(
{
"items": [_serialize_post(item) for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@content_bp.get("/posts/public")
@jwt_required(optional=True)
def public_feed():
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
pagination = (
ContentPost.query.filter_by(visibility="public", status="published")
.order_by(ContentPost.id.desc())
.paginate(page=page, per_page=page_size, error_out=False)
)
return ok(
{
"items": [_serialize_post(item) for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)

View File

@@ -0,0 +1,121 @@
from datetime import datetime
from flask import Blueprint, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.models import DietStatus
from app.utils.auth import current_user
from app.utils.response import fail, ok
diet_bp = Blueprint("diet", __name__)
def _parse_status_payload(payload: dict) -> dict:
return {
"weight": float(payload.get("weight", 0)),
"body_fat": float(payload.get("body_fat", 0)),
"exercise_kcal": float(payload.get("exercise_kcal", 0)),
"intake_kcal": float(payload.get("intake_kcal", 0)),
"sleep_hours": float(payload.get("sleep_hours", 7)),
"note": (payload.get("note") or "").strip(),
}
@diet_bp.post("/status")
@jwt_required()
def create_status():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
data = _parse_status_payload(payload)
if data["weight"] <= 0 or data["body_fat"] <= 0:
return fail("体重和体脂率必须大于0", 400)
status = DietStatus(user_id=user.id, **data)
if payload.get("recorded_at"):
status.recorded_at = datetime.fromisoformat(payload["recorded_at"])
db.session.add(status)
db.session.commit()
return ok(status.to_dict(), "饮食状态记录成功")
@diet_bp.put("/status/latest")
@jwt_required()
def update_latest_status():
user = current_user()
if not user:
return fail("用户不存在", 404)
latest = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.first()
)
if not latest:
return fail("暂无可编辑的饮食状态", 404)
payload = request.get_json(silent=True) or {}
data = _parse_status_payload(payload)
for key, value in data.items():
setattr(latest, key, value)
db.session.commit()
return ok(latest.to_dict(), "最新饮食状态已更新")
@diet_bp.get("/status/latest")
@jwt_required()
def get_latest_status():
user = current_user()
if not user:
return fail("用户不存在", 404)
latest = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.first()
)
if not latest:
return ok({})
return ok(latest.to_dict())
@diet_bp.get("/history")
@jwt_required()
def history():
user = current_user()
if not user:
return fail("用户不存在", 404)
limit = min(max(int(request.args.get("limit", 30) or 30), 1), 200)
rows = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.limit(limit)
.all()
)
return ok([item.to_dict() for item in rows])
@diet_bp.delete("/history/<int:status_id>")
@jwt_required()
def delete_history(status_id: int):
user = current_user()
if not user:
return fail("用户不存在", 404)
row = DietStatus.query.filter_by(id=status_id, user_id=user.id).first()
if not row:
return fail("记录不存在", 404)
db.session.delete(row)
db.session.commit()
return ok({}, "记录已删除")

View File

@@ -0,0 +1,177 @@
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.rag.llm_client import LLMConfigManager, ask_dify, ask_fastgpt
from app.rag.local_retriever import LocalKnowledgeRetriever
from app.utils.auth import admin_required, current_user
from app.utils.response import ok
qa_bp = Blueprint("qa", __name__)
def _retriever() -> LocalKnowledgeRetriever:
return LocalKnowledgeRetriever(current_app.config["LOCAL_KB_PATH"])
def _llm_config() -> dict:
manager = LLMConfigManager(current_app.config["LLM_CONFIG_PATH"])
return manager.load()
def _default_provider(config: dict) -> str:
provider = (config.get("active_provider") or "fastgpt").strip().lower() if isinstance(config, dict) else "fastgpt"
return provider if provider in {"fastgpt", "dify"} else "fastgpt"
def _context_text(hits: list) -> str:
chunks = []
for idx, item in enumerate(hits, start=1):
chunks.append(f"[{idx}] 问: {item['question']}\n答: {item['answer']}")
return "\n\n".join(chunks)
def _local_answer(hits: list) -> str:
if not hits:
return "本地知识库暂未检索到高相关内容,建议补充问题细节。"
top = hits[0]
return f"根据本地知识库:{top['answer']}"
def _advice_template(advice_type: str) -> str:
templates = {
"gain_muscle": "增肌建议:每天蛋白质建议 1.6-2.2g/kg训练后 30 分钟内补充蛋白+碳水,优先鸡胸肉/牛肉/鸡蛋/燕麦。",
"lose_fat": "减脂建议:保持 10%-20% 轻热量缺口,优先高蛋白高纤维食物,减少高糖饮料和深夜加餐。",
"keto": "生酮建议:控制碳水通常小于 50g/天,优先健康脂肪和适量蛋白,并关注电解质补充。",
"nutritionist": "营养师建议:结合体重、体脂、活动量和目标制定周计划,每 7 天复盘体重和围度变化。",
"general": "通用建议:保证规律作息、三餐结构稳定、适量运动和充足饮水。",
}
return templates.get(advice_type, templates["general"])
def _ask_by_provider(provider: str, config: dict, question: str, context: str, user_id: str, username: str):
if provider == "dify":
return ask_dify(config, question, context=context, user=username)
return ask_fastgpt(config, question, context=context, custom_uid=user_id)
@qa_bp.get("/kb/search")
@jwt_required(optional=True)
def search_kb():
query = (request.args.get("query") or "").strip()
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 10)
retriever = _retriever()
hits = retriever.search(query, top_k=top_k)
return ok({"items": hits})
@qa_bp.post("/ask")
@jwt_required()
def ask_nutrition_qa():
payload = request.get_json(silent=True) or {}
question = (payload.get("question") or "").strip()
if not question:
return ok({"answer": "请输入问题", "items": []}, "empty_question")
mode = (payload.get("mode") or "auto").strip().lower()
llm_cfg = _llm_config()
provider = (payload.get("provider") or _default_provider(llm_cfg)).strip().lower()
if provider not in {"fastgpt", "dify"}:
provider = _default_provider(llm_cfg)
retriever = _retriever()
hits = retriever.search(question, top_k=3)
context = _context_text(hits)
local_answer = _local_answer(hits)
if mode == "local":
return ok({"answer": local_answer, "provider": "local_kb", "items": hits})
user = current_user()
llm_result = _ask_by_provider(
provider,
llm_cfg,
question,
context,
str(user.id) if user else "",
user.username if user else "anonymous",
)
if mode == "llm":
if llm_result.get("ok"):
return ok({"answer": llm_result.get("answer"), "provider": provider, "items": hits})
return ok(
{
"answer": f"LLM 调用失败,已回退本地知识库。{local_answer}",
"provider": "local_kb_fallback",
"items": hits,
"llm_error": llm_result,
}
)
if llm_result.get("ok"):
return ok({"answer": llm_result.get("answer"), "provider": provider, "items": hits})
return ok(
{
"answer": local_answer,
"provider": "local_kb",
"items": hits,
"llm_error": llm_result,
}
)
@qa_bp.post("/advice")
@jwt_required()
def advice():
payload = request.get_json(silent=True) or {}
advice_type = (payload.get("advice_type") or "general").strip().lower()
question = (payload.get("question") or "请根据我的状态给建议").strip()
llm_cfg = _llm_config()
provider = (payload.get("provider") or _default_provider(llm_cfg)).strip().lower()
if provider not in {"fastgpt", "dify"}:
provider = _default_provider(llm_cfg)
retriever = _retriever()
hits = retriever.search(f"{advice_type} {question}", top_k=3)
context = _context_text(hits)
base_text = _advice_template(advice_type)
user = current_user()
llm_result = _ask_by_provider(
provider,
llm_cfg,
f"请给出{advice_type}方向的个性化饮食建议。用户补充: {question}",
f"{base_text}\n\n{context}",
str(user.id) if user else "",
user.username if user else "anonymous",
)
if llm_result.get("ok"):
answer = llm_result.get("answer")
source = provider
else:
answer = f"{base_text}\n\n提示LLM 暂不可用,当前为本地模板建议)"
source = "local_template"
return ok(
{
"advice_type": advice_type,
"answer": answer,
"provider": source,
"knowledge_hits": hits,
"llm_status": llm_result,
}
)
@qa_bp.post("/kb/reload")
@admin_required
def reload_kb():
retriever = _retriever()
retriever.reload()
return ok({"count": len(retriever.documents)}, "知识库已重载")

View File

@@ -0,0 +1,135 @@
import json
from pathlib import Path
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.models import Recipe
from app.utils.auth import admin_required
from app.utils.response import fail, ok
recipe_bp = Blueprint("recipe", __name__)
def _recipe_payload(recipe: Recipe, payload: dict) -> None:
recipe.name = payload.get("name", recipe.name)
recipe.category = payload.get("category", recipe.category or "轻食")
recipe.description = payload.get("description", recipe.description or "")
recipe.calories = float(payload.get("calories", recipe.calories or 0))
recipe.protein = float(payload.get("protein", recipe.protein or 0))
recipe.fat = float(payload.get("fat", recipe.fat or 0))
recipe.carbs = float(payload.get("carbs", recipe.carbs or 0))
recipe.fiber = float(payload.get("fiber", recipe.fiber or 0))
recipe.tags = payload.get("tags", recipe.tags or [])
recipe.difficulty = payload.get("difficulty", recipe.difficulty or "easy")
@recipe_bp.get("")
@jwt_required(optional=True)
def list_recipes():
keyword = (request.args.get("keyword") or "").strip()
category = (request.args.get("category") or "").strip()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 10) or 10), 1), 50)
query = Recipe.query
if keyword:
query = query.filter(Recipe.name.like(f"%{keyword}%"))
if category:
query = query.filter(Recipe.category == category)
pagination = query.order_by(Recipe.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
data = {
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
return ok(data)
@recipe_bp.get("/<int:recipe_id>")
@jwt_required(optional=True)
def get_recipe(recipe_id: int):
recipe = Recipe.query.get(recipe_id)
if not recipe:
return fail("食谱不存在", 404)
return ok(recipe.to_dict())
@recipe_bp.post("")
@admin_required
def create_recipe():
payload = request.get_json(silent=True) or {}
if not payload.get("name"):
return fail("缺少食谱名称", 400)
recipe = Recipe()
_recipe_payload(recipe, payload)
db.session.add(recipe)
db.session.commit()
return ok(recipe.to_dict(), "食谱创建成功")
@recipe_bp.put("/<int:recipe_id>")
@admin_required
def update_recipe(recipe_id: int):
recipe = Recipe.query.get(recipe_id)
if not recipe:
return fail("食谱不存在", 404)
payload = request.get_json(silent=True) or {}
_recipe_payload(recipe, payload)
db.session.commit()
return ok(recipe.to_dict(), "食谱更新成功")
@recipe_bp.delete("/<int:recipe_id>")
@admin_required
def delete_recipe(recipe_id: int):
recipe = Recipe.query.get(recipe_id)
if not recipe:
return fail("食谱不存在", 404)
db.session.delete(recipe)
db.session.commit()
return ok({}, "食谱删除成功")
@recipe_bp.post("/import")
@admin_required
def import_recipes():
payload = request.get_json(silent=True) or {}
rows = payload.get("items")
if not isinstance(rows, list):
seed_file = Path(current_app.root_path).parents[0] / "seed" / "recipes_seed.json"
if seed_file.exists():
with seed_file.open("r", encoding="utf-8-sig") as file:
rows = json.load(file)
else:
return fail("导入数据为空,且未找到默认种子文件", 400)
created_count = 0
updated_count = 0
for item in rows:
name = (item.get("name") or "").strip()
if not name:
continue
recipe = Recipe.query.filter_by(name=name).first()
if recipe:
_recipe_payload(recipe, item)
updated_count += 1
else:
recipe = Recipe()
_recipe_payload(recipe, item)
db.session.add(recipe)
created_count += 1
db.session.commit()
return ok({"created": created_count, "updated": updated_count}, "食谱导入完成")

View File

@@ -0,0 +1,314 @@
from statistics import mean
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.ml.rf_recommender import RandomForestDietRecommender, merge_profile_with_history
from app.models import DietStatus, RecommendationLog, Recipe
from app.utils.auth import current_user
from app.utils.response import fail, ok
recommend_bp = Blueprint("recommend", __name__)
def _build_profile(user, latest_status):
profile = {
"age": user.age,
"height_cm": user.height_cm,
"goal": user.goal,
"occupation": user.occupation,
"weight": 65,
"body_fat": 20,
"exercise_kcal": 300,
"intake_kcal": 1800,
}
if latest_status:
profile.update(
{
"weight": latest_status.weight,
"body_fat": latest_status.body_fat,
"exercise_kcal": latest_status.exercise_kcal,
"intake_kcal": latest_status.intake_kcal,
}
)
return profile
def _goal_adjust(goal: str, recipe: Recipe) -> float:
if goal == "lose_fat":
return recipe.protein * 0.8 - recipe.fat * 0.4 - max(recipe.calories - 500, 0) * 0.04
if goal == "gain_muscle":
return recipe.protein * 1.0 + recipe.carbs * 0.25
if goal == "keto":
return recipe.fat * 0.5 - recipe.carbs * 0.8
return recipe.protein * 0.3 + recipe.fiber * 1.2
def _occupation_adjust(occupation: str, recipe: Recipe) -> float:
occupation = (occupation or "").lower()
if occupation in {"developer", "office"}:
return recipe.fiber * 0.9 + recipe.protein * 0.2
if occupation in {"fitness", "manual"}:
return recipe.protein * 0.6 + recipe.carbs * 0.2
if occupation in {"teacher", "student"}:
return recipe.carbs * 0.15 + recipe.protein * 0.3
return recipe.protein * 0.25
def _health_adjust(profile: dict, recipe: Recipe) -> float:
score = 70.0
body_fat = float(profile.get("body_fat", 20))
intake_kcal = float(profile.get("intake_kcal", 1800))
exercise_kcal = float(profile.get("exercise_kcal", 300))
if body_fat > 28:
score += recipe.protein * 0.5
score -= max(recipe.calories - 480, 0) * 0.05
elif body_fat < 14:
score += recipe.carbs * 0.15 + recipe.fat * 0.1
if intake_kcal > exercise_kcal + 1700:
score -= max(recipe.calories - 450, 0) * 0.04
else:
score += recipe.protein * 0.3
score += min(recipe.fiber * 1.0, 10)
return round(score, 2)
def _serialize_with_score(recipes):
return [
{
**item[0].to_dict(),
"score": round(float(item[1]), 2),
"reason": item[2],
}
for item in recipes
]
def _save_log(user_id: int, rec_type: str, payload: dict) -> None:
log = RecommendationLog(user_id=user_id, rec_type=rec_type, payload=payload)
db.session.add(log)
db.session.commit()
@recommend_bp.get("/current")
@jwt_required()
def recommend_current_status():
user = current_user()
if not user:
return fail("用户不存在", 404)
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
latest = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.first()
)
history = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.limit(14)
.all()
)
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
if not recipes:
return fail("当前没有可推荐食谱,请先导入食谱", 400)
profile = _build_profile(user, latest)
profile = merge_profile_with_history(profile, history)
recommender = RandomForestDietRecommender(current_app.config["MODEL_PATH"])
items = recommender.recommend(profile, recipes, top_k=top_k)
response_data = {
"type": "current_status_rf",
"profile": profile,
"items": items,
}
_save_log(user.id, "current_status_rf", response_data)
return ok(response_data)
@recommend_bp.get("/health")
@jwt_required()
def recommend_by_health():
user = current_user()
if not user:
return fail("用户不存在", 404)
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
latest = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.first()
)
profile = _build_profile(user, latest)
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
if not recipes:
return fail("当前没有可推荐食谱,请先导入食谱", 400)
scored = []
for recipe in recipes:
score = _health_adjust(profile, recipe)
reason = "结合当前体脂和摄入消耗状态,推荐更合理的营养结构。"
scored.append((recipe, score, reason))
scored.sort(key=lambda row: row[1], reverse=True)
items = _serialize_with_score(scored[:top_k])
response_data = {
"type": "health_state",
"profile": profile,
"items": items,
}
_save_log(user.id, "health_state", response_data)
return ok(response_data)
@recommend_bp.get("/plan/history")
@jwt_required()
def recommend_plan_by_history():
user = current_user()
if not user:
return fail("用户不存在", 404)
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
history = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.limit(30)
.all()
)
if len(history) < 3:
return fail("历史数据太少至少记录3天后可用该功能", 400)
avg_weight = mean([h.weight for h in history])
early_weight = history[-1].weight
latest_weight = history[0].weight
weight_trend = latest_weight - early_weight
avg_intake = mean([h.intake_kcal for h in history])
avg_exercise = mean([h.exercise_kcal for h in history])
cal_limit = 520
strategy = "维持计划"
if weight_trend > 0.8 or avg_intake > avg_exercise + 1700:
cal_limit = 440
strategy = "减脂优先计划"
elif weight_trend < -0.8:
cal_limit = 620
strategy = "增肌恢复计划"
recipes = Recipe.query.filter(Recipe.calories <= cal_limit).order_by(Recipe.protein.desc()).limit(top_k).all()
items = []
for recipe in recipes:
items.append(
{
**recipe.to_dict(),
"score": round(recipe.protein * 2 - recipe.fat * 0.3, 2),
"reason": f"基于历史趋势({strategy}),优先推荐该热量区间食谱。",
}
)
response_data = {
"type": "history_plan",
"strategy": strategy,
"metrics": {
"avg_weight": round(avg_weight, 2),
"weight_trend": round(weight_trend, 2),
"avg_intake": round(avg_intake, 1),
"avg_exercise": round(avg_exercise, 1),
},
"items": items,
}
_save_log(user.id, "history_plan", response_data)
return ok(response_data)
@recommend_bp.get("/plan/goal")
@jwt_required()
def recommend_plan_by_goal():
user = current_user()
if not user:
return fail("用户不存在", 404)
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
goal = request.args.get("goal") or user.goal
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
if not recipes:
return fail("当前没有可推荐食谱,请先导入食谱", 400)
scored = []
for recipe in recipes:
score = 60 + _goal_adjust(goal, recipe)
reason = f"根据目标({goal})筛选营养比例更匹配的食谱。"
scored.append((recipe, score, reason))
scored.sort(key=lambda x: x[1], reverse=True)
items = _serialize_with_score(scored[:top_k])
response_data = {
"type": "future_goal_plan",
"goal": goal,
"items": items,
}
_save_log(user.id, "future_goal_plan", response_data)
return ok(response_data)
@recommend_bp.get("/plan/occupation")
@jwt_required()
def recommend_plan_by_occupation():
user = current_user()
if not user:
return fail("用户不存在", 404)
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
occupation = request.args.get("occupation") or user.occupation
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
if not recipes:
return fail("当前没有可推荐食谱,请先导入食谱", 400)
scored = []
for recipe in recipes:
score = 65 + _occupation_adjust(occupation, recipe)
reason = f"结合职业({occupation})的能量消耗特点进行推荐。"
scored.append((recipe, score, reason))
scored.sort(key=lambda x: x[1], reverse=True)
items = _serialize_with_score(scored[:top_k])
response_data = {
"type": "occupation_plan",
"occupation": occupation,
"items": items,
}
_save_log(user.id, "occupation_plan", response_data)
return ok(response_data)
@recommend_bp.get("/logs")
@jwt_required()
def my_recommend_logs():
user = current_user()
if not user:
return fail("用户不存在", 404)
limit = min(max(int(request.args.get("limit", 20) or 20), 1), 100)
rows = (
RecommendationLog.query.filter_by(user_id=user.id)
.order_by(RecommendationLog.id.desc())
.limit(limit)
.all()
)
return ok([item.to_dict() for item in rows])

View File

@@ -0,0 +1,334 @@
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
from app.models import DetectionConfig, SpamPredictionLog, SpamTrainingSample
from app.utils.auth import admin_required, current_user
from app.utils.response import fail, ok
spam_bp = Blueprint("spam", __name__)
def _classifier() -> NaiveBayesSpamClassifier:
return NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
def _active_samples() -> list[dict]:
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
return [{"text": row.text, "label": row.label} for row in rows]
def _ensure_ready() -> NaiveBayesSpamClassifier:
clf = _classifier()
samples = _active_samples()
clf.ensure_ready(samples)
return clf
def _threshold() -> float:
row = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
return float(row.spam_threshold) if row else 0.75
@spam_bp.post("/predict")
@jwt_required()
def predict_one():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
if len(text) < 2:
return fail("请输入至少2个字符的待识别文本", 400)
clf = _ensure_ready()
result = clf.predict(text)
threshold = _threshold()
blocked = float(result["spam_probability"]) >= threshold
row = SpamPredictionLog(
user_id=user.id,
text=result["text"],
prediction=result["prediction"],
spam_probability=result["spam_probability"],
ham_probability=result["ham_probability"],
confidence=result["confidence"],
reason_tokens=result["reason_tokens"],
model_version=result.get("model_version", ""),
)
db.session.add(row)
db.session.commit()
return ok({**result, "log_id": row.id, "threshold": threshold, "blocked_by_threshold": blocked}, "识别成功")
@spam_bp.post("/predict/batch")
@jwt_required()
def predict_batch():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
items = payload.get("items") or []
if not isinstance(items, list) or not items:
return fail("items 必须是非空数组", 400)
if len(items) > 100:
return fail("单次最多识别100条", 400)
clf = _ensure_ready()
rows = []
results = []
threshold = _threshold()
for text in items:
content = (text or "").strip()
if len(content) < 2:
continue
result = clf.predict(content)
result["blocked_by_threshold"] = float(result["spam_probability"]) >= threshold
rows.append(
SpamPredictionLog(
user_id=user.id,
text=result["text"],
prediction=result["prediction"],
spam_probability=result["spam_probability"],
ham_probability=result["ham_probability"],
confidence=result["confidence"],
reason_tokens=result["reason_tokens"],
model_version=result.get("model_version", ""),
)
)
results.append(result)
if not rows:
return fail("没有可识别的有效文本", 400)
db.session.add_all(rows)
db.session.commit()
spam_count = len([item for item in results if item["prediction"] == "spam"])
blocked_count = len([item for item in results if item["blocked_by_threshold"]])
ham_count = len(results) - spam_count
return ok(
{
"items": results,
"summary": {
"total": len(results),
"spam_count": spam_count,
"ham_count": ham_count,
"blocked_count": blocked_count,
"spam_ratio": round(spam_count / len(results), 4) if results else 0,
"blocked_ratio": round(blocked_count / len(results), 4) if results else 0,
"threshold": threshold,
},
},
"批量识别完成",
)
@spam_bp.get("/history")
@jwt_required()
def my_history():
user = current_user()
if not user:
return fail("用户不存在", 404)
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
pagination = (
SpamPredictionLog.query.filter_by(user_id=user.id)
.order_by(SpamPredictionLog.id.desc())
.paginate(page=page, per_page=page_size, error_out=False)
)
return ok(
{
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@spam_bp.delete("/history/<int:log_id>")
@jwt_required()
def delete_history(log_id: int):
user = current_user()
if not user:
return fail("用户不存在", 404)
row = SpamPredictionLog.query.filter_by(id=log_id, user_id=user.id).first()
if not row:
return fail("记录不存在", 404)
db.session.delete(row)
db.session.commit()
return ok({}, "记录已删除")
@spam_bp.post("/feedback")
@jwt_required()
def save_feedback():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
if len(text) < 2:
return fail("文本至少2个字符", 400)
if not label:
return fail("label 必须是 spam 或 ham", 400)
row = SpamTrainingSample(text=text, label=label, source="feedback", created_by=user.id, is_active=True)
db.session.add(row)
db.session.commit()
return ok(row.to_dict(), "反馈样本已记录")
@spam_bp.get("/model/info")
@jwt_required(optional=True)
def model_info():
clf = _classifier()
clf.load()
info = clf.model_info()
info["threshold"] = _threshold()
return ok(info)
@spam_bp.post("/train")
@admin_required
def train_model():
clf = _classifier()
samples = _active_samples()
metadata = clf.train(samples)
return ok(metadata, "模型训练完成")
@spam_bp.get("/samples")
@admin_required
def list_samples():
keyword = (request.args.get("keyword") or "").strip()
label = (request.args.get("label") or "").strip().lower()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
query = SpamTrainingSample.query
if keyword:
query = query.filter(SpamTrainingSample.text.like(f"%{keyword}%"))
if label in {"spam", "ham"}:
query = query.filter(SpamTrainingSample.label == label)
pagination = query.order_by(SpamTrainingSample.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
return ok(
{
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@spam_bp.post("/samples")
@admin_required
def create_sample():
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
if len(text) < 2:
return fail("文本至少2个字符", 400)
if not label:
return fail("label 必须是 spam 或 ham", 400)
user = current_user()
row = SpamTrainingSample(text=text, label=label, source="import", created_by=user.id if user else None, is_active=True)
db.session.add(row)
db.session.commit()
return ok(row.to_dict(), "样本创建成功")
@spam_bp.put("/samples/<int:sample_id>")
@admin_required
def update_sample(sample_id: int):
row = SpamTrainingSample.query.get(sample_id)
if not row:
return fail("样本不存在", 404)
payload = request.get_json(silent=True) or {}
if "text" in payload:
text = (payload.get("text") or "").strip()
if len(text) < 2:
return fail("文本至少2个字符", 400)
row.text = text
if "label" in payload:
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
if not label:
return fail("label 必须是 spam 或 ham", 400)
row.label = label
if "is_active" in payload:
row.is_active = bool(payload.get("is_active"))
db.session.commit()
return ok(row.to_dict(), "样本更新成功")
@spam_bp.delete("/samples/<int:sample_id>")
@admin_required
def delete_sample(sample_id: int):
row = SpamTrainingSample.query.get(sample_id)
if not row:
return fail("样本不存在", 404)
db.session.delete(row)
db.session.commit()
return ok({}, "样本已删除")
@spam_bp.post("/samples/import")
@admin_required
def import_samples():
payload = request.get_json(silent=True) or {}
items = payload.get("items") or []
if not isinstance(items, list) or not items:
return fail("items 必须是非空数组", 400)
user = current_user()
created = 0
updated = 0
for item in items:
text = (item.get("text") or "").strip()
label = NaiveBayesSpamClassifier.normalize_label(item.get("label"))
if len(text) < 2 or not label:
continue
row = SpamTrainingSample.query.filter_by(text=text).first()
if row:
row.label = label
row.is_active = bool(item.get("is_active", True))
row.source = item.get("source") or row.source
updated += 1
else:
row = SpamTrainingSample(
text=text,
label=label,
source=item.get("source") or "import",
created_by=user.id if user else None,
is_active=bool(item.get("is_active", True)),
)
db.session.add(row)
created += 1
db.session.commit()
return ok({"created": created, "updated": updated}, "样本导入完成")

View File

@@ -0,0 +1,62 @@
from flask import Blueprint, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.models import User
from app.utils.auth import current_user
from app.utils.response import fail, ok
user_bp = Blueprint("user", __name__)
@user_bp.get("/profile")
@jwt_required()
def get_profile():
user = current_user()
if not user:
return fail("用户不存在", 404)
return ok(user.to_dict())
@user_bp.put("/profile")
@jwt_required()
def update_profile():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
if "nickname" in payload:
user.nickname = (payload.get("nickname") or user.nickname).strip() or user.nickname
if "company" in payload:
user.company = (payload.get("company") or "").strip()
if "title" in payload:
user.title = (payload.get("title") or "").strip()
if "phone" in payload:
user.phone = (payload.get("phone") or "").strip()
new_password = payload.get("password")
if new_password:
if len(new_password) < 6:
return fail("新密码至少6位", 400)
user.set_password(new_password)
db.session.commit()
return ok(user.to_dict(), "个人信息更新成功")
@user_bp.get("/search")
@jwt_required()
def search_users():
keyword = (request.args.get("keyword") or "").strip()
if not keyword:
return ok([])
users = (
User.query.filter(User.username.like(f"%{keyword}%") | User.nickname.like(f"%{keyword}%"))
.order_by(User.id.desc())
.limit(20)
.all()
)
return ok([item.to_dict() for item in users])

View File

@@ -0,0 +1 @@

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

25
backend/app/utils/auth.py Normal file
View File

@@ -0,0 +1,25 @@
from functools import wraps
from flask import jsonify
from flask_jwt_extended import get_jwt, get_jwt_identity, verify_jwt_in_request
from app.models import User
def current_user() -> User | None:
identity = get_jwt_identity()
if not identity:
return None
return User.query.get(int(identity))
def admin_required(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
verify_jwt_in_request()
claims = get_jwt()
if not claims.get("is_admin", False):
return jsonify({"code": 403, "message": "需要管理员权限"}), 403
return fn(*args, **kwargs)
return wrapper

View File

@@ -0,0 +1,9 @@
from flask import jsonify
def ok(data=None, message="success"):
return jsonify({"code": 0, "message": message, "data": data or {}})
def fail(message="error", code=400, data=None):
return jsonify({"code": code, "message": message, "data": data or {}}), code