union-py-app-stream-chat 1.0.0 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +20 -6
- package/package.json +1 -1
- package/src/api/routes.py +11 -7
- package/src/core/config_loader.py +61 -24
- package/src/core/logging_config.py +16 -0
- package/src/manager/toolcall_manager.py +13 -4
- package/src/models/schemas.py +15 -0
- package/src/service/chat_service.py +83 -18
- package/src/service/rag_service.py +40 -2
- package/src/utils/function_utils.py +11 -2
- package/tests/test_chat_service.py +43 -3
package/README.md
CHANGED
|
@@ -14,10 +14,10 @@ glm-ops-assistant/
|
|
|
14
14
|
│ ├── service/
|
|
15
15
|
│ │ └── chat_service.py # Service 层 - 业务逻辑(问题筛选、流程编排)
|
|
16
16
|
│ ├── manager/
|
|
17
|
-
│ │
|
|
18
|
-
│ │ └── session_manager.py # Manager 层 - 多轮会话管理
|
|
17
|
+
│ │ └── toolcall_manager.py # Manager 层 - 工具调用封装
|
|
19
18
|
│ ├── core/
|
|
20
|
-
│ │
|
|
19
|
+
│ │ ├── config_loader.py # 核心工具 - 配置加载(支持环境变量覆盖)
|
|
20
|
+
│ │ └── logging_config.py # 核心工具 - 日志配置
|
|
21
21
|
│ └── models/
|
|
22
22
|
│ └── schemas.py # Pydantic 数据模型
|
|
23
23
|
├── tests/
|
|
@@ -48,8 +48,10 @@ llm:
|
|
|
48
48
|
或通过环境变量覆盖:
|
|
49
49
|
|
|
50
50
|
```bash
|
|
51
|
+
export APP_CONFIG_PATH="/path/to/config.yaml"
|
|
51
52
|
export LLM_API_KEY="your-api-key"
|
|
52
53
|
export LLM_MODEL="glm-4-flash"
|
|
54
|
+
export RAG_ENABLED="true"
|
|
53
55
|
```
|
|
54
56
|
|
|
55
57
|
### 3. 启动服务
|
|
@@ -64,10 +66,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
|
|
64
66
|
|
|
65
67
|
| 方法 | 路径 | 说明 |
|
|
66
68
|
|------|------|------|
|
|
67
|
-
| POST | `/api/v1/chat` | 普通对话(非流式) |
|
|
68
69
|
| POST | `/api/v1/chat/stream` | 流式对话(SSE) |
|
|
69
|
-
| POST | `/api/v1/session/clear` | 清空会话历史 |
|
|
70
|
-
| GET | `/api/v1/session/info` | 获取会话信息 |
|
|
71
70
|
| GET | `/api/v1/health` | 健康检查 |
|
|
72
71
|
|
|
73
72
|
## 多轮对话
|
|
@@ -83,3 +82,18 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
|
|
83
82
|
## 流式输出
|
|
84
83
|
|
|
85
84
|
调用 `/api/v1/chat/stream` 接口,服务端使用 SSE 协议逐字推送模型生成内容。
|
|
85
|
+
|
|
86
|
+
每条 `message` 事件的 `data` 为 `ChatResponse` JSON:
|
|
87
|
+
|
|
88
|
+
```json
|
|
89
|
+
{
|
|
90
|
+
"session_id": "sess-xxx",
|
|
91
|
+
"content": "正式回复增量",
|
|
92
|
+
"reasoning_content": null,
|
|
93
|
+
"tool_call": null,
|
|
94
|
+
"tool_result": null,
|
|
95
|
+
"finish_reason": null
|
|
96
|
+
}
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
模型正式回复写入 `content`,推理内容写入 `reasoning_content`,工具过程写入 `tool_call` / `tool_result`。
|
package/package.json
CHANGED
package/src/api/routes.py
CHANGED
|
@@ -1,11 +1,16 @@
|
|
|
1
|
-
import json
|
|
2
1
|
from fastapi import APIRouter
|
|
3
2
|
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
|
|
4
3
|
|
|
5
|
-
from src.models.schemas import ChatRequest
|
|
4
|
+
from src.models.schemas import ChatRequest, ChatResponse
|
|
6
5
|
from src.manager.toolcall_manager import ToolCallManager
|
|
7
6
|
|
|
8
7
|
router = APIRouter()
|
|
8
|
+
tool_call_manager = ToolCallManager()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@router.get("/health")
|
|
12
|
+
def health_check():
|
|
13
|
+
return {"status": "ok"}
|
|
9
14
|
|
|
10
15
|
|
|
11
16
|
@router.post("/chat/stream")
|
|
@@ -16,19 +21,18 @@ def chat_stream_endpoint(request: ChatRequest):
|
|
|
16
21
|
- 接收用户问题
|
|
17
22
|
- 逐块返回生成的内容
|
|
18
23
|
"""
|
|
19
|
-
manager = ToolCallManager()
|
|
20
|
-
|
|
21
24
|
def event_generator():
|
|
22
|
-
for chunk in
|
|
25
|
+
for chunk in tool_call_manager.tool_call_stream(request.session_id, request.question):
|
|
23
26
|
yield ServerSentEvent(
|
|
24
27
|
event="message",
|
|
25
|
-
data=
|
|
28
|
+
data=chunk.model_dump_json(),
|
|
26
29
|
)
|
|
27
30
|
|
|
28
31
|
# SSE 结束标记
|
|
32
|
+
done = ChatResponse(session_id=request.session_id, finish_reason="done")
|
|
29
33
|
yield ServerSentEvent(
|
|
30
34
|
event="done",
|
|
31
|
-
data=
|
|
35
|
+
data=done.model_dump_json(),
|
|
32
36
|
)
|
|
33
37
|
|
|
34
38
|
return EventSourceResponse(event_generator())
|
|
@@ -1,26 +1,25 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import threading
|
|
2
3
|
import yaml
|
|
3
4
|
from typing import Any, Dict
|
|
4
5
|
|
|
6
|
+
from src.core.logging_config import get_logger
|
|
5
7
|
|
|
6
|
-
class ConfigLoader:
|
|
7
|
-
"""配置加载器,支持从YAML文件加载配置,并可通过环境变量覆盖。"""
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
_config = None
|
|
9
|
+
logger = get_logger(__name__)
|
|
11
10
|
|
|
12
|
-
def __new__(cls, *args, **kwargs):
|
|
13
|
-
if not cls._instance:
|
|
14
|
-
cls._instance = super().__new__(cls)
|
|
15
|
-
return cls._instance
|
|
16
11
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
return
|
|
12
|
+
class ConfigLoader:
|
|
13
|
+
"""配置加载器,支持从YAML文件加载配置,并可通过环境变量覆盖。"""
|
|
20
14
|
|
|
15
|
+
def __init__(self, config_path: str = None):
|
|
16
|
+
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
21
17
|
if config_path is None:
|
|
22
|
-
|
|
23
|
-
|
|
18
|
+
config_path = os.environ.get("APP_CONFIG_PATH") or os.path.join(base_dir, "config", "config.yaml")
|
|
19
|
+
if not os.path.exists(config_path):
|
|
20
|
+
fallback_path = os.path.join(base_dir, "config", "config.example.yaml")
|
|
21
|
+
logger.warning("配置文件不存在,使用示例配置。config_path=%s fallback=%s", config_path, fallback_path)
|
|
22
|
+
config_path = fallback_path
|
|
24
23
|
|
|
25
24
|
self._config_path = config_path
|
|
26
25
|
self._config = self._load_yaml()
|
|
@@ -28,7 +27,10 @@ class ConfigLoader:
|
|
|
28
27
|
|
|
29
28
|
def _load_yaml(self) -> Dict[str, Any]:
|
|
30
29
|
with open(self._config_path, "r", encoding="utf-8") as f:
|
|
31
|
-
|
|
30
|
+
data = yaml.safe_load(f) or {}
|
|
31
|
+
if not isinstance(data, dict):
|
|
32
|
+
raise ValueError(f"配置文件格式错误,顶层必须是对象: {self._config_path}")
|
|
33
|
+
return data
|
|
32
34
|
|
|
33
35
|
def _override_from_env(self):
|
|
34
36
|
"""通过环境变量覆盖配置,支持 LLM_API_KEY, LLM_MODEL, LLM_BASE_URL 等。"""
|
|
@@ -40,21 +42,28 @@ class ConfigLoader:
|
|
|
40
42
|
"LLM_TEMPERATURE": ["llm", "temperature"],
|
|
41
43
|
"SESSION_TTL": ["session", "ttl_seconds"],
|
|
42
44
|
"SESSION_MAX_HISTORY": ["session", "max_history"],
|
|
45
|
+
"RAG_ENABLED": ["rag", "enabled"],
|
|
46
|
+
"RAG_TOP_K": ["rag", "top_k"],
|
|
43
47
|
}
|
|
44
48
|
|
|
45
49
|
for env_var, path in env_mappings.items():
|
|
46
50
|
value = os.environ.get(env_var)
|
|
47
51
|
if value is not None:
|
|
48
|
-
|
|
49
|
-
try:
|
|
50
|
-
if "." in value:
|
|
51
|
-
value = float(value)
|
|
52
|
-
else:
|
|
53
|
-
value = int(value)
|
|
54
|
-
except ValueError:
|
|
55
|
-
pass
|
|
52
|
+
value = self._parse_env_value(value)
|
|
56
53
|
self._set_nested_value(self._config, path, value)
|
|
57
54
|
|
|
55
|
+
@staticmethod
|
|
56
|
+
def _parse_env_value(value: str) -> Any:
|
|
57
|
+
lower_value = value.lower()
|
|
58
|
+
if lower_value in {"true", "false"}:
|
|
59
|
+
return lower_value == "true"
|
|
60
|
+
try:
|
|
61
|
+
if "." in value:
|
|
62
|
+
return float(value)
|
|
63
|
+
return int(value)
|
|
64
|
+
except ValueError:
|
|
65
|
+
return value
|
|
66
|
+
|
|
58
67
|
def _set_nested_value(self, config: Dict, path: list, value: Any):
|
|
59
68
|
current = config
|
|
60
69
|
for key in path[:-1]:
|
|
@@ -95,5 +104,33 @@ class ConfigLoader:
|
|
|
95
104
|
return self._config
|
|
96
105
|
|
|
97
106
|
|
|
98
|
-
|
|
99
|
-
config
|
|
107
|
+
class LazyConfig:
|
|
108
|
+
"""Import-safe config proxy; real files/env are read only on first use."""
|
|
109
|
+
|
|
110
|
+
def __init__(self):
|
|
111
|
+
self._loader = None
|
|
112
|
+
self._lock = threading.Lock()
|
|
113
|
+
|
|
114
|
+
def load(self, config_path: str = None) -> ConfigLoader:
|
|
115
|
+
with self._lock:
|
|
116
|
+
if self._loader is None or config_path is not None:
|
|
117
|
+
self._loader = ConfigLoader(config_path)
|
|
118
|
+
return self._loader
|
|
119
|
+
|
|
120
|
+
def reload(self, config_path: str = None) -> ConfigLoader:
|
|
121
|
+
with self._lock:
|
|
122
|
+
self._loader = ConfigLoader(config_path)
|
|
123
|
+
return self._loader
|
|
124
|
+
|
|
125
|
+
def __getattr__(self, name: str) -> Any:
|
|
126
|
+
return getattr(self.load(), name)
|
|
127
|
+
|
|
128
|
+
def get(self, *keys: str, default: Any = None) -> Any:
|
|
129
|
+
return self.load().get(*keys, default=default)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
config = LazyConfig()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def get_config() -> ConfigLoader:
|
|
136
|
+
return config.load()
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
LOG_FORMAT = "%(asctime)s %(levelname)s [%(name)s] %(message)s"
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_logger(name: str) -> logging.Logger:
|
|
8
|
+
"""Return a console logger with a small, consistent default format."""
|
|
9
|
+
logger = logging.getLogger(name)
|
|
10
|
+
if not logger.handlers:
|
|
11
|
+
handler = logging.StreamHandler()
|
|
12
|
+
handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
|
13
|
+
logger.addHandler(handler)
|
|
14
|
+
logger.setLevel(logging.INFO)
|
|
15
|
+
logger.propagate = False
|
|
16
|
+
return logger
|
|
@@ -1,12 +1,21 @@
|
|
|
1
1
|
"""使用官方的sdk调用模型,整合工具调用能力"""
|
|
2
|
+
from typing import Generator
|
|
3
|
+
|
|
4
|
+
from src.models.schemas import ChatResponse
|
|
2
5
|
from src.utils.function_utils import tools
|
|
3
|
-
from src.service.chat_service import
|
|
6
|
+
from src.service.chat_service import ChatService, get_chat_service
|
|
4
7
|
|
|
5
8
|
|
|
6
9
|
class ToolCallManager:
|
|
7
10
|
def __init__(self):
|
|
8
|
-
self._chat_service =
|
|
11
|
+
self._chat_service = None
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def chat_service(self) -> ChatService:
|
|
15
|
+
if self._chat_service is None:
|
|
16
|
+
self._chat_service = get_chat_service()
|
|
17
|
+
return self._chat_service
|
|
9
18
|
|
|
10
|
-
def tool_call_stream(self, session_id: str, question: str):
|
|
19
|
+
def tool_call_stream(self, session_id: str, question: str) -> Generator[ChatResponse, None, None]:
|
|
11
20
|
"""调用service层的工具调用方法,传入自定义的tools,格式化返回流式结果"""
|
|
12
|
-
yield from self.
|
|
21
|
+
yield from self.chat_service.tool_call_stream(session_id, question, tools)
|
package/src/models/schemas.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Literal, Optional
|
|
2
|
+
|
|
1
3
|
from pydantic import BaseModel, Field
|
|
2
4
|
|
|
3
5
|
|
|
@@ -5,3 +7,16 @@ class ChatRequest(BaseModel):
|
|
|
5
7
|
"""聊天请求模型"""
|
|
6
8
|
session_id: str = Field(..., description="会话ID,用于多轮对话上下文")
|
|
7
9
|
question: str = Field(..., min_length=1, max_length=4096, description="用户问题")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ChatResponse(BaseModel):
|
|
13
|
+
"""聊天流式响应模型"""
|
|
14
|
+
session_id: str = Field(..., description="会话ID")
|
|
15
|
+
content: Optional[str] = Field(default=None, description="模型正式回复内容增量")
|
|
16
|
+
reasoning_content: Optional[str] = Field(default=None, description="模型推理内容增量")
|
|
17
|
+
tool_call: Optional[str] = Field(default=None, description="工具调用信息")
|
|
18
|
+
tool_result: Optional[str] = Field(default=None, description="工具执行结果")
|
|
19
|
+
finish_reason: Optional[Literal["stop", "error", "rejected", "done"]] = Field(
|
|
20
|
+
default=None,
|
|
21
|
+
description="结束原因;中间流式增量为空",
|
|
22
|
+
)
|
|
@@ -1,14 +1,23 @@
|
|
|
1
1
|
import threading
|
|
2
2
|
import time
|
|
3
|
-
from typing import Dict, Generator, List
|
|
3
|
+
from typing import Dict, Generator, List, Optional
|
|
4
4
|
|
|
5
5
|
from zai import ZhipuAiClient
|
|
6
6
|
|
|
7
7
|
from src.core.config_loader import config
|
|
8
|
+
from src.core.logging_config import get_logger
|
|
9
|
+
from src.models.schemas import ChatResponse
|
|
8
10
|
from src.utils.function_utils import call_function
|
|
9
11
|
from src.service.rag_service import RagService
|
|
10
12
|
|
|
11
13
|
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _preview(text: str, limit: int = 300) -> str:
|
|
18
|
+
return str(text).replace("\n", " ")[:limit]
|
|
19
|
+
|
|
20
|
+
|
|
12
21
|
TOOL_ROUTING_PROMPT = """
|
|
13
22
|
你正在处理网络支付清算平台联合运维客服问题。选择工具前先识别用户真实业务场景:
|
|
14
23
|
1. 提到变更、投产、升级、发布、回滚、灾备、演练、通知、报备、关闭渠道,优先判断为生产变更,使用 evaluate_change_strategy。
|
|
@@ -135,25 +144,40 @@ class ChatService:
|
|
|
135
144
|
messages.insert(insert_at, {"role": "system", "content": TOOL_ROUTING_PROMPT})
|
|
136
145
|
return messages
|
|
137
146
|
|
|
138
|
-
def tool_call_stream(self, session_id: str, question: str, tools) -> Generator[
|
|
147
|
+
def tool_call_stream(self, session_id: str, question: str, tools) -> Generator[ChatResponse, None, None]:
|
|
139
148
|
"""
|
|
140
149
|
带工具调用的流式对话(支持交错思考与工具调用)
|
|
141
150
|
- stream=True + tool_stream=True:模型在流式输出中同时返回推理过程、回答内容与工具调用
|
|
142
151
|
- 工具执行结果回传模型后继续流式生成,循环直至模型不再调用工具或达到最大轮次
|
|
143
152
|
"""
|
|
144
153
|
if not self._check_question_valid(question):
|
|
145
|
-
|
|
154
|
+
logger.info("问题未通过业务过滤。session_id=%s question=%s", session_id, _preview(question, 120))
|
|
155
|
+
yield ChatResponse(
|
|
156
|
+
session_id=session_id,
|
|
157
|
+
content=self._rejection_message,
|
|
158
|
+
finish_reason="rejected",
|
|
159
|
+
)
|
|
146
160
|
return
|
|
147
161
|
|
|
148
|
-
def
|
|
149
|
-
return
|
|
162
|
+
def content_event(content: str) -> ChatResponse:
|
|
163
|
+
return ChatResponse(session_id=session_id, content=content)
|
|
164
|
+
|
|
165
|
+
def reasoning_event(reasoning_content: str) -> ChatResponse:
|
|
166
|
+
return ChatResponse(session_id=session_id, reasoning_content=reasoning_content)
|
|
167
|
+
|
|
168
|
+
def tool_call_event(tool_call: str) -> ChatResponse:
|
|
169
|
+
return ChatResponse(session_id=session_id, tool_call=tool_call)
|
|
170
|
+
|
|
171
|
+
def tool_result_event(tool_result: str) -> ChatResponse:
|
|
172
|
+
return ChatResponse(session_id=session_id, tool_result=tool_result)
|
|
150
173
|
|
|
151
174
|
try:
|
|
152
175
|
messages = self._build_tool_messages(session_id, question)
|
|
153
176
|
max_rounds = config.tools.get("max_rounds", 5)
|
|
154
177
|
final_answer = ""
|
|
155
178
|
|
|
156
|
-
|
|
179
|
+
logger.info("开始模型流式调用。session_id=%s model=%s question=%s", session_id, self._model, _preview(question, 120))
|
|
180
|
+
for round_idx in range(max_rounds):
|
|
157
181
|
response = self._client.chat.completions.create(
|
|
158
182
|
model=self._model,
|
|
159
183
|
messages=messages,
|
|
@@ -167,6 +191,7 @@ class ChatService:
|
|
|
167
191
|
)
|
|
168
192
|
|
|
169
193
|
current_content = ""
|
|
194
|
+
reasoning_len = 0
|
|
170
195
|
tool_calls_map: Dict[int, Dict] = {}
|
|
171
196
|
|
|
172
197
|
for chunk in response:
|
|
@@ -176,16 +201,27 @@ class ChatService:
|
|
|
176
201
|
|
|
177
202
|
reasoning = getattr(delta, "reasoning_content", None)
|
|
178
203
|
if reasoning:
|
|
179
|
-
|
|
204
|
+
reasoning_len += len(reasoning)
|
|
205
|
+
yield reasoning_event(reasoning)
|
|
180
206
|
|
|
181
207
|
content = getattr(delta, "content", None)
|
|
182
208
|
if content:
|
|
183
209
|
current_content += content
|
|
184
|
-
yield
|
|
210
|
+
yield content_event(content)
|
|
185
211
|
|
|
186
212
|
for tc in getattr(delta, "tool_calls", None) or []:
|
|
187
213
|
self._merge_tool_call_delta(tool_calls_map, tc)
|
|
188
214
|
|
|
215
|
+
logger.info(
|
|
216
|
+
"模型流式返回完成。session_id=%s round=%s content_chars=%s reasoning_chars=%s tool_calls=%s content_preview=%s",
|
|
217
|
+
session_id,
|
|
218
|
+
round_idx + 1,
|
|
219
|
+
len(current_content),
|
|
220
|
+
reasoning_len,
|
|
221
|
+
len(tool_calls_map),
|
|
222
|
+
_preview(current_content),
|
|
223
|
+
)
|
|
224
|
+
|
|
189
225
|
if not tool_calls_map:
|
|
190
226
|
final_answer = current_content
|
|
191
227
|
break
|
|
@@ -200,10 +236,22 @@ class ChatService:
|
|
|
200
236
|
for tc in assistant_tool_calls:
|
|
201
237
|
name = tc["function"]["name"]
|
|
202
238
|
args = tc["function"]["arguments"]
|
|
203
|
-
|
|
239
|
+
logger.info(
|
|
240
|
+
"执行工具调用。session_id=%s tool=%s args=%s",
|
|
241
|
+
session_id,
|
|
242
|
+
name,
|
|
243
|
+
_preview(args, 200),
|
|
244
|
+
)
|
|
245
|
+
yield tool_call_event(f"\n[调用工具: {name}({args})]\n")
|
|
204
246
|
|
|
205
247
|
result = call_function(name, args)
|
|
206
|
-
|
|
248
|
+
logger.info(
|
|
249
|
+
"工具调用完成。session_id=%s tool=%s result_preview=%s",
|
|
250
|
+
session_id,
|
|
251
|
+
name,
|
|
252
|
+
_preview(result, 300),
|
|
253
|
+
)
|
|
254
|
+
yield tool_result_event(result)
|
|
207
255
|
|
|
208
256
|
messages.append({
|
|
209
257
|
"role": "tool",
|
|
@@ -212,17 +260,24 @@ class ChatService:
|
|
|
212
260
|
})
|
|
213
261
|
else:
|
|
214
262
|
final_answer = current_content or "[系统提示: 工具调用轮次已达上限]"
|
|
215
|
-
yield
|
|
263
|
+
yield content_event(final_answer)
|
|
216
264
|
|
|
217
265
|
self._append_exchange(session_id, question, final_answer)
|
|
218
|
-
|
|
266
|
+
logger.info(
|
|
267
|
+
"对话完成。session_id=%s final_answer_chars=%s final_answer_preview=%s",
|
|
268
|
+
session_id,
|
|
269
|
+
len(final_answer),
|
|
270
|
+
_preview(final_answer),
|
|
271
|
+
)
|
|
272
|
+
yield ChatResponse(session_id=session_id, finish_reason="stop")
|
|
219
273
|
|
|
220
274
|
except Exception as e:
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
"
|
|
225
|
-
|
|
275
|
+
logger.exception("模型调用异常。session_id=%s question=%s", session_id, _preview(question, 120))
|
|
276
|
+
yield ChatResponse(
|
|
277
|
+
session_id=session_id,
|
|
278
|
+
content=f"[错误] 模型调用异常: {str(e)}",
|
|
279
|
+
finish_reason="error",
|
|
280
|
+
)
|
|
226
281
|
|
|
227
282
|
@staticmethod
|
|
228
283
|
def _merge_tool_call_delta(tool_calls_map: Dict[int, Dict], tc) -> None:
|
|
@@ -242,4 +297,14 @@ class ChatService:
|
|
|
242
297
|
slot["function"]["arguments"] += fn.arguments
|
|
243
298
|
|
|
244
299
|
|
|
245
|
-
|
|
300
|
+
_chat_service: Optional[ChatService] = None
|
|
301
|
+
_chat_service_lock = threading.Lock()
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def get_chat_service() -> ChatService:
|
|
305
|
+
global _chat_service
|
|
306
|
+
if _chat_service is None:
|
|
307
|
+
with _chat_service_lock:
|
|
308
|
+
if _chat_service is None:
|
|
309
|
+
_chat_service = ChatService()
|
|
310
|
+
return _chat_service
|
|
@@ -7,6 +7,14 @@ import yaml
|
|
|
7
7
|
from zai import ZhipuAiClient
|
|
8
8
|
|
|
9
9
|
from src.core.config_loader import config
|
|
10
|
+
from src.core.logging_config import get_logger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _preview(text: str, limit: int = 300) -> str:
|
|
17
|
+
return str(text).replace("\n", " ")[:limit]
|
|
10
18
|
|
|
11
19
|
|
|
12
20
|
class RagService:
|
|
@@ -15,7 +23,7 @@ class RagService:
|
|
|
15
23
|
def __init__(self):
|
|
16
24
|
self._cfg = config.get("rag", default={})
|
|
17
25
|
self._enabled = self._cfg.get("enabled", False)
|
|
18
|
-
self._top_k = self._cfg.get("top_k", 5)
|
|
26
|
+
self._top_k = self._positive_int(self._cfg.get("top_k", 5), default=5)
|
|
19
27
|
self._collection = None
|
|
20
28
|
self._client = ZhipuAiClient(
|
|
21
29
|
api_key=config.llm.get("api_key"),
|
|
@@ -25,13 +33,16 @@ class RagService:
|
|
|
25
33
|
self._init_collection()
|
|
26
34
|
|
|
27
35
|
def search(self, question: str) -> Tuple[str, List[Dict]]:
|
|
28
|
-
if not self.
|
|
36
|
+
if not self._ensure_collection():
|
|
37
|
+
logger.info("RAG 未启用或集合不可用,跳过检索。question=%s", _preview(question, 120))
|
|
29
38
|
return "", []
|
|
30
39
|
|
|
31
40
|
try:
|
|
32
41
|
if self._collection.count() == 0:
|
|
42
|
+
logger.info("RAG 集合为空,开始重建知识库。")
|
|
33
43
|
self.rebuild()
|
|
34
44
|
if self._collection.count() == 0:
|
|
45
|
+
logger.info("RAG 重建后仍无可用文档。question=%s", _preview(question, 120))
|
|
35
46
|
return "", []
|
|
36
47
|
result = self._collection.query(
|
|
37
48
|
query_embeddings=[self._embed(question)],
|
|
@@ -39,6 +50,7 @@ class RagService:
|
|
|
39
50
|
include=["documents", "metadatas"],
|
|
40
51
|
)
|
|
41
52
|
except Exception:
|
|
53
|
+
logger.exception("RAG 检索异常,已降级为空上下文。question=%s", _preview(question, 120))
|
|
42
54
|
self._collection = None
|
|
43
55
|
return "", []
|
|
44
56
|
docs = result.get("documents", [[]])[0]
|
|
@@ -50,12 +62,35 @@ class RagService:
|
|
|
50
62
|
f"内容:{doc}"
|
|
51
63
|
for i, (doc, m) in enumerate(zip(docs, metas))
|
|
52
64
|
)
|
|
65
|
+
logger.info(
|
|
66
|
+
"RAG 检索完成。question=%s hit_count=%s sources=%s context_preview=%s",
|
|
67
|
+
_preview(question, 120),
|
|
68
|
+
len(docs),
|
|
69
|
+
sources,
|
|
70
|
+
_preview(context),
|
|
71
|
+
)
|
|
53
72
|
return context, sources
|
|
54
73
|
|
|
74
|
+
def _ensure_collection(self) -> bool:
|
|
75
|
+
if self._collection:
|
|
76
|
+
return True
|
|
77
|
+
if not self._enabled:
|
|
78
|
+
return False
|
|
79
|
+
self._init_collection()
|
|
80
|
+
return self._collection is not None
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def _positive_int(value, default: int) -> int:
|
|
84
|
+
try:
|
|
85
|
+
return max(int(value), 1)
|
|
86
|
+
except (TypeError, ValueError):
|
|
87
|
+
return default
|
|
88
|
+
|
|
55
89
|
def _init_collection(self):
|
|
56
90
|
try:
|
|
57
91
|
import chromadb
|
|
58
92
|
except ImportError:
|
|
93
|
+
logger.warning("未安装 chromadb,RAG 检索不可用。")
|
|
59
94
|
return
|
|
60
95
|
|
|
61
96
|
root = Path(__file__).resolve().parents[2]
|
|
@@ -67,11 +102,13 @@ class RagService:
|
|
|
67
102
|
if self._cfg.get("rebuild_on_startup", False):
|
|
68
103
|
self.rebuild()
|
|
69
104
|
except Exception:
|
|
105
|
+
logger.exception("RAG 初始化重建失败,已关闭当前集合。")
|
|
70
106
|
self._collection = None
|
|
71
107
|
|
|
72
108
|
def rebuild(self):
|
|
73
109
|
docs = self._load_documents()
|
|
74
110
|
if not docs:
|
|
111
|
+
logger.info("RAG 未加载到知识库文档,跳过重建。")
|
|
75
112
|
return
|
|
76
113
|
self._collection.upsert(
|
|
77
114
|
ids=[d["id"] for d in docs],
|
|
@@ -79,6 +116,7 @@ class RagService:
|
|
|
79
116
|
metadatas=[d["metadata"] for d in docs],
|
|
80
117
|
embeddings=[self._embed(d["content"]) for d in docs],
|
|
81
118
|
)
|
|
119
|
+
logger.info("RAG 知识库重建完成。doc_chunks=%s", len(docs))
|
|
82
120
|
|
|
83
121
|
def _load_documents(self) -> List[Dict]:
|
|
84
122
|
docs = []
|
|
@@ -8,6 +8,8 @@ LARGE_UNITS = {
|
|
|
8
8
|
"微众银行", "网商银行", "农信银中心", "支付宝", "财付通",
|
|
9
9
|
}
|
|
10
10
|
|
|
11
|
+
UNIT_CATEGORIES = {"大型单位", "中型单位", "小型单位"}
|
|
12
|
+
|
|
11
13
|
MEDIUM_UNITS = {
|
|
12
14
|
"中信银行", "光大银行", "民生银行", "兴业银行", "广发银行", "平安银行", "浦发银行",
|
|
13
15
|
"浙江联社", "网银在线",
|
|
@@ -49,6 +51,13 @@ def _normalize_unit_name(unit_name: Optional[str]) -> str:
|
|
|
49
51
|
return (unit_name or "").strip().replace("中国邮政储蓄银行", "邮储银行")
|
|
50
52
|
|
|
51
53
|
|
|
54
|
+
def _resolve_unit_category(unit_category: Optional[str] = None, unit_name: Optional[str] = None) -> str:
|
|
55
|
+
"""Resolve a valid unit category and avoid tool-call input causing KeyError."""
|
|
56
|
+
if unit_category in UNIT_CATEGORIES:
|
|
57
|
+
return unit_category
|
|
58
|
+
return classify_member_unit(unit_name).get("category", "小型单位")
|
|
59
|
+
|
|
60
|
+
|
|
52
61
|
def classify_member_unit(unit_name: Optional[str] = None, daily_txn_count: Optional[int] = None) -> Dict[str, Any]:
|
|
53
62
|
"""按指引附录B或上一年全年日均交易量识别成员单位分类。"""
|
|
54
63
|
normalized = _normalize_unit_name(unit_name)
|
|
@@ -111,7 +120,7 @@ def evaluate_fault_grade(
|
|
|
111
120
|
duration_minutes: Optional[float] = None,
|
|
112
121
|
) -> Dict[str, Any]:
|
|
113
122
|
"""根据异常交易笔数或异常持续时间判断运行故障级别。"""
|
|
114
|
-
category = unit_category
|
|
123
|
+
category = _resolve_unit_category(unit_category, unit_name)
|
|
115
124
|
thresholds = {
|
|
116
125
|
"大型单位": {
|
|
117
126
|
"轻微故障": {"txn_min": 1000, "txn_max": 25000, "duration_min": 0, "duration_max": 10},
|
|
@@ -170,7 +179,7 @@ def evaluate_operation_scene(
|
|
|
170
179
|
bandwidth_usage_pct: Optional[float] = None,
|
|
171
180
|
) -> Dict[str, Any]:
|
|
172
181
|
"""按生产运行场景识别风险、联合处置或关闭渠道策略。"""
|
|
173
|
-
category = unit_category
|
|
182
|
+
category = _resolve_unit_category(unit_category, unit_name)
|
|
174
183
|
evidence = []
|
|
175
184
|
scene = "未触发明确处置场景"
|
|
176
185
|
action = "继续监控,补充系统成功率、失败笔数、耗时、异常交易数量、持续时间等指标后再判断"
|
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
from fastapi.testclient import TestClient
|
|
2
|
+
|
|
3
|
+
from main import app
|
|
4
|
+
from src.core.config_loader import LazyConfig
|
|
5
|
+
from src.manager.toolcall_manager import ToolCallManager
|
|
6
|
+
from src.models.schemas import ChatResponse
|
|
1
7
|
from src.service.chat_service import ChatService
|
|
2
8
|
from src.utils.function_utils import (
|
|
3
9
|
classify_member_unit,
|
|
@@ -7,6 +13,9 @@ from src.utils.function_utils import (
|
|
|
7
13
|
)
|
|
8
14
|
|
|
9
15
|
|
|
16
|
+
client = TestClient(app)
|
|
17
|
+
|
|
18
|
+
|
|
10
19
|
def test_question_filter_allowed():
|
|
11
20
|
"""测试允许的问题"""
|
|
12
21
|
svc = ChatService()
|
|
@@ -28,9 +37,28 @@ def test_rejected_response_format():
|
|
|
28
37
|
"""测试拒绝返回格式"""
|
|
29
38
|
svc = ChatService()
|
|
30
39
|
result = next(svc.tool_call_stream("session-test", "讲个笑话", tools=[]))
|
|
31
|
-
assert result
|
|
32
|
-
assert result
|
|
33
|
-
assert result
|
|
40
|
+
assert result.session_id == "session-test"
|
|
41
|
+
assert result.finish_reason == "rejected"
|
|
42
|
+
assert result.content
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_chat_response_separates_stream_fields():
|
|
46
|
+
response = ChatResponse(session_id="session-test", reasoning_content="推理")
|
|
47
|
+
assert response.reasoning_content == "推理"
|
|
48
|
+
assert response.content is None
|
|
49
|
+
assert response.tool_call is None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_config_proxy_is_lazy_until_first_use():
|
|
53
|
+
lazy_config = LazyConfig()
|
|
54
|
+
assert lazy_config._loader is None
|
|
55
|
+
assert lazy_config.get("missing", default="fallback") == "fallback"
|
|
56
|
+
assert lazy_config._loader is not None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_tool_call_manager_does_not_init_chat_service_on_construction():
|
|
60
|
+
manager = ToolCallManager()
|
|
61
|
+
assert manager._chat_service is None
|
|
34
62
|
|
|
35
63
|
|
|
36
64
|
def test_conversation_history_uses_session_id():
|
|
@@ -57,6 +85,12 @@ def test_evaluate_fault_grade_uses_highest_matched_rule():
|
|
|
57
85
|
assert "或关系" in result["rule"]
|
|
58
86
|
|
|
59
87
|
|
|
88
|
+
def test_evaluate_fault_grade_handles_invalid_category():
|
|
89
|
+
result = evaluate_fault_grade(unit_category="未知类型", abnormal_txn_count=12000)
|
|
90
|
+
assert result["unit_category"] == "小型单位"
|
|
91
|
+
assert result["fault_level"] == "严重故障"
|
|
92
|
+
|
|
93
|
+
|
|
60
94
|
def test_evaluate_operation_scene_close_channel():
|
|
61
95
|
result = evaluate_operation_scene(
|
|
62
96
|
unit_category="中型单位",
|
|
@@ -73,3 +107,9 @@ def test_evaluate_change_strategy_shutdown():
|
|
|
73
107
|
assert result["change_scene"] == "场景四"
|
|
74
108
|
assert "关闭渠道" in result["recommended_action"]
|
|
75
109
|
assert "小于30分钟" in result["notice_judgement"]
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_health_endpoint():
|
|
113
|
+
response = client.get("/api/v1/health")
|
|
114
|
+
assert response.status_code == 200
|
|
115
|
+
assert response.json() == {"status": "ok"}
|