flexllm 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. flexllm/__init__.py +224 -0
  2. flexllm/__main__.py +1096 -0
  3. flexllm/async_api/__init__.py +9 -0
  4. flexllm/async_api/concurrent_call.py +100 -0
  5. flexllm/async_api/concurrent_executor.py +1036 -0
  6. flexllm/async_api/core.py +373 -0
  7. flexllm/async_api/interface.py +12 -0
  8. flexllm/async_api/progress.py +277 -0
  9. flexllm/base_client.py +988 -0
  10. flexllm/batch_tools/__init__.py +16 -0
  11. flexllm/batch_tools/folder_processor.py +317 -0
  12. flexllm/batch_tools/table_processor.py +363 -0
  13. flexllm/cache/__init__.py +10 -0
  14. flexllm/cache/response_cache.py +293 -0
  15. flexllm/chain_of_thought_client.py +1120 -0
  16. flexllm/claudeclient.py +402 -0
  17. flexllm/client_pool.py +698 -0
  18. flexllm/geminiclient.py +563 -0
  19. flexllm/llm_client.py +523 -0
  20. flexllm/llm_parser.py +60 -0
  21. flexllm/mllm_client.py +559 -0
  22. flexllm/msg_processors/__init__.py +174 -0
  23. flexllm/msg_processors/image_processor.py +729 -0
  24. flexllm/msg_processors/image_processor_helper.py +485 -0
  25. flexllm/msg_processors/messages_processor.py +341 -0
  26. flexllm/msg_processors/unified_processor.py +1404 -0
  27. flexllm/openaiclient.py +256 -0
  28. flexllm/pricing/__init__.py +104 -0
  29. flexllm/pricing/data.json +1201 -0
  30. flexllm/pricing/updater.py +223 -0
  31. flexllm/provider_router.py +213 -0
  32. flexllm/token_counter.py +270 -0
  33. flexllm/utils/__init__.py +1 -0
  34. flexllm/utils/core.py +41 -0
  35. flexllm-0.3.3.dist-info/METADATA +573 -0
  36. flexllm-0.3.3.dist-info/RECORD +39 -0
  37. flexllm-0.3.3.dist-info/WHEEL +4 -0
  38. flexllm-0.3.3.dist-info/entry_points.txt +3 -0
  39. flexllm-0.3.3.dist-info/licenses/LICENSE +201 -0
flexllm/client_pool.py ADDED
@@ -0,0 +1,698 @@
1
+ """
2
+ LLMClientPool - 多 Endpoint 客户端池
3
+
4
+ 提供多个 LLM endpoint 的负载均衡和故障转移能力,接口与 LLMClient 一致。
5
+
6
+ Example:
7
+ # 方式1:传入 endpoints 配置
8
+ pool = LLMClientPool(
9
+ endpoints=[
10
+ {"base_url": "http://api1.com/v1", "api_key": "key1", "model": "qwen"},
11
+ {"base_url": "http://api2.com/v1", "api_key": "key2", "model": "qwen"},
12
+ ],
13
+ load_balance="round_robin",
14
+ fallback=True,
15
+ )
16
+
17
+ # 方式2:传入已有的 clients
18
+ pool = LLMClientPool(
19
+ clients=[client1, client2],
20
+ load_balance="round_robin",
21
+ fallback=True,
22
+ )
23
+
24
+ # 接口与 LLMClient 一致
25
+ result = await pool.chat_completions(messages)
26
+ results = await pool.chat_completions_batch(messages_list)
27
+ """
28
+
29
+ import asyncio
30
+ from typing import List, Dict, Any, Union, Optional, Literal
31
+ from dataclasses import dataclass
32
+
33
+ from loguru import logger
34
+
35
+ from .llm_client import LLMClient
36
+ from .base_client import ChatCompletionResult
37
+ from .provider_router import ProviderRouter, ProviderConfig, Strategy
38
+
39
+
40
+ @dataclass
41
+ class EndpointConfig:
42
+ """Endpoint 配置"""
43
+ base_url: str
44
+ api_key: str = "EMPTY"
45
+ model: str = None
46
+ provider: Literal["openai", "gemini", "auto"] = "auto"
47
+ weight: float = 1.0
48
+ # 其他 LLMClient 参数
49
+ extra: Dict[str, Any] = None
50
+
51
+ def __post_init__(self):
52
+ if self.extra is None:
53
+ self.extra = {}
54
+
55
+
56
+ class LLMClientPool:
57
+ """
58
+ 多 Endpoint 客户端池
59
+
60
+ 功能:
61
+ - 负载均衡:round_robin, weighted, random
62
+ - 故障转移:fallback=True 时自动尝试其他 endpoint
63
+ - 健康检查:自动标记失败的 endpoint,一段时间后尝试恢复
64
+ - 统一接口:与 LLMClient 完全一致的调用方式
65
+
66
+ Attributes:
67
+ load_balance: 负载均衡策略
68
+ fallback: 是否启用故障转移
69
+ max_fallback_attempts: 最大故障转移尝试次数
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ endpoints: List[Union[Dict, EndpointConfig]] = None,
75
+ clients: List[LLMClient] = None,
76
+ load_balance: Strategy = "round_robin",
77
+ fallback: bool = True,
78
+ max_fallback_attempts: int = None,
79
+ failure_threshold: int = 3,
80
+ recovery_time: float = 60.0,
81
+ # 共享的 LLMClient 参数(仅当使用 endpoints 时生效)
82
+ concurrency_limit: int = 10,
83
+ max_qps: int = 1000,
84
+ timeout: int = 120,
85
+ retry_times: int = 3,
86
+ **kwargs,
87
+ ):
88
+ """
89
+ 初始化客户端池
90
+
91
+ Args:
92
+ endpoints: Endpoint 配置列表,每个元素可以是 dict 或 EndpointConfig
93
+ clients: 已创建的 LLMClient 列表(与 endpoints 二选一)
94
+ load_balance: 负载均衡策略
95
+ - "round_robin": 轮询
96
+ - "weighted": 加权随机
97
+ - "random": 随机
98
+ - "fallback": 主备模式
99
+ fallback: 是否启用故障转移(某个 endpoint 失败时尝试其他)
100
+ max_fallback_attempts: 最大故障转移次数,默认为 endpoint 数量
101
+ failure_threshold: 连续失败多少次后标记为不健康
102
+ recovery_time: 不健康后多久尝试恢复(秒)
103
+ concurrency_limit: 每个 client 的并发限制
104
+ max_qps: 每个 client 的 QPS 限制
105
+ timeout: 请求超时时间
106
+ retry_times: 重试次数
107
+ **kwargs: 其他传递给 LLMClient 的参数
108
+ """
109
+ if not endpoints and not clients:
110
+ raise ValueError("必须提供 endpoints 或 clients")
111
+ if endpoints and clients:
112
+ raise ValueError("endpoints 和 clients 只能二选一")
113
+
114
+ self._fallback = fallback
115
+ self._load_balance = load_balance
116
+
117
+ if clients:
118
+ # 使用已有的 clients
119
+ self._clients = clients
120
+ self._endpoints = [
121
+ EndpointConfig(
122
+ base_url=c._client._base_url,
123
+ api_key=c._client._api_key or "EMPTY",
124
+ model=c._model,
125
+ )
126
+ for c in clients
127
+ ]
128
+ else:
129
+ # 从 endpoints 创建 clients
130
+ self._endpoints = []
131
+ self._clients = []
132
+
133
+ for ep in endpoints:
134
+ if isinstance(ep, dict):
135
+ ep = EndpointConfig(**ep)
136
+ self._endpoints.append(ep)
137
+
138
+ # 合并参数
139
+ client_kwargs = {
140
+ "provider": ep.provider,
141
+ "base_url": ep.base_url,
142
+ "api_key": ep.api_key,
143
+ "model": ep.model,
144
+ "concurrency_limit": concurrency_limit,
145
+ "max_qps": max_qps,
146
+ "timeout": timeout,
147
+ "retry_times": retry_times,
148
+ **kwargs,
149
+ **(ep.extra or {}),
150
+ }
151
+ self._clients.append(LLMClient(**client_kwargs))
152
+
153
+ # 创建路由器
154
+ provider_configs = [
155
+ ProviderConfig(
156
+ base_url=ep.base_url,
157
+ api_key=ep.api_key,
158
+ weight=ep.weight,
159
+ model=ep.model,
160
+ )
161
+ for ep in self._endpoints
162
+ ]
163
+ self._router = ProviderRouter(
164
+ providers=provider_configs,
165
+ strategy=load_balance,
166
+ failure_threshold=failure_threshold,
167
+ recovery_time=recovery_time,
168
+ )
169
+
170
+ # endpoint -> client 映射
171
+ self._client_map = {
172
+ ep.base_url: client for ep, client in zip(self._endpoints, self._clients)
173
+ }
174
+
175
+ self._max_fallback_attempts = max_fallback_attempts or len(self._clients)
176
+
177
+ def _get_client(self) -> tuple[LLMClient, ProviderConfig]:
178
+ """获取下一个可用的 client"""
179
+ provider = self._router.get_next()
180
+ client = self._client_map[provider.base_url]
181
+ return client, provider
182
+
183
+ async def chat_completions(
184
+ self,
185
+ messages: List[dict],
186
+ model: str = None,
187
+ return_raw: bool = False,
188
+ return_usage: bool = False,
189
+ **kwargs,
190
+ ) -> Union[str, ChatCompletionResult]:
191
+ """
192
+ 单条聊天完成(支持故障转移)
193
+
194
+ Args:
195
+ messages: 消息列表
196
+ model: 模型名称(可选,使用 endpoint 配置的默认值)
197
+ return_raw: 是否返回原始响应
198
+ return_usage: 是否返回包含 usage 的结果
199
+ **kwargs: 其他参数
200
+
201
+ Returns:
202
+ 与 LLMClient.chat_completions 返回值一致
203
+ """
204
+ last_error = None
205
+ tried_providers = set()
206
+
207
+ for attempt in range(self._max_fallback_attempts):
208
+ client, provider = self._get_client()
209
+
210
+ # 避免重复尝试同一个 provider
211
+ if provider.base_url in tried_providers:
212
+ # 如果所有 provider 都试过了,退出
213
+ if len(tried_providers) >= len(self._clients):
214
+ break
215
+ continue
216
+
217
+ tried_providers.add(provider.base_url)
218
+
219
+ try:
220
+ result = await client.chat_completions(
221
+ messages=messages,
222
+ model=model or provider.model,
223
+ return_raw=return_raw,
224
+ return_usage=return_usage,
225
+ **kwargs,
226
+ )
227
+
228
+ # 检查是否返回了 RequestResult(表示失败)
229
+ if hasattr(result, 'status') and result.status != 'success':
230
+ raise RuntimeError(f"请求失败: {getattr(result, 'error', result)}")
231
+
232
+ self._router.mark_success(provider)
233
+ return result
234
+
235
+ except Exception as e:
236
+ last_error = e
237
+ self._router.mark_failed(provider)
238
+ logger.warning(f"Endpoint {provider.base_url} 失败: {e}")
239
+
240
+ if not self._fallback:
241
+ raise
242
+
243
+ raise last_error or RuntimeError("所有 endpoint 都失败了")
244
+
245
+ def chat_completions_sync(
246
+ self,
247
+ messages: List[dict],
248
+ model: str = None,
249
+ return_raw: bool = False,
250
+ return_usage: bool = False,
251
+ **kwargs,
252
+ ) -> Union[str, ChatCompletionResult]:
253
+ """同步版本的聊天完成"""
254
+ return asyncio.run(
255
+ self.chat_completions(
256
+ messages=messages,
257
+ model=model,
258
+ return_raw=return_raw,
259
+ return_usage=return_usage,
260
+ **kwargs,
261
+ )
262
+ )
263
+
264
+ async def chat_completions_batch(
265
+ self,
266
+ messages_list: List[List[dict]],
267
+ model: str = None,
268
+ return_raw: bool = False,
269
+ return_usage: bool = False,
270
+ show_progress: bool = True,
271
+ return_summary: bool = False,
272
+ output_jsonl: Optional[str] = None,
273
+ flush_interval: float = 1.0,
274
+ distribute: bool = True,
275
+ **kwargs,
276
+ ) -> Union[List[str], List[ChatCompletionResult], tuple]:
277
+ """
278
+ 批量聊天完成(支持负载均衡和故障转移)
279
+
280
+ Args:
281
+ messages_list: 消息列表的列表
282
+ model: 模型名称
283
+ return_raw: 是否返回原始响应
284
+ return_usage: 是否返回包含 usage 的结果
285
+ show_progress: 是否显示进度条
286
+ return_summary: 是否返回统计摘要
287
+ output_jsonl: 输出文件路径(JSONL)
288
+ flush_interval: 文件刷新间隔(秒)
289
+ distribute: 是否将请求分散到多个 endpoint(True)
290
+ False 时使用单个 endpoint + fallback
291
+ **kwargs: 其他参数
292
+
293
+ Returns:
294
+ 与 LLMClient.chat_completions_batch 返回值一致
295
+ """
296
+ # output_jsonl 扩展名校验
297
+ if output_jsonl and not output_jsonl.endswith(".jsonl"):
298
+ raise ValueError(f"output_jsonl 必须使用 .jsonl 扩展名,当前: {output_jsonl}")
299
+
300
+ if not distribute or len(self._clients) == 1:
301
+ # 单 endpoint 模式:使用 fallback
302
+ return await self._batch_with_fallback(
303
+ messages_list=messages_list,
304
+ model=model,
305
+ return_raw=return_raw,
306
+ return_usage=return_usage,
307
+ show_progress=show_progress,
308
+ return_summary=return_summary,
309
+ output_jsonl=output_jsonl,
310
+ flush_interval=flush_interval,
311
+ **kwargs,
312
+ )
313
+ else:
314
+ # 多 endpoint 分布式模式
315
+ return await self._batch_distributed(
316
+ messages_list=messages_list,
317
+ model=model,
318
+ return_raw=return_raw,
319
+ return_usage=return_usage,
320
+ show_progress=show_progress,
321
+ return_summary=return_summary,
322
+ output_jsonl=output_jsonl,
323
+ flush_interval=flush_interval,
324
+ **kwargs,
325
+ )
326
+
327
+ async def _batch_with_fallback(
328
+ self,
329
+ messages_list: List[List[dict]],
330
+ model: str = None,
331
+ return_raw: bool = False,
332
+ return_usage: bool = False,
333
+ show_progress: bool = True,
334
+ return_summary: bool = False,
335
+ output_jsonl: Optional[str] = None,
336
+ flush_interval: float = 1.0,
337
+ **kwargs,
338
+ ):
339
+ """使用单个 endpoint + fallback 的批量调用"""
340
+ last_error = None
341
+ tried_providers = set()
342
+
343
+ for attempt in range(self._max_fallback_attempts):
344
+ client, provider = self._get_client()
345
+
346
+ if provider.base_url in tried_providers:
347
+ if len(tried_providers) >= len(self._clients):
348
+ break
349
+ continue
350
+
351
+ tried_providers.add(provider.base_url)
352
+
353
+ try:
354
+ result = await client.chat_completions_batch(
355
+ messages_list=messages_list,
356
+ model=model or provider.model,
357
+ return_raw=return_raw,
358
+ return_usage=return_usage,
359
+ show_progress=show_progress,
360
+ return_summary=return_summary,
361
+ output_jsonl=output_jsonl,
362
+ flush_interval=flush_interval,
363
+ **kwargs,
364
+ )
365
+ self._router.mark_success(provider)
366
+ return result
367
+
368
+ except Exception as e:
369
+ last_error = e
370
+ self._router.mark_failed(provider)
371
+ logger.warning(f"Endpoint {provider.base_url} 批量调用失败: {e}")
372
+
373
+ if not self._fallback:
374
+ raise
375
+
376
+ raise last_error or RuntimeError("所有 endpoint 都失败了")
377
+
378
+ async def _batch_distributed(
379
+ self,
380
+ messages_list: List[List[dict]],
381
+ model: str = None,
382
+ return_raw: bool = False,
383
+ return_usage: bool = False,
384
+ show_progress: bool = True,
385
+ return_summary: bool = False,
386
+ output_jsonl: Optional[str] = None,
387
+ flush_interval: float = 1.0,
388
+ **kwargs,
389
+ ):
390
+ """
391
+ 动态分配:多个 worker 从共享队列取任务
392
+
393
+ 每个 client 启动 concurrency_limit 个 worker,所有 worker 从同一个队列
394
+ 竞争取任务。快的 client 会自动处理更多任务,实现动态负载均衡。
395
+ """
396
+ import json
397
+ import time
398
+ from pathlib import Path
399
+ from tqdm import tqdm
400
+
401
+ n = len(messages_list)
402
+ results = [None] * n
403
+ success_count = 0
404
+ failed_count = 0
405
+ cached_count = 0
406
+ start_time = time.time()
407
+
408
+ # 断点续传:读取已完成的记录
409
+ completed_indices = set()
410
+ if output_jsonl:
411
+ output_path = Path(output_jsonl)
412
+ if output_path.exists():
413
+ records = []
414
+ with open(output_path, "r", encoding="utf-8") as f:
415
+ for line in f:
416
+ try:
417
+ record = json.loads(line.strip())
418
+ if record.get("status") == "success" and "input" in record:
419
+ idx = record.get("index")
420
+ if 0 <= idx < n:
421
+ records.append(record)
422
+ except (json.JSONDecodeError, KeyError, TypeError):
423
+ continue
424
+
425
+ # 首尾校验
426
+ file_valid = True
427
+ if records:
428
+ first, last = records[0], records[-1]
429
+ if first["input"] != messages_list[first["index"]]:
430
+ file_valid = False
431
+ elif len(records) > 1 and last["input"] != messages_list[last["index"]]:
432
+ file_valid = False
433
+
434
+ if file_valid:
435
+ for record in records:
436
+ idx = record["index"]
437
+ completed_indices.add(idx)
438
+ results[idx] = record["output"]
439
+ if completed_indices:
440
+ logger.info(f"从文件恢复: 已完成 {len(completed_indices)}/{n}")
441
+ cached_count = len(completed_indices)
442
+ else:
443
+ raise ValueError(
444
+ f"文件校验失败: {output_jsonl} 中的 input 与当前 messages_list 不匹配。"
445
+ f"请删除或重命名该文件后重试。"
446
+ )
447
+
448
+ # 共享任务队列(跳过已完成的)
449
+ queue = asyncio.Queue()
450
+ for idx, msg in enumerate(messages_list):
451
+ if idx not in completed_indices:
452
+ queue.put_nowait((idx, msg))
453
+
454
+ pending_count = queue.qsize()
455
+ if pending_count == 0:
456
+ logger.info("所有任务已完成,无需执行")
457
+ if return_summary:
458
+ return results, {"total": n, "success": n, "failed": 0, "cached": cached_count, "elapsed": 0}
459
+ return results
460
+
461
+ logger.info(f"待执行: {pending_count}/{n}")
462
+
463
+ # 进度条
464
+ pbar = tqdm(total=pending_count, desc="Processing", disable=not show_progress)
465
+
466
+ # 文件写入相关
467
+ file_writer = None
468
+ file_buffer = []
469
+ last_flush_time = time.time()
470
+
471
+ if output_jsonl:
472
+ file_writer = open(output_jsonl, "a", encoding="utf-8")
473
+
474
+ # 用于统计和线程安全更新
475
+ lock = asyncio.Lock()
476
+
477
+ def flush_to_file():
478
+ """刷新缓冲区到文件"""
479
+ nonlocal file_buffer, last_flush_time
480
+ if file_writer and file_buffer:
481
+ for record in file_buffer:
482
+ file_writer.write(json.dumps(record, ensure_ascii=False) + "\n")
483
+ file_writer.flush()
484
+ file_buffer = []
485
+ last_flush_time = time.time()
486
+
487
+ async def worker(client_idx: int):
488
+ """单个 worker:循环从队列取任务并执行"""
489
+ nonlocal success_count, failed_count, last_flush_time
490
+
491
+ client = self._clients[client_idx]
492
+ provider = self._router._providers[client_idx].config
493
+ effective_model = model or provider.model
494
+
495
+ while True:
496
+ try:
497
+ idx, msg = queue.get_nowait()
498
+ except asyncio.QueueEmpty:
499
+ break
500
+
501
+ try:
502
+ result = await client.chat_completions(
503
+ messages=msg,
504
+ model=effective_model,
505
+ return_raw=return_raw,
506
+ return_usage=return_usage,
507
+ **kwargs,
508
+ )
509
+
510
+ # 检查是否返回了 RequestResult(表示失败)
511
+ if hasattr(result, 'status') and result.status != 'success':
512
+ raise RuntimeError(f"请求失败: {getattr(result, 'error', result)}")
513
+
514
+ results[idx] = result
515
+ self._router.mark_success(provider)
516
+
517
+ async with lock:
518
+ success_count += 1
519
+ pbar.update(1)
520
+
521
+ # 写入文件
522
+ if file_writer:
523
+ file_buffer.append({
524
+ "index": idx,
525
+ "output": result,
526
+ "status": "success",
527
+ "input": msg,
528
+ })
529
+ if time.time() - last_flush_time >= flush_interval:
530
+ flush_to_file()
531
+
532
+ except Exception as e:
533
+ logger.warning(f"Task {idx} failed on {provider.base_url}: {e}")
534
+ results[idx] = None
535
+ self._router.mark_failed(provider)
536
+
537
+ async with lock:
538
+ failed_count += 1
539
+ pbar.update(1)
540
+
541
+ # 写入失败记录
542
+ if file_writer:
543
+ file_buffer.append({
544
+ "index": idx,
545
+ "output": None,
546
+ "status": "error",
547
+ "error": str(e),
548
+ "input": msg,
549
+ })
550
+ if time.time() - last_flush_time >= flush_interval:
551
+ flush_to_file()
552
+
553
+ try:
554
+ # 启动所有 worker
555
+ # 每个 client 启动 concurrency_limit 个 worker
556
+ workers = []
557
+ for client_idx, client in enumerate(self._clients):
558
+ # 获取 client 的并发限制
559
+ concurrency = getattr(client._client, '_concurrency_limit', 10)
560
+ for _ in range(concurrency):
561
+ workers.append(worker(client_idx))
562
+
563
+ # 并发执行所有 worker
564
+ await asyncio.gather(*workers)
565
+
566
+ finally:
567
+ # 确保最后的数据写入
568
+ flush_to_file()
569
+ if file_writer:
570
+ file_writer.close()
571
+ pbar.close()
572
+
573
+ if return_summary:
574
+ summary = {
575
+ "total": n,
576
+ "success": success_count + cached_count,
577
+ "failed": failed_count,
578
+ "cached": cached_count,
579
+ "elapsed": time.time() - start_time,
580
+ }
581
+ return results, summary
582
+
583
+ return results
584
+
585
+ def chat_completions_batch_sync(
586
+ self,
587
+ messages_list: List[List[dict]],
588
+ model: str = None,
589
+ return_raw: bool = False,
590
+ return_usage: bool = False,
591
+ show_progress: bool = True,
592
+ return_summary: bool = False,
593
+ output_jsonl: Optional[str] = None,
594
+ flush_interval: float = 1.0,
595
+ distribute: bool = True,
596
+ **kwargs,
597
+ ) -> Union[List[str], List[ChatCompletionResult], tuple]:
598
+ """同步版本的批量聊天完成"""
599
+ return asyncio.run(
600
+ self.chat_completions_batch(
601
+ messages_list=messages_list,
602
+ model=model,
603
+ return_raw=return_raw,
604
+ return_usage=return_usage,
605
+ show_progress=show_progress,
606
+ return_summary=return_summary,
607
+ output_jsonl=output_jsonl,
608
+ flush_interval=flush_interval,
609
+ distribute=distribute,
610
+ **kwargs,
611
+ )
612
+ )
613
+
614
+ async def chat_completions_stream(
615
+ self,
616
+ messages: List[dict],
617
+ model: str = None,
618
+ return_usage: bool = False,
619
+ **kwargs,
620
+ ):
621
+ """
622
+ 流式聊天完成(支持故障转移)
623
+
624
+ Args:
625
+ messages: 消息列表
626
+ model: 模型名称
627
+ return_usage: 是否返回 usage 信息
628
+ **kwargs: 其他参数
629
+
630
+ Yields:
631
+ 与 LLMClient.chat_completions_stream 一致
632
+ """
633
+ last_error = None
634
+ tried_providers = set()
635
+
636
+ for attempt in range(self._max_fallback_attempts):
637
+ client, provider = self._get_client()
638
+
639
+ if provider.base_url in tried_providers:
640
+ if len(tried_providers) >= len(self._clients):
641
+ break
642
+ continue
643
+
644
+ tried_providers.add(provider.base_url)
645
+
646
+ try:
647
+ async for chunk in client.chat_completions_stream(
648
+ messages=messages,
649
+ model=model or provider.model,
650
+ return_usage=return_usage,
651
+ **kwargs,
652
+ ):
653
+ yield chunk
654
+ self._router.mark_success(provider)
655
+ return
656
+
657
+ except Exception as e:
658
+ last_error = e
659
+ self._router.mark_failed(provider)
660
+ logger.warning(f"Endpoint {provider.base_url} 流式调用失败: {e}")
661
+
662
+ if not self._fallback:
663
+ raise
664
+
665
+ raise last_error or RuntimeError("所有 endpoint 都失败了")
666
+
667
+ @property
668
+ def stats(self) -> dict:
669
+ """返回池的统计信息"""
670
+ return {
671
+ "load_balance": self._load_balance,
672
+ "fallback": self._fallback,
673
+ "num_endpoints": len(self._clients),
674
+ "router_stats": self._router.stats,
675
+ }
676
+
677
+ def close(self):
678
+ """关闭所有客户端"""
679
+ for client in self._clients:
680
+ client.close()
681
+
682
+ def __enter__(self):
683
+ return self
684
+
685
+ def __exit__(self, *args):
686
+ self.close()
687
+
688
+ async def __aenter__(self):
689
+ return self
690
+
691
+ async def __aexit__(self, *args):
692
+ self.close()
693
+
694
+ def __repr__(self) -> str:
695
+ return (
696
+ f"LLMClientPool(endpoints={len(self._clients)}, "
697
+ f"load_balance='{self._load_balance}', fallback={self._fallback})"
698
+ )