llm-engine-kitty 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1077 @@
1
+ # llm_engine/kitty/engine.py
2
+
3
+ """KittyEngine 后端进程的任务执行器层。
4
+
5
+ 本模块只包含两个核心对象:
6
+
7
+ - `_WorkerTask`:后端进程内的任务状态容器(仅进程内使用,不跨进程)。
8
+ - `_KittyEngine`:任务生命周期管理器,由 KittyServer 在子进程主线程的
9
+ asyncio loop 中驱动。
10
+
11
+ 分层设计:
12
+ - 协议/序列化:由 `llm_engine.kitty.schemas`(`TaskSnapshot` / `WireResult`
13
+ / `TransitionRecord`)与 `llm_engine.kitty.config` 负责;worker 不感知。
14
+ - 事件循环 / socket:由 KittyServer 负责。worker 只暴露任务级 API
15
+ (`submit` / `cancel` / `snapshot` / `pop_result` / `clear_done`
16
+ / `force_remove` / `wait_task` / `cleanup_tasks`),不接触 socket。
17
+ - HTTP / 模型路由:由 `_KittyEngine` 直接使用 `httpx.AsyncClient` +
18
+ `ModelConfigRegistry`。上层 client 不感知。
19
+
20
+ 并发模型:
21
+ - worker 运行在子进程的唯一 asyncio 事件循环线程里。
22
+ - 所有 `self._tasks` / `_WorkerTask` 字段的读写都在该线程内完成,故不需要锁。
23
+ - 并发度由 `asyncio.Semaphore(max_global_concurrency)` 控制;每个任务是
24
+ 一条 `asyncio.Task`,运行 `_KittyEngine._run_task`。
25
+
26
+ 状态机(见 `llm_engine.schemas.TaskStatus`)::
27
+
28
+ SUBMITTED ──► ACCEPTED ──► PENDING ──► RUNNING ──► SUCCESS
29
+ │ │ │ │
30
+ │ │ │ ├──► FAILED (异常/HTTP/解析失败)
31
+ │ │ │ ├──► TIMEOUT (task_timeout 触发)
32
+ │ │ │ └──► CANCELLED (运行中取消)
33
+ │ └───────────┴──► CANCELLED (排队中取消)
34
+ └──► REJECTED(校验失败:缺 model_name、模型未注册……)
35
+
36
+ - SUBMITTED / ACCEPTED:任务已落盘到 worker,但尚未入队 sem。REJECTED 从这一
37
+ 段产生。
38
+ - PENDING:已入队 `_sem`,等待并发名额。
39
+ - RUNNING:持有 sem,正在发 HTTP。
40
+ - SUCCESS / FAILED / TIMEOUT / CANCELLED / REJECTED:终态,由 `is_finished` 判定。
41
+
42
+ 将来若支持 "RUNNING 中异常回到 PENDING 再跑" 的重试,状态会多次经过 PENDING /
43
+ RUNNING。`TaskSnapshot.submit_time` / `start_time` 等派生属性按"首次出现"
44
+ 语义实现,已为此预留。
45
+
46
+ 状态转移的唯一入口:
47
+ 所有状态变更必须经由 `_WorkerTask.transition`。该方法会:
48
+
49
+ 1. 写入新状态;
50
+ 2. 追加一条 `TransitionRecord`;
51
+ 3. 回填上一条记录的 `duration`(= 上一个状态的实际持续秒数)。
52
+
53
+ 禁止直接 `wt.status = X`,否则时间线会丢失。
54
+ """
55
+
56
+ import asyncio
57
+ import time
58
+ from dataclasses import dataclass, field
59
+ from datetime import datetime
60
+ from typing import Any, Callable, Optional
61
+
62
+ import httpx
63
+ import kitty_logger
64
+ from pydantic import ValidationError
65
+
66
+ from ..model_config import ModelConfig, ModelConfigRegistry
67
+ from ..schemas import (
68
+ ChatCompletionResponse,
69
+ InferenceParameters,
70
+ InferenceRequest,
71
+ PreparedRequest,
72
+ ModelOutput,
73
+ TaskStatus,
74
+ )
75
+ from .config import KittyEngineConfig, KittyEngineOverrides
76
+ from .schemas import MockRequest, TaskSnapshot, TransitionRecord, WireResult
77
+
78
+ logger = kitty_logger.getLogger(__name__)
79
+
80
+
81
+ def _now_str() -> str:
82
+ """返回人类可读的本地时间戳,形如 `'2026-05-11 10:30:45.123456'`。
83
+
84
+ 用法约定:
85
+ - 仅用于 `TransitionRecord.timestamp`(字符串便于日志 / 跨进程传输 /
86
+ 可读性)。
87
+ - 需要数值秒时通过 `TransitionRecord.epoch`
88
+ (`datetime.fromisoformat(...).timestamp()`)换算,保持
89
+ "字符串为真,epoch 派生"。
90
+
91
+ Returns:
92
+ 带微秒精度的本地时间字符串。
93
+ """
94
+ return datetime.now().isoformat(sep=" ", timespec="microseconds")
95
+
96
+
97
+ @dataclass
98
+ class _WorkerTask:
99
+ """后端进程内部的任务状态容器。不跨进程传输。
100
+
101
+ 为什么用 dataclass 而非 pydantic:
102
+ - 持有 `asyncio.Task` / `asyncio.Event` 等不可序列化对象;
103
+ - 仅在事件循环线程内存在,不需要 `model_validate`;
104
+ - 需要频繁原地变更(状态/结果/错误信息)。
105
+
106
+ 时间线语义(重要):
107
+ - 不保存 `submit_time / start_time / end_time` 字段。所有时间点统一
108
+ 由 `transitions` 列表记录;每次进入新状态追加一条 `TransitionRecord`,
109
+ 并把上一条的 `duration` 回填为上一个状态实际持续的秒数。
110
+ - 若未来支持 "RUNNING 失败→回到 PENDING→再次 RUNNING" 的重试路径,
111
+ `transitions` 会出现多个同状态条目;对应的 submit/start 等派生值
112
+ 按"首次出现"取值。
113
+
114
+ 与外部对象的协作要求:
115
+ - `_KittyEngine.submit` 在构造完实例之后、`create_task` 之前,必须
116
+ 显式设置 `done_event = asyncio.Event()`,并挂上 `asyncio_task`。
117
+ (当前 `__post_init__` 只初始化 `transitions`,不创建 event,以避免
118
+ 在无事件循环的线程里被误构造。)
119
+ - `snapshot` 对 `transitions` 做浅拷贝;`TransitionRecord` 在 worker
120
+ 内部一旦被追加后,除回填末条 `duration` 外不应再被修改,否则外部
121
+ 已经 pop 走的快照可能被"幽灵修改"。
122
+
123
+ Attributes:
124
+ task_id: 任务唯一 ID(由上层 client/engine 生成,本模块只负责查表)。
125
+ request: 原始 `InferenceRequest`。用于 `_run_task` 发起 HTTP。
126
+ overrides: 本次提交附带的 `KittyEngineOverrides`(可为 None),在模型 /
127
+ 超时 / header / payload 解析时作为中间优先级兜底。
128
+ status: 当前状态。禁止直接赋值,必须通过 `transition` 更新。
129
+ task_timeout: 单次任务(发 HTTP + 解析)的 wall-clock 超时。由 submit
130
+ 按优先级解析得到:`request.timeout` > submit 参数 >
131
+ `overrides.default_timeout` > `config.default_timeout`。
132
+ persist: true 则终态后仍保留在 `_tasks` 中(可多次 snapshot / 重复 pop
133
+ 语义见 `_KittyEngine.pop_result`);false 则一旦 pop_result 拿走
134
+ 即删除。
135
+ error: 终态为 FAILED / REJECTED / CANCELLED 时的简要错误信息
136
+ (同样会写入 `result`)。
137
+ result: 终态的 `WireResult`,由 `_run_task` 在各终态分支显式构造。
138
+ asyncio_task: 运行 `_run_task` 的底层 `asyncio.Task`,用于 cancel。
139
+ done_event: 到达终态时由 `_run_task` 的 `finally` 置位。`wait_task` /
140
+ `cleanup_tasks` 依赖它做异步等待。
141
+ transitions: 状态转移流水账。`transitions[0]` 为 SUBMITTED 初始记录,
142
+ 顺序即状态机实际走过的路径。
143
+ """
144
+
145
+ task_id: str
146
+ request: InferenceRequest
147
+ overrides: Optional[KittyEngineOverrides]
148
+ status: TaskStatus = TaskStatus.SUBMITTED
149
+ task_timeout: Optional[float] = None
150
+ persist: bool = False
151
+ error: Optional[str] = None
152
+ result: Optional[WireResult] = None
153
+ asyncio_task: Optional[asyncio.Task] = None
154
+ done_event: Optional[asyncio.Event] = None
155
+ transitions: list[TransitionRecord] = field(default_factory=list)
156
+ # 状态转移回调。由 `_KittyEngine.submit` 在构造后注入(当前只支持单订阅者);
157
+ # 不持有 worker 反向引用,避免循环。签名: (task_id, old, new) -> None;
158
+ # `old` 为 None 表示"首次落地"(由 `_KittyEngine.submit` 显式触发,不经 transition)。
159
+ on_transition: Optional[Callable[[str, Optional[TaskStatus], TaskStatus], None]] = None
160
+
161
+ def __post_init__(self) -> None:
162
+ # 初始状态(默认 SUBMITTED)作为第一条转移记录。
163
+ # 注意:这里直接 append,不走 transition(),因为此时没有"上一条"需要回填 duration,
164
+ # 且 transition() 的语义是"从旧状态进入新状态",不适用于首次落地。
165
+ if not self.transitions:
166
+ self.transitions.append(TransitionRecord(status=self.status, timestamp=_now_str(), desc="task created"))
167
+
168
+ # ------------------------------------------------------------------
169
+ # 状态转移(唯一入口)
170
+ # ------------------------------------------------------------------
171
+
172
+ def transition(self, new_status: TaskStatus, desc: str = "") -> None:
173
+ """更新 `status`、追加转移记录,并回填上一条的 `duration`。
174
+
175
+ 行为:
176
+ 1. 取当前时间 `now`(单次 `datetime.now()`);
177
+ 2. 把 `transitions[-1].duration` 置为 `now.timestamp() - prev.epoch`;
178
+ 3. 将 `self.status` 置为 `new_status`;
179
+ 4. 追加 `TransitionRecord(new_status, now.isoformat(...), desc=desc)`。
180
+
181
+ Args:
182
+ new_status: 要进入的新状态。
183
+ desc: 进入该状态的原因,便于日志 / 排障。
184
+
185
+ Note:
186
+ - 所有状态变更都应走这里,不要直接赋值 `self.status`,否则会丢失
187
+ 时间线。
188
+ - 注意:本模块的 `transition()` 约束仅限 `_WorkerTask`;
189
+ `TaskHandle`(GeneralEngine)当前仍直接赋值 `status`,未接入
190
+ transitions 流水账。两边状态机由各自维护。
191
+ - 不做状态合法性校验:调用方自己确保转移顺序合理(见模块开头状态机图)。
192
+ """
193
+ # 单次取 now,避免"写入字符串 → 再 fromisoformat 解回来"的冗余解析。
194
+ now = datetime.now()
195
+ now_str = now.isoformat(sep=" ", timespec="microseconds")
196
+ old = self.status
197
+ if self.transitions:
198
+ self.transitions[-1].duration = now.timestamp() - self.transitions[-1].epoch
199
+ self.status = new_status
200
+ self.transitions.append(TransitionRecord(status=new_status, timestamp=now_str, desc=desc))
201
+ # 通知回调。callback 异常绝不能污染 worker 状态机。
202
+ callback = self.on_transition
203
+ if callback is not None:
204
+ try:
205
+ callback(self.task_id, old, new_status)
206
+ except Exception:
207
+ logger.exception("transition callback 异常 (task=%s)", self.task_id)
208
+
209
+ def is_finished(self) -> bool:
210
+ """判断是否已到达终态(SUCCESS / FAILED / TIMEOUT / CANCELLED / REJECTED)。
211
+
212
+ Returns:
213
+ 到达终态返回 True,否则 False。
214
+ """
215
+ return self.status in (
216
+ TaskStatus.SUCCESS,
217
+ TaskStatus.FAILED,
218
+ TaskStatus.TIMEOUT,
219
+ TaskStatus.CANCELLED,
220
+ TaskStatus.REJECTED,
221
+ )
222
+
223
+ def snapshot(self) -> TaskSnapshot:
224
+ """生成可跨进程传输的 `TaskSnapshot`。
225
+
226
+ Returns:
227
+ 当前任务状态的快照。
228
+
229
+ Note:
230
+ - `transitions` 做**深拷贝**(逐条 `TransitionRecord.model_copy()`):
231
+ `transition()` 会原地修改 `transitions[-1].duration` 来回填上一个
232
+ 状态的持续时长;若只做浅拷贝,则已经交给客户端的 snapshot 末条
233
+ 记录会被后续 transition 调用"幽灵修改"。深拷贝彻底消除该隐患。
234
+ - 不回传 `result`:瘦身,避免每次 poll status 都搬运模型输出。
235
+ 客户端通过 `has_result` 判断是否需要单独 pop。
236
+ - 首条 SUBMITTED 记录由 `__post_init__` 在构造时自动追加,此处无需
237
+ 额外处理;REJECTED 分支里调用 `transition(REJECTED)` 时,会自然
238
+ 回填 SUBMITTED 的 duration,等价于"校验阶段耗时"。
239
+ """
240
+ return TaskSnapshot(
241
+ task_id=self.task_id,
242
+ status=self.status,
243
+ # 深拷贝:TransitionRecord 可被 transition() 原地改 duration,
244
+ # 不深拷贝会把已发出的 snapshot 一起带跑。
245
+ transitions=[r.model_copy() for r in self.transitions],
246
+ task_timeout=self.task_timeout,
247
+ persist=self.persist,
248
+ error=self.error,
249
+ has_result=self.result is not None,
250
+ )
251
+
252
+
253
+ class _KittyEngine:
254
+ """KittyEngine 后端进程的任务执行器。
255
+
256
+ 职责:
257
+ - 维护模型注册表 + 单个 `httpx.AsyncClient` + 全局并发信号量;
258
+ - 维护 `_tasks: dict[task_id, _WorkerTask]`;
259
+ - 暴露任务级 API(submit / cancel / snapshot / pop_result / wait_task /
260
+ cleanup_tasks / clear_done / force_remove)。
261
+
262
+ 使用约束:
263
+ - 必须在 asyncio 事件循环线程内调用所有方法(含同步方法:它们会创建
264
+ `asyncio.Task` / `asyncio.Event`,依赖当前 loop)。
265
+ - 生命周期:`setup()` 一次 → 正常工作 → `teardown()` 一次。重复 setup
266
+ / teardown 会 warning 但不抛错。
267
+ """
268
+
269
+ def __init__(self, config: KittyEngineConfig) -> None:
270
+ self.config = config
271
+
272
+ # —— 模型注册表(与 GeneralEngine 一致,reload 模式) ——
273
+ self.model_registry = ModelConfigRegistry()
274
+ self.model_registry.load_from_json(config.registry_path, mode="reload")
275
+
276
+ # —— 运行时资源,由 setup 初始化;teardown 置回 None / False ——
277
+ # _sem 用于 RUNNING 并发上限控制;PENDING → RUNNING 之间 await 它。
278
+ self._sem: Optional[asyncio.Semaphore] = None
279
+ # 单一 AsyncClient,连接复用;timeout 由 config 指定。
280
+ self._http_client: Optional[httpx.AsyncClient] = None
281
+ # 活跃任务表。终态任务是否保留取决于 persist。
282
+ self._tasks: dict[str, _WorkerTask] = {}
283
+ self._setup_done: bool = False
284
+ # 状态转移回调(单订阅者,一般由 KittyServer 注册,用于维护 per-connection
285
+ # 状态统计)。callback 在事件循环线程内被**同步**调用,不允许阻塞。
286
+ self._transition_callback: Optional[Callable[[str, Optional[TaskStatus], TaskStatus], None]] = None
287
+
288
+ def set_transition_callback(
289
+ self,
290
+ callback_function: Optional[Callable[[str, Optional[TaskStatus], TaskStatus], None]],
291
+ ) -> None:
292
+ """注册(或清除)状态转移回调。
293
+
294
+ 约定:
295
+ - 只能有一个订阅者;多次调用以最后一次为准。
296
+ - 设置时机应在 `submit` 首次调用之前,否则更早的任务不会带上 callback。
297
+ - callback 内不能抛异常(调用端会吞掉 + 记 exception 日志)。
298
+ """
299
+ self._transition_callback = callback_function
300
+
301
+ # ------------------------------------------------------------------
302
+ # 生命周期
303
+ # ------------------------------------------------------------------
304
+
305
+ async def setup(self) -> None:
306
+ """初始化 sem + AsyncClient。幂等(重复调用仅 warning)。"""
307
+ if self._setup_done:
308
+ logger.warning("KittyWorker 已启动,无需重复启动。")
309
+ return
310
+ self._sem = asyncio.Semaphore(self.config.max_global_concurrency)
311
+ self._http_client = httpx.AsyncClient(
312
+ timeout=httpx.Timeout(
313
+ timeout=self.config.http_client_connect_timeout,
314
+ read=self.config.http_client_read_timeout,
315
+ ),
316
+ limits=httpx.Limits(
317
+ max_connections=self.config.max_global_concurrency,
318
+ max_keepalive_connections=self.config.max_global_concurrency,
319
+ ),
320
+ )
321
+ self._setup_done = True
322
+ logger.info(
323
+ "KittyWorker 启动 (max_global_concurrency=%d, models=%d)",
324
+ self.config.max_global_concurrency,
325
+ len(self.model_registry.model_dict),
326
+ )
327
+
328
+ async def teardown(self) -> None:
329
+ """取消所有未完成任务并关闭 HTTP 客户端。幂等。
330
+
331
+ 流程:
332
+ 1. 收集所有未 finished 的 `asyncio_task`;
333
+ 2. 统一 `cancel()`;
334
+ 3. `gather(..., return_exceptions=True)` 等它们真正退出,防止在关闭
335
+ client 之后还有未结束协程尝试发请求;
336
+ 4. 关闭 `_http_client`。
337
+ """
338
+ if not self._setup_done:
339
+ logger.warning("KittyWorker 未启动,无需停止。")
340
+ return
341
+ # 取消未完成任务
342
+ pending = [t for t in self._tasks.values() if not t.is_finished() and t.asyncio_task is not None]
343
+ for t in pending:
344
+ t.asyncio_task.cancel() # type: ignore[union-attr]
345
+ if pending:
346
+ await asyncio.gather(*[t.asyncio_task for t in pending if t.asyncio_task], return_exceptions=True)
347
+ if self._http_client is not None:
348
+ await self._http_client.aclose()
349
+ self._http_client = None
350
+ self._setup_done = False
351
+ logger.info("KittyWorker 已停止")
352
+
353
+ # ------------------------------------------------------------------
354
+ # 任务 API(均在事件循环线程内调用)
355
+ # ------------------------------------------------------------------
356
+
357
+ def submit(
358
+ self,
359
+ *,
360
+ task_id: str,
361
+ request: InferenceRequest,
362
+ overrides: Optional[KittyEngineOverrides],
363
+ persist: bool,
364
+ timeout: Optional[float],
365
+ ) -> TaskSnapshot:
366
+ """提交一个任务并立刻返回其初始 snapshot。
367
+
368
+ 行为:
369
+ - 按 "request.timeout > submit timeout > overrides.default_timeout >
370
+ config.default_timeout" 的优先级解析 `task_timeout`;
371
+ - 构造 `_WorkerTask`(此时已经写入 SUBMITTED 记录);
372
+ - 创建 `done_event` 与 `asyncio_task`(= `_run_task(wt)`);
373
+ - 登记到 `_tasks`;
374
+ - 返回 snapshot 给调用方(此时 task 大概率还在 SUBMITTED;真正的
375
+ 校验 / ACCEPTED / REJECTED 由 `_run_task` 在下一个 tick 写入)。
376
+
377
+ Args:
378
+ task_id: 上层分配的任务 ID。
379
+ request: 推理请求。
380
+ overrides: 可选的 per-submit 覆盖配置。
381
+ persist: 终态后是否保留在 `_tasks` 表里。
382
+ timeout: 单次任务超时秒数(可被 `request.timeout` 覆盖)。
383
+
384
+ Returns:
385
+ 初始 `TaskSnapshot`。
386
+
387
+ Note:
388
+ - 不做同 `task_id` 的去重;上层 client/engine 负责保证唯一性。
389
+ - 不立刻做模型校验:校验在 `_run_task` 头部完成,REJECTED 也走统一
390
+ 的终态路径(置 result + transition + done_event.set)。
391
+ """
392
+ # 超时层级:request.timeout > submit timeout > overrides.default_timeout > config.default_timeout
393
+ effective_timeout = request.timeout
394
+ if effective_timeout is None:
395
+ effective_timeout = timeout
396
+ if effective_timeout is None and overrides is not None:
397
+ effective_timeout = overrides.default_timeout
398
+ if effective_timeout is None:
399
+ effective_timeout = self.config.default_timeout
400
+
401
+ wt = _WorkerTask(
402
+ task_id=task_id,
403
+ request=request,
404
+ overrides=overrides,
405
+ task_timeout=effective_timeout,
406
+ persist=persist,
407
+ )
408
+ # done_event 必须在 event loop 线程内创建,因此放到此处而非 __post_init__
409
+ wt.done_event = asyncio.Event()
410
+ wt.on_transition = self._transition_callback
411
+ # 注意:**先**登记到 _tasks,**再**显式 fire 一次"首次落地"事件(old=None),
412
+ # 最后才 create_task。这样:
413
+ # - 回调被触发时,`worker._tasks[task_id]` 已可见,callback 里若需要
414
+ # 回查 snapshot 不会 KeyError;
415
+ # - create_task 调度的 _run_task 只会在当前同步段结束后才实际执行,
416
+ # 因此它里面后续的 transition 事件一定晚于此处首次事件。
417
+ # 之所以不让 `__post_init__` 走 transition():初始记录不回填 duration、
418
+ # 没有"旧状态",语义上不是一次 transition。把首次 fire 放在这里,可以保证
419
+ # 订阅方看到任意 task 全生命周期的第一条事件总是由 worker 主动发出。
420
+ self._tasks[task_id] = wt
421
+ if self._transition_callback is not None:
422
+ try:
423
+ self._transition_callback(wt.task_id, None, wt.status)
424
+ except Exception:
425
+ logger.exception("transition callback 异常(initial, task=%s)", wt.task_id)
426
+ wt.asyncio_task = asyncio.create_task(self._run_task(wt))
427
+ return wt.snapshot()
428
+
429
+ def cancel(self, task_id: str, force: bool = False) -> bool:
430
+ """取消任务。
431
+
432
+ 语义:
433
+ - 任务不存在 / 已到终态:返回 False。
434
+ - RUNNING 中的任务:默认不取消(保护已经在打的 HTTP 请求),需要
435
+ `force=True` 才取消。
436
+ - 其它非终态(SUBMITTED / ACCEPTED / PENDING):直接取消
437
+ `asyncio_task`;`_run_task` 的 `except CancelledError` 分支会把
438
+ 状态推到 CANCELLED。
439
+
440
+ Args:
441
+ task_id: 目标任务 ID。
442
+ force: 是否强制取消 RUNNING 中的任务。
443
+
444
+ Returns:
445
+ True 表示已向 `asyncio_task` 发出 cancel。真正进入终态需等
446
+ `_run_task` 清理 + `done_event.set()`(可用 `wait_task` 同步等待)。
447
+ """
448
+ wt = self._tasks.get(task_id)
449
+ if wt is None or wt.is_finished():
450
+ return False
451
+ if wt.status == TaskStatus.RUNNING and not force:
452
+ return False
453
+ if wt.asyncio_task is not None:
454
+ wt.asyncio_task.cancel()
455
+ return True
456
+
457
+ def snapshot(self, task_id: str) -> Optional[TaskSnapshot]:
458
+ """返回当前任务的快照。
459
+
460
+ Args:
461
+ task_id: 目标任务 ID。
462
+
463
+ Returns:
464
+ 对应任务的 `TaskSnapshot`;任务不存在返回 None。
465
+ """
466
+ wt = self._tasks.get(task_id)
467
+ return wt.snapshot() if wt is not None else None
468
+
469
+ def pop_result(self, task_id: str) -> Optional[WireResult]:
470
+ """取结果。仅在终态有效。
471
+
472
+ Args:
473
+ task_id: 目标任务 ID。
474
+
475
+ Returns:
476
+ - 任务不存在 / 未终态 → None;
477
+ - 终态 + `persist=False` → 返回 result 并从 `_tasks` 删除该条目;
478
+ - 终态 + `persist=True` → 返回 result,条目保留(可再次 pop 到同一 result)。
479
+ """
480
+ wt = self._tasks.get(task_id)
481
+ if wt is None or not wt.is_finished():
482
+ return None
483
+ result = wt.result
484
+ if not wt.persist:
485
+ self._tasks.pop(task_id, None)
486
+ return result
487
+
488
+ def clear_done(self) -> list[str]:
489
+ """主动清理所有终态任务(含 persist)。
490
+
491
+ 典型用法:客户端做周期 GC,或上层显式要求丢弃历史。
492
+
493
+ Returns:
494
+ 被清理的任务 ID 列表。调用方据此同步自己的索引(例如 server 的
495
+ `_task_to_connection` 映射)。长度即 "实际清理的任务数"。
496
+ """
497
+ done = [tid for tid, wt in self._tasks.items() if wt.is_finished()]
498
+ for tid in done:
499
+ self._tasks.pop(tid, None)
500
+ return done
501
+
502
+ def force_remove(self, task_id: str) -> Optional[WireResult]:
503
+ """无视状态强制移除条目。
504
+
505
+ 与 `cancel` 的区别:cancel 只是触发取消流程、等待终态;force_remove
506
+ 直接从 `_tasks` 摘除,不管任务有没有跑完。一般只用于"客户端已断线、
507
+ 条目不再有意义"的兜底清理。
508
+
509
+ Args:
510
+ task_id: 目标任务 ID。
511
+
512
+ Returns:
513
+ 该任务的 `WireResult`(若已有),否则 None。
514
+ """
515
+ wt = self._tasks.pop(task_id, None)
516
+ return wt.result if wt is not None else None
517
+
518
+ async def wait_task(self, task_id: str, timeout: Optional[float]) -> tuple[Optional[WireResult], bool]:
519
+ """异步等待任务终态后 pop result。
520
+
521
+ 依赖:`done_event` 在 `_run_task` 的 `finally` 里被 set;所以所有终态
522
+ 路径都能唤醒本方法。
523
+
524
+ Args:
525
+ task_id: 目标任务 ID。
526
+ timeout: 等待秒数;None 表示不限时。
527
+
528
+ Returns:
529
+ `(result, timed_out)`:
530
+
531
+ - 任务不存在 → `(None, False)`;
532
+ - 超时 → `(None, True)`(task 仍在进行,状态不变);
533
+ - 正常 → `(result, False)`;`result` 的 pop 遵循 `pop_result` 的
534
+ persist 语义。
535
+ """
536
+ wt = self._tasks.get(task_id)
537
+ if wt is None:
538
+ return None, False
539
+ if not wt.is_finished():
540
+ if wt.done_event is None:
541
+ return None, False
542
+ try:
543
+ await asyncio.wait_for(wt.done_event.wait(), timeout=timeout)
544
+ except asyncio.TimeoutError:
545
+ return None, True
546
+ return self.pop_result(task_id), False
547
+
548
+ async def cleanup_tasks(self, task_ids, *, wait_timeout: float = 5.0) -> None:
549
+ """批量取消一组任务并等其真正进入终态;非 persist 的一并摘除。
550
+
551
+ 用于客户端断连、engine 主动 drain 等场景。使用 `force=True` 以便 RUNNING
552
+ 任务也会被取消。
553
+
554
+ Args:
555
+ task_ids: 要清理的任务 ID 序列。
556
+ wait_timeout: 等待全部任务进入终态的整体超时秒数。
557
+ """
558
+ tids = [t for t in task_ids if t in self._tasks]
559
+ if not tids:
560
+ return
561
+ events: list[asyncio.Event] = []
562
+ for tid in tids:
563
+ wt = self._tasks.get(tid)
564
+ if wt is None:
565
+ continue
566
+ self.cancel(tid, force=True)
567
+ if not wt.is_finished() and wt.done_event is not None:
568
+ events.append(wt.done_event)
569
+ if events:
570
+ try:
571
+ await asyncio.wait_for(
572
+ asyncio.gather(*[e.wait() for e in events], return_exceptions=True),
573
+ timeout=wait_timeout,
574
+ )
575
+ except asyncio.TimeoutError:
576
+ logger.warning("cleanup_tasks 等待 %d 个任务终态超时", len(events))
577
+ for tid in tids:
578
+ wt = self._tasks.get(tid)
579
+ if wt is not None and not wt.persist:
580
+ self._tasks.pop(tid, None)
581
+
582
+ # ------------------------------------------------------------------
583
+ # 任务执行
584
+ # ------------------------------------------------------------------
585
+
586
+ async def _run_task(self, wt: _WorkerTask) -> None:
587
+ """单个任务的完整生命周期协程。状态机在此闭合。
588
+
589
+ 阶段:
590
+ 1. 模型解析:`request.model_name` > `overrides.default_model` >
591
+ `config.default_model`,全空则 REJECTED。
592
+ 2. 模型校验:在 `model_registry` 中查表,未命中则 REJECTED。
593
+ 3. `transition(ACCEPTED)`:通过校验。
594
+ 4. `transition(PENDING)` → `async with self._sem` →
595
+ `transition(RUNNING)`:排队 → 持有 sem → 开始执行。
596
+ 5. `async with asyncio.timeout(wt.task_timeout):` 内 `await self._send_request(...)`:
597
+ 发 HTTP + 解析;
598
+ - 正常 → 填 `result` + `transition(SUCCESS)`;
599
+ - `CancelledError` → 填 `CANCELLED` result + transition,再 raise
600
+ 让外层 task 正确结束;此处可确定是外部 cancel,因为
601
+ `asyncio.timeout` 超时直接抛 `TimeoutError`,不经 inner cancel;
602
+ - `TimeoutError` → TIMEOUT;
603
+ - 其它异常 → FAILED,error=`type: msg`。
604
+ 6. 外层 `except CancelledError`:覆盖 ACCEPTED / PENDING 阶段(尚未
605
+ 进入 RUNNING)被取消的情形。若此时还不是终态,补写 CANCELLED。再
606
+ raise。
607
+ 7. `finally`:无论哪条路径,只要 `done_event` 存在就 set,唤醒所有
608
+ `wait_task` / `cleanup_tasks`。
609
+
610
+ 与其它组件的配合要点:
611
+ - 每个终态分支同时写 `wt.error` / `wt.result` / `transition`,
612
+ 缺一不可:
613
+ - `wt.result` 要带正确的 `status` 字段,client 据此同步
614
+ snapshot.status;
615
+ - `transition` 保证 `transitions` 的时间线闭合;
616
+ - `done_event.set()` 保证等待方被唤醒。
617
+ - 绝不在此函数外修改 `wt.status`(见 `_WorkerTask.transition` 的
618
+ 约束)。
619
+
620
+ Args:
621
+ wt: 要驱动的任务实例,已由 `submit` 登记到 `_tasks`。
622
+ """
623
+ request = wt.request
624
+ overrides = wt.overrides
625
+
626
+ # ---- 1. 模型名解析 ----
627
+ model_name = request.model_name
628
+ if model_name is None and overrides is not None:
629
+ model_name = overrides.default_model
630
+ if model_name is None:
631
+ model_name = self.config.default_model
632
+ if model_name is None:
633
+ wt.error = "no model_name (request/overrides/config all empty)"
634
+ wt.result = WireResult(success=False, task_id=wt.task_id, status=TaskStatus.REJECTED, error_message=wt.error)
635
+ wt.transition(TaskStatus.REJECTED, desc="no model_name")
636
+ if wt.done_event is not None:
637
+ wt.done_event.set()
638
+ return
639
+
640
+ # ---- 2. 模型查表校验 ----
641
+ try:
642
+ model = self.model_registry.get(name=model_name)
643
+ except KeyError as e:
644
+ wt.error = str(e)
645
+ wt.result = WireResult(success=False, task_id=wt.task_id, status=TaskStatus.REJECTED, error_message=wt.error)
646
+ wt.transition(TaskStatus.REJECTED, desc=f"model not found: {model_name}")
647
+ if wt.done_event is not None:
648
+ wt.done_event.set()
649
+ return
650
+
651
+ # ---- 3. 校验通过 ----
652
+ # ACCEPTED 与下一行 PENDING 在正常路径下紧邻;保留两条独立转移,便于
653
+ # 外部观察"校验耗时"与"排队耗时"的界限。
654
+ wt.transition(TaskStatus.ACCEPTED, desc="validation passed")
655
+
656
+ assert self._sem is not None
657
+ try:
658
+ # ---- 4. 排队等 sem ----
659
+ wt.transition(TaskStatus.PENDING, desc="waiting for sem")
660
+ async with self._sem:
661
+ # ---- 5. 持有 sem,开跑 ----
662
+ wt.transition(TaskStatus.RUNNING, desc="sem acquired")
663
+ run_start = time.time()
664
+ try:
665
+ # 用 asyncio.timeout 而非 asyncio.wait_for:前者超时时直接
666
+ # 抛 TimeoutError,不会经由内层 CancelledError 上传,因此
667
+ # `except asyncio.CancelledError` 分支可以确定是"外部 cancel",
668
+ # 避免 timeout 与 cancel 路径竞争写两次终态(参见 issue 历史)。
669
+ async with asyncio.timeout(wt.task_timeout):
670
+ model_output = await self._send_request(request, model, overrides)
671
+ wt.result = WireResult(
672
+ success=True,
673
+ task_id=wt.task_id,
674
+ status=TaskStatus.SUCCESS,
675
+ model_output=model_output,
676
+ duration=time.time() - run_start,
677
+ )
678
+ wt.transition(TaskStatus.SUCCESS, desc="completed")
679
+ except asyncio.CancelledError:
680
+ # RUNNING 中被外部 cancel:先写终态(外层 finally 统一触发 done_event),再继续 raise。
681
+ wt.error = "cancelled"
682
+ wt.result = WireResult(success=False, task_id=wt.task_id, status=TaskStatus.CANCELLED, error_message="cancelled")
683
+ wt.transition(TaskStatus.CANCELLED, desc="cancelled while running")
684
+ logger.info("任务 %s 已取消(RUNNING)", wt.task_id)
685
+ raise
686
+ except TimeoutError:
687
+ # asyncio.timeout 超时;与 CancelledError 分支互斥,不会重复写。
688
+ wt.error = "timeout"
689
+ wt.result = WireResult(success=False, task_id=wt.task_id, status=TaskStatus.TIMEOUT, error_message="timeout")
690
+ wt.transition(TaskStatus.TIMEOUT, desc="timeout")
691
+ except Exception as e:
692
+ logger.exception("任务 %s 执行异常", wt.task_id)
693
+ wt.error = f"{type(e).__name__}: {e}"
694
+ wt.result = WireResult(success=False, task_id=wt.task_id, status=TaskStatus.FAILED, error_message=wt.error)
695
+ wt.transition(TaskStatus.FAILED, desc=wt.error)
696
+ except asyncio.CancelledError:
697
+ # ACCEPTED / PENDING 阶段(等 sem 时)被取消:此处兜底,保证状态进入终态。
698
+ # 若 RUNNING 的分支已经写过 CANCELLED,则 is_finished() 为 True,这里不会重复写。
699
+ if not wt.is_finished():
700
+ wt.error = "cancelled"
701
+ wt.result = WireResult(success=False, task_id=wt.task_id, status=TaskStatus.CANCELLED, error_message="cancelled")
702
+ wt.transition(TaskStatus.CANCELLED, desc="cancelled while pending")
703
+ logger.info("任务 %s 已取消(PENDING)", wt.task_id)
704
+ raise
705
+ finally:
706
+ # 无论正常 / 异常 / 取消,所有终态路径都在此唤醒等待方。
707
+ if wt.done_event is not None:
708
+ wt.done_event.set()
709
+
710
+ # ------------------------------------------------------------------
711
+ # HTTP 核心
712
+ #
713
+ # 下面这组方法负责把 (request, model, overrides, config) 四层配置压平成
714
+ # 一次 HTTP 调用,并处理重试。优先级统一为:
715
+ # request > model > overrides > config
716
+ # 仅对列表/字典采用 **merge**(后者覆盖前者);对单值(api_key、默认 IP 等)
717
+ # 采用"优先级高者非空即返回"。
718
+ # ------------------------------------------------------------------
719
+
720
+ def _resolve_default_ip(
721
+ self,
722
+ overrides: Optional[KittyEngineOverrides],
723
+ ) -> Optional[InferenceParameters]:
724
+ """解析"默认推理参数"层,overrides 优先于 config。
725
+
726
+ Args:
727
+ overrides: per-submit 覆盖配置。
728
+
729
+ Returns:
730
+ 选中的默认 `InferenceParameters`;全空则 None。
731
+
732
+ Note:
733
+ model 和 request 的 inference_parameters 在 `_build_payload` 里
734
+ 另行 merge,不在此处返回。
735
+ """
736
+ if overrides is not None and overrides.default_inference_parameters is not None:
737
+ return overrides.default_inference_parameters
738
+ return self.config.default_inference_parameters
739
+
740
+ def _get_api_key(self, request: InferenceRequest, model: ModelConfig, overrides: Optional[KittyEngineOverrides]) -> str:
741
+ """解析 API Key,按 request > model > overrides > config 取第一个非 None。
742
+
743
+ 全部为空时返回空字符串并 warning(有些自研网关允许无 key;严格环境应在
744
+ 上游校验)。
745
+
746
+ Args:
747
+ request: 推理请求。
748
+ model: 已解析的模型配置。
749
+ overrides: per-submit 覆盖配置。
750
+
751
+ Returns:
752
+ 选中的 API Key 字符串。
753
+ """
754
+ if request.api_key is not None:
755
+ return request.api_key
756
+ if model.api_key is not None:
757
+ return model.api_key
758
+ if overrides is not None and overrides.default_api_key is not None:
759
+ return overrides.default_api_key
760
+ if self.config.default_api_key is not None:
761
+ return self.config.default_api_key
762
+ logger.warning("没有找到可用的 api_key, 将传递空 api_key")
763
+ return ""
764
+
765
+ def _build_headers(self, request: InferenceRequest, model: ModelConfig, overrides: Optional[KittyEngineOverrides]) -> dict[str, str]:
766
+ """组装请求头。
767
+
768
+ 顺序:基础(Authorization + Content-Type)→ config.extra_headers →
769
+ overrides.extra_headers → model.extra_headers → request.extra_headers。
770
+ 后写入的 key 会覆盖前者。
771
+
772
+ Args:
773
+ request: 推理请求。
774
+ model: 已解析的模型配置。
775
+ overrides: per-submit 覆盖配置。
776
+
777
+ Returns:
778
+ 合并好的 headers 字典。
779
+ """
780
+ headers: dict[str, str] = {
781
+ "Authorization": "Bearer " + self._get_api_key(request, model, overrides),
782
+ "Content-Type": "application/json",
783
+ }
784
+ headers.update(self.config.extra_headers)
785
+ if overrides is not None:
786
+ headers.update(overrides.extra_headers)
787
+ headers.update(model.extra_headers)
788
+ headers.update(request.extra_headers)
789
+ return headers
790
+
791
+ def _build_payload(self, request: InferenceRequest, model: ModelConfig, overrides: Optional[KittyEngineOverrides]) -> dict[str, Any]:
792
+ """组装 OpenAI 兼容 payload。
793
+
794
+ 合并顺序(后者覆盖前者):
795
+ 1. 骨架:`model` / `stream` / `messages`;
796
+ 2. default_inference_parameters(overrides / config 二选一,见
797
+ `_resolve_default_ip`);
798
+ 3. `model.default_inference_parameters`;
799
+ 4. `request.inference_parameters`;
800
+ 5. `config.extra_payload`;
801
+ 6. `overrides.extra_payload`;
802
+ 7. `model.extra_payload`;
803
+ 8. `request.extra_payload`。
804
+
805
+ Args:
806
+ request: 推理请求。
807
+ model: 已解析的模型配置。
808
+ overrides: per-submit 覆盖配置。
809
+
810
+ Returns:
811
+ 合并好的 payload 字典。
812
+ """
813
+ payload: dict[str, Any] = {
814
+ "model": model.model_id,
815
+ "stream": request.stream if request.stream is not None else False,
816
+ "messages": [msg.to_dict() for msg in request.messages],
817
+ }
818
+ default_ip = self._resolve_default_ip(overrides)
819
+ if default_ip:
820
+ payload.update(default_ip.to_dict())
821
+ if model.default_inference_parameters:
822
+ payload.update(model.default_inference_parameters.to_dict())
823
+ if request.inference_parameters:
824
+ payload.update(request.inference_parameters.to_dict())
825
+ payload.update(self.config.extra_payload)
826
+ if overrides is not None:
827
+ payload.update(overrides.extra_payload)
828
+ payload.update(model.extra_payload)
829
+ payload.update(request.extra_payload)
830
+ return payload
831
+
832
+ def _get_wait_time(self, status_code: int, attempt: int, base_delay: int) -> float:
833
+ """按 HTTP 状态码决定下一次重试前的等待秒数。
834
+
835
+ Args:
836
+ status_code: 上一次 HTTP 响应状态码。
837
+ attempt: 当前已尝试的次数(0-based)。
838
+ base_delay: 指数退避基数。
839
+
840
+ Returns:
841
+ 等待秒数。
842
+
843
+ - 429(限流):固定 1 秒,避免指数退避把 QPS 压到 0;
844
+ - 408(请求超时):按指数退避;
845
+ - 5xx(服务端错误):指数退避 `base_delay ** attempt`;
846
+ - 其它 4xx:调用方应在外层判定为不可重试,不会进到这里。
847
+
848
+ Note:
849
+ 4xx 非 429 / 408 视为客户端错误(鉴权 / 参数 / 资源不存在等),
850
+ 重试只是浪费时间。判断逻辑由 `_send_request` 在调用本函数之前
851
+ 完成,本函数仅给"应当重试的状态码"算等待时长。
852
+ """
853
+ if status_code == 429:
854
+ return 1.0
855
+ if status_code == 408 or status_code >= 500:
856
+ return float(base_delay**attempt)
857
+ # 理论上不可达:调用方已用 `_is_retriable_status` 把其它 4xx 拦在外面。
858
+ # 留作 defensive 兜底,避免未来调用点漏判时退化成热循环。
859
+ return 5.0
860
+
861
+ @staticmethod
862
+ def _is_retriable_status(status_code: int) -> bool:
863
+ """HTTP 状态码是否值得重试。
864
+
865
+ - 429 / 408:限流 / 请求超时,可重试。
866
+ - 5xx:服务端错误,可重试。
867
+ - 其它 4xx:客户端错误(401/403/404/422 …),重试无意义,立刻失败。
868
+ """
869
+ return status_code in (408, 429) or status_code >= 500
870
+
871
+ def _resolve_retry_params(self, overrides: Optional[KittyEngineOverrides]) -> tuple[int, int]:
872
+ """解析 `(max_retries, base_delay)`,overrides 中非 None 的字段覆盖 config。
873
+
874
+ Args:
875
+ overrides: per-submit 覆盖配置。
876
+
877
+ Returns:
878
+ `(max_retries, base_delay)` 二元组。
879
+ """
880
+ max_retries = self.config.default_max_retries
881
+ base_delay = self.config.default_base_delay
882
+ if overrides is not None:
883
+ if overrides.default_max_retries is not None:
884
+ max_retries = overrides.default_max_retries
885
+ if overrides.default_base_delay is not None:
886
+ base_delay = overrides.default_base_delay
887
+ return max_retries, base_delay
888
+
889
+ def build_request(
890
+ self,
891
+ request: InferenceRequest,
892
+ overrides: Optional[KittyEngineOverrides] = None,
893
+ ) -> PreparedRequest:
894
+ """组装最终发给大模型服务商的请求,不发出任何网络调用。
895
+
896
+ 用于调试:验证 URL、Headers、Payload 是否符合预期。
897
+
898
+ Args:
899
+ request (InferenceRequest): 推理请求对象。
900
+ overrides (Optional[KittyEngineOverrides]): 覆盖配置,默认为 None。
901
+
902
+ Returns:
903
+ PreparedRequest: 包含模型名、URLs、请求头和负载的已准备请求。
904
+ """
905
+
906
+ model_name = request.model_name
907
+ if not model_name:
908
+ model_name = overrides.default_model if (overrides and overrides.default_model) else self.config.default_model
909
+ if not model_name:
910
+ raise ValueError("未指定 model_name 且 config 中无 default_model")
911
+ model = self.model_registry.get(name=model_name)
912
+ if model is None:
913
+ raise ValueError(f"模型 '{model_name}' 不在注册表中")
914
+ return PreparedRequest(
915
+ model_name=model_name,
916
+ url="",
917
+ urls=model.api_urls,
918
+ headers=self._build_headers(request, model, overrides),
919
+ payload=self._build_payload(request, model, overrides),
920
+ )
921
+
922
+ def mock_request(
923
+ self,
924
+ request: InferenceRequest,
925
+ overrides: Optional[KittyEngineOverrides] = None,
926
+ ) -> MockRequest:
927
+ """组装最终发给大模型服务商的请求,不发出任何网络调用。
928
+
929
+ 用于调试:验证 URL、Headers、Payload 是否符合预期。
930
+
931
+ Args:
932
+ request (InferenceRequest): 推理请求对象。
933
+ overrides (Optional[KittyEngineOverrides]): 覆盖配置,默认为 None。
934
+
935
+ Returns:
936
+ MockRequest: 包含 URLs、模型名、请求头和负载的模拟请求。
937
+ """
938
+ prepared: PreparedRequest = self.build_request(request, overrides=overrides)
939
+ return MockRequest(
940
+ urls=prepared.urls,
941
+ model=prepared.model_name,
942
+ headers=prepared.headers,
943
+ payload=prepared.payload,
944
+ )
945
+
946
+ async def _send_request(
947
+ self,
948
+ request: InferenceRequest,
949
+ model: ModelConfig,
950
+ overrides: Optional[KittyEngineOverrides],
951
+ ) -> ModelOutput:
952
+ """一次完整的"带重试的 HTTP 调用 + 响应解析"。
953
+
954
+ 重试策略:
955
+ 共 `max_retries` 次尝试(非 `max_retries + 1`)。每次尝试内:
956
+
957
+ - 网络异常 (NetworkError / TimeoutException):
958
+ `wait = base_delay ** attempt` 后重试;最后一次直接 raise。
959
+ - HTTP 不可重试状态(4xx 非 408/429):立即 raise
960
+ `httpx.HTTPStatusError`,不消耗剩余 attempts(鉴权/参数错重试
961
+ 无意义)。
962
+ - HTTP 可重试状态(408 / 429 / 5xx):按 `_get_wait_time`
963
+ 等待后重试;最后一次 raise `httpx.HTTPStatusError`。
964
+ - 响应 JSON 解析失败:等 `base_delay` 后重试;最后一次 raise
965
+ `RuntimeError`。
966
+ - choices 为空(疑似风控):等 `base_delay` 后重试;最后一次 raise
967
+ `RuntimeError`。
968
+ - finish_reason != 'stop':warning 但不重试,视为正常返回(例如
969
+ length、tool_calls 等)。
970
+
971
+ 与 model 的协作:
972
+ - `model.get_url()` 在每次 attempt 内取 URL(支持多节点轮询 /
973
+ 负载均衡);
974
+ - 必须在 `finally` 中 `model.release_url(url)` 释放该节点的占用
975
+ 计数,否则会把节点计数泄漏,最终把模型自己的并发打爆。
976
+
977
+ 外部超时:
978
+ 本方法自身不设顶层超时;由 `_run_task` 用
979
+ `asyncio.wait_for(..., timeout=wt.task_timeout)` 包裹,超时会转成
980
+ `asyncio.TimeoutError`。
981
+
982
+ Stream:
983
+ 首版不支持 stream;若 payload `stream=True` 直接抛
984
+ `NotImplementedError`。
985
+
986
+ Args:
987
+ request: 推理请求。
988
+ model: 已解析的模型配置。
989
+ overrides: per-submit 覆盖配置。
990
+
991
+ Returns:
992
+ 解析完成的 `ModelOutput`。
993
+
994
+ Raises:
995
+ NotImplementedError: payload 中 stream=True。
996
+ httpx.HTTPStatusError: 最后一次 attempt 仍为非 200。
997
+ httpx.NetworkError: 最后一次 attempt 仍为网络异常。
998
+ httpx.TimeoutException: 最后一次 attempt 仍为读超时。
999
+ RuntimeError: 最后一次 attempt 响应解析失败 / choices 为空。
1000
+ """
1001
+ assert self._http_client is not None
1002
+ client = self._http_client
1003
+
1004
+ prepared = self.build_request(request, overrides)
1005
+ if prepared.payload["stream"]:
1006
+ raise NotImplementedError("KittyEngine 暂未实现 stream 模式")
1007
+ max_retries, base_delay = self._resolve_retry_params(overrides)
1008
+
1009
+ for attempt in range(max_retries):
1010
+ is_last = attempt == max_retries - 1
1011
+ url = model.get_url()
1012
+ try:
1013
+ resp = await client.post(url, json=prepared.payload, headers=prepared.headers)
1014
+ if resp.status_code != 200:
1015
+ body = resp.text
1016
+ # 不可重试的 4xx 直接抛,不浪费剩余 attempts。
1017
+ if not self._is_retriable_status(resp.status_code):
1018
+ logger.error("HTTP %d (不可重试),响应: %s", resp.status_code, body[:500])
1019
+ raise httpx.HTTPStatusError(f"HTTP {resp.status_code}", request=resp.request, response=resp)
1020
+ if is_last:
1021
+ logger.error("HTTP %d,已达最大重试次数,响应: %s", resp.status_code, body[:500])
1022
+ raise httpx.HTTPStatusError(f"HTTP {resp.status_code}", request=resp.request, response=resp)
1023
+ wait_time = self._get_wait_time(resp.status_code, attempt, base_delay)
1024
+ logger.warning(
1025
+ "HTTP %d, 第 %d 次重试,等待 %.1fs... 响应: %s",
1026
+ resp.status_code,
1027
+ attempt + 1,
1028
+ wait_time,
1029
+ body[:500],
1030
+ )
1031
+ await asyncio.sleep(wait_time)
1032
+ continue
1033
+
1034
+ try:
1035
+ parsed = ChatCompletionResponse.model_validate_json(resp.text)
1036
+ except ValidationError as ve:
1037
+ if is_last:
1038
+ raise RuntimeError(f"响应 JSON 解析失败,已达最大重试次数: {ve}; body: {resp.text[:500]}") from ve
1039
+ logger.warning("响应 JSON 解析失败,第 %d 次重试: %r", attempt + 1, ve)
1040
+ await asyncio.sleep(base_delay)
1041
+ continue
1042
+
1043
+ if not parsed.choices:
1044
+ if is_last:
1045
+ raise RuntimeError(f"empty choices after {max_retries} attempts, body: {resp.text[:500]}")
1046
+ logger.warning("resp.choices 为空,疑似风控,第 %d 次重试...", attempt + 1)
1047
+ await asyncio.sleep(base_delay)
1048
+ continue
1049
+ choice = parsed.choices[0]
1050
+ if choice.finish_reason and choice.finish_reason != "stop":
1051
+ # 不重试:length / tool_calls / content_filter 等属业务层语义,交由上层判断。
1052
+ logger.warning(
1053
+ "finish_reason='%s' (非 stop), content_len=%d, usage=%s",
1054
+ choice.finish_reason,
1055
+ len(choice.message.content or ""),
1056
+ parsed.usage.model_dump() if parsed.usage else None,
1057
+ )
1058
+ return ModelOutput(
1059
+ role=choice.message.role,
1060
+ content=choice.message.content,
1061
+ reasoning=choice.message.reasoning_content,
1062
+ finish_reason=choice.finish_reason,
1063
+ usage=parsed.usage.model_dump() if parsed.usage else None,
1064
+ )
1065
+ except (httpx.NetworkError, httpx.TimeoutException) as e:
1066
+ if is_last:
1067
+ logger.error("达到最大重试次数,最后错误: %r", e)
1068
+ raise
1069
+ wait_time = base_delay**attempt
1070
+ logger.warning("网络异常: %r, 等待 %.1fs 后重试...", e, wait_time)
1071
+ await asyncio.sleep(wait_time)
1072
+ finally:
1073
+ # 必须释放节点占用,否则 ModelConfig 的负载均衡计数会泄漏。
1074
+ model.release_url(url)
1075
+
1076
+ # 理论上到不了这里:要么在循环内 return,要么最后一次 attempt raise。
1077
+ raise RuntimeError("Unexpected end of retry loop")