315 lines
9.4 KiB
Python
315 lines
9.4 KiB
Python
from statistics import mean
|
||
|
||
from flask import Blueprint, current_app, request
|
||
from flask_jwt_extended import jwt_required
|
||
|
||
from app.extensions import db
|
||
from app.ml.rf_recommender import RandomForestDietRecommender, merge_profile_with_history
|
||
from app.models import DietStatus, RecommendationLog, Recipe
|
||
from app.utils.auth import current_user
|
||
from app.utils.response import fail, ok
|
||
|
||
|
||
recommend_bp = Blueprint("recommend", __name__)
|
||
|
||
|
||
def _build_profile(user, latest_status):
|
||
profile = {
|
||
"age": user.age,
|
||
"height_cm": user.height_cm,
|
||
"goal": user.goal,
|
||
"occupation": user.occupation,
|
||
"weight": 65,
|
||
"body_fat": 20,
|
||
"exercise_kcal": 300,
|
||
"intake_kcal": 1800,
|
||
}
|
||
if latest_status:
|
||
profile.update(
|
||
{
|
||
"weight": latest_status.weight,
|
||
"body_fat": latest_status.body_fat,
|
||
"exercise_kcal": latest_status.exercise_kcal,
|
||
"intake_kcal": latest_status.intake_kcal,
|
||
}
|
||
)
|
||
return profile
|
||
|
||
|
||
def _goal_adjust(goal: str, recipe: Recipe) -> float:
|
||
if goal == "lose_fat":
|
||
return recipe.protein * 0.8 - recipe.fat * 0.4 - max(recipe.calories - 500, 0) * 0.04
|
||
if goal == "gain_muscle":
|
||
return recipe.protein * 1.0 + recipe.carbs * 0.25
|
||
if goal == "keto":
|
||
return recipe.fat * 0.5 - recipe.carbs * 0.8
|
||
return recipe.protein * 0.3 + recipe.fiber * 1.2
|
||
|
||
|
||
def _occupation_adjust(occupation: str, recipe: Recipe) -> float:
|
||
occupation = (occupation or "").lower()
|
||
if occupation in {"developer", "office"}:
|
||
return recipe.fiber * 0.9 + recipe.protein * 0.2
|
||
if occupation in {"fitness", "manual"}:
|
||
return recipe.protein * 0.6 + recipe.carbs * 0.2
|
||
if occupation in {"teacher", "student"}:
|
||
return recipe.carbs * 0.15 + recipe.protein * 0.3
|
||
return recipe.protein * 0.25
|
||
|
||
|
||
def _health_adjust(profile: dict, recipe: Recipe) -> float:
|
||
score = 70.0
|
||
body_fat = float(profile.get("body_fat", 20))
|
||
intake_kcal = float(profile.get("intake_kcal", 1800))
|
||
exercise_kcal = float(profile.get("exercise_kcal", 300))
|
||
|
||
if body_fat > 28:
|
||
score += recipe.protein * 0.5
|
||
score -= max(recipe.calories - 480, 0) * 0.05
|
||
elif body_fat < 14:
|
||
score += recipe.carbs * 0.15 + recipe.fat * 0.1
|
||
|
||
if intake_kcal > exercise_kcal + 1700:
|
||
score -= max(recipe.calories - 450, 0) * 0.04
|
||
else:
|
||
score += recipe.protein * 0.3
|
||
|
||
score += min(recipe.fiber * 1.0, 10)
|
||
return round(score, 2)
|
||
|
||
|
||
def _serialize_with_score(recipes):
|
||
return [
|
||
{
|
||
**item[0].to_dict(),
|
||
"score": round(float(item[1]), 2),
|
||
"reason": item[2],
|
||
}
|
||
for item in recipes
|
||
]
|
||
|
||
|
||
def _save_log(user_id: int, rec_type: str, payload: dict) -> None:
|
||
log = RecommendationLog(user_id=user_id, rec_type=rec_type, payload=payload)
|
||
db.session.add(log)
|
||
db.session.commit()
|
||
|
||
|
||
@recommend_bp.get("/current")
|
||
@jwt_required()
|
||
def recommend_current_status():
|
||
user = current_user()
|
||
if not user:
|
||
return fail("用户不存在", 404)
|
||
|
||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
|
||
|
||
latest = (
|
||
DietStatus.query.filter_by(user_id=user.id)
|
||
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
|
||
.first()
|
||
)
|
||
history = (
|
||
DietStatus.query.filter_by(user_id=user.id)
|
||
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
|
||
.limit(14)
|
||
.all()
|
||
)
|
||
|
||
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
|
||
if not recipes:
|
||
return fail("当前没有可推荐食谱,请先导入食谱", 400)
|
||
|
||
profile = _build_profile(user, latest)
|
||
profile = merge_profile_with_history(profile, history)
|
||
|
||
recommender = RandomForestDietRecommender(current_app.config["MODEL_PATH"])
|
||
items = recommender.recommend(profile, recipes, top_k=top_k)
|
||
|
||
response_data = {
|
||
"type": "current_status_rf",
|
||
"profile": profile,
|
||
"items": items,
|
||
}
|
||
_save_log(user.id, "current_status_rf", response_data)
|
||
return ok(response_data)
|
||
|
||
|
||
@recommend_bp.get("/health")
|
||
@jwt_required()
|
||
def recommend_by_health():
|
||
user = current_user()
|
||
if not user:
|
||
return fail("用户不存在", 404)
|
||
|
||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
|
||
|
||
latest = (
|
||
DietStatus.query.filter_by(user_id=user.id)
|
||
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
|
||
.first()
|
||
)
|
||
profile = _build_profile(user, latest)
|
||
|
||
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
|
||
if not recipes:
|
||
return fail("当前没有可推荐食谱,请先导入食谱", 400)
|
||
|
||
scored = []
|
||
for recipe in recipes:
|
||
score = _health_adjust(profile, recipe)
|
||
reason = "结合当前体脂和摄入消耗状态,推荐更合理的营养结构。"
|
||
scored.append((recipe, score, reason))
|
||
|
||
scored.sort(key=lambda row: row[1], reverse=True)
|
||
items = _serialize_with_score(scored[:top_k])
|
||
|
||
response_data = {
|
||
"type": "health_state",
|
||
"profile": profile,
|
||
"items": items,
|
||
}
|
||
_save_log(user.id, "health_state", response_data)
|
||
return ok(response_data)
|
||
|
||
|
||
@recommend_bp.get("/plan/history")
|
||
@jwt_required()
|
||
def recommend_plan_by_history():
|
||
user = current_user()
|
||
if not user:
|
||
return fail("用户不存在", 404)
|
||
|
||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
|
||
history = (
|
||
DietStatus.query.filter_by(user_id=user.id)
|
||
.order_by(DietStatus.recorded_at.desc(), DietStatus.id.desc())
|
||
.limit(30)
|
||
.all()
|
||
)
|
||
if len(history) < 3:
|
||
return fail("历史数据太少,至少记录3天后可用该功能", 400)
|
||
|
||
avg_weight = mean([h.weight for h in history])
|
||
early_weight = history[-1].weight
|
||
latest_weight = history[0].weight
|
||
weight_trend = latest_weight - early_weight
|
||
|
||
avg_intake = mean([h.intake_kcal for h in history])
|
||
avg_exercise = mean([h.exercise_kcal for h in history])
|
||
|
||
cal_limit = 520
|
||
strategy = "维持计划"
|
||
if weight_trend > 0.8 or avg_intake > avg_exercise + 1700:
|
||
cal_limit = 440
|
||
strategy = "减脂优先计划"
|
||
elif weight_trend < -0.8:
|
||
cal_limit = 620
|
||
strategy = "增肌恢复计划"
|
||
|
||
recipes = Recipe.query.filter(Recipe.calories <= cal_limit).order_by(Recipe.protein.desc()).limit(top_k).all()
|
||
items = []
|
||
for recipe in recipes:
|
||
items.append(
|
||
{
|
||
**recipe.to_dict(),
|
||
"score": round(recipe.protein * 2 - recipe.fat * 0.3, 2),
|
||
"reason": f"基于历史趋势({strategy}),优先推荐该热量区间食谱。",
|
||
}
|
||
)
|
||
|
||
response_data = {
|
||
"type": "history_plan",
|
||
"strategy": strategy,
|
||
"metrics": {
|
||
"avg_weight": round(avg_weight, 2),
|
||
"weight_trend": round(weight_trend, 2),
|
||
"avg_intake": round(avg_intake, 1),
|
||
"avg_exercise": round(avg_exercise, 1),
|
||
},
|
||
"items": items,
|
||
}
|
||
_save_log(user.id, "history_plan", response_data)
|
||
return ok(response_data)
|
||
|
||
|
||
@recommend_bp.get("/plan/goal")
|
||
@jwt_required()
|
||
def recommend_plan_by_goal():
|
||
user = current_user()
|
||
if not user:
|
||
return fail("用户不存在", 404)
|
||
|
||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
|
||
goal = request.args.get("goal") or user.goal
|
||
|
||
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
|
||
if not recipes:
|
||
return fail("当前没有可推荐食谱,请先导入食谱", 400)
|
||
|
||
scored = []
|
||
for recipe in recipes:
|
||
score = 60 + _goal_adjust(goal, recipe)
|
||
reason = f"根据目标({goal})筛选营养比例更匹配的食谱。"
|
||
scored.append((recipe, score, reason))
|
||
|
||
scored.sort(key=lambda x: x[1], reverse=True)
|
||
items = _serialize_with_score(scored[:top_k])
|
||
|
||
response_data = {
|
||
"type": "future_goal_plan",
|
||
"goal": goal,
|
||
"items": items,
|
||
}
|
||
_save_log(user.id, "future_goal_plan", response_data)
|
||
return ok(response_data)
|
||
|
||
|
||
@recommend_bp.get("/plan/occupation")
|
||
@jwt_required()
|
||
def recommend_plan_by_occupation():
|
||
user = current_user()
|
||
if not user:
|
||
return fail("用户不存在", 404)
|
||
|
||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 20)
|
||
occupation = request.args.get("occupation") or user.occupation
|
||
|
||
recipes = Recipe.query.order_by(Recipe.id.asc()).all()
|
||
if not recipes:
|
||
return fail("当前没有可推荐食谱,请先导入食谱", 400)
|
||
|
||
scored = []
|
||
for recipe in recipes:
|
||
score = 65 + _occupation_adjust(occupation, recipe)
|
||
reason = f"结合职业({occupation})的能量消耗特点进行推荐。"
|
||
scored.append((recipe, score, reason))
|
||
|
||
scored.sort(key=lambda x: x[1], reverse=True)
|
||
items = _serialize_with_score(scored[:top_k])
|
||
|
||
response_data = {
|
||
"type": "occupation_plan",
|
||
"occupation": occupation,
|
||
"items": items,
|
||
}
|
||
_save_log(user.id, "occupation_plan", response_data)
|
||
return ok(response_data)
|
||
|
||
|
||
@recommend_bp.get("/logs")
|
||
@jwt_required()
|
||
def my_recommend_logs():
|
||
user = current_user()
|
||
if not user:
|
||
return fail("用户不存在", 404)
|
||
|
||
limit = min(max(int(request.args.get("limit", 20) or 20), 1), 100)
|
||
rows = (
|
||
RecommendationLog.query.filter_by(user_id=user.id)
|
||
.order_by(RecommendationLog.id.desc())
|
||
.limit(limit)
|
||
.all()
|
||
)
|
||
return ok([item.to_dict() for item in rows])
|