llm-engine-kitty 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_engine/__init__.py +54 -0
- llm_engine/engine.py +771 -0
- llm_engine/general_engine.py +562 -0
- llm_engine/kitty/__init__.py +8 -0
- llm_engine/kitty/__main__.py +46 -0
- llm_engine/kitty/client.py +550 -0
- llm_engine/kitty/config.py +83 -0
- llm_engine/kitty/engine.py +1077 -0
- llm_engine/kitty/protocol.py +213 -0
- llm_engine/kitty/schemas.py +89 -0
- llm_engine/kitty/server.py +408 -0
- llm_engine/model_config.py +112 -0
- llm_engine/schemas.py +251 -0
- llm_engine/utils.py +34 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/METADATA +15 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/RECORD +18 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/WHEEL +5 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,562 @@
|
|
|
1
|
+
# llm_engine/general_engine.py
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
GeneralEngine:自包含的任务化推理引擎。
|
|
5
|
+
|
|
6
|
+
- 后台维护一个 asyncio 事件循环(独立 daemon 线程)
|
|
7
|
+
- submit() 返回 TaskHandle,非阻塞
|
|
8
|
+
- wait()/get() 消费结果;poll() 非消费快照
|
|
9
|
+
- inference()/batch_inference() 为同步便捷接口
|
|
10
|
+
- 不依赖 SimpleEngine / SimpleCoroutineEngine
|
|
11
|
+
|
|
12
|
+
详细设计见 docs/general_engine_design.md
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import httpx
|
|
17
|
+
import threading
|
|
18
|
+
import time
|
|
19
|
+
|
|
20
|
+
import kitty_logger
|
|
21
|
+
from pydantic import ValidationError
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
from typing import IO, Any, TypeAlias
|
|
24
|
+
|
|
25
|
+
from .model_config import ModelConfig, ModelConfigRegistry
|
|
26
|
+
from .schemas import (
|
|
27
|
+
ChatCompletionChoice,
|
|
28
|
+
ChatCompletionResponse,
|
|
29
|
+
InferenceParameters,
|
|
30
|
+
InferenceRequest,
|
|
31
|
+
InferenceRequestResult,
|
|
32
|
+
Message,
|
|
33
|
+
MessageRole,
|
|
34
|
+
ModelOutput,
|
|
35
|
+
PreparedRequest,
|
|
36
|
+
TaskHandle,
|
|
37
|
+
TaskStatus,
|
|
38
|
+
)
|
|
39
|
+
from .utils import gen_unique_id
|
|
40
|
+
|
|
41
|
+
logger = kitty_logger.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
Key: TypeAlias = str | TaskHandle
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class GeneralEngine:
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
model_registry: ModelConfigRegistry,
|
|
52
|
+
default_model: str,
|
|
53
|
+
default_api_key: str | None = None,
|
|
54
|
+
default_inference_parameters: InferenceParameters | None = None,
|
|
55
|
+
extra_headers: dict[str, str] | None = None,
|
|
56
|
+
extra_payload: dict[str, Any] | None = None,
|
|
57
|
+
stream: bool = False,
|
|
58
|
+
max_global_concurrency: int = 64,
|
|
59
|
+
default_timeout: float | None = None,
|
|
60
|
+
default_max_retries: int = 5,
|
|
61
|
+
default_base_delay: int = 2,
|
|
62
|
+
) -> None:
|
|
63
|
+
# —— 静态配置 ——
|
|
64
|
+
self.model_registry: ModelConfigRegistry = model_registry
|
|
65
|
+
self.default_model_name: str = default_model
|
|
66
|
+
self.default_model: ModelConfig = model_registry.get(name=default_model)
|
|
67
|
+
self.default_api_key: str | None = default_api_key
|
|
68
|
+
self.default_inference_parameters: InferenceParameters | None = default_inference_parameters
|
|
69
|
+
self.extra_headers: dict[str, str] = extra_headers or {}
|
|
70
|
+
self.extra_payload: dict[str, Any] = extra_payload or {}
|
|
71
|
+
self.stream: bool = stream
|
|
72
|
+
self.max_global_concurrency: int = max_global_concurrency
|
|
73
|
+
self.default_timeout: float | None = default_timeout
|
|
74
|
+
self.default_max_retries: int = default_max_retries
|
|
75
|
+
self.default_base_delay: int = default_base_delay
|
|
76
|
+
|
|
77
|
+
# —— 运行时(由 setup 填充)——
|
|
78
|
+
self._loop: asyncio.AbstractEventLoop | None = None
|
|
79
|
+
self._loop_thread: threading.Thread | None = None
|
|
80
|
+
self._running: bool = False
|
|
81
|
+
|
|
82
|
+
self._global_sem: asyncio.Semaphore | None = None
|
|
83
|
+
|
|
84
|
+
self._handles: dict[str, TaskHandle] = {}
|
|
85
|
+
self._handles_lock: threading.Lock = threading.Lock()
|
|
86
|
+
|
|
87
|
+
# ------------------------------------------------------------------
|
|
88
|
+
# 生命周期
|
|
89
|
+
# ------------------------------------------------------------------
|
|
90
|
+
|
|
91
|
+
def setup(self) -> None:
|
|
92
|
+
"""启动后台事件循环线程。幂等。"""
|
|
93
|
+
if self._running:
|
|
94
|
+
logger.warning(msg="GeneralEngine 已启动,请勿重复调用 setup()")
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
ready: threading.Event = threading.Event()
|
|
98
|
+
|
|
99
|
+
def _run() -> None:
|
|
100
|
+
loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
|
|
101
|
+
asyncio.set_event_loop(loop)
|
|
102
|
+
self._loop = loop
|
|
103
|
+
# 在 loop 线程中创建 asyncio 原语
|
|
104
|
+
self._global_sem = asyncio.Semaphore(value=self.max_global_concurrency)
|
|
105
|
+
ready.set()
|
|
106
|
+
try:
|
|
107
|
+
loop.run_forever()
|
|
108
|
+
finally:
|
|
109
|
+
# run_forever 结束后,把所有待清理的生成器/任务关掉
|
|
110
|
+
try:
|
|
111
|
+
tasks: set[asyncio.Task[Any]] = asyncio.all_tasks(loop)
|
|
112
|
+
for t in tasks:
|
|
113
|
+
t.cancel()
|
|
114
|
+
if tasks:
|
|
115
|
+
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
|
116
|
+
finally:
|
|
117
|
+
loop.close()
|
|
118
|
+
|
|
119
|
+
self._loop_thread = threading.Thread(target=_run, name="GeneralEngine-Loop", daemon=True)
|
|
120
|
+
self._loop_thread.start()
|
|
121
|
+
ready.wait()
|
|
122
|
+
self._running = True
|
|
123
|
+
logger.info("GeneralEngine 已启动 (max_global_concurrency=%d)", self.max_global_concurrency)
|
|
124
|
+
|
|
125
|
+
def teardown(self, cancel_running: bool = True, timeout: float = 10.0) -> None:
|
|
126
|
+
"""停止后台事件循环,回收线程。"""
|
|
127
|
+
if not self._running:
|
|
128
|
+
return
|
|
129
|
+
|
|
130
|
+
assert self._loop is not None
|
|
131
|
+
assert self._loop_thread is not None
|
|
132
|
+
|
|
133
|
+
if cancel_running:
|
|
134
|
+
with self._handles_lock:
|
|
135
|
+
handles: list[TaskHandle] = list[TaskHandle](self._handles.values())
|
|
136
|
+
for h in handles:
|
|
137
|
+
if not h.is_finished() and h._asyncio_task is not None:
|
|
138
|
+
self._loop.call_soon_threadsafe(callback=h._asyncio_task.cancel)
|
|
139
|
+
|
|
140
|
+
self._loop.call_soon_threadsafe(callback=self._loop.stop)
|
|
141
|
+
self._loop_thread.join(timeout=timeout)
|
|
142
|
+
self._running = False
|
|
143
|
+
logger.info("GeneralEngine 已停止")
|
|
144
|
+
|
|
145
|
+
def is_alive(self) -> bool:
|
|
146
|
+
return self._running and self._loop_thread is not None and self._loop_thread.is_alive()
|
|
147
|
+
|
|
148
|
+
def __enter__(self) -> "GeneralEngine":
|
|
149
|
+
self.setup()
|
|
150
|
+
return self
|
|
151
|
+
|
|
152
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
153
|
+
self.teardown()
|
|
154
|
+
|
|
155
|
+
# ------------------------------------------------------------------
|
|
156
|
+
# 提交与执行
|
|
157
|
+
# ------------------------------------------------------------------
|
|
158
|
+
|
|
159
|
+
def submit(
|
|
160
|
+
self,
|
|
161
|
+
request: InferenceRequest,
|
|
162
|
+
*,
|
|
163
|
+
persist: bool = False,
|
|
164
|
+
timeout: float | None = None,
|
|
165
|
+
) -> TaskHandle:
|
|
166
|
+
if not self._running:
|
|
167
|
+
raise RuntimeError("GeneralEngine 未启动,请先调用 setup()")
|
|
168
|
+
|
|
169
|
+
task_id: str = gen_unique_id(prefix="task")
|
|
170
|
+
|
|
171
|
+
# 超时层级:request.timeout > submit timeout > default
|
|
172
|
+
effective_timeout: float | None = request.timeout if request.timeout is not None else (timeout if timeout is not None else self.default_timeout)
|
|
173
|
+
|
|
174
|
+
handle: TaskHandle = TaskHandle(
|
|
175
|
+
task_id=task_id,
|
|
176
|
+
request=request,
|
|
177
|
+
status=TaskStatus.SUBMITTED,
|
|
178
|
+
task_timeout=effective_timeout,
|
|
179
|
+
persist=persist,
|
|
180
|
+
)
|
|
181
|
+
# threading.Event 在任意线程都能创建
|
|
182
|
+
handle._done_threading_event = threading.Event()
|
|
183
|
+
|
|
184
|
+
with self._handles_lock:
|
|
185
|
+
self._handles[task_id] = handle
|
|
186
|
+
|
|
187
|
+
assert self._loop is not None
|
|
188
|
+
loop = self._loop
|
|
189
|
+
|
|
190
|
+
# 在 loop 线程里创建 Task,并同步回填到 handle。用 threading.Event 等回执,
|
|
191
|
+
# 确保 submit() 返回后 handle._asyncio_task 一定可用(否则 cancel() 竞态)。
|
|
192
|
+
bound: threading.Event = threading.Event()
|
|
193
|
+
|
|
194
|
+
def _schedule() -> None:
|
|
195
|
+
try:
|
|
196
|
+
handle._asyncio_task = loop.create_task(self._run_task(handle))
|
|
197
|
+
handle._done_event = asyncio.Event()
|
|
198
|
+
finally:
|
|
199
|
+
bound.set()
|
|
200
|
+
|
|
201
|
+
loop.call_soon_threadsafe(_schedule)
|
|
202
|
+
bound.wait()
|
|
203
|
+
return handle
|
|
204
|
+
|
|
205
|
+
async def _run_task(self, handle: TaskHandle) -> None:
|
|
206
|
+
request = handle.request
|
|
207
|
+
# —— 验证阶段:失败 → REJECTED ——
|
|
208
|
+
try:
|
|
209
|
+
model_name: str = request.model_name or self.default_model_name
|
|
210
|
+
if not model_name:
|
|
211
|
+
raise ValueError("no model_name (request/default all empty)")
|
|
212
|
+
model: ModelConfig = self.model_registry.get(name=model_name)
|
|
213
|
+
except Exception as e:
|
|
214
|
+
handle.status = TaskStatus.REJECTED
|
|
215
|
+
handle.error = f"{type(e).__name__}: {e}"
|
|
216
|
+
handle.error_exception = e
|
|
217
|
+
handle.result = InferenceRequestResult(
|
|
218
|
+
success=False,
|
|
219
|
+
task_id=handle.task_id,
|
|
220
|
+
request=request,
|
|
221
|
+
error_message=handle.error,
|
|
222
|
+
)
|
|
223
|
+
handle.end_time = time.time()
|
|
224
|
+
if handle._done_event is not None:
|
|
225
|
+
handle._done_event.set()
|
|
226
|
+
if handle._done_threading_event is not None:
|
|
227
|
+
handle._done_threading_event.set()
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
# 验证通过:ACCEPTED
|
|
231
|
+
handle.status = TaskStatus.ACCEPTED
|
|
232
|
+
|
|
233
|
+
assert self._global_sem is not None
|
|
234
|
+
try:
|
|
235
|
+
# 入队等 sem
|
|
236
|
+
handle.status = TaskStatus.PENDING
|
|
237
|
+
async with self._global_sem:
|
|
238
|
+
handle.status = TaskStatus.RUNNING
|
|
239
|
+
handle.start_time = time.time()
|
|
240
|
+
try:
|
|
241
|
+
# 用 asyncio.timeout 而非 asyncio.wait_for:超时直接抛
|
|
242
|
+
# TimeoutError,不经内层 CancelledError 上传,因此下面的
|
|
243
|
+
# CancelledError 分支可确定是"外部 cancel",避免双终态写入。
|
|
244
|
+
async with asyncio.timeout(handle.task_timeout):
|
|
245
|
+
model_output = await self._send_request(request, model)
|
|
246
|
+
handle.result = InferenceRequestResult(
|
|
247
|
+
success=True,
|
|
248
|
+
task_id=handle.task_id,
|
|
249
|
+
request=request,
|
|
250
|
+
model_output=model_output,
|
|
251
|
+
duration=time.time() - handle.start_time,
|
|
252
|
+
)
|
|
253
|
+
handle.status = TaskStatus.SUCCESS
|
|
254
|
+
except asyncio.CancelledError:
|
|
255
|
+
handle.status = TaskStatus.CANCELLED
|
|
256
|
+
handle.error = "cancelled"
|
|
257
|
+
handle.result = InferenceRequestResult(
|
|
258
|
+
success=False,
|
|
259
|
+
task_id=handle.task_id,
|
|
260
|
+
request=request,
|
|
261
|
+
error_message="cancelled",
|
|
262
|
+
)
|
|
263
|
+
logger.info("任务 %s 已取消", handle.task_id)
|
|
264
|
+
raise
|
|
265
|
+
except TimeoutError:
|
|
266
|
+
handle.status = TaskStatus.TIMEOUT
|
|
267
|
+
handle.error = "timeout"
|
|
268
|
+
handle.result = InferenceRequestResult(
|
|
269
|
+
success=False,
|
|
270
|
+
task_id=handle.task_id,
|
|
271
|
+
request=request,
|
|
272
|
+
error_message="timeout",
|
|
273
|
+
)
|
|
274
|
+
except Exception as e:
|
|
275
|
+
logger.exception("任务 %s 执行异常", handle.task_id)
|
|
276
|
+
handle.status = TaskStatus.FAILED
|
|
277
|
+
handle.error = f"{type(e).__name__}: {e}"
|
|
278
|
+
handle.error_exception = e
|
|
279
|
+
handle.result = InferenceRequestResult(
|
|
280
|
+
success=False,
|
|
281
|
+
task_id=handle.task_id,
|
|
282
|
+
request=request,
|
|
283
|
+
error_message=handle.error,
|
|
284
|
+
)
|
|
285
|
+
finally:
|
|
286
|
+
handle.end_time = time.time()
|
|
287
|
+
if handle._done_event is not None:
|
|
288
|
+
handle._done_event.set()
|
|
289
|
+
if handle._done_threading_event is not None:
|
|
290
|
+
handle._done_threading_event.set()
|
|
291
|
+
|
|
292
|
+
# ------------------------------------------------------------------
|
|
293
|
+
# 查询与控制
|
|
294
|
+
# ------------------------------------------------------------------
|
|
295
|
+
|
|
296
|
+
def _resolve(self, key: Key) -> TaskHandle:
|
|
297
|
+
if isinstance(key, TaskHandle):
|
|
298
|
+
return key
|
|
299
|
+
if isinstance(key, str):
|
|
300
|
+
with self._handles_lock:
|
|
301
|
+
h: TaskHandle | None = self._handles.get(key)
|
|
302
|
+
if h is None:
|
|
303
|
+
raise KeyError(f"未找到 task_id='{key}' 的任务")
|
|
304
|
+
return h
|
|
305
|
+
raise TypeError(f"key 必须是 str 或 TaskHandle,实际: {type(key).__name__}")
|
|
306
|
+
|
|
307
|
+
def _maybe_consume(self, handle: TaskHandle) -> None:
|
|
308
|
+
"""终态 + 非 persist 时从 _handles 删除。"""
|
|
309
|
+
if handle.is_finished() and not handle.persist:
|
|
310
|
+
with self._handles_lock:
|
|
311
|
+
self._handles.pop(handle.task_id, None)
|
|
312
|
+
|
|
313
|
+
def poll(self, key: Key) -> TaskHandle:
|
|
314
|
+
"""非消费快照:返回 handle 当前状态,不删除。"""
|
|
315
|
+
return self._resolve(key)
|
|
316
|
+
|
|
317
|
+
def get(self, key: Key) -> InferenceRequestResult | None:
|
|
318
|
+
"""非阻塞消费:未完成返回 None;完成则返回 result 并(若非 persist)删除。"""
|
|
319
|
+
h: TaskHandle = self._resolve(key)
|
|
320
|
+
if not h.is_finished():
|
|
321
|
+
return None
|
|
322
|
+
result: Any | None = h.result
|
|
323
|
+
self._maybe_consume(handle=h)
|
|
324
|
+
return result
|
|
325
|
+
|
|
326
|
+
def wait(self, key: Key, timeout: float | None = None) -> InferenceRequestResult:
|
|
327
|
+
"""阻塞到终态,返回 result。等待超时抛 TimeoutError(任务仍在后台运行)。"""
|
|
328
|
+
h: TaskHandle = self._resolve(key)
|
|
329
|
+
if h._done_threading_event is None:
|
|
330
|
+
raise RuntimeError("handle 缺少 done event,可能未经过 submit")
|
|
331
|
+
finished: bool = h._done_threading_event.wait(timeout=timeout)
|
|
332
|
+
if not finished:
|
|
333
|
+
raise TimeoutError(f"等待任务 {h.task_id} 超时({timeout}s),任务仍在后台运行")
|
|
334
|
+
result = h.result
|
|
335
|
+
self._maybe_consume(handle=h)
|
|
336
|
+
# 终态后 result 必然已被 _run_task 赋值(SUCCESS/FAILED/CANCELLED 均覆盖);
|
|
337
|
+
# 若仍为 None 则说明引擎内部状态异常,直接把失败信息返回给调用方。
|
|
338
|
+
if result is None:
|
|
339
|
+
return InferenceRequestResult(
|
|
340
|
+
success=False,
|
|
341
|
+
task_id=h.task_id,
|
|
342
|
+
request=h.request,
|
|
343
|
+
error_message=h.error or "unknown error (result missing)",
|
|
344
|
+
)
|
|
345
|
+
return result
|
|
346
|
+
|
|
347
|
+
def cancel(self, key: Key, force: bool = False) -> bool:
|
|
348
|
+
h: TaskHandle = self._resolve(key)
|
|
349
|
+
if h.is_finished():
|
|
350
|
+
return False
|
|
351
|
+
if h.status == TaskStatus.RUNNING and not force:
|
|
352
|
+
return False
|
|
353
|
+
if h._asyncio_task is not None and self._loop is not None:
|
|
354
|
+
self._loop.call_soon_threadsafe(callback=h._asyncio_task.cancel)
|
|
355
|
+
# 注意:实际状态更新在 _run_task 的 except CancelledError 分支完成
|
|
356
|
+
return True
|
|
357
|
+
|
|
358
|
+
def pop(self, key: Key) -> InferenceRequestResult | None:
|
|
359
|
+
"""强制从 _handles 移除(无论 persist),返回 result(若有)。"""
|
|
360
|
+
h: TaskHandle = self._resolve(key)
|
|
361
|
+
with self._handles_lock:
|
|
362
|
+
self._handles.pop(h.task_id, None)
|
|
363
|
+
return h.result
|
|
364
|
+
|
|
365
|
+
def clear_done(self) -> int:
|
|
366
|
+
"""清理所有终态 handle(含 persist)。返回清理个数。"""
|
|
367
|
+
with self._handles_lock:
|
|
368
|
+
done_ids: list[str] = [tid for tid, h in self._handles.items() if h.is_finished()]
|
|
369
|
+
for tid in done_ids:
|
|
370
|
+
self._handles.pop(tid, None)
|
|
371
|
+
return len(done_ids)
|
|
372
|
+
|
|
373
|
+
# ------------------------------------------------------------------
|
|
374
|
+
# 同步便捷接口
|
|
375
|
+
# ------------------------------------------------------------------
|
|
376
|
+
|
|
377
|
+
def inference(self, request: InferenceRequest) -> InferenceRequestResult:
|
|
378
|
+
handle: TaskHandle = self.submit(request)
|
|
379
|
+
# task_timeout 在 _run_task 的 asyncio.wait_for 里生效,此处 threading 侧不再重复设置,
|
|
380
|
+
# 避免两个超时竞争导致偶发 TimeoutError
|
|
381
|
+
return self.wait(handle)
|
|
382
|
+
|
|
383
|
+
def infer(self, query: str) -> str:
|
|
384
|
+
result: InferenceRequestResult = self.inference(request=InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]))
|
|
385
|
+
if result.success and result.model_output is not None:
|
|
386
|
+
return result.model_output.content
|
|
387
|
+
return ""
|
|
388
|
+
|
|
389
|
+
def batch_inference(
|
|
390
|
+
self,
|
|
391
|
+
requests: list[InferenceRequest],
|
|
392
|
+
output_file: str | None = None,
|
|
393
|
+
silent_mode: bool = False,
|
|
394
|
+
) -> list[InferenceRequestResult]:
|
|
395
|
+
handles: list[TaskHandle] = [self.submit(request=req) for req in requests]
|
|
396
|
+
|
|
397
|
+
pbar: tqdm | None = tqdm(total=len(requests), desc=f"Batch: {self.default_model_name}") if not silent_mode else None
|
|
398
|
+
f: IO[str] | None = open(output_file, mode="w", encoding="utf-8") if output_file else None
|
|
399
|
+
results: list[InferenceRequestResult] = []
|
|
400
|
+
try:
|
|
401
|
+
for h in handles:
|
|
402
|
+
result: InferenceRequestResult = self.wait(key=h)
|
|
403
|
+
results.append(result)
|
|
404
|
+
if f is not None:
|
|
405
|
+
f.write(result.model_dump_json() + "\n")
|
|
406
|
+
f.flush()
|
|
407
|
+
if pbar is not None:
|
|
408
|
+
pbar.update(n=1)
|
|
409
|
+
finally:
|
|
410
|
+
if pbar is not None:
|
|
411
|
+
pbar.close()
|
|
412
|
+
if f is not None:
|
|
413
|
+
f.close()
|
|
414
|
+
return results
|
|
415
|
+
|
|
416
|
+
def batch_infer(self, queries: list[str], **kwargs: Any) -> list[str]:
|
|
417
|
+
reqs: list[InferenceRequest] = [InferenceRequest(messages=[Message(role=MessageRole.USER, content=q)]) for q in queries]
|
|
418
|
+
results: list[InferenceRequestResult] = self.batch_inference(requests=reqs, **kwargs)
|
|
419
|
+
return [r.model_output.content if r.success and r.model_output else "" for r in results]
|
|
420
|
+
|
|
421
|
+
# ------------------------------------------------------------------
|
|
422
|
+
# HTTP 核心(自包含)
|
|
423
|
+
# ------------------------------------------------------------------
|
|
424
|
+
|
|
425
|
+
def _get_api_key(self, request: InferenceRequest, model: ModelConfig) -> str:
|
|
426
|
+
if request.api_key is not None:
|
|
427
|
+
return request.api_key
|
|
428
|
+
if model.api_key is not None:
|
|
429
|
+
return model.api_key
|
|
430
|
+
if self.default_api_key is not None:
|
|
431
|
+
return self.default_api_key
|
|
432
|
+
logger.warning("没有找到可用的 api_key, 将传递空 api_key")
|
|
433
|
+
return ""
|
|
434
|
+
|
|
435
|
+
def _build_headers(self, request: InferenceRequest, model: ModelConfig) -> dict[str, str]:
|
|
436
|
+
headers: dict[str, str] = {
|
|
437
|
+
"Authorization": "Bearer " + self._get_api_key(request, model),
|
|
438
|
+
"Content-Type": "application/json",
|
|
439
|
+
}
|
|
440
|
+
headers.update(self.extra_headers)
|
|
441
|
+
headers.update(model.extra_headers)
|
|
442
|
+
headers.update(request.extra_headers)
|
|
443
|
+
return headers
|
|
444
|
+
|
|
445
|
+
def _build_payload(self, request: InferenceRequest, model: ModelConfig) -> dict[str, Any]:
|
|
446
|
+
payload: dict[str, Any] = {
|
|
447
|
+
"model": model.model_id,
|
|
448
|
+
"stream": request.stream if request.stream is not None else self.stream,
|
|
449
|
+
"messages": [msg.to_dict() for msg in request.messages],
|
|
450
|
+
}
|
|
451
|
+
if self.default_inference_parameters:
|
|
452
|
+
payload.update(self.default_inference_parameters.to_dict())
|
|
453
|
+
if model.default_inference_parameters:
|
|
454
|
+
payload.update(model.default_inference_parameters.to_dict())
|
|
455
|
+
if request.inference_parameters:
|
|
456
|
+
payload.update(request.inference_parameters.to_dict())
|
|
457
|
+
payload.update(self.extra_payload)
|
|
458
|
+
payload.update(model.extra_payload)
|
|
459
|
+
payload.update(request.extra_payload)
|
|
460
|
+
return payload
|
|
461
|
+
|
|
462
|
+
def _get_wait_time(self, status_code: int, attempt: int, base_delay: int) -> float:
|
|
463
|
+
if status_code == 429:
|
|
464
|
+
return 1.0
|
|
465
|
+
if status_code >= 500:
|
|
466
|
+
return float(base_delay**attempt)
|
|
467
|
+
return 5.0
|
|
468
|
+
|
|
469
|
+
async def _send_request(self, request: InferenceRequest, model: ModelConfig) -> ModelOutput:
|
|
470
|
+
"""
|
|
471
|
+
发送请求,带重试。每次 attempt 重新选 URL(支持多节点故障转移)。
|
|
472
|
+
仅实现非流式;流式留待后续。
|
|
473
|
+
"""
|
|
474
|
+
headers: dict[str, str] = self._build_headers(request, model)
|
|
475
|
+
payload: dict[str, Any] = self._build_payload(request, model)
|
|
476
|
+
stream = payload["stream"]
|
|
477
|
+
max_retries: int = self.default_max_retries
|
|
478
|
+
base_delay: int = self.default_base_delay
|
|
479
|
+
|
|
480
|
+
if stream:
|
|
481
|
+
raise NotImplementedError("GeneralEngine 首版暂未实现 stream 模式")
|
|
482
|
+
|
|
483
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=10.0, read=3600.0)) as client:
|
|
484
|
+
for attempt in range(max_retries):
|
|
485
|
+
is_last = attempt == max_retries - 1
|
|
486
|
+
url = model.get_url()
|
|
487
|
+
try:
|
|
488
|
+
resp = await client.post(url, json=payload, headers=headers)
|
|
489
|
+
if resp.status_code != 200:
|
|
490
|
+
body = resp.text
|
|
491
|
+
if is_last:
|
|
492
|
+
logger.error("HTTP %d,已达最大重试次数,响应: %s", resp.status_code, body[:500])
|
|
493
|
+
raise httpx.HTTPStatusError(f"HTTP {resp.status_code}", request=resp.request, response=resp)
|
|
494
|
+
wait_time = self._get_wait_time(resp.status_code, attempt, base_delay)
|
|
495
|
+
logger.warning(
|
|
496
|
+
"HTTP %d, 第 %d 次重试,等待 %.1fs... 响应: %s",
|
|
497
|
+
resp.status_code,
|
|
498
|
+
attempt + 1,
|
|
499
|
+
wait_time,
|
|
500
|
+
body[:500],
|
|
501
|
+
)
|
|
502
|
+
await asyncio.sleep(wait_time)
|
|
503
|
+
continue
|
|
504
|
+
|
|
505
|
+
try:
|
|
506
|
+
parsed = ChatCompletionResponse.model_validate_json(resp.text)
|
|
507
|
+
except ValidationError as ve:
|
|
508
|
+
if is_last:
|
|
509
|
+
raise RuntimeError(f"响应 JSON 解析失败,已达最大重试次数: {ve}; body: {resp.text[:500]}") from ve
|
|
510
|
+
logger.warning("响应 JSON 解析失败,第 %d 次重试: %r", attempt + 1, ve)
|
|
511
|
+
await asyncio.sleep(base_delay)
|
|
512
|
+
continue
|
|
513
|
+
|
|
514
|
+
if not parsed.choices:
|
|
515
|
+
if is_last:
|
|
516
|
+
raise RuntimeError(f"empty choices after {max_retries} attempts, body: {resp.text[:500]}")
|
|
517
|
+
logger.warning("resp.choices 为空,疑似风控,第 %d 次重试...", attempt + 1)
|
|
518
|
+
await asyncio.sleep(base_delay)
|
|
519
|
+
continue
|
|
520
|
+
choice = parsed.choices[0]
|
|
521
|
+
if choice.finish_reason and choice.finish_reason != "stop":
|
|
522
|
+
logger.warning(
|
|
523
|
+
"finish_reason='%s' (非 stop), content_len=%d, usage=%s",
|
|
524
|
+
choice.finish_reason,
|
|
525
|
+
len(choice.message.content or ""),
|
|
526
|
+
parsed.usage.model_dump() if parsed.usage else None,
|
|
527
|
+
)
|
|
528
|
+
return ModelOutput(
|
|
529
|
+
role=choice.message.role,
|
|
530
|
+
content=choice.message.content,
|
|
531
|
+
reasoning=choice.message.reasoning_content,
|
|
532
|
+
finish_reason=choice.finish_reason,
|
|
533
|
+
usage=parsed.usage.model_dump() if parsed.usage else None,
|
|
534
|
+
)
|
|
535
|
+
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
|
536
|
+
# HTTPStatusError 仅会在上面 is_last 分支抛出,不再在此捕获
|
|
537
|
+
if is_last:
|
|
538
|
+
logger.error("达到最大重试次数,最后错误: %r", e)
|
|
539
|
+
raise
|
|
540
|
+
wait_time = base_delay**attempt
|
|
541
|
+
logger.warning("网络异常: %r, 等待 %.1fs 后重试...", e, wait_time)
|
|
542
|
+
await asyncio.sleep(wait_time)
|
|
543
|
+
finally:
|
|
544
|
+
model.release_url(url)
|
|
545
|
+
|
|
546
|
+
raise RuntimeError("Unexpected end of retry loop")
|
|
547
|
+
|
|
548
|
+
# ------------------------------------------------------------------
|
|
549
|
+
# 预留:B 方案(async-native),首版不实现
|
|
550
|
+
# ------------------------------------------------------------------
|
|
551
|
+
|
|
552
|
+
async def a_submit(self, request: InferenceRequest, *, persist: bool = False, timeout: float | None = None) -> TaskHandle:
|
|
553
|
+
raise NotImplementedError("async-native 接口尚未实现")
|
|
554
|
+
|
|
555
|
+
async def a_wait(self, key: Key, timeout: float | None = None) -> InferenceRequestResult:
|
|
556
|
+
raise NotImplementedError("async-native 接口尚未实现")
|
|
557
|
+
|
|
558
|
+
async def a_inference(self, request: InferenceRequest) -> InferenceRequestResult:
|
|
559
|
+
raise NotImplementedError("async-native 接口尚未实现")
|
|
560
|
+
|
|
561
|
+
async def a_batch_inference(self, requests: list[InferenceRequest]) -> list[InferenceRequestResult]:
|
|
562
|
+
raise NotImplementedError("async-native 接口尚未实现")
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# llm_engine/kitty/__main__.py
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
CLI 入口。
|
|
5
|
+
|
|
6
|
+
用法:
|
|
7
|
+
python -m llm_engine.kitty -c /path/to/kitty_engine.yaml
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import argparse
|
|
11
|
+
import asyncio
|
|
12
|
+
import logging
|
|
13
|
+
import sys
|
|
14
|
+
|
|
15
|
+
import kitty_logger
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def main() -> int:
|
|
19
|
+
parser = argparse.ArgumentParser(prog="python -m llm_engine.kitty", description="KittyEngine 后端 server")
|
|
20
|
+
group = parser.add_mutually_exclusive_group(required=True)
|
|
21
|
+
group.add_argument("-c", "--config", help="YAML 配置文件路径")
|
|
22
|
+
group.add_argument("--embedded", action="store_true", help="embedded 模式:从 stdin 读取 JSON config,stdin EOF 时自动退出")
|
|
23
|
+
args = parser.parse_args()
|
|
24
|
+
|
|
25
|
+
from .config import KittyEngineConfig
|
|
26
|
+
|
|
27
|
+
if args.embedded:
|
|
28
|
+
import json
|
|
29
|
+
|
|
30
|
+
config_dict = json.loads(sys.stdin.readline())
|
|
31
|
+
config = KittyEngineConfig.model_validate(config_dict)
|
|
32
|
+
else:
|
|
33
|
+
config = KittyEngineConfig.from_yaml(args.config)
|
|
34
|
+
|
|
35
|
+
level = getattr(logging, config.log_level.upper(), logging.INFO)
|
|
36
|
+
kitty_logger.setup_logging(level=level)
|
|
37
|
+
|
|
38
|
+
from .server import KittyServer
|
|
39
|
+
|
|
40
|
+
server = KittyServer(config)
|
|
41
|
+
asyncio.run(server.run(mode="embedded" if args.embedded else "standalone"))
|
|
42
|
+
return 0
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
if __name__ == "__main__":
|
|
46
|
+
sys.exit(main())
|