llm-engine-kitty 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_engine/engine.py ADDED
@@ -0,0 +1,771 @@
1
+ # llm_engine/engine.py
2
+
3
+ import asyncio
4
+ import httpx
5
+ import json
6
+ import time
7
+
8
+ from abc import ABC, abstractmethod
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
+ from tqdm import tqdm
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ import kitty_logger
14
+
15
+ from .model_config import ModelConfig, ModelConfigRegistry
16
+ from .schemas import (
17
+ InferenceParameters,
18
+ InferenceRequest,
19
+ InferenceRequestResult,
20
+ PreparedRequest,
21
+ Message,
22
+ MessageRole,
23
+ ModelOutput,
24
+ ChunkChoice,
25
+ ChunkDelta,
26
+ ChatCompletionChunk,
27
+ ChatCompletionChoice,
28
+ ChatCompletionResponse,
29
+ )
30
+ from .utils import gen_unique_id
31
+
32
+ logger = kitty_logger.getLogger(__name__)
33
+
34
+
35
+ class BaseEngine(ABC):
36
+
37
+ __slots__ = ()
38
+
39
+ pass
40
+
41
+
42
+ class SyncEngine(BaseEngine):
43
+
44
+ __slots__ = ()
45
+
46
+ @abstractmethod
47
+ def inference(self, request: InferenceRequest) -> InferenceRequestResult:
48
+ pass
49
+
50
+ def batch_inference(self, requests: List[InferenceRequest], max_concurrency: int = 1) -> List[InferenceRequestResult]:
51
+
52
+ class_name = self.__class__.__name__
53
+
54
+ # fmt: off
55
+ logger.warning(
56
+ f"引擎 '{class_name}' 未实现高效的 'batch_inference' 方法,将回退至串行模式。"
57
+ f"指定的并发数 (max_concurrency={max_concurrency}) 不会生效,"
58
+ f"正在串行处理 {len(requests)} 个请求,这可能会影响执行效率。"
59
+ )
60
+ # fmt: on
61
+
62
+ return [self.inference(req) for req in requests]
63
+
64
+ def infer(self, query: str) -> str:
65
+
66
+ inference_result: InferenceRequestResult = self.inference(InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]))
67
+
68
+ if inference_result.success is True and inference_result.model_output is not None:
69
+
70
+ return inference_result.model_output.content
71
+
72
+ else:
73
+
74
+ return ""
75
+
76
+ def batch_infer(self, queries: List[str], max_concurrency: int = 8) -> List[str]:
77
+
78
+ inference_requests: List[InferenceRequest] = [InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]) for query in queries]
79
+
80
+ inference_results: List[InferenceRequestResult] = self.batch_inference(inference_requests, max_concurrency=max_concurrency)
81
+
82
+ return [result.model_output.content if result.success is True and result.model_output is not None else "" for result in inference_results]
83
+
84
+
85
+ class AsyncEngine(BaseEngine):
86
+
87
+ __slots__ = ()
88
+
89
+ pass
90
+
91
+
92
+ class CoroutineEngine(BaseEngine):
93
+
94
+ __slots__ = ()
95
+
96
+ @abstractmethod
97
+ async def inference(self, request: InferenceRequest) -> InferenceRequestResult:
98
+ pass
99
+
100
+ async def batch_inference(self, requests: List[InferenceRequest], max_concurrency: int = 1) -> List[InferenceRequestResult]:
101
+
102
+ class_name = self.__class__.__name__
103
+
104
+ # fmt: off
105
+ logger.warning(
106
+ f"引擎 '{class_name}' 未实现高效的 'batch_inference' 方法,将回退至串行模式。"
107
+ f"指定的并发数 (max_concurrency={max_concurrency}) 不会生效,"
108
+ f"正在串行处理 {len(requests)} 个请求,这可能会影响执行效率。"
109
+ )
110
+ # fmt: on
111
+
112
+ return [await self.inference(req) for req in requests]
113
+
114
+ async def infer(self, query: str) -> str:
115
+
116
+ try:
117
+ inference_result: InferenceRequestResult = await self.inference(InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]))
118
+ except Exception as e:
119
+ logger.error(f"infer 抛出未捕获异常: {e}")
120
+ return ""
121
+
122
+ if inference_result.success is True and inference_result.model_output is not None:
123
+ return inference_result.model_output.content
124
+ else:
125
+ return ""
126
+
127
+ async def batch_infer(self, queries: List[str], max_concurrency: int = 8) -> List[str]:
128
+
129
+ inference_requests: List[InferenceRequest] = [InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]) for query in queries]
130
+
131
+ inference_results: List[InferenceRequestResult] = await self.batch_inference(inference_requests, max_concurrency=max_concurrency)
132
+
133
+ return [result.model_output.content if result.success is True and result.model_output is not None else "" for result in inference_results]
134
+
135
+
136
+ class SimpleEngine(SyncEngine):
137
+
138
+ __slots__ = (
139
+ "model_registry",
140
+ "default_model_name",
141
+ "default_model",
142
+ "default_api_key",
143
+ "default_inference_parameters",
144
+ "extra_headers",
145
+ "extra_payload",
146
+ "stream",
147
+ )
148
+
149
+ def __init__(
150
+ self,
151
+ model_registry: ModelConfigRegistry,
152
+ default_model: str,
153
+ default_api_key: Optional[str] = None,
154
+ default_inference_parameters: Optional[InferenceParameters] = None,
155
+ extra_headers: Optional[Dict[str, str]] = None,
156
+ extra_payload: Optional[Dict[str, Any]] = None,
157
+ stream: bool = False,
158
+ ) -> None:
159
+ self.model_registry: ModelConfigRegistry = model_registry
160
+ self.default_model_name: str = default_model
161
+ self.default_model: ModelConfig = self.model_registry.get(default_model)
162
+ self.default_api_key: Optional[str] = default_api_key
163
+ self.default_inference_parameters: Optional[InferenceParameters] = default_inference_parameters
164
+ self.extra_headers: Dict[str, str] = extra_headers if extra_headers is not None else {}
165
+ self.extra_payload: Dict[str, Any] = extra_payload if extra_payload is not None else {}
166
+ self.stream: bool = stream
167
+
168
+ def get_model_api_key(self, inference_request: InferenceRequest, model: ModelConfig) -> str:
169
+
170
+ if inference_request.api_key is not None:
171
+ return inference_request.api_key
172
+ elif model.api_key is not None:
173
+ return model.api_key
174
+ elif self.default_api_key is not None:
175
+ return self.default_api_key
176
+ else:
177
+ logger.warning(f"没有找到可用的 api_key, 将传递空 api_key")
178
+ return ""
179
+
180
+ def build_request_headers(self, inference_request: InferenceRequest, model: ModelConfig) -> Dict[str, str]:
181
+
182
+ headers: Dict[str, str] = {
183
+ "Authorization": "Bearer " + self.get_model_api_key(inference_request, model),
184
+ "Content-Type": "application/json",
185
+ }
186
+
187
+ headers.update(self.extra_headers)
188
+ headers.update(model.extra_headers)
189
+ headers.update(inference_request.extra_headers)
190
+
191
+ return headers
192
+
193
+ def build_request_payload(self, inference_request: InferenceRequest, model: ModelConfig) -> Dict[str, Any]:
194
+ payload: Dict[str, Any] = {
195
+ "model": model.model_id,
196
+ "stream": inference_request.stream if inference_request.stream is not None else self.stream,
197
+ "messages": [msg.to_dict() for msg in inference_request.messages],
198
+ }
199
+
200
+ if self.default_inference_parameters:
201
+ payload.update(self.default_inference_parameters.to_dict())
202
+ if model.default_inference_parameters:
203
+ payload.update(model.default_inference_parameters.to_dict())
204
+ if inference_request.inference_parameters:
205
+ payload.update(inference_request.inference_parameters.to_dict())
206
+
207
+ payload.update(self.extra_payload)
208
+ payload.update(model.extra_payload)
209
+ payload.update(inference_request.extra_payload)
210
+
211
+ return payload
212
+
213
+ def build_request(self, request: InferenceRequest) -> PreparedRequest:
214
+ """核心组装逻辑:将业务请求转换为底层 HTTP 请求参数"""
215
+
216
+ model_name = request.model_name or self.default_model_name
217
+ model = self.model_registry.get(model_name)
218
+
219
+ return PreparedRequest(
220
+ model_name=model_name,
221
+ url=model.get_url(),
222
+ headers=self.build_request_headers(request, model),
223
+ payload=self.build_request_payload(request, model),
224
+ )
225
+
226
+ def send_request(
227
+ self,
228
+ url: str,
229
+ headers: Dict[str, str],
230
+ payload: Dict[str, Any],
231
+ stream: bool,
232
+ max_retries: int = 5,
233
+ base_delay: int = 2,
234
+ ) -> ModelOutput:
235
+
236
+ for attempt in range(max_retries):
237
+ is_last_attempt: bool = attempt == (max_retries - 1)
238
+ try:
239
+ with httpx.Client(timeout=httpx.Timeout(timeout=10.0, read=3600.0)) as http_client:
240
+
241
+ with http_client.stream("POST", url, json=payload, headers=headers) as http_response:
242
+
243
+ if http_response.status_code != 200:
244
+
245
+ if is_last_attempt:
246
+ response_body = http_response.read().decode("utf-8", errors="replace")
247
+ logger.error(f"HTTP {http_response.status_code},已达最大重试次数,响应内容: {response_body[:500]}")
248
+ raise httpx.HTTPStatusError(f"HTTP {http_response.status_code}", request=http_response.request, response=http_response)
249
+
250
+ wait_time = self._get_wait_time(http_response.status_code, attempt, base_delay)
251
+ response_body = http_response.read().decode("utf-8", errors="replace")
252
+ logger.warning(f"HTTP {http_response.status_code}, 第 {attempt+1} 次重试,等待 {wait_time}s... 响应内容: {response_body[:500]}")
253
+ time.sleep(wait_time)
254
+ continue
255
+
256
+ if not stream:
257
+ full_body = http_response.read().decode("utf-8")
258
+ resp: ChatCompletionResponse = ChatCompletionResponse.model_validate_json(full_body)
259
+ choice: ChatCompletionChoice = resp.choices[0]
260
+ if choice.finish_reason and choice.finish_reason != "stop":
261
+ logger.warning(f"finish_reason='{choice.finish_reason}' (非 stop),content_len={len(choice.message.content or '')}, usage={resp.usage.model_dump() if resp.usage else None}")
262
+ return ModelOutput(
263
+ role=choice.message.role,
264
+ content=choice.message.content,
265
+ reasoning=choice.message.reasoning_content,
266
+ finish_reason=choice.finish_reason,
267
+ usage=resp.usage.model_dump() if resp.usage else None,
268
+ )
269
+
270
+ else:
271
+ model_reasoning: str = ""
272
+ model_response: str = ""
273
+ role: Optional[str] = None
274
+ usage: Optional[Dict] = None
275
+ finish_reason: Optional[str] = None
276
+
277
+ for line in http_response.iter_lines():
278
+
279
+ if not line or line.strip() == "":
280
+ continue
281
+
282
+ logger.debug(f"line: '{line}'")
283
+
284
+ if not line.startswith("data: "):
285
+ logger.error(f"数据行没有以'data: '开头,将跳过。line: {line}")
286
+ continue
287
+
288
+ raw_data = line.removeprefix("data: ").strip()
289
+ if raw_data == "[DONE]":
290
+ logger.debug("收到结束标志'[DONE]', 不再解析后续包")
291
+ break
292
+
293
+ try:
294
+ chunk: ChatCompletionChunk = ChatCompletionChunk.model_validate_json(raw_data)
295
+ except Exception as e:
296
+ logger.error(f"解析数据包失败: {e}, 原始数据: {raw_data}")
297
+ continue
298
+
299
+ if chunk.usage:
300
+ if usage is not None:
301
+ logger.warning(f"收到多个含有usage的包,后收到的usage将会覆盖先前的token用量信息。")
302
+ usage = chunk.usage.model_dump()
303
+
304
+ if not chunk.choices:
305
+ continue
306
+
307
+ chunk_choice: ChunkChoice = chunk.choices[0]
308
+ chunk_delta: ChunkDelta = chunk_choice.delta
309
+
310
+ if chunk_delta.content:
311
+ model_response += chunk_delta.content
312
+ if chunk_delta.reasoning_content:
313
+ model_reasoning += chunk_delta.reasoning_content
314
+
315
+ if chunk_delta.role:
316
+ if role is not None:
317
+ logger.warning(f"收到多个含有role的包,后收到的role将会覆盖先前的role信息。")
318
+ role = chunk_delta.role
319
+
320
+ if chunk_choice.finish_reason:
321
+ if finish_reason is not None:
322
+ logger.warning(f"收到多个含有finish_reason的包,后收到的finish_reason将会覆盖先前的finish_reason信息。")
323
+ finish_reason = chunk_choice.finish_reason
324
+
325
+ if finish_reason and finish_reason != "stop":
326
+ logger.warning(f"finish_reason='{finish_reason}' (非 stop),content_len={len(model_response)}, usage={usage}")
327
+ return ModelOutput(
328
+ role=role,
329
+ content=model_response,
330
+ reasoning=model_reasoning,
331
+ finish_reason=finish_reason,
332
+ usage=usage,
333
+ )
334
+
335
+ except (httpx.NetworkError, httpx.TimeoutException, httpx.HTTPStatusError) as e:
336
+ if is_last_attempt:
337
+ logger.error(f"达到最大重试次数,最后一次错误: {e}")
338
+ raise
339
+ if not isinstance(e, httpx.HTTPStatusError):
340
+ wait_time = base_delay**attempt
341
+ logger.warning(f"网络异常: {e}, 等待 {wait_time}s 后重试...")
342
+ time.sleep(wait_time)
343
+
344
+ raise RuntimeError("Unexpected end of retry loop")
345
+
346
+ def _get_wait_time(self, status_code: int, attempt: int, base_delay: int) -> float:
347
+ """集中管理不同错误的等待时间算法"""
348
+ if status_code == 429:
349
+ # 频率限制通常需要更长的等待
350
+ return 10 * (attempt + 1)
351
+ elif status_code >= 500:
352
+ # 服务器内部错误使用指数退避
353
+ return float(base_delay**attempt)
354
+ return 5.0 # 其他错误默认值
355
+
356
+ def inference(self, request: InferenceRequest) -> InferenceRequestResult:
357
+ prepared_request: PreparedRequest = self.build_request(request)
358
+ model = self.model_registry.get(request.model_name or self.default_model_name)
359
+
360
+ # URL 在 build_request 时一次性选定,整个重试循环复用同一 URL。
361
+ # 如需重试时重新选 URL(例如多节点场景),需将 URL 选取移至 send_request 的 attempt 循环内。
362
+ start_time: float = time.time()
363
+ try:
364
+ model_output: ModelOutput = self.send_request(
365
+ url=prepared_request.url,
366
+ headers=prepared_request.headers,
367
+ payload=prepared_request.payload,
368
+ stream=prepared_request.payload["stream"],
369
+ )
370
+ finally:
371
+ model.release_url(prepared_request.url)
372
+
373
+ return InferenceRequestResult(
374
+ success=True,
375
+ task_id="",
376
+ request=request,
377
+ model_output=model_output,
378
+ duration=time.time() - start_time,
379
+ )
380
+
381
+ def batch_inference(
382
+ self,
383
+ requests: List[InferenceRequest],
384
+ max_concurrency: int = 8,
385
+ output_file: Optional[str] = None,
386
+ silent_mode: bool = False,
387
+ ) -> List[InferenceRequestResult]:
388
+
389
+ result_dict: Dict[int, InferenceRequestResult] = {}
390
+
391
+ with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
392
+ future_to_idx = {executor.submit(self.inference, req): i for i, req in enumerate(requests)}
393
+
394
+ futures_iterator = as_completed(future_to_idx)
395
+ if not silent_mode:
396
+ futures_iterator = tqdm(
397
+ futures_iterator,
398
+ total=len(requests),
399
+ desc=f"Batch: {self.default_model_name}",
400
+ )
401
+
402
+ f = open(output_file, "w", encoding="utf-8") if output_file else None
403
+
404
+ try:
405
+ for future in futures_iterator:
406
+ idx = future_to_idx[future]
407
+ try:
408
+ result: InferenceRequestResult = future.result()
409
+ result_dict[idx] = result
410
+
411
+ if f:
412
+ f.write(result.model_dump_json() + "\n")
413
+ f.flush()
414
+
415
+ except Exception as e:
416
+ logger.error(f"Request #{idx} 抛出未捕获异常: {e}")
417
+ error_res = self._get_error_result(requests[idx], e)
418
+ result_dict[idx] = error_res
419
+ if f:
420
+ f.write(error_res.model_dump_json() + "\n")
421
+ finally:
422
+ if f:
423
+ f.close()
424
+
425
+ return [result_dict[i] for i in range(len(requests))]
426
+
427
+ def _get_error_result(self, request: InferenceRequest, error: Exception) -> InferenceRequestResult:
428
+ """内部工具:当请求彻底失败时返回标准化错误结构"""
429
+ return InferenceRequestResult(
430
+ success=False,
431
+ task_id="error",
432
+ request=request,
433
+ error_message=f"{type(error).__name__}: {str(error)}",
434
+ )
435
+
436
+ def mock_request(self, query: Optional[str] = None, request: Optional[InferenceRequest] = None) -> PreparedRequest:
437
+ """
438
+ 模拟构建请求,返回最终的 URL、Headers 和 Payload。
439
+ 用于验证配置组装、API Key 传递和参数覆盖是否符合预期。
440
+ """
441
+
442
+ if request is None:
443
+ if query is None:
444
+ query = f"请求文本: [{gen_unique_id()}]"
445
+ request = InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)])
446
+
447
+ prepared_request: PreparedRequest = self.build_request(request)
448
+
449
+ return prepared_request
450
+
451
+
452
+ class SimpleCoroutineEngine(CoroutineEngine):
453
+
454
+ __slots__ = (
455
+ "model_registry",
456
+ "default_model_name",
457
+ "default_model",
458
+ "default_api_key",
459
+ "default_inference_parameters",
460
+ "extra_headers",
461
+ "extra_payload",
462
+ "stream",
463
+ )
464
+
465
+ def __init__(
466
+ self,
467
+ model_registry: ModelConfigRegistry,
468
+ default_model: str,
469
+ default_api_key: Optional[str] = None,
470
+ default_inference_parameters: Optional[InferenceParameters] = None,
471
+ extra_headers: Optional[Dict[str, str]] = None,
472
+ extra_payload: Optional[Dict[str, Any]] = None,
473
+ stream: bool = False,
474
+ ) -> None:
475
+ self.model_registry: ModelConfigRegistry = model_registry
476
+ self.default_model_name: str = default_model
477
+ self.default_model: ModelConfig = self.model_registry.get(default_model)
478
+ self.default_api_key: Optional[str] = default_api_key
479
+ self.default_inference_parameters: Optional[InferenceParameters] = default_inference_parameters
480
+ self.extra_headers: Dict[str, str] = extra_headers if extra_headers is not None else {}
481
+ self.extra_payload: Dict[str, Any] = extra_payload if extra_payload is not None else {}
482
+ self.stream: bool = stream
483
+
484
+ def get_model_api_key(self, inference_request: InferenceRequest, model: ModelConfig) -> str:
485
+
486
+ if inference_request.api_key is not None:
487
+ return inference_request.api_key
488
+ elif model.api_key is not None:
489
+ return model.api_key
490
+ elif self.default_api_key is not None:
491
+ return self.default_api_key
492
+ else:
493
+ logger.warning(f"没有找到可用的 api_key, 将传递空 api_key")
494
+ return ""
495
+
496
+ def build_request_headers(self, inference_request: InferenceRequest, model: ModelConfig) -> Dict[str, str]:
497
+
498
+ headers: Dict[str, str] = {
499
+ "Authorization": "Bearer " + self.get_model_api_key(inference_request, model),
500
+ "Content-Type": "application/json",
501
+ }
502
+
503
+ headers.update(self.extra_headers)
504
+ headers.update(model.extra_headers)
505
+ headers.update(inference_request.extra_headers)
506
+
507
+ return headers
508
+
509
+ def build_request_payload(self, inference_request: InferenceRequest, model: ModelConfig) -> Dict[str, Any]:
510
+ payload: Dict[str, Any] = {
511
+ "model": model.model_id,
512
+ "stream": inference_request.stream if inference_request.stream is not None else self.stream,
513
+ "messages": [msg.to_dict() for msg in inference_request.messages],
514
+ }
515
+
516
+ if self.default_inference_parameters:
517
+ payload.update(self.default_inference_parameters.to_dict())
518
+ if model.default_inference_parameters:
519
+ payload.update(model.default_inference_parameters.to_dict())
520
+ if inference_request.inference_parameters:
521
+ payload.update(inference_request.inference_parameters.to_dict())
522
+
523
+ payload.update(self.extra_payload)
524
+ payload.update(model.extra_payload)
525
+ payload.update(inference_request.extra_payload)
526
+
527
+ return payload
528
+
529
+ def build_request(self, request: InferenceRequest) -> PreparedRequest:
530
+ """核心组装逻辑:将业务请求转换为底层 HTTP 请求参数"""
531
+
532
+ model_name = request.model_name or self.default_model_name
533
+ model = self.model_registry.get(model_name)
534
+
535
+ return PreparedRequest(
536
+ model_name=model_name,
537
+ url=model.get_url(),
538
+ headers=self.build_request_headers(request, model),
539
+ payload=self.build_request_payload(request, model),
540
+ )
541
+
542
+ def _get_wait_time(self, status_code: int, attempt: int, base_delay: int) -> float:
543
+ """集中管理不同错误的等待时间算法"""
544
+ if status_code == 429:
545
+ # return 10 * (attempt + 1)
546
+ return 1.0
547
+ elif status_code >= 500:
548
+ # return float(base_delay**attempt)
549
+ return 1.0
550
+ return 1.0
551
+
552
+ def _get_error_result(self, request: InferenceRequest, error: Exception) -> InferenceRequestResult:
553
+ """内部工具:当请求彻底失败时返回标准化错误结构"""
554
+ return InferenceRequestResult(
555
+ success=False,
556
+ task_id="error",
557
+ request=request,
558
+ error_message=f"{type(error).__name__}: {str(error)}",
559
+ )
560
+
561
+ def mock_request(self, query: Optional[str] = None, request: Optional[InferenceRequest] = None) -> PreparedRequest:
562
+ """
563
+ 模拟构建请求,返回最终的 URL、Headers 和 Payload。
564
+ 用于验证配置组装、API Key 传递和参数覆盖是否符合预期。
565
+ """
566
+
567
+ if request is None:
568
+ if query is None:
569
+ query = f"请求文本: [{gen_unique_id()}]"
570
+ request = InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)])
571
+
572
+ prepared_request: PreparedRequest = self.build_request(request)
573
+
574
+ return prepared_request
575
+
576
+ async def send_request(
577
+ self,
578
+ url: str,
579
+ headers: Dict[str, str],
580
+ payload: Dict[str, Any],
581
+ stream: bool,
582
+ max_retries: int = 256,
583
+ base_delay: int = 2,
584
+ ) -> ModelOutput:
585
+
586
+ async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=10.0, read=1800.0)) as http_client:
587
+ for attempt in range(max_retries):
588
+ is_last_attempt: bool = attempt == (max_retries - 1)
589
+ try:
590
+ async with http_client.stream("POST", url, json=payload, headers=headers) as http_response:
591
+
592
+ if http_response.status_code != 200:
593
+
594
+ if is_last_attempt:
595
+ response_body = (await http_response.aread()).decode("utf-8", errors="replace")
596
+ logger.error(f"HTTP {http_response.status_code},已达最大重试次数,响应内容: {response_body[:500]}")
597
+ raise httpx.HTTPStatusError(f"HTTP {http_response.status_code}", request=http_response.request, response=http_response)
598
+
599
+ wait_time = self._get_wait_time(http_response.status_code, attempt, base_delay)
600
+ response_body = (await http_response.aread()).decode("utf-8", errors="replace")
601
+ logger.warning(f"HTTP {http_response.status_code}, 第 {attempt+1} 次重试,等待 {wait_time}s... 响应内容: {response_body[:500]}")
602
+ await asyncio.sleep(wait_time)
603
+ continue
604
+
605
+ if not stream:
606
+ full_body = (await http_response.aread()).decode("utf-8")
607
+ resp: ChatCompletionResponse = ChatCompletionResponse.model_validate_json(full_body)
608
+ if not resp.choices:
609
+ if is_last_attempt:
610
+ logger.error(f"resp.choices 为空,已达最大重试次数,疑似被风控拦截。响应内容: {full_body[:500]}")
611
+ raise RuntimeError(f"empty choices after {max_retries} attempts, body: {full_body[:500]}")
612
+ # wait_time = base_delay**attempt
613
+ wait_time = base_delay
614
+ logger.warning(f"resp.choices 为空,疑似被风控拦截,第 {attempt+1} 次重试,等待 {wait_time}s... 响应内容: {full_body[:500]}")
615
+ await asyncio.sleep(wait_time)
616
+ continue
617
+ choice: ChatCompletionChoice = resp.choices[0]
618
+ if choice.finish_reason and choice.finish_reason != "stop":
619
+ logger.warning(f"finish_reason='{choice.finish_reason}' (非 stop),content_len={len(choice.message.content or '')}, usage={resp.usage.model_dump() if resp.usage else None}")
620
+ return ModelOutput(
621
+ role=choice.message.role,
622
+ content=choice.message.content,
623
+ reasoning=choice.message.reasoning_content,
624
+ finish_reason=choice.finish_reason,
625
+ usage=resp.usage.model_dump() if resp.usage else None,
626
+ )
627
+
628
+ else:
629
+ model_reasoning: str = ""
630
+ model_response: str = ""
631
+ role: Optional[str] = None
632
+ usage: Optional[Dict] = None
633
+ finish_reason: Optional[str] = None
634
+
635
+ async for line in http_response.aiter_lines():
636
+
637
+ if not line or line.strip() == "":
638
+ continue
639
+
640
+ logger.debug(f"line: '{line}'")
641
+
642
+ if not line.startswith("data: "):
643
+ logger.error(f"数据行没有以'data: '开头,将跳过。line: {line}")
644
+ continue
645
+
646
+ raw_data = line.removeprefix("data: ").strip()
647
+ if raw_data == "[DONE]":
648
+ logger.debug("收到结束标志'[DONE]', 不再解析后续包")
649
+ break
650
+
651
+ try:
652
+ chunk: ChatCompletionChunk = ChatCompletionChunk.model_validate_json(raw_data)
653
+ except Exception as e:
654
+ logger.error(f"解析数据包失败: {e}, 原始数据: {raw_data}")
655
+ continue
656
+
657
+ if chunk.usage:
658
+ if usage is not None:
659
+ logger.warning(f"收到多个含有usage的包,后收到的usage将会覆盖先前的token用量信息。")
660
+ usage = chunk.usage.model_dump()
661
+
662
+ if not chunk.choices:
663
+ continue
664
+
665
+ chunk_choice: ChunkChoice = chunk.choices[0]
666
+ chunk_delta: ChunkDelta = chunk_choice.delta
667
+
668
+ if chunk_delta.content:
669
+ model_response += chunk_delta.content
670
+ if chunk_delta.reasoning_content:
671
+ model_reasoning += chunk_delta.reasoning_content
672
+
673
+ if chunk_delta.role:
674
+ if role is not None:
675
+ logger.warning(f"收到多个含有role的包,后收到的role将会覆盖先前的role信息。")
676
+ role = chunk_delta.role
677
+
678
+ if chunk_choice.finish_reason:
679
+ if finish_reason is not None:
680
+ logger.warning(f"收到多个含有finish_reason的包,后收到的finish_reason将会覆盖先前的finish_reason信息。")
681
+ finish_reason = chunk_choice.finish_reason
682
+
683
+ if finish_reason and finish_reason != "stop":
684
+ logger.warning(f"finish_reason='{finish_reason}' (非 stop),content_len={len(model_response)}, usage={usage}")
685
+ return ModelOutput(
686
+ role=role,
687
+ content=model_response,
688
+ reasoning=model_reasoning,
689
+ finish_reason=finish_reason,
690
+ usage=usage,
691
+ )
692
+
693
+ except (httpx.NetworkError, httpx.TimeoutException, httpx.HTTPStatusError) as e:
694
+ if is_last_attempt:
695
+ logger.error(f"达到最大重试次数,最后一次错误: {e}")
696
+ raise
697
+ if not isinstance(e, httpx.HTTPStatusError):
698
+ wait_time = base_delay**attempt
699
+ logger.warning(f"网络异常: {type(e).__name__}: {e!r}, 等待 {wait_time}s 后重试...")
700
+ await asyncio.sleep(wait_time)
701
+
702
+ raise RuntimeError("Unexpected end of retry loop")
703
+
704
+ async def inference(self, request: InferenceRequest) -> InferenceRequestResult:
705
+ prepared_request: PreparedRequest = self.build_request(request)
706
+ model = self.model_registry.get(request.model_name or self.default_model_name)
707
+
708
+ # URL 在 build_request 时一次性选定,整个重试循环复用同一 URL。
709
+ # 如需重试时重新选 URL(例如多节点场景),需将 URL 选取移至 send_request 的 attempt 循环内。
710
+ start_time: float = time.time()
711
+ try:
712
+ model_output: ModelOutput = await self.send_request(
713
+ url=prepared_request.url,
714
+ headers=prepared_request.headers,
715
+ payload=prepared_request.payload,
716
+ stream=prepared_request.payload["stream"],
717
+ )
718
+ finally:
719
+ model.release_url(prepared_request.url)
720
+
721
+ return InferenceRequestResult(
722
+ success=True,
723
+ task_id="",
724
+ request=request,
725
+ model_output=model_output,
726
+ duration=time.time() - start_time,
727
+ )
728
+
729
+ async def batch_inference(
730
+ self,
731
+ requests: List[InferenceRequest],
732
+ max_concurrency: int = 8,
733
+ output_file: Optional[str] = None,
734
+ silent_mode: bool = False,
735
+ ) -> List[InferenceRequestResult]:
736
+
737
+ semaphore = asyncio.Semaphore(max_concurrency)
738
+ result_dict: Dict[int, InferenceRequestResult] = {}
739
+
740
+ async def run_one(idx: int, req: InferenceRequest):
741
+ async with semaphore:
742
+ try:
743
+ result = await self.inference(req)
744
+ result_dict[idx] = result
745
+ return idx, result
746
+ except Exception as e:
747
+ logger.error(f"Request #{idx} 抛出未捕获异常: {e}")
748
+ err = self._get_error_result(req, e)
749
+ result_dict[idx] = err
750
+ return idx, err
751
+
752
+ tasks = [asyncio.create_task(run_one(i, req)) for i, req in enumerate(requests)]
753
+
754
+ pbar = tqdm(total=len(requests), desc=f"Batch: {self.default_model_name}") if not silent_mode else None
755
+ f = open(output_file, "w", encoding="utf-8") if output_file else None
756
+
757
+ try:
758
+ for coro in asyncio.as_completed(tasks):
759
+ idx, result = await coro
760
+ if f:
761
+ f.write(result.model_dump_json() + "\n")
762
+ f.flush()
763
+ if pbar:
764
+ pbar.update(1)
765
+ finally:
766
+ if pbar:
767
+ pbar.close()
768
+ if f:
769
+ f.close()
770
+
771
+ return [result_dict[i] for i in range(len(requests))]