llm-engine-kitty 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_engine/__init__.py +54 -0
- llm_engine/engine.py +771 -0
- llm_engine/general_engine.py +562 -0
- llm_engine/kitty/__init__.py +8 -0
- llm_engine/kitty/__main__.py +46 -0
- llm_engine/kitty/client.py +550 -0
- llm_engine/kitty/config.py +83 -0
- llm_engine/kitty/engine.py +1077 -0
- llm_engine/kitty/protocol.py +213 -0
- llm_engine/kitty/schemas.py +89 -0
- llm_engine/kitty/server.py +408 -0
- llm_engine/model_config.py +112 -0
- llm_engine/schemas.py +251 -0
- llm_engine/utils.py +34 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/METADATA +15 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/RECORD +18 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/WHEEL +5 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
# llm_engine/kitty/server.py
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
KittyServer:在子进程中 driving 事件循环 + 接受 socket 连接的服务端。
|
|
5
|
+
|
|
6
|
+
- 一个进程内只有一个 _KittyEngine 实例被多个连接共享
|
|
7
|
+
- 每个连接维护自己的 _ConnectionContext,包含其携带的 overrides 以及它提交的 task_id 集合
|
|
8
|
+
- 帧级别请求→响应模型,submit 立即返回 snapshot,wait 在服务端阻塞 done_event
|
|
9
|
+
- 连接断开时:持有的未完成任务被取消,未持久化的已完成结果被回收
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
import signal
|
|
14
|
+
import socket
|
|
15
|
+
import sys
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from typing import Any, Optional
|
|
18
|
+
|
|
19
|
+
import kitty_logger
|
|
20
|
+
|
|
21
|
+
from ..schemas import InferenceRequest, TaskStatus
|
|
22
|
+
from ..utils import gen_unique_id
|
|
23
|
+
from .config import KittyEngineConfig, KittyEngineOverrides
|
|
24
|
+
from .protocol import MsgType, aread_frame, encode_frame, parse_endpoint
|
|
25
|
+
from .engine import _KittyEngine
|
|
26
|
+
|
|
27
|
+
logger = kitty_logger.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class _ConnectionContext:
|
|
32
|
+
connection_id: str
|
|
33
|
+
overrides: Optional[KittyEngineOverrides] = None
|
|
34
|
+
task_ids: set[str] = field(default_factory=set)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class KittyServer:
|
|
38
|
+
def __init__(self, config: KittyEngineConfig) -> None:
|
|
39
|
+
self.config = config
|
|
40
|
+
self.engine = _KittyEngine(config)
|
|
41
|
+
self._server: Optional[asyncio.AbstractServer] = None
|
|
42
|
+
self._stop_event: Optional[asyncio.Event] = None
|
|
43
|
+
self._endpoint_kind: Optional[str] = None
|
|
44
|
+
self._endpoint_addr: Optional[tuple] = None
|
|
45
|
+
# 任务 → 所属连接 的反向映射。维护时机:
|
|
46
|
+
# - CMD_SUBMIT 成功返回前登记(必须早于 engine.submit 首次触发
|
|
47
|
+
# callback,否则初始 SUBMITTED 事件会找不到连接而被丢弃)。
|
|
48
|
+
# - CMD_POP / CMD_WAIT 真正消费掉结果时 pop。
|
|
49
|
+
# - CMD_CLEAR_DONE 按 engine.clear_done() 返回的 id 列表批量 pop。
|
|
50
|
+
# - 连接关闭时由 _handle_client 的 finally 清理该连接剩余的全部映射。
|
|
51
|
+
self._task_to_connection: dict[str, str] = {}
|
|
52
|
+
# 每连接的状态计数器(累计口径,不随 POP/WAIT 消费而减)。
|
|
53
|
+
# key 用 TaskStatus 枚举本体,回包时再转成 name 字符串给 client 侧拿。
|
|
54
|
+
self._connection_statistics: dict[str, dict[TaskStatus, int]] = {}
|
|
55
|
+
self.engine.set_transition_callback(self._on_task_transition)
|
|
56
|
+
|
|
57
|
+
# ------------------------------------------------------------------
|
|
58
|
+
# 状态统计(callback)
|
|
59
|
+
# ------------------------------------------------------------------
|
|
60
|
+
|
|
61
|
+
def _on_task_transition(
|
|
62
|
+
self,
|
|
63
|
+
task_id: str,
|
|
64
|
+
old: Optional[TaskStatus],
|
|
65
|
+
new: TaskStatus,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""worker → server 的状态回调。维护 per-connection 计数器。
|
|
68
|
+
|
|
69
|
+
语义:
|
|
70
|
+
- old is None:首次落地(worker.submit 触发),只对 new 做 +1。
|
|
71
|
+
- 正常转移:old -1、new +1,保持"当前处于各状态的任务数"口径。
|
|
72
|
+
终态由于不会再转出,对应计数只增不减。
|
|
73
|
+
|
|
74
|
+
健壮性:
|
|
75
|
+
- 找不到 connection_id 或 statistics 条目时安静返回。正常路径不会
|
|
76
|
+
走到(CMD_SUBMIT 保证映射早于 worker.submit 调用),唯一命中
|
|
77
|
+
场景是"连接已关闭、finally 已经把映射/statistics 清掉后,
|
|
78
|
+
cleanup_tasks 里取消任务产生的最后一轮 transition"——这是预期
|
|
79
|
+
的 no-op 路径,因此用 debug 级而非 warning,避免连接断开时
|
|
80
|
+
刷屏(每个未完成任务都会触发一次)。
|
|
81
|
+
- old 的 -1 做 max(0, ...) 下界保护:正常路径下初始 +1 保证非负,
|
|
82
|
+
但 callback 是"绝不能抛异常也绝不能把状态搞歪"的合约,兜底更稳。
|
|
83
|
+
- 本方法被 `_WorkerTask.transition` 同步调用,必须尽快返回。
|
|
84
|
+
"""
|
|
85
|
+
connection_id: str | None = self._task_to_connection.get(task_id)
|
|
86
|
+
if connection_id is None:
|
|
87
|
+
logger.debug("KittyServer: can not find conn_id for task %s (likely cleanup path)", task_id)
|
|
88
|
+
return
|
|
89
|
+
statistics = self._connection_statistics.get(connection_id)
|
|
90
|
+
if statistics is None:
|
|
91
|
+
logger.debug("KittyServer: can not find statistics for connection %s (likely cleanup path)", connection_id)
|
|
92
|
+
return
|
|
93
|
+
if old is not None:
|
|
94
|
+
statistics[old] = max(0, statistics.get(old, 0) - 1)
|
|
95
|
+
statistics[new] = statistics.get(new, 0) + 1
|
|
96
|
+
|
|
97
|
+
# ------------------------------------------------------------------
|
|
98
|
+
# 入口
|
|
99
|
+
# ------------------------------------------------------------------
|
|
100
|
+
|
|
101
|
+
async def run(self, mode: str = "standalone", auto_port: bool = False) -> None:
|
|
102
|
+
"""启动 server;阻塞直到收到 stop 信号。
|
|
103
|
+
|
|
104
|
+
mode: "standalone"(命令行启动)或 "embedded"(子进程启动,stdin EOF 时自动退出)。
|
|
105
|
+
auto_port: 仅 TCP 模式有效。True 时端口被占用会自动寻找可用端口;False 时直接 raise。
|
|
106
|
+
"""
|
|
107
|
+
self._stop_event = asyncio.Event()
|
|
108
|
+
await self.engine.setup()
|
|
109
|
+
|
|
110
|
+
kind, addr = parse_endpoint(self.config.listen)
|
|
111
|
+
self._endpoint_kind = kind
|
|
112
|
+
self._endpoint_addr = addr
|
|
113
|
+
|
|
114
|
+
if kind == "unix":
|
|
115
|
+
import os
|
|
116
|
+
|
|
117
|
+
path = addr[0]
|
|
118
|
+
try:
|
|
119
|
+
os.unlink(path)
|
|
120
|
+
except FileNotFoundError:
|
|
121
|
+
pass
|
|
122
|
+
self._server = await asyncio.start_unix_server(self._handle_client, path=path)
|
|
123
|
+
os.chmod(path, 0o600)
|
|
124
|
+
logger.info("KittyServer listening on unix://%s", path)
|
|
125
|
+
else:
|
|
126
|
+
host, port = addr
|
|
127
|
+
if auto_port:
|
|
128
|
+
while True:
|
|
129
|
+
try:
|
|
130
|
+
self._server = await asyncio.start_server(self._handle_client, host=host, port=port)
|
|
131
|
+
break
|
|
132
|
+
except OSError:
|
|
133
|
+
# 端口被占用,让 OS 分配一个随机可用端口
|
|
134
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
135
|
+
s.bind((host, 0))
|
|
136
|
+
port = s.getsockname()[1]
|
|
137
|
+
# 更新 config.listen 以便外部可查询实际端口
|
|
138
|
+
self.config = self.config.model_copy(update={"listen": f"tcp://{host}:{port}"})
|
|
139
|
+
else:
|
|
140
|
+
self._server = await asyncio.start_server(self._handle_client, host=host, port=port)
|
|
141
|
+
logger.info("KittyServer listening on tcp://%s:%d", host, port)
|
|
142
|
+
|
|
143
|
+
# 注册信号(仅主线程;子进程 spawn 启动时主线程即这里)
|
|
144
|
+
loop = asyncio.get_running_loop()
|
|
145
|
+
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
146
|
+
try:
|
|
147
|
+
loop.add_signal_handler(sig, self._stop_event.set)
|
|
148
|
+
except NotImplementedError:
|
|
149
|
+
# Windows 等平台不支持
|
|
150
|
+
pass
|
|
151
|
+
|
|
152
|
+
async with self._server:
|
|
153
|
+
stop_task = asyncio.create_task(self._stop_event.wait())
|
|
154
|
+
serve_task = asyncio.create_task(self._server.serve_forever())
|
|
155
|
+
tasks = {stop_task, serve_task}
|
|
156
|
+
if mode == "embedded":
|
|
157
|
+
tasks.add(asyncio.create_task(self._watch_stdin()))
|
|
158
|
+
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
159
|
+
for t in pending:
|
|
160
|
+
t.cancel()
|
|
161
|
+
# 清理异常
|
|
162
|
+
for t in done:
|
|
163
|
+
exc = t.exception()
|
|
164
|
+
if exc is not None and not isinstance(exc, asyncio.CancelledError):
|
|
165
|
+
logger.error("server task 异常: %r", exc)
|
|
166
|
+
|
|
167
|
+
await self.engine.teardown()
|
|
168
|
+
|
|
169
|
+
# 清理 unix sock 文件
|
|
170
|
+
if kind == "unix":
|
|
171
|
+
import os
|
|
172
|
+
|
|
173
|
+
try:
|
|
174
|
+
os.unlink(addr[0])
|
|
175
|
+
except FileNotFoundError:
|
|
176
|
+
pass
|
|
177
|
+
logger.info("KittyServer 已停止")
|
|
178
|
+
|
|
179
|
+
def stop(self) -> None:
|
|
180
|
+
"""由信号/SHUTDOWN 消息触发。"""
|
|
181
|
+
if self._stop_event is not None:
|
|
182
|
+
self._stop_event.set()
|
|
183
|
+
|
|
184
|
+
async def _watch_stdin(self) -> None:
|
|
185
|
+
"""embedded 模式:监听 stdin EOF,父进程退出时自动停止 server。"""
|
|
186
|
+
loop = asyncio.get_running_loop()
|
|
187
|
+
reader = asyncio.StreamReader()
|
|
188
|
+
protocol = asyncio.StreamReaderProtocol(reader)
|
|
189
|
+
# 使用 sys.stdin.buffer(二进制模式):避免文本层缓冲导致 asyncio transport
|
|
190
|
+
# 看不到底层 fd 中已到达的字节,且与 __main__.py 中 readline 消费第一行后
|
|
191
|
+
# 底层 fd 位置保持一致。
|
|
192
|
+
await loop.connect_read_pipe(lambda: protocol, sys.stdin.buffer)
|
|
193
|
+
try:
|
|
194
|
+
while True:
|
|
195
|
+
data = await reader.read(4096)
|
|
196
|
+
if not data: # EOF
|
|
197
|
+
break
|
|
198
|
+
except Exception:
|
|
199
|
+
pass
|
|
200
|
+
logger.info("KittyServer: stdin EOF,embedded 模式自动退出")
|
|
201
|
+
self.stop()
|
|
202
|
+
|
|
203
|
+
# ------------------------------------------------------------------
|
|
204
|
+
# 连接处理
|
|
205
|
+
# ------------------------------------------------------------------
|
|
206
|
+
|
|
207
|
+
async def _handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
|
208
|
+
ctx = _ConnectionContext(connection_id=gen_unique_id(prefix="conn"))
|
|
209
|
+
# 初始化该连接的状态计数器。所有 TaskStatus 都显式置 0,方便 client 侧
|
|
210
|
+
# 直接按 enum 取值渲染进度条。
|
|
211
|
+
self._connection_statistics[ctx.connection_id] = {s: 0 for s in TaskStatus}
|
|
212
|
+
peer = writer.get_extra_info("peername") or writer.get_extra_info("sockname")
|
|
213
|
+
logger.info("连接建立 connection_id=%s peer=%s", ctx.connection_id, peer)
|
|
214
|
+
|
|
215
|
+
try:
|
|
216
|
+
while True:
|
|
217
|
+
try:
|
|
218
|
+
msg_type, payload = await aread_frame(reader)
|
|
219
|
+
except (asyncio.IncompleteReadError, ConnectionError):
|
|
220
|
+
break
|
|
221
|
+
except Exception as e:
|
|
222
|
+
logger.warning("connection_id=%s 读帧异常: %r", ctx.connection_id, e)
|
|
223
|
+
break
|
|
224
|
+
|
|
225
|
+
try:
|
|
226
|
+
should_close = await self._dispatch(ctx, msg_type, payload, writer)
|
|
227
|
+
except Exception as e:
|
|
228
|
+
logger.exception("connection_id=%s 处理 msg_type=%d 异常", ctx.connection_id, msg_type)
|
|
229
|
+
await self._send(writer, MsgType.MSG_ERROR, {"error": f"{type(e).__name__}: {e}"})
|
|
230
|
+
should_close = False
|
|
231
|
+
|
|
232
|
+
if should_close:
|
|
233
|
+
break
|
|
234
|
+
finally:
|
|
235
|
+
# 先断开 task→connection 映射 + 移除 statistics,callback 对遗留任务
|
|
236
|
+
# 变为 no-op,避免 cleanup_tasks 期间的 transition 事件仍去维护
|
|
237
|
+
# 一份马上要丢弃的计数。
|
|
238
|
+
leftover = list(ctx.task_ids)
|
|
239
|
+
for task_id in leftover:
|
|
240
|
+
self._task_to_connection.pop(task_id, None)
|
|
241
|
+
self._connection_statistics.pop(ctx.connection_id, None)
|
|
242
|
+
# 清理该连接持有的任务:取消 + 等待进入终态 + 非 persist 移除
|
|
243
|
+
if leftover:
|
|
244
|
+
try:
|
|
245
|
+
await self.engine.cleanup_tasks(leftover)
|
|
246
|
+
except Exception:
|
|
247
|
+
logger.exception("connection_id=%s cleanup_tasks 异常", ctx.connection_id)
|
|
248
|
+
try:
|
|
249
|
+
writer.close()
|
|
250
|
+
await writer.wait_closed()
|
|
251
|
+
except Exception:
|
|
252
|
+
pass
|
|
253
|
+
logger.info("连接关闭 connection_id=%s", ctx.connection_id)
|
|
254
|
+
|
|
255
|
+
async def _send(self, writer: asyncio.StreamWriter, msg_type: int, payload: dict[str, Any]) -> None:
|
|
256
|
+
writer.write(encode_frame(msg_type, payload))
|
|
257
|
+
await writer.drain()
|
|
258
|
+
|
|
259
|
+
# ------------------------------------------------------------------
|
|
260
|
+
# 消息分发
|
|
261
|
+
# ------------------------------------------------------------------
|
|
262
|
+
|
|
263
|
+
async def _dispatch(
|
|
264
|
+
self,
|
|
265
|
+
ctx: _ConnectionContext,
|
|
266
|
+
msg_type: int,
|
|
267
|
+
payload: dict[str, Any],
|
|
268
|
+
writer: asyncio.StreamWriter,
|
|
269
|
+
) -> bool:
|
|
270
|
+
"""处理一帧消息,返回是否应关闭连接。"""
|
|
271
|
+
|
|
272
|
+
if msg_type == MsgType.CMD_HELLO:
|
|
273
|
+
ov_dict = payload.get("overrides")
|
|
274
|
+
if ov_dict:
|
|
275
|
+
ctx.overrides = KittyEngineOverrides.model_validate(ov_dict)
|
|
276
|
+
await self._send(writer, MsgType.MSG_WELCOME, {"conn_id": ctx.connection_id}) # 协议字段名,勿改
|
|
277
|
+
return False
|
|
278
|
+
|
|
279
|
+
if msg_type == MsgType.CMD_PING:
|
|
280
|
+
await self._send(writer, MsgType.MSG_PONG, {})
|
|
281
|
+
return False
|
|
282
|
+
|
|
283
|
+
if msg_type == MsgType.CMD_SUBMIT:
|
|
284
|
+
task_id = payload.get("task_id") or gen_unique_id(prefix="task")
|
|
285
|
+
request = InferenceRequest.model_validate(payload["request"])
|
|
286
|
+
persist = bool(payload.get("persist", False))
|
|
287
|
+
timeout = payload.get("timeout")
|
|
288
|
+
# 必须先登记 task→connection 映射,再调 engine.submit:
|
|
289
|
+
# engine.submit 内部会同步触发首次 transition callback,届时回调
|
|
290
|
+
# 要能找到本任务所属连接。顺序反了首次 SUBMITTED 事件会被丢。
|
|
291
|
+
self._task_to_connection[task_id] = ctx.connection_id
|
|
292
|
+
try:
|
|
293
|
+
snap = self.engine.submit(
|
|
294
|
+
task_id=task_id,
|
|
295
|
+
request=request,
|
|
296
|
+
overrides=ctx.overrides,
|
|
297
|
+
persist=persist,
|
|
298
|
+
timeout=timeout,
|
|
299
|
+
)
|
|
300
|
+
except Exception:
|
|
301
|
+
# 登记后若 submit 自己抛了,回滚避免 _task_to_connection 泄漏
|
|
302
|
+
self._task_to_connection.pop(task_id, None)
|
|
303
|
+
raise
|
|
304
|
+
ctx.task_ids.add(task_id)
|
|
305
|
+
await self._send(writer, MsgType.MSG_SNAPSHOT, {"snapshot": snap.model_dump()})
|
|
306
|
+
return False
|
|
307
|
+
|
|
308
|
+
if msg_type == MsgType.CMD_CANCEL:
|
|
309
|
+
task_id = payload["task_id"]
|
|
310
|
+
force = bool(payload.get("force", False))
|
|
311
|
+
cancelled = self.engine.cancel(task_id, force=force)
|
|
312
|
+
await self._send(writer, MsgType.MSG_PONG, {"cancelled": cancelled})
|
|
313
|
+
return False
|
|
314
|
+
|
|
315
|
+
if msg_type == MsgType.CMD_POLL:
|
|
316
|
+
task_id = payload["task_id"]
|
|
317
|
+
snap = self.engine.snapshot(task_id)
|
|
318
|
+
await self._send(
|
|
319
|
+
writer,
|
|
320
|
+
MsgType.MSG_SNAPSHOT,
|
|
321
|
+
{"snapshot": snap.model_dump() if snap is not None else None},
|
|
322
|
+
)
|
|
323
|
+
return False
|
|
324
|
+
|
|
325
|
+
if msg_type == MsgType.CMD_POP:
|
|
326
|
+
task_id = payload["task_id"]
|
|
327
|
+
result = self.engine.pop_result(task_id)
|
|
328
|
+
if result is not None:
|
|
329
|
+
ctx.task_ids.discard(task_id)
|
|
330
|
+
self._task_to_connection.pop(task_id, None)
|
|
331
|
+
await self._send(
|
|
332
|
+
writer,
|
|
333
|
+
MsgType.MSG_RESULT,
|
|
334
|
+
{
|
|
335
|
+
"found": result is not None,
|
|
336
|
+
"result": result.model_dump() if result is not None else None,
|
|
337
|
+
"timed_out": False,
|
|
338
|
+
},
|
|
339
|
+
)
|
|
340
|
+
return False
|
|
341
|
+
|
|
342
|
+
if msg_type == MsgType.CMD_WAIT:
|
|
343
|
+
task_id = payload["task_id"]
|
|
344
|
+
timeout = payload.get("timeout")
|
|
345
|
+
result, timed_out = await self.engine.wait_task(task_id, timeout)
|
|
346
|
+
if result is not None:
|
|
347
|
+
ctx.task_ids.discard(task_id)
|
|
348
|
+
self._task_to_connection.pop(task_id, None)
|
|
349
|
+
await self._send(
|
|
350
|
+
writer,
|
|
351
|
+
MsgType.MSG_RESULT,
|
|
352
|
+
{
|
|
353
|
+
"found": result is not None,
|
|
354
|
+
"result": result.model_dump() if result is not None else None,
|
|
355
|
+
"timed_out": timed_out,
|
|
356
|
+
},
|
|
357
|
+
)
|
|
358
|
+
return False
|
|
359
|
+
|
|
360
|
+
if msg_type == MsgType.CMD_CLEAR_DONE:
|
|
361
|
+
cleared_ids = self.engine.clear_done()
|
|
362
|
+
# engine 一次性 GC 了所有终态任务;同步剔除两张映射 + 本连接 task_ids
|
|
363
|
+
for task_id in cleared_ids:
|
|
364
|
+
self._task_to_connection.pop(task_id, None)
|
|
365
|
+
if cleared_ids:
|
|
366
|
+
ctx.task_ids.difference_update(cleared_ids)
|
|
367
|
+
# 把 cleared_ids 一并回传给 client,client 端按 ids 同步本地句柄。
|
|
368
|
+
# 不能依赖 client 本地 snapshot.status 判断(任务可能未 poll 过状态,
|
|
369
|
+
# 本地 status 还是 RUNNING/PENDING,但 server 这边已经摘除条目了)。
|
|
370
|
+
await self._send(
|
|
371
|
+
writer,
|
|
372
|
+
MsgType.MSG_PONG,
|
|
373
|
+
{"count": len(cleared_ids), "cleared_ids": cleared_ids},
|
|
374
|
+
)
|
|
375
|
+
return False
|
|
376
|
+
|
|
377
|
+
if msg_type == MsgType.CMD_STATS:
|
|
378
|
+
# 返回本连接累计的状态分布。这是"累计口径":SUCCESS/FAILED/TIMEOUT/CANCELLED/REJECTED
|
|
379
|
+
# 不会因 POP/WAIT 消费而减,方便 client 侧做 tqdm 进度条。
|
|
380
|
+
statistics = self._connection_statistics.get(ctx.connection_id, {})
|
|
381
|
+
counts = {s.name: int(statistics.get(s, 0)) for s in TaskStatus}
|
|
382
|
+
finished = sum(counts[s] for s in ("SUCCESS", "FAILED", "TIMEOUT", "CANCELLED", "REJECTED"))
|
|
383
|
+
total = sum(counts.values())
|
|
384
|
+
await self._send(
|
|
385
|
+
writer,
|
|
386
|
+
MsgType.MSG_STATS,
|
|
387
|
+
{"total": total, "finished": finished, "counts": counts},
|
|
388
|
+
)
|
|
389
|
+
return False
|
|
390
|
+
|
|
391
|
+
if msg_type == MsgType.CMD_BYE:
|
|
392
|
+
return True
|
|
393
|
+
|
|
394
|
+
if msg_type == MsgType.CMD_MOCK_REQUEST:
|
|
395
|
+
request = InferenceRequest.model_validate(payload["request"])
|
|
396
|
+
mr = self.engine.mock_request(request, overrides=ctx.overrides)
|
|
397
|
+
await self._send(writer, MsgType.MSG_PREPARED_REQUEST, dict(mr))
|
|
398
|
+
return False
|
|
399
|
+
|
|
400
|
+
if msg_type == MsgType.CMD_SHUTDOWN:
|
|
401
|
+
await self._send(writer, MsgType.MSG_PONG, {})
|
|
402
|
+
self.stop()
|
|
403
|
+
return True
|
|
404
|
+
|
|
405
|
+
await self._send(writer, MsgType.MSG_ERROR, {"error": f"未知 msg_type={msg_type}"})
|
|
406
|
+
return False
|
|
407
|
+
|
|
408
|
+
# (wait_task 逻辑已下沉到 _KittyEngine.wait_task)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# llm_engine/model_config.py
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import threading
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
|
8
|
+
from typing import Any, Dict, List, Literal, Optional
|
|
9
|
+
|
|
10
|
+
import kitty_logger
|
|
11
|
+
|
|
12
|
+
from .schemas import InferenceParameters
|
|
13
|
+
|
|
14
|
+
logger = kitty_logger.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ModelPricing(BaseModel):
|
|
18
|
+
|
|
19
|
+
input: float
|
|
20
|
+
output: float
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ModelConfig(BaseModel):
|
|
24
|
+
model_config = ConfigDict(populate_by_name=True)
|
|
25
|
+
|
|
26
|
+
name: str
|
|
27
|
+
model_id: str = Field(alias="model_id")
|
|
28
|
+
api_urls: List[str] = Field(alias="api_urls")
|
|
29
|
+
api_key: Optional[str] = Field(None, alias="api_key")
|
|
30
|
+
platform: Optional[str] = None
|
|
31
|
+
|
|
32
|
+
pricing: Optional[ModelPricing] = None
|
|
33
|
+
default_inference_parameters: Optional[InferenceParameters] = Field(default=None, alias="inference_parameters")
|
|
34
|
+
extra_headers: Dict[str, str] = Field(default_factory=dict)
|
|
35
|
+
extra_payload: Dict[str, Any] = Field(default_factory=dict)
|
|
36
|
+
|
|
37
|
+
# --- 私有属性:处理运行时逻辑 ---
|
|
38
|
+
_url_active_counts: Dict[str, int] = PrivateAttr()
|
|
39
|
+
_url_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
|
40
|
+
|
|
41
|
+
def model_post_init(self, __context: Any) -> None:
|
|
42
|
+
"""Pydantic 初始化后的钩子,初始化各 URL 的活跃请求计数器。"""
|
|
43
|
+
if not self.api_urls:
|
|
44
|
+
raise ValueError(f"模型 '{self.name}' 的 api_urls 不能为空")
|
|
45
|
+
self._url_active_counts = {url: 0 for url in self.api_urls}
|
|
46
|
+
|
|
47
|
+
def get_url(self) -> str:
|
|
48
|
+
"""选取当前活跃请求数最少的 URL,并将其计数 +1。"""
|
|
49
|
+
with self._url_lock:
|
|
50
|
+
url = min(self._url_active_counts, key=lambda u: self._url_active_counts[u])
|
|
51
|
+
self._url_active_counts[url] += 1
|
|
52
|
+
return url
|
|
53
|
+
|
|
54
|
+
def release_url(self, url: str) -> None:
|
|
55
|
+
"""请求完成后,将对应 URL 的活跃计数 -1。"""
|
|
56
|
+
with self._url_lock:
|
|
57
|
+
if url in self._url_active_counts:
|
|
58
|
+
self._url_active_counts[url] = max(0, self._url_active_counts[url] - 1)
|
|
59
|
+
|
|
60
|
+
def calculate_cost(self, input_tokens: int, output_tokens: int) -> float:
|
|
61
|
+
if not self.pricing:
|
|
62
|
+
return 0.0
|
|
63
|
+
return input_tokens * self.pricing.input + output_tokens * self.pricing.output
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ModelConfigRegistry:
|
|
67
|
+
|
|
68
|
+
__slots__ = ("model_dict",)
|
|
69
|
+
|
|
70
|
+
def __init__(self):
|
|
71
|
+
|
|
72
|
+
self.model_dict: Dict[str, ModelConfig] = {}
|
|
73
|
+
|
|
74
|
+
def load_from_json(self, config_path: str, mode: Literal["reload", "merge"] = "reload") -> bool:
|
|
75
|
+
|
|
76
|
+
path = Path(config_path)
|
|
77
|
+
logger.info(f"正在加载模型配置: '{path}'")
|
|
78
|
+
|
|
79
|
+
if self.model_dict and mode == "reload":
|
|
80
|
+
self.model_dict.clear()
|
|
81
|
+
|
|
82
|
+
if not path.is_file():
|
|
83
|
+
raise FileNotFoundError(f"模型配置文件不存在或不是文件: '{path}'")
|
|
84
|
+
|
|
85
|
+
with open(path, mode="r", encoding="utf-8") as f:
|
|
86
|
+
try:
|
|
87
|
+
raw_data: Dict[str, Dict[str, Any]] = json.load(f)
|
|
88
|
+
except json.JSONDecodeError as e:
|
|
89
|
+
logger.error(f"JSON 格式错误: {e}")
|
|
90
|
+
raise
|
|
91
|
+
|
|
92
|
+
for model_name, info in raw_data.items():
|
|
93
|
+
|
|
94
|
+
if mode == "merge" and model_name in self.model_dict:
|
|
95
|
+
logger.debug(f"模型 '{model_name}' 已存在,将被新配置覆盖。")
|
|
96
|
+
|
|
97
|
+
model_config: ModelConfig = ModelConfig(name=model_name, **info)
|
|
98
|
+
self.model_dict[model_name] = model_config
|
|
99
|
+
|
|
100
|
+
logger.info(f"共加载 {len(self.model_dict)} 个模型配置。")
|
|
101
|
+
|
|
102
|
+
return True
|
|
103
|
+
|
|
104
|
+
def get(self, name: str) -> ModelConfig:
|
|
105
|
+
|
|
106
|
+
if name not in self.model_dict:
|
|
107
|
+
|
|
108
|
+
available_models: str = ", ".join(self.model_dict.keys())
|
|
109
|
+
|
|
110
|
+
raise KeyError(f"模型 '{name}' 未注册。可用模型: [{available_models}]")
|
|
111
|
+
|
|
112
|
+
return self.model_dict[name]
|