This commit is contained in:
刘正航
2026-04-21 22:45:19 +08:00
commit b5237f9038
159 changed files with 7769 additions and 0 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

37
backend/app/__init__.py Normal file
View File

@@ -0,0 +1,37 @@
from flask import Flask
from app.config import Config
from app.extensions import cors, db, jwt
from app.routes.admin_routes import admin_bp
from app.routes.auth_routes import auth_bp
from app.routes.content_routes import content_bp
from app.routes.spam_routes import spam_bp
from app.routes.user_routes import user_bp
def create_app() -> Flask:
app = Flask(__name__)
app.config.from_object(Config)
db.init_app(app)
jwt.init_app(app)
cors.init_app(app, supports_credentials=True)
app.register_blueprint(auth_bp, url_prefix="/api/auth")
app.register_blueprint(user_bp, url_prefix="/api/user")
app.register_blueprint(spam_bp, url_prefix="/api/spam")
app.register_blueprint(content_bp, url_prefix="/api/content")
app.register_blueprint(admin_bp, url_prefix="/api/admin")
@app.get("/api/health")
def health_check():
return {
"code": 0,
"message": "ok",
"data": {
"service": "spam-detect-backend",
"version": "2.1.0",
},
}
return app

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

37
backend/app/config.py Normal file
View File

@@ -0,0 +1,37 @@
import json
import os
from pathlib import Path
from urllib.parse import quote_plus
BASE_DIR = Path(__file__).resolve().parents[1]
MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json"
def load_mysql_config() -> dict:
if not MYSQL_CONFIG_PATH.exists():
raise FileNotFoundError(f"未找到 MySQL 配置文件: {MYSQL_CONFIG_PATH}")
with MYSQL_CONFIG_PATH.open("r", encoding="utf-8-sig") as file:
return json.load(file)
def build_mysql_uri(mysql_cfg: dict) -> str:
user = mysql_cfg["user"]
password = quote_plus(mysql_cfg["password"])
host = mysql_cfg.get("host", "127.0.0.1")
port = mysql_cfg.get("port", 3306)
database = mysql_cfg["database"]
charset = mysql_cfg.get("charset", "utf8mb4")
return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}?charset={charset}"
class Config:
MYSQL_CONFIG = load_mysql_config()
SQLALCHEMY_DATABASE_URI = build_mysql_uri(MYSQL_CONFIG)
SQLALCHEMY_TRACK_MODIFICATIONS = False
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "replace-this-jwt-secret-key-in-production")
SECRET_KEY = os.getenv("FLASK_SECRET_KEY", "replace-this-flask-secret-key-in-production")
SPAM_DATASET_PATH = str(BASE_DIR / "seed" / "spam_samples_seed.json")
NB_MODEL_PATH = str(BASE_DIR / "models" / "spam_nb_model.joblib")

View File

@@ -0,0 +1,8 @@
from flask_cors import CORS
from flask_jwt_extended import JWTManager
from flask_sqlalchemy import SQLAlchemy
db = SQLAlchemy()
jwt = JWTManager()
cors = CORS()

View File

@@ -0,0 +1 @@

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,156 @@
import hashlib
from collections import Counter
from datetime import datetime
from pathlib import Path
import joblib
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
VALID_LABELS = {"spam", "ham"}
class NaiveBayesSpamClassifier:
def __init__(self, model_path: str):
self.model_path = Path(model_path)
self.vectorizer = None
self.model = None
self.metadata = {}
@staticmethod
def normalize_label(label: str) -> str:
value = (label or "").strip().lower()
return value if value in VALID_LABELS else ""
@staticmethod
def normalize_text(text: str) -> str:
return " ".join((text or "").strip().split())
@staticmethod
def _to_metadata(samples: list[dict], version_seed: str) -> dict:
dist = Counter([item["label"] for item in samples])
digest = hashlib.md5(version_seed.encode("utf-8")).hexdigest()[:16]
return {
"trained_at": datetime.utcnow().isoformat(),
"sample_count": len(samples),
"label_distribution": dict(dist),
"version": f"nb-{digest}",
}
def train(self, samples: list[dict]) -> dict:
clean_samples = []
for row in samples:
text = self.normalize_text(row.get("text"))
label = self.normalize_label(row.get("label"))
if not text or not label:
continue
clean_samples.append({"text": text, "label": label})
if len(clean_samples) < 10:
raise ValueError("训练样本太少至少需要10条有效样本")
texts = [item["text"] for item in clean_samples]
labels = [item["label"] for item in clean_samples]
vectorizer = TfidfVectorizer(analyzer="char", ngram_range=(1, 2), min_df=1)
x = vectorizer.fit_transform(texts)
model = MultinomialNB(alpha=0.4)
model.fit(x, labels)
version_seed = "||".join([f"{item['label']}::{item['text']}" for item in clean_samples])
metadata = self._to_metadata(clean_samples, version_seed)
self.model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump({"vectorizer": vectorizer, "model": model, "metadata": metadata}, self.model_path)
self.vectorizer = vectorizer
self.model = model
self.metadata = metadata
return metadata
def load(self) -> bool:
if not self.model_path.exists():
return False
payload = joblib.load(self.model_path)
self.vectorizer = payload.get("vectorizer")
self.model = payload.get("model")
self.metadata = payload.get("metadata", {})
return self.vectorizer is not None and self.model is not None
def ensure_ready(self, samples: list[dict]) -> dict:
if self.load():
return self.metadata
return self.train(samples)
def predict(self, text: str) -> dict:
if self.vectorizer is None or self.model is None:
raise RuntimeError("模型未加载,请先训练")
cleaned = self.normalize_text(text)
if len(cleaned) < 2:
raise ValueError("待识别文本至少2个字符")
x = self.vectorizer.transform([cleaned])
probs = self.model.predict_proba(x)[0]
classes = list(self.model.classes_)
spam_idx = classes.index("spam") if "spam" in classes else 0
ham_idx = classes.index("ham") if "ham" in classes else 0
spam_prob = float(probs[spam_idx])
ham_prob = float(probs[ham_idx])
prediction = "spam" if spam_prob >= ham_prob else "ham"
reason_tokens = self._extract_reason_tokens(cleaned, classes, x)
confidence = max(spam_prob, ham_prob)
return {
"text": cleaned,
"prediction": prediction,
"prediction_text": "垃圾信息" if prediction == "spam" else "正常信息",
"spam_probability": round(spam_prob, 4),
"ham_probability": round(ham_prob, 4),
"confidence": round(confidence, 4),
"reason_tokens": reason_tokens,
"model_version": self.metadata.get("version", ""),
"trained_at": self.metadata.get("trained_at"),
}
def _extract_reason_tokens(self, text: str, classes: list[str], x_row) -> list[str]:
try:
vocab = self.vectorizer.vocabulary_
feature_names = self.vectorizer.get_feature_names_out()
class_log_prob = self.model.feature_log_prob_
spam_idx = classes.index("spam") if "spam" in classes else 0
ham_idx = classes.index("ham") if "ham" in classes else 0
token_counter = Counter()
for idx in x_row.nonzero()[1]:
token = feature_names[idx]
token_counter[token] += 1
scored = []
for token in token_counter:
idx = vocab.get(token)
if idx is None:
continue
delta = class_log_prob[spam_idx][idx] - class_log_prob[ham_idx][idx]
scored.append((token, delta))
scored.sort(key=lambda row: abs(row[1]), reverse=True)
return [token for token, _ in scored[:5]]
except Exception:
return list(text[:5])
def model_info(self) -> dict:
return {
"ready": self.vectorizer is not None and self.model is not None,
"model_path": str(self.model_path),
"version": self.metadata.get("version", ""),
"trained_at": self.metadata.get("trained_at"),
"sample_count": int(self.metadata.get("sample_count", 0) or 0),
"label_distribution": self.metadata.get("label_distribution", {}),
}

View File

@@ -0,0 +1,252 @@
import hashlib
import json
from datetime import datetime
from pathlib import Path
import joblib
import numpy as np
from sklearn.ensemble import RandomForestRegressor
GOAL_MAP = {
"maintain": 0,
"lose_fat": 1,
"gain_muscle": 2,
"keto": 3,
}
OCCUPATION_MAP = {
"通用": 0,
"student": 1,
"office": 2,
"teacher": 3,
"developer": 4,
"healthcare": 5,
"fitness": 6,
"manual": 7,
}
class RandomForestDietRecommender:
def __init__(self, model_path: str):
self.model_path = Path(model_path)
self.model = None
self.recipe_signature = None
@staticmethod
def _encode_goal(goal: str) -> int:
return GOAL_MAP.get(goal or "maintain", 0)
@staticmethod
def _encode_occupation(occupation: str) -> int:
occupation = occupation or "通用"
if occupation in OCCUPATION_MAP:
return OCCUPATION_MAP[occupation]
return OCCUPATION_MAP["通用"]
def _signature(self, recipes: list) -> str:
raw = [
{
"id": item.id,
"name": item.name,
"calories": item.calories,
"protein": item.protein,
"fat": item.fat,
"carbs": item.carbs,
"fiber": item.fiber,
"updated": item.updated_at.isoformat() if item.updated_at else "",
}
for item in recipes
]
raw_json = json.dumps(raw, ensure_ascii=False, sort_keys=True)
return hashlib.md5(raw_json.encode("utf-8")).hexdigest()
@staticmethod
def _daily_target_kcal(profile: dict) -> float:
goal = profile.get("goal", "maintain")
baseline = 1800 + float(profile.get("exercise_kcal", 0)) * 0.4
if goal == "lose_fat":
baseline *= 0.82
elif goal == "gain_muscle":
baseline *= 1.12
elif goal == "keto":
baseline *= 0.9
return max(baseline, 1200)
@staticmethod
def _heuristic_score(profile: dict, recipe) -> float:
goal = profile.get("goal", "maintain")
daily_target = RandomForestDietRecommender._daily_target_kcal(profile)
target_per_meal = daily_target / 3
cal_gap_ratio = abs(recipe.calories - target_per_meal) / max(target_per_meal, 1)
protein_ratio = recipe.protein / max(recipe.calories, 1)
carbs_ratio = recipe.carbs / max(recipe.calories, 1)
fat_ratio = recipe.fat / max(recipe.calories, 1)
score = 100.0
score -= min(cal_gap_ratio * 55, 50)
if goal == "lose_fat":
score += min(recipe.protein * 0.6, 18)
score -= max((recipe.fat - 20) * 0.7, 0)
score -= max((recipe.carbs - 55) * 0.3, 0)
elif goal == "gain_muscle":
score += min(recipe.protein * 0.8, 26)
score += min(recipe.carbs * 0.2, 10)
elif goal == "keto":
score += min(recipe.fat * 0.4, 18)
score -= max(recipe.carbs - 30, 0) * 0.8
else:
score += min(recipe.fiber * 1.2, 8)
body_fat = float(profile.get("body_fat", 20))
if body_fat > 28:
score -= max(recipe.calories - 520, 0) * 0.03
intake_kcal = float(profile.get("intake_kcal", 1800))
if intake_kcal > daily_target:
score -= max(recipe.calories - 460, 0) * 0.02
score += np.clip((protein_ratio - 0.12) * 100, -8, 8)
score += np.clip((0.08 - carbs_ratio) * 80 if goal == "keto" else 0, -6, 6)
score += np.clip((0.25 - fat_ratio) * 30 if goal == "lose_fat" else 0, -5, 5)
return float(np.clip(score, 1, 100))
def _build_feature(self, profile: dict, recipe) -> list:
return [
float(profile.get("weight", 65)),
float(profile.get("body_fat", 20)),
float(profile.get("exercise_kcal", 300)),
float(profile.get("intake_kcal", 1800)),
float(profile.get("age", 25)),
float(profile.get("height_cm", 170)),
float(self._encode_goal(profile.get("goal", "maintain"))),
float(self._encode_occupation(profile.get("occupation", "通用"))),
float(recipe.calories),
float(recipe.protein),
float(recipe.fat),
float(recipe.carbs),
float(recipe.fiber or 0),
]
def _sample_profiles(self, n: int = 600) -> list:
rng = np.random.default_rng(2026)
goals = list(GOAL_MAP.keys())
occupations = list(OCCUPATION_MAP.keys())
profiles = []
for _ in range(n):
goal = goals[int(rng.integers(0, len(goals)))]
occupation = occupations[int(rng.integers(0, len(occupations)))]
profiles.append(
{
"weight": float(rng.uniform(45, 100)),
"body_fat": float(rng.uniform(10, 38)),
"exercise_kcal": float(rng.uniform(50, 850)),
"intake_kcal": float(rng.uniform(1200, 3200)),
"age": float(rng.uniform(18, 55)),
"height_cm": float(rng.uniform(150, 190)),
"goal": goal,
"occupation": occupation,
}
)
return profiles
def train(self, recipes: list) -> None:
if not recipes:
raise ValueError("训练随机森林前至少需要 1 条食谱数据")
x_rows = []
y_rows = []
sampled_profiles = self._sample_profiles()
for profile in sampled_profiles:
for recipe in recipes:
x_rows.append(self._build_feature(profile, recipe))
y_rows.append(self._heuristic_score(profile, recipe))
x = np.array(x_rows)
y = np.array(y_rows)
model = RandomForestRegressor(
n_estimators=240,
random_state=2026,
max_depth=12,
min_samples_leaf=2,
n_jobs=-1,
)
model.fit(x, y)
self.model = model
self.recipe_signature = self._signature(recipes)
self.model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(
{
"model": model,
"recipe_signature": self.recipe_signature,
"trained_at": datetime.utcnow().isoformat(),
},
self.model_path,
)
def load_or_train(self, recipes: list) -> None:
current_signature = self._signature(recipes)
if self.model_path.exists():
payload = joblib.load(self.model_path)
if payload.get("recipe_signature") == current_signature:
self.model = payload["model"]
self.recipe_signature = current_signature
return
self.train(recipes)
@staticmethod
def _build_reason(profile: dict, recipe, score: float) -> str:
goal = profile.get("goal", "maintain")
if goal == "lose_fat":
return f"热量适中,蛋白质 {recipe.protein}g适合减脂期控热量和保肌。"
if goal == "gain_muscle":
return f"蛋白质与碳水配置较高,适合增肌训练后的恢复。"
if goal == "keto":
return f"碳水 {recipe.carbs}g偏低碳结构适合生酮期参考。"
if score > 80:
return "营养均衡度高,适合作为日常轻食搭配。"
return "综合营养结构较均衡,可作为个性化备选方案。"
def recommend(self, profile: dict, recipes: list, top_k: int = 5) -> list:
if not recipes:
return []
self.load_or_train(recipes)
x = np.array([self._build_feature(profile, recipe) for recipe in recipes])
pred_scores = self.model.predict(x)
result = []
for recipe, score in zip(recipes, pred_scores):
row = recipe.to_dict()
row["rf_score"] = round(float(score), 2)
row["reason"] = self._build_reason(profile, recipe, float(score))
result.append(row)
result.sort(key=lambda item: item["rf_score"], reverse=True)
return result[:top_k]
def merge_profile_with_history(base_profile: dict, history: list) -> dict:
if not history:
return base_profile
weights = [item.weight for item in history]
body_fats = [item.body_fat for item in history]
exercise = [item.exercise_kcal for item in history]
intake = [item.intake_kcal for item in history]
merged = dict(base_profile)
merged["weight"] = float(np.mean(weights))
merged["body_fat"] = float(np.mean(body_fats))
merged["exercise_kcal"] = float(np.mean(exercise))
merged["intake_kcal"] = float(np.mean(intake))
return merged

180
backend/app/models.py Normal file
View File

@@ -0,0 +1,180 @@
from datetime import datetime
from werkzeug.security import check_password_hash, generate_password_hash
from app.extensions import db
class User(db.Model):
__tablename__ = "users"
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(64), unique=True, nullable=False, index=True)
password_hash = db.Column(db.String(255), nullable=False)
nickname = db.Column(db.String(64), nullable=False)
company = db.Column(db.String(128), default="")
title = db.Column(db.String(64), default="")
phone = db.Column(db.String(32), default="")
is_admin = db.Column(db.Boolean, default=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
prediction_logs = db.relationship("SpamPredictionLog", backref="user", lazy=True, cascade="all, delete-orphan")
training_samples = db.relationship("SpamTrainingSample", backref="creator", lazy=True, foreign_keys="SpamTrainingSample.created_by")
sent_posts = db.relationship("ContentPost", backref="author", lazy=True, foreign_keys="ContentPost.user_id")
received_posts = db.relationship("ContentPost", backref="recipient", lazy=True, foreign_keys="ContentPost.recipient_user_id")
reviewed_posts = db.relationship("ContentPost", backref="reviewer", lazy=True, foreign_keys="ContentPost.manual_review_by")
def set_password(self, password: str) -> None:
self.password_hash = generate_password_hash(password)
def check_password(self, password: str) -> bool:
return check_password_hash(self.password_hash, password)
def to_dict(self) -> dict:
return {
"id": self.id,
"username": self.username,
"nickname": self.nickname,
"company": self.company,
"title": self.title,
"phone": self.phone,
"is_admin": self.is_admin,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class SpamTrainingSample(db.Model):
__tablename__ = "spam_training_samples"
id = db.Column(db.Integer, primary_key=True)
text = db.Column(db.Text, nullable=False)
label = db.Column(db.String(16), nullable=False, index=True) # spam | ham
source = db.Column(db.String(32), default="seed") # seed | import | feedback | manual_review
created_by = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True, index=True)
is_active = db.Column(db.Boolean, default=True)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def to_dict(self) -> dict:
return {
"id": self.id,
"text": self.text,
"label": self.label,
"source": self.source,
"created_by": self.created_by,
"is_active": self.is_active,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class SpamPredictionLog(db.Model):
__tablename__ = "spam_prediction_logs"
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False, index=True)
text = db.Column(db.Text, nullable=False)
prediction = db.Column(db.String(16), nullable=False) # spam | ham
spam_probability = db.Column(db.Float, nullable=False)
ham_probability = db.Column(db.Float, nullable=False)
confidence = db.Column(db.Float, nullable=False)
reason_tokens = db.Column(db.JSON, default=list)
model_version = db.Column(db.String(64), default="")
created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
def to_dict(self) -> dict:
return {
"id": self.id,
"user_id": self.user_id,
"text": self.text,
"prediction": self.prediction,
"spam_probability": round(float(self.spam_probability), 4),
"ham_probability": round(float(self.ham_probability), 4),
"confidence": round(float(self.confidence), 4),
"reason_tokens": self.reason_tokens or [],
"model_version": self.model_version,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
class DetectionConfig(db.Model):
__tablename__ = "detection_configs"
id = db.Column(db.Integer, primary_key=True)
spam_threshold = db.Column(db.Float, nullable=False, default=0.75)
updated_by = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def to_dict(self) -> dict:
return {
"id": self.id,
"spam_threshold": round(float(self.spam_threshold), 4),
"updated_by": self.updated_by,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class ContentPost(db.Model):
__tablename__ = "content_posts"
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False, index=True)
recipient_user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True, index=True)
text = db.Column(db.Text, nullable=False)
visibility = db.Column(db.String(16), nullable=False, default="public") # public | private | direct
status = db.Column(db.String(16), nullable=False, default="published") # published | blocked
prediction = db.Column(db.String(16), nullable=False, default="ham")
spam_probability = db.Column(db.Float, nullable=False, default=0)
ham_probability = db.Column(db.Float, nullable=False, default=0)
confidence = db.Column(db.Float, nullable=False, default=0)
threshold = db.Column(db.Float, nullable=False, default=0.75)
reason_tokens = db.Column(db.JSON, default=list)
model_version = db.Column(db.String(64), default="")
manual_review_status = db.Column(db.String(32), nullable=False, default="none") # none | pending | confirmed_spam | approved_ham
manual_review_by = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True)
manual_review_note = db.Column(db.String(255), default="")
manual_review_at = db.Column(db.DateTime, nullable=True)
appeal_status = db.Column(db.String(16), nullable=False, default="none") # none | pending | approved | rejected
appeal_reason = db.Column(db.String(255), default="")
appeal_admin_note = db.Column(db.String(255), default="")
appeal_submitted_at = db.Column(db.DateTime, nullable=True)
appeal_processed_at = db.Column(db.DateTime, nullable=True)
appeal_processed_by = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True)
created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def to_dict(self) -> dict:
return {
"id": self.id,
"user_id": self.user_id,
"recipient_user_id": self.recipient_user_id,
"text": self.text,
"visibility": self.visibility,
"status": self.status,
"prediction": self.prediction,
"spam_probability": round(float(self.spam_probability), 4),
"ham_probability": round(float(self.ham_probability), 4),
"confidence": round(float(self.confidence), 4),
"threshold": round(float(self.threshold), 4),
"reason_tokens": self.reason_tokens or [],
"model_version": self.model_version,
"manual_review_status": self.manual_review_status,
"manual_review_by": self.manual_review_by,
"manual_review_note": self.manual_review_note,
"manual_review_at": self.manual_review_at.isoformat() if self.manual_review_at else None,
"appeal_status": self.appeal_status,
"appeal_reason": self.appeal_reason,
"appeal_admin_note": self.appeal_admin_note,
"appeal_submitted_at": self.appeal_submitted_at.isoformat() if self.appeal_submitted_at else None,
"appeal_processed_at": self.appeal_processed_at.isoformat() if self.appeal_processed_at else None,
"appeal_processed_by": self.appeal_processed_by,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}

View File

@@ -0,0 +1 @@

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,151 @@
import json
from pathlib import Path
from typing import Any
import requests
class LLMConfigManager:
def __init__(self, config_path: str):
self.config_path = Path(config_path)
def load(self) -> dict:
if not self.config_path.exists():
return {}
with self.config_path.open("r", encoding="utf-8-sig") as file:
return json.load(file)
def _join_url(base_url: str, path: str) -> str:
return f"{base_url.rstrip('/')}/{path.lstrip('/')}"
def _timeout(config: dict) -> int:
section = config.get("request", {}) if isinstance(config, dict) else {}
try:
return int(section.get("timeout_seconds", 45))
except Exception:
return 45
def _extract_openai_text(payload: Any) -> str:
if isinstance(payload, dict):
choices = payload.get("choices")
if isinstance(choices, list) and choices:
msg = choices[0].get("message") if isinstance(choices[0], dict) else {}
if isinstance(msg, dict) and msg.get("content"):
return str(msg["content"])
if payload.get("answer"):
return str(payload["answer"])
if payload.get("data"):
return json.dumps(payload["data"], ensure_ascii=False)
return json.dumps(payload, ensure_ascii=False) if isinstance(payload, (dict, list)) else str(payload)
def ask_fastgpt(config: dict, prompt: str, context: str = "", custom_uid: str = "") -> dict:
section = config.get("fastgpt", {})
base_url = section.get("base_url", "").strip()
api_key = section.get("api_key", "").strip()
chat_id = str(section.get("chat_id", "111"))
model = section.get("model", "")
timeout = _timeout(config)
if not base_url or not api_key:
return {"ok": False, "message": "FastGPT 未配置 base_url 或 api_key"}
url = _join_url(base_url, "/v1/chat/completions")
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
content = prompt if not context else f"请基于以下本地知识库内容回答。\n\n知识库:\n{context}\n\n问题: {prompt}"
messages = []
system_prompt = section.get("system_prompt")
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": content})
use_custom_uid = custom_uid or section.get("custom_uid", "")
payload = {
"chatId": chat_id,
"stream": False,
"detail": False,
"messages": messages,
}
if model:
payload["model"] = model
if use_custom_uid:
payload["customUid"] = use_custom_uid
try:
response = requests.post(url, headers=headers, json=payload, timeout=timeout, allow_redirects=True)
if response.status_code >= 400:
return {
"ok": False,
"message": f"FastGPT 调用失败: HTTP {response.status_code}",
"detail": response.text,
}
data = response.json()
return {
"ok": True,
"provider": "fastgpt",
"raw": data,
"answer": _extract_openai_text(data),
}
except Exception as exc:
return {
"ok": False,
"message": "FastGPT 请求异常",
"detail": str(exc),
}
def ask_dify(config: dict, prompt: str, context: str = "", user: str = "anonymous") -> dict:
section = config.get("dify", {})
base_url = section.get("base_url", "").strip()
api_key = section.get("api_key", "").strip()
timeout = _timeout(config)
if not base_url or not api_key:
return {"ok": False, "message": "Dify 未配置 base_url 或 api_key"}
url = _join_url(base_url, "/v1/chat-messages")
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
query = prompt if not context else f"知识库上下文:\n{context}\n\n用户问题:\n{prompt}"
payload = {
"inputs": {},
"query": query,
"response_mode": "blocking",
"conversation_id": "",
"user": user,
}
try:
response = requests.post(url, headers=headers, json=payload, timeout=timeout, allow_redirects=True)
if response.status_code >= 400:
return {
"ok": False,
"message": f"Dify 调用失败: HTTP {response.status_code}",
"detail": response.text,
}
data = response.json()
answer = data.get("answer") or _extract_openai_text(data)
return {
"ok": True,
"provider": "dify",
"raw": data,
"answer": answer,
}
except Exception as exc:
return {
"ok": False,
"message": "Dify 请求异常",
"detail": str(exc),
}

View File

@@ -0,0 +1,62 @@
import json
import re
from pathlib import Path
class LocalKnowledgeRetriever:
def __init__(self, kb_path: str):
self.kb_path = Path(kb_path)
self.documents = []
self._load()
@staticmethod
def _tokenize(text: str) -> set:
text = text or ""
words = set(re.findall(r"[A-Za-z0-9_]+", text.lower()))
cjk_chunks = re.findall(r"[\u4e00-\u9fff]+", text)
for chunk in cjk_chunks:
if len(chunk) <= 2:
words.add(chunk)
else:
for idx in range(len(chunk) - 1):
words.add(chunk[idx : idx + 2])
return words
def _load(self):
if not self.kb_path.exists():
self.documents = []
return
with self.kb_path.open("r", encoding="utf-8-sig") as file:
rows = json.load(file)
self.documents = rows if isinstance(rows, list) else []
def reload(self):
self._load()
def search(self, query: str, top_k: int = 3) -> list:
query_tokens = self._tokenize(query)
if not query_tokens:
return []
scored = []
for item in self.documents:
content = f"{item.get('question', '')} {item.get('answer', '')} {' '.join(item.get('tags', []))}"
doc_tokens = self._tokenize(content)
if not doc_tokens:
continue
overlap = len(query_tokens & doc_tokens)
if overlap == 0:
continue
score = overlap / max(len(query_tokens), 1)
scored.append(
{
"score": round(score, 4),
"question": item.get("question", ""),
"answer": item.get("answer", ""),
"tags": item.get("tags", []),
"source": item.get("source", "本地知识库"),
}
)
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:top_k]

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,429 @@
from collections import Counter
from datetime import datetime, timedelta
from flask import Blueprint, current_app, request
from sqlalchemy import func, or_
from app.extensions import db
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
from app.models import ContentPost, DetectionConfig, SpamPredictionLog, SpamTrainingSample, User
from app.utils.auth import admin_required, current_user
from app.utils.response import fail, ok
admin_bp = Blueprint("admin", __name__)
def _day_key(day_value) -> str:
if hasattr(day_value, "isoformat"):
return day_value.isoformat()
return str(day_value)
def _tokenize(text: str) -> list[str]:
content = (text or "").strip()
if len(content) <= 1:
return []
tokens = []
for i in range(len(content) - 1):
token = content[i : i + 2]
if token.strip():
tokens.append(token)
return tokens
def _get_or_create_config() -> DetectionConfig:
cfg = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
if cfg:
return cfg
cfg = DetectionConfig(spam_threshold=0.75)
db.session.add(cfg)
db.session.commit()
return cfg
def _serialize_post(item: ContentPost) -> dict:
row = item.to_dict()
row["username"] = item.author.username if item.author else ""
row["nickname"] = item.author.nickname if item.author else ""
row["recipient_username"] = item.recipient.username if item.recipient else ""
row["recipient_nickname"] = item.recipient.nickname if item.recipient else ""
row["reviewer_username"] = item.reviewer.username if item.reviewer else ""
return row
def _upsert_manual_sample(text: str, label: str, admin_id: int | None) -> None:
existed = SpamTrainingSample.query.filter_by(text=text, label=label).first()
if existed:
existed.is_active = True
existed.source = existed.source or "manual_review"
return
row = SpamTrainingSample(
text=text,
label=label,
source="manual_review",
created_by=admin_id,
is_active=True,
)
db.session.add(row)
@admin_bp.get("/stats")
@admin_required
def stats():
user_count = User.query.count()
sample_count = SpamTrainingSample.query.count()
predict_count = SpamPredictionLog.query.count()
post_count = ContentPost.query.count()
blocked_count = ContentPost.query.filter_by(status="blocked").count()
published_count = ContentPost.query.filter_by(status="published").count()
pending_appeal_count = ContentPost.query.filter_by(appeal_status="pending").count()
now = datetime.utcnow()
week_ago = now - timedelta(days=6)
trend_rows = (
db.session.query(func.date(ContentPost.created_at), func.count(ContentPost.id))
.filter(ContentPost.created_at >= week_ago)
.group_by(func.date(ContentPost.created_at))
.all()
)
blocked_7d_count = ContentPost.query.filter(ContentPost.created_at >= week_ago, ContentPost.status == "blocked").count() or 0
total_7d_count = ContentPost.query.filter(ContentPost.created_at >= week_ago).count() or 0
day_map = {_day_key(day): int(count or 0) for day, count in trend_rows}
trend = []
today = now.date()
for offset in range(6, -1, -1):
day = today - timedelta(days=offset)
key = day.isoformat()
trend.append({"date": key, "label": day.strftime("%m-%d"), "post_count": day_map.get(key, 0)})
source_rows = (
db.session.query(SpamTrainingSample.source, func.count(SpamTrainingSample.id))
.group_by(SpamTrainingSample.source)
.order_by(func.count(SpamTrainingSample.id).desc())
.all()
)
source_dist = [{"name": (name or "unknown"), "value": int(value or 0)} for name, value in source_rows]
blocked_logs = (
ContentPost.query.filter(ContentPost.created_at >= week_ago, ContentPost.status == "blocked")
.order_by(ContentPost.id.desc())
.limit(1000)
.all()
)
token_counter = Counter()
for row in blocked_logs:
token_counter.update(_tokenize(row.text))
top_keywords = [{"token": token, "count": count} for token, count in token_counter.most_common(12)]
cfg = _get_or_create_config()
clf = NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
clf.load()
return ok(
{
"user_count": user_count,
"sample_count": sample_count,
"predict_count": predict_count,
"post_count": post_count,
"blocked_count": blocked_count,
"published_count": published_count,
"pending_appeal_count": pending_appeal_count,
"blocked_ratio_7d": round(blocked_7d_count / total_7d_count, 4) if total_7d_count else 0,
"total_7d": total_7d_count,
"trend_7d": trend,
"source_distribution": source_dist,
"top_keywords": top_keywords,
"model_info": clf.model_info(),
"threshold": cfg.to_dict(),
}
)
@admin_bp.get("/detection/threshold")
@admin_required
def get_threshold():
return ok(_get_or_create_config().to_dict())
@admin_bp.put("/detection/threshold")
@admin_required
def set_threshold():
payload = request.get_json(silent=True) or {}
try:
threshold = float(payload.get("spam_threshold"))
except Exception:
return fail("spam_threshold 必须是数字", 400)
if threshold < 0.01 or threshold > 0.99:
return fail("spam_threshold 必须在 0.01 到 0.99 之间", 400)
cfg = _get_or_create_config()
admin = current_user()
cfg.spam_threshold = threshold
cfg.updated_by = admin.id if admin else None
db.session.commit()
return ok(cfg.to_dict(), "阈值更新成功")
@admin_bp.get("/intercepts")
@admin_required
def list_intercepts():
keyword = (request.args.get("keyword") or "").strip()
status = (request.args.get("status") or "blocked").strip().lower()
review_status = (request.args.get("review_status") or "").strip().lower()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
query = ContentPost.query
if keyword:
query = query.filter(ContentPost.text.like(f"%{keyword}%"))
if status in {"blocked", "published"}:
query = query.filter(ContentPost.status == status)
if review_status in {"none", "pending", "confirmed_spam", "approved_ham"}:
query = query.filter(ContentPost.manual_review_status == review_status)
pagination = query.order_by(ContentPost.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
return ok(
{
"items": [_serialize_post(item) for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@admin_bp.put("/intercepts/<int:post_id>/review")
@admin_required
def review_intercept(post_id: int):
row = ContentPost.query.get(post_id)
if not row:
return fail("记录不存在", 404)
payload = request.get_json(silent=True) or {}
decision = (payload.get("decision") or "").strip().lower()
note = (payload.get("note") or "").strip()
if decision not in {"spam", "ham"}:
return fail("decision 必须是 spam 或 ham", 400)
admin = current_user()
now = datetime.utcnow()
row.manual_review_by = admin.id if admin else None
row.manual_review_note = note
row.manual_review_at = now
if decision == "spam":
row.status = "blocked"
row.prediction = "spam"
row.manual_review_status = "confirmed_spam"
if row.appeal_status == "pending":
row.appeal_status = "rejected"
row.appeal_admin_note = note or "人工复核确认为垃圾信息"
row.appeal_processed_at = now
row.appeal_processed_by = admin.id if admin else None
_upsert_manual_sample(row.text, "spam", admin.id if admin else None)
else:
row.status = "published"
row.prediction = "ham"
row.manual_review_status = "approved_ham"
if row.appeal_status == "pending":
row.appeal_status = "approved"
row.appeal_admin_note = note or "人工复核后解除拦截"
row.appeal_processed_at = now
row.appeal_processed_by = admin.id if admin else None
_upsert_manual_sample(row.text, "ham", admin.id if admin else None)
db.session.commit()
return ok(_serialize_post(row), "人工复核完成")
@admin_bp.get("/appeals")
@admin_required
def list_appeals():
keyword = (request.args.get("keyword") or "").strip()
status = (request.args.get("status") or "pending").strip().lower()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
query = ContentPost.query.filter(ContentPost.appeal_status != "none")
if keyword:
query = query.filter(
or_(
ContentPost.text.like(f"%{keyword}%"),
ContentPost.appeal_reason.like(f"%{keyword}%"),
ContentPost.appeal_admin_note.like(f"%{keyword}%"),
)
)
if status in {"pending", "approved", "rejected"}:
query = query.filter(ContentPost.appeal_status == status)
pagination = query.order_by(ContentPost.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
return ok(
{
"items": [_serialize_post(item) for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@admin_bp.put("/appeals/<int:post_id>/process")
@admin_required
def process_appeal(post_id: int):
row = ContentPost.query.get(post_id)
if not row:
return fail("记录不存在", 404)
if row.appeal_status != "pending":
return fail("该申诉不在待处理状态", 400)
payload = request.get_json(silent=True) or {}
action = (payload.get("action") or "").strip().lower()
note = (payload.get("note") or "").strip()
if action not in {"approve", "reject"}:
return fail("action 必须是 approve 或 reject", 400)
admin = current_user()
now = datetime.utcnow()
row.appeal_status = "approved" if action == "approve" else "rejected"
row.appeal_admin_note = note
row.appeal_processed_at = now
row.appeal_processed_by = admin.id if admin else None
row.manual_review_by = admin.id if admin else None
row.manual_review_note = note
row.manual_review_at = now
if action == "approve":
row.status = "published"
row.prediction = "ham"
row.manual_review_status = "approved_ham"
_upsert_manual_sample(row.text, "ham", admin.id if admin else None)
else:
row.status = "blocked"
row.prediction = "spam"
row.manual_review_status = "confirmed_spam"
_upsert_manual_sample(row.text, "spam", admin.id if admin else None)
db.session.commit()
return ok(_serialize_post(row), "申诉处理完成")
@admin_bp.get("/users")
@admin_required
def list_users():
keyword = (request.args.get("keyword") or "").strip()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
query = User.query
if keyword:
query = query.filter(User.username.like(f"%{keyword}%") | User.nickname.like(f"%{keyword}%"))
pagination = query.order_by(User.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
return ok(
{
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@admin_bp.post("/users/import")
@admin_required
def import_users():
payload = request.get_json(silent=True) or {}
items = payload.get("items") or []
if not isinstance(items, list) or not items:
return fail("items 必须是非空数组", 400)
created = 0
updated = 0
for row in items:
username = (row.get("username") or "").strip()
if len(username) < 3:
continue
user = User.query.filter_by(username=username).first()
if not user:
user = User(
username=username,
nickname=(row.get("nickname") or username).strip(),
company=(row.get("company") or "").strip(),
title=(row.get("title") or "").strip(),
phone=(row.get("phone") or "").strip(),
is_admin=bool(row.get("is_admin", False)),
)
user.set_password(row.get("password") or "123456")
db.session.add(user)
created += 1
continue
user.nickname = (row.get("nickname") or user.nickname).strip()
user.company = (row.get("company") or user.company).strip()
user.title = (row.get("title") or user.title).strip()
user.phone = (row.get("phone") or user.phone).strip()
if "is_admin" in row:
user.is_admin = bool(row.get("is_admin"))
if row.get("password"):
user.set_password(row["password"])
updated += 1
db.session.commit()
return ok({"created": created, "updated": updated}, "用户导入完成")
@admin_bp.put("/users/<int:user_id>")
@admin_required
def update_user(user_id: int):
user = User.query.get(user_id)
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
if "nickname" in payload:
user.nickname = (payload.get("nickname") or user.nickname).strip()
if "company" in payload:
user.company = (payload.get("company") or "").strip()
if "title" in payload:
user.title = (payload.get("title") or "").strip()
if "phone" in payload:
user.phone = (payload.get("phone") or "").strip()
if "is_admin" in payload:
user.is_admin = bool(payload.get("is_admin"))
if payload.get("password"):
if len(payload["password"]) < 6:
return fail("密码至少6位", 400)
user.set_password(payload["password"])
db.session.commit()
return ok(user.to_dict(), "用户更新成功")
@admin_bp.delete("/users/<int:user_id>")
@admin_required
def delete_user(user_id: int):
user = User.query.get(user_id)
if not user:
return fail("用户不存在", 404)
if user.is_admin and User.query.filter_by(is_admin=True).count() <= 1:
return fail("至少保留一个管理员账号", 400)
db.session.delete(user)
db.session.commit()
return ok({}, "用户已删除")

View File

@@ -0,0 +1,67 @@
from flask import Blueprint, request
from flask_jwt_extended import create_access_token, jwt_required
from app.extensions import db
from app.models import User
from app.utils.auth import current_user
from app.utils.response import fail, ok
auth_bp = Blueprint("auth", __name__)
@auth_bp.post("/register")
def register():
payload = request.get_json(silent=True) or {}
username = (payload.get("username") or "").strip()
password = payload.get("password") or ""
nickname = (payload.get("nickname") or username).strip()
if len(username) < 3:
return fail("用户名至少3位", 400)
if len(password) < 6:
return fail("密码至少6位", 400)
if User.query.filter_by(username=username).first():
return fail("用户名已存在", 409)
user = User(
username=username,
nickname=nickname or username,
company=(payload.get("company") or "").strip(),
title=(payload.get("title") or "").strip(),
phone=(payload.get("phone") or "").strip(),
is_admin=bool(payload.get("is_admin", False)),
)
user.set_password(password)
db.session.add(user)
db.session.commit()
return ok(user.to_dict(), "注册成功")
@auth_bp.post("/login")
def login():
payload = request.get_json(silent=True) or {}
username = (payload.get("username") or "").strip()
password = payload.get("password") or ""
user = User.query.filter_by(username=username).first()
if not user or not user.check_password(password):
return fail("用户名或密码错误", 401)
access_token = create_access_token(
identity=str(user.id),
additional_claims={"is_admin": bool(user.is_admin), "username": user.username},
)
return ok({"token": access_token, "user": user.to_dict()}, "登录成功")
@auth_bp.get("/me")
@jwt_required()
def me():
user = current_user()
if not user:
return fail("用户不存在", 404)
return ok(user.to_dict())

View File

@@ -0,0 +1,360 @@
from datetime import datetime
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
from app.models import ContentPost, DetectionConfig, SpamPredictionLog, SpamTrainingSample, User
from app.utils.auth import current_user
from app.utils.response import fail, ok
content_bp = Blueprint("content", __name__)
def _classifier() -> NaiveBayesSpamClassifier:
return NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
def _active_samples() -> list[dict]:
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
return [{"text": row.text, "label": row.label} for row in rows]
def _ensure_ready() -> NaiveBayesSpamClassifier:
clf = _classifier()
clf.ensure_ready(_active_samples())
return clf
def _get_config() -> DetectionConfig:
cfg = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
if cfg:
return cfg
cfg = DetectionConfig(spam_threshold=0.75)
db.session.add(cfg)
db.session.commit()
return cfg
def _serialize_post(row: ContentPost) -> dict:
payload = row.to_dict()
payload["username"] = row.author.username if row.author else ""
payload["nickname"] = row.author.nickname if row.author else ""
payload["recipient_username"] = row.recipient.username if row.recipient else ""
payload["recipient_nickname"] = row.recipient.nickname if row.recipient else ""
payload["reviewer_username"] = row.reviewer.username if row.reviewer else ""
return payload
def _resolve_visibility(value: str) -> str:
key = (value or "public").strip().lower()
return key if key in {"public", "private", "direct"} else "public"
def _resolve_recipient(payload: dict, visibility: str, current_user_id: int):
if visibility != "direct":
return None, None
recipient = None
raw_id = payload.get("recipient_user_id")
username = (payload.get("recipient_username") or "").strip()
if raw_id is not None and str(raw_id).strip() != "":
try:
recipient = User.query.get(int(raw_id))
except Exception:
return None, "recipient_user_id 无效"
elif username:
recipient = User.query.filter_by(username=username).first()
if not recipient:
return None, "私信发布必须指定有效接收人"
if recipient.id == current_user_id:
return None, "不能给自己发送私信"
return recipient, None
def _predict_and_decide(text: str) -> tuple[dict, float, bool]:
clf = _ensure_ready()
result = clf.predict(text)
threshold = float(_get_config().spam_threshold)
blocked = float(result["spam_probability"]) >= threshold
return result, threshold, blocked
@content_bp.post("/publish")
@jwt_required()
def publish_text():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
visibility = _resolve_visibility(payload.get("visibility"))
if len(text) < 2:
return fail("发布文本至少2个字符", 400)
recipient, err = _resolve_recipient(payload, visibility, user.id)
if err:
return fail(err, 400)
result, threshold, blocked = _predict_and_decide(text)
post = ContentPost(
user_id=user.id,
recipient_user_id=recipient.id if recipient else None,
text=result["text"],
visibility=visibility,
status="blocked" if blocked else "published",
prediction=result["prediction"],
spam_probability=result["spam_probability"],
ham_probability=result["ham_probability"],
confidence=result["confidence"],
threshold=threshold,
reason_tokens=result["reason_tokens"],
model_version=result.get("model_version", ""),
manual_review_status="pending" if blocked else "none",
)
detect_log = SpamPredictionLog(
user_id=user.id,
text=result["text"],
prediction=result["prediction"],
spam_probability=result["spam_probability"],
ham_probability=result["ham_probability"],
confidence=result["confidence"],
reason_tokens=result["reason_tokens"],
model_version=result.get("model_version", ""),
)
db.session.add(post)
db.session.add(detect_log)
db.session.commit()
feedback = "发布成功" if not blocked else "疑似垃圾信息,系统已拦截,可提交申诉"
return ok(
{
"publish_allowed": not blocked,
"action": "published" if not blocked else "blocked",
"feedback": feedback,
"post": _serialize_post(post),
"detect": result,
},
feedback,
)
@content_bp.put("/posts/<int:post_id>")
@jwt_required()
def edit_post(post_id: int):
user = current_user()
if not user:
return fail("用户不存在", 404)
post = ContentPost.query.filter_by(id=post_id, user_id=user.id).first()
if not post:
return fail("发布记录不存在", 404)
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or post.text).strip()
visibility = _resolve_visibility(payload.get("visibility") or post.visibility)
if len(text) < 2:
return fail("发布文本至少2个字符", 400)
recipient, err = _resolve_recipient(payload, visibility, user.id)
if err:
return fail(err, 400)
result, threshold, blocked = _predict_and_decide(text)
post.text = result["text"]
post.visibility = visibility
post.recipient_user_id = recipient.id if recipient else None
post.status = "blocked" if blocked else "published"
post.prediction = result["prediction"]
post.spam_probability = result["spam_probability"]
post.ham_probability = result["ham_probability"]
post.confidence = result["confidence"]
post.threshold = threshold
post.reason_tokens = result["reason_tokens"]
post.model_version = result.get("model_version", "")
post.manual_review_status = "pending" if blocked else "none"
post.manual_review_by = None
post.manual_review_note = ""
post.manual_review_at = None
post.appeal_status = "none"
post.appeal_reason = ""
post.appeal_admin_note = ""
post.appeal_submitted_at = None
post.appeal_processed_at = None
post.appeal_processed_by = None
db.session.commit()
feedback = "更新并重新发布成功" if not blocked else "更新后触发拦截,可提交申诉"
return ok(
{
"publish_allowed": not blocked,
"action": "published" if not blocked else "blocked",
"feedback": feedback,
"post": _serialize_post(post),
"detect": result,
},
feedback,
)
@content_bp.get("/posts/history")
@jwt_required()
def my_posts():
user = current_user()
if not user:
return fail("用户不存在", 404)
status = (request.args.get("status") or "").strip().lower()
visibility = (request.args.get("visibility") or "").strip().lower()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
query = ContentPost.query.filter_by(user_id=user.id)
if status in {"published", "blocked"}:
query = query.filter(ContentPost.status == status)
if visibility in {"public", "private", "direct"}:
query = query.filter(ContentPost.visibility == visibility)
pagination = query.order_by(ContentPost.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
return ok(
{
"items": [_serialize_post(item) for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@content_bp.get("/posts/inbox")
@jwt_required()
def my_inbox():
user = current_user()
if not user:
return fail("用户不存在", 404)
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
pagination = (
ContentPost.query.filter_by(recipient_user_id=user.id, visibility="direct", status="published")
.order_by(ContentPost.id.desc())
.paginate(page=page, per_page=page_size, error_out=False)
)
return ok(
{
"items": [_serialize_post(item) for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@content_bp.delete("/posts/<int:post_id>")
@jwt_required()
def delete_post(post_id: int):
user = current_user()
if not user:
return fail("用户不存在", 404)
row = ContentPost.query.filter_by(id=post_id, user_id=user.id).first()
if not row:
return fail("记录不存在", 404)
db.session.delete(row)
db.session.commit()
return ok({}, "记录已删除")
@content_bp.post("/posts/<int:post_id>/appeal")
@jwt_required()
def submit_appeal(post_id: int):
user = current_user()
if not user:
return fail("用户不存在", 404)
post = ContentPost.query.filter_by(id=post_id, user_id=user.id).first()
if not post:
return fail("发布记录不存在", 404)
if post.status != "blocked":
return fail("仅被拦截的信息可申诉", 400)
payload = request.get_json(silent=True) or {}
reason = (payload.get("reason") or "").strip()
if len(reason) < 2:
return fail("申诉理由至少2个字符", 400)
if post.appeal_status == "pending":
return fail("该记录已在申诉处理中", 400)
post.appeal_status = "pending"
post.appeal_reason = reason
post.appeal_submitted_at = datetime.utcnow()
post.appeal_admin_note = ""
post.appeal_processed_at = None
post.appeal_processed_by = None
post.manual_review_status = "pending"
db.session.commit()
return ok(_serialize_post(post), "申诉提交成功")
@content_bp.get("/appeals/my")
@jwt_required()
def my_appeals():
user = current_user()
if not user:
return fail("用户不存在", 404)
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
pagination = (
ContentPost.query.filter(ContentPost.user_id == user.id, ContentPost.appeal_status != "none")
.order_by(ContentPost.id.desc())
.paginate(page=page, per_page=page_size, error_out=False)
)
return ok(
{
"items": [_serialize_post(item) for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@content_bp.get("/posts/public")
@jwt_required(optional=True)
def public_feed():
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
pagination = (
ContentPost.query.filter_by(visibility="public", status="published")
.order_by(ContentPost.id.desc())
.paginate(page=page, per_page=page_size, error_out=False)
)
return ok(
{
"items": [_serialize_post(item) for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)

View File

@@ -0,0 +1,121 @@
from datetime import datetime
from flask import Blueprint, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.models import DietStatus
from app.utils.auth import current_user
from app.utils.response import fail, ok
diet_bp = Blueprint("diet", __name__)
def _parse_status_payload(payload: dict) -> dict:
return {
"weight": float(payload.get("weight", 0)),
"body_fat": float(payload.get("body_fat", 0)),
"exercise_kcal": float(payload.get("exercise_kcal", 0)),
"intake_kcal": float(payload.get("intake_kcal", 0)),
"sleep_hours": float(payload.get("sleep_hours", 7)),
"note": (payload.get("note") or "").strip(),
}
@diet_bp.post("/status")
@jwt_required()
def create_status():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
data = _parse_status_payload(payload)
if data["weight"] <= 0 or data["body_fat"] <= 0:
return fail("体重和体脂率必须大于0", 400)
status = DietStatus(user_id=user.id, **data)
if payload.get("recorded_at"):
status.recorded_at = datetime.fromisoformat(payload["recorded_at"])
db.session.add(status)
db.session.commit()
return ok(status.to_dict(), "饮食状态记录成功")
@diet_bp.put("/status/latest")
@jwt_required()
def update_latest_status():
user = current_user()
if not user:
return fail("用户不存在", 404)
latest = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.first()
)
if not latest:
return fail("暂无可编辑的饮食状态", 404)
payload = request.get_json(silent=True) or {}
data = _parse_status_payload(payload)
for key, value in data.items():
setattr(latest, key, value)
db.session.commit()
return ok(latest.to_dict(), "最新饮食状态已更新")
@diet_bp.get("/status/latest")
@jwt_required()
def get_latest_status():
user = current_user()
if not user:
return fail("用户不存在", 404)
latest = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.first()
)
if not latest:
return ok({})
return ok(latest.to_dict())
@diet_bp.get("/history")
@jwt_required()
def history():
user = current_user()
if not user:
return fail("用户不存在", 404)
limit = min(max(int(request.args.get("limit", 30) or 30), 1), 200)
rows = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.limit(limit)
.all()
)
return ok([item.to_dict() for item in rows])
@diet_bp.delete("/history/<int:status_id>")
@jwt_required()
def delete_history(status_id: int):
user = current_user()
if not user:
return fail("用户不存在", 404)
row = DietStatus.query.filter_by(id=status_id, user_id=user.id).first()
if not row:
return fail("记录不存在", 404)
db.session.delete(row)
db.session.commit()
return ok({}, "记录已删除")

View File

@@ -0,0 +1,177 @@
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.rag.llm_client import LLMConfigManager, ask_dify, ask_fastgpt
from app.rag.local_retriever import LocalKnowledgeRetriever
from app.utils.auth import admin_required, current_user
from app.utils.response import ok
qa_bp = Blueprint("qa", __name__)
def _retriever() -> LocalKnowledgeRetriever:
return LocalKnowledgeRetriever(current_app.config["LOCAL_KB_PATH"])
def _llm_config() -> dict:
manager = LLMConfigManager(current_app.config["LLM_CONFIG_PATH"])
return manager.load()
def _default_provider(config: dict) -> str:
provider = (config.get("active_provider") or "fastgpt").strip().lower() if isinstance(config, dict) else "fastgpt"
return provider if provider in {"fastgpt", "dify"} else "fastgpt"
def _context_text(hits: list) -> str:
chunks = []
for idx, item in enumerate(hits, start=1):
chunks.append(f"[{idx}] 问: {item['question']}\n答: {item['answer']}")
return "\n\n".join(chunks)
def _local_answer(hits: list) -> str:
if not hits:
return "本地知识库暂未检索到高相关内容,建议补充问题细节。"
top = hits[0]
return f"根据本地知识库:{top['answer']}"
def _advice_template(advice_type: str) -> str:
templates = {
"gain_muscle": "增肌建议:每天蛋白质建议 1.6-2.2g/kg训练后 30 分钟内补充蛋白+碳水,优先鸡胸肉/牛肉/鸡蛋/燕麦。",
"lose_fat": "减脂建议:保持 10%-20% 轻热量缺口,优先高蛋白高纤维食物,减少高糖饮料和深夜加餐。",
"keto": "生酮建议:控制碳水通常小于 50g/天,优先健康脂肪和适量蛋白,并关注电解质补充。",
"nutritionist": "营养师建议:结合体重、体脂、活动量和目标制定周计划,每 7 天复盘体重和围度变化。",
"general": "通用建议:保证规律作息、三餐结构稳定、适量运动和充足饮水。",
}
return templates.get(advice_type, templates["general"])
def _ask_by_provider(provider: str, config: dict, question: str, context: str, user_id: str, username: str):
if provider == "dify":
return ask_dify(config, question, context=context, user=username)
return ask_fastgpt(config, question, context=context, custom_uid=user_id)
@qa_bp.get("/kb/search")
@jwt_required(optional=True)
def search_kb():
query = (request.args.get("query") or "").strip()
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 10)
retriever = _retriever()
hits = retriever.search(query, top_k=top_k)
return ok({"items": hits})
@qa_bp.post("/ask")
@jwt_required()
def ask_nutrition_qa():
payload = request.get_json(silent=True) or {}
question = (payload.get("question") or "").strip()
if not question:
return ok({"answer": "请输入问题", "items": []}, "empty_question")
mode = (payload.get("mode") or "auto").strip().lower()
llm_cfg = _llm_config()
provider = (payload.get("provider") or _default_provider(llm_cfg)).strip().lower()
if provider not in {"fastgpt", "dify"}:
provider = _default_provider(llm_cfg)
retriever = _retriever()
hits = retriever.search(question, top_k=3)
context = _context_text(hits)
local_answer = _local_answer(hits)
if mode == "local":
return ok({"answer": local_answer, "provider": "local_kb", "items": hits})
user = current_user()
llm_result = _ask_by_provider(
provider,
llm_cfg,
question,
context,
str(user.id) if user else "",
user.username if user else "anonymous",
)
if mode == "llm":
if llm_result.get("ok"):
return ok({"answer": llm_result.get("answer"), "provider": provider, "items": hits})
return ok(
{
"answer": f"LLM 调用失败,已回退本地知识库。{local_answer}",
"provider": "local_kb_fallback",
"items": hits,
"llm_error": llm_result,
}
)
if llm_result.get("ok"):
return ok({"answer": llm_result.get("answer"), "provider": provider, "items": hits})
return ok(
{
"answer": local_answer,
"provider": "local_kb",
"items": hits,
"llm_error": llm_result,
}
)
@qa_bp.post("/advice")
@jwt_required()
def advice():
payload = request.get_json(silent=True) or {}
advice_type = (payload.get("advice_type") or "general").strip().lower()
question = (payload.get("question") or "请根据我的状态给建议").strip()
llm_cfg = _llm_config()
provider = (payload.get("provider") or _default_provider(llm_cfg)).strip().lower()
if provider not in {"fastgpt", "dify"}:
provider = _default_provider(llm_cfg)
retriever = _retriever()
hits = retriever.search(f"{advice_type} {question}", top_k=3)
context = _context_text(hits)
base_text = _advice_template(advice_type)
user = current_user()
llm_result = _ask_by_provider(
provider,
llm_cfg,
f"请给出{advice_type}方向的个性化饮食建议。用户补充: {question}",
f"{base_text}\n\n{context}",
str(user.id) if user else "",
user.username if user else "anonymous",
)
if llm_result.get("ok"):
answer = llm_result.get("answer")
source = provider
else:
answer = f"{base_text}\n\n提示LLM 暂不可用,当前为本地模板建议)"
source = "local_template"
return ok(
{
"advice_type": advice_type,
"answer": answer,
"provider": source,
"knowledge_hits": hits,
"llm_status": llm_result,
}
)
@qa_bp.post("/kb/reload")
@admin_required
def reload_kb():
retriever = _retriever()
retriever.reload()
return ok({"count": len(retriever.documents)}, "知识库已重载")

View File

@@ -0,0 +1,135 @@
import json
from pathlib import Path
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.models import Recipe
from app.utils.auth import admin_required
from app.utils.response import fail, ok
recipe_bp = Blueprint("recipe", __name__)
def _recipe_payload(recipe: Recipe, payload: dict) -> None:
recipe.name = payload.get("name", recipe.name)
recipe.category = payload.get("category", recipe.category or "轻食")
recipe.description = payload.get("description", recipe.description or "")
recipe.calories = float(payload.get("calories", recipe.calories or 0))
recipe.protein = float(payload.get("protein", recipe.protein or 0))
recipe.fat = float(payload.get("fat", recipe.fat or 0))
recipe.carbs = float(payload.get("carbs", recipe.carbs or 0))
recipe.fiber = float(payload.get("fiber", recipe.fiber or 0))
recipe.tags = payload.get("tags", recipe.tags or [])
recipe.difficulty = payload.get("difficulty", recipe.difficulty or "easy")
@recipe_bp.get("")
@jwt_required(optional=True)
def list_recipes():
keyword = (request.args.get("keyword") or "").strip()
category = (request.args.get("category") or "").strip()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 10) or 10), 1), 50)
query = Recipe.query
if keyword:
query = query.filter(Recipe.name.like(f"%{keyword}%"))
if category:
query = query.filter(Recipe.category == category)
pagination = query.order_by(Recipe.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
data = {
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
return ok(data)
@recipe_bp.get("/<int:recipe_id>")
@jwt_required(optional=True)
def get_recipe(recipe_id: int):
recipe = Recipe.query.get(recipe_id)
if not recipe:
return fail("食谱不存在", 404)
return ok(recipe.to_dict())
@recipe_bp.post("")
@admin_required
def create_recipe():
payload = request.get_json(silent=True) or {}
if not payload.get("name"):
return fail("缺少食谱名称", 400)
recipe = Recipe()
_recipe_payload(recipe, payload)
db.session.add(recipe)
db.session.commit()
return ok(recipe.to_dict(), "食谱创建成功")
@recipe_bp.put("/<int:recipe_id>")
@admin_required
def update_recipe(recipe_id: int):
recipe = Recipe.query.get(recipe_id)
if not recipe:
return fail("食谱不存在", 404)
payload = request.get_json(silent=True) or {}
_recipe_payload(recipe, payload)
db.session.commit()
return ok(recipe.to_dict(), "食谱更新成功")
@recipe_bp.delete("/<int:recipe_id>")
@admin_required
def delete_recipe(recipe_id: int):
recipe = Recipe.query.get(recipe_id)
if not recipe:
return fail("食谱不存在", 404)
db.session.delete(recipe)
db.session.commit()
return ok({}, "食谱删除成功")
@recipe_bp.post("/import")
@admin_required
def import_recipes():
payload = request.get_json(silent=True) or {}
rows = payload.get("items")
if not isinstance(rows, list):
seed_file = Path(current_app.root_path).parents[0] / "seed" / "recipes_seed.json"
if seed_file.exists():
with seed_file.open("r", encoding="utf-8-sig") as file:
rows = json.load(file)
else:
return fail("导入数据为空,且未找到默认种子文件", 400)
created_count = 0
updated_count = 0
for item in rows:
name = (item.get("name") or "").strip()
if not name:
continue
recipe = Recipe.query.filter_by(name=name).first()
if recipe:
_recipe_payload(recipe, item)
updated_count += 1
else:
recipe = Recipe()
_recipe_payload(recipe, item)
db.session.add(recipe)
created_count += 1
db.session.commit()
return ok({"created": created_count, "updated": updated_count}, "食谱导入完成")

View File

@@ -0,0 +1,314 @@
from statistics import mean
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.ml.rf_recommender import RandomForestDietRecommender, merge_profile_with_history
from app.models import DietStatus, RecommendationLog, Recipe
from app.utils.auth import current_user
from app.utils.response import fail, ok
recommend_bp = Blueprint("recommend", __name__)
def _build_profile(user, latest_status):
profile = {
"age": user.age,
"height_cm": user.height_cm,
"goal": user.goal,
"occupation": user.occupation,
"weight": 65,
"body_fat": 20,
"exercise_kcal": 300,
"intake_kcal": 1800,
}
if latest_status:
profile.update(
{
"weight": latest_status.weight,
"body_fat": latest_status.body_fat,
"exercise_kcal": latest_status.exercise_kcal,
"intake_kcal": latest_status.intake_kcal,
}
)
return profile
def _goal_adjust(goal: str, recipe: Recipe) -> float:
if goal == "lose_fat":
return recipe.protein * 0.8 - recipe.fat * 0.4 - max(recipe.calories - 500, 0) * 0.04
if goal == "gain_muscle":
return recipe.protein * 1.0 + recipe.carbs * 0.25
if goal == "keto":
return recipe.fat * 0.5 - recipe.carbs * 0.8
return recipe.protein * 0.3 + recipe.fiber * 1.2
def _occupation_adjust(occupation: str, recipe: Recipe) -> float:
occupation = (occupation or "").lower()
if occupation in {"developer", "office"}:
return recipe.fiber * 0.9 + recipe.protein * 0.2
if occupation in {"fitness", "manual"}:
return recipe.protein * 0.6 + recipe.carbs * 0.2
if occupation in {"teacher", "student"}:
return recipe.carbs * 0.15 + recipe.protein * 0.3
return recipe.protein * 0.25
def _health_adjust(profile: dict, recipe: Recipe) -> float:
score = 70.0
body_fat = float(profile.get("body_fat", 20))
intake_kcal = float(profile.get("intake_kcal", 1800))
exercise_kcal = float(profile.get("exercise_kcal", 300))
if body_fat > 28:
score += recipe.protein * 0.5
score -= max(recipe.calories - 480, 0) * 0.05
elif body_fat < 14:
score += recipe.carbs * 0.15 + recipe.fat * 0.1
if intake_kcal > exercise_kcal + 1700:
score -= max(recipe.calories - 450, 0) * 0.04
else:
score += recipe.protein * 0.3
score += min(recipe.fiber * 1.0, 10)
return round(score, 2)
def _serialize_with_score(recipes):
return [
{
**item[0].to_dict(),
"score": round(float(item[1]), 2),
"reason": item[2],
}
for item in recipes
]
def _save_log(user_id: int, rec_type: str, payload: dict) -> None:
log = RecommendationLog(user_id=user_id, rec_type=rec_type, payload=payload)
db.session.add(log)
db.session.commit()
@recommend_bp.get("/current")
@jwt_required()
def recommend_current_status():
user = current_user()
if not user:
return fail("用户不存在", 404)
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
latest = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.first()
)
history = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.limit(14)
.all()
)
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
if not recipes:
return fail("当前没有可推荐食谱,请先导入食谱", 400)
profile = _build_profile(user, latest)
profile = merge_profile_with_history(profile, history)
recommender = RandomForestDietRecommender(current_app.config["MODEL_PATH"])
items = recommender.recommend(profile, recipes, top_k=top_k)
response_data = {
"type": "current_status_rf",
"profile": profile,
"items": items,
}
_save_log(user.id, "current_status_rf", response_data)
return ok(response_data)
@recommend_bp.get("/health")
@jwt_required()
def recommend_by_health():
user = current_user()
if not user:
return fail("用户不存在", 404)
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
latest = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.first()
)
profile = _build_profile(user, latest)
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
if not recipes:
return fail("当前没有可推荐食谱,请先导入食谱", 400)
scored = []
for recipe in recipes:
score = _health_adjust(profile, recipe)
reason = "结合当前体脂和摄入消耗状态,推荐更合理的营养结构。"
scored.append((recipe, score, reason))
scored.sort(key=lambda row: row[1], reverse=True)
items = _serialize_with_score(scored[:top_k])
response_data = {
"type": "health_state",
"profile": profile,
"items": items,
}
_save_log(user.id, "health_state", response_data)
return ok(response_data)
@recommend_bp.get("/plan/history")
@jwt_required()
def recommend_plan_by_history():
user = current_user()
if not user:
return fail("用户不存在", 404)
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
history = (
DietStatus.query.filter_by(user_id=user.id)
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
.limit(30)
.all()
)
if len(history) < 3:
return fail("历史数据太少至少记录3天后可用该功能", 400)
avg_weight = mean([h.weight for h in history])
early_weight = history[-1].weight
latest_weight = history[0].weight
weight_trend = latest_weight - early_weight
avg_intake = mean([h.intake_kcal for h in history])
avg_exercise = mean([h.exercise_kcal for h in history])
cal_limit = 520
strategy = "维持计划"
if weight_trend > 0.8 or avg_intake > avg_exercise + 1700:
cal_limit = 440
strategy = "减脂优先计划"
elif weight_trend < -0.8:
cal_limit = 620
strategy = "增肌恢复计划"
recipes = Recipe.query.filter(Recipe.calories <= cal_limit).order_by(Recipe.protein.desc()).limit(top_k).all()
items = []
for recipe in recipes:
items.append(
{
**recipe.to_dict(),
"score": round(recipe.protein * 2 - recipe.fat * 0.3, 2),
"reason": f"基于历史趋势({strategy}),优先推荐该热量区间食谱。",
}
)
response_data = {
"type": "history_plan",
"strategy": strategy,
"metrics": {
"avg_weight": round(avg_weight, 2),
"weight_trend": round(weight_trend, 2),
"avg_intake": round(avg_intake, 1),
"avg_exercise": round(avg_exercise, 1),
},
"items": items,
}
_save_log(user.id, "history_plan", response_data)
return ok(response_data)
@recommend_bp.get("/plan/goal")
@jwt_required()
def recommend_plan_by_goal():
user = current_user()
if not user:
return fail("用户不存在", 404)
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
goal = request.args.get("goal") or user.goal
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
if not recipes:
return fail("当前没有可推荐食谱,请先导入食谱", 400)
scored = []
for recipe in recipes:
score = 60 + _goal_adjust(goal, recipe)
reason = f"根据目标({goal})筛选营养比例更匹配的食谱。"
scored.append((recipe, score, reason))
scored.sort(key=lambda x: x[1], reverse=True)
items = _serialize_with_score(scored[:top_k])
response_data = {
"type": "future_goal_plan",
"goal": goal,
"items": items,
}
_save_log(user.id, "future_goal_plan", response_data)
return ok(response_data)
@recommend_bp.get("/plan/occupation")
@jwt_required()
def recommend_plan_by_occupation():
user = current_user()
if not user:
return fail("用户不存在", 404)
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
occupation = request.args.get("occupation") or user.occupation
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
if not recipes:
return fail("当前没有可推荐食谱,请先导入食谱", 400)
scored = []
for recipe in recipes:
score = 65 + _occupation_adjust(occupation, recipe)
reason = f"结合职业({occupation})的能量消耗特点进行推荐。"
scored.append((recipe, score, reason))
scored.sort(key=lambda x: x[1], reverse=True)
items = _serialize_with_score(scored[:top_k])
response_data = {
"type": "occupation_plan",
"occupation": occupation,
"items": items,
}
_save_log(user.id, "occupation_plan", response_data)
return ok(response_data)
@recommend_bp.get("/logs")
@jwt_required()
def my_recommend_logs():
user = current_user()
if not user:
return fail("用户不存在", 404)
limit = min(max(int(request.args.get("limit", 20) or 20), 1), 100)
rows = (
RecommendationLog.query.filter_by(user_id=user.id)
.order_by(RecommendationLog.id.desc())
.limit(limit)
.all()
)
return ok([item.to_dict() for item in rows])

View File

@@ -0,0 +1,334 @@
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
from app.models import DetectionConfig, SpamPredictionLog, SpamTrainingSample
from app.utils.auth import admin_required, current_user
from app.utils.response import fail, ok
spam_bp = Blueprint("spam", __name__)
def _classifier() -> NaiveBayesSpamClassifier:
return NaiveBayesSpamClassifier(current_app.config["NB_MODEL_PATH"])
def _active_samples() -> list[dict]:
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
return [{"text": row.text, "label": row.label} for row in rows]
def _ensure_ready() -> NaiveBayesSpamClassifier:
clf = _classifier()
samples = _active_samples()
clf.ensure_ready(samples)
return clf
def _threshold() -> float:
row = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
return float(row.spam_threshold) if row else 0.75
@spam_bp.post("/predict")
@jwt_required()
def predict_one():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
if len(text) < 2:
return fail("请输入至少2个字符的待识别文本", 400)
clf = _ensure_ready()
result = clf.predict(text)
threshold = _threshold()
blocked = float(result["spam_probability"]) >= threshold
row = SpamPredictionLog(
user_id=user.id,
text=result["text"],
prediction=result["prediction"],
spam_probability=result["spam_probability"],
ham_probability=result["ham_probability"],
confidence=result["confidence"],
reason_tokens=result["reason_tokens"],
model_version=result.get("model_version", ""),
)
db.session.add(row)
db.session.commit()
return ok({**result, "log_id": row.id, "threshold": threshold, "blocked_by_threshold": blocked}, "识别成功")
@spam_bp.post("/predict/batch")
@jwt_required()
def predict_batch():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
items = payload.get("items") or []
if not isinstance(items, list) or not items:
return fail("items 必须是非空数组", 400)
if len(items) > 100:
return fail("单次最多识别100条", 400)
clf = _ensure_ready()
rows = []
results = []
threshold = _threshold()
for text in items:
content = (text or "").strip()
if len(content) < 2:
continue
result = clf.predict(content)
result["blocked_by_threshold"] = float(result["spam_probability"]) >= threshold
rows.append(
SpamPredictionLog(
user_id=user.id,
text=result["text"],
prediction=result["prediction"],
spam_probability=result["spam_probability"],
ham_probability=result["ham_probability"],
confidence=result["confidence"],
reason_tokens=result["reason_tokens"],
model_version=result.get("model_version", ""),
)
)
results.append(result)
if not rows:
return fail("没有可识别的有效文本", 400)
db.session.add_all(rows)
db.session.commit()
spam_count = len([item for item in results if item["prediction"] == "spam"])
blocked_count = len([item for item in results if item["blocked_by_threshold"]])
ham_count = len(results) - spam_count
return ok(
{
"items": results,
"summary": {
"total": len(results),
"spam_count": spam_count,
"ham_count": ham_count,
"blocked_count": blocked_count,
"spam_ratio": round(spam_count / len(results), 4) if results else 0,
"blocked_ratio": round(blocked_count / len(results), 4) if results else 0,
"threshold": threshold,
},
},
"批量识别完成",
)
@spam_bp.get("/history")
@jwt_required()
def my_history():
user = current_user()
if not user:
return fail("用户不存在", 404)
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
pagination = (
SpamPredictionLog.query.filter_by(user_id=user.id)
.order_by(SpamPredictionLog.id.desc())
.paginate(page=page, per_page=page_size, error_out=False)
)
return ok(
{
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@spam_bp.delete("/history/<int:log_id>")
@jwt_required()
def delete_history(log_id: int):
user = current_user()
if not user:
return fail("用户不存在", 404)
row = SpamPredictionLog.query.filter_by(id=log_id, user_id=user.id).first()
if not row:
return fail("记录不存在", 404)
db.session.delete(row)
db.session.commit()
return ok({}, "记录已删除")
@spam_bp.post("/feedback")
@jwt_required()
def save_feedback():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
if len(text) < 2:
return fail("文本至少2个字符", 400)
if not label:
return fail("label 必须是 spam 或 ham", 400)
row = SpamTrainingSample(text=text, label=label, source="feedback", created_by=user.id, is_active=True)
db.session.add(row)
db.session.commit()
return ok(row.to_dict(), "反馈样本已记录")
@spam_bp.get("/model/info")
@jwt_required(optional=True)
def model_info():
clf = _classifier()
clf.load()
info = clf.model_info()
info["threshold"] = _threshold()
return ok(info)
@spam_bp.post("/train")
@admin_required
def train_model():
clf = _classifier()
samples = _active_samples()
metadata = clf.train(samples)
return ok(metadata, "模型训练完成")
@spam_bp.get("/samples")
@admin_required
def list_samples():
keyword = (request.args.get("keyword") or "").strip()
label = (request.args.get("label") or "").strip().lower()
page = max(int(request.args.get("page", 1) or 1), 1)
page_size = min(max(int(request.args.get("page_size", 20) or 20), 1), 100)
query = SpamTrainingSample.query
if keyword:
query = query.filter(SpamTrainingSample.text.like(f"%{keyword}%"))
if label in {"spam", "ham"}:
query = query.filter(SpamTrainingSample.label == label)
pagination = query.order_by(SpamTrainingSample.id.desc()).paginate(page=page, per_page=page_size, error_out=False)
return ok(
{
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": page,
"page_size": page_size,
}
)
@spam_bp.post("/samples")
@admin_required
def create_sample():
payload = request.get_json(silent=True) or {}
text = (payload.get("text") or "").strip()
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
if len(text) < 2:
return fail("文本至少2个字符", 400)
if not label:
return fail("label 必须是 spam 或 ham", 400)
user = current_user()
row = SpamTrainingSample(text=text, label=label, source="import", created_by=user.id if user else None, is_active=True)
db.session.add(row)
db.session.commit()
return ok(row.to_dict(), "样本创建成功")
@spam_bp.put("/samples/<int:sample_id>")
@admin_required
def update_sample(sample_id: int):
row = SpamTrainingSample.query.get(sample_id)
if not row:
return fail("样本不存在", 404)
payload = request.get_json(silent=True) or {}
if "text" in payload:
text = (payload.get("text") or "").strip()
if len(text) < 2:
return fail("文本至少2个字符", 400)
row.text = text
if "label" in payload:
label = NaiveBayesSpamClassifier.normalize_label(payload.get("label"))
if not label:
return fail("label 必须是 spam 或 ham", 400)
row.label = label
if "is_active" in payload:
row.is_active = bool(payload.get("is_active"))
db.session.commit()
return ok(row.to_dict(), "样本更新成功")
@spam_bp.delete("/samples/<int:sample_id>")
@admin_required
def delete_sample(sample_id: int):
row = SpamTrainingSample.query.get(sample_id)
if not row:
return fail("样本不存在", 404)
db.session.delete(row)
db.session.commit()
return ok({}, "样本已删除")
@spam_bp.post("/samples/import")
@admin_required
def import_samples():
payload = request.get_json(silent=True) or {}
items = payload.get("items") or []
if not isinstance(items, list) or not items:
return fail("items 必须是非空数组", 400)
user = current_user()
created = 0
updated = 0
for item in items:
text = (item.get("text") or "").strip()
label = NaiveBayesSpamClassifier.normalize_label(item.get("label"))
if len(text) < 2 or not label:
continue
row = SpamTrainingSample.query.filter_by(text=text).first()
if row:
row.label = label
row.is_active = bool(item.get("is_active", True))
row.source = item.get("source") or row.source
updated += 1
else:
row = SpamTrainingSample(
text=text,
label=label,
source=item.get("source") or "import",
created_by=user.id if user else None,
is_active=bool(item.get("is_active", True)),
)
db.session.add(row)
created += 1
db.session.commit()
return ok({"created": created, "updated": updated}, "样本导入完成")

View File

@@ -0,0 +1,62 @@
from flask import Blueprint, request
from flask_jwt_extended import jwt_required
from app.extensions import db
from app.models import User
from app.utils.auth import current_user
from app.utils.response import fail, ok
user_bp = Blueprint("user", __name__)
@user_bp.get("/profile")
@jwt_required()
def get_profile():
user = current_user()
if not user:
return fail("用户不存在", 404)
return ok(user.to_dict())
@user_bp.put("/profile")
@jwt_required()
def update_profile():
user = current_user()
if not user:
return fail("用户不存在", 404)
payload = request.get_json(silent=True) or {}
if "nickname" in payload:
user.nickname = (payload.get("nickname") or user.nickname).strip() or user.nickname
if "company" in payload:
user.company = (payload.get("company") or "").strip()
if "title" in payload:
user.title = (payload.get("title") or "").strip()
if "phone" in payload:
user.phone = (payload.get("phone") or "").strip()
new_password = payload.get("password")
if new_password:
if len(new_password) < 6:
return fail("新密码至少6位", 400)
user.set_password(new_password)
db.session.commit()
return ok(user.to_dict(), "个人信息更新成功")
@user_bp.get("/search")
@jwt_required()
def search_users():
keyword = (request.args.get("keyword") or "").strip()
if not keyword:
return ok([])
users = (
User.query.filter(User.username.like(f"%{keyword}%") | User.nickname.like(f"%{keyword}%"))
.order_by(User.id.desc())
.limit(20)
.all()
)
return ok([item.to_dict() for item in users])

View File

@@ -0,0 +1 @@

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

25
backend/app/utils/auth.py Normal file
View File

@@ -0,0 +1,25 @@
from functools import wraps
from flask import jsonify
from flask_jwt_extended import get_jwt, get_jwt_identity, verify_jwt_in_request
from app.models import User
def current_user() -> User | None:
identity = get_jwt_identity()
if not identity:
return None
return User.query.get(int(identity))
def admin_required(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
verify_jwt_in_request()
claims = get_jwt()
if not claims.get("is_admin", False):
return jsonify({"code": 403, "message": "需要管理员权限"}), 403
return fn(*args, **kwargs)
return wrapper

View File

@@ -0,0 +1,9 @@
from flask import jsonify
def ok(data=None, message="success"):
return jsonify({"code": 0, "message": message, "data": data or {}})
def fail(message="error", code=400, data=None):
return jsonify({"code": code, "message": message, "data": data or {}}), code

152
backend/init_db.py Normal file
View File

@@ -0,0 +1,152 @@
import json
from pathlib import Path
import pymysql
from app import create_app
from app.extensions import db
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
from app.models import DetectionConfig, SpamTrainingSample, User
BASE_DIR = Path(__file__).resolve().parent
MYSQL_CONFIG_PATH = BASE_DIR / "mysqlconfig.json"
SPAM_SEED_PATH = BASE_DIR / "seed" / "spam_samples_seed.json"
def load_mysql_cfg() -> dict:
if not MYSQL_CONFIG_PATH.exists():
raise FileNotFoundError(f"未找到配置文件: {MYSQL_CONFIG_PATH}")
with MYSQL_CONFIG_PATH.open("r", encoding="utf-8-sig") as file:
return json.load(file)
def create_database(mysql_cfg: dict) -> None:
conn = pymysql.connect(
host=mysql_cfg.get("host", "127.0.0.1"),
port=int(mysql_cfg.get("port", 3306)),
user=mysql_cfg["user"],
password=mysql_cfg["password"],
charset=mysql_cfg.get("charset", "utf8mb4"),
autocommit=True,
)
try:
with conn.cursor() as cursor:
cursor.execute(
f"CREATE DATABASE IF NOT EXISTS `{mysql_cfg['database']}` DEFAULT CHARACTER SET {mysql_cfg.get('charset', 'utf8mb4')}"
)
finally:
conn.close()
def ensure_seed_file() -> None:
if SPAM_SEED_PATH.exists():
return
SPAM_SEED_PATH.parent.mkdir(parents=True, exist_ok=True)
defaults = [
{"text": "点击链接领取1000元现金红包先到先得", "label": "spam"},
{"text": "今晚项目例会改到19点记得带周报", "label": "ham"},
{"text": "恭喜你成为幸运用户,马上回复领取大奖", "label": "spam"},
{"text": "明天上午10点客户演示请提前准备材料", "label": "ham"},
]
with SPAM_SEED_PATH.open("w", encoding="utf-8") as file:
json.dump(defaults, file, ensure_ascii=False, indent=2)
def seed_samples() -> tuple[int, int]:
ensure_seed_file()
with SPAM_SEED_PATH.open("r", encoding="utf-8-sig") as file:
rows = json.load(file)
created = 0
updated = 0
for item in rows:
text = " ".join((item.get("text") or "").strip().split())
label = (item.get("label") or "").strip().lower()
if len(text) < 2 or label not in {"spam", "ham"}:
continue
sample = SpamTrainingSample.query.filter_by(text=text).first()
if sample:
sample.label = label
sample.is_active = True
sample.source = sample.source or "seed"
updated += 1
continue
sample = SpamTrainingSample(text=text, label=label, source="seed", is_active=True)
db.session.add(sample)
created += 1
db.session.commit()
return created, updated
def ensure_detection_config(mysql_cfg: dict) -> float:
cfg = DetectionConfig.query.order_by(DetectionConfig.id.asc()).first()
if cfg:
return float(cfg.spam_threshold)
init_cfg = mysql_cfg.get("detection_init", {}) if isinstance(mysql_cfg, dict) else {}
try:
threshold = float(init_cfg.get("spam_threshold", 0.75))
except Exception:
threshold = 0.75
threshold = min(max(threshold, 0.01), 0.99)
cfg = DetectionConfig(spam_threshold=threshold)
db.session.add(cfg)
db.session.commit()
return threshold
def init_admin(mysql_cfg: dict) -> str:
admin_cfg = mysql_cfg.get("admin_init", {})
if not admin_cfg.get("create_default_admin", True):
return "默认管理员创建已关闭"
username = admin_cfg.get("username", "admin")
password = admin_cfg.get("password", "Admin@123456")
nickname = admin_cfg.get("nickname", "系统管理员")
admin = User.query.filter_by(username=username).first()
if admin:
return f"管理员已存在: {username}"
admin = User(username=username, nickname=nickname, is_admin=True)
admin.set_password(password)
db.session.add(admin)
db.session.commit()
return f"管理员已创建: {username}"
def train_initial_model(model_path: str) -> dict:
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
samples = [{"text": row.text, "label": row.label} for row in rows]
clf = NaiveBayesSpamClassifier(model_path)
return clf.train(samples)
def main():
mysql_cfg = load_mysql_cfg()
create_database(mysql_cfg)
app = create_app()
with app.app_context():
db.create_all()
created, updated = seed_samples()
threshold = ensure_detection_config(mysql_cfg)
admin_msg = init_admin(mysql_cfg)
model_meta = train_initial_model(app.config["NB_MODEL_PATH"])
print("数据库初始化完成")
print(f"- 样本新增: {created}")
print(f"- 样本更新: {updated}")
print(f"- 初始阈值: {threshold}")
print(f"- {admin_msg}")
print(f"- 模型版本: {model_meta.get('version')}")
if __name__ == "__main__":
main()

18
backend/llm_config.json Normal file
View File

@@ -0,0 +1,18 @@
{
"active_provider": "fastgpt",
"fastgpt": {
"base_url": "https://cloud.fastgpt.io/api",
"api_key": "fastgpt-tdn84kTbv0hdIuEdHpVsQcoJuTR9doopjc9c5DRxsUjmfJ9B9Sn3gw1ywEoE",
"chat_id": "111",
"custom_uid": "",
"model": "",
"system_prompt": "你是一位专业营养师,请结合用户饮食状态给出可执行建议。"
},
"dify": {
"base_url": "http://localhost:5001",
"api_key": "app-your-dify-api-key"
},
"request": {
"timeout_seconds": 45
}
}

Binary file not shown.

Binary file not shown.

17
backend/mysqlconfig.json Normal file
View File

@@ -0,0 +1,17 @@
{
"host": "192.168.2.183",
"port": 3308,
"user": "root",
"password": "rootroot",
"database": "spam_nb_miniapp",
"charset": "utf8mb4",
"admin_init": {
"create_default_admin": true,
"username": "admin",
"password": "Admin@123456",
"nickname": "系统管理员"
},
"detection_init": {
"spam_threshold": 0.75
}
}

View File

@@ -0,0 +1,17 @@
{
"host": "127.0.0.1",
"port": 3306,
"user": "root",
"password": "pk123123",
"database": "spam_nb_miniapp",
"charset": "utf8mb4",
"admin_init": {
"create_default_admin": true,
"username": "admin",
"password": "Admin@123456",
"nickname": "系统管理员"
},
"detection_init": {
"spam_threshold": 0.75
}
}

13
backend/requirements.txt Normal file
View File

@@ -0,0 +1,13 @@
Flask==3.1.0
Flask-SQLAlchemy==3.1.1
Flask-JWT-Extended==4.6.0
Flask-Cors==5.0.0
PyMySQL==1.1.1
SQLAlchemy==2.0.36
scikit-learn==1.5.2
numpy==2.1.3
pandas==2.2.3
joblib==1.4.2
python-dotenv==1.0.1
requests==2.32.3
Werkzeug==3.1.3

8
backend/run.py Normal file
View File

@@ -0,0 +1,8 @@
from app import create_app
app = create_app()
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000, debug=True)

View File

@@ -0,0 +1,122 @@
[
{
"question": "减脂期每天蛋白质摄入多少?",
"answer": "通常建议 1.6-2.2g/kg 体重;训练较多时可接近上限。",
"tags": ["减脂", "蛋白质"],
"source": "中国居民膳食指南与运动营养实践"
},
{
"question": "减脂为什么还要吃碳水?",
"answer": "适量碳水可维持训练表现和代谢,过低碳水可能影响恢复与饱腹感。",
"tags": ["减脂", "碳水"],
"source": "本地营养知识整理"
},
{
"question": "生酮饮食碳水要控制到多少?",
"answer": "常见做法是每日净碳水控制在 20-50g但需要结合个体耐受和专业建议。",
"tags": ["生酮", "低碳"],
"source": "本地营养知识整理"
},
{
"question": "生酮期间为什么要补电解质?",
"answer": "低碳初期水分与电解质流失增加,适度补充钠钾镁可减轻不适。",
"tags": ["生酮", "电解质"],
"source": "本地营养知识整理"
},
{
"question": "增肌期热量应该怎么安排?",
"answer": "建议在维持热量基础上增加约 5%-15%,并确保优质蛋白和训练刺激。",
"tags": ["增肌", "热量"],
"source": "运动营养实践"
},
{
"question": "增肌期蛋白质摄入建议",
"answer": "每日 1.6-2.2g/kg 体重较常见,可分散到 3-5 餐摄入。",
"tags": ["增肌", "蛋白质"],
"source": "运动营养实践"
},
{
"question": "办公室人群如何控制饮食",
"answer": "优先高蛋白高纤维便当,减少含糖饮料与高油零食,固定进餐时间。",
"tags": ["职业", "办公室"],
"source": "本地营养知识整理"
},
{
"question": "程序员久坐饮食建议",
"answer": "主食粗细搭配,搭配蛋白和蔬菜,下午加餐可选酸奶或坚果。",
"tags": ["职业", "开发"],
"source": "本地营养知识整理"
},
{
"question": "体脂率高于 30% 应先做什么",
"answer": "先建立热量缺口与规律运动,优先提升蛋白和蔬菜比例,再逐步加大训练量。",
"tags": ["体脂率", "减脂"],
"source": "本地营养知识整理"
},
{
"question": "如何判断每天热量是否超标",
"answer": "结合体重趋势、围度和饮食记录评估,连续 2-4 周上升通常意味着摄入偏高。",
"tags": ["热量", "记录"],
"source": "本地营养知识整理"
},
{
"question": "晚餐吃轻食会饿怎么办",
"answer": "提高蛋白质和纤维比例,适当增加低 GI 主食,避免只吃蔬菜。",
"tags": ["晚餐", "饱腹"],
"source": "本地营养知识整理"
},
{
"question": "一天喝多少水比较合理",
"answer": "常见建议约 30-35ml/kg 体重,运动出汗时需额外补充。",
"tags": ["饮水", "基础"],
"source": "本地营养知识整理"
},
{
"question": "轻食是不是只能吃沙拉",
"answer": "不是。轻食强调低负担和营养均衡,可包含热餐、汤品、全谷物和优质蛋白。",
"tags": ["轻食", "误区"],
"source": "本地营养知识整理"
},
{
"question": "减脂期间可以吃水果吗",
"answer": "可以。注意控制总量和整体热量,优先低糖高纤水果并安排在白天。",
"tags": ["减脂", "水果"],
"source": "本地营养知识整理"
},
{
"question": "运动后多长时间吃饭较好",
"answer": "通常在训练后 30-90 分钟补充蛋白和碳水有助于恢复。",
"tags": ["训练", "恢复"],
"source": "运动营养实践"
},
{
"question": "早餐不吃会影响减脂吗",
"answer": "关键是全天能量和营养结构;不吃早餐可能导致后续暴食,应根据作息选择。",
"tags": ["早餐", "减脂"],
"source": "本地营养知识整理"
},
{
"question": "体重平台期怎么破",
"answer": "复查记录准确度,微调热量或活动量,保证睡眠与训练强度,再观察 2 周。",
"tags": ["平台期", "减脂"],
"source": "本地营养知识整理"
},
{
"question": "高蛋白饮食会不会伤肾",
"answer": "健康人群在合理范围内通常问题不大,但已有肾脏疾病者需遵医嘱。",
"tags": ["蛋白质", "安全"],
"source": "本地营养知识整理"
},
{
"question": "如何设置一周饮食计划",
"answer": "先按目标设定热量和宏量营养,再按早餐/午餐/晚餐分配,预留 10% 灵活空间。",
"tags": ["计划", "实操"],
"source": "本地营养知识整理"
},
{
"question": "轻断食适合所有人吗",
"answer": "并不适合所有人,孕妇、青少年、慢病人群需谨慎并咨询专业人士。",
"tags": ["轻断食", "安全"],
"source": "本地营养知识整理"
}
]

View File

@@ -0,0 +1,242 @@
[
{
"name": "香煎鸡胸藜麦碗",
"category": "增肌轻食",
"description": "鸡胸肉、藜麦、生菜、番茄、玉米粒",
"calories": 460,
"protein": 42,
"fat": 12,
"carbs": 45,
"fiber": 8,
"tags": ["高蛋白", "低脂", "训练后"],
"difficulty": "easy"
},
{
"name": "烟熏三文鱼牛油果沙拉",
"category": "生酮轻食",
"description": "三文鱼、牛油果、混合生菜、坚果",
"calories": 510,
"protein": 30,
"fat": 34,
"carbs": 18,
"fiber": 7,
"tags": ["生酮", "优质脂肪"],
"difficulty": "easy"
},
{
"name": "鸡蛋金枪鱼全麦卷",
"category": "减脂轻食",
"description": "全麦饼、金枪鱼、水煮蛋、酸奶酱",
"calories": 390,
"protein": 33,
"fat": 11,
"carbs": 36,
"fiber": 6,
"tags": ["减脂", "高蛋白"],
"difficulty": "easy"
},
{
"name": "牛肉彩椒糙米饭",
"category": "增肌轻食",
"description": "瘦牛肉、糙米、彩椒、西兰花",
"calories": 560,
"protein": 39,
"fat": 16,
"carbs": 62,
"fiber": 9,
"tags": ["增肌", "高碳水"],
"difficulty": "medium"
},
{
"name": "虾仁西兰花意面",
"category": "均衡轻食",
"description": "全麦意面、虾仁、西兰花、蒜香橄榄油",
"calories": 480,
"protein": 34,
"fat": 14,
"carbs": 54,
"fiber": 8,
"tags": ["均衡", "高纤维"],
"difficulty": "medium"
},
{
"name": "低脂鸡肉凯撒沙拉",
"category": "减脂轻食",
"description": "鸡胸肉、生菜、圣女果、低脂凯撒酱",
"calories": 360,
"protein": 37,
"fat": 10,
"carbs": 20,
"fiber": 5,
"tags": ["低卡", "减脂"],
"difficulty": "easy"
},
{
"name": "豆腐牛油果暖沙拉",
"category": "素食轻食",
"description": "豆腐、牛油果、蘑菇、菠菜",
"calories": 410,
"protein": 21,
"fat": 24,
"carbs": 25,
"fiber": 9,
"tags": ["素食", "生酮友好"],
"difficulty": "easy"
},
{
"name": "鸡肉鹰嘴豆能量碗",
"category": "均衡轻食",
"description": "鸡胸肉、鹰嘴豆、紫甘蓝、胡萝卜",
"calories": 520,
"protein": 36,
"fat": 13,
"carbs": 63,
"fiber": 11,
"tags": ["高纤维", "饱腹"],
"difficulty": "medium"
},
{
"name": "芦笋牛排轻食拼盘",
"category": "增肌轻食",
"description": "牛排、芦笋、南瓜泥、番茄",
"calories": 590,
"protein": 44,
"fat": 24,
"carbs": 38,
"fiber": 6,
"tags": ["高蛋白", "力量训练"],
"difficulty": "medium"
},
{
"name": "魔芋鸡丝低碳碗",
"category": "生酮轻食",
"description": "魔芋丝、鸡胸肉、黄瓜、芝麻酱",
"calories": 320,
"protein": 29,
"fat": 15,
"carbs": 12,
"fiber": 4,
"tags": ["低碳", "生酮", "减脂"],
"difficulty": "easy"
},
{
"name": "三文鱼羽衣甘蓝碗",
"category": "均衡轻食",
"description": "三文鱼、羽衣甘蓝、红薯、藜麦",
"calories": 540,
"protein": 35,
"fat": 22,
"carbs": 49,
"fiber": 10,
"tags": ["抗氧化", "Omega-3"],
"difficulty": "medium"
},
{
"name": "鸡腿肉菌菇焖饭",
"category": "日常轻食",
"description": "去皮鸡腿肉、糙米、香菇、玉米",
"calories": 500,
"protein": 32,
"fat": 14,
"carbs": 57,
"fiber": 7,
"tags": ["家常", "均衡"],
"difficulty": "medium"
},
{
"name": "火鸡胸蔬菜三明治",
"category": "减脂轻食",
"description": "全麦吐司、火鸡胸、生菜、番茄",
"calories": 340,
"protein": 28,
"fat": 9,
"carbs": 34,
"fiber": 5,
"tags": ["便携", "办公室"],
"difficulty": "easy"
},
{
"name": "鸡蛋菠菜奶酪欧姆雷",
"category": "生酮轻食",
"description": "鸡蛋、菠菜、低碳奶酪、蘑菇",
"calories": 370,
"protein": 25,
"fat": 26,
"carbs": 8,
"fiber": 3,
"tags": ["早餐", "低碳"],
"difficulty": "easy"
},
{
"name": "酸奶浆果燕麦杯",
"category": "日常轻食",
"description": "希腊酸奶、燕麦、蓝莓、奇亚籽",
"calories": 300,
"protein": 19,
"fat": 7,
"carbs": 42,
"fiber": 7,
"tags": ["早餐", "高纤维"],
"difficulty": "easy"
},
{
"name": "鸡胸南瓜蔬菜汤",
"category": "减脂轻食",
"description": "鸡胸肉、南瓜、西芹、胡萝卜",
"calories": 280,
"protein": 26,
"fat": 6,
"carbs": 30,
"fiber": 6,
"tags": ["低卡", "晚餐"],
"difficulty": "easy"
},
{
"name": "鳕鱼豆腐味噌碗",
"category": "均衡轻食",
"description": "鳕鱼、嫩豆腐、糙米、海带",
"calories": 430,
"protein": 33,
"fat": 11,
"carbs": 44,
"fiber": 5,
"tags": ["清淡", "高蛋白"],
"difficulty": "medium"
},
{
"name": "黑椒鸡排花椰菜饭",
"category": "生酮轻食",
"description": "鸡排、花椰菜米、彩椒、坚果碎",
"calories": 400,
"protein": 35,
"fat": 21,
"carbs": 16,
"fiber": 6,
"tags": ["低碳", "减脂"],
"difficulty": "easy"
},
{
"name": "牛油果鸡蛋藜麦沙拉",
"category": "均衡轻食",
"description": "牛油果、水煮蛋、藜麦、生菜",
"calories": 450,
"protein": 22,
"fat": 24,
"carbs": 36,
"fiber": 9,
"tags": ["均衡", "饱腹"],
"difficulty": "easy"
},
{
"name": "鸡胸肉荞麦冷面",
"category": "日常轻食",
"description": "荞麦面、鸡胸丝、黄瓜、溏心蛋",
"calories": 470,
"protein": 31,
"fat": 12,
"carbs": 58,
"fiber": 6,
"tags": ["夏季", "办公室"],
"difficulty": "medium"
}
]

View File

@@ -0,0 +1,67 @@
[
{"text": "尊敬的用户您已获赠100元话费点击链接立即到账", "label": "spam"},
{"text": "本周五下午两点进行季度复盘,请准时参加", "label": "ham"},
{"text": "最后3个名额免费领取手机一台回复1立即领取", "label": "spam"},
{"text": "您好,合同已发送到邮箱,请查收并反馈修改意见", "label": "ham"},
{"text": "你的快递因地址异常被退回,点击网址重新填写", "label": "spam"},
{"text": "明天出差高铁票已订好,车次信息已同步到群里", "label": "ham"},
{"text": "恭喜你成为平台幸运粉丝,马上领现金红包", "label": "spam"},
{"text": "研发环境今晚22点维护预计30分钟恢复", "label": "ham"},
{"text": "内部渠道兼职日结500添加微信了解详情", "label": "spam"},
{"text": "周报模板已更新,请使用新模板提交", "label": "ham"},
{"text": "低价出售苹果手机,全新未拆封,先到先得", "label": "spam"},
{"text": "客户反馈文档在共享盘,路径已发你私聊", "label": "ham"},
{"text": "紧急通知:你的银行卡存在风险,请立即验证", "label": "spam"},
{"text": "今天的日报我已补充到项目看板", "label": "ham"},
{"text": "官方补贴发放中,输入验证码即可领取", "label": "spam"},
{"text": "下午四点产品评审,麻烦准备交互稿", "label": "ham"},
{"text": "无需面试,高薪在家办公,扫码进群", "label": "spam"},
{"text": "发票已开具完成,纸质件今天寄出", "label": "ham"},
{"text": "您的贷款已通过,点击查看额度", "label": "spam"},
{"text": "会议纪要我整理好了,已上传飞书文档", "label": "ham"},
{"text": "双十一秒杀提前抢,点此领隐藏优惠券", "label": "spam"},
{"text": "新同事今天入职,请大家中午一起欢迎", "label": "ham"},
{"text": "你有一笔退款待确认,马上处理避免失效", "label": "spam"},
{"text": "设计稿第二版我已经按你建议调整完了", "label": "ham"},
{"text": "点击领取年度会员原价699现价9.9", "label": "spam"},
{"text": "明天早会由我来同步上线计划", "label": "ham"},
{"text": "官方通知:账号异常将被冻结,请立即解封", "label": "spam"},
{"text": "请把测试环境数据库备份到指定目录", "label": "ham"},
{"text": "刷单项目火热招募,宝妈学生都能做", "label": "spam"},
{"text": "你发的需求我已经拆分成开发任务", "label": "ham"},
{"text": "中奖通知:你获得平板电脑一台,限时领取", "label": "spam"},
{"text": "客户明天下午三点会远程验收新功能", "label": "ham"},
{"text": "陌生链接请勿泄露验证码,谨防被骗", "label": "ham"},
{"text": "马上关注公众号领取无门槛现金券", "label": "spam"},
{"text": "今天的构建失败是依赖冲突,我在修复", "label": "ham"},
{"text": "免费领取课程资料,扫码后自动发放", "label": "spam"},
{"text": "请确认一下下周排期是否需要调整", "label": "ham"},
{"text": "特惠机票内部价,回复姓名立刻锁座", "label": "spam"},
{"text": "我已经把版本回滚流程补充到Wiki", "label": "ham"},
{"text": "贷款秒批到账额度最高20万", "label": "spam"},
{"text": "合同法务意见已返回,请你二次确认", "label": "ham"},
{"text": "限时返现活动,点击进入马上到账", "label": "spam"},
{"text": "这个bug我复现到了定位在缓存层", "label": "ham"},
{"text": "你有新的违章信息,点开链接立即处理", "label": "spam"},
{"text": "早上好,今天先做性能压测再发版", "label": "ham"},
{"text": "邀请码最后1小时有效错过不再补发", "label": "spam"},
{"text": "中午12点在会议室A开需求评审会", "label": "ham"},
{"text": "苹果14只要1999货到付款保真", "label": "spam"},
{"text": "供应商报价单已更新到共享文件夹", "label": "ham"},
{"text": "想赚外快吗?加我秒赚零花钱", "label": "spam"},
{"text": "今天下午我去客户现场,晚些回公司", "label": "ham"},
{"text": "官方补贴计划启动,名额有限速来登记", "label": "spam"},
{"text": "测试报告已发你邮箱,包含复现步骤", "label": "ham"},
{"text": "欠费停机提醒,立即充值恢复使用", "label": "spam"},
{"text": "这个接口我加了幂等,避免重复提交", "label": "ham"},
{"text": "点击抽取盲盒大奖,百分百中奖", "label": "spam"},
{"text": "版本发布说明我已经整理成公告", "label": "ham"},
{"text": "独家内部消息,股票必涨,速进群", "label": "spam"},
{"text": "周一上午需要和财务对齐预算数据", "label": "ham"},
{"text": "紧急!你的社保账户异常,立即核验", "label": "spam"},
{"text": "我下午会把接口文档补全到OpenAPI", "label": "ham"},
{"text": "游戏皮肤免费领,输入手机号立刻到账", "label": "spam"},
{"text": "晚上的培训链接我刚刚发到部门群", "label": "ham"},
{"text": "预约体检补贴开通,点击立即申请", "label": "spam"},
{"text": "新需求优先级调高了,请先排进本周", "label": "ham"}
]

17
backend/train_model.py Normal file
View File

@@ -0,0 +1,17 @@
from app import create_app
from app.ml.naive_bayes_classifier import NaiveBayesSpamClassifier
from app.models import SpamTrainingSample
def main():
app = create_app()
with app.app_context():
rows = SpamTrainingSample.query.filter_by(is_active=True).order_by(SpamTrainingSample.id.asc()).all()
samples = [{"text": row.text, "label": row.label} for row in rows]
clf = NaiveBayesSpamClassifier(app.config["NB_MODEL_PATH"])
meta = clf.train(samples)
print(f"模型训练完成: {meta.get('version')} 样本数={meta.get('sample_count')}")
if __name__ == "__main__":
main()