llm-engine-kitty 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_engine/__init__.py +54 -0
- llm_engine/engine.py +771 -0
- llm_engine/general_engine.py +562 -0
- llm_engine/kitty/__init__.py +8 -0
- llm_engine/kitty/__main__.py +46 -0
- llm_engine/kitty/client.py +550 -0
- llm_engine/kitty/config.py +83 -0
- llm_engine/kitty/engine.py +1077 -0
- llm_engine/kitty/protocol.py +213 -0
- llm_engine/kitty/schemas.py +89 -0
- llm_engine/kitty/server.py +408 -0
- llm_engine/model_config.py +112 -0
- llm_engine/schemas.py +251 -0
- llm_engine/utils.py +34 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/METADATA +15 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/RECORD +18 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/WHEEL +5 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/top_level.txt +1 -0
llm_engine/engine.py
ADDED
|
@@ -0,0 +1,771 @@
|
|
|
1
|
+
# llm_engine/engine.py
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import httpx
|
|
5
|
+
import json
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
from typing import Any, Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
import kitty_logger
|
|
14
|
+
|
|
15
|
+
from .model_config import ModelConfig, ModelConfigRegistry
|
|
16
|
+
from .schemas import (
|
|
17
|
+
InferenceParameters,
|
|
18
|
+
InferenceRequest,
|
|
19
|
+
InferenceRequestResult,
|
|
20
|
+
PreparedRequest,
|
|
21
|
+
Message,
|
|
22
|
+
MessageRole,
|
|
23
|
+
ModelOutput,
|
|
24
|
+
ChunkChoice,
|
|
25
|
+
ChunkDelta,
|
|
26
|
+
ChatCompletionChunk,
|
|
27
|
+
ChatCompletionChoice,
|
|
28
|
+
ChatCompletionResponse,
|
|
29
|
+
)
|
|
30
|
+
from .utils import gen_unique_id
|
|
31
|
+
|
|
32
|
+
logger = kitty_logger.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BaseEngine(ABC):
|
|
36
|
+
|
|
37
|
+
__slots__ = ()
|
|
38
|
+
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SyncEngine(BaseEngine):
|
|
43
|
+
|
|
44
|
+
__slots__ = ()
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def inference(self, request: InferenceRequest) -> InferenceRequestResult:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
def batch_inference(self, requests: List[InferenceRequest], max_concurrency: int = 1) -> List[InferenceRequestResult]:
|
|
51
|
+
|
|
52
|
+
class_name = self.__class__.__name__
|
|
53
|
+
|
|
54
|
+
# fmt: off
|
|
55
|
+
logger.warning(
|
|
56
|
+
f"引擎 '{class_name}' 未实现高效的 'batch_inference' 方法,将回退至串行模式。"
|
|
57
|
+
f"指定的并发数 (max_concurrency={max_concurrency}) 不会生效,"
|
|
58
|
+
f"正在串行处理 {len(requests)} 个请求,这可能会影响执行效率。"
|
|
59
|
+
)
|
|
60
|
+
# fmt: on
|
|
61
|
+
|
|
62
|
+
return [self.inference(req) for req in requests]
|
|
63
|
+
|
|
64
|
+
def infer(self, query: str) -> str:
|
|
65
|
+
|
|
66
|
+
inference_result: InferenceRequestResult = self.inference(InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]))
|
|
67
|
+
|
|
68
|
+
if inference_result.success is True and inference_result.model_output is not None:
|
|
69
|
+
|
|
70
|
+
return inference_result.model_output.content
|
|
71
|
+
|
|
72
|
+
else:
|
|
73
|
+
|
|
74
|
+
return ""
|
|
75
|
+
|
|
76
|
+
def batch_infer(self, queries: List[str], max_concurrency: int = 8) -> List[str]:
|
|
77
|
+
|
|
78
|
+
inference_requests: List[InferenceRequest] = [InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]) for query in queries]
|
|
79
|
+
|
|
80
|
+
inference_results: List[InferenceRequestResult] = self.batch_inference(inference_requests, max_concurrency=max_concurrency)
|
|
81
|
+
|
|
82
|
+
return [result.model_output.content if result.success is True and result.model_output is not None else "" for result in inference_results]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class AsyncEngine(BaseEngine):
|
|
86
|
+
|
|
87
|
+
__slots__ = ()
|
|
88
|
+
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class CoroutineEngine(BaseEngine):
|
|
93
|
+
|
|
94
|
+
__slots__ = ()
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
async def inference(self, request: InferenceRequest) -> InferenceRequestResult:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
async def batch_inference(self, requests: List[InferenceRequest], max_concurrency: int = 1) -> List[InferenceRequestResult]:
|
|
101
|
+
|
|
102
|
+
class_name = self.__class__.__name__
|
|
103
|
+
|
|
104
|
+
# fmt: off
|
|
105
|
+
logger.warning(
|
|
106
|
+
f"引擎 '{class_name}' 未实现高效的 'batch_inference' 方法,将回退至串行模式。"
|
|
107
|
+
f"指定的并发数 (max_concurrency={max_concurrency}) 不会生效,"
|
|
108
|
+
f"正在串行处理 {len(requests)} 个请求,这可能会影响执行效率。"
|
|
109
|
+
)
|
|
110
|
+
# fmt: on
|
|
111
|
+
|
|
112
|
+
return [await self.inference(req) for req in requests]
|
|
113
|
+
|
|
114
|
+
async def infer(self, query: str) -> str:
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
inference_result: InferenceRequestResult = await self.inference(InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]))
|
|
118
|
+
except Exception as e:
|
|
119
|
+
logger.error(f"infer 抛出未捕获异常: {e}")
|
|
120
|
+
return ""
|
|
121
|
+
|
|
122
|
+
if inference_result.success is True and inference_result.model_output is not None:
|
|
123
|
+
return inference_result.model_output.content
|
|
124
|
+
else:
|
|
125
|
+
return ""
|
|
126
|
+
|
|
127
|
+
async def batch_infer(self, queries: List[str], max_concurrency: int = 8) -> List[str]:
|
|
128
|
+
|
|
129
|
+
inference_requests: List[InferenceRequest] = [InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]) for query in queries]
|
|
130
|
+
|
|
131
|
+
inference_results: List[InferenceRequestResult] = await self.batch_inference(inference_requests, max_concurrency=max_concurrency)
|
|
132
|
+
|
|
133
|
+
return [result.model_output.content if result.success is True and result.model_output is not None else "" for result in inference_results]
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class SimpleEngine(SyncEngine):
|
|
137
|
+
|
|
138
|
+
__slots__ = (
|
|
139
|
+
"model_registry",
|
|
140
|
+
"default_model_name",
|
|
141
|
+
"default_model",
|
|
142
|
+
"default_api_key",
|
|
143
|
+
"default_inference_parameters",
|
|
144
|
+
"extra_headers",
|
|
145
|
+
"extra_payload",
|
|
146
|
+
"stream",
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
model_registry: ModelConfigRegistry,
|
|
152
|
+
default_model: str,
|
|
153
|
+
default_api_key: Optional[str] = None,
|
|
154
|
+
default_inference_parameters: Optional[InferenceParameters] = None,
|
|
155
|
+
extra_headers: Optional[Dict[str, str]] = None,
|
|
156
|
+
extra_payload: Optional[Dict[str, Any]] = None,
|
|
157
|
+
stream: bool = False,
|
|
158
|
+
) -> None:
|
|
159
|
+
self.model_registry: ModelConfigRegistry = model_registry
|
|
160
|
+
self.default_model_name: str = default_model
|
|
161
|
+
self.default_model: ModelConfig = self.model_registry.get(default_model)
|
|
162
|
+
self.default_api_key: Optional[str] = default_api_key
|
|
163
|
+
self.default_inference_parameters: Optional[InferenceParameters] = default_inference_parameters
|
|
164
|
+
self.extra_headers: Dict[str, str] = extra_headers if extra_headers is not None else {}
|
|
165
|
+
self.extra_payload: Dict[str, Any] = extra_payload if extra_payload is not None else {}
|
|
166
|
+
self.stream: bool = stream
|
|
167
|
+
|
|
168
|
+
def get_model_api_key(self, inference_request: InferenceRequest, model: ModelConfig) -> str:
|
|
169
|
+
|
|
170
|
+
if inference_request.api_key is not None:
|
|
171
|
+
return inference_request.api_key
|
|
172
|
+
elif model.api_key is not None:
|
|
173
|
+
return model.api_key
|
|
174
|
+
elif self.default_api_key is not None:
|
|
175
|
+
return self.default_api_key
|
|
176
|
+
else:
|
|
177
|
+
logger.warning(f"没有找到可用的 api_key, 将传递空 api_key")
|
|
178
|
+
return ""
|
|
179
|
+
|
|
180
|
+
def build_request_headers(self, inference_request: InferenceRequest, model: ModelConfig) -> Dict[str, str]:
|
|
181
|
+
|
|
182
|
+
headers: Dict[str, str] = {
|
|
183
|
+
"Authorization": "Bearer " + self.get_model_api_key(inference_request, model),
|
|
184
|
+
"Content-Type": "application/json",
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
headers.update(self.extra_headers)
|
|
188
|
+
headers.update(model.extra_headers)
|
|
189
|
+
headers.update(inference_request.extra_headers)
|
|
190
|
+
|
|
191
|
+
return headers
|
|
192
|
+
|
|
193
|
+
def build_request_payload(self, inference_request: InferenceRequest, model: ModelConfig) -> Dict[str, Any]:
|
|
194
|
+
payload: Dict[str, Any] = {
|
|
195
|
+
"model": model.model_id,
|
|
196
|
+
"stream": inference_request.stream if inference_request.stream is not None else self.stream,
|
|
197
|
+
"messages": [msg.to_dict() for msg in inference_request.messages],
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
if self.default_inference_parameters:
|
|
201
|
+
payload.update(self.default_inference_parameters.to_dict())
|
|
202
|
+
if model.default_inference_parameters:
|
|
203
|
+
payload.update(model.default_inference_parameters.to_dict())
|
|
204
|
+
if inference_request.inference_parameters:
|
|
205
|
+
payload.update(inference_request.inference_parameters.to_dict())
|
|
206
|
+
|
|
207
|
+
payload.update(self.extra_payload)
|
|
208
|
+
payload.update(model.extra_payload)
|
|
209
|
+
payload.update(inference_request.extra_payload)
|
|
210
|
+
|
|
211
|
+
return payload
|
|
212
|
+
|
|
213
|
+
def build_request(self, request: InferenceRequest) -> PreparedRequest:
|
|
214
|
+
"""核心组装逻辑:将业务请求转换为底层 HTTP 请求参数"""
|
|
215
|
+
|
|
216
|
+
model_name = request.model_name or self.default_model_name
|
|
217
|
+
model = self.model_registry.get(model_name)
|
|
218
|
+
|
|
219
|
+
return PreparedRequest(
|
|
220
|
+
model_name=model_name,
|
|
221
|
+
url=model.get_url(),
|
|
222
|
+
headers=self.build_request_headers(request, model),
|
|
223
|
+
payload=self.build_request_payload(request, model),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def send_request(
|
|
227
|
+
self,
|
|
228
|
+
url: str,
|
|
229
|
+
headers: Dict[str, str],
|
|
230
|
+
payload: Dict[str, Any],
|
|
231
|
+
stream: bool,
|
|
232
|
+
max_retries: int = 5,
|
|
233
|
+
base_delay: int = 2,
|
|
234
|
+
) -> ModelOutput:
|
|
235
|
+
|
|
236
|
+
for attempt in range(max_retries):
|
|
237
|
+
is_last_attempt: bool = attempt == (max_retries - 1)
|
|
238
|
+
try:
|
|
239
|
+
with httpx.Client(timeout=httpx.Timeout(timeout=10.0, read=3600.0)) as http_client:
|
|
240
|
+
|
|
241
|
+
with http_client.stream("POST", url, json=payload, headers=headers) as http_response:
|
|
242
|
+
|
|
243
|
+
if http_response.status_code != 200:
|
|
244
|
+
|
|
245
|
+
if is_last_attempt:
|
|
246
|
+
response_body = http_response.read().decode("utf-8", errors="replace")
|
|
247
|
+
logger.error(f"HTTP {http_response.status_code},已达最大重试次数,响应内容: {response_body[:500]}")
|
|
248
|
+
raise httpx.HTTPStatusError(f"HTTP {http_response.status_code}", request=http_response.request, response=http_response)
|
|
249
|
+
|
|
250
|
+
wait_time = self._get_wait_time(http_response.status_code, attempt, base_delay)
|
|
251
|
+
response_body = http_response.read().decode("utf-8", errors="replace")
|
|
252
|
+
logger.warning(f"HTTP {http_response.status_code}, 第 {attempt+1} 次重试,等待 {wait_time}s... 响应内容: {response_body[:500]}")
|
|
253
|
+
time.sleep(wait_time)
|
|
254
|
+
continue
|
|
255
|
+
|
|
256
|
+
if not stream:
|
|
257
|
+
full_body = http_response.read().decode("utf-8")
|
|
258
|
+
resp: ChatCompletionResponse = ChatCompletionResponse.model_validate_json(full_body)
|
|
259
|
+
choice: ChatCompletionChoice = resp.choices[0]
|
|
260
|
+
if choice.finish_reason and choice.finish_reason != "stop":
|
|
261
|
+
logger.warning(f"finish_reason='{choice.finish_reason}' (非 stop),content_len={len(choice.message.content or '')}, usage={resp.usage.model_dump() if resp.usage else None}")
|
|
262
|
+
return ModelOutput(
|
|
263
|
+
role=choice.message.role,
|
|
264
|
+
content=choice.message.content,
|
|
265
|
+
reasoning=choice.message.reasoning_content,
|
|
266
|
+
finish_reason=choice.finish_reason,
|
|
267
|
+
usage=resp.usage.model_dump() if resp.usage else None,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
else:
|
|
271
|
+
model_reasoning: str = ""
|
|
272
|
+
model_response: str = ""
|
|
273
|
+
role: Optional[str] = None
|
|
274
|
+
usage: Optional[Dict] = None
|
|
275
|
+
finish_reason: Optional[str] = None
|
|
276
|
+
|
|
277
|
+
for line in http_response.iter_lines():
|
|
278
|
+
|
|
279
|
+
if not line or line.strip() == "":
|
|
280
|
+
continue
|
|
281
|
+
|
|
282
|
+
logger.debug(f"line: '{line}'")
|
|
283
|
+
|
|
284
|
+
if not line.startswith("data: "):
|
|
285
|
+
logger.error(f"数据行没有以'data: '开头,将跳过。line: {line}")
|
|
286
|
+
continue
|
|
287
|
+
|
|
288
|
+
raw_data = line.removeprefix("data: ").strip()
|
|
289
|
+
if raw_data == "[DONE]":
|
|
290
|
+
logger.debug("收到结束标志'[DONE]', 不再解析后续包")
|
|
291
|
+
break
|
|
292
|
+
|
|
293
|
+
try:
|
|
294
|
+
chunk: ChatCompletionChunk = ChatCompletionChunk.model_validate_json(raw_data)
|
|
295
|
+
except Exception as e:
|
|
296
|
+
logger.error(f"解析数据包失败: {e}, 原始数据: {raw_data}")
|
|
297
|
+
continue
|
|
298
|
+
|
|
299
|
+
if chunk.usage:
|
|
300
|
+
if usage is not None:
|
|
301
|
+
logger.warning(f"收到多个含有usage的包,后收到的usage将会覆盖先前的token用量信息。")
|
|
302
|
+
usage = chunk.usage.model_dump()
|
|
303
|
+
|
|
304
|
+
if not chunk.choices:
|
|
305
|
+
continue
|
|
306
|
+
|
|
307
|
+
chunk_choice: ChunkChoice = chunk.choices[0]
|
|
308
|
+
chunk_delta: ChunkDelta = chunk_choice.delta
|
|
309
|
+
|
|
310
|
+
if chunk_delta.content:
|
|
311
|
+
model_response += chunk_delta.content
|
|
312
|
+
if chunk_delta.reasoning_content:
|
|
313
|
+
model_reasoning += chunk_delta.reasoning_content
|
|
314
|
+
|
|
315
|
+
if chunk_delta.role:
|
|
316
|
+
if role is not None:
|
|
317
|
+
logger.warning(f"收到多个含有role的包,后收到的role将会覆盖先前的role信息。")
|
|
318
|
+
role = chunk_delta.role
|
|
319
|
+
|
|
320
|
+
if chunk_choice.finish_reason:
|
|
321
|
+
if finish_reason is not None:
|
|
322
|
+
logger.warning(f"收到多个含有finish_reason的包,后收到的finish_reason将会覆盖先前的finish_reason信息。")
|
|
323
|
+
finish_reason = chunk_choice.finish_reason
|
|
324
|
+
|
|
325
|
+
if finish_reason and finish_reason != "stop":
|
|
326
|
+
logger.warning(f"finish_reason='{finish_reason}' (非 stop),content_len={len(model_response)}, usage={usage}")
|
|
327
|
+
return ModelOutput(
|
|
328
|
+
role=role,
|
|
329
|
+
content=model_response,
|
|
330
|
+
reasoning=model_reasoning,
|
|
331
|
+
finish_reason=finish_reason,
|
|
332
|
+
usage=usage,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
except (httpx.NetworkError, httpx.TimeoutException, httpx.HTTPStatusError) as e:
|
|
336
|
+
if is_last_attempt:
|
|
337
|
+
logger.error(f"达到最大重试次数,最后一次错误: {e}")
|
|
338
|
+
raise
|
|
339
|
+
if not isinstance(e, httpx.HTTPStatusError):
|
|
340
|
+
wait_time = base_delay**attempt
|
|
341
|
+
logger.warning(f"网络异常: {e}, 等待 {wait_time}s 后重试...")
|
|
342
|
+
time.sleep(wait_time)
|
|
343
|
+
|
|
344
|
+
raise RuntimeError("Unexpected end of retry loop")
|
|
345
|
+
|
|
346
|
+
def _get_wait_time(self, status_code: int, attempt: int, base_delay: int) -> float:
|
|
347
|
+
"""集中管理不同错误的等待时间算法"""
|
|
348
|
+
if status_code == 429:
|
|
349
|
+
# 频率限制通常需要更长的等待
|
|
350
|
+
return 10 * (attempt + 1)
|
|
351
|
+
elif status_code >= 500:
|
|
352
|
+
# 服务器内部错误使用指数退避
|
|
353
|
+
return float(base_delay**attempt)
|
|
354
|
+
return 5.0 # 其他错误默认值
|
|
355
|
+
|
|
356
|
+
def inference(self, request: InferenceRequest) -> InferenceRequestResult:
|
|
357
|
+
prepared_request: PreparedRequest = self.build_request(request)
|
|
358
|
+
model = self.model_registry.get(request.model_name or self.default_model_name)
|
|
359
|
+
|
|
360
|
+
# URL 在 build_request 时一次性选定,整个重试循环复用同一 URL。
|
|
361
|
+
# 如需重试时重新选 URL(例如多节点场景),需将 URL 选取移至 send_request 的 attempt 循环内。
|
|
362
|
+
start_time: float = time.time()
|
|
363
|
+
try:
|
|
364
|
+
model_output: ModelOutput = self.send_request(
|
|
365
|
+
url=prepared_request.url,
|
|
366
|
+
headers=prepared_request.headers,
|
|
367
|
+
payload=prepared_request.payload,
|
|
368
|
+
stream=prepared_request.payload["stream"],
|
|
369
|
+
)
|
|
370
|
+
finally:
|
|
371
|
+
model.release_url(prepared_request.url)
|
|
372
|
+
|
|
373
|
+
return InferenceRequestResult(
|
|
374
|
+
success=True,
|
|
375
|
+
task_id="",
|
|
376
|
+
request=request,
|
|
377
|
+
model_output=model_output,
|
|
378
|
+
duration=time.time() - start_time,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
def batch_inference(
|
|
382
|
+
self,
|
|
383
|
+
requests: List[InferenceRequest],
|
|
384
|
+
max_concurrency: int = 8,
|
|
385
|
+
output_file: Optional[str] = None,
|
|
386
|
+
silent_mode: bool = False,
|
|
387
|
+
) -> List[InferenceRequestResult]:
|
|
388
|
+
|
|
389
|
+
result_dict: Dict[int, InferenceRequestResult] = {}
|
|
390
|
+
|
|
391
|
+
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
|
392
|
+
future_to_idx = {executor.submit(self.inference, req): i for i, req in enumerate(requests)}
|
|
393
|
+
|
|
394
|
+
futures_iterator = as_completed(future_to_idx)
|
|
395
|
+
if not silent_mode:
|
|
396
|
+
futures_iterator = tqdm(
|
|
397
|
+
futures_iterator,
|
|
398
|
+
total=len(requests),
|
|
399
|
+
desc=f"Batch: {self.default_model_name}",
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
f = open(output_file, "w", encoding="utf-8") if output_file else None
|
|
403
|
+
|
|
404
|
+
try:
|
|
405
|
+
for future in futures_iterator:
|
|
406
|
+
idx = future_to_idx[future]
|
|
407
|
+
try:
|
|
408
|
+
result: InferenceRequestResult = future.result()
|
|
409
|
+
result_dict[idx] = result
|
|
410
|
+
|
|
411
|
+
if f:
|
|
412
|
+
f.write(result.model_dump_json() + "\n")
|
|
413
|
+
f.flush()
|
|
414
|
+
|
|
415
|
+
except Exception as e:
|
|
416
|
+
logger.error(f"Request #{idx} 抛出未捕获异常: {e}")
|
|
417
|
+
error_res = self._get_error_result(requests[idx], e)
|
|
418
|
+
result_dict[idx] = error_res
|
|
419
|
+
if f:
|
|
420
|
+
f.write(error_res.model_dump_json() + "\n")
|
|
421
|
+
finally:
|
|
422
|
+
if f:
|
|
423
|
+
f.close()
|
|
424
|
+
|
|
425
|
+
return [result_dict[i] for i in range(len(requests))]
|
|
426
|
+
|
|
427
|
+
def _get_error_result(self, request: InferenceRequest, error: Exception) -> InferenceRequestResult:
|
|
428
|
+
"""内部工具:当请求彻底失败时返回标准化错误结构"""
|
|
429
|
+
return InferenceRequestResult(
|
|
430
|
+
success=False,
|
|
431
|
+
task_id="error",
|
|
432
|
+
request=request,
|
|
433
|
+
error_message=f"{type(error).__name__}: {str(error)}",
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
def mock_request(self, query: Optional[str] = None, request: Optional[InferenceRequest] = None) -> PreparedRequest:
|
|
437
|
+
"""
|
|
438
|
+
模拟构建请求,返回最终的 URL、Headers 和 Payload。
|
|
439
|
+
用于验证配置组装、API Key 传递和参数覆盖是否符合预期。
|
|
440
|
+
"""
|
|
441
|
+
|
|
442
|
+
if request is None:
|
|
443
|
+
if query is None:
|
|
444
|
+
query = f"请求文本: [{gen_unique_id()}]"
|
|
445
|
+
request = InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)])
|
|
446
|
+
|
|
447
|
+
prepared_request: PreparedRequest = self.build_request(request)
|
|
448
|
+
|
|
449
|
+
return prepared_request
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
class SimpleCoroutineEngine(CoroutineEngine):
|
|
453
|
+
|
|
454
|
+
__slots__ = (
|
|
455
|
+
"model_registry",
|
|
456
|
+
"default_model_name",
|
|
457
|
+
"default_model",
|
|
458
|
+
"default_api_key",
|
|
459
|
+
"default_inference_parameters",
|
|
460
|
+
"extra_headers",
|
|
461
|
+
"extra_payload",
|
|
462
|
+
"stream",
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
def __init__(
|
|
466
|
+
self,
|
|
467
|
+
model_registry: ModelConfigRegistry,
|
|
468
|
+
default_model: str,
|
|
469
|
+
default_api_key: Optional[str] = None,
|
|
470
|
+
default_inference_parameters: Optional[InferenceParameters] = None,
|
|
471
|
+
extra_headers: Optional[Dict[str, str]] = None,
|
|
472
|
+
extra_payload: Optional[Dict[str, Any]] = None,
|
|
473
|
+
stream: bool = False,
|
|
474
|
+
) -> None:
|
|
475
|
+
self.model_registry: ModelConfigRegistry = model_registry
|
|
476
|
+
self.default_model_name: str = default_model
|
|
477
|
+
self.default_model: ModelConfig = self.model_registry.get(default_model)
|
|
478
|
+
self.default_api_key: Optional[str] = default_api_key
|
|
479
|
+
self.default_inference_parameters: Optional[InferenceParameters] = default_inference_parameters
|
|
480
|
+
self.extra_headers: Dict[str, str] = extra_headers if extra_headers is not None else {}
|
|
481
|
+
self.extra_payload: Dict[str, Any] = extra_payload if extra_payload is not None else {}
|
|
482
|
+
self.stream: bool = stream
|
|
483
|
+
|
|
484
|
+
def get_model_api_key(self, inference_request: InferenceRequest, model: ModelConfig) -> str:
|
|
485
|
+
|
|
486
|
+
if inference_request.api_key is not None:
|
|
487
|
+
return inference_request.api_key
|
|
488
|
+
elif model.api_key is not None:
|
|
489
|
+
return model.api_key
|
|
490
|
+
elif self.default_api_key is not None:
|
|
491
|
+
return self.default_api_key
|
|
492
|
+
else:
|
|
493
|
+
logger.warning(f"没有找到可用的 api_key, 将传递空 api_key")
|
|
494
|
+
return ""
|
|
495
|
+
|
|
496
|
+
def build_request_headers(self, inference_request: InferenceRequest, model: ModelConfig) -> Dict[str, str]:
|
|
497
|
+
|
|
498
|
+
headers: Dict[str, str] = {
|
|
499
|
+
"Authorization": "Bearer " + self.get_model_api_key(inference_request, model),
|
|
500
|
+
"Content-Type": "application/json",
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
headers.update(self.extra_headers)
|
|
504
|
+
headers.update(model.extra_headers)
|
|
505
|
+
headers.update(inference_request.extra_headers)
|
|
506
|
+
|
|
507
|
+
return headers
|
|
508
|
+
|
|
509
|
+
def build_request_payload(self, inference_request: InferenceRequest, model: ModelConfig) -> Dict[str, Any]:
|
|
510
|
+
payload: Dict[str, Any] = {
|
|
511
|
+
"model": model.model_id,
|
|
512
|
+
"stream": inference_request.stream if inference_request.stream is not None else self.stream,
|
|
513
|
+
"messages": [msg.to_dict() for msg in inference_request.messages],
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
if self.default_inference_parameters:
|
|
517
|
+
payload.update(self.default_inference_parameters.to_dict())
|
|
518
|
+
if model.default_inference_parameters:
|
|
519
|
+
payload.update(model.default_inference_parameters.to_dict())
|
|
520
|
+
if inference_request.inference_parameters:
|
|
521
|
+
payload.update(inference_request.inference_parameters.to_dict())
|
|
522
|
+
|
|
523
|
+
payload.update(self.extra_payload)
|
|
524
|
+
payload.update(model.extra_payload)
|
|
525
|
+
payload.update(inference_request.extra_payload)
|
|
526
|
+
|
|
527
|
+
return payload
|
|
528
|
+
|
|
529
|
+
def build_request(self, request: InferenceRequest) -> PreparedRequest:
|
|
530
|
+
"""核心组装逻辑:将业务请求转换为底层 HTTP 请求参数"""
|
|
531
|
+
|
|
532
|
+
model_name = request.model_name or self.default_model_name
|
|
533
|
+
model = self.model_registry.get(model_name)
|
|
534
|
+
|
|
535
|
+
return PreparedRequest(
|
|
536
|
+
model_name=model_name,
|
|
537
|
+
url=model.get_url(),
|
|
538
|
+
headers=self.build_request_headers(request, model),
|
|
539
|
+
payload=self.build_request_payload(request, model),
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
def _get_wait_time(self, status_code: int, attempt: int, base_delay: int) -> float:
|
|
543
|
+
"""集中管理不同错误的等待时间算法"""
|
|
544
|
+
if status_code == 429:
|
|
545
|
+
# return 10 * (attempt + 1)
|
|
546
|
+
return 1.0
|
|
547
|
+
elif status_code >= 500:
|
|
548
|
+
# return float(base_delay**attempt)
|
|
549
|
+
return 1.0
|
|
550
|
+
return 1.0
|
|
551
|
+
|
|
552
|
+
def _get_error_result(self, request: InferenceRequest, error: Exception) -> InferenceRequestResult:
|
|
553
|
+
"""内部工具:当请求彻底失败时返回标准化错误结构"""
|
|
554
|
+
return InferenceRequestResult(
|
|
555
|
+
success=False,
|
|
556
|
+
task_id="error",
|
|
557
|
+
request=request,
|
|
558
|
+
error_message=f"{type(error).__name__}: {str(error)}",
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
def mock_request(self, query: Optional[str] = None, request: Optional[InferenceRequest] = None) -> PreparedRequest:
|
|
562
|
+
"""
|
|
563
|
+
模拟构建请求,返回最终的 URL、Headers 和 Payload。
|
|
564
|
+
用于验证配置组装、API Key 传递和参数覆盖是否符合预期。
|
|
565
|
+
"""
|
|
566
|
+
|
|
567
|
+
if request is None:
|
|
568
|
+
if query is None:
|
|
569
|
+
query = f"请求文本: [{gen_unique_id()}]"
|
|
570
|
+
request = InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)])
|
|
571
|
+
|
|
572
|
+
prepared_request: PreparedRequest = self.build_request(request)
|
|
573
|
+
|
|
574
|
+
return prepared_request
|
|
575
|
+
|
|
576
|
+
async def send_request(
|
|
577
|
+
self,
|
|
578
|
+
url: str,
|
|
579
|
+
headers: Dict[str, str],
|
|
580
|
+
payload: Dict[str, Any],
|
|
581
|
+
stream: bool,
|
|
582
|
+
max_retries: int = 256,
|
|
583
|
+
base_delay: int = 2,
|
|
584
|
+
) -> ModelOutput:
|
|
585
|
+
|
|
586
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=10.0, read=1800.0)) as http_client:
|
|
587
|
+
for attempt in range(max_retries):
|
|
588
|
+
is_last_attempt: bool = attempt == (max_retries - 1)
|
|
589
|
+
try:
|
|
590
|
+
async with http_client.stream("POST", url, json=payload, headers=headers) as http_response:
|
|
591
|
+
|
|
592
|
+
if http_response.status_code != 200:
|
|
593
|
+
|
|
594
|
+
if is_last_attempt:
|
|
595
|
+
response_body = (await http_response.aread()).decode("utf-8", errors="replace")
|
|
596
|
+
logger.error(f"HTTP {http_response.status_code},已达最大重试次数,响应内容: {response_body[:500]}")
|
|
597
|
+
raise httpx.HTTPStatusError(f"HTTP {http_response.status_code}", request=http_response.request, response=http_response)
|
|
598
|
+
|
|
599
|
+
wait_time = self._get_wait_time(http_response.status_code, attempt, base_delay)
|
|
600
|
+
response_body = (await http_response.aread()).decode("utf-8", errors="replace")
|
|
601
|
+
logger.warning(f"HTTP {http_response.status_code}, 第 {attempt+1} 次重试,等待 {wait_time}s... 响应内容: {response_body[:500]}")
|
|
602
|
+
await asyncio.sleep(wait_time)
|
|
603
|
+
continue
|
|
604
|
+
|
|
605
|
+
if not stream:
|
|
606
|
+
full_body = (await http_response.aread()).decode("utf-8")
|
|
607
|
+
resp: ChatCompletionResponse = ChatCompletionResponse.model_validate_json(full_body)
|
|
608
|
+
if not resp.choices:
|
|
609
|
+
if is_last_attempt:
|
|
610
|
+
logger.error(f"resp.choices 为空,已达最大重试次数,疑似被风控拦截。响应内容: {full_body[:500]}")
|
|
611
|
+
raise RuntimeError(f"empty choices after {max_retries} attempts, body: {full_body[:500]}")
|
|
612
|
+
# wait_time = base_delay**attempt
|
|
613
|
+
wait_time = base_delay
|
|
614
|
+
logger.warning(f"resp.choices 为空,疑似被风控拦截,第 {attempt+1} 次重试,等待 {wait_time}s... 响应内容: {full_body[:500]}")
|
|
615
|
+
await asyncio.sleep(wait_time)
|
|
616
|
+
continue
|
|
617
|
+
choice: ChatCompletionChoice = resp.choices[0]
|
|
618
|
+
if choice.finish_reason and choice.finish_reason != "stop":
|
|
619
|
+
logger.warning(f"finish_reason='{choice.finish_reason}' (非 stop),content_len={len(choice.message.content or '')}, usage={resp.usage.model_dump() if resp.usage else None}")
|
|
620
|
+
return ModelOutput(
|
|
621
|
+
role=choice.message.role,
|
|
622
|
+
content=choice.message.content,
|
|
623
|
+
reasoning=choice.message.reasoning_content,
|
|
624
|
+
finish_reason=choice.finish_reason,
|
|
625
|
+
usage=resp.usage.model_dump() if resp.usage else None,
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
else:
|
|
629
|
+
model_reasoning: str = ""
|
|
630
|
+
model_response: str = ""
|
|
631
|
+
role: Optional[str] = None
|
|
632
|
+
usage: Optional[Dict] = None
|
|
633
|
+
finish_reason: Optional[str] = None
|
|
634
|
+
|
|
635
|
+
async for line in http_response.aiter_lines():
|
|
636
|
+
|
|
637
|
+
if not line or line.strip() == "":
|
|
638
|
+
continue
|
|
639
|
+
|
|
640
|
+
logger.debug(f"line: '{line}'")
|
|
641
|
+
|
|
642
|
+
if not line.startswith("data: "):
|
|
643
|
+
logger.error(f"数据行没有以'data: '开头,将跳过。line: {line}")
|
|
644
|
+
continue
|
|
645
|
+
|
|
646
|
+
raw_data = line.removeprefix("data: ").strip()
|
|
647
|
+
if raw_data == "[DONE]":
|
|
648
|
+
logger.debug("收到结束标志'[DONE]', 不再解析后续包")
|
|
649
|
+
break
|
|
650
|
+
|
|
651
|
+
try:
|
|
652
|
+
chunk: ChatCompletionChunk = ChatCompletionChunk.model_validate_json(raw_data)
|
|
653
|
+
except Exception as e:
|
|
654
|
+
logger.error(f"解析数据包失败: {e}, 原始数据: {raw_data}")
|
|
655
|
+
continue
|
|
656
|
+
|
|
657
|
+
if chunk.usage:
|
|
658
|
+
if usage is not None:
|
|
659
|
+
logger.warning(f"收到多个含有usage的包,后收到的usage将会覆盖先前的token用量信息。")
|
|
660
|
+
usage = chunk.usage.model_dump()
|
|
661
|
+
|
|
662
|
+
if not chunk.choices:
|
|
663
|
+
continue
|
|
664
|
+
|
|
665
|
+
chunk_choice: ChunkChoice = chunk.choices[0]
|
|
666
|
+
chunk_delta: ChunkDelta = chunk_choice.delta
|
|
667
|
+
|
|
668
|
+
if chunk_delta.content:
|
|
669
|
+
model_response += chunk_delta.content
|
|
670
|
+
if chunk_delta.reasoning_content:
|
|
671
|
+
model_reasoning += chunk_delta.reasoning_content
|
|
672
|
+
|
|
673
|
+
if chunk_delta.role:
|
|
674
|
+
if role is not None:
|
|
675
|
+
logger.warning(f"收到多个含有role的包,后收到的role将会覆盖先前的role信息。")
|
|
676
|
+
role = chunk_delta.role
|
|
677
|
+
|
|
678
|
+
if chunk_choice.finish_reason:
|
|
679
|
+
if finish_reason is not None:
|
|
680
|
+
logger.warning(f"收到多个含有finish_reason的包,后收到的finish_reason将会覆盖先前的finish_reason信息。")
|
|
681
|
+
finish_reason = chunk_choice.finish_reason
|
|
682
|
+
|
|
683
|
+
if finish_reason and finish_reason != "stop":
|
|
684
|
+
logger.warning(f"finish_reason='{finish_reason}' (非 stop),content_len={len(model_response)}, usage={usage}")
|
|
685
|
+
return ModelOutput(
|
|
686
|
+
role=role,
|
|
687
|
+
content=model_response,
|
|
688
|
+
reasoning=model_reasoning,
|
|
689
|
+
finish_reason=finish_reason,
|
|
690
|
+
usage=usage,
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
except (httpx.NetworkError, httpx.TimeoutException, httpx.HTTPStatusError) as e:
|
|
694
|
+
if is_last_attempt:
|
|
695
|
+
logger.error(f"达到最大重试次数,最后一次错误: {e}")
|
|
696
|
+
raise
|
|
697
|
+
if not isinstance(e, httpx.HTTPStatusError):
|
|
698
|
+
wait_time = base_delay**attempt
|
|
699
|
+
logger.warning(f"网络异常: {type(e).__name__}: {e!r}, 等待 {wait_time}s 后重试...")
|
|
700
|
+
await asyncio.sleep(wait_time)
|
|
701
|
+
|
|
702
|
+
raise RuntimeError("Unexpected end of retry loop")
|
|
703
|
+
|
|
704
|
+
async def inference(self, request: InferenceRequest) -> InferenceRequestResult:
|
|
705
|
+
prepared_request: PreparedRequest = self.build_request(request)
|
|
706
|
+
model = self.model_registry.get(request.model_name or self.default_model_name)
|
|
707
|
+
|
|
708
|
+
# URL 在 build_request 时一次性选定,整个重试循环复用同一 URL。
|
|
709
|
+
# 如需重试时重新选 URL(例如多节点场景),需将 URL 选取移至 send_request 的 attempt 循环内。
|
|
710
|
+
start_time: float = time.time()
|
|
711
|
+
try:
|
|
712
|
+
model_output: ModelOutput = await self.send_request(
|
|
713
|
+
url=prepared_request.url,
|
|
714
|
+
headers=prepared_request.headers,
|
|
715
|
+
payload=prepared_request.payload,
|
|
716
|
+
stream=prepared_request.payload["stream"],
|
|
717
|
+
)
|
|
718
|
+
finally:
|
|
719
|
+
model.release_url(prepared_request.url)
|
|
720
|
+
|
|
721
|
+
return InferenceRequestResult(
|
|
722
|
+
success=True,
|
|
723
|
+
task_id="",
|
|
724
|
+
request=request,
|
|
725
|
+
model_output=model_output,
|
|
726
|
+
duration=time.time() - start_time,
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
async def batch_inference(
|
|
730
|
+
self,
|
|
731
|
+
requests: List[InferenceRequest],
|
|
732
|
+
max_concurrency: int = 8,
|
|
733
|
+
output_file: Optional[str] = None,
|
|
734
|
+
silent_mode: bool = False,
|
|
735
|
+
) -> List[InferenceRequestResult]:
|
|
736
|
+
|
|
737
|
+
semaphore = asyncio.Semaphore(max_concurrency)
|
|
738
|
+
result_dict: Dict[int, InferenceRequestResult] = {}
|
|
739
|
+
|
|
740
|
+
async def run_one(idx: int, req: InferenceRequest):
|
|
741
|
+
async with semaphore:
|
|
742
|
+
try:
|
|
743
|
+
result = await self.inference(req)
|
|
744
|
+
result_dict[idx] = result
|
|
745
|
+
return idx, result
|
|
746
|
+
except Exception as e:
|
|
747
|
+
logger.error(f"Request #{idx} 抛出未捕获异常: {e}")
|
|
748
|
+
err = self._get_error_result(req, e)
|
|
749
|
+
result_dict[idx] = err
|
|
750
|
+
return idx, err
|
|
751
|
+
|
|
752
|
+
tasks = [asyncio.create_task(run_one(i, req)) for i, req in enumerate(requests)]
|
|
753
|
+
|
|
754
|
+
pbar = tqdm(total=len(requests), desc=f"Batch: {self.default_model_name}") if not silent_mode else None
|
|
755
|
+
f = open(output_file, "w", encoding="utf-8") if output_file else None
|
|
756
|
+
|
|
757
|
+
try:
|
|
758
|
+
for coro in asyncio.as_completed(tasks):
|
|
759
|
+
idx, result = await coro
|
|
760
|
+
if f:
|
|
761
|
+
f.write(result.model_dump_json() + "\n")
|
|
762
|
+
f.flush()
|
|
763
|
+
if pbar:
|
|
764
|
+
pbar.update(1)
|
|
765
|
+
finally:
|
|
766
|
+
if pbar:
|
|
767
|
+
pbar.close()
|
|
768
|
+
if f:
|
|
769
|
+
f.close()
|
|
770
|
+
|
|
771
|
+
return [result_dict[i] for i in range(len(requests))]
|