union-py-app-stream-chat 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +85 -0
- package/config/config.example.yaml +60 -0
- package/knowledge/files/scenario-netunion-runtime-risk-v1.md +111 -0
- package/knowledge/files/severity-netunion-runtime-classification-v1.md +112 -0
- package/knowledge//350/201/224/345/220/210/350/277/220/347/273/264/347/237/245/350/257/206/345/272/223/345/273/272/347/253/213/350/247/204/350/214/203.md +272 -0
- package/main.py +46 -0
- package/package.json +21 -0
- package/pyproject.toml +23 -0
- package/requirements.txt +11 -0
- package/src/api/__init__.py +0 -0
- package/src/api/routes.py +34 -0
- package/src/core/__init__.py +0 -0
- package/src/core/config_loader.py +99 -0
- package/src/manager/toolcall_manager.py +12 -0
- package/src/models/__init__.py +0 -0
- package/src/models/schemas.py +7 -0
- package/src/service/__init__.py +0 -0
- package/src/service/chat_service.py +245 -0
- package/src/service/rag_service.py +154 -0
- package/src/utils/function_utils.py +421 -0
- package/tests/__init__.py +0 -0
- package/tests/test_chat_service.py +75 -0
- package/trainingDocs//347/275/221/347/273/234/346/224/257/344/273/230/346/270/205/347/256/227/345/271/263/345/217/260/350/201/224/345/220/210/350/277/220/347/273/264/350/277/220/350/241/214/345/256/236/350/267/265/346/214/207/345/274/225V1.0.md +33 -0
package/package.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "union-py-app-stream-chat",
|
|
3
|
+
"version": "1.0.0",
|
|
4
|
+
"description": "Source package for the union operations stream chat Python app.",
|
|
5
|
+
"license": "UNLICENSED",
|
|
6
|
+
"private": false,
|
|
7
|
+
"files": [
|
|
8
|
+
"src",
|
|
9
|
+
"knowledge",
|
|
10
|
+
"trainingDocs",
|
|
11
|
+
"config/config.example.yaml",
|
|
12
|
+
"main.py",
|
|
13
|
+
"requirements.txt",
|
|
14
|
+
"pyproject.toml",
|
|
15
|
+
"README.md",
|
|
16
|
+
"tests"
|
|
17
|
+
],
|
|
18
|
+
"scripts": {
|
|
19
|
+
"pack:check": "npm pack --dry-run"
|
|
20
|
+
}
|
|
21
|
+
}
|
package/pyproject.toml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "glm-ops-assistant"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = "运维平台运行表现智能助手 - 基于 GLM 大模型"
|
|
9
|
+
requires-python = ">=3.9"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"fastapi>=0.110.0",
|
|
12
|
+
"uvicorn[standard]>=0.29.0",
|
|
13
|
+
"sse-starlette>=2.0.0",
|
|
14
|
+
"pydantic>=2.6.0",
|
|
15
|
+
"pyyaml>=6.0.1",
|
|
16
|
+
"zhipuai>=2.1.0",
|
|
17
|
+
"httpx>=0.23.0",
|
|
18
|
+
"sniffio>=1.3.0",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
[tool.setuptools.packages.find]
|
|
22
|
+
where = ["."]
|
|
23
|
+
include = ["src*", "config*"]
|
package/requirements.txt
ADDED
|
File without changes
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from fastapi import APIRouter
|
|
3
|
+
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
|
|
4
|
+
|
|
5
|
+
from src.models.schemas import ChatRequest
|
|
6
|
+
from src.manager.toolcall_manager import ToolCallManager
|
|
7
|
+
|
|
8
|
+
router = APIRouter()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@router.post("/chat/stream")
|
|
12
|
+
def chat_stream_endpoint(request: ChatRequest):
|
|
13
|
+
"""
|
|
14
|
+
流式对话接口(SSE)
|
|
15
|
+
使用工具调用大模型
|
|
16
|
+
- 接收用户问题
|
|
17
|
+
- 逐块返回生成的内容
|
|
18
|
+
"""
|
|
19
|
+
manager = ToolCallManager()
|
|
20
|
+
|
|
21
|
+
def event_generator():
|
|
22
|
+
for chunk in manager.tool_call_stream(request.session_id, request.question):
|
|
23
|
+
yield ServerSentEvent(
|
|
24
|
+
event="message",
|
|
25
|
+
data=json.dumps(chunk, ensure_ascii=False),
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# SSE 结束标记
|
|
29
|
+
yield ServerSentEvent(
|
|
30
|
+
event="done",
|
|
31
|
+
data=json.dumps({"session_id": request.session_id, "finish_reason": "done"}, ensure_ascii=False),
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
return EventSourceResponse(event_generator())
|
|
File without changes
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import yaml
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ConfigLoader:
|
|
7
|
+
"""配置加载器,支持从YAML文件加载配置,并可通过环境变量覆盖。"""
|
|
8
|
+
|
|
9
|
+
_instance = None
|
|
10
|
+
_config = None
|
|
11
|
+
|
|
12
|
+
def __new__(cls, *args, **kwargs):
|
|
13
|
+
if not cls._instance:
|
|
14
|
+
cls._instance = super().__new__(cls)
|
|
15
|
+
return cls._instance
|
|
16
|
+
|
|
17
|
+
def __init__(self, config_path: str = None):
|
|
18
|
+
if self._config is not None:
|
|
19
|
+
return
|
|
20
|
+
|
|
21
|
+
if config_path is None:
|
|
22
|
+
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
23
|
+
config_path = os.path.join(base_dir, "config", "config.yaml")
|
|
24
|
+
|
|
25
|
+
self._config_path = config_path
|
|
26
|
+
self._config = self._load_yaml()
|
|
27
|
+
self._override_from_env()
|
|
28
|
+
|
|
29
|
+
def _load_yaml(self) -> Dict[str, Any]:
|
|
30
|
+
with open(self._config_path, "r", encoding="utf-8") as f:
|
|
31
|
+
return yaml.safe_load(f)
|
|
32
|
+
|
|
33
|
+
def _override_from_env(self):
|
|
34
|
+
"""通过环境变量覆盖配置,支持 LLM_API_KEY, LLM_MODEL, LLM_BASE_URL 等。"""
|
|
35
|
+
env_mappings = {
|
|
36
|
+
"LLM_API_KEY": ["llm", "api_key"],
|
|
37
|
+
"LLM_MODEL": ["llm", "model"],
|
|
38
|
+
"LLM_BASE_URL": ["llm", "base_url"],
|
|
39
|
+
"LLM_MAX_TOKENS": ["llm", "max_tokens"],
|
|
40
|
+
"LLM_TEMPERATURE": ["llm", "temperature"],
|
|
41
|
+
"SESSION_TTL": ["session", "ttl_seconds"],
|
|
42
|
+
"SESSION_MAX_HISTORY": ["session", "max_history"],
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
for env_var, path in env_mappings.items():
|
|
46
|
+
value = os.environ.get(env_var)
|
|
47
|
+
if value is not None:
|
|
48
|
+
# 尝试数字转换
|
|
49
|
+
try:
|
|
50
|
+
if "." in value:
|
|
51
|
+
value = float(value)
|
|
52
|
+
else:
|
|
53
|
+
value = int(value)
|
|
54
|
+
except ValueError:
|
|
55
|
+
pass
|
|
56
|
+
self._set_nested_value(self._config, path, value)
|
|
57
|
+
|
|
58
|
+
def _set_nested_value(self, config: Dict, path: list, value: Any):
|
|
59
|
+
current = config
|
|
60
|
+
for key in path[:-1]:
|
|
61
|
+
current = current.setdefault(key, {})
|
|
62
|
+
current[path[-1]] = value
|
|
63
|
+
|
|
64
|
+
def get(self, *keys: str, default: Any = None) -> Any:
|
|
65
|
+
current = self._config
|
|
66
|
+
for key in keys:
|
|
67
|
+
if isinstance(current, dict) and key in current:
|
|
68
|
+
current = current[key]
|
|
69
|
+
else:
|
|
70
|
+
return default
|
|
71
|
+
return current
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def llm(self) -> Dict[str, Any]:
|
|
75
|
+
return self._config.get("llm", {})
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def system(self) -> Dict[str, Any]:
|
|
79
|
+
return self._config.get("system", {})
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def filter_config(self) -> Dict[str, Any]:
|
|
83
|
+
return self._config.get("filter", {})
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def session(self) -> Dict[str, Any]:
|
|
87
|
+
return self._config.get("session", {})
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def tools(self) -> Dict[str, Any]:
|
|
91
|
+
return self._config.get("tools", {})
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def raw(self) -> Dict[str, Any]:
|
|
95
|
+
return self._config
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# 全局配置实例
|
|
99
|
+
config = ConfigLoader()
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""使用官方的sdk调用模型,整合工具调用能力"""
|
|
2
|
+
from src.utils.function_utils import tools
|
|
3
|
+
from src.service.chat_service import chat_service
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ToolCallManager:
|
|
7
|
+
def __init__(self):
|
|
8
|
+
self._chat_service = chat_service
|
|
9
|
+
|
|
10
|
+
def tool_call_stream(self, session_id: str, question: str):
|
|
11
|
+
"""调用service层的工具调用方法,传入自定义的tools,格式化返回流式结果"""
|
|
12
|
+
yield from self._chat_service.tool_call_stream(session_id, question, tools)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
import time
|
|
3
|
+
from typing import Dict, Generator, List
|
|
4
|
+
|
|
5
|
+
from zai import ZhipuAiClient
|
|
6
|
+
|
|
7
|
+
from src.core.config_loader import config
|
|
8
|
+
from src.utils.function_utils import call_function
|
|
9
|
+
from src.service.rag_service import RagService
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
TOOL_ROUTING_PROMPT = """
|
|
13
|
+
你正在处理网络支付清算平台联合运维客服问题。选择工具前先识别用户真实业务场景:
|
|
14
|
+
1. 提到变更、投产、升级、发布、回滚、灾备、演练、通知、报备、关闭渠道,优先判断为生产变更,使用 evaluate_change_strategy。
|
|
15
|
+
2. 提到系统成功率、业务成功率、失败笔数、耗时、请求量、带宽、异常交易、降低流量、暂停交易、恢复,优先判断为生产运行,使用 evaluate_operation_scene。
|
|
16
|
+
3. 提到轻微/一般/严重故障、故障级别、异常交易笔数、持续时间、定级,使用 evaluate_fault_grade;只问账务调整时限时使用 get_fault_aftercare。
|
|
17
|
+
4. 提到银行/支付机构/成员单位类别、大型/中型/小型单位,使用 classify_member_unit。
|
|
18
|
+
5. 提到制度概念、业务范围、恢复策略、周期评价等解释类问题,使用 get_ops_guidance。
|
|
19
|
+
6. 用户未提供必要指标时,不要编造数值;先说明缺少哪些指标,再给出可确定的规则和下一步需要补充的信息。
|
|
20
|
+
工具返回结果是业务依据。最终回答应结合用户原问题解释结论、触发依据、建议动作,并明确引用《网络支付清算平台联合运维运行实践指引V1.0》。
|
|
21
|
+
""".strip()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
RAG_PROMPT = """
|
|
25
|
+
请优先依据【知识库检索结果】回答用户问题。
|
|
26
|
+
如果知识库没有明确依据,请说明缺少依据,不要编造规则。
|
|
27
|
+
涉及阈值时保留单位、比较符和适用对象;涉及生产操作时只给建议,不自动执行。
|
|
28
|
+
|
|
29
|
+
【知识库检索结果】
|
|
30
|
+
{context}
|
|
31
|
+
|
|
32
|
+
【用户问题】
|
|
33
|
+
{question}
|
|
34
|
+
""".strip()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ChatService:
|
|
38
|
+
"""
|
|
39
|
+
聊天服务层
|
|
40
|
+
整合工具调用流式 LLM + 业务过滤
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self):
|
|
44
|
+
# ---- 过滤配置 ----
|
|
45
|
+
self._filter_enabled = config.filter_config.get("enabled", True)
|
|
46
|
+
self._allowed_keywords = config.filter_config.get("allowed_keywords", [])
|
|
47
|
+
self._rejection_message = config.filter_config.get(
|
|
48
|
+
"rejection_message",
|
|
49
|
+
"抱歉,我只能回答与系统运行表现相关的问题。"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# ---- LLM 配置 ----
|
|
53
|
+
llm_cfg = config.llm
|
|
54
|
+
self._client = ZhipuAiClient(
|
|
55
|
+
api_key=llm_cfg.get("api_key"),
|
|
56
|
+
base_url=llm_cfg.get("base_url"),
|
|
57
|
+
)
|
|
58
|
+
self._model = llm_cfg.get("model", "glm-4-flash")
|
|
59
|
+
self._max_tokens = llm_cfg.get("max_tokens", 4096)
|
|
60
|
+
self._temperature = llm_cfg.get("temperature", 0.7)
|
|
61
|
+
self._top_p = llm_cfg.get("top_p", 0.9)
|
|
62
|
+
self._system_prompt = config.system.get("system_prompt", "")
|
|
63
|
+
self._rag = RagService()
|
|
64
|
+
|
|
65
|
+
# ---- 对话上下文 ----
|
|
66
|
+
self._sessions: Dict[str, Dict] = {}
|
|
67
|
+
self._max_history = config.session.get("max_history", 20)
|
|
68
|
+
self._ttl = config.session.get("ttl_seconds", 3600)
|
|
69
|
+
self._lock = threading.Lock()
|
|
70
|
+
|
|
71
|
+
# ========== 过滤 ==========
|
|
72
|
+
|
|
73
|
+
def _check_question_valid(self, question: str) -> bool:
|
|
74
|
+
if not self._filter_enabled or not self._allowed_keywords:
|
|
75
|
+
return True
|
|
76
|
+
lower_question = question.lower()
|
|
77
|
+
return any(k.lower() in lower_question for k in self._allowed_keywords)
|
|
78
|
+
|
|
79
|
+
# ========== 对话上下文 ==========
|
|
80
|
+
|
|
81
|
+
def _cleanup_expired(self):
|
|
82
|
+
now = time.time()
|
|
83
|
+
expired = [sid for sid, s in self._sessions.items() if now - s["last_active"] > self._ttl]
|
|
84
|
+
for sid in expired:
|
|
85
|
+
del self._sessions[sid]
|
|
86
|
+
|
|
87
|
+
def _ensure_session(self, session_id: str):
|
|
88
|
+
if session_id not in self._sessions:
|
|
89
|
+
self._sessions[session_id] = {
|
|
90
|
+
"messages": [],
|
|
91
|
+
"created_at": time.time(),
|
|
92
|
+
"last_active": time.time(),
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
def _get_history(self, session_id: str) -> List[Dict[str, str]]:
|
|
96
|
+
with self._lock:
|
|
97
|
+
self._cleanup_expired()
|
|
98
|
+
self._ensure_session(session_id)
|
|
99
|
+
return list(self._sessions[session_id]["messages"])
|
|
100
|
+
|
|
101
|
+
def _append_exchange(self, session_id: str, user_question: str, assistant_answer: str):
|
|
102
|
+
with self._lock:
|
|
103
|
+
self._ensure_session(session_id)
|
|
104
|
+
session = self._sessions[session_id]
|
|
105
|
+
session["messages"].extend([
|
|
106
|
+
{"role": "user", "content": user_question},
|
|
107
|
+
{"role": "assistant", "content": assistant_answer},
|
|
108
|
+
])
|
|
109
|
+
session["last_active"] = time.time()
|
|
110
|
+
max_messages = self._max_history * 2
|
|
111
|
+
if len(session["messages"]) > max_messages:
|
|
112
|
+
session["messages"] = session["messages"][-max_messages:]
|
|
113
|
+
|
|
114
|
+
# ========== LLM ==========
|
|
115
|
+
|
|
116
|
+
def _build_messages(
|
|
117
|
+
self,
|
|
118
|
+
session_id: str,
|
|
119
|
+
user_question: str,
|
|
120
|
+
rag_context: str = "",
|
|
121
|
+
) -> List[Dict[str, str]]:
|
|
122
|
+
context = rag_context
|
|
123
|
+
user_content = RAG_PROMPT.format(context=context, question=user_question) if context else user_question
|
|
124
|
+
messages = []
|
|
125
|
+
if self._system_prompt:
|
|
126
|
+
messages.append({"role": "system", "content": self._system_prompt})
|
|
127
|
+
messages.extend(self._get_history(session_id))
|
|
128
|
+
messages.append({"role": "user", "content": user_content})
|
|
129
|
+
return messages
|
|
130
|
+
|
|
131
|
+
def _build_tool_messages(self, session_id: str, user_question: str) -> List[Dict[str, str]]:
|
|
132
|
+
context, _ = self._rag.search(user_question)
|
|
133
|
+
messages = self._build_messages(session_id, user_question, context)
|
|
134
|
+
insert_at = 1 if messages and messages[0].get("role") == "system" else 0
|
|
135
|
+
messages.insert(insert_at, {"role": "system", "content": TOOL_ROUTING_PROMPT})
|
|
136
|
+
return messages
|
|
137
|
+
|
|
138
|
+
def tool_call_stream(self, session_id: str, question: str, tools) -> Generator[dict, None, None]:
|
|
139
|
+
"""
|
|
140
|
+
带工具调用的流式对话(支持交错思考与工具调用)
|
|
141
|
+
- stream=True + tool_stream=True:模型在流式输出中同时返回推理过程、回答内容与工具调用
|
|
142
|
+
- 工具执行结果回传模型后继续流式生成,循环直至模型不再调用工具或达到最大轮次
|
|
143
|
+
"""
|
|
144
|
+
if not self._check_question_valid(question):
|
|
145
|
+
yield {"session_id": session_id, "delta": self._rejection_message, "finish_reason": "rejected"}
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
def event(delta: str, etype: str) -> dict:
|
|
149
|
+
return {"session_id": session_id, "delta": delta, "type": etype, "finish_reason": None}
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
messages = self._build_tool_messages(session_id, question)
|
|
153
|
+
max_rounds = config.tools.get("max_rounds", 5)
|
|
154
|
+
final_answer = ""
|
|
155
|
+
|
|
156
|
+
for _ in range(max_rounds):
|
|
157
|
+
response = self._client.chat.completions.create(
|
|
158
|
+
model=self._model,
|
|
159
|
+
messages=messages,
|
|
160
|
+
tools=tools,
|
|
161
|
+
tool_choice="auto",
|
|
162
|
+
stream=True,
|
|
163
|
+
tool_stream=True,
|
|
164
|
+
max_tokens=self._max_tokens,
|
|
165
|
+
temperature=self._temperature,
|
|
166
|
+
top_p=self._top_p,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
current_content = ""
|
|
170
|
+
tool_calls_map: Dict[int, Dict] = {}
|
|
171
|
+
|
|
172
|
+
for chunk in response:
|
|
173
|
+
if not chunk.choices:
|
|
174
|
+
continue
|
|
175
|
+
delta = chunk.choices[0].delta
|
|
176
|
+
|
|
177
|
+
reasoning = getattr(delta, "reasoning_content", None)
|
|
178
|
+
if reasoning:
|
|
179
|
+
yield event(reasoning, "reasoning")
|
|
180
|
+
|
|
181
|
+
content = getattr(delta, "content", None)
|
|
182
|
+
if content:
|
|
183
|
+
current_content += content
|
|
184
|
+
yield event(content, "content")
|
|
185
|
+
|
|
186
|
+
for tc in getattr(delta, "tool_calls", None) or []:
|
|
187
|
+
self._merge_tool_call_delta(tool_calls_map, tc)
|
|
188
|
+
|
|
189
|
+
if not tool_calls_map:
|
|
190
|
+
final_answer = current_content
|
|
191
|
+
break
|
|
192
|
+
|
|
193
|
+
assistant_tool_calls = [tool_calls_map[i] for i in sorted(tool_calls_map)]
|
|
194
|
+
messages.append({
|
|
195
|
+
"role": "assistant",
|
|
196
|
+
"content": current_content or None,
|
|
197
|
+
"tool_calls": assistant_tool_calls,
|
|
198
|
+
})
|
|
199
|
+
|
|
200
|
+
for tc in assistant_tool_calls:
|
|
201
|
+
name = tc["function"]["name"]
|
|
202
|
+
args = tc["function"]["arguments"]
|
|
203
|
+
yield event(f"\n[调用工具: {name}({args})]\n", "tool_call")
|
|
204
|
+
|
|
205
|
+
result = call_function(name, args)
|
|
206
|
+
yield event(result, "tool_result")
|
|
207
|
+
|
|
208
|
+
messages.append({
|
|
209
|
+
"role": "tool",
|
|
210
|
+
"content": result,
|
|
211
|
+
"tool_call_id": tc["id"],
|
|
212
|
+
})
|
|
213
|
+
else:
|
|
214
|
+
final_answer = current_content or "[系统提示: 工具调用轮次已达上限]"
|
|
215
|
+
yield event(final_answer, "content")
|
|
216
|
+
|
|
217
|
+
self._append_exchange(session_id, question, final_answer)
|
|
218
|
+
yield {"session_id": session_id, "delta": "", "finish_reason": "stop"}
|
|
219
|
+
|
|
220
|
+
except Exception as e:
|
|
221
|
+
yield {
|
|
222
|
+
"session_id": session_id,
|
|
223
|
+
"delta": f"[错误] 模型调用异常: {str(e)}",
|
|
224
|
+
"finish_reason": "error",
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
@staticmethod
|
|
228
|
+
def _merge_tool_call_delta(tool_calls_map: Dict[int, Dict], tc) -> None:
|
|
229
|
+
"""将单个流式 tool_call 增量按 index 合并到累积字典中"""
|
|
230
|
+
slot = tool_calls_map.setdefault(tc.index, {
|
|
231
|
+
"id": "",
|
|
232
|
+
"type": "function",
|
|
233
|
+
"function": {"name": "", "arguments": ""},
|
|
234
|
+
})
|
|
235
|
+
if tc.id:
|
|
236
|
+
slot["id"] = tc.id
|
|
237
|
+
fn = getattr(tc, "function", None)
|
|
238
|
+
if fn is not None:
|
|
239
|
+
if getattr(fn, "name", None):
|
|
240
|
+
slot["function"]["name"] += fn.name
|
|
241
|
+
if getattr(fn, "arguments", None):
|
|
242
|
+
slot["function"]["arguments"] += fn.arguments
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
chat_service = ChatService()
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import re
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Dict, List, Tuple
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
from zai import ZhipuAiClient
|
|
8
|
+
|
|
9
|
+
from src.core.config_loader import config
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RagService:
|
|
13
|
+
"""轻量 RAG 服务:加载 Markdown 知识库,写入 Chroma,并按问题检索。"""
|
|
14
|
+
|
|
15
|
+
def __init__(self):
|
|
16
|
+
self._cfg = config.get("rag", default={})
|
|
17
|
+
self._enabled = self._cfg.get("enabled", False)
|
|
18
|
+
self._top_k = self._cfg.get("top_k", 5)
|
|
19
|
+
self._collection = None
|
|
20
|
+
self._client = ZhipuAiClient(
|
|
21
|
+
api_key=config.llm.get("api_key"),
|
|
22
|
+
base_url=config.llm.get("base_url"),
|
|
23
|
+
)
|
|
24
|
+
if self._enabled:
|
|
25
|
+
self._init_collection()
|
|
26
|
+
|
|
27
|
+
def search(self, question: str) -> Tuple[str, List[Dict]]:
|
|
28
|
+
if not self._collection:
|
|
29
|
+
return "", []
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
if self._collection.count() == 0:
|
|
33
|
+
self.rebuild()
|
|
34
|
+
if self._collection.count() == 0:
|
|
35
|
+
return "", []
|
|
36
|
+
result = self._collection.query(
|
|
37
|
+
query_embeddings=[self._embed(question)],
|
|
38
|
+
n_results=self._top_k,
|
|
39
|
+
include=["documents", "metadatas"],
|
|
40
|
+
)
|
|
41
|
+
except Exception:
|
|
42
|
+
self._collection = None
|
|
43
|
+
return "", []
|
|
44
|
+
docs = result.get("documents", [[]])[0]
|
|
45
|
+
metas = result.get("metadatas", [[]])[0]
|
|
46
|
+
sources = [self._source(m) for m in metas]
|
|
47
|
+
context = "\n\n".join(
|
|
48
|
+
f"[{i + 1}] 标题:{m.get('title', '')}\n"
|
|
49
|
+
f"来源:{m.get('source_doc', '')} {m.get('source_section', '')}\n"
|
|
50
|
+
f"内容:{doc}"
|
|
51
|
+
for i, (doc, m) in enumerate(zip(docs, metas))
|
|
52
|
+
)
|
|
53
|
+
return context, sources
|
|
54
|
+
|
|
55
|
+
def _init_collection(self):
|
|
56
|
+
try:
|
|
57
|
+
import chromadb
|
|
58
|
+
except ImportError:
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
root = Path(__file__).resolve().parents[2]
|
|
62
|
+
persist_dir = root / self._cfg.get("persist_dir", ".chroma")
|
|
63
|
+
self._knowledge_dir = root / self._cfg.get("knowledge_dir", "knowledge")
|
|
64
|
+
chroma = chromadb.PersistentClient(path=str(persist_dir))
|
|
65
|
+
self._collection = chroma.get_or_create_collection(self._cfg.get("collection", "ops_knowledge"))
|
|
66
|
+
try:
|
|
67
|
+
if self._cfg.get("rebuild_on_startup", False):
|
|
68
|
+
self.rebuild()
|
|
69
|
+
except Exception:
|
|
70
|
+
self._collection = None
|
|
71
|
+
|
|
72
|
+
def rebuild(self):
|
|
73
|
+
docs = self._load_documents()
|
|
74
|
+
if not docs:
|
|
75
|
+
return
|
|
76
|
+
self._collection.upsert(
|
|
77
|
+
ids=[d["id"] for d in docs],
|
|
78
|
+
documents=[d["content"] for d in docs],
|
|
79
|
+
metadatas=[d["metadata"] for d in docs],
|
|
80
|
+
embeddings=[self._embed(d["content"]) for d in docs],
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def _load_documents(self) -> List[Dict]:
|
|
84
|
+
docs = []
|
|
85
|
+
for path in self._knowledge_dir.rglob("*.md"):
|
|
86
|
+
metadata, body = self._read_markdown(path)
|
|
87
|
+
if metadata.get("status", "active") != "active":
|
|
88
|
+
continue
|
|
89
|
+
for idx, chunk in enumerate(self._split(body)):
|
|
90
|
+
item_meta = self._clean_metadata({**metadata, "file_path": str(path), "chunk_index": idx})
|
|
91
|
+
docs.append({
|
|
92
|
+
"id": self._chunk_id(path, idx, chunk),
|
|
93
|
+
"content": chunk,
|
|
94
|
+
"metadata": item_meta,
|
|
95
|
+
})
|
|
96
|
+
return docs
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _read_markdown(path: Path) -> Tuple[Dict, str]:
|
|
100
|
+
text = path.read_text(encoding="utf-8")
|
|
101
|
+
match = re.match(r"^---\n(.*?)\n---\n(.*)$", text, re.S)
|
|
102
|
+
if not match:
|
|
103
|
+
return {"title": path.stem, "status": "active"}, text
|
|
104
|
+
return yaml.safe_load(match.group(1)) or {}, match.group(2).strip()
|
|
105
|
+
|
|
106
|
+
def _split(self, text: str) -> List[str]:
|
|
107
|
+
max_chars = self._cfg.get("chunk_size", 1200)
|
|
108
|
+
parts = re.split(r"\n(?=##\s+)", text)
|
|
109
|
+
chunks = []
|
|
110
|
+
for part in parts:
|
|
111
|
+
part = part.strip()
|
|
112
|
+
if not part:
|
|
113
|
+
continue
|
|
114
|
+
if len(part) <= max_chars:
|
|
115
|
+
chunks.append(part)
|
|
116
|
+
else:
|
|
117
|
+
chunks.extend(part[i:i + max_chars] for i in range(0, len(part), max_chars))
|
|
118
|
+
return chunks
|
|
119
|
+
|
|
120
|
+
def _embed(self, text: str) -> List[float]:
|
|
121
|
+
response = self._client.embeddings.create(
|
|
122
|
+
model=self._cfg.get("embedding_model", "embedding-3"),
|
|
123
|
+
input=text[: self._cfg.get("embedding_max_chars", 6000)],
|
|
124
|
+
)
|
|
125
|
+
item = response.data[0]
|
|
126
|
+
return item["embedding"] if isinstance(item, dict) else item.embedding
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def _clean_metadata(metadata: Dict) -> Dict:
|
|
130
|
+
cleaned = {}
|
|
131
|
+
for key, value in metadata.items():
|
|
132
|
+
if isinstance(value, list):
|
|
133
|
+
cleaned[key] = ",".join(map(str, value))
|
|
134
|
+
elif value is None:
|
|
135
|
+
cleaned[key] = ""
|
|
136
|
+
elif isinstance(value, (str, int, float, bool)):
|
|
137
|
+
cleaned[key] = value
|
|
138
|
+
else:
|
|
139
|
+
cleaned[key] = str(value)
|
|
140
|
+
return cleaned
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
def _chunk_id(path: Path, idx: int, chunk: str) -> str:
|
|
144
|
+
digest = hashlib.md5(f"{path}:{idx}:{chunk}".encode("utf-8")).hexdigest()
|
|
145
|
+
return digest
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def _source(metadata: Dict) -> Dict:
|
|
149
|
+
return {
|
|
150
|
+
"title": metadata.get("title", ""),
|
|
151
|
+
"source_doc": metadata.get("source_doc", ""),
|
|
152
|
+
"source_section": metadata.get("source_section", ""),
|
|
153
|
+
"file_path": metadata.get("file_path", ""),
|
|
154
|
+
}
|