llm-engine-kitty 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,408 @@
1
+ # llm_engine/kitty/server.py
2
+
3
+ """
4
+ KittyServer:在子进程中 driving 事件循环 + 接受 socket 连接的服务端。
5
+
6
+ - 一个进程内只有一个 _KittyEngine 实例被多个连接共享
7
+ - 每个连接维护自己的 _ConnectionContext,包含其携带的 overrides 以及它提交的 task_id 集合
8
+ - 帧级别请求→响应模型,submit 立即返回 snapshot,wait 在服务端阻塞 done_event
9
+ - 连接断开时:持有的未完成任务被取消,未持久化的已完成结果被回收
10
+ """
11
+
12
+ import asyncio
13
+ import signal
14
+ import socket
15
+ import sys
16
+ from dataclasses import dataclass, field
17
+ from typing import Any, Optional
18
+
19
+ import kitty_logger
20
+
21
+ from ..schemas import InferenceRequest, TaskStatus
22
+ from ..utils import gen_unique_id
23
+ from .config import KittyEngineConfig, KittyEngineOverrides
24
+ from .protocol import MsgType, aread_frame, encode_frame, parse_endpoint
25
+ from .engine import _KittyEngine
26
+
27
+ logger = kitty_logger.getLogger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class _ConnectionContext:
32
+ connection_id: str
33
+ overrides: Optional[KittyEngineOverrides] = None
34
+ task_ids: set[str] = field(default_factory=set)
35
+
36
+
37
+ class KittyServer:
38
+ def __init__(self, config: KittyEngineConfig) -> None:
39
+ self.config = config
40
+ self.engine = _KittyEngine(config)
41
+ self._server: Optional[asyncio.AbstractServer] = None
42
+ self._stop_event: Optional[asyncio.Event] = None
43
+ self._endpoint_kind: Optional[str] = None
44
+ self._endpoint_addr: Optional[tuple] = None
45
+ # 任务 → 所属连接 的反向映射。维护时机:
46
+ # - CMD_SUBMIT 成功返回前登记(必须早于 engine.submit 首次触发
47
+ # callback,否则初始 SUBMITTED 事件会找不到连接而被丢弃)。
48
+ # - CMD_POP / CMD_WAIT 真正消费掉结果时 pop。
49
+ # - CMD_CLEAR_DONE 按 engine.clear_done() 返回的 id 列表批量 pop。
50
+ # - 连接关闭时由 _handle_client 的 finally 清理该连接剩余的全部映射。
51
+ self._task_to_connection: dict[str, str] = {}
52
+ # 每连接的状态计数器(累计口径,不随 POP/WAIT 消费而减)。
53
+ # key 用 TaskStatus 枚举本体,回包时再转成 name 字符串给 client 侧拿。
54
+ self._connection_statistics: dict[str, dict[TaskStatus, int]] = {}
55
+ self.engine.set_transition_callback(self._on_task_transition)
56
+
57
+ # ------------------------------------------------------------------
58
+ # 状态统计(callback)
59
+ # ------------------------------------------------------------------
60
+
61
+ def _on_task_transition(
62
+ self,
63
+ task_id: str,
64
+ old: Optional[TaskStatus],
65
+ new: TaskStatus,
66
+ ) -> None:
67
+ """worker → server 的状态回调。维护 per-connection 计数器。
68
+
69
+ 语义:
70
+ - old is None:首次落地(worker.submit 触发),只对 new 做 +1。
71
+ - 正常转移:old -1、new +1,保持"当前处于各状态的任务数"口径。
72
+ 终态由于不会再转出,对应计数只增不减。
73
+
74
+ 健壮性:
75
+ - 找不到 connection_id 或 statistics 条目时安静返回。正常路径不会
76
+ 走到(CMD_SUBMIT 保证映射早于 worker.submit 调用),唯一命中
77
+ 场景是"连接已关闭、finally 已经把映射/statistics 清掉后,
78
+ cleanup_tasks 里取消任务产生的最后一轮 transition"——这是预期
79
+ 的 no-op 路径,因此用 debug 级而非 warning,避免连接断开时
80
+ 刷屏(每个未完成任务都会触发一次)。
81
+ - old 的 -1 做 max(0, ...) 下界保护:正常路径下初始 +1 保证非负,
82
+ 但 callback 是"绝不能抛异常也绝不能把状态搞歪"的合约,兜底更稳。
83
+ - 本方法被 `_WorkerTask.transition` 同步调用,必须尽快返回。
84
+ """
85
+ connection_id: str | None = self._task_to_connection.get(task_id)
86
+ if connection_id is None:
87
+ logger.debug("KittyServer: can not find conn_id for task %s (likely cleanup path)", task_id)
88
+ return
89
+ statistics = self._connection_statistics.get(connection_id)
90
+ if statistics is None:
91
+ logger.debug("KittyServer: can not find statistics for connection %s (likely cleanup path)", connection_id)
92
+ return
93
+ if old is not None:
94
+ statistics[old] = max(0, statistics.get(old, 0) - 1)
95
+ statistics[new] = statistics.get(new, 0) + 1
96
+
97
+ # ------------------------------------------------------------------
98
+ # 入口
99
+ # ------------------------------------------------------------------
100
+
101
+ async def run(self, mode: str = "standalone", auto_port: bool = False) -> None:
102
+ """启动 server;阻塞直到收到 stop 信号。
103
+
104
+ mode: "standalone"(命令行启动)或 "embedded"(子进程启动,stdin EOF 时自动退出)。
105
+ auto_port: 仅 TCP 模式有效。True 时端口被占用会自动寻找可用端口;False 时直接 raise。
106
+ """
107
+ self._stop_event = asyncio.Event()
108
+ await self.engine.setup()
109
+
110
+ kind, addr = parse_endpoint(self.config.listen)
111
+ self._endpoint_kind = kind
112
+ self._endpoint_addr = addr
113
+
114
+ if kind == "unix":
115
+ import os
116
+
117
+ path = addr[0]
118
+ try:
119
+ os.unlink(path)
120
+ except FileNotFoundError:
121
+ pass
122
+ self._server = await asyncio.start_unix_server(self._handle_client, path=path)
123
+ os.chmod(path, 0o600)
124
+ logger.info("KittyServer listening on unix://%s", path)
125
+ else:
126
+ host, port = addr
127
+ if auto_port:
128
+ while True:
129
+ try:
130
+ self._server = await asyncio.start_server(self._handle_client, host=host, port=port)
131
+ break
132
+ except OSError:
133
+ # 端口被占用,让 OS 分配一个随机可用端口
134
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
135
+ s.bind((host, 0))
136
+ port = s.getsockname()[1]
137
+ # 更新 config.listen 以便外部可查询实际端口
138
+ self.config = self.config.model_copy(update={"listen": f"tcp://{host}:{port}"})
139
+ else:
140
+ self._server = await asyncio.start_server(self._handle_client, host=host, port=port)
141
+ logger.info("KittyServer listening on tcp://%s:%d", host, port)
142
+
143
+ # 注册信号(仅主线程;子进程 spawn 启动时主线程即这里)
144
+ loop = asyncio.get_running_loop()
145
+ for sig in (signal.SIGINT, signal.SIGTERM):
146
+ try:
147
+ loop.add_signal_handler(sig, self._stop_event.set)
148
+ except NotImplementedError:
149
+ # Windows 等平台不支持
150
+ pass
151
+
152
+ async with self._server:
153
+ stop_task = asyncio.create_task(self._stop_event.wait())
154
+ serve_task = asyncio.create_task(self._server.serve_forever())
155
+ tasks = {stop_task, serve_task}
156
+ if mode == "embedded":
157
+ tasks.add(asyncio.create_task(self._watch_stdin()))
158
+ done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
159
+ for t in pending:
160
+ t.cancel()
161
+ # 清理异常
162
+ for t in done:
163
+ exc = t.exception()
164
+ if exc is not None and not isinstance(exc, asyncio.CancelledError):
165
+ logger.error("server task 异常: %r", exc)
166
+
167
+ await self.engine.teardown()
168
+
169
+ # 清理 unix sock 文件
170
+ if kind == "unix":
171
+ import os
172
+
173
+ try:
174
+ os.unlink(addr[0])
175
+ except FileNotFoundError:
176
+ pass
177
+ logger.info("KittyServer 已停止")
178
+
179
+ def stop(self) -> None:
180
+ """由信号/SHUTDOWN 消息触发。"""
181
+ if self._stop_event is not None:
182
+ self._stop_event.set()
183
+
184
+ async def _watch_stdin(self) -> None:
185
+ """embedded 模式:监听 stdin EOF,父进程退出时自动停止 server。"""
186
+ loop = asyncio.get_running_loop()
187
+ reader = asyncio.StreamReader()
188
+ protocol = asyncio.StreamReaderProtocol(reader)
189
+ # 使用 sys.stdin.buffer(二进制模式):避免文本层缓冲导致 asyncio transport
190
+ # 看不到底层 fd 中已到达的字节,且与 __main__.py 中 readline 消费第一行后
191
+ # 底层 fd 位置保持一致。
192
+ await loop.connect_read_pipe(lambda: protocol, sys.stdin.buffer)
193
+ try:
194
+ while True:
195
+ data = await reader.read(4096)
196
+ if not data: # EOF
197
+ break
198
+ except Exception:
199
+ pass
200
+ logger.info("KittyServer: stdin EOF,embedded 模式自动退出")
201
+ self.stop()
202
+
203
+ # ------------------------------------------------------------------
204
+ # 连接处理
205
+ # ------------------------------------------------------------------
206
+
207
+ async def _handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
208
+ ctx = _ConnectionContext(connection_id=gen_unique_id(prefix="conn"))
209
+ # 初始化该连接的状态计数器。所有 TaskStatus 都显式置 0,方便 client 侧
210
+ # 直接按 enum 取值渲染进度条。
211
+ self._connection_statistics[ctx.connection_id] = {s: 0 for s in TaskStatus}
212
+ peer = writer.get_extra_info("peername") or writer.get_extra_info("sockname")
213
+ logger.info("连接建立 connection_id=%s peer=%s", ctx.connection_id, peer)
214
+
215
+ try:
216
+ while True:
217
+ try:
218
+ msg_type, payload = await aread_frame(reader)
219
+ except (asyncio.IncompleteReadError, ConnectionError):
220
+ break
221
+ except Exception as e:
222
+ logger.warning("connection_id=%s 读帧异常: %r", ctx.connection_id, e)
223
+ break
224
+
225
+ try:
226
+ should_close = await self._dispatch(ctx, msg_type, payload, writer)
227
+ except Exception as e:
228
+ logger.exception("connection_id=%s 处理 msg_type=%d 异常", ctx.connection_id, msg_type)
229
+ await self._send(writer, MsgType.MSG_ERROR, {"error": f"{type(e).__name__}: {e}"})
230
+ should_close = False
231
+
232
+ if should_close:
233
+ break
234
+ finally:
235
+ # 先断开 task→connection 映射 + 移除 statistics,callback 对遗留任务
236
+ # 变为 no-op,避免 cleanup_tasks 期间的 transition 事件仍去维护
237
+ # 一份马上要丢弃的计数。
238
+ leftover = list(ctx.task_ids)
239
+ for task_id in leftover:
240
+ self._task_to_connection.pop(task_id, None)
241
+ self._connection_statistics.pop(ctx.connection_id, None)
242
+ # 清理该连接持有的任务:取消 + 等待进入终态 + 非 persist 移除
243
+ if leftover:
244
+ try:
245
+ await self.engine.cleanup_tasks(leftover)
246
+ except Exception:
247
+ logger.exception("connection_id=%s cleanup_tasks 异常", ctx.connection_id)
248
+ try:
249
+ writer.close()
250
+ await writer.wait_closed()
251
+ except Exception:
252
+ pass
253
+ logger.info("连接关闭 connection_id=%s", ctx.connection_id)
254
+
255
+ async def _send(self, writer: asyncio.StreamWriter, msg_type: int, payload: dict[str, Any]) -> None:
256
+ writer.write(encode_frame(msg_type, payload))
257
+ await writer.drain()
258
+
259
+ # ------------------------------------------------------------------
260
+ # 消息分发
261
+ # ------------------------------------------------------------------
262
+
263
+ async def _dispatch(
264
+ self,
265
+ ctx: _ConnectionContext,
266
+ msg_type: int,
267
+ payload: dict[str, Any],
268
+ writer: asyncio.StreamWriter,
269
+ ) -> bool:
270
+ """处理一帧消息,返回是否应关闭连接。"""
271
+
272
+ if msg_type == MsgType.CMD_HELLO:
273
+ ov_dict = payload.get("overrides")
274
+ if ov_dict:
275
+ ctx.overrides = KittyEngineOverrides.model_validate(ov_dict)
276
+ await self._send(writer, MsgType.MSG_WELCOME, {"conn_id": ctx.connection_id}) # 协议字段名,勿改
277
+ return False
278
+
279
+ if msg_type == MsgType.CMD_PING:
280
+ await self._send(writer, MsgType.MSG_PONG, {})
281
+ return False
282
+
283
+ if msg_type == MsgType.CMD_SUBMIT:
284
+ task_id = payload.get("task_id") or gen_unique_id(prefix="task")
285
+ request = InferenceRequest.model_validate(payload["request"])
286
+ persist = bool(payload.get("persist", False))
287
+ timeout = payload.get("timeout")
288
+ # 必须先登记 task→connection 映射,再调 engine.submit:
289
+ # engine.submit 内部会同步触发首次 transition callback,届时回调
290
+ # 要能找到本任务所属连接。顺序反了首次 SUBMITTED 事件会被丢。
291
+ self._task_to_connection[task_id] = ctx.connection_id
292
+ try:
293
+ snap = self.engine.submit(
294
+ task_id=task_id,
295
+ request=request,
296
+ overrides=ctx.overrides,
297
+ persist=persist,
298
+ timeout=timeout,
299
+ )
300
+ except Exception:
301
+ # 登记后若 submit 自己抛了,回滚避免 _task_to_connection 泄漏
302
+ self._task_to_connection.pop(task_id, None)
303
+ raise
304
+ ctx.task_ids.add(task_id)
305
+ await self._send(writer, MsgType.MSG_SNAPSHOT, {"snapshot": snap.model_dump()})
306
+ return False
307
+
308
+ if msg_type == MsgType.CMD_CANCEL:
309
+ task_id = payload["task_id"]
310
+ force = bool(payload.get("force", False))
311
+ cancelled = self.engine.cancel(task_id, force=force)
312
+ await self._send(writer, MsgType.MSG_PONG, {"cancelled": cancelled})
313
+ return False
314
+
315
+ if msg_type == MsgType.CMD_POLL:
316
+ task_id = payload["task_id"]
317
+ snap = self.engine.snapshot(task_id)
318
+ await self._send(
319
+ writer,
320
+ MsgType.MSG_SNAPSHOT,
321
+ {"snapshot": snap.model_dump() if snap is not None else None},
322
+ )
323
+ return False
324
+
325
+ if msg_type == MsgType.CMD_POP:
326
+ task_id = payload["task_id"]
327
+ result = self.engine.pop_result(task_id)
328
+ if result is not None:
329
+ ctx.task_ids.discard(task_id)
330
+ self._task_to_connection.pop(task_id, None)
331
+ await self._send(
332
+ writer,
333
+ MsgType.MSG_RESULT,
334
+ {
335
+ "found": result is not None,
336
+ "result": result.model_dump() if result is not None else None,
337
+ "timed_out": False,
338
+ },
339
+ )
340
+ return False
341
+
342
+ if msg_type == MsgType.CMD_WAIT:
343
+ task_id = payload["task_id"]
344
+ timeout = payload.get("timeout")
345
+ result, timed_out = await self.engine.wait_task(task_id, timeout)
346
+ if result is not None:
347
+ ctx.task_ids.discard(task_id)
348
+ self._task_to_connection.pop(task_id, None)
349
+ await self._send(
350
+ writer,
351
+ MsgType.MSG_RESULT,
352
+ {
353
+ "found": result is not None,
354
+ "result": result.model_dump() if result is not None else None,
355
+ "timed_out": timed_out,
356
+ },
357
+ )
358
+ return False
359
+
360
+ if msg_type == MsgType.CMD_CLEAR_DONE:
361
+ cleared_ids = self.engine.clear_done()
362
+ # engine 一次性 GC 了所有终态任务;同步剔除两张映射 + 本连接 task_ids
363
+ for task_id in cleared_ids:
364
+ self._task_to_connection.pop(task_id, None)
365
+ if cleared_ids:
366
+ ctx.task_ids.difference_update(cleared_ids)
367
+ # 把 cleared_ids 一并回传给 client,client 端按 ids 同步本地句柄。
368
+ # 不能依赖 client 本地 snapshot.status 判断(任务可能未 poll 过状态,
369
+ # 本地 status 还是 RUNNING/PENDING,但 server 这边已经摘除条目了)。
370
+ await self._send(
371
+ writer,
372
+ MsgType.MSG_PONG,
373
+ {"count": len(cleared_ids), "cleared_ids": cleared_ids},
374
+ )
375
+ return False
376
+
377
+ if msg_type == MsgType.CMD_STATS:
378
+ # 返回本连接累计的状态分布。这是"累计口径":SUCCESS/FAILED/TIMEOUT/CANCELLED/REJECTED
379
+ # 不会因 POP/WAIT 消费而减,方便 client 侧做 tqdm 进度条。
380
+ statistics = self._connection_statistics.get(ctx.connection_id, {})
381
+ counts = {s.name: int(statistics.get(s, 0)) for s in TaskStatus}
382
+ finished = sum(counts[s] for s in ("SUCCESS", "FAILED", "TIMEOUT", "CANCELLED", "REJECTED"))
383
+ total = sum(counts.values())
384
+ await self._send(
385
+ writer,
386
+ MsgType.MSG_STATS,
387
+ {"total": total, "finished": finished, "counts": counts},
388
+ )
389
+ return False
390
+
391
+ if msg_type == MsgType.CMD_BYE:
392
+ return True
393
+
394
+ if msg_type == MsgType.CMD_MOCK_REQUEST:
395
+ request = InferenceRequest.model_validate(payload["request"])
396
+ mr = self.engine.mock_request(request, overrides=ctx.overrides)
397
+ await self._send(writer, MsgType.MSG_PREPARED_REQUEST, dict(mr))
398
+ return False
399
+
400
+ if msg_type == MsgType.CMD_SHUTDOWN:
401
+ await self._send(writer, MsgType.MSG_PONG, {})
402
+ self.stop()
403
+ return True
404
+
405
+ await self._send(writer, MsgType.MSG_ERROR, {"error": f"未知 msg_type={msg_type}"})
406
+ return False
407
+
408
+ # (wait_task 逻辑已下沉到 _KittyEngine.wait_task)
@@ -0,0 +1,112 @@
1
+ # llm_engine/model_config.py
2
+
3
+ import json
4
+ import threading
5
+
6
+ from pathlib import Path
7
+ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
8
+ from typing import Any, Dict, List, Literal, Optional
9
+
10
+ import kitty_logger
11
+
12
+ from .schemas import InferenceParameters
13
+
14
+ logger = kitty_logger.getLogger(__name__)
15
+
16
+
17
+ class ModelPricing(BaseModel):
18
+
19
+ input: float
20
+ output: float
21
+
22
+
23
+ class ModelConfig(BaseModel):
24
+ model_config = ConfigDict(populate_by_name=True)
25
+
26
+ name: str
27
+ model_id: str = Field(alias="model_id")
28
+ api_urls: List[str] = Field(alias="api_urls")
29
+ api_key: Optional[str] = Field(None, alias="api_key")
30
+ platform: Optional[str] = None
31
+
32
+ pricing: Optional[ModelPricing] = None
33
+ default_inference_parameters: Optional[InferenceParameters] = Field(default=None, alias="inference_parameters")
34
+ extra_headers: Dict[str, str] = Field(default_factory=dict)
35
+ extra_payload: Dict[str, Any] = Field(default_factory=dict)
36
+
37
+ # --- 私有属性:处理运行时逻辑 ---
38
+ _url_active_counts: Dict[str, int] = PrivateAttr()
39
+ _url_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
40
+
41
+ def model_post_init(self, __context: Any) -> None:
42
+ """Pydantic 初始化后的钩子,初始化各 URL 的活跃请求计数器。"""
43
+ if not self.api_urls:
44
+ raise ValueError(f"模型 '{self.name}' 的 api_urls 不能为空")
45
+ self._url_active_counts = {url: 0 for url in self.api_urls}
46
+
47
+ def get_url(self) -> str:
48
+ """选取当前活跃请求数最少的 URL,并将其计数 +1。"""
49
+ with self._url_lock:
50
+ url = min(self._url_active_counts, key=lambda u: self._url_active_counts[u])
51
+ self._url_active_counts[url] += 1
52
+ return url
53
+
54
+ def release_url(self, url: str) -> None:
55
+ """请求完成后,将对应 URL 的活跃计数 -1。"""
56
+ with self._url_lock:
57
+ if url in self._url_active_counts:
58
+ self._url_active_counts[url] = max(0, self._url_active_counts[url] - 1)
59
+
60
+ def calculate_cost(self, input_tokens: int, output_tokens: int) -> float:
61
+ if not self.pricing:
62
+ return 0.0
63
+ return input_tokens * self.pricing.input + output_tokens * self.pricing.output
64
+
65
+
66
+ class ModelConfigRegistry:
67
+
68
+ __slots__ = ("model_dict",)
69
+
70
+ def __init__(self):
71
+
72
+ self.model_dict: Dict[str, ModelConfig] = {}
73
+
74
+ def load_from_json(self, config_path: str, mode: Literal["reload", "merge"] = "reload") -> bool:
75
+
76
+ path = Path(config_path)
77
+ logger.info(f"正在加载模型配置: '{path}'")
78
+
79
+ if self.model_dict and mode == "reload":
80
+ self.model_dict.clear()
81
+
82
+ if not path.is_file():
83
+ raise FileNotFoundError(f"模型配置文件不存在或不是文件: '{path}'")
84
+
85
+ with open(path, mode="r", encoding="utf-8") as f:
86
+ try:
87
+ raw_data: Dict[str, Dict[str, Any]] = json.load(f)
88
+ except json.JSONDecodeError as e:
89
+ logger.error(f"JSON 格式错误: {e}")
90
+ raise
91
+
92
+ for model_name, info in raw_data.items():
93
+
94
+ if mode == "merge" and model_name in self.model_dict:
95
+ logger.debug(f"模型 '{model_name}' 已存在,将被新配置覆盖。")
96
+
97
+ model_config: ModelConfig = ModelConfig(name=model_name, **info)
98
+ self.model_dict[model_name] = model_config
99
+
100
+ logger.info(f"共加载 {len(self.model_dict)} 个模型配置。")
101
+
102
+ return True
103
+
104
+ def get(self, name: str) -> ModelConfig:
105
+
106
+ if name not in self.model_dict:
107
+
108
+ available_models: str = ", ".join(self.model_dict.keys())
109
+
110
+ raise KeyError(f"模型 '{name}' 未注册。可用模型: [{available_models}]")
111
+
112
+ return self.model_dict[name]