flexllm 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flexllm/__init__.py +224 -0
- flexllm/__main__.py +1096 -0
- flexllm/async_api/__init__.py +9 -0
- flexllm/async_api/concurrent_call.py +100 -0
- flexllm/async_api/concurrent_executor.py +1036 -0
- flexllm/async_api/core.py +373 -0
- flexllm/async_api/interface.py +12 -0
- flexllm/async_api/progress.py +277 -0
- flexllm/base_client.py +988 -0
- flexllm/batch_tools/__init__.py +16 -0
- flexllm/batch_tools/folder_processor.py +317 -0
- flexllm/batch_tools/table_processor.py +363 -0
- flexllm/cache/__init__.py +10 -0
- flexllm/cache/response_cache.py +293 -0
- flexllm/chain_of_thought_client.py +1120 -0
- flexllm/claudeclient.py +402 -0
- flexllm/client_pool.py +698 -0
- flexllm/geminiclient.py +563 -0
- flexllm/llm_client.py +523 -0
- flexllm/llm_parser.py +60 -0
- flexllm/mllm_client.py +559 -0
- flexllm/msg_processors/__init__.py +174 -0
- flexllm/msg_processors/image_processor.py +729 -0
- flexllm/msg_processors/image_processor_helper.py +485 -0
- flexllm/msg_processors/messages_processor.py +341 -0
- flexllm/msg_processors/unified_processor.py +1404 -0
- flexllm/openaiclient.py +256 -0
- flexllm/pricing/__init__.py +104 -0
- flexllm/pricing/data.json +1201 -0
- flexllm/pricing/updater.py +223 -0
- flexllm/provider_router.py +213 -0
- flexllm/token_counter.py +270 -0
- flexllm/utils/__init__.py +1 -0
- flexllm/utils/core.py +41 -0
- flexllm-0.3.3.dist-info/METADATA +573 -0
- flexllm-0.3.3.dist-info/RECORD +39 -0
- flexllm-0.3.3.dist-info/WHEEL +4 -0
- flexllm-0.3.3.dist-info/entry_points.txt +3 -0
- flexllm-0.3.3.dist-info/licenses/LICENSE +201 -0
flexllm/client_pool.py
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLMClientPool - 多 Endpoint 客户端池
|
|
3
|
+
|
|
4
|
+
提供多个 LLM endpoint 的负载均衡和故障转移能力,接口与 LLMClient 一致。
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
# 方式1:传入 endpoints 配置
|
|
8
|
+
pool = LLMClientPool(
|
|
9
|
+
endpoints=[
|
|
10
|
+
{"base_url": "http://api1.com/v1", "api_key": "key1", "model": "qwen"},
|
|
11
|
+
{"base_url": "http://api2.com/v1", "api_key": "key2", "model": "qwen"},
|
|
12
|
+
],
|
|
13
|
+
load_balance="round_robin",
|
|
14
|
+
fallback=True,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# 方式2:传入已有的 clients
|
|
18
|
+
pool = LLMClientPool(
|
|
19
|
+
clients=[client1, client2],
|
|
20
|
+
load_balance="round_robin",
|
|
21
|
+
fallback=True,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# 接口与 LLMClient 一致
|
|
25
|
+
result = await pool.chat_completions(messages)
|
|
26
|
+
results = await pool.chat_completions_batch(messages_list)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
import asyncio
|
|
30
|
+
from typing import List, Dict, Any, Union, Optional, Literal
|
|
31
|
+
from dataclasses import dataclass
|
|
32
|
+
|
|
33
|
+
from loguru import logger
|
|
34
|
+
|
|
35
|
+
from .llm_client import LLMClient
|
|
36
|
+
from .base_client import ChatCompletionResult
|
|
37
|
+
from .provider_router import ProviderRouter, ProviderConfig, Strategy
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class EndpointConfig:
|
|
42
|
+
"""Endpoint 配置"""
|
|
43
|
+
base_url: str
|
|
44
|
+
api_key: str = "EMPTY"
|
|
45
|
+
model: str = None
|
|
46
|
+
provider: Literal["openai", "gemini", "auto"] = "auto"
|
|
47
|
+
weight: float = 1.0
|
|
48
|
+
# 其他 LLMClient 参数
|
|
49
|
+
extra: Dict[str, Any] = None
|
|
50
|
+
|
|
51
|
+
def __post_init__(self):
|
|
52
|
+
if self.extra is None:
|
|
53
|
+
self.extra = {}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class LLMClientPool:
|
|
57
|
+
"""
|
|
58
|
+
多 Endpoint 客户端池
|
|
59
|
+
|
|
60
|
+
功能:
|
|
61
|
+
- 负载均衡:round_robin, weighted, random
|
|
62
|
+
- 故障转移:fallback=True 时自动尝试其他 endpoint
|
|
63
|
+
- 健康检查:自动标记失败的 endpoint,一段时间后尝试恢复
|
|
64
|
+
- 统一接口:与 LLMClient 完全一致的调用方式
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
load_balance: 负载均衡策略
|
|
68
|
+
fallback: 是否启用故障转移
|
|
69
|
+
max_fallback_attempts: 最大故障转移尝试次数
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
endpoints: List[Union[Dict, EndpointConfig]] = None,
|
|
75
|
+
clients: List[LLMClient] = None,
|
|
76
|
+
load_balance: Strategy = "round_robin",
|
|
77
|
+
fallback: bool = True,
|
|
78
|
+
max_fallback_attempts: int = None,
|
|
79
|
+
failure_threshold: int = 3,
|
|
80
|
+
recovery_time: float = 60.0,
|
|
81
|
+
# 共享的 LLMClient 参数(仅当使用 endpoints 时生效)
|
|
82
|
+
concurrency_limit: int = 10,
|
|
83
|
+
max_qps: int = 1000,
|
|
84
|
+
timeout: int = 120,
|
|
85
|
+
retry_times: int = 3,
|
|
86
|
+
**kwargs,
|
|
87
|
+
):
|
|
88
|
+
"""
|
|
89
|
+
初始化客户端池
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
endpoints: Endpoint 配置列表,每个元素可以是 dict 或 EndpointConfig
|
|
93
|
+
clients: 已创建的 LLMClient 列表(与 endpoints 二选一)
|
|
94
|
+
load_balance: 负载均衡策略
|
|
95
|
+
- "round_robin": 轮询
|
|
96
|
+
- "weighted": 加权随机
|
|
97
|
+
- "random": 随机
|
|
98
|
+
- "fallback": 主备模式
|
|
99
|
+
fallback: 是否启用故障转移(某个 endpoint 失败时尝试其他)
|
|
100
|
+
max_fallback_attempts: 最大故障转移次数,默认为 endpoint 数量
|
|
101
|
+
failure_threshold: 连续失败多少次后标记为不健康
|
|
102
|
+
recovery_time: 不健康后多久尝试恢复(秒)
|
|
103
|
+
concurrency_limit: 每个 client 的并发限制
|
|
104
|
+
max_qps: 每个 client 的 QPS 限制
|
|
105
|
+
timeout: 请求超时时间
|
|
106
|
+
retry_times: 重试次数
|
|
107
|
+
**kwargs: 其他传递给 LLMClient 的参数
|
|
108
|
+
"""
|
|
109
|
+
if not endpoints and not clients:
|
|
110
|
+
raise ValueError("必须提供 endpoints 或 clients")
|
|
111
|
+
if endpoints and clients:
|
|
112
|
+
raise ValueError("endpoints 和 clients 只能二选一")
|
|
113
|
+
|
|
114
|
+
self._fallback = fallback
|
|
115
|
+
self._load_balance = load_balance
|
|
116
|
+
|
|
117
|
+
if clients:
|
|
118
|
+
# 使用已有的 clients
|
|
119
|
+
self._clients = clients
|
|
120
|
+
self._endpoints = [
|
|
121
|
+
EndpointConfig(
|
|
122
|
+
base_url=c._client._base_url,
|
|
123
|
+
api_key=c._client._api_key or "EMPTY",
|
|
124
|
+
model=c._model,
|
|
125
|
+
)
|
|
126
|
+
for c in clients
|
|
127
|
+
]
|
|
128
|
+
else:
|
|
129
|
+
# 从 endpoints 创建 clients
|
|
130
|
+
self._endpoints = []
|
|
131
|
+
self._clients = []
|
|
132
|
+
|
|
133
|
+
for ep in endpoints:
|
|
134
|
+
if isinstance(ep, dict):
|
|
135
|
+
ep = EndpointConfig(**ep)
|
|
136
|
+
self._endpoints.append(ep)
|
|
137
|
+
|
|
138
|
+
# 合并参数
|
|
139
|
+
client_kwargs = {
|
|
140
|
+
"provider": ep.provider,
|
|
141
|
+
"base_url": ep.base_url,
|
|
142
|
+
"api_key": ep.api_key,
|
|
143
|
+
"model": ep.model,
|
|
144
|
+
"concurrency_limit": concurrency_limit,
|
|
145
|
+
"max_qps": max_qps,
|
|
146
|
+
"timeout": timeout,
|
|
147
|
+
"retry_times": retry_times,
|
|
148
|
+
**kwargs,
|
|
149
|
+
**(ep.extra or {}),
|
|
150
|
+
}
|
|
151
|
+
self._clients.append(LLMClient(**client_kwargs))
|
|
152
|
+
|
|
153
|
+
# 创建路由器
|
|
154
|
+
provider_configs = [
|
|
155
|
+
ProviderConfig(
|
|
156
|
+
base_url=ep.base_url,
|
|
157
|
+
api_key=ep.api_key,
|
|
158
|
+
weight=ep.weight,
|
|
159
|
+
model=ep.model,
|
|
160
|
+
)
|
|
161
|
+
for ep in self._endpoints
|
|
162
|
+
]
|
|
163
|
+
self._router = ProviderRouter(
|
|
164
|
+
providers=provider_configs,
|
|
165
|
+
strategy=load_balance,
|
|
166
|
+
failure_threshold=failure_threshold,
|
|
167
|
+
recovery_time=recovery_time,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# endpoint -> client 映射
|
|
171
|
+
self._client_map = {
|
|
172
|
+
ep.base_url: client for ep, client in zip(self._endpoints, self._clients)
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
self._max_fallback_attempts = max_fallback_attempts or len(self._clients)
|
|
176
|
+
|
|
177
|
+
def _get_client(self) -> tuple[LLMClient, ProviderConfig]:
|
|
178
|
+
"""获取下一个可用的 client"""
|
|
179
|
+
provider = self._router.get_next()
|
|
180
|
+
client = self._client_map[provider.base_url]
|
|
181
|
+
return client, provider
|
|
182
|
+
|
|
183
|
+
async def chat_completions(
|
|
184
|
+
self,
|
|
185
|
+
messages: List[dict],
|
|
186
|
+
model: str = None,
|
|
187
|
+
return_raw: bool = False,
|
|
188
|
+
return_usage: bool = False,
|
|
189
|
+
**kwargs,
|
|
190
|
+
) -> Union[str, ChatCompletionResult]:
|
|
191
|
+
"""
|
|
192
|
+
单条聊天完成(支持故障转移)
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
messages: 消息列表
|
|
196
|
+
model: 模型名称(可选,使用 endpoint 配置的默认值)
|
|
197
|
+
return_raw: 是否返回原始响应
|
|
198
|
+
return_usage: 是否返回包含 usage 的结果
|
|
199
|
+
**kwargs: 其他参数
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
与 LLMClient.chat_completions 返回值一致
|
|
203
|
+
"""
|
|
204
|
+
last_error = None
|
|
205
|
+
tried_providers = set()
|
|
206
|
+
|
|
207
|
+
for attempt in range(self._max_fallback_attempts):
|
|
208
|
+
client, provider = self._get_client()
|
|
209
|
+
|
|
210
|
+
# 避免重复尝试同一个 provider
|
|
211
|
+
if provider.base_url in tried_providers:
|
|
212
|
+
# 如果所有 provider 都试过了,退出
|
|
213
|
+
if len(tried_providers) >= len(self._clients):
|
|
214
|
+
break
|
|
215
|
+
continue
|
|
216
|
+
|
|
217
|
+
tried_providers.add(provider.base_url)
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
result = await client.chat_completions(
|
|
221
|
+
messages=messages,
|
|
222
|
+
model=model or provider.model,
|
|
223
|
+
return_raw=return_raw,
|
|
224
|
+
return_usage=return_usage,
|
|
225
|
+
**kwargs,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# 检查是否返回了 RequestResult(表示失败)
|
|
229
|
+
if hasattr(result, 'status') and result.status != 'success':
|
|
230
|
+
raise RuntimeError(f"请求失败: {getattr(result, 'error', result)}")
|
|
231
|
+
|
|
232
|
+
self._router.mark_success(provider)
|
|
233
|
+
return result
|
|
234
|
+
|
|
235
|
+
except Exception as e:
|
|
236
|
+
last_error = e
|
|
237
|
+
self._router.mark_failed(provider)
|
|
238
|
+
logger.warning(f"Endpoint {provider.base_url} 失败: {e}")
|
|
239
|
+
|
|
240
|
+
if not self._fallback:
|
|
241
|
+
raise
|
|
242
|
+
|
|
243
|
+
raise last_error or RuntimeError("所有 endpoint 都失败了")
|
|
244
|
+
|
|
245
|
+
def chat_completions_sync(
|
|
246
|
+
self,
|
|
247
|
+
messages: List[dict],
|
|
248
|
+
model: str = None,
|
|
249
|
+
return_raw: bool = False,
|
|
250
|
+
return_usage: bool = False,
|
|
251
|
+
**kwargs,
|
|
252
|
+
) -> Union[str, ChatCompletionResult]:
|
|
253
|
+
"""同步版本的聊天完成"""
|
|
254
|
+
return asyncio.run(
|
|
255
|
+
self.chat_completions(
|
|
256
|
+
messages=messages,
|
|
257
|
+
model=model,
|
|
258
|
+
return_raw=return_raw,
|
|
259
|
+
return_usage=return_usage,
|
|
260
|
+
**kwargs,
|
|
261
|
+
)
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
async def chat_completions_batch(
|
|
265
|
+
self,
|
|
266
|
+
messages_list: List[List[dict]],
|
|
267
|
+
model: str = None,
|
|
268
|
+
return_raw: bool = False,
|
|
269
|
+
return_usage: bool = False,
|
|
270
|
+
show_progress: bool = True,
|
|
271
|
+
return_summary: bool = False,
|
|
272
|
+
output_jsonl: Optional[str] = None,
|
|
273
|
+
flush_interval: float = 1.0,
|
|
274
|
+
distribute: bool = True,
|
|
275
|
+
**kwargs,
|
|
276
|
+
) -> Union[List[str], List[ChatCompletionResult], tuple]:
|
|
277
|
+
"""
|
|
278
|
+
批量聊天完成(支持负载均衡和故障转移)
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
messages_list: 消息列表的列表
|
|
282
|
+
model: 模型名称
|
|
283
|
+
return_raw: 是否返回原始响应
|
|
284
|
+
return_usage: 是否返回包含 usage 的结果
|
|
285
|
+
show_progress: 是否显示进度条
|
|
286
|
+
return_summary: 是否返回统计摘要
|
|
287
|
+
output_jsonl: 输出文件路径(JSONL)
|
|
288
|
+
flush_interval: 文件刷新间隔(秒)
|
|
289
|
+
distribute: 是否将请求分散到多个 endpoint(True)
|
|
290
|
+
False 时使用单个 endpoint + fallback
|
|
291
|
+
**kwargs: 其他参数
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
与 LLMClient.chat_completions_batch 返回值一致
|
|
295
|
+
"""
|
|
296
|
+
# output_jsonl 扩展名校验
|
|
297
|
+
if output_jsonl and not output_jsonl.endswith(".jsonl"):
|
|
298
|
+
raise ValueError(f"output_jsonl 必须使用 .jsonl 扩展名,当前: {output_jsonl}")
|
|
299
|
+
|
|
300
|
+
if not distribute or len(self._clients) == 1:
|
|
301
|
+
# 单 endpoint 模式:使用 fallback
|
|
302
|
+
return await self._batch_with_fallback(
|
|
303
|
+
messages_list=messages_list,
|
|
304
|
+
model=model,
|
|
305
|
+
return_raw=return_raw,
|
|
306
|
+
return_usage=return_usage,
|
|
307
|
+
show_progress=show_progress,
|
|
308
|
+
return_summary=return_summary,
|
|
309
|
+
output_jsonl=output_jsonl,
|
|
310
|
+
flush_interval=flush_interval,
|
|
311
|
+
**kwargs,
|
|
312
|
+
)
|
|
313
|
+
else:
|
|
314
|
+
# 多 endpoint 分布式模式
|
|
315
|
+
return await self._batch_distributed(
|
|
316
|
+
messages_list=messages_list,
|
|
317
|
+
model=model,
|
|
318
|
+
return_raw=return_raw,
|
|
319
|
+
return_usage=return_usage,
|
|
320
|
+
show_progress=show_progress,
|
|
321
|
+
return_summary=return_summary,
|
|
322
|
+
output_jsonl=output_jsonl,
|
|
323
|
+
flush_interval=flush_interval,
|
|
324
|
+
**kwargs,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
async def _batch_with_fallback(
|
|
328
|
+
self,
|
|
329
|
+
messages_list: List[List[dict]],
|
|
330
|
+
model: str = None,
|
|
331
|
+
return_raw: bool = False,
|
|
332
|
+
return_usage: bool = False,
|
|
333
|
+
show_progress: bool = True,
|
|
334
|
+
return_summary: bool = False,
|
|
335
|
+
output_jsonl: Optional[str] = None,
|
|
336
|
+
flush_interval: float = 1.0,
|
|
337
|
+
**kwargs,
|
|
338
|
+
):
|
|
339
|
+
"""使用单个 endpoint + fallback 的批量调用"""
|
|
340
|
+
last_error = None
|
|
341
|
+
tried_providers = set()
|
|
342
|
+
|
|
343
|
+
for attempt in range(self._max_fallback_attempts):
|
|
344
|
+
client, provider = self._get_client()
|
|
345
|
+
|
|
346
|
+
if provider.base_url in tried_providers:
|
|
347
|
+
if len(tried_providers) >= len(self._clients):
|
|
348
|
+
break
|
|
349
|
+
continue
|
|
350
|
+
|
|
351
|
+
tried_providers.add(provider.base_url)
|
|
352
|
+
|
|
353
|
+
try:
|
|
354
|
+
result = await client.chat_completions_batch(
|
|
355
|
+
messages_list=messages_list,
|
|
356
|
+
model=model or provider.model,
|
|
357
|
+
return_raw=return_raw,
|
|
358
|
+
return_usage=return_usage,
|
|
359
|
+
show_progress=show_progress,
|
|
360
|
+
return_summary=return_summary,
|
|
361
|
+
output_jsonl=output_jsonl,
|
|
362
|
+
flush_interval=flush_interval,
|
|
363
|
+
**kwargs,
|
|
364
|
+
)
|
|
365
|
+
self._router.mark_success(provider)
|
|
366
|
+
return result
|
|
367
|
+
|
|
368
|
+
except Exception as e:
|
|
369
|
+
last_error = e
|
|
370
|
+
self._router.mark_failed(provider)
|
|
371
|
+
logger.warning(f"Endpoint {provider.base_url} 批量调用失败: {e}")
|
|
372
|
+
|
|
373
|
+
if not self._fallback:
|
|
374
|
+
raise
|
|
375
|
+
|
|
376
|
+
raise last_error or RuntimeError("所有 endpoint 都失败了")
|
|
377
|
+
|
|
378
|
+
async def _batch_distributed(
|
|
379
|
+
self,
|
|
380
|
+
messages_list: List[List[dict]],
|
|
381
|
+
model: str = None,
|
|
382
|
+
return_raw: bool = False,
|
|
383
|
+
return_usage: bool = False,
|
|
384
|
+
show_progress: bool = True,
|
|
385
|
+
return_summary: bool = False,
|
|
386
|
+
output_jsonl: Optional[str] = None,
|
|
387
|
+
flush_interval: float = 1.0,
|
|
388
|
+
**kwargs,
|
|
389
|
+
):
|
|
390
|
+
"""
|
|
391
|
+
动态分配:多个 worker 从共享队列取任务
|
|
392
|
+
|
|
393
|
+
每个 client 启动 concurrency_limit 个 worker,所有 worker 从同一个队列
|
|
394
|
+
竞争取任务。快的 client 会自动处理更多任务,实现动态负载均衡。
|
|
395
|
+
"""
|
|
396
|
+
import json
|
|
397
|
+
import time
|
|
398
|
+
from pathlib import Path
|
|
399
|
+
from tqdm import tqdm
|
|
400
|
+
|
|
401
|
+
n = len(messages_list)
|
|
402
|
+
results = [None] * n
|
|
403
|
+
success_count = 0
|
|
404
|
+
failed_count = 0
|
|
405
|
+
cached_count = 0
|
|
406
|
+
start_time = time.time()
|
|
407
|
+
|
|
408
|
+
# 断点续传:读取已完成的记录
|
|
409
|
+
completed_indices = set()
|
|
410
|
+
if output_jsonl:
|
|
411
|
+
output_path = Path(output_jsonl)
|
|
412
|
+
if output_path.exists():
|
|
413
|
+
records = []
|
|
414
|
+
with open(output_path, "r", encoding="utf-8") as f:
|
|
415
|
+
for line in f:
|
|
416
|
+
try:
|
|
417
|
+
record = json.loads(line.strip())
|
|
418
|
+
if record.get("status") == "success" and "input" in record:
|
|
419
|
+
idx = record.get("index")
|
|
420
|
+
if 0 <= idx < n:
|
|
421
|
+
records.append(record)
|
|
422
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
423
|
+
continue
|
|
424
|
+
|
|
425
|
+
# 首尾校验
|
|
426
|
+
file_valid = True
|
|
427
|
+
if records:
|
|
428
|
+
first, last = records[0], records[-1]
|
|
429
|
+
if first["input"] != messages_list[first["index"]]:
|
|
430
|
+
file_valid = False
|
|
431
|
+
elif len(records) > 1 and last["input"] != messages_list[last["index"]]:
|
|
432
|
+
file_valid = False
|
|
433
|
+
|
|
434
|
+
if file_valid:
|
|
435
|
+
for record in records:
|
|
436
|
+
idx = record["index"]
|
|
437
|
+
completed_indices.add(idx)
|
|
438
|
+
results[idx] = record["output"]
|
|
439
|
+
if completed_indices:
|
|
440
|
+
logger.info(f"从文件恢复: 已完成 {len(completed_indices)}/{n}")
|
|
441
|
+
cached_count = len(completed_indices)
|
|
442
|
+
else:
|
|
443
|
+
raise ValueError(
|
|
444
|
+
f"文件校验失败: {output_jsonl} 中的 input 与当前 messages_list 不匹配。"
|
|
445
|
+
f"请删除或重命名该文件后重试。"
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
# 共享任务队列(跳过已完成的)
|
|
449
|
+
queue = asyncio.Queue()
|
|
450
|
+
for idx, msg in enumerate(messages_list):
|
|
451
|
+
if idx not in completed_indices:
|
|
452
|
+
queue.put_nowait((idx, msg))
|
|
453
|
+
|
|
454
|
+
pending_count = queue.qsize()
|
|
455
|
+
if pending_count == 0:
|
|
456
|
+
logger.info("所有任务已完成,无需执行")
|
|
457
|
+
if return_summary:
|
|
458
|
+
return results, {"total": n, "success": n, "failed": 0, "cached": cached_count, "elapsed": 0}
|
|
459
|
+
return results
|
|
460
|
+
|
|
461
|
+
logger.info(f"待执行: {pending_count}/{n}")
|
|
462
|
+
|
|
463
|
+
# 进度条
|
|
464
|
+
pbar = tqdm(total=pending_count, desc="Processing", disable=not show_progress)
|
|
465
|
+
|
|
466
|
+
# 文件写入相关
|
|
467
|
+
file_writer = None
|
|
468
|
+
file_buffer = []
|
|
469
|
+
last_flush_time = time.time()
|
|
470
|
+
|
|
471
|
+
if output_jsonl:
|
|
472
|
+
file_writer = open(output_jsonl, "a", encoding="utf-8")
|
|
473
|
+
|
|
474
|
+
# 用于统计和线程安全更新
|
|
475
|
+
lock = asyncio.Lock()
|
|
476
|
+
|
|
477
|
+
def flush_to_file():
|
|
478
|
+
"""刷新缓冲区到文件"""
|
|
479
|
+
nonlocal file_buffer, last_flush_time
|
|
480
|
+
if file_writer and file_buffer:
|
|
481
|
+
for record in file_buffer:
|
|
482
|
+
file_writer.write(json.dumps(record, ensure_ascii=False) + "\n")
|
|
483
|
+
file_writer.flush()
|
|
484
|
+
file_buffer = []
|
|
485
|
+
last_flush_time = time.time()
|
|
486
|
+
|
|
487
|
+
async def worker(client_idx: int):
|
|
488
|
+
"""单个 worker:循环从队列取任务并执行"""
|
|
489
|
+
nonlocal success_count, failed_count, last_flush_time
|
|
490
|
+
|
|
491
|
+
client = self._clients[client_idx]
|
|
492
|
+
provider = self._router._providers[client_idx].config
|
|
493
|
+
effective_model = model or provider.model
|
|
494
|
+
|
|
495
|
+
while True:
|
|
496
|
+
try:
|
|
497
|
+
idx, msg = queue.get_nowait()
|
|
498
|
+
except asyncio.QueueEmpty:
|
|
499
|
+
break
|
|
500
|
+
|
|
501
|
+
try:
|
|
502
|
+
result = await client.chat_completions(
|
|
503
|
+
messages=msg,
|
|
504
|
+
model=effective_model,
|
|
505
|
+
return_raw=return_raw,
|
|
506
|
+
return_usage=return_usage,
|
|
507
|
+
**kwargs,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# 检查是否返回了 RequestResult(表示失败)
|
|
511
|
+
if hasattr(result, 'status') and result.status != 'success':
|
|
512
|
+
raise RuntimeError(f"请求失败: {getattr(result, 'error', result)}")
|
|
513
|
+
|
|
514
|
+
results[idx] = result
|
|
515
|
+
self._router.mark_success(provider)
|
|
516
|
+
|
|
517
|
+
async with lock:
|
|
518
|
+
success_count += 1
|
|
519
|
+
pbar.update(1)
|
|
520
|
+
|
|
521
|
+
# 写入文件
|
|
522
|
+
if file_writer:
|
|
523
|
+
file_buffer.append({
|
|
524
|
+
"index": idx,
|
|
525
|
+
"output": result,
|
|
526
|
+
"status": "success",
|
|
527
|
+
"input": msg,
|
|
528
|
+
})
|
|
529
|
+
if time.time() - last_flush_time >= flush_interval:
|
|
530
|
+
flush_to_file()
|
|
531
|
+
|
|
532
|
+
except Exception as e:
|
|
533
|
+
logger.warning(f"Task {idx} failed on {provider.base_url}: {e}")
|
|
534
|
+
results[idx] = None
|
|
535
|
+
self._router.mark_failed(provider)
|
|
536
|
+
|
|
537
|
+
async with lock:
|
|
538
|
+
failed_count += 1
|
|
539
|
+
pbar.update(1)
|
|
540
|
+
|
|
541
|
+
# 写入失败记录
|
|
542
|
+
if file_writer:
|
|
543
|
+
file_buffer.append({
|
|
544
|
+
"index": idx,
|
|
545
|
+
"output": None,
|
|
546
|
+
"status": "error",
|
|
547
|
+
"error": str(e),
|
|
548
|
+
"input": msg,
|
|
549
|
+
})
|
|
550
|
+
if time.time() - last_flush_time >= flush_interval:
|
|
551
|
+
flush_to_file()
|
|
552
|
+
|
|
553
|
+
try:
|
|
554
|
+
# 启动所有 worker
|
|
555
|
+
# 每个 client 启动 concurrency_limit 个 worker
|
|
556
|
+
workers = []
|
|
557
|
+
for client_idx, client in enumerate(self._clients):
|
|
558
|
+
# 获取 client 的并发限制
|
|
559
|
+
concurrency = getattr(client._client, '_concurrency_limit', 10)
|
|
560
|
+
for _ in range(concurrency):
|
|
561
|
+
workers.append(worker(client_idx))
|
|
562
|
+
|
|
563
|
+
# 并发执行所有 worker
|
|
564
|
+
await asyncio.gather(*workers)
|
|
565
|
+
|
|
566
|
+
finally:
|
|
567
|
+
# 确保最后的数据写入
|
|
568
|
+
flush_to_file()
|
|
569
|
+
if file_writer:
|
|
570
|
+
file_writer.close()
|
|
571
|
+
pbar.close()
|
|
572
|
+
|
|
573
|
+
if return_summary:
|
|
574
|
+
summary = {
|
|
575
|
+
"total": n,
|
|
576
|
+
"success": success_count + cached_count,
|
|
577
|
+
"failed": failed_count,
|
|
578
|
+
"cached": cached_count,
|
|
579
|
+
"elapsed": time.time() - start_time,
|
|
580
|
+
}
|
|
581
|
+
return results, summary
|
|
582
|
+
|
|
583
|
+
return results
|
|
584
|
+
|
|
585
|
+
def chat_completions_batch_sync(
|
|
586
|
+
self,
|
|
587
|
+
messages_list: List[List[dict]],
|
|
588
|
+
model: str = None,
|
|
589
|
+
return_raw: bool = False,
|
|
590
|
+
return_usage: bool = False,
|
|
591
|
+
show_progress: bool = True,
|
|
592
|
+
return_summary: bool = False,
|
|
593
|
+
output_jsonl: Optional[str] = None,
|
|
594
|
+
flush_interval: float = 1.0,
|
|
595
|
+
distribute: bool = True,
|
|
596
|
+
**kwargs,
|
|
597
|
+
) -> Union[List[str], List[ChatCompletionResult], tuple]:
|
|
598
|
+
"""同步版本的批量聊天完成"""
|
|
599
|
+
return asyncio.run(
|
|
600
|
+
self.chat_completions_batch(
|
|
601
|
+
messages_list=messages_list,
|
|
602
|
+
model=model,
|
|
603
|
+
return_raw=return_raw,
|
|
604
|
+
return_usage=return_usage,
|
|
605
|
+
show_progress=show_progress,
|
|
606
|
+
return_summary=return_summary,
|
|
607
|
+
output_jsonl=output_jsonl,
|
|
608
|
+
flush_interval=flush_interval,
|
|
609
|
+
distribute=distribute,
|
|
610
|
+
**kwargs,
|
|
611
|
+
)
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
async def chat_completions_stream(
|
|
615
|
+
self,
|
|
616
|
+
messages: List[dict],
|
|
617
|
+
model: str = None,
|
|
618
|
+
return_usage: bool = False,
|
|
619
|
+
**kwargs,
|
|
620
|
+
):
|
|
621
|
+
"""
|
|
622
|
+
流式聊天完成(支持故障转移)
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
messages: 消息列表
|
|
626
|
+
model: 模型名称
|
|
627
|
+
return_usage: 是否返回 usage 信息
|
|
628
|
+
**kwargs: 其他参数
|
|
629
|
+
|
|
630
|
+
Yields:
|
|
631
|
+
与 LLMClient.chat_completions_stream 一致
|
|
632
|
+
"""
|
|
633
|
+
last_error = None
|
|
634
|
+
tried_providers = set()
|
|
635
|
+
|
|
636
|
+
for attempt in range(self._max_fallback_attempts):
|
|
637
|
+
client, provider = self._get_client()
|
|
638
|
+
|
|
639
|
+
if provider.base_url in tried_providers:
|
|
640
|
+
if len(tried_providers) >= len(self._clients):
|
|
641
|
+
break
|
|
642
|
+
continue
|
|
643
|
+
|
|
644
|
+
tried_providers.add(provider.base_url)
|
|
645
|
+
|
|
646
|
+
try:
|
|
647
|
+
async for chunk in client.chat_completions_stream(
|
|
648
|
+
messages=messages,
|
|
649
|
+
model=model or provider.model,
|
|
650
|
+
return_usage=return_usage,
|
|
651
|
+
**kwargs,
|
|
652
|
+
):
|
|
653
|
+
yield chunk
|
|
654
|
+
self._router.mark_success(provider)
|
|
655
|
+
return
|
|
656
|
+
|
|
657
|
+
except Exception as e:
|
|
658
|
+
last_error = e
|
|
659
|
+
self._router.mark_failed(provider)
|
|
660
|
+
logger.warning(f"Endpoint {provider.base_url} 流式调用失败: {e}")
|
|
661
|
+
|
|
662
|
+
if not self._fallback:
|
|
663
|
+
raise
|
|
664
|
+
|
|
665
|
+
raise last_error or RuntimeError("所有 endpoint 都失败了")
|
|
666
|
+
|
|
667
|
+
@property
|
|
668
|
+
def stats(self) -> dict:
|
|
669
|
+
"""返回池的统计信息"""
|
|
670
|
+
return {
|
|
671
|
+
"load_balance": self._load_balance,
|
|
672
|
+
"fallback": self._fallback,
|
|
673
|
+
"num_endpoints": len(self._clients),
|
|
674
|
+
"router_stats": self._router.stats,
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
def close(self):
|
|
678
|
+
"""关闭所有客户端"""
|
|
679
|
+
for client in self._clients:
|
|
680
|
+
client.close()
|
|
681
|
+
|
|
682
|
+
def __enter__(self):
|
|
683
|
+
return self
|
|
684
|
+
|
|
685
|
+
def __exit__(self, *args):
|
|
686
|
+
self.close()
|
|
687
|
+
|
|
688
|
+
async def __aenter__(self):
|
|
689
|
+
return self
|
|
690
|
+
|
|
691
|
+
async def __aexit__(self, *args):
|
|
692
|
+
self.close()
|
|
693
|
+
|
|
694
|
+
def __repr__(self) -> str:
|
|
695
|
+
return (
|
|
696
|
+
f"LLMClientPool(endpoints={len(self._clients)}, "
|
|
697
|
+
f"load_balance='{self._load_balance}', fallback={self._fallback})"
|
|
698
|
+
)
|