63 lines
2.0 KiB
Python
63 lines
2.0 KiB
Python
import json
|
|
import re
|
|
from pathlib import Path
|
|
|
|
|
|
class LocalKnowledgeRetriever:
|
|
def __init__(self, kb_path: str):
|
|
self.kb_path = Path(kb_path)
|
|
self.documents = []
|
|
self._load()
|
|
|
|
@staticmethod
|
|
def _tokenize(text: str) -> set:
|
|
text = text or ""
|
|
words = set(re.findall(r"[A-Za-z0-9_]+", text.lower()))
|
|
cjk_chunks = re.findall(r"[\u4e00-\u9fff]+", text)
|
|
for chunk in cjk_chunks:
|
|
if len(chunk) <= 2:
|
|
words.add(chunk)
|
|
else:
|
|
for idx in range(len(chunk) - 1):
|
|
words.add(chunk[idx : idx + 2])
|
|
return words
|
|
|
|
def _load(self):
|
|
if not self.kb_path.exists():
|
|
self.documents = []
|
|
return
|
|
with self.kb_path.open("r", encoding="utf-8-sig") as file:
|
|
rows = json.load(file)
|
|
self.documents = rows if isinstance(rows, list) else []
|
|
|
|
def reload(self):
|
|
self._load()
|
|
|
|
def search(self, query: str, top_k: int = 3) -> list:
|
|
query_tokens = self._tokenize(query)
|
|
if not query_tokens:
|
|
return []
|
|
|
|
scored = []
|
|
for item in self.documents:
|
|
content = f"{item.get('question', '')} {item.get('answer', '')} {' '.join(item.get('tags', []))}"
|
|
doc_tokens = self._tokenize(content)
|
|
if not doc_tokens:
|
|
continue
|
|
overlap = len(query_tokens & doc_tokens)
|
|
if overlap == 0:
|
|
continue
|
|
score = overlap / max(len(query_tokens), 1)
|
|
scored.append(
|
|
{
|
|
"score": round(score, 4),
|
|
"question": item.get("question", ""),
|
|
"answer": item.get("answer", ""),
|
|
"tags": item.get("tags", []),
|
|
"source": item.get("source", "本地知识库"),
|
|
}
|
|
)
|
|
|
|
scored.sort(key=lambda x: x["score"], reverse=True)
|
|
return scored[:top_k]
|