flexllm 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. flexllm/__init__.py +224 -0
  2. flexllm/__main__.py +1096 -0
  3. flexllm/async_api/__init__.py +9 -0
  4. flexllm/async_api/concurrent_call.py +100 -0
  5. flexllm/async_api/concurrent_executor.py +1036 -0
  6. flexllm/async_api/core.py +373 -0
  7. flexllm/async_api/interface.py +12 -0
  8. flexllm/async_api/progress.py +277 -0
  9. flexllm/base_client.py +988 -0
  10. flexllm/batch_tools/__init__.py +16 -0
  11. flexllm/batch_tools/folder_processor.py +317 -0
  12. flexllm/batch_tools/table_processor.py +363 -0
  13. flexllm/cache/__init__.py +10 -0
  14. flexllm/cache/response_cache.py +293 -0
  15. flexllm/chain_of_thought_client.py +1120 -0
  16. flexllm/claudeclient.py +402 -0
  17. flexllm/client_pool.py +698 -0
  18. flexllm/geminiclient.py +563 -0
  19. flexllm/llm_client.py +523 -0
  20. flexllm/llm_parser.py +60 -0
  21. flexllm/mllm_client.py +559 -0
  22. flexllm/msg_processors/__init__.py +174 -0
  23. flexllm/msg_processors/image_processor.py +729 -0
  24. flexllm/msg_processors/image_processor_helper.py +485 -0
  25. flexllm/msg_processors/messages_processor.py +341 -0
  26. flexllm/msg_processors/unified_processor.py +1404 -0
  27. flexllm/openaiclient.py +256 -0
  28. flexllm/pricing/__init__.py +104 -0
  29. flexllm/pricing/data.json +1201 -0
  30. flexllm/pricing/updater.py +223 -0
  31. flexllm/provider_router.py +213 -0
  32. flexllm/token_counter.py +270 -0
  33. flexllm/utils/__init__.py +1 -0
  34. flexllm/utils/core.py +41 -0
  35. flexllm-0.3.3.dist-info/METADATA +573 -0
  36. flexllm-0.3.3.dist-info/RECORD +39 -0
  37. flexllm-0.3.3.dist-info/WHEEL +4 -0
  38. flexllm-0.3.3.dist-info/entry_points.txt +3 -0
  39. flexllm-0.3.3.dist-info/licenses/LICENSE +201 -0
flexllm/base_client.py ADDED
@@ -0,0 +1,988 @@
1
+ """
2
+ LLMClientBase - LLM 客户端抽象基类
3
+
4
+ 提供通用的方法实现,子类只需实现核心的差异化方法。
5
+ """
6
+
7
+ import asyncio
8
+ import json
9
+ import time
10
+ from abc import ABC
11
+ from dataclasses import dataclass, field
12
+ from pathlib import Path
13
+ from typing import TYPE_CHECKING, List, Union, Optional, Any
14
+
15
+ from loguru import logger
16
+
17
+ from .async_api import ConcurrentRequester
18
+ from .msg_processors.image_processor import ImageCacheConfig
19
+ from .msg_processors.messages_processor import messages_preprocess
20
+ from .msg_processors.unified_processor import batch_process_messages as optimized_batch_preprocess
21
+ from .cache import ResponseCache, ResponseCacheConfig
22
+
23
+ if TYPE_CHECKING:
24
+ from .async_api.interface import RequestResult
25
+
26
+
27
+ @dataclass
28
+ class ToolCall:
29
+ """工具调用信息"""
30
+ id: str
31
+ type: str # "function"
32
+ function: dict # {"name": "...", "arguments": "..."}
33
+
34
+
35
+ @dataclass
36
+ class ChatCompletionResult:
37
+ """聊天完成的结果,包含内容和 token 用量信息"""
38
+ content: str
39
+ usage: Optional[dict] = None # {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
40
+ reasoning_content: Optional[str] = None # 思考内容(DeepSeek-R1、Qwen3 等)
41
+ tool_calls: Optional[List["ToolCall"]] = None # 工具调用列表
42
+
43
+
44
+ @dataclass
45
+ class BatchResultItem:
46
+ """批量请求中单条结果,包含索引、内容和 usage"""
47
+ index: int
48
+ content: Optional[str]
49
+ usage: Optional[dict] = None
50
+ status: str = "success" # success, error, cached
51
+ error: Optional[str] = None
52
+ latency: float = 0.0
53
+
54
+
55
+ class LLMClientBase(ABC):
56
+ """
57
+ LLM 客户端抽象基类
58
+
59
+ 子类只需实现 4 个核心方法:
60
+ - _get_url(model, stream) -> str
61
+ - _get_headers() -> dict
62
+ - _build_request_body(messages, model, **kwargs) -> dict
63
+ - _extract_content(response_data) -> str
64
+
65
+ 可选覆盖:
66
+ - _extract_stream_content(data) -> str
67
+ - _get_stream_url(model) -> str
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ base_url: str = None,
73
+ api_key: str = None,
74
+ model: str = None,
75
+ concurrency_limit: int = 10,
76
+ max_qps: int = 1000,
77
+ timeout: int = 120,
78
+ retry_times: int = 3,
79
+ retry_delay: float = 1.0,
80
+ cache_image: bool = False,
81
+ cache_dir: str = "image_cache",
82
+ cache: Union[bool, ResponseCacheConfig, None] = None,
83
+ **kwargs,
84
+ ):
85
+ """
86
+ Args:
87
+ base_url: API 基础 URL
88
+ api_key: API 密钥
89
+ model: 默认模型名称
90
+ concurrency_limit: 并发请求数限制
91
+ max_qps: 最大 QPS
92
+ timeout: 请求超时时间(秒)
93
+ retry_times: 重试次数
94
+ retry_delay: 重试延迟(秒)
95
+ cache_image: 是否缓存图片
96
+ cache_dir: 图片缓存目录
97
+ cache: 响应缓存配置
98
+ - True: 启用缓存(默认 IPC 模式,24小时 TTL)
99
+ - False/None: 禁用缓存(默认)
100
+ - ResponseCacheConfig: 自定义配置
101
+ """
102
+ self._base_url = base_url.rstrip("/") if base_url else None
103
+ self._api_key = api_key
104
+ self._model = model
105
+ self._concurrency_limit = concurrency_limit
106
+ self._timeout = timeout
107
+
108
+ self._client = ConcurrentRequester(
109
+ concurrency_limit=concurrency_limit,
110
+ max_qps=max_qps,
111
+ timeout=timeout,
112
+ retry_times=retry_times,
113
+ retry_delay=retry_delay,
114
+ )
115
+
116
+ self._cache_config = ImageCacheConfig(
117
+ enabled=cache_image,
118
+ cache_dir=cache_dir,
119
+ force_refresh=False,
120
+ retry_failed=False,
121
+ )
122
+
123
+ # 响应缓存
124
+ if cache is True:
125
+ cache = ResponseCacheConfig.ipc() # 默认 IPC 模式,24小时 TTL
126
+ elif cache is None or cache is False:
127
+ cache = ResponseCacheConfig.disabled()
128
+ self._response_cache = ResponseCache(cache) if cache.enabled else None
129
+
130
+ # ========== 核心抽象方法(子类必须实现)==========
131
+
132
+ def _get_url(self, model: str, stream: bool = False) -> str:
133
+ raise NotImplementedError
134
+
135
+ def _get_headers(self) -> dict:
136
+ raise NotImplementedError
137
+
138
+ def _build_request_body(
139
+ self, messages: List[dict], model: str, stream: bool = False, **kwargs
140
+ ) -> dict:
141
+ raise NotImplementedError
142
+
143
+ def _extract_content(self, response_data: dict) -> Optional[str]:
144
+ raise NotImplementedError
145
+
146
+ def _extract_usage(self, response_data: dict) -> Optional[dict]:
147
+ """提取 usage 信息(子类可覆盖)"""
148
+ if not response_data:
149
+ return None
150
+ return response_data.get("usage")
151
+
152
+ def _extract_tool_calls(self, response_data: dict) -> Optional[List[ToolCall]]:
153
+ """提取工具调用信息(子类可覆盖)"""
154
+ return None
155
+
156
+ # ========== 可选覆盖的钩子方法 ==========
157
+
158
+ def _extract_stream_content(self, data: dict) -> Optional[str]:
159
+ return self._extract_content(data)
160
+
161
+ def _get_stream_url(self, model: str) -> str:
162
+ return self._get_url(model, stream=True)
163
+
164
+ # ========== 通用工具方法 ==========
165
+
166
+ def _get_effective_model(self, model: str = None) -> str:
167
+ effective_model = model or self._model
168
+ if not effective_model:
169
+ raise ValueError("必须提供 model 参数或在初始化时指定 model")
170
+ return effective_model
171
+
172
+ async def _preprocess_messages(
173
+ self, messages: List[dict], preprocess_msg: bool = False
174
+ ) -> List[dict]:
175
+ """消息预处理(图片转 base64 等)"""
176
+ if preprocess_msg:
177
+ return await messages_preprocess(
178
+ messages, preprocess_msg=preprocess_msg, cache_config=self._cache_config
179
+ )
180
+ return messages
181
+
182
+ async def _preprocess_messages_batch(
183
+ self, messages_list: List[List[dict]], preprocess_msg: bool = False
184
+ ) -> List[List[dict]]:
185
+ """批量消息预处理"""
186
+ if preprocess_msg:
187
+ return await optimized_batch_preprocess(
188
+ messages_list, max_concurrent=self._concurrency_limit, cache_config=self._cache_config
189
+ )
190
+ return messages_list
191
+
192
+ # ========== 通用接口实现 ==========
193
+
194
+ async def chat_completions(
195
+ self,
196
+ messages: List[dict],
197
+ model: str = None,
198
+ return_raw: bool = False,
199
+ return_usage: bool = False,
200
+ show_progress: bool = False,
201
+ preprocess_msg: bool = False,
202
+ url: str = None,
203
+ **kwargs,
204
+ ) -> Union[str, ChatCompletionResult, "RequestResult"]:
205
+ """
206
+ 单条聊天完成
207
+
208
+ Args:
209
+ messages: 消息列表
210
+ model: 模型名称
211
+ return_raw: 是否返回原始响应(RequestResult)
212
+ return_usage: 是否返回包含 usage 的结果(ChatCompletionResult)
213
+ show_progress: 是否显示进度条
214
+ preprocess_msg: 是否预处理消息
215
+ url: 自定义请求 URL,默认使用 _get_url() 生成
216
+
217
+ Returns:
218
+ - return_raw=True: RequestResult 原始响应
219
+ - return_usage=True: ChatCompletionResult(content, usage, reasoning_content)
220
+ - 默认: str 内容文本
221
+
222
+ Note:
223
+ 缓存由初始化时的 cache 参数控制,return_raw/return_usage 时自动跳过缓存
224
+ """
225
+ effective_model = self._get_effective_model(model)
226
+ messages = await self._preprocess_messages(messages, preprocess_msg)
227
+
228
+ # 检查缓存(缓存不包含 usage 信息,return_raw/return_usage 时跳过缓存)
229
+ use_cache = self._response_cache is not None and not return_raw and not return_usage
230
+ if use_cache:
231
+ cached = self._response_cache.get(messages, model=effective_model, **kwargs)
232
+ if cached is not None:
233
+ return cached
234
+
235
+ body = self._build_request_body(messages, effective_model, stream=False, **kwargs)
236
+ request_params = {"json": body, "headers": self._get_headers()}
237
+ effective_url = url or self._get_url(effective_model, stream=False)
238
+
239
+ results, _ = await self._client.process_requests(
240
+ request_params=[request_params], url=effective_url, method="POST", show_progress=show_progress
241
+ )
242
+
243
+ data = results[0]
244
+ if return_raw:
245
+ return data
246
+ if data.status == "success":
247
+ content = self._extract_content(data.data)
248
+ # 写入缓存
249
+ if use_cache and content is not None:
250
+ self._response_cache.set(messages, content, model=effective_model, **kwargs)
251
+
252
+ if return_usage:
253
+ usage = self._extract_usage(data.data)
254
+ tool_calls = self._extract_tool_calls(data.data)
255
+ return ChatCompletionResult(content=content, usage=usage, tool_calls=tool_calls)
256
+ return content
257
+ return data
258
+
259
+ def chat_completions_sync(
260
+ self,
261
+ messages: List[dict],
262
+ model: str = None,
263
+ return_raw: bool = False,
264
+ return_usage: bool = False,
265
+ **kwargs,
266
+ ) -> Union[str, ChatCompletionResult, "RequestResult"]:
267
+ """同步版本的聊天完成"""
268
+ return asyncio.run(
269
+ self.chat_completions(
270
+ messages=messages, model=model, return_raw=return_raw, return_usage=return_usage, **kwargs
271
+ )
272
+ )
273
+
274
+ async def chat_completions_batch(
275
+ self,
276
+ messages_list: List[List[dict]],
277
+ model: str = None,
278
+ return_raw: bool = False,
279
+ return_usage: bool = False,
280
+ show_progress: bool = True,
281
+ return_summary: bool = False,
282
+ preprocess_msg: bool = False,
283
+ output_jsonl: Optional[str] = None,
284
+ flush_interval: float = 1.0,
285
+ metadata_list: Optional[List[dict]] = None,
286
+ url: str = None,
287
+ **kwargs,
288
+ ) -> Union[List[str], List[ChatCompletionResult], tuple]:
289
+ """
290
+ 批量聊天完成(支持断点续传)
291
+
292
+ Args:
293
+ messages_list: 消息列表
294
+ model: 模型名称
295
+ return_raw: 是否返回原始响应
296
+ return_usage: 是否返回包含 usage 的结果(ChatCompletionResult 列表)
297
+ show_progress: 是否显示进度条
298
+ return_summary: 是否返回执行摘要
299
+ preprocess_msg: 是否预处理消息
300
+ output_jsonl: 输出文件路径(JSONL 格式),用于持久化保存结果
301
+ flush_interval: 文件刷新间隔(秒),默认 1 秒
302
+ metadata_list: 元数据列表,与 messages_list 等长,每个元素保存到对应输出记录
303
+ url: 自定义请求 URL,默认使用 _get_url() 生成
304
+
305
+ Returns:
306
+ - return_usage=True: List[ChatCompletionResult] 或 (List[ChatCompletionResult], summary)
307
+ - 默认: List[str] 或 (List[str], summary)
308
+
309
+ Note:
310
+ 缓存由初始化时的 cache 参数控制,return_usage=True 时自动跳过缓存
311
+ """
312
+ effective_model = self._get_effective_model(model)
313
+ effective_url = url or self._get_url(effective_model, stream=False)
314
+ headers = self._get_headers()
315
+
316
+ # metadata_list 长度校验
317
+ if metadata_list is not None and len(metadata_list) != len(messages_list):
318
+ raise ValueError(
319
+ f"metadata_list 长度 ({len(metadata_list)}) 必须与 messages_list 长度 ({len(messages_list)}) 一致"
320
+ )
321
+
322
+ # output_jsonl 扩展名校验
323
+ if output_jsonl and not output_jsonl.endswith(".jsonl"):
324
+ raise ValueError(f"output_jsonl 必须使用 .jsonl 扩展名,当前: {output_jsonl}")
325
+
326
+ messages_list = await self._preprocess_messages_batch(messages_list, preprocess_msg)
327
+
328
+ # return_usage 时跳过缓存(缓存不包含 usage 信息)
329
+ use_cache = self._response_cache is not None and not return_usage
330
+
331
+ def extractor(result):
332
+ return self._extract_content(result.data)
333
+
334
+ def extractor_with_usage(result):
335
+ content = self._extract_content(result.data)
336
+ usage = self._extract_usage(result.data)
337
+ tool_calls = self._extract_tool_calls(result.data)
338
+ return ChatCompletionResult(content=content, usage=usage, tool_calls=tool_calls)
339
+
340
+ # 文件输出相关状态
341
+ file_writer = None
342
+ file_buffer = []
343
+ last_flush_time = time.time()
344
+ completed_indices = set()
345
+
346
+ # 如果指定了输出文件,读取已完成的索引(断点续传)
347
+ if output_jsonl:
348
+ output_path = Path(output_jsonl)
349
+ if output_path.exists():
350
+ # 读取所有有效记录
351
+ records = []
352
+ with open(output_path, "r", encoding="utf-8") as f:
353
+ for line in f:
354
+ try:
355
+ record = json.loads(line.strip())
356
+ if record.get("status") == "success" and "input" in record:
357
+ idx = record.get("index")
358
+ if 0 <= idx < len(messages_list):
359
+ records.append(record)
360
+ except (json.JSONDecodeError, KeyError, TypeError):
361
+ continue
362
+
363
+ # 首尾校验:只比较第一条和最后一条的 input
364
+ file_valid = True
365
+ if records:
366
+ first, last = records[0], records[-1]
367
+ if first["input"] != messages_list[first["index"]]:
368
+ file_valid = False
369
+ elif len(records) > 1 and last["input"] != messages_list[last["index"]]:
370
+ file_valid = False
371
+
372
+ if file_valid:
373
+ completed_indices = {r["index"] for r in records}
374
+ if completed_indices:
375
+ logger.info(f"从文件恢复: 已完成 {len(completed_indices)}/{len(messages_list)}")
376
+ else:
377
+ raise ValueError(
378
+ f"文件校验失败: {output_jsonl} 中的 input 与当前 messages_list 不匹配。"
379
+ f"请删除或重命名该文件后重试。"
380
+ )
381
+
382
+ file_writer = open(output_path, "a", encoding="utf-8")
383
+
384
+ def flush_to_file():
385
+ """刷新缓冲区到文件"""
386
+ nonlocal file_buffer, last_flush_time
387
+ if file_writer and file_buffer:
388
+ for record in file_buffer:
389
+ file_writer.write(json.dumps(record, ensure_ascii=False) + "\n")
390
+ file_writer.flush()
391
+ file_buffer = []
392
+ last_flush_time = time.time()
393
+
394
+ def on_file_result(original_idx: int, content: Any, status: str = "success", error: str = None, usage: dict = None):
395
+ """文件输出回调"""
396
+ nonlocal last_flush_time
397
+ if file_writer is None:
398
+ return
399
+ record = {
400
+ "index": original_idx,
401
+ "output": content,
402
+ "status": status,
403
+ "input": messages_list[original_idx],
404
+ }
405
+ if metadata_list is not None:
406
+ record["metadata"] = metadata_list[original_idx]
407
+ if usage is not None:
408
+ record["usage"] = usage
409
+ if error:
410
+ record["error"] = error
411
+ file_buffer.append(record)
412
+ # 基于时间刷新
413
+ if time.time() - last_flush_time >= flush_interval:
414
+ flush_to_file()
415
+
416
+ try:
417
+ # 计算实际需要执行的索引(排除文件中已完成的)
418
+ all_indices = set(range(len(messages_list)))
419
+ indices_to_skip = completed_indices & all_indices
420
+ if indices_to_skip:
421
+ logger.info(f"从文件恢复跳过: {len(indices_to_skip)}/{len(messages_list)}")
422
+
423
+ # 带缓存执行
424
+ if use_cache and self._response_cache:
425
+ # 查询缓存(传递 kwargs 以确保不同参数配置使用不同缓存键)
426
+ cached_responses, uncached_indices = self._response_cache.get_batch(
427
+ messages_list, model=effective_model, **kwargs
428
+ )
429
+
430
+ # 将缓存命中的写入文件(如果文件中没有)
431
+ for i, resp in enumerate(cached_responses):
432
+ if resp is not None and i not in completed_indices:
433
+ on_file_result(i, resp)
434
+
435
+ # 过滤掉文件中已完成的
436
+ actual_uncached = [i for i in uncached_indices if i not in completed_indices]
437
+
438
+ progress = None
439
+ if actual_uncached:
440
+ logger.info(f"待执行: {len(actual_uncached)}/{len(messages_list)}")
441
+
442
+ uncached_messages = [messages_list[i] for i in actual_uncached]
443
+ request_params = [
444
+ {"json": self._build_request_body(m, effective_model, **kwargs), "headers": headers}
445
+ for m in uncached_messages
446
+ ]
447
+
448
+ # 选择提取器
449
+ extract_fn = extractor_with_usage if return_usage else extractor
450
+
451
+ async for batch in self._client.aiter_stream_requests(
452
+ request_params=request_params,
453
+ url=effective_url,
454
+ method="POST",
455
+ show_progress=show_progress,
456
+ total_requests=len(uncached_messages),
457
+ ):
458
+ for result in batch.completed_requests:
459
+ original_idx = actual_uncached[result.request_id]
460
+ # 检查请求状态
461
+ if result.status != "success":
462
+ error_msg = result.data.get("error", "Unknown error") if isinstance(result.data, dict) else str(result.data)
463
+ logger.warning(f"请求失败: {error_msg}")
464
+ cached_responses[original_idx] = None
465
+ on_file_result(original_idx, None, "error", error_msg)
466
+ continue
467
+ try:
468
+ extracted = extract_fn(result)
469
+ cached_responses[original_idx] = extracted
470
+ # 写入缓存(仅当不需要 usage 时,因为缓存不存储 usage)
471
+ if not return_usage:
472
+ self._response_cache.set(
473
+ messages_list[original_idx], extracted, model=effective_model, **kwargs
474
+ )
475
+ # 文件输出(存储 content 和 usage)
476
+ if return_usage:
477
+ on_file_result(original_idx, extracted.content, usage=extracted.usage)
478
+ else:
479
+ on_file_result(original_idx, extracted)
480
+ except Exception as e:
481
+ logger.warning(f"提取结果失败: {e}")
482
+ cached_responses[original_idx] = None
483
+ on_file_result(original_idx, None, "error", str(e))
484
+ if batch.is_final:
485
+ progress = batch.progress
486
+
487
+ responses = cached_responses
488
+ else:
489
+ # 不使用缓存,直接批量执行(流式处理以支持增量保存)
490
+ indices_to_run = [i for i in range(len(messages_list)) if i not in completed_indices]
491
+ responses = [None] * len(messages_list)
492
+
493
+ # 选择提取器
494
+ extract_fn = extractor_with_usage if return_usage else extractor
495
+
496
+ progress = None
497
+ if indices_to_run:
498
+ messages_to_run = [messages_list[i] for i in indices_to_run]
499
+ request_params = [
500
+ {"json": self._build_request_body(m, effective_model, **kwargs), "headers": headers}
501
+ for m in messages_to_run
502
+ ]
503
+ # 使用流式处理,每完成一个请求就写入文件
504
+ async for batch in self._client.aiter_stream_requests(
505
+ request_params=request_params,
506
+ url=effective_url,
507
+ method="POST",
508
+ show_progress=show_progress,
509
+ total_requests=len(messages_to_run),
510
+ ):
511
+ for result in batch.completed_requests:
512
+ original_idx = indices_to_run[result.request_id]
513
+ # 检查请求状态
514
+ if result.status != "success":
515
+ error_msg = result.data.get("error", "Unknown error") if isinstance(result.data, dict) else str(result.data)
516
+ logger.warning(f"请求失败: {error_msg}")
517
+ responses[original_idx] = None
518
+ on_file_result(original_idx, None, "error", error_msg)
519
+ continue
520
+ try:
521
+ extracted = extract_fn(result)
522
+ responses[original_idx] = extracted
523
+ # 文件输出(存储 content 和 usage)
524
+ if return_usage:
525
+ on_file_result(original_idx, extracted.content, usage=extracted.usage)
526
+ else:
527
+ on_file_result(original_idx, extracted)
528
+ except Exception as e:
529
+ logger.warning(f"Error: {e}, set content to None")
530
+ responses[original_idx] = None
531
+ on_file_result(original_idx, None, "error", str(e))
532
+ if batch.is_final:
533
+ progress = batch.progress
534
+
535
+ finally:
536
+ # 确保最后的数据写入
537
+ flush_to_file()
538
+ if file_writer:
539
+ file_writer.close()
540
+ # 自动 compact:去重,保留每个 index 的最新成功记录
541
+ self._compact_output_file(output_jsonl)
542
+
543
+ summary = progress.summary(print_to_console=False) if progress else None
544
+ return (responses, summary) if return_summary else responses
545
+
546
+ def _compact_output_file(self, file_path: str):
547
+ """去重输出文件,保留每个 index 的最新成功记录"""
548
+ import os
549
+
550
+ tmp_path = file_path + ".tmp"
551
+ try:
552
+ records = {}
553
+ with open(file_path, "r", encoding="utf-8") as f:
554
+ for line in f:
555
+ if not line.strip():
556
+ continue
557
+ r = json.loads(line)
558
+ idx = r.get("index")
559
+ if idx is None:
560
+ continue
561
+ # 成功记录优先,或者该 index 还没有记录
562
+ if r.get("status") == "success" or idx not in records:
563
+ records[idx] = r
564
+
565
+ # 先写入临时文件
566
+ with open(tmp_path, "w", encoding="utf-8") as f:
567
+ for r in sorted(records.values(), key=lambda x: x["index"]):
568
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
569
+
570
+ # 原子替换(同一文件系统上 replace 是原子操作)
571
+ os.replace(tmp_path, file_path)
572
+ except Exception as e:
573
+ logger.warning(f"Compact 输出文件失败: {e}")
574
+ # 清理可能残留的临时文件
575
+ if os.path.exists(tmp_path):
576
+ os.remove(tmp_path)
577
+
578
+ def chat_completions_batch_sync(
579
+ self,
580
+ messages_list: List[List[dict]],
581
+ model: str = None,
582
+ return_raw: bool = False,
583
+ return_usage: bool = False,
584
+ show_progress: bool = True,
585
+ return_summary: bool = False,
586
+ output_jsonl: Optional[str] = None,
587
+ flush_interval: float = 1.0,
588
+ metadata_list: Optional[List[dict]] = None,
589
+ **kwargs,
590
+ ) -> Union[List[str], List[ChatCompletionResult], tuple]:
591
+ """同步版本的批量聊天完成"""
592
+ return asyncio.run(
593
+ self.chat_completions_batch(
594
+ messages_list=messages_list,
595
+ model=model,
596
+ return_raw=return_raw,
597
+ return_usage=return_usage,
598
+ show_progress=show_progress,
599
+ return_summary=return_summary,
600
+ output_jsonl=output_jsonl,
601
+ flush_interval=flush_interval,
602
+ metadata_list=metadata_list,
603
+ **kwargs,
604
+ )
605
+ )
606
+
607
+ async def iter_chat_completions_batch(
608
+ self,
609
+ messages_list: List[List[dict]],
610
+ model: str = None,
611
+ return_raw: bool = False,
612
+ return_usage: bool = False,
613
+ show_progress: bool = True,
614
+ preprocess_msg: bool = False,
615
+ output_jsonl: Optional[str] = None,
616
+ flush_interval: float = 1.0,
617
+ metadata_list: Optional[List[dict]] = None,
618
+ batch_size: int = None,
619
+ url: str = None,
620
+ **kwargs,
621
+ ):
622
+ """
623
+ 迭代式批量聊天完成(边请求边返回结果)
624
+
625
+ 与 chat_completions_batch 功能相同,但以流式方式逐条返回结果,
626
+ 适合处理大批量数据时节省内存。
627
+
628
+ Args:
629
+ messages_list: 消息列表
630
+ model: 模型名称
631
+ return_raw: 是否返回原始响应(影响 result.content 的内容)
632
+ return_usage: 是否在 result 对象上添加 usage 属性
633
+ show_progress: 是否显示进度条
634
+ preprocess_msg: 是否预处理消息
635
+ output_jsonl: 输出文件路径(JSONL 格式),用于持久化保存结果
636
+ flush_interval: 文件刷新间隔(秒),默认 1 秒
637
+ metadata_list: 元数据列表,与 messages_list 等长,每个元素保存到对应输出记录
638
+ batch_size: 每批返回的数量(传递给底层请求器)
639
+ url: 自定义请求 URL,默认使用 _get_url() 生成
640
+
641
+ Yields:
642
+ result: 包含以下属性的结果对象
643
+ - content: 提取后的内容 (str | dict)
644
+ - usage: token 用量信息(仅当 return_usage=True 时)
645
+ - original_idx: 原始索引
646
+ - latency: 请求延迟(秒)
647
+ - status: 状态 ('success', 'error', 'cached')
648
+ - error: 错误信息(如果有)
649
+ - data: 原始响应数据
650
+ - summary: 最后一个 result 包含整体统计 (dict),其他为 None
651
+ - total: 总请求数
652
+ - success: 成功数
653
+ - failed: 失败数
654
+ - cached: 缓存命中数
655
+ - elapsed: 总耗时(秒)
656
+ - avg_latency: 平均延迟(秒)
657
+
658
+ Note:
659
+ 缓存由初始化时的 cache 参数控制,return_usage=True 时自动跳过缓存
660
+ """
661
+ effective_model = self._get_effective_model(model)
662
+ effective_url = url or self._get_url(effective_model, stream=False)
663
+ headers = self._get_headers()
664
+
665
+ # metadata_list 长度校验
666
+ if metadata_list is not None and len(metadata_list) != len(messages_list):
667
+ raise ValueError(
668
+ f"metadata_list 长度 ({len(metadata_list)}) 必须与 messages_list 长度 ({len(messages_list)}) 一致"
669
+ )
670
+
671
+ # output_jsonl 扩展名校验
672
+ if output_jsonl and not output_jsonl.endswith(".jsonl"):
673
+ raise ValueError(f"output_jsonl 必须使用 .jsonl 扩展名,当前: {output_jsonl}")
674
+
675
+ messages_list = await self._preprocess_messages_batch(messages_list, preprocess_msg)
676
+
677
+ # return_usage 时跳过缓存
678
+ use_cache = self._response_cache is not None and not return_usage
679
+
680
+ def extractor(result):
681
+ if return_raw:
682
+ return result.data
683
+ return self._extract_content(result.data) if result.data else None
684
+
685
+ # 文件输出相关状态
686
+ file_writer = None
687
+ file_buffer = []
688
+ last_flush_time = time.time()
689
+ completed_indices = set()
690
+
691
+ # 如果指定了输出文件,读取已完成的索引(断点续传)
692
+ if output_jsonl:
693
+ output_path = Path(output_jsonl)
694
+ if output_path.exists():
695
+ records = []
696
+ with open(output_path, "r", encoding="utf-8") as f:
697
+ for line in f:
698
+ try:
699
+ record = json.loads(line.strip())
700
+ if record.get("status") == "success" and "input" in record:
701
+ idx = record.get("index")
702
+ if 0 <= idx < len(messages_list):
703
+ records.append(record)
704
+ except (json.JSONDecodeError, KeyError, TypeError):
705
+ continue
706
+
707
+ # 首尾校验
708
+ file_valid = True
709
+ if records:
710
+ first, last = records[0], records[-1]
711
+ if first["input"] != messages_list[first["index"]]:
712
+ file_valid = False
713
+ elif len(records) > 1 and last["input"] != messages_list[last["index"]]:
714
+ file_valid = False
715
+
716
+ if file_valid:
717
+ completed_indices = {r["index"] for r in records}
718
+ if completed_indices:
719
+ logger.info(f"从文件恢复: 已完成 {len(completed_indices)}/{len(messages_list)}")
720
+ else:
721
+ raise ValueError(
722
+ f"文件校验失败: {output_jsonl} 中的 input 与当前 messages_list 不匹配。"
723
+ f"请删除或重命名该文件后重试。"
724
+ )
725
+
726
+ file_writer = open(output_path, "a", encoding="utf-8")
727
+
728
+ def flush_to_file():
729
+ nonlocal file_buffer, last_flush_time
730
+ if file_writer and file_buffer:
731
+ for record in file_buffer:
732
+ file_writer.write(json.dumps(record, ensure_ascii=False) + "\n")
733
+ file_writer.flush()
734
+ file_buffer = []
735
+ last_flush_time = time.time()
736
+
737
+ def on_file_result(original_idx: int, content: Any, status: str = "success", error: str = None, usage: dict = None):
738
+ nonlocal last_flush_time
739
+ if file_writer is None:
740
+ return
741
+ record = {
742
+ "index": original_idx,
743
+ "output": content,
744
+ "status": status,
745
+ "input": messages_list[original_idx],
746
+ }
747
+ if metadata_list is not None:
748
+ record["metadata"] = metadata_list[original_idx]
749
+ if usage is not None:
750
+ record["usage"] = usage
751
+ if error:
752
+ record["error"] = error
753
+ file_buffer.append(record)
754
+ if time.time() - last_flush_time >= flush_interval:
755
+ flush_to_file()
756
+
757
+ try:
758
+ # 统计信息
759
+ total_count = len(messages_list)
760
+ yielded_count = 0
761
+ success_count = 0
762
+ cached_count = 0
763
+ start_time = time.time()
764
+ total_latency = 0.0
765
+ last_progress = None
766
+
767
+ # 查询缓存
768
+ cached_responses = [None] * len(messages_list)
769
+ uncached_indices = list(range(len(messages_list)))
770
+
771
+ if use_cache and self._response_cache:
772
+ cached_responses, uncached_indices = self._response_cache.get_batch(
773
+ messages_list, model=effective_model, **kwargs
774
+ )
775
+
776
+ # 先 yield 缓存命中的结果
777
+ for i, resp in enumerate(cached_responses):
778
+ if resp is not None:
779
+ if i not in completed_indices:
780
+ on_file_result(i, resp)
781
+ # 缓存命中时创建结果对象
782
+ from types import SimpleNamespace
783
+
784
+ yielded_count += 1
785
+ cached_count += 1
786
+ success_count += 1
787
+ is_last = yielded_count == total_count
788
+
789
+ cached_result = SimpleNamespace(
790
+ content=resp,
791
+ usage=None, # 缓存不包含 usage 信息
792
+ original_idx=i,
793
+ latency=0.0,
794
+ status="cached",
795
+ error=None,
796
+ data=None,
797
+ summary=None,
798
+ )
799
+ if is_last:
800
+ cached_result.summary = {
801
+ "total": total_count,
802
+ "success": success_count,
803
+ "failed": total_count - success_count,
804
+ "cached": cached_count,
805
+ "elapsed": time.time() - start_time,
806
+ "avg_latency": total_latency / max(yielded_count - cached_count, 1),
807
+ }
808
+ yield cached_result
809
+
810
+ # 过滤掉文件中已完成的
811
+ actual_uncached = [i for i in uncached_indices if i not in completed_indices]
812
+
813
+ if actual_uncached:
814
+ logger.info(f"待执行: {len(actual_uncached)}/{len(messages_list)}")
815
+
816
+ uncached_messages = [messages_list[i] for i in actual_uncached]
817
+ request_params = [
818
+ {"json": self._build_request_body(m, effective_model, **kwargs), "headers": headers}
819
+ for m in uncached_messages
820
+ ]
821
+
822
+ async for batch in self._client.aiter_stream_requests(
823
+ request_params=request_params,
824
+ url=effective_url,
825
+ method="POST",
826
+ show_progress=show_progress,
827
+ batch_size=batch_size,
828
+ total_requests=len(uncached_messages),
829
+ ):
830
+ for result in batch.completed_requests:
831
+ original_idx = actual_uncached[result.request_id]
832
+ yielded_count += 1
833
+ is_last = yielded_count == total_count
834
+
835
+ # 检查请求状态
836
+ if result.status != "success":
837
+ error_msg = result.data.get("error", "Unknown error") if isinstance(result.data, dict) else str(result.data)
838
+ logger.warning(f"请求失败: {error_msg}")
839
+ on_file_result(original_idx, None, "error", error_msg)
840
+ result.content = None
841
+ result.usage = None
842
+ result.original_idx = original_idx
843
+ result.error = error_msg
844
+ else:
845
+ try:
846
+ content = extractor(result)
847
+ usage = self._extract_usage(result.data) if return_usage else None
848
+ # 写入缓存
849
+ if use_cache and self._response_cache and content is not None and not return_raw:
850
+ self._response_cache.set(
851
+ messages_list[original_idx], content, model=effective_model, **kwargs
852
+ )
853
+ on_file_result(original_idx, content, usage=usage)
854
+ # 在 result 对象上添加属性
855
+ result.content = content
856
+ result.usage = usage
857
+ result.original_idx = original_idx
858
+ success_count += 1
859
+ total_latency += result.latency
860
+ except Exception as e:
861
+ logger.warning(f"提取结果失败: {e}")
862
+ on_file_result(original_idx, None, "error", str(e))
863
+ result.content = None
864
+ result.usage = None
865
+ result.original_idx = original_idx
866
+
867
+ # 最后一个 result 添加 summary
868
+ result.summary = None
869
+ if is_last:
870
+ result.summary = {
871
+ "total": total_count,
872
+ "success": success_count,
873
+ "failed": total_count - success_count,
874
+ "cached": cached_count,
875
+ "elapsed": time.time() - start_time,
876
+ "avg_latency": total_latency / max(yielded_count - cached_count, 1),
877
+ }
878
+ yield result
879
+ if batch.is_final:
880
+ last_progress = batch.progress
881
+
882
+ finally:
883
+ flush_to_file()
884
+ if file_writer:
885
+ file_writer.close()
886
+ # 自动 compact:去重,保留每个 index 的最新成功记录
887
+ self._compact_output_file(output_jsonl)
888
+
889
+ async def chat_completions_stream(
890
+ self,
891
+ messages: List[dict],
892
+ model: str = None,
893
+ return_usage: bool = False,
894
+ preprocess_msg: bool = False,
895
+ url: str = None,
896
+ timeout: int = None,
897
+ **kwargs,
898
+ ):
899
+ """
900
+ 流式聊天完成
901
+
902
+ Args:
903
+ messages: 消息列表
904
+ model: 模型名称
905
+ return_usage: 是否返回 usage 信息。当为 True 时,yield 的是 dict:
906
+ - {"type": "content", "content": "..."} 表示内容片段
907
+ - {"type": "usage", "usage": {...}} 表示 token 用量(最后一条)
908
+ 当为 False 时(默认),yield 的是 str 内容片段
909
+ preprocess_msg: 是否预处理消息
910
+ url: 自定义请求 URL,默认使用 _get_stream_url() 生成
911
+ timeout: 超时时间(秒),默认使用客户端配置
912
+
913
+ Yields:
914
+ - return_usage=False: str 内容片段
915
+ - return_usage=True: dict,包含 type 和对应数据
916
+ """
917
+ import aiohttp
918
+ import json
919
+
920
+ effective_model = self._get_effective_model(model)
921
+ messages = await self._preprocess_messages(messages, preprocess_msg)
922
+
923
+ body = self._build_request_body(messages, effective_model, stream=True, **kwargs)
924
+
925
+ # 当需要 usage 时,添加 stream_options(OpenAI 格式)
926
+ if return_usage:
927
+ body["stream_options"] = {"include_usage": True}
928
+
929
+ effective_url = url or self._get_stream_url(effective_model)
930
+ headers = self._get_headers()
931
+
932
+ effective_timeout = timeout if timeout is not None else self._timeout
933
+ aio_timeout = aiohttp.ClientTimeout(total=effective_timeout)
934
+
935
+ async with aiohttp.ClientSession(trust_env=True) as session:
936
+ async with session.post(effective_url, json=body, headers=headers, timeout=aio_timeout) as response:
937
+ if response.status != 200:
938
+ error_text = await response.text()
939
+ raise Exception(f"HTTP {response.status}: {error_text}")
940
+
941
+ async for line in response.content:
942
+ line = line.decode("utf-8").strip()
943
+ if line.startswith("data: "):
944
+ data_str = line[6:]
945
+ if data_str == "[DONE]":
946
+ break
947
+ try:
948
+ data = json.loads(data_str)
949
+
950
+ # 检查是否包含 usage(流式响应的最后一个 chunk)
951
+ if return_usage and "usage" in data and data["usage"]:
952
+ yield {"type": "usage", "usage": data["usage"]}
953
+ continue
954
+
955
+ content = self._extract_stream_content(data)
956
+ if content:
957
+ if return_usage:
958
+ yield {"type": "content", "content": content}
959
+ else:
960
+ yield content
961
+ except json.JSONDecodeError:
962
+ continue
963
+
964
+ def model_list(self) -> List[str]:
965
+ raise NotImplementedError("子类需要实现 model_list 方法")
966
+
967
+ def __repr__(self) -> str:
968
+ return f"{self.__class__.__name__}(model='{self._model}')"
969
+
970
+ # ========== 资源管理 ==========
971
+
972
+ def close(self):
973
+ """关闭客户端,释放资源(如缓存连接)"""
974
+ if self._response_cache is not None:
975
+ self._response_cache.close()
976
+ self._response_cache = None
977
+
978
+ def __enter__(self):
979
+ return self
980
+
981
+ def __exit__(self, *args):
982
+ self.close()
983
+
984
+ async def __aenter__(self):
985
+ return self
986
+
987
+ async def __aexit__(self, *args):
988
+ self.close()