1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
| from dataclasses import dataclass
from typing import List, Dict, Any
import time
@dataclass
class Chunk:
doc_id: str
score: float
text: str
def normalize_query(q: str) -> str:
# 你可以在这里做:全角半角、大小写、同义词、实体标准化…
return " ".join(q.strip().split())
def retrieve(q: str, topk: int = 5) -> List[Chunk]:
# 示例:这里替换成你的 BM25/向量检索
# 返回 doc_id/score/text,便于后续定位“到底检索到了什么”
return [
Chunk(doc_id="doc:pricing", score=0.78, text="..."),
Chunk(doc_id="doc:limits", score=0.74, text="..."),
][:topk]
def build_prompt(q: str, chunks: List[Chunk], max_chars: int = 6000) -> str:
context = "\n\n".join(
f"[source:{c.doc_id} score={c.score:.2f}]\n{c.text}" for c in chunks
)
prompt = (
"你是一个严谨的助手。只允许基于给定的 sources 回答,并在结尾列出引用。\n\n"
f"Question:\n{q}\n\n"
f"Sources:\n{context}\n\n"
"Answer:\n"
)
truncated = len(prompt) > max_chars
if truncated:
prompt = prompt[:max_chars] + "\n\n[TRUNCATED]"
return prompt
def rag_once(question: str) -> Dict[str, Any]:
t0 = time.time()
q = normalize_query(question)
chunks = retrieve(q, topk=8)
prompt = build_prompt(q, chunks, max_chars=6000)
# 这里替换成你的 LLM 调用
answer = "(mock) ..."
return {
"latency_ms": int((time.time() - t0) * 1000),
"query": q,
"top_docs": [{"doc_id": c.doc_id, "score": c.score} for c in chunks],
"prompt_chars": len(prompt),
"answer": answer,
}
if __name__ == "__main__":
result = rag_once("你们套餐的价格和限制是什么?")
print(result)
|