flexllm 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flexllm/__init__.py +224 -0
- flexllm/__main__.py +1096 -0
- flexllm/async_api/__init__.py +9 -0
- flexllm/async_api/concurrent_call.py +100 -0
- flexllm/async_api/concurrent_executor.py +1036 -0
- flexllm/async_api/core.py +373 -0
- flexllm/async_api/interface.py +12 -0
- flexllm/async_api/progress.py +277 -0
- flexllm/base_client.py +988 -0
- flexllm/batch_tools/__init__.py +16 -0
- flexllm/batch_tools/folder_processor.py +317 -0
- flexllm/batch_tools/table_processor.py +363 -0
- flexllm/cache/__init__.py +10 -0
- flexllm/cache/response_cache.py +293 -0
- flexllm/chain_of_thought_client.py +1120 -0
- flexllm/claudeclient.py +402 -0
- flexllm/client_pool.py +698 -0
- flexllm/geminiclient.py +563 -0
- flexllm/llm_client.py +523 -0
- flexllm/llm_parser.py +60 -0
- flexllm/mllm_client.py +559 -0
- flexllm/msg_processors/__init__.py +174 -0
- flexllm/msg_processors/image_processor.py +729 -0
- flexllm/msg_processors/image_processor_helper.py +485 -0
- flexllm/msg_processors/messages_processor.py +341 -0
- flexllm/msg_processors/unified_processor.py +1404 -0
- flexllm/openaiclient.py +256 -0
- flexllm/pricing/__init__.py +104 -0
- flexllm/pricing/data.json +1201 -0
- flexllm/pricing/updater.py +223 -0
- flexllm/provider_router.py +213 -0
- flexllm/token_counter.py +270 -0
- flexllm/utils/__init__.py +1 -0
- flexllm/utils/core.py +41 -0
- flexllm-0.3.3.dist-info/METADATA +573 -0
- flexllm-0.3.3.dist-info/RECORD +39 -0
- flexllm-0.3.3.dist-info/WHEEL +4 -0
- flexllm-0.3.3.dist-info/entry_points.txt +3 -0
- flexllm-0.3.3.dist-info/licenses/LICENSE +201 -0
flexllm/base_client.py
ADDED
|
@@ -0,0 +1,988 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLMClientBase - LLM 客户端抽象基类
|
|
3
|
+
|
|
4
|
+
提供通用的方法实现,子类只需实现核心的差异化方法。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import time
|
|
10
|
+
from abc import ABC
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import TYPE_CHECKING, List, Union, Optional, Any
|
|
14
|
+
|
|
15
|
+
from loguru import logger
|
|
16
|
+
|
|
17
|
+
from .async_api import ConcurrentRequester
|
|
18
|
+
from .msg_processors.image_processor import ImageCacheConfig
|
|
19
|
+
from .msg_processors.messages_processor import messages_preprocess
|
|
20
|
+
from .msg_processors.unified_processor import batch_process_messages as optimized_batch_preprocess
|
|
21
|
+
from .cache import ResponseCache, ResponseCacheConfig
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from .async_api.interface import RequestResult
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class ToolCall:
|
|
29
|
+
"""工具调用信息"""
|
|
30
|
+
id: str
|
|
31
|
+
type: str # "function"
|
|
32
|
+
function: dict # {"name": "...", "arguments": "..."}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class ChatCompletionResult:
|
|
37
|
+
"""聊天完成的结果,包含内容和 token 用量信息"""
|
|
38
|
+
content: str
|
|
39
|
+
usage: Optional[dict] = None # {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
|
|
40
|
+
reasoning_content: Optional[str] = None # 思考内容(DeepSeek-R1、Qwen3 等)
|
|
41
|
+
tool_calls: Optional[List["ToolCall"]] = None # 工具调用列表
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class BatchResultItem:
|
|
46
|
+
"""批量请求中单条结果,包含索引、内容和 usage"""
|
|
47
|
+
index: int
|
|
48
|
+
content: Optional[str]
|
|
49
|
+
usage: Optional[dict] = None
|
|
50
|
+
status: str = "success" # success, error, cached
|
|
51
|
+
error: Optional[str] = None
|
|
52
|
+
latency: float = 0.0
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class LLMClientBase(ABC):
|
|
56
|
+
"""
|
|
57
|
+
LLM 客户端抽象基类
|
|
58
|
+
|
|
59
|
+
子类只需实现 4 个核心方法:
|
|
60
|
+
- _get_url(model, stream) -> str
|
|
61
|
+
- _get_headers() -> dict
|
|
62
|
+
- _build_request_body(messages, model, **kwargs) -> dict
|
|
63
|
+
- _extract_content(response_data) -> str
|
|
64
|
+
|
|
65
|
+
可选覆盖:
|
|
66
|
+
- _extract_stream_content(data) -> str
|
|
67
|
+
- _get_stream_url(model) -> str
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
base_url: str = None,
|
|
73
|
+
api_key: str = None,
|
|
74
|
+
model: str = None,
|
|
75
|
+
concurrency_limit: int = 10,
|
|
76
|
+
max_qps: int = 1000,
|
|
77
|
+
timeout: int = 120,
|
|
78
|
+
retry_times: int = 3,
|
|
79
|
+
retry_delay: float = 1.0,
|
|
80
|
+
cache_image: bool = False,
|
|
81
|
+
cache_dir: str = "image_cache",
|
|
82
|
+
cache: Union[bool, ResponseCacheConfig, None] = None,
|
|
83
|
+
**kwargs,
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Args:
|
|
87
|
+
base_url: API 基础 URL
|
|
88
|
+
api_key: API 密钥
|
|
89
|
+
model: 默认模型名称
|
|
90
|
+
concurrency_limit: 并发请求数限制
|
|
91
|
+
max_qps: 最大 QPS
|
|
92
|
+
timeout: 请求超时时间(秒)
|
|
93
|
+
retry_times: 重试次数
|
|
94
|
+
retry_delay: 重试延迟(秒)
|
|
95
|
+
cache_image: 是否缓存图片
|
|
96
|
+
cache_dir: 图片缓存目录
|
|
97
|
+
cache: 响应缓存配置
|
|
98
|
+
- True: 启用缓存(默认 IPC 模式,24小时 TTL)
|
|
99
|
+
- False/None: 禁用缓存(默认)
|
|
100
|
+
- ResponseCacheConfig: 自定义配置
|
|
101
|
+
"""
|
|
102
|
+
self._base_url = base_url.rstrip("/") if base_url else None
|
|
103
|
+
self._api_key = api_key
|
|
104
|
+
self._model = model
|
|
105
|
+
self._concurrency_limit = concurrency_limit
|
|
106
|
+
self._timeout = timeout
|
|
107
|
+
|
|
108
|
+
self._client = ConcurrentRequester(
|
|
109
|
+
concurrency_limit=concurrency_limit,
|
|
110
|
+
max_qps=max_qps,
|
|
111
|
+
timeout=timeout,
|
|
112
|
+
retry_times=retry_times,
|
|
113
|
+
retry_delay=retry_delay,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
self._cache_config = ImageCacheConfig(
|
|
117
|
+
enabled=cache_image,
|
|
118
|
+
cache_dir=cache_dir,
|
|
119
|
+
force_refresh=False,
|
|
120
|
+
retry_failed=False,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# 响应缓存
|
|
124
|
+
if cache is True:
|
|
125
|
+
cache = ResponseCacheConfig.ipc() # 默认 IPC 模式,24小时 TTL
|
|
126
|
+
elif cache is None or cache is False:
|
|
127
|
+
cache = ResponseCacheConfig.disabled()
|
|
128
|
+
self._response_cache = ResponseCache(cache) if cache.enabled else None
|
|
129
|
+
|
|
130
|
+
# ========== 核心抽象方法(子类必须实现)==========
|
|
131
|
+
|
|
132
|
+
def _get_url(self, model: str, stream: bool = False) -> str:
|
|
133
|
+
raise NotImplementedError
|
|
134
|
+
|
|
135
|
+
def _get_headers(self) -> dict:
|
|
136
|
+
raise NotImplementedError
|
|
137
|
+
|
|
138
|
+
def _build_request_body(
|
|
139
|
+
self, messages: List[dict], model: str, stream: bool = False, **kwargs
|
|
140
|
+
) -> dict:
|
|
141
|
+
raise NotImplementedError
|
|
142
|
+
|
|
143
|
+
def _extract_content(self, response_data: dict) -> Optional[str]:
|
|
144
|
+
raise NotImplementedError
|
|
145
|
+
|
|
146
|
+
def _extract_usage(self, response_data: dict) -> Optional[dict]:
|
|
147
|
+
"""提取 usage 信息(子类可覆盖)"""
|
|
148
|
+
if not response_data:
|
|
149
|
+
return None
|
|
150
|
+
return response_data.get("usage")
|
|
151
|
+
|
|
152
|
+
def _extract_tool_calls(self, response_data: dict) -> Optional[List[ToolCall]]:
|
|
153
|
+
"""提取工具调用信息(子类可覆盖)"""
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
# ========== 可选覆盖的钩子方法 ==========
|
|
157
|
+
|
|
158
|
+
def _extract_stream_content(self, data: dict) -> Optional[str]:
|
|
159
|
+
return self._extract_content(data)
|
|
160
|
+
|
|
161
|
+
def _get_stream_url(self, model: str) -> str:
|
|
162
|
+
return self._get_url(model, stream=True)
|
|
163
|
+
|
|
164
|
+
# ========== 通用工具方法 ==========
|
|
165
|
+
|
|
166
|
+
def _get_effective_model(self, model: str = None) -> str:
|
|
167
|
+
effective_model = model or self._model
|
|
168
|
+
if not effective_model:
|
|
169
|
+
raise ValueError("必须提供 model 参数或在初始化时指定 model")
|
|
170
|
+
return effective_model
|
|
171
|
+
|
|
172
|
+
async def _preprocess_messages(
|
|
173
|
+
self, messages: List[dict], preprocess_msg: bool = False
|
|
174
|
+
) -> List[dict]:
|
|
175
|
+
"""消息预处理(图片转 base64 等)"""
|
|
176
|
+
if preprocess_msg:
|
|
177
|
+
return await messages_preprocess(
|
|
178
|
+
messages, preprocess_msg=preprocess_msg, cache_config=self._cache_config
|
|
179
|
+
)
|
|
180
|
+
return messages
|
|
181
|
+
|
|
182
|
+
async def _preprocess_messages_batch(
|
|
183
|
+
self, messages_list: List[List[dict]], preprocess_msg: bool = False
|
|
184
|
+
) -> List[List[dict]]:
|
|
185
|
+
"""批量消息预处理"""
|
|
186
|
+
if preprocess_msg:
|
|
187
|
+
return await optimized_batch_preprocess(
|
|
188
|
+
messages_list, max_concurrent=self._concurrency_limit, cache_config=self._cache_config
|
|
189
|
+
)
|
|
190
|
+
return messages_list
|
|
191
|
+
|
|
192
|
+
# ========== 通用接口实现 ==========
|
|
193
|
+
|
|
194
|
+
async def chat_completions(
|
|
195
|
+
self,
|
|
196
|
+
messages: List[dict],
|
|
197
|
+
model: str = None,
|
|
198
|
+
return_raw: bool = False,
|
|
199
|
+
return_usage: bool = False,
|
|
200
|
+
show_progress: bool = False,
|
|
201
|
+
preprocess_msg: bool = False,
|
|
202
|
+
url: str = None,
|
|
203
|
+
**kwargs,
|
|
204
|
+
) -> Union[str, ChatCompletionResult, "RequestResult"]:
|
|
205
|
+
"""
|
|
206
|
+
单条聊天完成
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
messages: 消息列表
|
|
210
|
+
model: 模型名称
|
|
211
|
+
return_raw: 是否返回原始响应(RequestResult)
|
|
212
|
+
return_usage: 是否返回包含 usage 的结果(ChatCompletionResult)
|
|
213
|
+
show_progress: 是否显示进度条
|
|
214
|
+
preprocess_msg: 是否预处理消息
|
|
215
|
+
url: 自定义请求 URL,默认使用 _get_url() 生成
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
- return_raw=True: RequestResult 原始响应
|
|
219
|
+
- return_usage=True: ChatCompletionResult(content, usage, reasoning_content)
|
|
220
|
+
- 默认: str 内容文本
|
|
221
|
+
|
|
222
|
+
Note:
|
|
223
|
+
缓存由初始化时的 cache 参数控制,return_raw/return_usage 时自动跳过缓存
|
|
224
|
+
"""
|
|
225
|
+
effective_model = self._get_effective_model(model)
|
|
226
|
+
messages = await self._preprocess_messages(messages, preprocess_msg)
|
|
227
|
+
|
|
228
|
+
# 检查缓存(缓存不包含 usage 信息,return_raw/return_usage 时跳过缓存)
|
|
229
|
+
use_cache = self._response_cache is not None and not return_raw and not return_usage
|
|
230
|
+
if use_cache:
|
|
231
|
+
cached = self._response_cache.get(messages, model=effective_model, **kwargs)
|
|
232
|
+
if cached is not None:
|
|
233
|
+
return cached
|
|
234
|
+
|
|
235
|
+
body = self._build_request_body(messages, effective_model, stream=False, **kwargs)
|
|
236
|
+
request_params = {"json": body, "headers": self._get_headers()}
|
|
237
|
+
effective_url = url or self._get_url(effective_model, stream=False)
|
|
238
|
+
|
|
239
|
+
results, _ = await self._client.process_requests(
|
|
240
|
+
request_params=[request_params], url=effective_url, method="POST", show_progress=show_progress
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
data = results[0]
|
|
244
|
+
if return_raw:
|
|
245
|
+
return data
|
|
246
|
+
if data.status == "success":
|
|
247
|
+
content = self._extract_content(data.data)
|
|
248
|
+
# 写入缓存
|
|
249
|
+
if use_cache and content is not None:
|
|
250
|
+
self._response_cache.set(messages, content, model=effective_model, **kwargs)
|
|
251
|
+
|
|
252
|
+
if return_usage:
|
|
253
|
+
usage = self._extract_usage(data.data)
|
|
254
|
+
tool_calls = self._extract_tool_calls(data.data)
|
|
255
|
+
return ChatCompletionResult(content=content, usage=usage, tool_calls=tool_calls)
|
|
256
|
+
return content
|
|
257
|
+
return data
|
|
258
|
+
|
|
259
|
+
def chat_completions_sync(
|
|
260
|
+
self,
|
|
261
|
+
messages: List[dict],
|
|
262
|
+
model: str = None,
|
|
263
|
+
return_raw: bool = False,
|
|
264
|
+
return_usage: bool = False,
|
|
265
|
+
**kwargs,
|
|
266
|
+
) -> Union[str, ChatCompletionResult, "RequestResult"]:
|
|
267
|
+
"""同步版本的聊天完成"""
|
|
268
|
+
return asyncio.run(
|
|
269
|
+
self.chat_completions(
|
|
270
|
+
messages=messages, model=model, return_raw=return_raw, return_usage=return_usage, **kwargs
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
async def chat_completions_batch(
|
|
275
|
+
self,
|
|
276
|
+
messages_list: List[List[dict]],
|
|
277
|
+
model: str = None,
|
|
278
|
+
return_raw: bool = False,
|
|
279
|
+
return_usage: bool = False,
|
|
280
|
+
show_progress: bool = True,
|
|
281
|
+
return_summary: bool = False,
|
|
282
|
+
preprocess_msg: bool = False,
|
|
283
|
+
output_jsonl: Optional[str] = None,
|
|
284
|
+
flush_interval: float = 1.0,
|
|
285
|
+
metadata_list: Optional[List[dict]] = None,
|
|
286
|
+
url: str = None,
|
|
287
|
+
**kwargs,
|
|
288
|
+
) -> Union[List[str], List[ChatCompletionResult], tuple]:
|
|
289
|
+
"""
|
|
290
|
+
批量聊天完成(支持断点续传)
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
messages_list: 消息列表
|
|
294
|
+
model: 模型名称
|
|
295
|
+
return_raw: 是否返回原始响应
|
|
296
|
+
return_usage: 是否返回包含 usage 的结果(ChatCompletionResult 列表)
|
|
297
|
+
show_progress: 是否显示进度条
|
|
298
|
+
return_summary: 是否返回执行摘要
|
|
299
|
+
preprocess_msg: 是否预处理消息
|
|
300
|
+
output_jsonl: 输出文件路径(JSONL 格式),用于持久化保存结果
|
|
301
|
+
flush_interval: 文件刷新间隔(秒),默认 1 秒
|
|
302
|
+
metadata_list: 元数据列表,与 messages_list 等长,每个元素保存到对应输出记录
|
|
303
|
+
url: 自定义请求 URL,默认使用 _get_url() 生成
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
- return_usage=True: List[ChatCompletionResult] 或 (List[ChatCompletionResult], summary)
|
|
307
|
+
- 默认: List[str] 或 (List[str], summary)
|
|
308
|
+
|
|
309
|
+
Note:
|
|
310
|
+
缓存由初始化时的 cache 参数控制,return_usage=True 时自动跳过缓存
|
|
311
|
+
"""
|
|
312
|
+
effective_model = self._get_effective_model(model)
|
|
313
|
+
effective_url = url or self._get_url(effective_model, stream=False)
|
|
314
|
+
headers = self._get_headers()
|
|
315
|
+
|
|
316
|
+
# metadata_list 长度校验
|
|
317
|
+
if metadata_list is not None and len(metadata_list) != len(messages_list):
|
|
318
|
+
raise ValueError(
|
|
319
|
+
f"metadata_list 长度 ({len(metadata_list)}) 必须与 messages_list 长度 ({len(messages_list)}) 一致"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# output_jsonl 扩展名校验
|
|
323
|
+
if output_jsonl and not output_jsonl.endswith(".jsonl"):
|
|
324
|
+
raise ValueError(f"output_jsonl 必须使用 .jsonl 扩展名,当前: {output_jsonl}")
|
|
325
|
+
|
|
326
|
+
messages_list = await self._preprocess_messages_batch(messages_list, preprocess_msg)
|
|
327
|
+
|
|
328
|
+
# return_usage 时跳过缓存(缓存不包含 usage 信息)
|
|
329
|
+
use_cache = self._response_cache is not None and not return_usage
|
|
330
|
+
|
|
331
|
+
def extractor(result):
|
|
332
|
+
return self._extract_content(result.data)
|
|
333
|
+
|
|
334
|
+
def extractor_with_usage(result):
|
|
335
|
+
content = self._extract_content(result.data)
|
|
336
|
+
usage = self._extract_usage(result.data)
|
|
337
|
+
tool_calls = self._extract_tool_calls(result.data)
|
|
338
|
+
return ChatCompletionResult(content=content, usage=usage, tool_calls=tool_calls)
|
|
339
|
+
|
|
340
|
+
# 文件输出相关状态
|
|
341
|
+
file_writer = None
|
|
342
|
+
file_buffer = []
|
|
343
|
+
last_flush_time = time.time()
|
|
344
|
+
completed_indices = set()
|
|
345
|
+
|
|
346
|
+
# 如果指定了输出文件,读取已完成的索引(断点续传)
|
|
347
|
+
if output_jsonl:
|
|
348
|
+
output_path = Path(output_jsonl)
|
|
349
|
+
if output_path.exists():
|
|
350
|
+
# 读取所有有效记录
|
|
351
|
+
records = []
|
|
352
|
+
with open(output_path, "r", encoding="utf-8") as f:
|
|
353
|
+
for line in f:
|
|
354
|
+
try:
|
|
355
|
+
record = json.loads(line.strip())
|
|
356
|
+
if record.get("status") == "success" and "input" in record:
|
|
357
|
+
idx = record.get("index")
|
|
358
|
+
if 0 <= idx < len(messages_list):
|
|
359
|
+
records.append(record)
|
|
360
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
# 首尾校验:只比较第一条和最后一条的 input
|
|
364
|
+
file_valid = True
|
|
365
|
+
if records:
|
|
366
|
+
first, last = records[0], records[-1]
|
|
367
|
+
if first["input"] != messages_list[first["index"]]:
|
|
368
|
+
file_valid = False
|
|
369
|
+
elif len(records) > 1 and last["input"] != messages_list[last["index"]]:
|
|
370
|
+
file_valid = False
|
|
371
|
+
|
|
372
|
+
if file_valid:
|
|
373
|
+
completed_indices = {r["index"] for r in records}
|
|
374
|
+
if completed_indices:
|
|
375
|
+
logger.info(f"从文件恢复: 已完成 {len(completed_indices)}/{len(messages_list)}")
|
|
376
|
+
else:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
f"文件校验失败: {output_jsonl} 中的 input 与当前 messages_list 不匹配。"
|
|
379
|
+
f"请删除或重命名该文件后重试。"
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
file_writer = open(output_path, "a", encoding="utf-8")
|
|
383
|
+
|
|
384
|
+
def flush_to_file():
|
|
385
|
+
"""刷新缓冲区到文件"""
|
|
386
|
+
nonlocal file_buffer, last_flush_time
|
|
387
|
+
if file_writer and file_buffer:
|
|
388
|
+
for record in file_buffer:
|
|
389
|
+
file_writer.write(json.dumps(record, ensure_ascii=False) + "\n")
|
|
390
|
+
file_writer.flush()
|
|
391
|
+
file_buffer = []
|
|
392
|
+
last_flush_time = time.time()
|
|
393
|
+
|
|
394
|
+
def on_file_result(original_idx: int, content: Any, status: str = "success", error: str = None, usage: dict = None):
|
|
395
|
+
"""文件输出回调"""
|
|
396
|
+
nonlocal last_flush_time
|
|
397
|
+
if file_writer is None:
|
|
398
|
+
return
|
|
399
|
+
record = {
|
|
400
|
+
"index": original_idx,
|
|
401
|
+
"output": content,
|
|
402
|
+
"status": status,
|
|
403
|
+
"input": messages_list[original_idx],
|
|
404
|
+
}
|
|
405
|
+
if metadata_list is not None:
|
|
406
|
+
record["metadata"] = metadata_list[original_idx]
|
|
407
|
+
if usage is not None:
|
|
408
|
+
record["usage"] = usage
|
|
409
|
+
if error:
|
|
410
|
+
record["error"] = error
|
|
411
|
+
file_buffer.append(record)
|
|
412
|
+
# 基于时间刷新
|
|
413
|
+
if time.time() - last_flush_time >= flush_interval:
|
|
414
|
+
flush_to_file()
|
|
415
|
+
|
|
416
|
+
try:
|
|
417
|
+
# 计算实际需要执行的索引(排除文件中已完成的)
|
|
418
|
+
all_indices = set(range(len(messages_list)))
|
|
419
|
+
indices_to_skip = completed_indices & all_indices
|
|
420
|
+
if indices_to_skip:
|
|
421
|
+
logger.info(f"从文件恢复跳过: {len(indices_to_skip)}/{len(messages_list)}")
|
|
422
|
+
|
|
423
|
+
# 带缓存执行
|
|
424
|
+
if use_cache and self._response_cache:
|
|
425
|
+
# 查询缓存(传递 kwargs 以确保不同参数配置使用不同缓存键)
|
|
426
|
+
cached_responses, uncached_indices = self._response_cache.get_batch(
|
|
427
|
+
messages_list, model=effective_model, **kwargs
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
# 将缓存命中的写入文件(如果文件中没有)
|
|
431
|
+
for i, resp in enumerate(cached_responses):
|
|
432
|
+
if resp is not None and i not in completed_indices:
|
|
433
|
+
on_file_result(i, resp)
|
|
434
|
+
|
|
435
|
+
# 过滤掉文件中已完成的
|
|
436
|
+
actual_uncached = [i for i in uncached_indices if i not in completed_indices]
|
|
437
|
+
|
|
438
|
+
progress = None
|
|
439
|
+
if actual_uncached:
|
|
440
|
+
logger.info(f"待执行: {len(actual_uncached)}/{len(messages_list)}")
|
|
441
|
+
|
|
442
|
+
uncached_messages = [messages_list[i] for i in actual_uncached]
|
|
443
|
+
request_params = [
|
|
444
|
+
{"json": self._build_request_body(m, effective_model, **kwargs), "headers": headers}
|
|
445
|
+
for m in uncached_messages
|
|
446
|
+
]
|
|
447
|
+
|
|
448
|
+
# 选择提取器
|
|
449
|
+
extract_fn = extractor_with_usage if return_usage else extractor
|
|
450
|
+
|
|
451
|
+
async for batch in self._client.aiter_stream_requests(
|
|
452
|
+
request_params=request_params,
|
|
453
|
+
url=effective_url,
|
|
454
|
+
method="POST",
|
|
455
|
+
show_progress=show_progress,
|
|
456
|
+
total_requests=len(uncached_messages),
|
|
457
|
+
):
|
|
458
|
+
for result in batch.completed_requests:
|
|
459
|
+
original_idx = actual_uncached[result.request_id]
|
|
460
|
+
# 检查请求状态
|
|
461
|
+
if result.status != "success":
|
|
462
|
+
error_msg = result.data.get("error", "Unknown error") if isinstance(result.data, dict) else str(result.data)
|
|
463
|
+
logger.warning(f"请求失败: {error_msg}")
|
|
464
|
+
cached_responses[original_idx] = None
|
|
465
|
+
on_file_result(original_idx, None, "error", error_msg)
|
|
466
|
+
continue
|
|
467
|
+
try:
|
|
468
|
+
extracted = extract_fn(result)
|
|
469
|
+
cached_responses[original_idx] = extracted
|
|
470
|
+
# 写入缓存(仅当不需要 usage 时,因为缓存不存储 usage)
|
|
471
|
+
if not return_usage:
|
|
472
|
+
self._response_cache.set(
|
|
473
|
+
messages_list[original_idx], extracted, model=effective_model, **kwargs
|
|
474
|
+
)
|
|
475
|
+
# 文件输出(存储 content 和 usage)
|
|
476
|
+
if return_usage:
|
|
477
|
+
on_file_result(original_idx, extracted.content, usage=extracted.usage)
|
|
478
|
+
else:
|
|
479
|
+
on_file_result(original_idx, extracted)
|
|
480
|
+
except Exception as e:
|
|
481
|
+
logger.warning(f"提取结果失败: {e}")
|
|
482
|
+
cached_responses[original_idx] = None
|
|
483
|
+
on_file_result(original_idx, None, "error", str(e))
|
|
484
|
+
if batch.is_final:
|
|
485
|
+
progress = batch.progress
|
|
486
|
+
|
|
487
|
+
responses = cached_responses
|
|
488
|
+
else:
|
|
489
|
+
# 不使用缓存,直接批量执行(流式处理以支持增量保存)
|
|
490
|
+
indices_to_run = [i for i in range(len(messages_list)) if i not in completed_indices]
|
|
491
|
+
responses = [None] * len(messages_list)
|
|
492
|
+
|
|
493
|
+
# 选择提取器
|
|
494
|
+
extract_fn = extractor_with_usage if return_usage else extractor
|
|
495
|
+
|
|
496
|
+
progress = None
|
|
497
|
+
if indices_to_run:
|
|
498
|
+
messages_to_run = [messages_list[i] for i in indices_to_run]
|
|
499
|
+
request_params = [
|
|
500
|
+
{"json": self._build_request_body(m, effective_model, **kwargs), "headers": headers}
|
|
501
|
+
for m in messages_to_run
|
|
502
|
+
]
|
|
503
|
+
# 使用流式处理,每完成一个请求就写入文件
|
|
504
|
+
async for batch in self._client.aiter_stream_requests(
|
|
505
|
+
request_params=request_params,
|
|
506
|
+
url=effective_url,
|
|
507
|
+
method="POST",
|
|
508
|
+
show_progress=show_progress,
|
|
509
|
+
total_requests=len(messages_to_run),
|
|
510
|
+
):
|
|
511
|
+
for result in batch.completed_requests:
|
|
512
|
+
original_idx = indices_to_run[result.request_id]
|
|
513
|
+
# 检查请求状态
|
|
514
|
+
if result.status != "success":
|
|
515
|
+
error_msg = result.data.get("error", "Unknown error") if isinstance(result.data, dict) else str(result.data)
|
|
516
|
+
logger.warning(f"请求失败: {error_msg}")
|
|
517
|
+
responses[original_idx] = None
|
|
518
|
+
on_file_result(original_idx, None, "error", error_msg)
|
|
519
|
+
continue
|
|
520
|
+
try:
|
|
521
|
+
extracted = extract_fn(result)
|
|
522
|
+
responses[original_idx] = extracted
|
|
523
|
+
# 文件输出(存储 content 和 usage)
|
|
524
|
+
if return_usage:
|
|
525
|
+
on_file_result(original_idx, extracted.content, usage=extracted.usage)
|
|
526
|
+
else:
|
|
527
|
+
on_file_result(original_idx, extracted)
|
|
528
|
+
except Exception as e:
|
|
529
|
+
logger.warning(f"Error: {e}, set content to None")
|
|
530
|
+
responses[original_idx] = None
|
|
531
|
+
on_file_result(original_idx, None, "error", str(e))
|
|
532
|
+
if batch.is_final:
|
|
533
|
+
progress = batch.progress
|
|
534
|
+
|
|
535
|
+
finally:
|
|
536
|
+
# 确保最后的数据写入
|
|
537
|
+
flush_to_file()
|
|
538
|
+
if file_writer:
|
|
539
|
+
file_writer.close()
|
|
540
|
+
# 自动 compact:去重,保留每个 index 的最新成功记录
|
|
541
|
+
self._compact_output_file(output_jsonl)
|
|
542
|
+
|
|
543
|
+
summary = progress.summary(print_to_console=False) if progress else None
|
|
544
|
+
return (responses, summary) if return_summary else responses
|
|
545
|
+
|
|
546
|
+
def _compact_output_file(self, file_path: str):
|
|
547
|
+
"""去重输出文件,保留每个 index 的最新成功记录"""
|
|
548
|
+
import os
|
|
549
|
+
|
|
550
|
+
tmp_path = file_path + ".tmp"
|
|
551
|
+
try:
|
|
552
|
+
records = {}
|
|
553
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
554
|
+
for line in f:
|
|
555
|
+
if not line.strip():
|
|
556
|
+
continue
|
|
557
|
+
r = json.loads(line)
|
|
558
|
+
idx = r.get("index")
|
|
559
|
+
if idx is None:
|
|
560
|
+
continue
|
|
561
|
+
# 成功记录优先,或者该 index 还没有记录
|
|
562
|
+
if r.get("status") == "success" or idx not in records:
|
|
563
|
+
records[idx] = r
|
|
564
|
+
|
|
565
|
+
# 先写入临时文件
|
|
566
|
+
with open(tmp_path, "w", encoding="utf-8") as f:
|
|
567
|
+
for r in sorted(records.values(), key=lambda x: x["index"]):
|
|
568
|
+
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
|
569
|
+
|
|
570
|
+
# 原子替换(同一文件系统上 replace 是原子操作)
|
|
571
|
+
os.replace(tmp_path, file_path)
|
|
572
|
+
except Exception as e:
|
|
573
|
+
logger.warning(f"Compact 输出文件失败: {e}")
|
|
574
|
+
# 清理可能残留的临时文件
|
|
575
|
+
if os.path.exists(tmp_path):
|
|
576
|
+
os.remove(tmp_path)
|
|
577
|
+
|
|
578
|
+
def chat_completions_batch_sync(
|
|
579
|
+
self,
|
|
580
|
+
messages_list: List[List[dict]],
|
|
581
|
+
model: str = None,
|
|
582
|
+
return_raw: bool = False,
|
|
583
|
+
return_usage: bool = False,
|
|
584
|
+
show_progress: bool = True,
|
|
585
|
+
return_summary: bool = False,
|
|
586
|
+
output_jsonl: Optional[str] = None,
|
|
587
|
+
flush_interval: float = 1.0,
|
|
588
|
+
metadata_list: Optional[List[dict]] = None,
|
|
589
|
+
**kwargs,
|
|
590
|
+
) -> Union[List[str], List[ChatCompletionResult], tuple]:
|
|
591
|
+
"""同步版本的批量聊天完成"""
|
|
592
|
+
return asyncio.run(
|
|
593
|
+
self.chat_completions_batch(
|
|
594
|
+
messages_list=messages_list,
|
|
595
|
+
model=model,
|
|
596
|
+
return_raw=return_raw,
|
|
597
|
+
return_usage=return_usage,
|
|
598
|
+
show_progress=show_progress,
|
|
599
|
+
return_summary=return_summary,
|
|
600
|
+
output_jsonl=output_jsonl,
|
|
601
|
+
flush_interval=flush_interval,
|
|
602
|
+
metadata_list=metadata_list,
|
|
603
|
+
**kwargs,
|
|
604
|
+
)
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
async def iter_chat_completions_batch(
|
|
608
|
+
self,
|
|
609
|
+
messages_list: List[List[dict]],
|
|
610
|
+
model: str = None,
|
|
611
|
+
return_raw: bool = False,
|
|
612
|
+
return_usage: bool = False,
|
|
613
|
+
show_progress: bool = True,
|
|
614
|
+
preprocess_msg: bool = False,
|
|
615
|
+
output_jsonl: Optional[str] = None,
|
|
616
|
+
flush_interval: float = 1.0,
|
|
617
|
+
metadata_list: Optional[List[dict]] = None,
|
|
618
|
+
batch_size: int = None,
|
|
619
|
+
url: str = None,
|
|
620
|
+
**kwargs,
|
|
621
|
+
):
|
|
622
|
+
"""
|
|
623
|
+
迭代式批量聊天完成(边请求边返回结果)
|
|
624
|
+
|
|
625
|
+
与 chat_completions_batch 功能相同,但以流式方式逐条返回结果,
|
|
626
|
+
适合处理大批量数据时节省内存。
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
messages_list: 消息列表
|
|
630
|
+
model: 模型名称
|
|
631
|
+
return_raw: 是否返回原始响应(影响 result.content 的内容)
|
|
632
|
+
return_usage: 是否在 result 对象上添加 usage 属性
|
|
633
|
+
show_progress: 是否显示进度条
|
|
634
|
+
preprocess_msg: 是否预处理消息
|
|
635
|
+
output_jsonl: 输出文件路径(JSONL 格式),用于持久化保存结果
|
|
636
|
+
flush_interval: 文件刷新间隔(秒),默认 1 秒
|
|
637
|
+
metadata_list: 元数据列表,与 messages_list 等长,每个元素保存到对应输出记录
|
|
638
|
+
batch_size: 每批返回的数量(传递给底层请求器)
|
|
639
|
+
url: 自定义请求 URL,默认使用 _get_url() 生成
|
|
640
|
+
|
|
641
|
+
Yields:
|
|
642
|
+
result: 包含以下属性的结果对象
|
|
643
|
+
- content: 提取后的内容 (str | dict)
|
|
644
|
+
- usage: token 用量信息(仅当 return_usage=True 时)
|
|
645
|
+
- original_idx: 原始索引
|
|
646
|
+
- latency: 请求延迟(秒)
|
|
647
|
+
- status: 状态 ('success', 'error', 'cached')
|
|
648
|
+
- error: 错误信息(如果有)
|
|
649
|
+
- data: 原始响应数据
|
|
650
|
+
- summary: 最后一个 result 包含整体统计 (dict),其他为 None
|
|
651
|
+
- total: 总请求数
|
|
652
|
+
- success: 成功数
|
|
653
|
+
- failed: 失败数
|
|
654
|
+
- cached: 缓存命中数
|
|
655
|
+
- elapsed: 总耗时(秒)
|
|
656
|
+
- avg_latency: 平均延迟(秒)
|
|
657
|
+
|
|
658
|
+
Note:
|
|
659
|
+
缓存由初始化时的 cache 参数控制,return_usage=True 时自动跳过缓存
|
|
660
|
+
"""
|
|
661
|
+
effective_model = self._get_effective_model(model)
|
|
662
|
+
effective_url = url or self._get_url(effective_model, stream=False)
|
|
663
|
+
headers = self._get_headers()
|
|
664
|
+
|
|
665
|
+
# metadata_list 长度校验
|
|
666
|
+
if metadata_list is not None and len(metadata_list) != len(messages_list):
|
|
667
|
+
raise ValueError(
|
|
668
|
+
f"metadata_list 长度 ({len(metadata_list)}) 必须与 messages_list 长度 ({len(messages_list)}) 一致"
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
# output_jsonl 扩展名校验
|
|
672
|
+
if output_jsonl and not output_jsonl.endswith(".jsonl"):
|
|
673
|
+
raise ValueError(f"output_jsonl 必须使用 .jsonl 扩展名,当前: {output_jsonl}")
|
|
674
|
+
|
|
675
|
+
messages_list = await self._preprocess_messages_batch(messages_list, preprocess_msg)
|
|
676
|
+
|
|
677
|
+
# return_usage 时跳过缓存
|
|
678
|
+
use_cache = self._response_cache is not None and not return_usage
|
|
679
|
+
|
|
680
|
+
def extractor(result):
|
|
681
|
+
if return_raw:
|
|
682
|
+
return result.data
|
|
683
|
+
return self._extract_content(result.data) if result.data else None
|
|
684
|
+
|
|
685
|
+
# 文件输出相关状态
|
|
686
|
+
file_writer = None
|
|
687
|
+
file_buffer = []
|
|
688
|
+
last_flush_time = time.time()
|
|
689
|
+
completed_indices = set()
|
|
690
|
+
|
|
691
|
+
# 如果指定了输出文件,读取已完成的索引(断点续传)
|
|
692
|
+
if output_jsonl:
|
|
693
|
+
output_path = Path(output_jsonl)
|
|
694
|
+
if output_path.exists():
|
|
695
|
+
records = []
|
|
696
|
+
with open(output_path, "r", encoding="utf-8") as f:
|
|
697
|
+
for line in f:
|
|
698
|
+
try:
|
|
699
|
+
record = json.loads(line.strip())
|
|
700
|
+
if record.get("status") == "success" and "input" in record:
|
|
701
|
+
idx = record.get("index")
|
|
702
|
+
if 0 <= idx < len(messages_list):
|
|
703
|
+
records.append(record)
|
|
704
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
705
|
+
continue
|
|
706
|
+
|
|
707
|
+
# 首尾校验
|
|
708
|
+
file_valid = True
|
|
709
|
+
if records:
|
|
710
|
+
first, last = records[0], records[-1]
|
|
711
|
+
if first["input"] != messages_list[first["index"]]:
|
|
712
|
+
file_valid = False
|
|
713
|
+
elif len(records) > 1 and last["input"] != messages_list[last["index"]]:
|
|
714
|
+
file_valid = False
|
|
715
|
+
|
|
716
|
+
if file_valid:
|
|
717
|
+
completed_indices = {r["index"] for r in records}
|
|
718
|
+
if completed_indices:
|
|
719
|
+
logger.info(f"从文件恢复: 已完成 {len(completed_indices)}/{len(messages_list)}")
|
|
720
|
+
else:
|
|
721
|
+
raise ValueError(
|
|
722
|
+
f"文件校验失败: {output_jsonl} 中的 input 与当前 messages_list 不匹配。"
|
|
723
|
+
f"请删除或重命名该文件后重试。"
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
file_writer = open(output_path, "a", encoding="utf-8")
|
|
727
|
+
|
|
728
|
+
def flush_to_file():
|
|
729
|
+
nonlocal file_buffer, last_flush_time
|
|
730
|
+
if file_writer and file_buffer:
|
|
731
|
+
for record in file_buffer:
|
|
732
|
+
file_writer.write(json.dumps(record, ensure_ascii=False) + "\n")
|
|
733
|
+
file_writer.flush()
|
|
734
|
+
file_buffer = []
|
|
735
|
+
last_flush_time = time.time()
|
|
736
|
+
|
|
737
|
+
def on_file_result(original_idx: int, content: Any, status: str = "success", error: str = None, usage: dict = None):
|
|
738
|
+
nonlocal last_flush_time
|
|
739
|
+
if file_writer is None:
|
|
740
|
+
return
|
|
741
|
+
record = {
|
|
742
|
+
"index": original_idx,
|
|
743
|
+
"output": content,
|
|
744
|
+
"status": status,
|
|
745
|
+
"input": messages_list[original_idx],
|
|
746
|
+
}
|
|
747
|
+
if metadata_list is not None:
|
|
748
|
+
record["metadata"] = metadata_list[original_idx]
|
|
749
|
+
if usage is not None:
|
|
750
|
+
record["usage"] = usage
|
|
751
|
+
if error:
|
|
752
|
+
record["error"] = error
|
|
753
|
+
file_buffer.append(record)
|
|
754
|
+
if time.time() - last_flush_time >= flush_interval:
|
|
755
|
+
flush_to_file()
|
|
756
|
+
|
|
757
|
+
try:
|
|
758
|
+
# 统计信息
|
|
759
|
+
total_count = len(messages_list)
|
|
760
|
+
yielded_count = 0
|
|
761
|
+
success_count = 0
|
|
762
|
+
cached_count = 0
|
|
763
|
+
start_time = time.time()
|
|
764
|
+
total_latency = 0.0
|
|
765
|
+
last_progress = None
|
|
766
|
+
|
|
767
|
+
# 查询缓存
|
|
768
|
+
cached_responses = [None] * len(messages_list)
|
|
769
|
+
uncached_indices = list(range(len(messages_list)))
|
|
770
|
+
|
|
771
|
+
if use_cache and self._response_cache:
|
|
772
|
+
cached_responses, uncached_indices = self._response_cache.get_batch(
|
|
773
|
+
messages_list, model=effective_model, **kwargs
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
# 先 yield 缓存命中的结果
|
|
777
|
+
for i, resp in enumerate(cached_responses):
|
|
778
|
+
if resp is not None:
|
|
779
|
+
if i not in completed_indices:
|
|
780
|
+
on_file_result(i, resp)
|
|
781
|
+
# 缓存命中时创建结果对象
|
|
782
|
+
from types import SimpleNamespace
|
|
783
|
+
|
|
784
|
+
yielded_count += 1
|
|
785
|
+
cached_count += 1
|
|
786
|
+
success_count += 1
|
|
787
|
+
is_last = yielded_count == total_count
|
|
788
|
+
|
|
789
|
+
cached_result = SimpleNamespace(
|
|
790
|
+
content=resp,
|
|
791
|
+
usage=None, # 缓存不包含 usage 信息
|
|
792
|
+
original_idx=i,
|
|
793
|
+
latency=0.0,
|
|
794
|
+
status="cached",
|
|
795
|
+
error=None,
|
|
796
|
+
data=None,
|
|
797
|
+
summary=None,
|
|
798
|
+
)
|
|
799
|
+
if is_last:
|
|
800
|
+
cached_result.summary = {
|
|
801
|
+
"total": total_count,
|
|
802
|
+
"success": success_count,
|
|
803
|
+
"failed": total_count - success_count,
|
|
804
|
+
"cached": cached_count,
|
|
805
|
+
"elapsed": time.time() - start_time,
|
|
806
|
+
"avg_latency": total_latency / max(yielded_count - cached_count, 1),
|
|
807
|
+
}
|
|
808
|
+
yield cached_result
|
|
809
|
+
|
|
810
|
+
# 过滤掉文件中已完成的
|
|
811
|
+
actual_uncached = [i for i in uncached_indices if i not in completed_indices]
|
|
812
|
+
|
|
813
|
+
if actual_uncached:
|
|
814
|
+
logger.info(f"待执行: {len(actual_uncached)}/{len(messages_list)}")
|
|
815
|
+
|
|
816
|
+
uncached_messages = [messages_list[i] for i in actual_uncached]
|
|
817
|
+
request_params = [
|
|
818
|
+
{"json": self._build_request_body(m, effective_model, **kwargs), "headers": headers}
|
|
819
|
+
for m in uncached_messages
|
|
820
|
+
]
|
|
821
|
+
|
|
822
|
+
async for batch in self._client.aiter_stream_requests(
|
|
823
|
+
request_params=request_params,
|
|
824
|
+
url=effective_url,
|
|
825
|
+
method="POST",
|
|
826
|
+
show_progress=show_progress,
|
|
827
|
+
batch_size=batch_size,
|
|
828
|
+
total_requests=len(uncached_messages),
|
|
829
|
+
):
|
|
830
|
+
for result in batch.completed_requests:
|
|
831
|
+
original_idx = actual_uncached[result.request_id]
|
|
832
|
+
yielded_count += 1
|
|
833
|
+
is_last = yielded_count == total_count
|
|
834
|
+
|
|
835
|
+
# 检查请求状态
|
|
836
|
+
if result.status != "success":
|
|
837
|
+
error_msg = result.data.get("error", "Unknown error") if isinstance(result.data, dict) else str(result.data)
|
|
838
|
+
logger.warning(f"请求失败: {error_msg}")
|
|
839
|
+
on_file_result(original_idx, None, "error", error_msg)
|
|
840
|
+
result.content = None
|
|
841
|
+
result.usage = None
|
|
842
|
+
result.original_idx = original_idx
|
|
843
|
+
result.error = error_msg
|
|
844
|
+
else:
|
|
845
|
+
try:
|
|
846
|
+
content = extractor(result)
|
|
847
|
+
usage = self._extract_usage(result.data) if return_usage else None
|
|
848
|
+
# 写入缓存
|
|
849
|
+
if use_cache and self._response_cache and content is not None and not return_raw:
|
|
850
|
+
self._response_cache.set(
|
|
851
|
+
messages_list[original_idx], content, model=effective_model, **kwargs
|
|
852
|
+
)
|
|
853
|
+
on_file_result(original_idx, content, usage=usage)
|
|
854
|
+
# 在 result 对象上添加属性
|
|
855
|
+
result.content = content
|
|
856
|
+
result.usage = usage
|
|
857
|
+
result.original_idx = original_idx
|
|
858
|
+
success_count += 1
|
|
859
|
+
total_latency += result.latency
|
|
860
|
+
except Exception as e:
|
|
861
|
+
logger.warning(f"提取结果失败: {e}")
|
|
862
|
+
on_file_result(original_idx, None, "error", str(e))
|
|
863
|
+
result.content = None
|
|
864
|
+
result.usage = None
|
|
865
|
+
result.original_idx = original_idx
|
|
866
|
+
|
|
867
|
+
# 最后一个 result 添加 summary
|
|
868
|
+
result.summary = None
|
|
869
|
+
if is_last:
|
|
870
|
+
result.summary = {
|
|
871
|
+
"total": total_count,
|
|
872
|
+
"success": success_count,
|
|
873
|
+
"failed": total_count - success_count,
|
|
874
|
+
"cached": cached_count,
|
|
875
|
+
"elapsed": time.time() - start_time,
|
|
876
|
+
"avg_latency": total_latency / max(yielded_count - cached_count, 1),
|
|
877
|
+
}
|
|
878
|
+
yield result
|
|
879
|
+
if batch.is_final:
|
|
880
|
+
last_progress = batch.progress
|
|
881
|
+
|
|
882
|
+
finally:
|
|
883
|
+
flush_to_file()
|
|
884
|
+
if file_writer:
|
|
885
|
+
file_writer.close()
|
|
886
|
+
# 自动 compact:去重,保留每个 index 的最新成功记录
|
|
887
|
+
self._compact_output_file(output_jsonl)
|
|
888
|
+
|
|
889
|
+
async def chat_completions_stream(
|
|
890
|
+
self,
|
|
891
|
+
messages: List[dict],
|
|
892
|
+
model: str = None,
|
|
893
|
+
return_usage: bool = False,
|
|
894
|
+
preprocess_msg: bool = False,
|
|
895
|
+
url: str = None,
|
|
896
|
+
timeout: int = None,
|
|
897
|
+
**kwargs,
|
|
898
|
+
):
|
|
899
|
+
"""
|
|
900
|
+
流式聊天完成
|
|
901
|
+
|
|
902
|
+
Args:
|
|
903
|
+
messages: 消息列表
|
|
904
|
+
model: 模型名称
|
|
905
|
+
return_usage: 是否返回 usage 信息。当为 True 时,yield 的是 dict:
|
|
906
|
+
- {"type": "content", "content": "..."} 表示内容片段
|
|
907
|
+
- {"type": "usage", "usage": {...}} 表示 token 用量(最后一条)
|
|
908
|
+
当为 False 时(默认),yield 的是 str 内容片段
|
|
909
|
+
preprocess_msg: 是否预处理消息
|
|
910
|
+
url: 自定义请求 URL,默认使用 _get_stream_url() 生成
|
|
911
|
+
timeout: 超时时间(秒),默认使用客户端配置
|
|
912
|
+
|
|
913
|
+
Yields:
|
|
914
|
+
- return_usage=False: str 内容片段
|
|
915
|
+
- return_usage=True: dict,包含 type 和对应数据
|
|
916
|
+
"""
|
|
917
|
+
import aiohttp
|
|
918
|
+
import json
|
|
919
|
+
|
|
920
|
+
effective_model = self._get_effective_model(model)
|
|
921
|
+
messages = await self._preprocess_messages(messages, preprocess_msg)
|
|
922
|
+
|
|
923
|
+
body = self._build_request_body(messages, effective_model, stream=True, **kwargs)
|
|
924
|
+
|
|
925
|
+
# 当需要 usage 时,添加 stream_options(OpenAI 格式)
|
|
926
|
+
if return_usage:
|
|
927
|
+
body["stream_options"] = {"include_usage": True}
|
|
928
|
+
|
|
929
|
+
effective_url = url or self._get_stream_url(effective_model)
|
|
930
|
+
headers = self._get_headers()
|
|
931
|
+
|
|
932
|
+
effective_timeout = timeout if timeout is not None else self._timeout
|
|
933
|
+
aio_timeout = aiohttp.ClientTimeout(total=effective_timeout)
|
|
934
|
+
|
|
935
|
+
async with aiohttp.ClientSession(trust_env=True) as session:
|
|
936
|
+
async with session.post(effective_url, json=body, headers=headers, timeout=aio_timeout) as response:
|
|
937
|
+
if response.status != 200:
|
|
938
|
+
error_text = await response.text()
|
|
939
|
+
raise Exception(f"HTTP {response.status}: {error_text}")
|
|
940
|
+
|
|
941
|
+
async for line in response.content:
|
|
942
|
+
line = line.decode("utf-8").strip()
|
|
943
|
+
if line.startswith("data: "):
|
|
944
|
+
data_str = line[6:]
|
|
945
|
+
if data_str == "[DONE]":
|
|
946
|
+
break
|
|
947
|
+
try:
|
|
948
|
+
data = json.loads(data_str)
|
|
949
|
+
|
|
950
|
+
# 检查是否包含 usage(流式响应的最后一个 chunk)
|
|
951
|
+
if return_usage and "usage" in data and data["usage"]:
|
|
952
|
+
yield {"type": "usage", "usage": data["usage"]}
|
|
953
|
+
continue
|
|
954
|
+
|
|
955
|
+
content = self._extract_stream_content(data)
|
|
956
|
+
if content:
|
|
957
|
+
if return_usage:
|
|
958
|
+
yield {"type": "content", "content": content}
|
|
959
|
+
else:
|
|
960
|
+
yield content
|
|
961
|
+
except json.JSONDecodeError:
|
|
962
|
+
continue
|
|
963
|
+
|
|
964
|
+
def model_list(self) -> List[str]:
|
|
965
|
+
raise NotImplementedError("子类需要实现 model_list 方法")
|
|
966
|
+
|
|
967
|
+
def __repr__(self) -> str:
|
|
968
|
+
return f"{self.__class__.__name__}(model='{self._model}')"
|
|
969
|
+
|
|
970
|
+
# ========== 资源管理 ==========
|
|
971
|
+
|
|
972
|
+
def close(self):
|
|
973
|
+
"""关闭客户端,释放资源(如缓存连接)"""
|
|
974
|
+
if self._response_cache is not None:
|
|
975
|
+
self._response_cache.close()
|
|
976
|
+
self._response_cache = None
|
|
977
|
+
|
|
978
|
+
def __enter__(self):
|
|
979
|
+
return self
|
|
980
|
+
|
|
981
|
+
def __exit__(self, *args):
|
|
982
|
+
self.close()
|
|
983
|
+
|
|
984
|
+
async def __aenter__(self):
|
|
985
|
+
return self
|
|
986
|
+
|
|
987
|
+
async def __aexit__(self, *args):
|
|
988
|
+
self.close()
|