Files
c/backend/app/rag/local_retriever.py
刘正航 b5237f9038 1
2026-04-21 22:45:19 +08:00

63 lines
2.0 KiB
Python

import json
import re
from pathlib import Path
class LocalKnowledgeRetriever:
def __init__(self, kb_path: str):
self.kb_path = Path(kb_path)
self.documents = []
self._load()
@staticmethod
def _tokenize(text: str) -> set:
text = text or ""
words = set(re.findall(r"[A-Za-z0-9_]+", text.lower()))
cjk_chunks = re.findall(r"[\u4e00-\u9fff]+", text)
for chunk in cjk_chunks:
if len(chunk) <= 2:
words.add(chunk)
else:
for idx in range(len(chunk) - 1):
words.add(chunk[idx : idx + 2])
return words
def _load(self):
if not self.kb_path.exists():
self.documents = []
return
with self.kb_path.open("r", encoding="utf-8-sig") as file:
rows = json.load(file)
self.documents = rows if isinstance(rows, list) else []
def reload(self):
self._load()
def search(self, query: str, top_k: int = 3) -> list:
query_tokens = self._tokenize(query)
if not query_tokens:
return []
scored = []
for item in self.documents:
content = f"{item.get('question', '')} {item.get('answer', '')} {' '.join(item.get('tags', []))}"
doc_tokens = self._tokenize(content)
if not doc_tokens:
continue
overlap = len(query_tokens & doc_tokens)
if overlap == 0:
continue
score = overlap / max(len(query_tokens), 1)
scored.append(
{
"score": round(score, 4),
"question": item.get("question", ""),
"answer": item.get("answer", ""),
"tags": item.get("tags", []),
"source": item.get("source", "本地知识库"),
}
)
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:top_k]