union-py-app-stream-chat 1.0.0 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -14,10 +14,10 @@ glm-ops-assistant/
14
14
  │ ├── service/
15
15
  │ │ └── chat_service.py # Service 层 - 业务逻辑(问题筛选、流程编排)
16
16
  │ ├── manager/
17
- │ │ ├── llm_manager.py # Manager 层 - GLM SDK 调用封装
18
- │ │ └── session_manager.py # Manager 层 - 多轮会话管理
17
+ │ │ └── toolcall_manager.py # Manager 层 - 工具调用封装
19
18
  │ ├── core/
20
- │ │ └── config_loader.py # 核心工具 - 配置加载(支持环境变量覆盖)
19
+ │ │ ├── config_loader.py # 核心工具 - 配置加载(支持环境变量覆盖)
20
+ │ │ └── logging_config.py # 核心工具 - 日志配置
21
21
  │ └── models/
22
22
  │ └── schemas.py # Pydantic 数据模型
23
23
  ├── tests/
@@ -48,8 +48,10 @@ llm:
48
48
  或通过环境变量覆盖:
49
49
 
50
50
  ```bash
51
+ export APP_CONFIG_PATH="/path/to/config.yaml"
51
52
  export LLM_API_KEY="your-api-key"
52
53
  export LLM_MODEL="glm-4-flash"
54
+ export RAG_ENABLED="true"
53
55
  ```
54
56
 
55
57
  ### 3. 启动服务
@@ -64,10 +66,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
64
66
 
65
67
  | 方法 | 路径 | 说明 |
66
68
  |------|------|------|
67
- | POST | `/api/v1/chat` | 普通对话(非流式) |
68
69
  | POST | `/api/v1/chat/stream` | 流式对话(SSE) |
69
- | POST | `/api/v1/session/clear` | 清空会话历史 |
70
- | GET | `/api/v1/session/info` | 获取会话信息 |
71
70
  | GET | `/api/v1/health` | 健康检查 |
72
71
 
73
72
  ## 多轮对话
@@ -83,3 +82,18 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
83
82
  ## 流式输出
84
83
 
85
84
  调用 `/api/v1/chat/stream` 接口,服务端使用 SSE 协议逐字推送模型生成内容。
85
+
86
+ 每条 `message` 事件的 `data` 为 `ChatResponse` JSON:
87
+
88
+ ```json
89
+ {
90
+ "session_id": "sess-xxx",
91
+ "content": "正式回复增量",
92
+ "reasoning_content": null,
93
+ "tool_call": null,
94
+ "tool_result": null,
95
+ "finish_reason": null
96
+ }
97
+ ```
98
+
99
+ 模型正式回复写入 `content`,推理内容写入 `reasoning_content`,工具过程写入 `tool_call` / `tool_result`。
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "union-py-app-stream-chat",
3
- "version": "1.0.0",
3
+ "version": "1.1.0",
4
4
  "description": "Source package for the union operations stream chat Python app.",
5
5
  "license": "UNLICENSED",
6
6
  "private": false,
package/src/api/routes.py CHANGED
@@ -1,11 +1,16 @@
1
- import json
2
1
  from fastapi import APIRouter
3
2
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
4
3
 
5
- from src.models.schemas import ChatRequest
4
+ from src.models.schemas import ChatRequest, ChatResponse
6
5
  from src.manager.toolcall_manager import ToolCallManager
7
6
 
8
7
  router = APIRouter()
8
+ tool_call_manager = ToolCallManager()
9
+
10
+
11
+ @router.get("/health")
12
+ def health_check():
13
+ return {"status": "ok"}
9
14
 
10
15
 
11
16
  @router.post("/chat/stream")
@@ -16,19 +21,18 @@ def chat_stream_endpoint(request: ChatRequest):
16
21
  - 接收用户问题
17
22
  - 逐块返回生成的内容
18
23
  """
19
- manager = ToolCallManager()
20
-
21
24
  def event_generator():
22
- for chunk in manager.tool_call_stream(request.session_id, request.question):
25
+ for chunk in tool_call_manager.tool_call_stream(request.session_id, request.question):
23
26
  yield ServerSentEvent(
24
27
  event="message",
25
- data=json.dumps(chunk, ensure_ascii=False),
28
+ data=chunk.model_dump_json(),
26
29
  )
27
30
 
28
31
  # SSE 结束标记
32
+ done = ChatResponse(session_id=request.session_id, finish_reason="done")
29
33
  yield ServerSentEvent(
30
34
  event="done",
31
- data=json.dumps({"session_id": request.session_id, "finish_reason": "done"}, ensure_ascii=False),
35
+ data=done.model_dump_json(),
32
36
  )
33
37
 
34
38
  return EventSourceResponse(event_generator())
@@ -1,26 +1,25 @@
1
1
  import os
2
+ import threading
2
3
  import yaml
3
4
  from typing import Any, Dict
4
5
 
6
+ from src.core.logging_config import get_logger
5
7
 
6
- class ConfigLoader:
7
- """配置加载器,支持从YAML文件加载配置,并可通过环境变量覆盖。"""
8
8
 
9
- _instance = None
10
- _config = None
9
+ logger = get_logger(__name__)
11
10
 
12
- def __new__(cls, *args, **kwargs):
13
- if not cls._instance:
14
- cls._instance = super().__new__(cls)
15
- return cls._instance
16
11
 
17
- def __init__(self, config_path: str = None):
18
- if self._config is not None:
19
- return
12
+ class ConfigLoader:
13
+ """配置加载器,支持从YAML文件加载配置,并可通过环境变量覆盖。"""
20
14
 
15
+ def __init__(self, config_path: str = None):
16
+ base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
17
  if config_path is None:
22
- base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
- config_path = os.path.join(base_dir, "config", "config.yaml")
18
+ config_path = os.environ.get("APP_CONFIG_PATH") or os.path.join(base_dir, "config", "config.yaml")
19
+ if not os.path.exists(config_path):
20
+ fallback_path = os.path.join(base_dir, "config", "config.example.yaml")
21
+ logger.warning("配置文件不存在,使用示例配置。config_path=%s fallback=%s", config_path, fallback_path)
22
+ config_path = fallback_path
24
23
 
25
24
  self._config_path = config_path
26
25
  self._config = self._load_yaml()
@@ -28,7 +27,10 @@ class ConfigLoader:
28
27
 
29
28
  def _load_yaml(self) -> Dict[str, Any]:
30
29
  with open(self._config_path, "r", encoding="utf-8") as f:
31
- return yaml.safe_load(f)
30
+ data = yaml.safe_load(f) or {}
31
+ if not isinstance(data, dict):
32
+ raise ValueError(f"配置文件格式错误,顶层必须是对象: {self._config_path}")
33
+ return data
32
34
 
33
35
  def _override_from_env(self):
34
36
  """通过环境变量覆盖配置,支持 LLM_API_KEY, LLM_MODEL, LLM_BASE_URL 等。"""
@@ -40,21 +42,28 @@ class ConfigLoader:
40
42
  "LLM_TEMPERATURE": ["llm", "temperature"],
41
43
  "SESSION_TTL": ["session", "ttl_seconds"],
42
44
  "SESSION_MAX_HISTORY": ["session", "max_history"],
45
+ "RAG_ENABLED": ["rag", "enabled"],
46
+ "RAG_TOP_K": ["rag", "top_k"],
43
47
  }
44
48
 
45
49
  for env_var, path in env_mappings.items():
46
50
  value = os.environ.get(env_var)
47
51
  if value is not None:
48
- # 尝试数字转换
49
- try:
50
- if "." in value:
51
- value = float(value)
52
- else:
53
- value = int(value)
54
- except ValueError:
55
- pass
52
+ value = self._parse_env_value(value)
56
53
  self._set_nested_value(self._config, path, value)
57
54
 
55
+ @staticmethod
56
+ def _parse_env_value(value: str) -> Any:
57
+ lower_value = value.lower()
58
+ if lower_value in {"true", "false"}:
59
+ return lower_value == "true"
60
+ try:
61
+ if "." in value:
62
+ return float(value)
63
+ return int(value)
64
+ except ValueError:
65
+ return value
66
+
58
67
  def _set_nested_value(self, config: Dict, path: list, value: Any):
59
68
  current = config
60
69
  for key in path[:-1]:
@@ -95,5 +104,33 @@ class ConfigLoader:
95
104
  return self._config
96
105
 
97
106
 
98
- # 全局配置实例
99
- config = ConfigLoader()
107
+ class LazyConfig:
108
+ """Import-safe config proxy; real files/env are read only on first use."""
109
+
110
+ def __init__(self):
111
+ self._loader = None
112
+ self._lock = threading.Lock()
113
+
114
+ def load(self, config_path: str = None) -> ConfigLoader:
115
+ with self._lock:
116
+ if self._loader is None or config_path is not None:
117
+ self._loader = ConfigLoader(config_path)
118
+ return self._loader
119
+
120
+ def reload(self, config_path: str = None) -> ConfigLoader:
121
+ with self._lock:
122
+ self._loader = ConfigLoader(config_path)
123
+ return self._loader
124
+
125
+ def __getattr__(self, name: str) -> Any:
126
+ return getattr(self.load(), name)
127
+
128
+ def get(self, *keys: str, default: Any = None) -> Any:
129
+ return self.load().get(*keys, default=default)
130
+
131
+
132
+ config = LazyConfig()
133
+
134
+
135
+ def get_config() -> ConfigLoader:
136
+ return config.load()
@@ -0,0 +1,16 @@
1
+ import logging
2
+
3
+
4
+ LOG_FORMAT = "%(asctime)s %(levelname)s [%(name)s] %(message)s"
5
+
6
+
7
+ def get_logger(name: str) -> logging.Logger:
8
+ """Return a console logger with a small, consistent default format."""
9
+ logger = logging.getLogger(name)
10
+ if not logger.handlers:
11
+ handler = logging.StreamHandler()
12
+ handler.setFormatter(logging.Formatter(LOG_FORMAT))
13
+ logger.addHandler(handler)
14
+ logger.setLevel(logging.INFO)
15
+ logger.propagate = False
16
+ return logger
@@ -1,12 +1,21 @@
1
1
  """使用官方的sdk调用模型,整合工具调用能力"""
2
+ from typing import Generator
3
+
4
+ from src.models.schemas import ChatResponse
2
5
  from src.utils.function_utils import tools
3
- from src.service.chat_service import chat_service
6
+ from src.service.chat_service import ChatService, get_chat_service
4
7
 
5
8
 
6
9
  class ToolCallManager:
7
10
  def __init__(self):
8
- self._chat_service = chat_service
11
+ self._chat_service = None
12
+
13
+ @property
14
+ def chat_service(self) -> ChatService:
15
+ if self._chat_service is None:
16
+ self._chat_service = get_chat_service()
17
+ return self._chat_service
9
18
 
10
- def tool_call_stream(self, session_id: str, question: str):
19
+ def tool_call_stream(self, session_id: str, question: str) -> Generator[ChatResponse, None, None]:
11
20
  """调用service层的工具调用方法,传入自定义的tools,格式化返回流式结果"""
12
- yield from self._chat_service.tool_call_stream(session_id, question, tools)
21
+ yield from self.chat_service.tool_call_stream(session_id, question, tools)
@@ -1,3 +1,5 @@
1
+ from typing import Literal, Optional
2
+
1
3
  from pydantic import BaseModel, Field
2
4
 
3
5
 
@@ -5,3 +7,16 @@ class ChatRequest(BaseModel):
5
7
  """聊天请求模型"""
6
8
  session_id: str = Field(..., description="会话ID,用于多轮对话上下文")
7
9
  question: str = Field(..., min_length=1, max_length=4096, description="用户问题")
10
+
11
+
12
+ class ChatResponse(BaseModel):
13
+ """聊天流式响应模型"""
14
+ session_id: str = Field(..., description="会话ID")
15
+ content: Optional[str] = Field(default=None, description="模型正式回复内容增量")
16
+ reasoning_content: Optional[str] = Field(default=None, description="模型推理内容增量")
17
+ tool_call: Optional[str] = Field(default=None, description="工具调用信息")
18
+ tool_result: Optional[str] = Field(default=None, description="工具执行结果")
19
+ finish_reason: Optional[Literal["stop", "error", "rejected", "done"]] = Field(
20
+ default=None,
21
+ description="结束原因;中间流式增量为空",
22
+ )
@@ -1,14 +1,23 @@
1
1
  import threading
2
2
  import time
3
- from typing import Dict, Generator, List
3
+ from typing import Dict, Generator, List, Optional
4
4
 
5
5
  from zai import ZhipuAiClient
6
6
 
7
7
  from src.core.config_loader import config
8
+ from src.core.logging_config import get_logger
9
+ from src.models.schemas import ChatResponse
8
10
  from src.utils.function_utils import call_function
9
11
  from src.service.rag_service import RagService
10
12
 
11
13
 
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def _preview(text: str, limit: int = 300) -> str:
18
+ return str(text).replace("\n", " ")[:limit]
19
+
20
+
12
21
  TOOL_ROUTING_PROMPT = """
13
22
  你正在处理网络支付清算平台联合运维客服问题。选择工具前先识别用户真实业务场景:
14
23
  1. 提到变更、投产、升级、发布、回滚、灾备、演练、通知、报备、关闭渠道,优先判断为生产变更,使用 evaluate_change_strategy。
@@ -135,25 +144,40 @@ class ChatService:
135
144
  messages.insert(insert_at, {"role": "system", "content": TOOL_ROUTING_PROMPT})
136
145
  return messages
137
146
 
138
- def tool_call_stream(self, session_id: str, question: str, tools) -> Generator[dict, None, None]:
147
+ def tool_call_stream(self, session_id: str, question: str, tools) -> Generator[ChatResponse, None, None]:
139
148
  """
140
149
  带工具调用的流式对话(支持交错思考与工具调用)
141
150
  - stream=True + tool_stream=True:模型在流式输出中同时返回推理过程、回答内容与工具调用
142
151
  - 工具执行结果回传模型后继续流式生成,循环直至模型不再调用工具或达到最大轮次
143
152
  """
144
153
  if not self._check_question_valid(question):
145
- yield {"session_id": session_id, "delta": self._rejection_message, "finish_reason": "rejected"}
154
+ logger.info("问题未通过业务过滤。session_id=%s question=%s", session_id, _preview(question, 120))
155
+ yield ChatResponse(
156
+ session_id=session_id,
157
+ content=self._rejection_message,
158
+ finish_reason="rejected",
159
+ )
146
160
  return
147
161
 
148
- def event(delta: str, etype: str) -> dict:
149
- return {"session_id": session_id, "delta": delta, "type": etype, "finish_reason": None}
162
+ def content_event(content: str) -> ChatResponse:
163
+ return ChatResponse(session_id=session_id, content=content)
164
+
165
+ def reasoning_event(reasoning_content: str) -> ChatResponse:
166
+ return ChatResponse(session_id=session_id, reasoning_content=reasoning_content)
167
+
168
+ def tool_call_event(tool_call: str) -> ChatResponse:
169
+ return ChatResponse(session_id=session_id, tool_call=tool_call)
170
+
171
+ def tool_result_event(tool_result: str) -> ChatResponse:
172
+ return ChatResponse(session_id=session_id, tool_result=tool_result)
150
173
 
151
174
  try:
152
175
  messages = self._build_tool_messages(session_id, question)
153
176
  max_rounds = config.tools.get("max_rounds", 5)
154
177
  final_answer = ""
155
178
 
156
- for _ in range(max_rounds):
179
+ logger.info("开始模型流式调用。session_id=%s model=%s question=%s", session_id, self._model, _preview(question, 120))
180
+ for round_idx in range(max_rounds):
157
181
  response = self._client.chat.completions.create(
158
182
  model=self._model,
159
183
  messages=messages,
@@ -167,6 +191,7 @@ class ChatService:
167
191
  )
168
192
 
169
193
  current_content = ""
194
+ reasoning_len = 0
170
195
  tool_calls_map: Dict[int, Dict] = {}
171
196
 
172
197
  for chunk in response:
@@ -176,16 +201,27 @@ class ChatService:
176
201
 
177
202
  reasoning = getattr(delta, "reasoning_content", None)
178
203
  if reasoning:
179
- yield event(reasoning, "reasoning")
204
+ reasoning_len += len(reasoning)
205
+ yield reasoning_event(reasoning)
180
206
 
181
207
  content = getattr(delta, "content", None)
182
208
  if content:
183
209
  current_content += content
184
- yield event(content, "content")
210
+ yield content_event(content)
185
211
 
186
212
  for tc in getattr(delta, "tool_calls", None) or []:
187
213
  self._merge_tool_call_delta(tool_calls_map, tc)
188
214
 
215
+ logger.info(
216
+ "模型流式返回完成。session_id=%s round=%s content_chars=%s reasoning_chars=%s tool_calls=%s content_preview=%s",
217
+ session_id,
218
+ round_idx + 1,
219
+ len(current_content),
220
+ reasoning_len,
221
+ len(tool_calls_map),
222
+ _preview(current_content),
223
+ )
224
+
189
225
  if not tool_calls_map:
190
226
  final_answer = current_content
191
227
  break
@@ -200,10 +236,22 @@ class ChatService:
200
236
  for tc in assistant_tool_calls:
201
237
  name = tc["function"]["name"]
202
238
  args = tc["function"]["arguments"]
203
- yield event(f"\n[调用工具: {name}({args})]\n", "tool_call")
239
+ logger.info(
240
+ "执行工具调用。session_id=%s tool=%s args=%s",
241
+ session_id,
242
+ name,
243
+ _preview(args, 200),
244
+ )
245
+ yield tool_call_event(f"\n[调用工具: {name}({args})]\n")
204
246
 
205
247
  result = call_function(name, args)
206
- yield event(result, "tool_result")
248
+ logger.info(
249
+ "工具调用完成。session_id=%s tool=%s result_preview=%s",
250
+ session_id,
251
+ name,
252
+ _preview(result, 300),
253
+ )
254
+ yield tool_result_event(result)
207
255
 
208
256
  messages.append({
209
257
  "role": "tool",
@@ -212,17 +260,24 @@ class ChatService:
212
260
  })
213
261
  else:
214
262
  final_answer = current_content or "[系统提示: 工具调用轮次已达上限]"
215
- yield event(final_answer, "content")
263
+ yield content_event(final_answer)
216
264
 
217
265
  self._append_exchange(session_id, question, final_answer)
218
- yield {"session_id": session_id, "delta": "", "finish_reason": "stop"}
266
+ logger.info(
267
+ "对话完成。session_id=%s final_answer_chars=%s final_answer_preview=%s",
268
+ session_id,
269
+ len(final_answer),
270
+ _preview(final_answer),
271
+ )
272
+ yield ChatResponse(session_id=session_id, finish_reason="stop")
219
273
 
220
274
  except Exception as e:
221
- yield {
222
- "session_id": session_id,
223
- "delta": f"[错误] 模型调用异常: {str(e)}",
224
- "finish_reason": "error",
225
- }
275
+ logger.exception("模型调用异常。session_id=%s question=%s", session_id, _preview(question, 120))
276
+ yield ChatResponse(
277
+ session_id=session_id,
278
+ content=f"[错误] 模型调用异常: {str(e)}",
279
+ finish_reason="error",
280
+ )
226
281
 
227
282
  @staticmethod
228
283
  def _merge_tool_call_delta(tool_calls_map: Dict[int, Dict], tc) -> None:
@@ -242,4 +297,14 @@ class ChatService:
242
297
  slot["function"]["arguments"] += fn.arguments
243
298
 
244
299
 
245
- chat_service = ChatService()
300
+ _chat_service: Optional[ChatService] = None
301
+ _chat_service_lock = threading.Lock()
302
+
303
+
304
+ def get_chat_service() -> ChatService:
305
+ global _chat_service
306
+ if _chat_service is None:
307
+ with _chat_service_lock:
308
+ if _chat_service is None:
309
+ _chat_service = ChatService()
310
+ return _chat_service
@@ -7,6 +7,14 @@ import yaml
7
7
  from zai import ZhipuAiClient
8
8
 
9
9
  from src.core.config_loader import config
10
+ from src.core.logging_config import get_logger
11
+
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ def _preview(text: str, limit: int = 300) -> str:
17
+ return str(text).replace("\n", " ")[:limit]
10
18
 
11
19
 
12
20
  class RagService:
@@ -15,7 +23,7 @@ class RagService:
15
23
  def __init__(self):
16
24
  self._cfg = config.get("rag", default={})
17
25
  self._enabled = self._cfg.get("enabled", False)
18
- self._top_k = self._cfg.get("top_k", 5)
26
+ self._top_k = self._positive_int(self._cfg.get("top_k", 5), default=5)
19
27
  self._collection = None
20
28
  self._client = ZhipuAiClient(
21
29
  api_key=config.llm.get("api_key"),
@@ -25,13 +33,16 @@ class RagService:
25
33
  self._init_collection()
26
34
 
27
35
  def search(self, question: str) -> Tuple[str, List[Dict]]:
28
- if not self._collection:
36
+ if not self._ensure_collection():
37
+ logger.info("RAG 未启用或集合不可用,跳过检索。question=%s", _preview(question, 120))
29
38
  return "", []
30
39
 
31
40
  try:
32
41
  if self._collection.count() == 0:
42
+ logger.info("RAG 集合为空,开始重建知识库。")
33
43
  self.rebuild()
34
44
  if self._collection.count() == 0:
45
+ logger.info("RAG 重建后仍无可用文档。question=%s", _preview(question, 120))
35
46
  return "", []
36
47
  result = self._collection.query(
37
48
  query_embeddings=[self._embed(question)],
@@ -39,6 +50,7 @@ class RagService:
39
50
  include=["documents", "metadatas"],
40
51
  )
41
52
  except Exception:
53
+ logger.exception("RAG 检索异常,已降级为空上下文。question=%s", _preview(question, 120))
42
54
  self._collection = None
43
55
  return "", []
44
56
  docs = result.get("documents", [[]])[0]
@@ -50,12 +62,35 @@ class RagService:
50
62
  f"内容:{doc}"
51
63
  for i, (doc, m) in enumerate(zip(docs, metas))
52
64
  )
65
+ logger.info(
66
+ "RAG 检索完成。question=%s hit_count=%s sources=%s context_preview=%s",
67
+ _preview(question, 120),
68
+ len(docs),
69
+ sources,
70
+ _preview(context),
71
+ )
53
72
  return context, sources
54
73
 
74
+ def _ensure_collection(self) -> bool:
75
+ if self._collection:
76
+ return True
77
+ if not self._enabled:
78
+ return False
79
+ self._init_collection()
80
+ return self._collection is not None
81
+
82
+ @staticmethod
83
+ def _positive_int(value, default: int) -> int:
84
+ try:
85
+ return max(int(value), 1)
86
+ except (TypeError, ValueError):
87
+ return default
88
+
55
89
  def _init_collection(self):
56
90
  try:
57
91
  import chromadb
58
92
  except ImportError:
93
+ logger.warning("未安装 chromadb,RAG 检索不可用。")
59
94
  return
60
95
 
61
96
  root = Path(__file__).resolve().parents[2]
@@ -67,11 +102,13 @@ class RagService:
67
102
  if self._cfg.get("rebuild_on_startup", False):
68
103
  self.rebuild()
69
104
  except Exception:
105
+ logger.exception("RAG 初始化重建失败,已关闭当前集合。")
70
106
  self._collection = None
71
107
 
72
108
  def rebuild(self):
73
109
  docs = self._load_documents()
74
110
  if not docs:
111
+ logger.info("RAG 未加载到知识库文档,跳过重建。")
75
112
  return
76
113
  self._collection.upsert(
77
114
  ids=[d["id"] for d in docs],
@@ -79,6 +116,7 @@ class RagService:
79
116
  metadatas=[d["metadata"] for d in docs],
80
117
  embeddings=[self._embed(d["content"]) for d in docs],
81
118
  )
119
+ logger.info("RAG 知识库重建完成。doc_chunks=%s", len(docs))
82
120
 
83
121
  def _load_documents(self) -> List[Dict]:
84
122
  docs = []
@@ -8,6 +8,8 @@ LARGE_UNITS = {
8
8
  "微众银行", "网商银行", "农信银中心", "支付宝", "财付通",
9
9
  }
10
10
 
11
+ UNIT_CATEGORIES = {"大型单位", "中型单位", "小型单位"}
12
+
11
13
  MEDIUM_UNITS = {
12
14
  "中信银行", "光大银行", "民生银行", "兴业银行", "广发银行", "平安银行", "浦发银行",
13
15
  "浙江联社", "网银在线",
@@ -49,6 +51,13 @@ def _normalize_unit_name(unit_name: Optional[str]) -> str:
49
51
  return (unit_name or "").strip().replace("中国邮政储蓄银行", "邮储银行")
50
52
 
51
53
 
54
+ def _resolve_unit_category(unit_category: Optional[str] = None, unit_name: Optional[str] = None) -> str:
55
+ """Resolve a valid unit category and avoid tool-call input causing KeyError."""
56
+ if unit_category in UNIT_CATEGORIES:
57
+ return unit_category
58
+ return classify_member_unit(unit_name).get("category", "小型单位")
59
+
60
+
52
61
  def classify_member_unit(unit_name: Optional[str] = None, daily_txn_count: Optional[int] = None) -> Dict[str, Any]:
53
62
  """按指引附录B或上一年全年日均交易量识别成员单位分类。"""
54
63
  normalized = _normalize_unit_name(unit_name)
@@ -111,7 +120,7 @@ def evaluate_fault_grade(
111
120
  duration_minutes: Optional[float] = None,
112
121
  ) -> Dict[str, Any]:
113
122
  """根据异常交易笔数或异常持续时间判断运行故障级别。"""
114
- category = unit_category or classify_member_unit(unit_name).get("category")
123
+ category = _resolve_unit_category(unit_category, unit_name)
115
124
  thresholds = {
116
125
  "大型单位": {
117
126
  "轻微故障": {"txn_min": 1000, "txn_max": 25000, "duration_min": 0, "duration_max": 10},
@@ -170,7 +179,7 @@ def evaluate_operation_scene(
170
179
  bandwidth_usage_pct: Optional[float] = None,
171
180
  ) -> Dict[str, Any]:
172
181
  """按生产运行场景识别风险、联合处置或关闭渠道策略。"""
173
- category = unit_category or classify_member_unit(unit_name).get("category")
182
+ category = _resolve_unit_category(unit_category, unit_name)
174
183
  evidence = []
175
184
  scene = "未触发明确处置场景"
176
185
  action = "继续监控,补充系统成功率、失败笔数、耗时、异常交易数量、持续时间等指标后再判断"
@@ -1,3 +1,9 @@
1
+ from fastapi.testclient import TestClient
2
+
3
+ from main import app
4
+ from src.core.config_loader import LazyConfig
5
+ from src.manager.toolcall_manager import ToolCallManager
6
+ from src.models.schemas import ChatResponse
1
7
  from src.service.chat_service import ChatService
2
8
  from src.utils.function_utils import (
3
9
  classify_member_unit,
@@ -7,6 +13,9 @@ from src.utils.function_utils import (
7
13
  )
8
14
 
9
15
 
16
+ client = TestClient(app)
17
+
18
+
10
19
  def test_question_filter_allowed():
11
20
  """测试允许的问题"""
12
21
  svc = ChatService()
@@ -28,9 +37,28 @@ def test_rejected_response_format():
28
37
  """测试拒绝返回格式"""
29
38
  svc = ChatService()
30
39
  result = next(svc.tool_call_stream("session-test", "讲个笑话", tools=[]))
31
- assert result["session_id"] == "session-test"
32
- assert result["finish_reason"] == "rejected"
33
- assert result["delta"]
40
+ assert result.session_id == "session-test"
41
+ assert result.finish_reason == "rejected"
42
+ assert result.content
43
+
44
+
45
+ def test_chat_response_separates_stream_fields():
46
+ response = ChatResponse(session_id="session-test", reasoning_content="推理")
47
+ assert response.reasoning_content == "推理"
48
+ assert response.content is None
49
+ assert response.tool_call is None
50
+
51
+
52
+ def test_config_proxy_is_lazy_until_first_use():
53
+ lazy_config = LazyConfig()
54
+ assert lazy_config._loader is None
55
+ assert lazy_config.get("missing", default="fallback") == "fallback"
56
+ assert lazy_config._loader is not None
57
+
58
+
59
+ def test_tool_call_manager_does_not_init_chat_service_on_construction():
60
+ manager = ToolCallManager()
61
+ assert manager._chat_service is None
34
62
 
35
63
 
36
64
  def test_conversation_history_uses_session_id():
@@ -57,6 +85,12 @@ def test_evaluate_fault_grade_uses_highest_matched_rule():
57
85
  assert "或关系" in result["rule"]
58
86
 
59
87
 
88
+ def test_evaluate_fault_grade_handles_invalid_category():
89
+ result = evaluate_fault_grade(unit_category="未知类型", abnormal_txn_count=12000)
90
+ assert result["unit_category"] == "小型单位"
91
+ assert result["fault_level"] == "严重故障"
92
+
93
+
60
94
  def test_evaluate_operation_scene_close_channel():
61
95
  result = evaluate_operation_scene(
62
96
  unit_category="中型单位",
@@ -73,3 +107,9 @@ def test_evaluate_change_strategy_shutdown():
73
107
  assert result["change_scene"] == "场景四"
74
108
  assert "关闭渠道" in result["recommended_action"]
75
109
  assert "小于30分钟" in result["notice_judgement"]
110
+
111
+
112
+ def test_health_endpoint():
113
+ response = client.get("/api/v1/health")
114
+ assert response.status_code == 200
115
+ assert response.json() == {"status": "ok"}