Files
c/backend/app/routes/qa_routes.py
刘正航 b5237f9038 1
2026-04-21 22:45:19 +08:00

178 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from flask import Blueprint, current_app, request
from flask_jwt_extended import jwt_required
from app.rag.llm_client import LLMConfigManager, ask_dify, ask_fastgpt
from app.rag.local_retriever import LocalKnowledgeRetriever
from app.utils.auth import admin_required, current_user
from app.utils.response import ok
qa_bp = Blueprint("qa", __name__)
def _retriever() -> LocalKnowledgeRetriever:
return LocalKnowledgeRetriever(current_app.config["LOCAL_KB_PATH"])
def _llm_config() -> dict:
manager = LLMConfigManager(current_app.config["LLM_CONFIG_PATH"])
return manager.load()
def _default_provider(config: dict) -> str:
provider = (config.get("active_provider") or "fastgpt").strip().lower() if isinstance(config, dict) else "fastgpt"
return provider if provider in {"fastgpt", "dify"} else "fastgpt"
def _context_text(hits: list) -> str:
chunks = []
for idx, item in enumerate(hits, start=1):
chunks.append(f"[{idx}] 问: {item['question']}\n答: {item['answer']}")
return "\n\n".join(chunks)
def _local_answer(hits: list) -> str:
if not hits:
return "本地知识库暂未检索到高相关内容,建议补充问题细节。"
top = hits[0]
return f"根据本地知识库:{top['answer']}"
def _advice_template(advice_type: str) -> str:
templates = {
"gain_muscle": "增肌建议:每天蛋白质建议 1.6-2.2g/kg训练后 30 分钟内补充蛋白+碳水,优先鸡胸肉/牛肉/鸡蛋/燕麦。",
"lose_fat": "减脂建议:保持 10%-20% 轻热量缺口,优先高蛋白高纤维食物,减少高糖饮料和深夜加餐。",
"keto": "生酮建议:控制碳水通常小于 50g/天,优先健康脂肪和适量蛋白,并关注电解质补充。",
"nutritionist": "营养师建议:结合体重、体脂、活动量和目标制定周计划,每 7 天复盘体重和围度变化。",
"general": "通用建议:保证规律作息、三餐结构稳定、适量运动和充足饮水。",
}
return templates.get(advice_type, templates["general"])
def _ask_by_provider(provider: str, config: dict, question: str, context: str, user_id: str, username: str):
if provider == "dify":
return ask_dify(config, question, context=context, user=username)
return ask_fastgpt(config, question, context=context, custom_uid=user_id)
@qa_bp.get("/kb/search")
@jwt_required(optional=True)
def search_kb():
query = (request.args.get("query") or "").strip()
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 10)
retriever = _retriever()
hits = retriever.search(query, top_k=top_k)
return ok({"items": hits})
@qa_bp.post("/ask")
@jwt_required()
def ask_nutrition_qa():
payload = request.get_json(silent=True) or {}
question = (payload.get("question") or "").strip()
if not question:
return ok({"answer": "请输入问题", "items": []}, "empty_question")
mode = (payload.get("mode") or "auto").strip().lower()
llm_cfg = _llm_config()
provider = (payload.get("provider") or _default_provider(llm_cfg)).strip().lower()
if provider not in {"fastgpt", "dify"}:
provider = _default_provider(llm_cfg)
retriever = _retriever()
hits = retriever.search(question, top_k=3)
context = _context_text(hits)
local_answer = _local_answer(hits)
if mode == "local":
return ok({"answer": local_answer, "provider": "local_kb", "items": hits})
user = current_user()
llm_result = _ask_by_provider(
provider,
llm_cfg,
question,
context,
str(user.id) if user else "",
user.username if user else "anonymous",
)
if mode == "llm":
if llm_result.get("ok"):
return ok({"answer": llm_result.get("answer"), "provider": provider, "items": hits})
return ok(
{
"answer": f"LLM 调用失败,已回退本地知识库。{local_answer}",
"provider": "local_kb_fallback",
"items": hits,
"llm_error": llm_result,
}
)
if llm_result.get("ok"):
return ok({"answer": llm_result.get("answer"), "provider": provider, "items": hits})
return ok(
{
"answer": local_answer,
"provider": "local_kb",
"items": hits,
"llm_error": llm_result,
}
)
@qa_bp.post("/advice")
@jwt_required()
def advice():
payload = request.get_json(silent=True) or {}
advice_type = (payload.get("advice_type") or "general").strip().lower()
question = (payload.get("question") or "请根据我的状态给建议").strip()
llm_cfg = _llm_config()
provider = (payload.get("provider") or _default_provider(llm_cfg)).strip().lower()
if provider not in {"fastgpt", "dify"}:
provider = _default_provider(llm_cfg)
retriever = _retriever()
hits = retriever.search(f"{advice_type} {question}", top_k=3)
context = _context_text(hits)
base_text = _advice_template(advice_type)
user = current_user()
llm_result = _ask_by_provider(
provider,
llm_cfg,
f"请给出{advice_type}方向的个性化饮食建议。用户补充: {question}",
f"{base_text}\n\n{context}",
str(user.id) if user else "",
user.username if user else "anonymous",
)
if llm_result.get("ok"):
answer = llm_result.get("answer")
source = provider
else:
answer = f"{base_text}\n\n提示LLM 暂不可用,当前为本地模板建议)"
source = "local_template"
return ok(
{
"advice_type": advice_type,
"answer": answer,
"provider": source,
"knowledge_hits": hits,
"llm_status": llm_result,
}
)
@qa_bp.post("/kb/reload")
@admin_required
def reload_kb():
retriever = _retriever()
retriever.reload()
return ok({"count": len(retriever.documents)}, "知识库已重载")