flexllm 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flexllm/__init__.py +224 -0
- flexllm/__main__.py +1096 -0
- flexllm/async_api/__init__.py +9 -0
- flexllm/async_api/concurrent_call.py +100 -0
- flexllm/async_api/concurrent_executor.py +1036 -0
- flexllm/async_api/core.py +373 -0
- flexllm/async_api/interface.py +12 -0
- flexllm/async_api/progress.py +277 -0
- flexllm/base_client.py +988 -0
- flexllm/batch_tools/__init__.py +16 -0
- flexllm/batch_tools/folder_processor.py +317 -0
- flexllm/batch_tools/table_processor.py +363 -0
- flexllm/cache/__init__.py +10 -0
- flexllm/cache/response_cache.py +293 -0
- flexllm/chain_of_thought_client.py +1120 -0
- flexllm/claudeclient.py +402 -0
- flexllm/client_pool.py +698 -0
- flexllm/geminiclient.py +563 -0
- flexllm/llm_client.py +523 -0
- flexllm/llm_parser.py +60 -0
- flexllm/mllm_client.py +559 -0
- flexllm/msg_processors/__init__.py +174 -0
- flexllm/msg_processors/image_processor.py +729 -0
- flexllm/msg_processors/image_processor_helper.py +485 -0
- flexllm/msg_processors/messages_processor.py +341 -0
- flexllm/msg_processors/unified_processor.py +1404 -0
- flexllm/openaiclient.py +256 -0
- flexllm/pricing/__init__.py +104 -0
- flexllm/pricing/data.json +1201 -0
- flexllm/pricing/updater.py +223 -0
- flexllm/provider_router.py +213 -0
- flexllm/token_counter.py +270 -0
- flexllm/utils/__init__.py +1 -0
- flexllm/utils/core.py +41 -0
- flexllm-0.3.3.dist-info/METADATA +573 -0
- flexllm-0.3.3.dist-info/RECORD +39 -0
- flexllm-0.3.3.dist-info/WHEEL +4 -0
- flexllm-0.3.3.dist-info/entry_points.txt +3 -0
- flexllm-0.3.3.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,1120 @@
|
|
|
1
|
+
#! /usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Chain of Thought client for orchestrating multiple LLM calls.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import time
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Callable, Dict, Any, Optional, List, Union
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import sys
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
from .openaiclient import OpenAIClient
|
|
19
|
+
from .async_api.progress import ProgressTracker, ProgressBarConfig
|
|
20
|
+
from .async_api.concurrent_call import concurrent_executor
|
|
21
|
+
|
|
22
|
+
# Rich库安全导入和使用
|
|
23
|
+
|
|
24
|
+
from rich.console import Console
|
|
25
|
+
from rich.panel import Panel
|
|
26
|
+
from rich.text import Text
|
|
27
|
+
|
|
28
|
+
# 创建全局console实例
|
|
29
|
+
chain_console = Console(force_terminal=True, width=100, color_system="auto")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def safe_chain_print(*args, **kwargs):
|
|
33
|
+
"""安全的Rich打印函数,用于Chain of Thought模块"""
|
|
34
|
+
try:
|
|
35
|
+
chain_console.print(*args, **kwargs)
|
|
36
|
+
except Exception:
|
|
37
|
+
# 降级到普通print
|
|
38
|
+
import re
|
|
39
|
+
|
|
40
|
+
clean_args = []
|
|
41
|
+
for arg in args:
|
|
42
|
+
if isinstance(arg, str):
|
|
43
|
+
clean_text = re.sub(r"\[/?[^\]]*\]", "", str(arg))
|
|
44
|
+
clean_text = clean_text.encode("ascii", "ignore").decode("ascii")
|
|
45
|
+
clean_args.append(clean_text)
|
|
46
|
+
else:
|
|
47
|
+
clean_args.append(str(arg))
|
|
48
|
+
import builtins
|
|
49
|
+
|
|
50
|
+
builtins.print(*clean_args, **kwargs)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ChainProgressTracker:
|
|
54
|
+
"""链条批处理进度跟踪器,适配我们自己的 ProgressTracker"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, total_chains: int, show_progress: bool = True):
|
|
57
|
+
self.total_chains = total_chains
|
|
58
|
+
self.completed_chains = 0
|
|
59
|
+
self.start_time = time.time()
|
|
60
|
+
self.show_progress = show_progress
|
|
61
|
+
|
|
62
|
+
# 创建内部进度跟踪器
|
|
63
|
+
if self.show_progress:
|
|
64
|
+
self.config = ProgressBarConfig(
|
|
65
|
+
bar_length=40,
|
|
66
|
+
show_percentage=True,
|
|
67
|
+
show_speed=True,
|
|
68
|
+
show_counts=True,
|
|
69
|
+
show_time_stats=True,
|
|
70
|
+
use_colors=True,
|
|
71
|
+
)
|
|
72
|
+
# 创建一个虚拟的结果类来适配接口
|
|
73
|
+
from dataclasses import dataclass
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class ChainResult:
|
|
77
|
+
request_id: int
|
|
78
|
+
data: Any = None
|
|
79
|
+
status: str = "success"
|
|
80
|
+
latency: float = 0.0
|
|
81
|
+
|
|
82
|
+
self.ChainResult = ChainResult
|
|
83
|
+
self.tracker = ProgressTracker(
|
|
84
|
+
total_chains, concurrency=1, config=self.config
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
self.tracker = None
|
|
88
|
+
|
|
89
|
+
def update(
|
|
90
|
+
self, chain_index: int, success: bool = True, execution_time: float = 0.0
|
|
91
|
+
):
|
|
92
|
+
"""更新进度"""
|
|
93
|
+
self.completed_chains += 1
|
|
94
|
+
|
|
95
|
+
if self.tracker:
|
|
96
|
+
result = self.ChainResult(
|
|
97
|
+
request_id=chain_index,
|
|
98
|
+
status="success" if success else "error",
|
|
99
|
+
latency=execution_time,
|
|
100
|
+
)
|
|
101
|
+
self.tracker.update(result)
|
|
102
|
+
|
|
103
|
+
def finish(self):
|
|
104
|
+
"""完成进度跟踪"""
|
|
105
|
+
if self.tracker:
|
|
106
|
+
# 确保最终进度条状态显示
|
|
107
|
+
if self.completed_chains == self.total_chains:
|
|
108
|
+
safe_chain_print() # 换行,保留最终进度条
|
|
109
|
+
|
|
110
|
+
def get_progress_info(self) -> dict:
|
|
111
|
+
"""获取进度信息"""
|
|
112
|
+
elapsed_time = time.time() - self.start_time
|
|
113
|
+
remaining_time = 0.0
|
|
114
|
+
if self.completed_chains > 0:
|
|
115
|
+
avg_time_per_chain = elapsed_time / self.completed_chains
|
|
116
|
+
remaining_chains = self.total_chains - self.completed_chains
|
|
117
|
+
remaining_time = avg_time_per_chain * remaining_chains
|
|
118
|
+
|
|
119
|
+
return {
|
|
120
|
+
"completed": self.completed_chains,
|
|
121
|
+
"total": self.total_chains,
|
|
122
|
+
"progress_percent": (self.completed_chains / self.total_chains * 100)
|
|
123
|
+
if self.total_chains > 0
|
|
124
|
+
else 0,
|
|
125
|
+
"elapsed_time": elapsed_time,
|
|
126
|
+
"remaining_time": remaining_time,
|
|
127
|
+
"rate": self.completed_chains / elapsed_time if elapsed_time > 0 else 0,
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# 为ChainOfThoughtClient创建专用logger
|
|
132
|
+
def setup_chain_logger():
|
|
133
|
+
"""设置ChainOfThoughtClient专用logger"""
|
|
134
|
+
if not hasattr(setup_chain_logger, "_configured"):
|
|
135
|
+
# 创建专用的logger
|
|
136
|
+
chain_logger = logging.getLogger("maque.chain")
|
|
137
|
+
chain_logger.setLevel(logging.INFO)
|
|
138
|
+
|
|
139
|
+
# 创建自定义格式的handler
|
|
140
|
+
if not chain_logger.handlers: # 避免重复添加handler
|
|
141
|
+
handler = logging.StreamHandler(sys.stderr)
|
|
142
|
+
|
|
143
|
+
# 自定义格式器,模仿loguru的简洁格式
|
|
144
|
+
class ChainFormatter(logging.Formatter):
|
|
145
|
+
def format(self, record):
|
|
146
|
+
timestamp = datetime.now().strftime("%H:%M:%S")
|
|
147
|
+
return f"{timestamp} | {record.getMessage()}"
|
|
148
|
+
|
|
149
|
+
handler.setFormatter(ChainFormatter())
|
|
150
|
+
chain_logger.addHandler(handler)
|
|
151
|
+
|
|
152
|
+
# 防止消息传播到root logger(避免重复输出)
|
|
153
|
+
chain_logger.propagate = False
|
|
154
|
+
|
|
155
|
+
setup_chain_logger._configured = True
|
|
156
|
+
setup_chain_logger._logger = chain_logger
|
|
157
|
+
|
|
158
|
+
return setup_chain_logger._logger
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# 创建专用logger实例
|
|
162
|
+
chain_logger = setup_chain_logger()
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# 为了保持与loguru兼容的接口,创建一个包装类
|
|
166
|
+
class ChainLoggerWrapper:
|
|
167
|
+
def __init__(self, logger):
|
|
168
|
+
self._logger = logger
|
|
169
|
+
|
|
170
|
+
def info(self, message):
|
|
171
|
+
self._logger.info(message)
|
|
172
|
+
|
|
173
|
+
def debug(self, message):
|
|
174
|
+
self._logger.debug(message)
|
|
175
|
+
|
|
176
|
+
def warning(self, message):
|
|
177
|
+
self._logger.warning(message)
|
|
178
|
+
|
|
179
|
+
def error(self, message):
|
|
180
|
+
self._logger.error(message)
|
|
181
|
+
|
|
182
|
+
def success(self, message):
|
|
183
|
+
# 对于success级别,我们使用info
|
|
184
|
+
self._logger.info(message)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# 使用包装后的logger
|
|
188
|
+
chain_logger = ChainLoggerWrapper(chain_logger)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class StepStatus(Enum):
|
|
192
|
+
"""步骤执行状态枚举"""
|
|
193
|
+
|
|
194
|
+
RUNNING = "running" # 正在执行
|
|
195
|
+
COMPLETED = "completed" # 执行完成
|
|
196
|
+
FAILED = "failed" # 执行失败
|
|
197
|
+
TIMEOUT = "timeout" # 执行超时
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class ChainStatus(Enum):
|
|
201
|
+
"""链条执行状态枚举"""
|
|
202
|
+
|
|
203
|
+
RUNNING = "running" # 正在执行
|
|
204
|
+
COMPLETED = "completed" # 执行完成
|
|
205
|
+
FAILED = "failed" # 执行失败
|
|
206
|
+
TIMEOUT = "timeout" # 执行超时
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@dataclass
|
|
210
|
+
class ExecutionConfig:
|
|
211
|
+
"""
|
|
212
|
+
执行配置类。
|
|
213
|
+
|
|
214
|
+
Attributes:
|
|
215
|
+
step_timeout: 单个步骤的超时时间(秒),None表示无超时
|
|
216
|
+
chain_timeout: 整个链条的超时时间(秒),None表示无超时
|
|
217
|
+
max_retries: 单个步骤的最大重试次数
|
|
218
|
+
retry_delay: 重试间隔时间(秒)
|
|
219
|
+
enable_monitoring: 是否启用监控
|
|
220
|
+
log_level: 日志级别 ("DEBUG", "INFO", "WARNING", "ERROR")
|
|
221
|
+
enable_progress: 是否显示进度信息
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
step_timeout: Optional[float] = None
|
|
225
|
+
chain_timeout: Optional[float] = None
|
|
226
|
+
max_retries: int = 0
|
|
227
|
+
retry_delay: float = 1.0
|
|
228
|
+
enable_monitoring: bool = True
|
|
229
|
+
log_level: str = "WARNING"
|
|
230
|
+
enable_progress: bool = False
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@dataclass
|
|
234
|
+
class StepExecutionInfo:
|
|
235
|
+
"""
|
|
236
|
+
步骤执行信息。
|
|
237
|
+
|
|
238
|
+
Attributes:
|
|
239
|
+
step_name: 步骤名称
|
|
240
|
+
status: 执行状态
|
|
241
|
+
start_time: 开始时间
|
|
242
|
+
end_time: 结束时间
|
|
243
|
+
execution_time: 执行时间
|
|
244
|
+
retry_count: 重试次数
|
|
245
|
+
error: 错误信息
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
step_name: str
|
|
249
|
+
status: StepStatus = StepStatus.RUNNING
|
|
250
|
+
start_time: Optional[float] = None
|
|
251
|
+
end_time: Optional[float] = None
|
|
252
|
+
execution_time: Optional[float] = None
|
|
253
|
+
retry_count: int = 0
|
|
254
|
+
error: Optional[str] = None
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@dataclass
|
|
258
|
+
class ChainExecutionInfo:
|
|
259
|
+
"""
|
|
260
|
+
链条执行信息。
|
|
261
|
+
|
|
262
|
+
Attributes:
|
|
263
|
+
chain_id: 链条ID
|
|
264
|
+
status: 执行状态
|
|
265
|
+
start_time: 开始时间
|
|
266
|
+
end_time: 结束时间
|
|
267
|
+
total_execution_time: 总执行时间
|
|
268
|
+
steps_info: 各步骤执行信息
|
|
269
|
+
completed_steps: 已完成步骤数
|
|
270
|
+
error: 错误信息
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
chain_id: str
|
|
274
|
+
status: ChainStatus = ChainStatus.RUNNING
|
|
275
|
+
start_time: Optional[float] = None
|
|
276
|
+
end_time: Optional[float] = None
|
|
277
|
+
total_execution_time: Optional[float] = None
|
|
278
|
+
steps_info: List[StepExecutionInfo] = field(default_factory=list)
|
|
279
|
+
completed_steps: int = 0
|
|
280
|
+
error: Optional[str] = None
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class ChainMonitor(ABC):
|
|
284
|
+
"""链条监控器抽象基类"""
|
|
285
|
+
|
|
286
|
+
@abstractmethod
|
|
287
|
+
async def on_chain_start(self, chain_info: ChainExecutionInfo) -> None:
|
|
288
|
+
"""链条开始执行时调用"""
|
|
289
|
+
pass
|
|
290
|
+
|
|
291
|
+
@abstractmethod
|
|
292
|
+
async def on_chain_end(self, chain_info: ChainExecutionInfo) -> None:
|
|
293
|
+
"""链条执行结束时调用"""
|
|
294
|
+
pass
|
|
295
|
+
|
|
296
|
+
@abstractmethod
|
|
297
|
+
async def on_step_start(
|
|
298
|
+
self, step_info: StepExecutionInfo, chain_info: ChainExecutionInfo
|
|
299
|
+
) -> None:
|
|
300
|
+
"""步骤开始执行时调用"""
|
|
301
|
+
pass
|
|
302
|
+
|
|
303
|
+
@abstractmethod
|
|
304
|
+
async def on_step_end(
|
|
305
|
+
self, step_info: StepExecutionInfo, chain_info: ChainExecutionInfo
|
|
306
|
+
) -> None:
|
|
307
|
+
"""步骤执行结束时调用"""
|
|
308
|
+
pass
|
|
309
|
+
|
|
310
|
+
@abstractmethod
|
|
311
|
+
async def on_error(self, error: Exception, chain_info: ChainExecutionInfo) -> None:
|
|
312
|
+
"""发生错误时调用"""
|
|
313
|
+
pass
|
|
314
|
+
|
|
315
|
+
@abstractmethod
|
|
316
|
+
async def on_timeout(
|
|
317
|
+
self, timeout_type: str, chain_info: ChainExecutionInfo
|
|
318
|
+
) -> None:
|
|
319
|
+
"""超时时调用"""
|
|
320
|
+
pass
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class DefaultChainMonitor(ChainMonitor):
|
|
324
|
+
"""默认链条监控器实现"""
|
|
325
|
+
|
|
326
|
+
def __init__(self, config: ExecutionConfig):
|
|
327
|
+
self.config = config
|
|
328
|
+
self.log_levels = {"DEBUG": 0, "INFO": 1, "WARNING": 2, "ERROR": 3}
|
|
329
|
+
self.current_level = self.log_levels.get(config.log_level, 1)
|
|
330
|
+
|
|
331
|
+
def _get_chain_prefix(self, chain_info: ChainExecutionInfo) -> str:
|
|
332
|
+
"""生成链条前缀"""
|
|
333
|
+
chain_short_id = chain_info.chain_id.split("_")[-1][-4:] # 取最后4位数字
|
|
334
|
+
return f"[链条{chain_short_id}] "
|
|
335
|
+
|
|
336
|
+
def _should_log(self, level: str) -> bool:
|
|
337
|
+
return self.log_levels.get(level, 1) >= self.current_level
|
|
338
|
+
|
|
339
|
+
def _log(
|
|
340
|
+
self, level: str, message: str, chain_info: Optional[ChainExecutionInfo] = None
|
|
341
|
+
) -> None:
|
|
342
|
+
if self._should_log(level):
|
|
343
|
+
# 添加简化的链条ID前缀以区分不同链条的日志
|
|
344
|
+
if chain_info:
|
|
345
|
+
chain_prefix = self._get_chain_prefix(chain_info)
|
|
346
|
+
formatted_message = f"{chain_prefix}{message}"
|
|
347
|
+
else:
|
|
348
|
+
formatted_message = message
|
|
349
|
+
|
|
350
|
+
if level == "DEBUG":
|
|
351
|
+
chain_logger.debug(formatted_message)
|
|
352
|
+
elif level == "INFO":
|
|
353
|
+
chain_logger.info(formatted_message)
|
|
354
|
+
elif level == "WARNING":
|
|
355
|
+
chain_logger.warning(formatted_message)
|
|
356
|
+
elif level == "ERROR":
|
|
357
|
+
chain_logger.error(formatted_message)
|
|
358
|
+
else:
|
|
359
|
+
chain_logger.info(formatted_message)
|
|
360
|
+
|
|
361
|
+
async def on_chain_start(self, chain_info: ChainExecutionInfo) -> None:
|
|
362
|
+
if self.config.enable_monitoring:
|
|
363
|
+
self._log("INFO", f"链条开始执行", chain_info)
|
|
364
|
+
|
|
365
|
+
async def on_chain_end(self, chain_info: ChainExecutionInfo) -> None:
|
|
366
|
+
if self.config.enable_monitoring:
|
|
367
|
+
status_msg = f"链条执行结束 - 状态: {chain_info.status.value}"
|
|
368
|
+
if chain_info.total_execution_time:
|
|
369
|
+
status_msg += f", 总耗时: {chain_info.total_execution_time:.2f}秒"
|
|
370
|
+
status_msg += f", 完成步骤: {chain_info.completed_steps}"
|
|
371
|
+
self._log("INFO", status_msg, chain_info)
|
|
372
|
+
|
|
373
|
+
async def on_step_start(
|
|
374
|
+
self, step_info: StepExecutionInfo, chain_info: ChainExecutionInfo
|
|
375
|
+
) -> None:
|
|
376
|
+
if self.config.enable_monitoring:
|
|
377
|
+
progress = f"({chain_info.completed_steps + 1})"
|
|
378
|
+
self._log(
|
|
379
|
+
"DEBUG", f"步骤 {step_info.step_name} 开始执行 {progress}", chain_info
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
if self.config.enable_progress:
|
|
383
|
+
# 使用带链条ID的格式
|
|
384
|
+
chain_prefix = self._get_chain_prefix(chain_info)
|
|
385
|
+
chain_logger.info(
|
|
386
|
+
f"{chain_prefix}执行进度: 步骤 {chain_info.completed_steps + 1} - {step_info.step_name}"
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
async def on_step_end(
|
|
390
|
+
self, step_info: StepExecutionInfo, chain_info: ChainExecutionInfo
|
|
391
|
+
) -> None:
|
|
392
|
+
if self.config.enable_monitoring:
|
|
393
|
+
status_msg = (
|
|
394
|
+
f"步骤 {step_info.step_name} 执行完成 - 状态: {step_info.status.value}"
|
|
395
|
+
)
|
|
396
|
+
if step_info.execution_time:
|
|
397
|
+
status_msg += f", 耗时: {step_info.execution_time:.2f}秒"
|
|
398
|
+
if step_info.retry_count > 0:
|
|
399
|
+
status_msg += f", 重试次数: {step_info.retry_count}"
|
|
400
|
+
self._log("DEBUG", status_msg, chain_info)
|
|
401
|
+
|
|
402
|
+
async def on_error(self, error: Exception, chain_info: ChainExecutionInfo) -> None:
|
|
403
|
+
if self.config.enable_monitoring:
|
|
404
|
+
self._log("ERROR", f"发生错误: {str(error)}", chain_info)
|
|
405
|
+
|
|
406
|
+
async def on_timeout(
|
|
407
|
+
self, timeout_type: str, chain_info: ChainExecutionInfo
|
|
408
|
+
) -> None:
|
|
409
|
+
if self.config.enable_monitoring:
|
|
410
|
+
self._log("WARNING", f"{timeout_type}超时", chain_info)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class ExecutionController:
|
|
414
|
+
"""执行控制器"""
|
|
415
|
+
|
|
416
|
+
def __init__(self, config: ExecutionConfig):
|
|
417
|
+
self.config = config
|
|
418
|
+
|
|
419
|
+
async def check_timeout(self, start_time: float, timeout: Optional[float]) -> bool:
|
|
420
|
+
"""检查是否超时"""
|
|
421
|
+
if timeout is None:
|
|
422
|
+
return False
|
|
423
|
+
return (time.time() - start_time) > timeout
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
@dataclass
|
|
427
|
+
class StepResult:
|
|
428
|
+
"""
|
|
429
|
+
单个步骤的执行结果。
|
|
430
|
+
|
|
431
|
+
Attributes:
|
|
432
|
+
step_name: 步骤名称
|
|
433
|
+
messages: 发送给LLM的消息列表
|
|
434
|
+
response: LLM的响应内容
|
|
435
|
+
model_params: 使用的模型参数
|
|
436
|
+
execution_time: 执行时间(秒)
|
|
437
|
+
status: 执行状态
|
|
438
|
+
retry_count: 重试次数
|
|
439
|
+
error: 错误信息(如果有)
|
|
440
|
+
"""
|
|
441
|
+
|
|
442
|
+
step_name: str
|
|
443
|
+
messages: List[Dict[str, Any]]
|
|
444
|
+
response: str
|
|
445
|
+
model_params: Dict[str, Any] = field(default_factory=dict)
|
|
446
|
+
execution_time: Optional[float] = None
|
|
447
|
+
status: StepStatus = StepStatus.COMPLETED
|
|
448
|
+
retry_count: int = 0
|
|
449
|
+
error: Optional[str] = None
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
@dataclass
|
|
453
|
+
class Context:
|
|
454
|
+
"""
|
|
455
|
+
链条执行的上下文信息。
|
|
456
|
+
|
|
457
|
+
Attributes:
|
|
458
|
+
query: 初始用户查询(可选,用于通用场景)
|
|
459
|
+
history: 所有步骤的执行历史
|
|
460
|
+
custom_data: 自定义数据字典,用于存储任意额外信息
|
|
461
|
+
execution_info: 链条执行信息(用于监控)
|
|
462
|
+
"""
|
|
463
|
+
|
|
464
|
+
history: List[StepResult] = field(default_factory=list)
|
|
465
|
+
query: Optional[str] = None
|
|
466
|
+
custom_data: Dict[str, Any] = field(default_factory=dict)
|
|
467
|
+
execution_info: Optional[ChainExecutionInfo] = None
|
|
468
|
+
|
|
469
|
+
def get_last_response(self) -> Optional[str]:
|
|
470
|
+
"""获取最后一个步骤的响应。"""
|
|
471
|
+
return self.history[-1].response if self.history else None
|
|
472
|
+
|
|
473
|
+
def get_response_by_step(self, step_name: str) -> Optional[str]:
|
|
474
|
+
"""根据步骤名称获取响应。"""
|
|
475
|
+
for step_result in self.history:
|
|
476
|
+
if step_result.step_name == step_name:
|
|
477
|
+
return step_result.response
|
|
478
|
+
return None
|
|
479
|
+
|
|
480
|
+
def get_step_count(self) -> int:
|
|
481
|
+
"""获取已执行的步骤数量。"""
|
|
482
|
+
return len(self.history)
|
|
483
|
+
|
|
484
|
+
def add_custom_data(self, key: str, value: Any) -> None:
|
|
485
|
+
"""添加自定义数据。"""
|
|
486
|
+
self.custom_data[key] = value
|
|
487
|
+
|
|
488
|
+
def get_custom_data(self, key: str, default: Any = None) -> Any:
|
|
489
|
+
"""获取自定义数据。"""
|
|
490
|
+
return self.custom_data.get(key, default)
|
|
491
|
+
|
|
492
|
+
def get_execution_summary(self) -> Dict[str, Any]:
|
|
493
|
+
"""获取执行摘要信息"""
|
|
494
|
+
total_time = sum(s.execution_time or 0 for s in self.history)
|
|
495
|
+
total_retries = sum(s.retry_count for s in self.history)
|
|
496
|
+
failed_steps = [
|
|
497
|
+
s.step_name for s in self.history if s.status == StepStatus.FAILED
|
|
498
|
+
]
|
|
499
|
+
|
|
500
|
+
return {
|
|
501
|
+
"total_steps": len(self.history),
|
|
502
|
+
"total_execution_time": total_time,
|
|
503
|
+
"total_retries": total_retries,
|
|
504
|
+
"failed_steps": failed_steps,
|
|
505
|
+
"success_rate": len(
|
|
506
|
+
[s for s in self.history if s.status == StepStatus.COMPLETED]
|
|
507
|
+
)
|
|
508
|
+
/ len(self.history)
|
|
509
|
+
if self.history
|
|
510
|
+
else 0,
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
@dataclass
|
|
515
|
+
class Step:
|
|
516
|
+
"""
|
|
517
|
+
定义思想链中的一个步骤。
|
|
518
|
+
|
|
519
|
+
Attributes:
|
|
520
|
+
name: 步骤的唯一名称。
|
|
521
|
+
prepare_messages_fn: 一个可调用对象,接收上下文(Context),返回用于LLM调用的消息列表(List[Dict])。
|
|
522
|
+
get_next_step_fn: 一个可调用对象,接收当前步骤的响应(str)和完整上下文(Context),返回下一个步骤的名称(str)或None表示结束。
|
|
523
|
+
model_params: 调用LLM时使用的模型参数,例如 model, temperature等。
|
|
524
|
+
"""
|
|
525
|
+
|
|
526
|
+
name: str
|
|
527
|
+
prepare_messages_fn: Callable[[Context], List[Dict[str, Any]]]
|
|
528
|
+
get_next_step_fn: Callable[[str, Context], Optional[str]]
|
|
529
|
+
model_params: Dict[str, Any] = field(default_factory=dict)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
@dataclass
|
|
533
|
+
class LinearStep:
|
|
534
|
+
"""
|
|
535
|
+
定义线性链条中的一个步骤(简化版本)。
|
|
536
|
+
|
|
537
|
+
Attributes:
|
|
538
|
+
name: 步骤的唯一名称。
|
|
539
|
+
prepare_messages_fn: 一个可调用对象,接收上下文(Context),返回用于LLM调用的消息列表(List[Dict])。
|
|
540
|
+
model_params: 调用LLM时使用的模型参数,例如 model, temperature等。
|
|
541
|
+
"""
|
|
542
|
+
|
|
543
|
+
name: str
|
|
544
|
+
prepare_messages_fn: Callable[[Context], List[Dict[str, Any]]]
|
|
545
|
+
model_params: Dict[str, Any] = field(default_factory=dict)
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
class ChainOfThoughtClient:
|
|
549
|
+
"""
|
|
550
|
+
一个客户端,用于执行由多个步骤组成的思想链(Chain of Thought)。
|
|
551
|
+
它允许根据一个模型调用的结果动态决定下一个调用的模型和内容。
|
|
552
|
+
"""
|
|
553
|
+
|
|
554
|
+
def __init__(
|
|
555
|
+
self,
|
|
556
|
+
openai_client: OpenAIClient,
|
|
557
|
+
execution_config: Optional[ExecutionConfig] = None,
|
|
558
|
+
):
|
|
559
|
+
"""
|
|
560
|
+
初始化思想链客户端。
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
openai_client: 一个 OpenAIClient 实例,用于执行底层的LLM调用。
|
|
564
|
+
execution_config: 执行配置,如果为None则使用默认配置。
|
|
565
|
+
"""
|
|
566
|
+
self.openai_client = openai_client
|
|
567
|
+
self.steps: Dict[str, Step] = {}
|
|
568
|
+
self.execution_config = execution_config or ExecutionConfig()
|
|
569
|
+
self.monitor: ChainMonitor = DefaultChainMonitor(self.execution_config)
|
|
570
|
+
self._chain_counter = 0
|
|
571
|
+
|
|
572
|
+
def set_monitor(self, monitor: ChainMonitor) -> None:
|
|
573
|
+
"""设置自定义监控器"""
|
|
574
|
+
self.monitor = monitor
|
|
575
|
+
|
|
576
|
+
def add_step(self, step: Step):
|
|
577
|
+
"""
|
|
578
|
+
向客户端注册一个步骤。
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
step: 一个 Step 实例。
|
|
582
|
+
"""
|
|
583
|
+
if step.name in self.steps:
|
|
584
|
+
raise ValueError(f"步骤 '{step.name}' 已存在。请确保每个步骤名称唯一。")
|
|
585
|
+
self.steps[step.name] = step
|
|
586
|
+
|
|
587
|
+
def add_steps(self, steps: List[Step]):
|
|
588
|
+
"""
|
|
589
|
+
向客户端批量注册多个步骤。
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
steps: Step 实例的列表。
|
|
593
|
+
"""
|
|
594
|
+
for step in steps:
|
|
595
|
+
self.add_step(step)
|
|
596
|
+
|
|
597
|
+
def create_linear_chain(
|
|
598
|
+
self, linear_steps: List[LinearStep], chain_name: str = "linear_chain"
|
|
599
|
+
):
|
|
600
|
+
"""
|
|
601
|
+
创建一个线性的步骤链条,每个步骤按顺序执行。
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
linear_steps: LinearStep 实例的列表,按执行顺序排列。
|
|
605
|
+
chain_name: 链条的名称前缀。
|
|
606
|
+
"""
|
|
607
|
+
if not linear_steps:
|
|
608
|
+
raise ValueError("线性链条至少需要一个步骤。")
|
|
609
|
+
|
|
610
|
+
def create_next_step_fn(current_index: int, total_steps: int):
|
|
611
|
+
"""为线性链条创建next_step函数"""
|
|
612
|
+
|
|
613
|
+
def next_step_fn(response: str, context: Context) -> Optional[str]:
|
|
614
|
+
if current_index < total_steps - 1:
|
|
615
|
+
return f"{chain_name}_{current_index + 1}"
|
|
616
|
+
else:
|
|
617
|
+
return None # 结束链条
|
|
618
|
+
|
|
619
|
+
return next_step_fn
|
|
620
|
+
|
|
621
|
+
# 转换LinearStep为Step并注册
|
|
622
|
+
for i, linear_step in enumerate(linear_steps):
|
|
623
|
+
step_name = f"{chain_name}_{i}"
|
|
624
|
+
full_step = Step(
|
|
625
|
+
name=step_name,
|
|
626
|
+
prepare_messages_fn=linear_step.prepare_messages_fn,
|
|
627
|
+
get_next_step_fn=create_next_step_fn(i, len(linear_steps)),
|
|
628
|
+
model_params=linear_step.model_params,
|
|
629
|
+
)
|
|
630
|
+
self.add_step(full_step)
|
|
631
|
+
|
|
632
|
+
return f"{chain_name}_0" # 返回第一个步骤的名称
|
|
633
|
+
|
|
634
|
+
def create_context(self, initial_data: Optional[Dict[str, Any]] = None) -> Context:
|
|
635
|
+
"""
|
|
636
|
+
创建一个新的上下文对象。
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
initial_data: 初始数据字典,可以包含 'query' 和其他自定义字段
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
新创建的Context对象
|
|
643
|
+
"""
|
|
644
|
+
if initial_data is None:
|
|
645
|
+
return Context()
|
|
646
|
+
|
|
647
|
+
# 提取特殊字段
|
|
648
|
+
query = initial_data.get("query")
|
|
649
|
+
|
|
650
|
+
# 剩余字段作为custom_data
|
|
651
|
+
custom_data = {k: v for k, v in initial_data.items() if k != "query"}
|
|
652
|
+
|
|
653
|
+
return Context(query=query, custom_data=custom_data)
|
|
654
|
+
|
|
655
|
+
def _generate_chain_id(self) -> str:
|
|
656
|
+
"""生成链条ID"""
|
|
657
|
+
self._chain_counter += 1
|
|
658
|
+
# 使用毫秒级时间戳确保ID唯一性
|
|
659
|
+
timestamp_ms = int(time.time() * 1000)
|
|
660
|
+
return f"chain_{self._chain_counter}_{timestamp_ms}"
|
|
661
|
+
|
|
662
|
+
async def _execute_step_with_retry(
|
|
663
|
+
self,
|
|
664
|
+
step: Step,
|
|
665
|
+
context: Context,
|
|
666
|
+
controller: ExecutionController,
|
|
667
|
+
step_info: StepExecutionInfo,
|
|
668
|
+
chain_info: ChainExecutionInfo,
|
|
669
|
+
show_step_details: bool = False,
|
|
670
|
+
) -> Optional[str]:
|
|
671
|
+
"""执行单个步骤,包含重试逻辑"""
|
|
672
|
+
last_error = None
|
|
673
|
+
|
|
674
|
+
for attempt in range(self.execution_config.max_retries + 1):
|
|
675
|
+
step_info.retry_count = attempt
|
|
676
|
+
|
|
677
|
+
try:
|
|
678
|
+
# 准备消息
|
|
679
|
+
messages = step.prepare_messages_fn(context)
|
|
680
|
+
|
|
681
|
+
# 显示步骤详细信息 - 输入
|
|
682
|
+
if show_step_details:
|
|
683
|
+
chain_short_id = chain_info.chain_id.split("_")[-1][-4:]
|
|
684
|
+
chain_prefix = f"[链条{chain_short_id}] "
|
|
685
|
+
|
|
686
|
+
chain_logger.info(
|
|
687
|
+
f"{chain_prefix}\n📝 步骤 '{step_info.step_name}' 输入消息:"
|
|
688
|
+
)
|
|
689
|
+
for i, msg in enumerate(messages):
|
|
690
|
+
role = msg.get("role", "unknown")
|
|
691
|
+
content = msg.get("content", "")
|
|
692
|
+
chain_logger.info(
|
|
693
|
+
f"{chain_prefix} {i + 1}. [{role}]: {content[:100]}{'...' if len(content) > 100 else ''}"
|
|
694
|
+
)
|
|
695
|
+
chain_logger.info(f"{chain_prefix}🔧 模型参数: {step.model_params}")
|
|
696
|
+
if attempt > 0:
|
|
697
|
+
chain_logger.warning(f"{chain_prefix}🔄 重试第 {attempt} 次")
|
|
698
|
+
|
|
699
|
+
# 执行LLM调用
|
|
700
|
+
start_time = time.time()
|
|
701
|
+
|
|
702
|
+
# 创建超时任务
|
|
703
|
+
llm_task = self.openai_client.chat_completions(
|
|
704
|
+
messages=messages,
|
|
705
|
+
preprocess_msg=True,
|
|
706
|
+
show_progress=False, # LLM调用的进度条始终关闭
|
|
707
|
+
**step.model_params,
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
if self.execution_config.step_timeout:
|
|
711
|
+
response_content = await asyncio.wait_for(
|
|
712
|
+
llm_task, timeout=self.execution_config.step_timeout
|
|
713
|
+
)
|
|
714
|
+
else:
|
|
715
|
+
response_content = await llm_task
|
|
716
|
+
|
|
717
|
+
execution_time = time.time() - start_time
|
|
718
|
+
step_info.execution_time = execution_time
|
|
719
|
+
|
|
720
|
+
if response_content is None or not isinstance(response_content, str):
|
|
721
|
+
raise ValueError("LLM调用返回空响应")
|
|
722
|
+
|
|
723
|
+
# 显示步骤详细信息 - 输出
|
|
724
|
+
if show_step_details:
|
|
725
|
+
chain_short_id = chain_info.chain_id.split("_")[-1][-4:]
|
|
726
|
+
chain_prefix = f"[链条{chain_short_id}] "
|
|
727
|
+
chain_logger.success(
|
|
728
|
+
f"{chain_prefix}✅ 步骤 '{step_info.step_name}' 输出响应:"
|
|
729
|
+
)
|
|
730
|
+
chain_logger.info(
|
|
731
|
+
f"{chain_prefix} 📄 响应内容: {response_content[:200]}{'...' if len(response_content) > 200 else ''}"
|
|
732
|
+
)
|
|
733
|
+
chain_logger.info(
|
|
734
|
+
f"{chain_prefix} ⏱️ 执行时间: {execution_time:.3f}秒"
|
|
735
|
+
)
|
|
736
|
+
if attempt > 0:
|
|
737
|
+
chain_logger.success(f"{chain_prefix} 🔄 重试成功")
|
|
738
|
+
|
|
739
|
+
step_info.status = StepStatus.COMPLETED
|
|
740
|
+
return response_content
|
|
741
|
+
|
|
742
|
+
except asyncio.TimeoutError:
|
|
743
|
+
step_info.status = StepStatus.TIMEOUT
|
|
744
|
+
step_info.error = (
|
|
745
|
+
f"步骤执行超时({self.execution_config.step_timeout}秒)"
|
|
746
|
+
)
|
|
747
|
+
if show_step_details:
|
|
748
|
+
chain_short_id = chain_info.chain_id.split("_")[-1][-4:]
|
|
749
|
+
chain_prefix = f"[链条{chain_short_id}] "
|
|
750
|
+
chain_logger.error(
|
|
751
|
+
f"{chain_prefix}⏰ 步骤 '{step_info.step_name}' 执行超时"
|
|
752
|
+
)
|
|
753
|
+
chain_logger.warning(
|
|
754
|
+
f"{chain_prefix} ⚠️ 超时时间: {self.execution_config.step_timeout}秒"
|
|
755
|
+
)
|
|
756
|
+
await self.monitor.on_timeout("step", chain_info)
|
|
757
|
+
last_error = TimeoutError(step_info.error)
|
|
758
|
+
|
|
759
|
+
except Exception as e:
|
|
760
|
+
step_info.status = StepStatus.FAILED
|
|
761
|
+
step_info.error = str(e)
|
|
762
|
+
if show_step_details:
|
|
763
|
+
chain_short_id = chain_info.chain_id.split("_")[-1][-4:]
|
|
764
|
+
chain_prefix = f"[链条{chain_short_id}] "
|
|
765
|
+
chain_logger.error(
|
|
766
|
+
f"{chain_prefix}❌ 步骤 '{step_info.step_name}' 执行失败"
|
|
767
|
+
)
|
|
768
|
+
chain_logger.error(
|
|
769
|
+
f"{chain_prefix} 🐛 错误类型: {type(e).__name__}"
|
|
770
|
+
)
|
|
771
|
+
chain_logger.error(f"{chain_prefix} 📝 错误信息: {str(e)}")
|
|
772
|
+
if attempt < self.execution_config.max_retries:
|
|
773
|
+
chain_logger.warning(
|
|
774
|
+
f"{chain_prefix} 🔄 将在 {self.execution_config.retry_delay}秒后重试..."
|
|
775
|
+
)
|
|
776
|
+
last_error = e
|
|
777
|
+
await self.monitor.on_error(e, chain_info)
|
|
778
|
+
|
|
779
|
+
# 如果不是最后一次尝试,等待重试间隔
|
|
780
|
+
if attempt < self.execution_config.max_retries:
|
|
781
|
+
await asyncio.sleep(self.execution_config.retry_delay)
|
|
782
|
+
|
|
783
|
+
# 所有重试都失败了
|
|
784
|
+
if last_error:
|
|
785
|
+
raise last_error
|
|
786
|
+
|
|
787
|
+
return None
|
|
788
|
+
|
|
789
|
+
async def execute_chain(
|
|
790
|
+
self,
|
|
791
|
+
initial_step_name: str,
|
|
792
|
+
initial_context: Optional[Union[Dict[str, Any], Context]] = None,
|
|
793
|
+
show_step_details: bool = False,
|
|
794
|
+
) -> Context:
|
|
795
|
+
"""
|
|
796
|
+
异步执行一个完整的思想链。
|
|
797
|
+
|
|
798
|
+
Args:
|
|
799
|
+
initial_step_name: 起始步骤的名称。
|
|
800
|
+
initial_context: 传递给第一个步骤的初始上下文,可以是字典或Context对象。
|
|
801
|
+
show_step_details: 是否显示每个步骤的详细信息(输入消息、输出响应、执行时间等)。
|
|
802
|
+
|
|
803
|
+
Returns:
|
|
804
|
+
返回包含所有步骤历史记录的最终上下文。
|
|
805
|
+
"""
|
|
806
|
+
if initial_step_name not in self.steps:
|
|
807
|
+
raise ValueError(f"起始步骤 '{initial_step_name}' 未注册。")
|
|
808
|
+
|
|
809
|
+
# 处理初始上下文
|
|
810
|
+
if isinstance(initial_context, Context):
|
|
811
|
+
context = initial_context
|
|
812
|
+
elif isinstance(initial_context, dict):
|
|
813
|
+
context = self.create_context(initial_context)
|
|
814
|
+
else:
|
|
815
|
+
context = Context()
|
|
816
|
+
|
|
817
|
+
# 创建执行信息和控制器
|
|
818
|
+
chain_id = self._generate_chain_id()
|
|
819
|
+
chain_info = ChainExecutionInfo(
|
|
820
|
+
chain_id=chain_id, status=ChainStatus.RUNNING, start_time=time.time()
|
|
821
|
+
)
|
|
822
|
+
context.execution_info = chain_info
|
|
823
|
+
|
|
824
|
+
controller = ExecutionController(self.execution_config)
|
|
825
|
+
|
|
826
|
+
try:
|
|
827
|
+
await self.monitor.on_chain_start(chain_info)
|
|
828
|
+
|
|
829
|
+
current_step_name: Optional[str] = initial_step_name
|
|
830
|
+
chain_start_time = time.time()
|
|
831
|
+
|
|
832
|
+
while current_step_name:
|
|
833
|
+
# 检查链条超时
|
|
834
|
+
if self.execution_config.chain_timeout:
|
|
835
|
+
if await controller.check_timeout(
|
|
836
|
+
chain_start_time, self.execution_config.chain_timeout
|
|
837
|
+
):
|
|
838
|
+
chain_info.status = ChainStatus.TIMEOUT
|
|
839
|
+
await self.monitor.on_timeout("chain", chain_info)
|
|
840
|
+
break
|
|
841
|
+
|
|
842
|
+
if current_step_name not in self.steps:
|
|
843
|
+
raise ValueError(
|
|
844
|
+
f"执行过程中发现未注册的步骤 '{current_step_name}'。"
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
step = self.steps[current_step_name]
|
|
848
|
+
|
|
849
|
+
# 创建步骤执行信息
|
|
850
|
+
step_info = StepExecutionInfo(
|
|
851
|
+
step_name=current_step_name,
|
|
852
|
+
status=StepStatus.RUNNING,
|
|
853
|
+
start_time=time.time(),
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
chain_info.steps_info.append(step_info)
|
|
857
|
+
await self.monitor.on_step_start(step_info, chain_info)
|
|
858
|
+
|
|
859
|
+
try:
|
|
860
|
+
# 执行步骤(包含重试逻辑)
|
|
861
|
+
response_content = await self._execute_step_with_retry(
|
|
862
|
+
step,
|
|
863
|
+
context,
|
|
864
|
+
controller,
|
|
865
|
+
step_info,
|
|
866
|
+
chain_info,
|
|
867
|
+
show_step_details,
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
if response_content is None:
|
|
871
|
+
break # 步骤执行失败或被取消
|
|
872
|
+
|
|
873
|
+
step_info.end_time = time.time()
|
|
874
|
+
step_info.execution_time = step_info.end_time - (
|
|
875
|
+
step_info.start_time or 0
|
|
876
|
+
)
|
|
877
|
+
|
|
878
|
+
# 记录步骤结果
|
|
879
|
+
step_result = StepResult(
|
|
880
|
+
step_name=current_step_name,
|
|
881
|
+
messages=step.prepare_messages_fn(context),
|
|
882
|
+
response=response_content,
|
|
883
|
+
model_params=step.model_params,
|
|
884
|
+
execution_time=step_info.execution_time,
|
|
885
|
+
status=step_info.status,
|
|
886
|
+
retry_count=step_info.retry_count,
|
|
887
|
+
error=step_info.error,
|
|
888
|
+
)
|
|
889
|
+
context.history.append(step_result)
|
|
890
|
+
|
|
891
|
+
chain_info.completed_steps += 1
|
|
892
|
+
await self.monitor.on_step_end(step_info, chain_info)
|
|
893
|
+
|
|
894
|
+
# 决定下一步
|
|
895
|
+
next_step_name = step.get_next_step_fn(response_content, context)
|
|
896
|
+
current_step_name = next_step_name
|
|
897
|
+
|
|
898
|
+
except Exception as e:
|
|
899
|
+
step_info.status = StepStatus.FAILED
|
|
900
|
+
step_info.error = str(e)
|
|
901
|
+
step_info.end_time = time.time()
|
|
902
|
+
|
|
903
|
+
await self.monitor.on_step_end(step_info, chain_info)
|
|
904
|
+
await self.monitor.on_error(e, chain_info)
|
|
905
|
+
|
|
906
|
+
chain_info.status = ChainStatus.FAILED
|
|
907
|
+
chain_info.error = str(e)
|
|
908
|
+
break
|
|
909
|
+
|
|
910
|
+
# 设置链条结束状态
|
|
911
|
+
chain_info.end_time = time.time()
|
|
912
|
+
chain_info.total_execution_time = (
|
|
913
|
+
chain_info.end_time - chain_info.start_time
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
if chain_info.status == ChainStatus.RUNNING:
|
|
917
|
+
chain_info.status = ChainStatus.COMPLETED
|
|
918
|
+
|
|
919
|
+
await self.monitor.on_chain_end(chain_info)
|
|
920
|
+
|
|
921
|
+
except Exception as e:
|
|
922
|
+
chain_info.status = ChainStatus.FAILED
|
|
923
|
+
chain_info.error = str(e)
|
|
924
|
+
chain_info.end_time = time.time()
|
|
925
|
+
if chain_info.start_time:
|
|
926
|
+
chain_info.total_execution_time = (
|
|
927
|
+
chain_info.end_time - chain_info.start_time
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
await self.monitor.on_error(e, chain_info)
|
|
931
|
+
await self.monitor.on_chain_end(chain_info)
|
|
932
|
+
raise
|
|
933
|
+
|
|
934
|
+
return context
|
|
935
|
+
|
|
936
|
+
async def execute_chains_batch(
|
|
937
|
+
self,
|
|
938
|
+
chain_requests: List[Dict[str, Any]],
|
|
939
|
+
show_step_details: bool = False,
|
|
940
|
+
show_progress: bool = True,
|
|
941
|
+
) -> List[Context]:
|
|
942
|
+
"""
|
|
943
|
+
并发执行多个思想链。
|
|
944
|
+
|
|
945
|
+
Args:
|
|
946
|
+
chain_requests: 一个请求列表,每个请求是一个字典,包含 'initial_step_name' 和 'initial_context'。
|
|
947
|
+
例如: [{'initial_step_name': 'step1', 'initial_context': {'query': '你好'}}]
|
|
948
|
+
show_step_details: 是否显示每个步骤的详细信息(输入消息、输出响应、执行时间等)。
|
|
949
|
+
show_progress: 是否显示批处理进度条。
|
|
950
|
+
|
|
951
|
+
Returns:
|
|
952
|
+
一个结果列表,每个元素是对应调用链的最终上下文。
|
|
953
|
+
"""
|
|
954
|
+
total_chains = len(chain_requests)
|
|
955
|
+
if total_chains == 0:
|
|
956
|
+
return []
|
|
957
|
+
|
|
958
|
+
batch_start_time = time.time()
|
|
959
|
+
completed_count = 0
|
|
960
|
+
|
|
961
|
+
# 创建进度跟踪器
|
|
962
|
+
progress_tracker = ChainProgressTracker(total_chains, show_progress)
|
|
963
|
+
|
|
964
|
+
if show_progress and progress_tracker.show_progress:
|
|
965
|
+
safe_chain_print(
|
|
966
|
+
f"[bold green]🚀 开始执行 {total_chains} 个链条的批处理...[/bold green]"
|
|
967
|
+
)
|
|
968
|
+
safe_chain_print(f"[dim]{'=' * 80}[/dim]")
|
|
969
|
+
|
|
970
|
+
# 包装任务以便跟踪进度
|
|
971
|
+
async def wrapped_execute_chain(request_index: int, request: Dict[str, Any]):
|
|
972
|
+
try:
|
|
973
|
+
result = await self.execute_chain(
|
|
974
|
+
initial_step_name=request["initial_step_name"],
|
|
975
|
+
initial_context=request.get("initial_context"),
|
|
976
|
+
show_step_details=show_step_details,
|
|
977
|
+
)
|
|
978
|
+
return request_index, result, None
|
|
979
|
+
except Exception as e:
|
|
980
|
+
# 创建错误上下文
|
|
981
|
+
error_context = Context()
|
|
982
|
+
error_context.execution_info = ChainExecutionInfo(
|
|
983
|
+
chain_id=self._generate_chain_id(),
|
|
984
|
+
status=ChainStatus.FAILED,
|
|
985
|
+
error=str(e),
|
|
986
|
+
)
|
|
987
|
+
return request_index, error_context, e
|
|
988
|
+
|
|
989
|
+
# 创建任务列表
|
|
990
|
+
tasks = []
|
|
991
|
+
for i, request in enumerate(chain_requests):
|
|
992
|
+
task = wrapped_execute_chain(i, request)
|
|
993
|
+
tasks.append(task)
|
|
994
|
+
|
|
995
|
+
# 执行任务并收集结果
|
|
996
|
+
final_results = [None] * total_chains # 预分配结果列表
|
|
997
|
+
|
|
998
|
+
try:
|
|
999
|
+
# 使用asyncio.as_completed来获得实时进度更新
|
|
1000
|
+
for future in asyncio.as_completed(tasks):
|
|
1001
|
+
request_index, result, error = await future
|
|
1002
|
+
final_results[request_index] = result
|
|
1003
|
+
completed_count += 1
|
|
1004
|
+
|
|
1005
|
+
# 计算执行时间
|
|
1006
|
+
execution_time = 0.0
|
|
1007
|
+
success = False
|
|
1008
|
+
if result and result.execution_info:
|
|
1009
|
+
execution_time = result.execution_info.total_execution_time or 0.0
|
|
1010
|
+
success = result.execution_info.status.value == "completed"
|
|
1011
|
+
|
|
1012
|
+
# 更新进度跟踪器
|
|
1013
|
+
progress_tracker.update(
|
|
1014
|
+
chain_index=request_index,
|
|
1015
|
+
success=success,
|
|
1016
|
+
execution_time=execution_time,
|
|
1017
|
+
)
|
|
1018
|
+
|
|
1019
|
+
# 如果启用了监控,记录批处理进度(只在不显示进度条时输出)
|
|
1020
|
+
if self.execution_config.enable_monitoring and not (
|
|
1021
|
+
show_progress and progress_tracker.show_progress
|
|
1022
|
+
):
|
|
1023
|
+
progress_info = progress_tracker.get_progress_info()
|
|
1024
|
+
chain_logger.info(
|
|
1025
|
+
f"批处理进度: {completed_count}/{total_chains} ({progress_info['progress_percent']:.1f}%) - "
|
|
1026
|
+
f"已用时: {progress_info['elapsed_time']:.2f}秒, "
|
|
1027
|
+
f"预计剩余: {progress_info['remaining_time']:.2f}秒"
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
finally:
|
|
1031
|
+
progress_tracker.finish()
|
|
1032
|
+
|
|
1033
|
+
# 最终统计信息
|
|
1034
|
+
progress_info = progress_tracker.get_progress_info()
|
|
1035
|
+
successful_chains = sum(
|
|
1036
|
+
1
|
|
1037
|
+
for result in final_results
|
|
1038
|
+
if result
|
|
1039
|
+
and result.execution_info
|
|
1040
|
+
and result.execution_info.status == ChainStatus.COMPLETED
|
|
1041
|
+
)
|
|
1042
|
+
failed_chains = total_chains - successful_chains
|
|
1043
|
+
|
|
1044
|
+
if show_progress and progress_tracker.show_progress:
|
|
1045
|
+
safe_chain_print(f"[dim]{'=' * 80}[/dim]")
|
|
1046
|
+
|
|
1047
|
+
if chain_console:
|
|
1048
|
+
# 使用Rich Panel创建美观的结果显示
|
|
1049
|
+
try:
|
|
1050
|
+
result_text = Text()
|
|
1051
|
+
result_text.append("批处理执行完成!\n\n", style="bold green")
|
|
1052
|
+
result_text.append(
|
|
1053
|
+
f"📊 总计: {total_chains} 个链条\n", style="cyan"
|
|
1054
|
+
)
|
|
1055
|
+
result_text.append(
|
|
1056
|
+
f"⏱️ 总耗时: {progress_info['elapsed_time']:.2f}秒\n",
|
|
1057
|
+
style="cyan",
|
|
1058
|
+
)
|
|
1059
|
+
result_text.append(
|
|
1060
|
+
f"✅ 成功: {successful_chains} 个\n", style="green"
|
|
1061
|
+
)
|
|
1062
|
+
result_text.append(
|
|
1063
|
+
f"❌ 失败: {failed_chains} 个\n",
|
|
1064
|
+
style="red" if failed_chains > 0 else "dim",
|
|
1065
|
+
)
|
|
1066
|
+
result_text.append(
|
|
1067
|
+
f"📈 成功率: {successful_chains / total_chains * 100:.1f}%\n",
|
|
1068
|
+
style="yellow",
|
|
1069
|
+
)
|
|
1070
|
+
result_text.append(
|
|
1071
|
+
f"⚡ 平均速率: {progress_info['rate']:.2f} 链/秒", style="blue"
|
|
1072
|
+
)
|
|
1073
|
+
|
|
1074
|
+
panel = Panel(
|
|
1075
|
+
result_text,
|
|
1076
|
+
title="[bold blue]Chain of Thought 执行结果[/bold blue]",
|
|
1077
|
+
border_style="green" if failed_chains == 0 else "yellow",
|
|
1078
|
+
)
|
|
1079
|
+
chain_console.print(panel)
|
|
1080
|
+
except Exception:
|
|
1081
|
+
# fallback到简单输出
|
|
1082
|
+
safe_chain_print("[bold green]批处理执行完成![/bold green]")
|
|
1083
|
+
safe_chain_print(f"[cyan]📊 总计: {total_chains} 个链条[/cyan]")
|
|
1084
|
+
safe_chain_print(
|
|
1085
|
+
f"[cyan]⏱️ 总耗时: {progress_info['elapsed_time']:.2f}秒[/cyan]"
|
|
1086
|
+
)
|
|
1087
|
+
safe_chain_print(f"[green]✅ 成功: {successful_chains} 个[/green]")
|
|
1088
|
+
safe_chain_print(
|
|
1089
|
+
f"[red]❌ 失败: {failed_chains} 个[/red]"
|
|
1090
|
+
if failed_chains > 0
|
|
1091
|
+
else f"[dim]❌ 失败: {failed_chains} 个[/dim]"
|
|
1092
|
+
)
|
|
1093
|
+
safe_chain_print(
|
|
1094
|
+
f"[yellow]📈 成功率: {successful_chains / total_chains * 100:.1f}%[/yellow]"
|
|
1095
|
+
)
|
|
1096
|
+
safe_chain_print(
|
|
1097
|
+
f"[blue]⚡ 平均速率: {progress_info['rate']:.2f} 链/秒[/blue]"
|
|
1098
|
+
)
|
|
1099
|
+
else:
|
|
1100
|
+
# 没有Rich库时的简单输出
|
|
1101
|
+
safe_chain_print("批处理执行完成!")
|
|
1102
|
+
safe_chain_print(f"总计: {total_chains} 个链条")
|
|
1103
|
+
safe_chain_print(f"总耗时: {progress_info['elapsed_time']:.2f}秒")
|
|
1104
|
+
safe_chain_print(f"成功: {successful_chains} 个")
|
|
1105
|
+
safe_chain_print(f"失败: {failed_chains} 个")
|
|
1106
|
+
safe_chain_print(
|
|
1107
|
+
f"成功率: {successful_chains / total_chains * 100:.1f}%"
|
|
1108
|
+
)
|
|
1109
|
+
safe_chain_print(f"平均速率: {progress_info['rate']:.2f} 链/秒")
|
|
1110
|
+
|
|
1111
|
+
if self.execution_config.enable_monitoring and not (
|
|
1112
|
+
show_progress and progress_tracker.show_progress
|
|
1113
|
+
):
|
|
1114
|
+
chain_logger.info(
|
|
1115
|
+
f"批处理完成 - 总耗时: {progress_info['elapsed_time']:.2f}秒, "
|
|
1116
|
+
f"成功: {successful_chains}, 失败: {failed_chains}, "
|
|
1117
|
+
f"成功率: {successful_chains / total_chains * 100:.1f}%"
|
|
1118
|
+
)
|
|
1119
|
+
|
|
1120
|
+
return final_results
|