llm-engine-kitty 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_engine/__init__.py +54 -0
- llm_engine/engine.py +771 -0
- llm_engine/general_engine.py +562 -0
- llm_engine/kitty/__init__.py +8 -0
- llm_engine/kitty/__main__.py +46 -0
- llm_engine/kitty/client.py +550 -0
- llm_engine/kitty/config.py +83 -0
- llm_engine/kitty/engine.py +1077 -0
- llm_engine/kitty/protocol.py +213 -0
- llm_engine/kitty/schemas.py +89 -0
- llm_engine/kitty/server.py +408 -0
- llm_engine/model_config.py +112 -0
- llm_engine/schemas.py +251 -0
- llm_engine/utils.py +34 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/METADATA +15 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/RECORD +18 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/WHEEL +5 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,550 @@
|
|
|
1
|
+
# llm_engine/kitty/client.py
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
KittyClient:用户面向的客户端。
|
|
5
|
+
|
|
6
|
+
两种部署模式:
|
|
7
|
+
1. embedded(默认):传入 kwargs,内部 spawn 一个子进程启动 KittyServer,自动连接。
|
|
8
|
+
适合单机单租户使用,生命周期与用户进程绑定。
|
|
9
|
+
2. remote:KittyClient.connect(endpoint=..., **overrides) 连接到已经单独启动的
|
|
10
|
+
KittyServer(通过 `python -m llm_engine.kitty -c config.yaml` 启动)。
|
|
11
|
+
|
|
12
|
+
用户无需接触 KittyEngineConfig / KittyEngineOverrides 类,所有参数以 kwargs 形式传入。
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import os
|
|
17
|
+
import socket
|
|
18
|
+
import subprocess
|
|
19
|
+
import sys
|
|
20
|
+
import tempfile
|
|
21
|
+
import threading
|
|
22
|
+
import time
|
|
23
|
+
from dataclasses import dataclass
|
|
24
|
+
from typing import IO, Any, Optional, TypeAlias
|
|
25
|
+
|
|
26
|
+
import kitty_logger
|
|
27
|
+
from tqdm import tqdm
|
|
28
|
+
|
|
29
|
+
from ..schemas import (
|
|
30
|
+
InferenceParameters,
|
|
31
|
+
InferenceRequest,
|
|
32
|
+
InferenceRequestResult,
|
|
33
|
+
Message,
|
|
34
|
+
MessageRole,
|
|
35
|
+
PreparedRequest,
|
|
36
|
+
TaskStatus,
|
|
37
|
+
)
|
|
38
|
+
from ..utils import gen_unique_id
|
|
39
|
+
from .config import KittyEngineConfig, KittyEngineOverrides
|
|
40
|
+
from .protocol import MsgType, connect_endpoint, encode_frame, recv_frame
|
|
41
|
+
from .schemas import MockRequest, TaskSnapshot, WireResult
|
|
42
|
+
|
|
43
|
+
logger = kitty_logger.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ClientTaskHandle:
|
|
48
|
+
"""客户端侧的任务句柄。request 由客户端缓存(WireResult 不含 request)。"""
|
|
49
|
+
|
|
50
|
+
task_id: str
|
|
51
|
+
request: InferenceRequest
|
|
52
|
+
snapshot: TaskSnapshot
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def status(self) -> TaskStatus:
|
|
56
|
+
return self.snapshot.status
|
|
57
|
+
|
|
58
|
+
def is_finished(self) -> bool:
|
|
59
|
+
return self.snapshot.is_finished()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
Key: TypeAlias = "str | ClientTaskHandle"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ---------------------------------------------------------------------------
|
|
66
|
+
# subprocess 入口(embedded 模式)已移至 __main__.py --embedded 分支
|
|
67
|
+
# ---------------------------------------------------------------------------
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# ---------------------------------------------------------------------------
|
|
71
|
+
# KittyClient
|
|
72
|
+
# ---------------------------------------------------------------------------
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class KittyClient:
|
|
76
|
+
"""
|
|
77
|
+
KittyClient 客户端。直接构造为 embedded 模式;KittyClient.connect() 为 remote 模式。
|
|
78
|
+
|
|
79
|
+
线程安全 / 并发使用约束(重要):
|
|
80
|
+
本实现仅按**单线程顺序使用**设计。`_request` 用一把全局 `self._lock`
|
|
81
|
+
把 send + recv 整段串行化,因此一旦某个调用阻塞(典型如 `wait()`:
|
|
82
|
+
server 端 `wait_task` 会一直 await `done_event` 到任务终态),**该
|
|
83
|
+
连接上所有其它 RPC(含 `submit/poll/stats/cancel/clear_done`)都会
|
|
84
|
+
排队**,并发使用会退化为完全串行。
|
|
85
|
+
|
|
86
|
+
如需并发使用:
|
|
87
|
+
- 推荐:每个使用线程持有自己的 `KittyClient`(embedded 不可,否则
|
|
88
|
+
每条都会拉一个 worker 子进程;用 `KittyClient.connect(endpoint=...)`
|
|
89
|
+
连同一个 server 即可)。
|
|
90
|
+
- 不推荐:在多个线程间共享一个实例并自己加锁——等价于退化为单线程。
|
|
91
|
+
|
|
92
|
+
协议层目前未做帧多路复用;未来若改造,本约束可放宽。
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
*,
|
|
98
|
+
# —— 必填(embedded 模式) ——
|
|
99
|
+
registry_path: str,
|
|
100
|
+
default_model: Optional[str] = None,
|
|
101
|
+
# —— 引擎默认 ——
|
|
102
|
+
default_api_key: Optional[str] = None,
|
|
103
|
+
default_inference_parameters: Optional[InferenceParameters] = None,
|
|
104
|
+
extra_headers: Optional[dict[str, str]] = None,
|
|
105
|
+
extra_payload: Optional[dict[str, Any]] = None,
|
|
106
|
+
# —— 并发/重试 ——
|
|
107
|
+
max_global_concurrency: int = 64,
|
|
108
|
+
default_timeout: Optional[float] = None,
|
|
109
|
+
default_max_retries: int = 5,
|
|
110
|
+
default_base_delay: int = 2,
|
|
111
|
+
# —— server 专属 ——
|
|
112
|
+
listen: Optional[str] = None,
|
|
113
|
+
log_level: str = "INFO",
|
|
114
|
+
# —— 连接调优 ——
|
|
115
|
+
http_client_connect_timeout: float = 10.0,
|
|
116
|
+
http_client_read_timeout: float = 3600.0,
|
|
117
|
+
) -> None:
|
|
118
|
+
logger.debug("初始化KittyClient...")
|
|
119
|
+
self._init_common_state()
|
|
120
|
+
|
|
121
|
+
# embedded 模式:构造完整 config 并 spawn worker
|
|
122
|
+
if listen is None:
|
|
123
|
+
# 默认随机 UDS 路径,避免多实例冲突
|
|
124
|
+
tmpdir = tempfile.gettempdir()
|
|
125
|
+
listen = f"unix://{tmpdir}/kitty_{os.getpid()}_{int(time.time()*1000)}.sock"
|
|
126
|
+
|
|
127
|
+
config = KittyEngineConfig(
|
|
128
|
+
registry_path=registry_path,
|
|
129
|
+
default_model=default_model,
|
|
130
|
+
default_api_key=default_api_key,
|
|
131
|
+
default_inference_parameters=default_inference_parameters,
|
|
132
|
+
extra_headers=extra_headers or {},
|
|
133
|
+
extra_payload=extra_payload or {},
|
|
134
|
+
max_global_concurrency=max_global_concurrency,
|
|
135
|
+
default_timeout=default_timeout,
|
|
136
|
+
default_max_retries=default_max_retries,
|
|
137
|
+
default_base_delay=default_base_delay,
|
|
138
|
+
listen=listen,
|
|
139
|
+
log_level=log_level,
|
|
140
|
+
http_client_connect_timeout=http_client_connect_timeout,
|
|
141
|
+
http_client_read_timeout=http_client_read_timeout,
|
|
142
|
+
)
|
|
143
|
+
self._endpoint = listen
|
|
144
|
+
self._spawn_and_connect(config)
|
|
145
|
+
logger.debug("KittyClient初始化完成")
|
|
146
|
+
|
|
147
|
+
def _init_common_state(self) -> None:
|
|
148
|
+
"""初始化跨两种部署模式共用的实例状态。"""
|
|
149
|
+
self._lock = threading.Lock()
|
|
150
|
+
self._sock: Optional[socket.socket] = None
|
|
151
|
+
self._proc: Optional[subprocess.Popen] = None
|
|
152
|
+
self._owns_process: bool = False
|
|
153
|
+
self._endpoint: str = ""
|
|
154
|
+
self._handles: dict[str, ClientTaskHandle] = {}
|
|
155
|
+
self._handles_lock = threading.Lock()
|
|
156
|
+
self._connected: bool = False
|
|
157
|
+
|
|
158
|
+
# ------------------------------------------------------------------
|
|
159
|
+
# 连接/启动
|
|
160
|
+
# ------------------------------------------------------------------
|
|
161
|
+
|
|
162
|
+
@classmethod
|
|
163
|
+
def connect(
|
|
164
|
+
cls,
|
|
165
|
+
endpoint: str,
|
|
166
|
+
*,
|
|
167
|
+
default_model: Optional[str] = None,
|
|
168
|
+
default_api_key: Optional[str] = None,
|
|
169
|
+
default_inference_parameters: Optional[InferenceParameters] = None,
|
|
170
|
+
extra_headers: Optional[dict[str, str]] = None,
|
|
171
|
+
extra_payload: Optional[dict[str, Any]] = None,
|
|
172
|
+
default_timeout: Optional[float] = None,
|
|
173
|
+
default_max_retries: Optional[int] = None,
|
|
174
|
+
default_base_delay: Optional[int] = None,
|
|
175
|
+
) -> "KittyClient":
|
|
176
|
+
"""连接到一个已运行的 KittyServer。kwargs 为连接级覆盖。"""
|
|
177
|
+
overrides = KittyEngineOverrides(
|
|
178
|
+
default_model=default_model,
|
|
179
|
+
default_api_key=default_api_key,
|
|
180
|
+
default_inference_parameters=default_inference_parameters,
|
|
181
|
+
extra_headers=extra_headers or {},
|
|
182
|
+
extra_payload=extra_payload or {},
|
|
183
|
+
default_timeout=default_timeout,
|
|
184
|
+
default_max_retries=default_max_retries,
|
|
185
|
+
default_base_delay=default_base_delay,
|
|
186
|
+
)
|
|
187
|
+
# remote 模式不走 embedded __init__(其需要 registry_path 且会 spawn worker)
|
|
188
|
+
self = cls.__new__(cls)
|
|
189
|
+
self._init_common_state()
|
|
190
|
+
self._endpoint = endpoint
|
|
191
|
+
self._connect_remote(endpoint, overrides)
|
|
192
|
+
return self
|
|
193
|
+
|
|
194
|
+
def _spawn_and_connect(self, config: KittyEngineConfig) -> None:
|
|
195
|
+
# 把父进程 sys.path 合并进 PYTHONPATH,兼容未 pip install 而通过 sys.path.insert 加载的场景
|
|
196
|
+
env = os.environ.copy()
|
|
197
|
+
extra_paths = os.pathsep.join(p for p in sys.path if p)
|
|
198
|
+
existing = env.get("PYTHONPATH", "")
|
|
199
|
+
env["PYTHONPATH"] = extra_paths + (os.pathsep + existing if existing else "")
|
|
200
|
+
|
|
201
|
+
proc = subprocess.Popen(
|
|
202
|
+
[sys.executable, "-m", "llm_engine.kitty", "--embedded"],
|
|
203
|
+
stdin=subprocess.PIPE,
|
|
204
|
+
env=env,
|
|
205
|
+
)
|
|
206
|
+
# 把 config 作为第一行 JSON 写入 stdin;保持 stdin 不关闭,EOF = 父进程退出
|
|
207
|
+
assert proc.stdin is not None
|
|
208
|
+
proc.stdin.write((json.dumps(config.model_dump()) + "\n").encode("utf-8"))
|
|
209
|
+
proc.stdin.flush()
|
|
210
|
+
|
|
211
|
+
self._proc = proc
|
|
212
|
+
self._owns_process = True
|
|
213
|
+
logger.info("KittyClient spawn worker pid=%d endpoint=%s", proc.pid, self._endpoint)
|
|
214
|
+
|
|
215
|
+
try:
|
|
216
|
+
sock = connect_endpoint(self._endpoint, connect_timeout=30.0)
|
|
217
|
+
except Exception:
|
|
218
|
+
if proc.poll() is None:
|
|
219
|
+
proc.terminate()
|
|
220
|
+
proc.wait(timeout=5.0)
|
|
221
|
+
raise
|
|
222
|
+
self._sock = sock
|
|
223
|
+
self._hello(overrides=None)
|
|
224
|
+
self._connected = True
|
|
225
|
+
|
|
226
|
+
def _connect_remote(self, endpoint: str, overrides: Optional[KittyEngineOverrides]) -> None:
|
|
227
|
+
sock = connect_endpoint(endpoint, connect_timeout=10.0)
|
|
228
|
+
self._sock = sock
|
|
229
|
+
self._hello(overrides=overrides)
|
|
230
|
+
self._connected = True
|
|
231
|
+
logger.info("KittyClient 已连接到 %s", endpoint)
|
|
232
|
+
|
|
233
|
+
def _hello(self, overrides: Optional[KittyEngineOverrides]) -> None:
|
|
234
|
+
payload: dict[str, Any] = {}
|
|
235
|
+
if overrides is not None and overrides.has_any():
|
|
236
|
+
payload["overrides"] = overrides.model_dump(exclude_none=True)
|
|
237
|
+
msg_type, resp = self._request(MsgType.CMD_HELLO, payload)
|
|
238
|
+
if msg_type != MsgType.MSG_WELCOME:
|
|
239
|
+
raise RuntimeError(f"HELLO 失败: msg_type={msg_type}, resp={resp}")
|
|
240
|
+
|
|
241
|
+
# ------------------------------------------------------------------
|
|
242
|
+
# 关闭
|
|
243
|
+
# ------------------------------------------------------------------
|
|
244
|
+
|
|
245
|
+
def close(self, *, shutdown_remote: bool = False, timeout: float = 10.0) -> None:
|
|
246
|
+
"""关闭客户端。embedded 模式会同时停止 worker 进程。
|
|
247
|
+
|
|
248
|
+
所有网络 I/O 使用 socket timeout 保护,避免 close 阻塞在死连接上。
|
|
249
|
+
"""
|
|
250
|
+
if not self._connected:
|
|
251
|
+
return
|
|
252
|
+
sock = self._sock
|
|
253
|
+
try:
|
|
254
|
+
if sock is not None:
|
|
255
|
+
# 给后续 sendall/recv 设上超时,避免 close 永久阻塞
|
|
256
|
+
try:
|
|
257
|
+
sock.settimeout(max(0.1, timeout))
|
|
258
|
+
except OSError:
|
|
259
|
+
pass
|
|
260
|
+
if self._owns_process or shutdown_remote:
|
|
261
|
+
try:
|
|
262
|
+
self._request(MsgType.CMD_SHUTDOWN, {})
|
|
263
|
+
except Exception as e:
|
|
264
|
+
logger.debug("CMD_SHUTDOWN 失败(忽略): %r", e)
|
|
265
|
+
else:
|
|
266
|
+
try:
|
|
267
|
+
with self._lock:
|
|
268
|
+
self._send(MsgType.CMD_BYE, {})
|
|
269
|
+
except Exception as e:
|
|
270
|
+
logger.debug("CMD_BYE 失败(忽略): %r", e)
|
|
271
|
+
try:
|
|
272
|
+
sock.close()
|
|
273
|
+
except Exception:
|
|
274
|
+
pass
|
|
275
|
+
self._sock = None
|
|
276
|
+
finally:
|
|
277
|
+
if self._owns_process and self._proc is not None:
|
|
278
|
+
# 关闭 stdin 触发 server 端 EOF 自动退出(embedded 模式)
|
|
279
|
+
if self._proc.stdin is not None:
|
|
280
|
+
try:
|
|
281
|
+
self._proc.stdin.close()
|
|
282
|
+
except Exception:
|
|
283
|
+
pass
|
|
284
|
+
try:
|
|
285
|
+
self._proc.wait(timeout=timeout)
|
|
286
|
+
except subprocess.TimeoutExpired:
|
|
287
|
+
pass
|
|
288
|
+
if self._proc.poll() is None:
|
|
289
|
+
self._proc.terminate()
|
|
290
|
+
try:
|
|
291
|
+
self._proc.wait(timeout=5.0)
|
|
292
|
+
except subprocess.TimeoutExpired:
|
|
293
|
+
pass
|
|
294
|
+
if self._proc.poll() is None:
|
|
295
|
+
self._proc.kill()
|
|
296
|
+
try:
|
|
297
|
+
self._proc.wait(timeout=2.0)
|
|
298
|
+
except subprocess.TimeoutExpired:
|
|
299
|
+
pass
|
|
300
|
+
self._proc = None
|
|
301
|
+
self._connected = False
|
|
302
|
+
|
|
303
|
+
def __enter__(self) -> "KittyClient":
|
|
304
|
+
return self
|
|
305
|
+
|
|
306
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
307
|
+
self.close()
|
|
308
|
+
|
|
309
|
+
def __del__(self) -> None:
|
|
310
|
+
# __del__ 中不做网络 I/O:解释器关闭阶段 socket/threading 可能已不可用。
|
|
311
|
+
# 只尽力回收 worker 子进程,避免孤儿。
|
|
312
|
+
try:
|
|
313
|
+
if getattr(self, "_sock", None) is not None:
|
|
314
|
+
try:
|
|
315
|
+
self._sock.close() # type: ignore[union-attr]
|
|
316
|
+
except Exception:
|
|
317
|
+
pass
|
|
318
|
+
self._sock = None
|
|
319
|
+
proc = getattr(self, "_proc", None)
|
|
320
|
+
if getattr(self, "_owns_process", False) and proc is not None:
|
|
321
|
+
try:
|
|
322
|
+
if proc.poll() is None:
|
|
323
|
+
proc.terminate()
|
|
324
|
+
proc.wait(timeout=2.0)
|
|
325
|
+
if proc.poll() is None:
|
|
326
|
+
proc.kill()
|
|
327
|
+
except Exception:
|
|
328
|
+
pass
|
|
329
|
+
self._proc = None
|
|
330
|
+
self._connected = False
|
|
331
|
+
except Exception:
|
|
332
|
+
pass
|
|
333
|
+
|
|
334
|
+
# ------------------------------------------------------------------
|
|
335
|
+
# 协议 I/O(全部加锁串行)
|
|
336
|
+
# ------------------------------------------------------------------
|
|
337
|
+
|
|
338
|
+
def _send(self, msg_type: int, payload: dict[str, Any]) -> None:
|
|
339
|
+
assert self._sock is not None
|
|
340
|
+
self._sock.sendall(encode_frame(msg_type, payload))
|
|
341
|
+
|
|
342
|
+
def _request(self, msg_type: int, payload: dict[str, Any]) -> tuple[int, dict[str, Any]]:
|
|
343
|
+
with self._lock:
|
|
344
|
+
assert self._sock is not None
|
|
345
|
+
self._sock.sendall(encode_frame(msg_type, payload))
|
|
346
|
+
resp_type, resp = recv_frame(self._sock)
|
|
347
|
+
if resp_type == MsgType.MSG_ERROR:
|
|
348
|
+
raise RuntimeError(f"server 返回错误: {resp.get('error')}")
|
|
349
|
+
return resp_type, resp
|
|
350
|
+
|
|
351
|
+
# ------------------------------------------------------------------
|
|
352
|
+
# 任务 API
|
|
353
|
+
# ------------------------------------------------------------------
|
|
354
|
+
|
|
355
|
+
def _resolve_key(self, key: Key) -> ClientTaskHandle:
|
|
356
|
+
if isinstance(key, ClientTaskHandle):
|
|
357
|
+
return key
|
|
358
|
+
if isinstance(key, str):
|
|
359
|
+
with self._handles_lock:
|
|
360
|
+
h = self._handles.get(key)
|
|
361
|
+
if h is None:
|
|
362
|
+
raise KeyError(f"未找到 task_id='{key}' 的本地句柄")
|
|
363
|
+
return h
|
|
364
|
+
raise TypeError(f"key 必须是 str 或 ClientTaskHandle,实际: {type(key).__name__}")
|
|
365
|
+
|
|
366
|
+
def submit(
|
|
367
|
+
self,
|
|
368
|
+
request: InferenceRequest,
|
|
369
|
+
*,
|
|
370
|
+
persist: bool = False,
|
|
371
|
+
timeout: Optional[float] = None,
|
|
372
|
+
) -> ClientTaskHandle:
|
|
373
|
+
task_id = gen_unique_id(prefix="task")
|
|
374
|
+
_, resp = self._request(
|
|
375
|
+
MsgType.CMD_SUBMIT,
|
|
376
|
+
{
|
|
377
|
+
"task_id": task_id,
|
|
378
|
+
"request": request.model_dump(),
|
|
379
|
+
"persist": persist,
|
|
380
|
+
"timeout": timeout,
|
|
381
|
+
},
|
|
382
|
+
)
|
|
383
|
+
snap = TaskSnapshot.model_validate(resp["snapshot"])
|
|
384
|
+
handle = ClientTaskHandle(task_id=task_id, request=request, snapshot=snap)
|
|
385
|
+
with self._handles_lock:
|
|
386
|
+
self._handles[task_id] = handle
|
|
387
|
+
return handle
|
|
388
|
+
|
|
389
|
+
def poll(self, key: Key) -> ClientTaskHandle:
|
|
390
|
+
h = self._resolve_key(key)
|
|
391
|
+
_, resp = self._request(MsgType.CMD_POLL, {"task_id": h.task_id})
|
|
392
|
+
snap_dict = resp.get("snapshot")
|
|
393
|
+
if snap_dict is None:
|
|
394
|
+
# server 端已不存在(可能被 pop)
|
|
395
|
+
return h
|
|
396
|
+
h.snapshot = TaskSnapshot.model_validate(snap_dict)
|
|
397
|
+
return h
|
|
398
|
+
|
|
399
|
+
def get(self, key: Key) -> Optional[InferenceRequestResult]:
|
|
400
|
+
"""非阻塞消费:未完成返回 None。"""
|
|
401
|
+
h = self._resolve_key(key)
|
|
402
|
+
_, resp = self._request(MsgType.CMD_POP, {"task_id": h.task_id})
|
|
403
|
+
if not resp.get("found"):
|
|
404
|
+
return None
|
|
405
|
+
return self._consume_result(h, WireResult.model_validate(resp["result"]))
|
|
406
|
+
|
|
407
|
+
def wait(self, key: Key, timeout: Optional[float] = None) -> InferenceRequestResult:
|
|
408
|
+
"""阻塞到终态。超时抛 TimeoutError(server 端任务仍在跑)。"""
|
|
409
|
+
h = self._resolve_key(key)
|
|
410
|
+
_, resp = self._request(MsgType.CMD_WAIT, {"task_id": h.task_id, "timeout": timeout})
|
|
411
|
+
if resp.get("timed_out"):
|
|
412
|
+
raise TimeoutError(f"等待任务 {h.task_id} 超时({timeout}s)")
|
|
413
|
+
if not resp.get("found"):
|
|
414
|
+
raise RuntimeError(f"任务 {h.task_id} 在 server 端不存在")
|
|
415
|
+
return self._consume_result(h, WireResult.model_validate(resp["result"]))
|
|
416
|
+
|
|
417
|
+
def cancel(self, key: Key, force: bool = False) -> bool:
|
|
418
|
+
h = self._resolve_key(key)
|
|
419
|
+
_, resp = self._request(MsgType.CMD_CANCEL, {"task_id": h.task_id, "force": force})
|
|
420
|
+
return bool(resp.get("cancelled"))
|
|
421
|
+
|
|
422
|
+
def clear_done(self) -> int:
|
|
423
|
+
_, resp = self._request(MsgType.CMD_CLEAR_DONE, {})
|
|
424
|
+
# 必须按 server 回传的 cleared_ids 同步本地句柄,而不是凭本地
|
|
425
|
+
# `is_finished()` 判断:用户可能 submit 之后从未 poll/wait,本地
|
|
426
|
+
# snapshot.status 仍是 RUNNING/PENDING,但 server 已经把它清理掉了。
|
|
427
|
+
# 早期实现按本地状态删,会导致句柄永久残留 → 后续 wait/get 拿不到结果。
|
|
428
|
+
cleared_ids = resp.get("cleared_ids") or []
|
|
429
|
+
with self._handles_lock:
|
|
430
|
+
for tid in cleared_ids:
|
|
431
|
+
self._handles.pop(tid, None)
|
|
432
|
+
return int(resp.get("count", 0))
|
|
433
|
+
|
|
434
|
+
def mock_request_infer(self, query: str) -> MockRequest:
|
|
435
|
+
"""mock_request 的字符串便捷封装,对应 infer 的调用形式。
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
query (str): 用户输入的字符串,将封装为单轮 USER 消息。
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
MockRequest: 包含 URLs、模型名、请求头和负载的模拟请求。
|
|
442
|
+
"""
|
|
443
|
+
return self.mock_request(InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]))
|
|
444
|
+
|
|
445
|
+
def mock_request(self, request: InferenceRequest) -> MockRequest:
|
|
446
|
+
"""组装最终发给大模型服务商的请求,不发出任何网络调用。
|
|
447
|
+
|
|
448
|
+
用于调试:验证 URL、Headers、Payload 是否符合预期。
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
request (InferenceRequest): 推理请求对象。
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
MockRequest: 包含 URLs、模型名、请求头和负载的模拟请求。
|
|
455
|
+
"""
|
|
456
|
+
_, resp = self._request(MsgType.CMD_MOCK_REQUEST, {"request": request.model_dump()})
|
|
457
|
+
return MockRequest(urls=resp["urls"], model=resp["model"], headers=resp["headers"], payload=resp["payload"])
|
|
458
|
+
|
|
459
|
+
def ping(self) -> None:
|
|
460
|
+
self._request(MsgType.CMD_PING, {})
|
|
461
|
+
|
|
462
|
+
def stats(self) -> dict[str, Any]:
|
|
463
|
+
"""查询本连接的任务状态分布,适合驱动 `finished/total` 进度条。
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
形如::
|
|
467
|
+
|
|
468
|
+
{
|
|
469
|
+
"total": <int, Σ counts.values()>,
|
|
470
|
+
"finished": <int, SUCCESS + FAILED + TIMEOUT + CANCELLED + REJECTED>,
|
|
471
|
+
"counts": {"SUBMITTED": int, "ACCEPTED": int, ..., "REJECTED": int},
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
计数口径说明(重要):
|
|
475
|
+
- 非终态(SUBMITTED/ACCEPTED/PENDING/RUNNING)维护的是**当前处于该
|
|
476
|
+
状态的任务数**,会随状态前进而回落到 0。
|
|
477
|
+
- 终态(SUCCESS/FAILED/TIMEOUT/CANCELLED/REJECTED)只会累加,且 POP/WAIT
|
|
478
|
+
消费结果、CLEAR_DONE 从 server 侧摘除条目都不会使其回退(这些操作不
|
|
479
|
+
触发状态转移)。
|
|
480
|
+
- 因此对一个只管 submit / wait 的连接而言,`total` 等于累计 submit
|
|
481
|
+
的任务数,`finished` 等于累计完成数,可直接当进度条的 `n/total` 用。
|
|
482
|
+
"""
|
|
483
|
+
_, resp = self._request(MsgType.CMD_STATS, {})
|
|
484
|
+
return {
|
|
485
|
+
"total": int(resp.get("total", 0)),
|
|
486
|
+
"finished": int(resp.get("finished", 0)),
|
|
487
|
+
"counts": dict(resp.get("counts") or {}),
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
# ------------------------------------------------------------------
|
|
491
|
+
# 结果拼装
|
|
492
|
+
# ------------------------------------------------------------------
|
|
493
|
+
|
|
494
|
+
def _consume_result(self, handle: ClientTaskHandle, wr: WireResult) -> InferenceRequestResult:
|
|
495
|
+
# 同步终态到本地 snapshot:使用 WireResult 携带的 status,以区分 FAILED / REJECTED / CANCELLED
|
|
496
|
+
handle.snapshot.status = wr.status
|
|
497
|
+
with self._handles_lock:
|
|
498
|
+
self._handles.pop(handle.task_id, None)
|
|
499
|
+
return InferenceRequestResult(
|
|
500
|
+
success=wr.success,
|
|
501
|
+
task_id=wr.task_id,
|
|
502
|
+
request=handle.request,
|
|
503
|
+
model_output=wr.model_output,
|
|
504
|
+
error_message=wr.error_message,
|
|
505
|
+
duration=wr.duration,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
# ------------------------------------------------------------------
|
|
509
|
+
# 同步便捷接口
|
|
510
|
+
# ------------------------------------------------------------------
|
|
511
|
+
|
|
512
|
+
def inference(self, request: InferenceRequest) -> InferenceRequestResult:
|
|
513
|
+
return self.wait(self.submit(request))
|
|
514
|
+
|
|
515
|
+
def infer(self, query: str) -> str:
|
|
516
|
+
result = self.inference(InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]))
|
|
517
|
+
if result.success and result.model_output is not None:
|
|
518
|
+
return result.model_output.content
|
|
519
|
+
return ""
|
|
520
|
+
|
|
521
|
+
def batch_inference(
|
|
522
|
+
self,
|
|
523
|
+
requests: list[InferenceRequest],
|
|
524
|
+
output_file: Optional[str] = None,
|
|
525
|
+
silent_mode: bool = False,
|
|
526
|
+
) -> list[InferenceRequestResult]:
|
|
527
|
+
handles = [self.submit(req) for req in requests]
|
|
528
|
+
pbar: Optional[tqdm] = tqdm(total=len(requests), desc="Batch: Kitty") if not silent_mode else None
|
|
529
|
+
f: Optional[IO[str]] = open(output_file, "w", encoding="utf-8") if output_file else None
|
|
530
|
+
results: list[InferenceRequestResult] = []
|
|
531
|
+
try:
|
|
532
|
+
for h in handles:
|
|
533
|
+
result = self.wait(h)
|
|
534
|
+
results.append(result)
|
|
535
|
+
if f is not None:
|
|
536
|
+
f.write(result.model_dump_json() + "\n")
|
|
537
|
+
f.flush()
|
|
538
|
+
if pbar is not None:
|
|
539
|
+
pbar.update(1)
|
|
540
|
+
finally:
|
|
541
|
+
if pbar is not None:
|
|
542
|
+
pbar.close()
|
|
543
|
+
if f is not None:
|
|
544
|
+
f.close()
|
|
545
|
+
return results
|
|
546
|
+
|
|
547
|
+
def batch_infer(self, queries: list[str], **kwargs: Any) -> list[str]:
|
|
548
|
+
reqs = [InferenceRequest(messages=[Message(role=MessageRole.USER, content=q)]) for q in queries]
|
|
549
|
+
results = self.batch_inference(reqs, **kwargs)
|
|
550
|
+
return [r.model_output.content if r.success and r.model_output else "" for r in results]
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# llm_engine/kitty/config.py
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
KittyEngine 配置数据结构。
|
|
5
|
+
|
|
6
|
+
- KittyEngineConfig:引擎完整配置。由 server 启动时从 YAML 加载;embedded 模式下
|
|
7
|
+
由 KittyEngine.__init__ 根据 kwargs 内部构造后 pickle 给 worker 子进程。
|
|
8
|
+
- KittyEngineOverrides:client 连接时可选的覆盖集。每个字段均 Optional,不包含
|
|
9
|
+
server 全局资源字段(registry_path / max_global_concurrency / listen)。
|
|
10
|
+
|
|
11
|
+
两者均为内部数据结构,不通过顶层 __init__.py 导出给用户。
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, Optional
|
|
16
|
+
|
|
17
|
+
import yaml
|
|
18
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
19
|
+
|
|
20
|
+
from ..schemas import InferenceParameters
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class KittyEngineConfig(BaseModel):
|
|
24
|
+
"""KittyEngine 完整运行时配置。"""
|
|
25
|
+
|
|
26
|
+
model_config = ConfigDict(extra="forbid")
|
|
27
|
+
|
|
28
|
+
# —— 模型注册表 ——
|
|
29
|
+
registry_path: str = Field(description="ModelConfigRegistry 的 JSON 配置路径")
|
|
30
|
+
|
|
31
|
+
# —— 引擎默认 ——
|
|
32
|
+
default_model: Optional[str] = None
|
|
33
|
+
default_api_key: Optional[str] = None
|
|
34
|
+
default_inference_parameters: Optional[InferenceParameters] = None
|
|
35
|
+
extra_headers: dict[str, str] = Field(default_factory=dict)
|
|
36
|
+
extra_payload: dict[str, Any] = Field(default_factory=dict)
|
|
37
|
+
|
|
38
|
+
# —— 并发与重试(server 全局,不可被 client 覆盖) ——
|
|
39
|
+
max_global_concurrency: int = Field(default=64, ge=1)
|
|
40
|
+
default_timeout: Optional[float] = None
|
|
41
|
+
# max_retries 语义为"包含首次的总尝试次数",因此最小值是 1(即不重试,只跑一次)。
|
|
42
|
+
# 若传 0,会让 _send_request 的 for-range 循环不执行而落入 "Unexpected end of retry loop" 死代码。
|
|
43
|
+
default_max_retries: int = Field(default=5, ge=1)
|
|
44
|
+
default_base_delay: int = Field(default=2, ge=1)
|
|
45
|
+
|
|
46
|
+
# —— server 专属 ——
|
|
47
|
+
listen: str = "unix:///tmp/kitty.sock"
|
|
48
|
+
log_level: str = "INFO"
|
|
49
|
+
|
|
50
|
+
# —— 连接调优 ——
|
|
51
|
+
http_client_connect_timeout: float = 10.0
|
|
52
|
+
http_client_read_timeout: float = 3600.0
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def from_yaml(cls, path: str | Path) -> "KittyEngineConfig":
|
|
56
|
+
path = Path(path)
|
|
57
|
+
with open(path, mode="r", encoding="utf-8") as f:
|
|
58
|
+
data = yaml.safe_load(f) or {}
|
|
59
|
+
if not isinstance(data, dict):
|
|
60
|
+
raise ValueError(f"YAML 根节点必须是映射: {path}")
|
|
61
|
+
return cls.model_validate(data)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class KittyEngineOverrides(BaseModel):
|
|
65
|
+
"""Client 连接时可携带的覆盖项。不含 server 全局资源字段。"""
|
|
66
|
+
|
|
67
|
+
model_config = ConfigDict(extra="forbid")
|
|
68
|
+
|
|
69
|
+
default_model: Optional[str] = None
|
|
70
|
+
default_api_key: Optional[str] = None
|
|
71
|
+
default_inference_parameters: Optional[InferenceParameters] = None
|
|
72
|
+
extra_headers: dict[str, str] = Field(default_factory=dict)
|
|
73
|
+
extra_payload: dict[str, Any] = Field(default_factory=dict)
|
|
74
|
+
default_timeout: Optional[float] = None
|
|
75
|
+
# 同 KittyEngineConfig:>=1,None 表示"不覆盖、用 server 默认"。
|
|
76
|
+
default_max_retries: Optional[int] = Field(default=None, ge=1)
|
|
77
|
+
default_base_delay: Optional[int] = Field(default=None, ge=1)
|
|
78
|
+
|
|
79
|
+
def has_any(self) -> bool:
|
|
80
|
+
"""是否包含任何非默认值。空 overrides 可免发 HELLO。"""
|
|
81
|
+
dumped = self.model_dump(exclude_none=True)
|
|
82
|
+
# extra_headers / extra_payload 的默认值是空 dict,dump 会保留;过滤掉
|
|
83
|
+
return any(v for v in dumped.values())
|