1
This commit is contained in:
177
backend/app/routes/qa_routes.py
Normal file
177
backend/app/routes/qa_routes.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from flask import Blueprint, current_app, request
|
||||
from flask_jwt_extended import jwt_required
|
||||
|
||||
from app.rag.llm_client import LLMConfigManager, ask_dify, ask_fastgpt
|
||||
from app.rag.local_retriever import LocalKnowledgeRetriever
|
||||
from app.utils.auth import admin_required, current_user
|
||||
from app.utils.response import ok
|
||||
|
||||
|
||||
qa_bp = Blueprint("qa", __name__)
|
||||
|
||||
|
||||
def _retriever() -> LocalKnowledgeRetriever:
|
||||
return LocalKnowledgeRetriever(current_app.config["LOCAL_KB_PATH"])
|
||||
|
||||
|
||||
def _llm_config() -> dict:
|
||||
manager = LLMConfigManager(current_app.config["LLM_CONFIG_PATH"])
|
||||
return manager.load()
|
||||
|
||||
|
||||
def _default_provider(config: dict) -> str:
|
||||
provider = (config.get("active_provider") or "fastgpt").strip().lower() if isinstance(config, dict) else "fastgpt"
|
||||
return provider if provider in {"fastgpt", "dify"} else "fastgpt"
|
||||
|
||||
|
||||
def _context_text(hits: list) -> str:
|
||||
chunks = []
|
||||
for idx, item in enumerate(hits, start=1):
|
||||
chunks.append(f"[{idx}] 问: {item['question']}\n答: {item['answer']}")
|
||||
return "\n\n".join(chunks)
|
||||
|
||||
|
||||
def _local_answer(hits: list) -> str:
|
||||
if not hits:
|
||||
return "本地知识库暂未检索到高相关内容,建议补充问题细节。"
|
||||
top = hits[0]
|
||||
return f"根据本地知识库:{top['answer']}"
|
||||
|
||||
|
||||
def _advice_template(advice_type: str) -> str:
|
||||
templates = {
|
||||
"gain_muscle": "增肌建议:每天蛋白质建议 1.6-2.2g/kg,训练后 30 分钟内补充蛋白+碳水,优先鸡胸肉/牛肉/鸡蛋/燕麦。",
|
||||
"lose_fat": "减脂建议:保持 10%-20% 轻热量缺口,优先高蛋白高纤维食物,减少高糖饮料和深夜加餐。",
|
||||
"keto": "生酮建议:控制碳水通常小于 50g/天,优先健康脂肪和适量蛋白,并关注电解质补充。",
|
||||
"nutritionist": "营养师建议:结合体重、体脂、活动量和目标制定周计划,每 7 天复盘体重和围度变化。",
|
||||
"general": "通用建议:保证规律作息、三餐结构稳定、适量运动和充足饮水。",
|
||||
}
|
||||
return templates.get(advice_type, templates["general"])
|
||||
|
||||
|
||||
def _ask_by_provider(provider: str, config: dict, question: str, context: str, user_id: str, username: str):
|
||||
if provider == "dify":
|
||||
return ask_dify(config, question, context=context, user=username)
|
||||
return ask_fastgpt(config, question, context=context, custom_uid=user_id)
|
||||
|
||||
|
||||
@qa_bp.get("/kb/search")
|
||||
@jwt_required(optional=True)
|
||||
def search_kb():
|
||||
query = (request.args.get("query") or "").strip()
|
||||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 10)
|
||||
|
||||
retriever = _retriever()
|
||||
hits = retriever.search(query, top_k=top_k)
|
||||
return ok({"items": hits})
|
||||
|
||||
|
||||
@qa_bp.post("/ask")
|
||||
@jwt_required()
|
||||
def ask_nutrition_qa():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
question = (payload.get("question") or "").strip()
|
||||
if not question:
|
||||
return ok({"answer": "请输入问题", "items": []}, "empty_question")
|
||||
|
||||
mode = (payload.get("mode") or "auto").strip().lower()
|
||||
|
||||
llm_cfg = _llm_config()
|
||||
provider = (payload.get("provider") or _default_provider(llm_cfg)).strip().lower()
|
||||
if provider not in {"fastgpt", "dify"}:
|
||||
provider = _default_provider(llm_cfg)
|
||||
|
||||
retriever = _retriever()
|
||||
hits = retriever.search(question, top_k=3)
|
||||
context = _context_text(hits)
|
||||
|
||||
local_answer = _local_answer(hits)
|
||||
if mode == "local":
|
||||
return ok({"answer": local_answer, "provider": "local_kb", "items": hits})
|
||||
|
||||
user = current_user()
|
||||
llm_result = _ask_by_provider(
|
||||
provider,
|
||||
llm_cfg,
|
||||
question,
|
||||
context,
|
||||
str(user.id) if user else "",
|
||||
user.username if user else "anonymous",
|
||||
)
|
||||
|
||||
if mode == "llm":
|
||||
if llm_result.get("ok"):
|
||||
return ok({"answer": llm_result.get("answer"), "provider": provider, "items": hits})
|
||||
return ok(
|
||||
{
|
||||
"answer": f"LLM 调用失败,已回退本地知识库。{local_answer}",
|
||||
"provider": "local_kb_fallback",
|
||||
"items": hits,
|
||||
"llm_error": llm_result,
|
||||
}
|
||||
)
|
||||
|
||||
if llm_result.get("ok"):
|
||||
return ok({"answer": llm_result.get("answer"), "provider": provider, "items": hits})
|
||||
|
||||
return ok(
|
||||
{
|
||||
"answer": local_answer,
|
||||
"provider": "local_kb",
|
||||
"items": hits,
|
||||
"llm_error": llm_result,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@qa_bp.post("/advice")
|
||||
@jwt_required()
|
||||
def advice():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
advice_type = (payload.get("advice_type") or "general").strip().lower()
|
||||
question = (payload.get("question") or "请根据我的状态给建议").strip()
|
||||
|
||||
llm_cfg = _llm_config()
|
||||
provider = (payload.get("provider") or _default_provider(llm_cfg)).strip().lower()
|
||||
if provider not in {"fastgpt", "dify"}:
|
||||
provider = _default_provider(llm_cfg)
|
||||
|
||||
retriever = _retriever()
|
||||
hits = retriever.search(f"{advice_type} {question}", top_k=3)
|
||||
context = _context_text(hits)
|
||||
base_text = _advice_template(advice_type)
|
||||
|
||||
user = current_user()
|
||||
llm_result = _ask_by_provider(
|
||||
provider,
|
||||
llm_cfg,
|
||||
f"请给出{advice_type}方向的个性化饮食建议。用户补充: {question}",
|
||||
f"{base_text}\n\n{context}",
|
||||
str(user.id) if user else "",
|
||||
user.username if user else "anonymous",
|
||||
)
|
||||
|
||||
if llm_result.get("ok"):
|
||||
answer = llm_result.get("answer")
|
||||
source = provider
|
||||
else:
|
||||
answer = f"{base_text}\n\n(提示:LLM 暂不可用,当前为本地模板建议)"
|
||||
source = "local_template"
|
||||
|
||||
return ok(
|
||||
{
|
||||
"advice_type": advice_type,
|
||||
"answer": answer,
|
||||
"provider": source,
|
||||
"knowledge_hits": hits,
|
||||
"llm_status": llm_result,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@qa_bp.post("/kb/reload")
|
||||
@admin_required
|
||||
def reload_kb():
|
||||
retriever = _retriever()
|
||||
retriever.reload()
|
||||
return ok({"count": len(retriever.documents)}, "知识库已重载")
|
||||
Reference in New Issue
Block a user