178 lines
5.9 KiB
Python
178 lines
5.9 KiB
Python
from flask import Blueprint, current_app, request
|
||
from flask_jwt_extended import jwt_required
|
||
|
||
from app.rag.llm_client import LLMConfigManager, ask_dify, ask_fastgpt
|
||
from app.rag.local_retriever import LocalKnowledgeRetriever
|
||
from app.utils.auth import admin_required, current_user
|
||
from app.utils.response import ok
|
||
|
||
|
||
qa_bp = Blueprint("qa", __name__)
|
||
|
||
|
||
def _retriever() -> LocalKnowledgeRetriever:
|
||
return LocalKnowledgeRetriever(current_app.config["LOCAL_KB_PATH"])
|
||
|
||
|
||
def _llm_config() -> dict:
|
||
manager = LLMConfigManager(current_app.config["LLM_CONFIG_PATH"])
|
||
return manager.load()
|
||
|
||
|
||
def _default_provider(config: dict) -> str:
|
||
provider = (config.get("active_provider") or "fastgpt").strip().lower() if isinstance(config, dict) else "fastgpt"
|
||
return provider if provider in {"fastgpt", "dify"} else "fastgpt"
|
||
|
||
|
||
def _context_text(hits: list) -> str:
|
||
chunks = []
|
||
for idx, item in enumerate(hits, start=1):
|
||
chunks.append(f"[{idx}] 问: {item['question']}\n答: {item['answer']}")
|
||
return "\n\n".join(chunks)
|
||
|
||
|
||
def _local_answer(hits: list) -> str:
|
||
if not hits:
|
||
return "本地知识库暂未检索到高相关内容,建议补充问题细节。"
|
||
top = hits[0]
|
||
return f"根据本地知识库:{top['answer']}"
|
||
|
||
|
||
def _advice_template(advice_type: str) -> str:
|
||
templates = {
|
||
"gain_muscle": "增肌建议:每天蛋白质建议 1.6-2.2g/kg,训练后 30 分钟内补充蛋白+碳水,优先鸡胸肉/牛肉/鸡蛋/燕麦。",
|
||
"lose_fat": "减脂建议:保持 10%-20% 轻热量缺口,优先高蛋白高纤维食物,减少高糖饮料和深夜加餐。",
|
||
"keto": "生酮建议:控制碳水通常小于 50g/天,优先健康脂肪和适量蛋白,并关注电解质补充。",
|
||
"nutritionist": "营养师建议:结合体重、体脂、活动量和目标制定周计划,每 7 天复盘体重和围度变化。",
|
||
"general": "通用建议:保证规律作息、三餐结构稳定、适量运动和充足饮水。",
|
||
}
|
||
return templates.get(advice_type, templates["general"])
|
||
|
||
|
||
def _ask_by_provider(provider: str, config: dict, question: str, context: str, user_id: str, username: str):
|
||
if provider == "dify":
|
||
return ask_dify(config, question, context=context, user=username)
|
||
return ask_fastgpt(config, question, context=context, custom_uid=user_id)
|
||
|
||
|
||
@qa_bp.get("/kb/search")
|
||
@jwt_required(optional=True)
|
||
def search_kb():
|
||
query = (request.args.get("query") or "").strip()
|
||
top_k = min(max(int(request.args.get("top_k", 5) or 5), 1), 10)
|
||
|
||
retriever = _retriever()
|
||
hits = retriever.search(query, top_k=top_k)
|
||
return ok({"items": hits})
|
||
|
||
|
||
@qa_bp.post("/ask")
|
||
@jwt_required()
|
||
def ask_nutrition_qa():
|
||
payload = request.get_json(silent=True) or {}
|
||
question = (payload.get("question") or "").strip()
|
||
if not question:
|
||
return ok({"answer": "请输入问题", "items": []}, "empty_question")
|
||
|
||
mode = (payload.get("mode") or "auto").strip().lower()
|
||
|
||
llm_cfg = _llm_config()
|
||
provider = (payload.get("provider") or _default_provider(llm_cfg)).strip().lower()
|
||
if provider not in {"fastgpt", "dify"}:
|
||
provider = _default_provider(llm_cfg)
|
||
|
||
retriever = _retriever()
|
||
hits = retriever.search(question, top_k=3)
|
||
context = _context_text(hits)
|
||
|
||
local_answer = _local_answer(hits)
|
||
if mode == "local":
|
||
return ok({"answer": local_answer, "provider": "local_kb", "items": hits})
|
||
|
||
user = current_user()
|
||
llm_result = _ask_by_provider(
|
||
provider,
|
||
llm_cfg,
|
||
question,
|
||
context,
|
||
str(user.id) if user else "",
|
||
user.username if user else "anonymous",
|
||
)
|
||
|
||
if mode == "llm":
|
||
if llm_result.get("ok"):
|
||
return ok({"answer": llm_result.get("answer"), "provider": provider, "items": hits})
|
||
return ok(
|
||
{
|
||
"answer": f"LLM 调用失败,已回退本地知识库。{local_answer}",
|
||
"provider": "local_kb_fallback",
|
||
"items": hits,
|
||
"llm_error": llm_result,
|
||
}
|
||
)
|
||
|
||
if llm_result.get("ok"):
|
||
return ok({"answer": llm_result.get("answer"), "provider": provider, "items": hits})
|
||
|
||
return ok(
|
||
{
|
||
"answer": local_answer,
|
||
"provider": "local_kb",
|
||
"items": hits,
|
||
"llm_error": llm_result,
|
||
}
|
||
)
|
||
|
||
|
||
@qa_bp.post("/advice")
|
||
@jwt_required()
|
||
def advice():
|
||
payload = request.get_json(silent=True) or {}
|
||
advice_type = (payload.get("advice_type") or "general").strip().lower()
|
||
question = (payload.get("question") or "请根据我的状态给建议").strip()
|
||
|
||
llm_cfg = _llm_config()
|
||
provider = (payload.get("provider") or _default_provider(llm_cfg)).strip().lower()
|
||
if provider not in {"fastgpt", "dify"}:
|
||
provider = _default_provider(llm_cfg)
|
||
|
||
retriever = _retriever()
|
||
hits = retriever.search(f"{advice_type} {question}", top_k=3)
|
||
context = _context_text(hits)
|
||
base_text = _advice_template(advice_type)
|
||
|
||
user = current_user()
|
||
llm_result = _ask_by_provider(
|
||
provider,
|
||
llm_cfg,
|
||
f"请给出{advice_type}方向的个性化饮食建议。用户补充: {question}",
|
||
f"{base_text}\n\n{context}",
|
||
str(user.id) if user else "",
|
||
user.username if user else "anonymous",
|
||
)
|
||
|
||
if llm_result.get("ok"):
|
||
answer = llm_result.get("answer")
|
||
source = provider
|
||
else:
|
||
answer = f"{base_text}\n\n(提示:LLM 暂不可用,当前为本地模板建议)"
|
||
source = "local_template"
|
||
|
||
return ok(
|
||
{
|
||
"advice_type": advice_type,
|
||
"answer": answer,
|
||
"provider": source,
|
||
"knowledge_hits": hits,
|
||
"llm_status": llm_result,
|
||
}
|
||
)
|
||
|
||
|
||
@qa_bp.post("/kb/reload")
|
||
@admin_required
|
||
def reload_kb():
|
||
retriever = _retriever()
|
||
retriever.reload()
|
||
return ok({"count": len(retriever.documents)}, "知识库已重载")
|