llm-engine-kitty 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,213 @@
1
+ # llm_engine/kitty/protocol.py
2
+
3
+ """
4
+ KittyEngine 的线缆协议。
5
+
6
+ 帧格式:
7
+ [4 bytes BE uint32: payload_len (含 msg_type 字节)]
8
+ [1 byte: msg_type]
9
+ [payload_len - 1 bytes: msgpack 编码的 dict]
10
+
11
+ 序列化:
12
+ msgpack(安全、紧凑;对比 pickle 不支持任意对象反序列化,消除 RCE 风险)。
13
+ 所有 payload 均由 pydantic .model_dump() 产出,已是纯 dict/list/scalar 结构。
14
+
15
+ 传输:
16
+ UDS (AF_UNIX + SOCK_STREAM) 或 TCP (AF_INET + SOCK_STREAM)。
17
+ """
18
+
19
+ import socket
20
+ import struct
21
+ import asyncio
22
+ from enum import IntEnum
23
+ from typing import Any
24
+ from urllib.parse import urlparse
25
+
26
+ import msgpack
27
+
28
+ MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB
29
+ _HEADER_LEN = 4
30
+
31
+
32
+ class MsgType(IntEnum):
33
+ # client → server
34
+ CMD_HELLO = 1
35
+ CMD_SUBMIT = 2
36
+ CMD_CANCEL = 3
37
+ CMD_POLL = 4
38
+ CMD_POP = 5
39
+ CMD_CLEAR_DONE = 6
40
+ CMD_PING = 7
41
+ CMD_BYE = 8
42
+ CMD_SHUTDOWN = 9 # 仅 embedded 模式客户端用,通知 server 整体退出
43
+ CMD_WAIT = 10 # 阻塞等待任务终态
44
+ CMD_STATS = 11 # 查询本连接所提交任务的状态分布(用于进度条)
45
+ CMD_MOCK_REQUEST = 12 # 组装请求但不发送,返回 PreparedRequest
46
+
47
+ # server → client
48
+ MSG_WELCOME = 101
49
+ MSG_RESULT = 102
50
+ MSG_SNAPSHOT = 103
51
+ MSG_PONG = 104
52
+ MSG_ERROR = 105
53
+ MSG_STATS = 106
54
+ MSG_PREPARED_REQUEST = 107
55
+
56
+
57
+ def encode_frame(msg_type: int, payload: dict[str, Any]) -> bytes:
58
+ """将消息类型和 payload 编码为帧字节流。
59
+
60
+ Args:
61
+ msg_type (int): 消息类型,取 MsgType 枚举值。
62
+ payload (dict[str, Any]): 消息体,需可被 msgpack 序列化。
63
+
64
+ Returns:
65
+ bytes: 4 字节大端长度头 + 1 字节消息类型 + msgpack 编码的 payload。
66
+ """
67
+ # use_bin_type=True(msgpack>=1.0 默认):str → msgpack str,bytes → msgpack bin
68
+ body = msgpack.packb(payload, use_bin_type=True)
69
+ if body is None:
70
+ raise ValueError("msgpack.packb returned None")
71
+ body_len = 1 + len(body)
72
+ if body_len > MAX_FRAME_SIZE:
73
+ raise ValueError(f"帧超限: {body_len} > {MAX_FRAME_SIZE}")
74
+ return struct.pack(">I", body_len) + bytes([int(msg_type)]) + body
75
+
76
+
77
+ def _decode_body(body: bytes) -> dict[str, Any]:
78
+ # raw=False:msgpack str 解码为 Python str;strict_map_key=False:允许非字符串 key(我方不产生,但宽容处理)
79
+ payload = msgpack.unpackb(body, raw=False, strict_map_key=False)
80
+ if not isinstance(payload, dict):
81
+ raise ValueError(f"非法 payload 类型: {type(payload).__name__}")
82
+ return payload
83
+
84
+
85
+ def recv_exact(sock: socket.socket, n: int) -> bytes:
86
+ """精确读 n 字节;EOF 抛 ConnectionError。"""
87
+ chunks: list[bytes] = []
88
+ remaining = n
89
+ while remaining > 0:
90
+ chunk = sock.recv(remaining)
91
+ if not chunk:
92
+ raise ConnectionError("socket EOF")
93
+ chunks.append(chunk)
94
+ remaining -= len(chunk)
95
+ return b"".join(chunks)
96
+
97
+
98
+ async def aread_frame(reader: "asyncio.StreamReader") -> tuple[int, dict[str, Any]]:
99
+ """从 asyncio.StreamReader 读一帧。EOF/错误抛异常。"""
100
+ header = await reader.readexactly(_HEADER_LEN)
101
+ (body_len,) = struct.unpack(">I", header)
102
+ if body_len <= 0 or body_len > MAX_FRAME_SIZE:
103
+ raise ValueError(f"非法帧长: {body_len}")
104
+ body = await reader.readexactly(body_len)
105
+ msg_type = body[0]
106
+ return msg_type, _decode_body(body[1:])
107
+
108
+
109
+ def recv_frame(sock: socket.socket) -> tuple[int, dict[str, Any]]:
110
+ """读一帧,返回 (msg_type, payload)。EOF/错误抛异常。"""
111
+ header = recv_exact(sock, _HEADER_LEN)
112
+ (body_len,) = struct.unpack(">I", header)
113
+ if body_len <= 0 or body_len > MAX_FRAME_SIZE:
114
+ raise ValueError(f"非法帧长: {body_len}")
115
+ body = recv_exact(sock, body_len)
116
+ msg_type = body[0]
117
+ return msg_type, _decode_body(body[1:])
118
+
119
+
120
+ # —— Endpoint 解析 —— #
121
+
122
+
123
+ def parse_endpoint(endpoint: str) -> tuple[str, tuple]:
124
+ """
125
+ 解析 endpoint 字符串。返回 (kind, addr)。
126
+ "unix:///path/to.sock" → ("unix", ("/path/to.sock",))
127
+ "tcp://host:port" → ("tcp", (host, port))
128
+ """
129
+ parsed = urlparse(endpoint)
130
+ if parsed.scheme == "unix":
131
+ # urlparse("unix:///tmp/x.sock") → path="/tmp/x.sock"
132
+ path = parsed.path
133
+ if not path:
134
+ raise ValueError(f"unix endpoint 缺路径: {endpoint}")
135
+ return "unix", (path,)
136
+ if parsed.scheme == "tcp":
137
+ host = parsed.hostname or "127.0.0.1"
138
+ port = parsed.port
139
+ if port is None:
140
+ raise ValueError(f"tcp endpoint 缺端口: {endpoint}")
141
+ return "tcp", (host, port)
142
+ raise ValueError(f"不支持的 endpoint scheme: {parsed.scheme}")
143
+
144
+
145
+ def create_listen_socket(endpoint: str, backlog: int = 128) -> socket.socket:
146
+ """按 endpoint 创建监听 socket。UDS 自动 unlink 已有文件。"""
147
+ kind, addr = parse_endpoint(endpoint)
148
+ if kind == "unix":
149
+ import os
150
+
151
+ path = addr[0]
152
+ try:
153
+ os.unlink(path)
154
+ except FileNotFoundError:
155
+ pass
156
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
157
+ sock.bind(path)
158
+ os.chmod(path, 0o600)
159
+ sock.listen(backlog)
160
+ return sock
161
+ else:
162
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
163
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
164
+ sock.bind(addr)
165
+ sock.listen(backlog)
166
+ return sock
167
+
168
+
169
+ def connect_endpoint(endpoint: str, connect_timeout: float = 30.0, retry_interval: float = 0.1) -> socket.socket:
170
+ """
171
+ 连 endpoint,带退避重试(embedded 模式下 worker 启动需要时间)。
172
+ 超时或重试耗尽抛 TimeoutError。
173
+ - connect_timeout <= 0:表示只尝试一次,不做任何等待/重试。
174
+ """
175
+ import time
176
+
177
+ kind, addr = parse_endpoint(endpoint)
178
+ attempts = 0
179
+ last_err: Exception | None = None
180
+
181
+ def _try_once() -> socket.socket:
182
+ if kind == "unix":
183
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
184
+ else:
185
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
186
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
187
+ # connect 失败时必须关闭 sock,否则在重试循环里每次失败都会泄漏一个 fd
188
+ # (embedded 模式下 worker 启动慢时尤为明显)。
189
+ try:
190
+ if kind == "unix":
191
+ sock.connect(addr[0])
192
+ else:
193
+ sock.connect(addr)
194
+ except Exception:
195
+ sock.close()
196
+ raise
197
+ return sock
198
+
199
+ # 至少尝试一次(即便 connect_timeout<=0),避免边界条件下直接报 TimeoutError
200
+ deadline = time.time() + connect_timeout if connect_timeout > 0 else None
201
+ interval = retry_interval
202
+ while True:
203
+ attempts += 1
204
+ try:
205
+ return _try_once()
206
+ except (FileNotFoundError, ConnectionRefusedError, OSError) as e:
207
+ last_err = e
208
+ if deadline is None or time.time() >= deadline:
209
+ break
210
+ time.sleep(interval)
211
+ interval = min(interval * 2, 1.0)
212
+
213
+ raise TimeoutError(f"连接 {endpoint} 超时(尝试 {attempts} 次,connect_timeout={connect_timeout}s): {last_err!r}")
@@ -0,0 +1,89 @@
1
+ # llm_engine/kitty/schemas.py
2
+
3
+ """
4
+ KittyEngine 专属数据结构:TaskSnapshot 与 WireResult。
5
+
6
+ - TaskSnapshot:跨进程可序列化的任务快照,替代 GeneralEngine 的 TaskHandle。
7
+ 不嵌 result,避免每次 poll 都回传大对象。
8
+ - WireResult:线缆上传输的瘦身 Result,不含 request 字段。客户端收到后用本地
9
+ cache 的原始 request 拼回完整 InferenceRequestResult。
10
+ - TransitionRecord:一次状态转移记录。由 worker 产生,随 snapshot 跨进程回传。
11
+ """
12
+
13
+ from datetime import datetime
14
+ from typing import Any, Optional, TypedDict
15
+
16
+ from pydantic import BaseModel, Field
17
+
18
+ from ..schemas import ModelOutput, TaskStatus
19
+
20
+
21
+ class TransitionRecord(BaseModel):
22
+ """一次状态转移记录。
23
+
24
+ - ``status``:进入的状态
25
+ - ``timestamp``:进入该状态时刻的人类可读本地时间,形如 '2026-05-11 10:30:45.123456'
26
+ - ``duration``:该状态实际持续秒数;在进入下一状态时由调用方回填;末条为 None
27
+ - ``desc``:进入该状态的原因
28
+ """
29
+
30
+ status: TaskStatus
31
+ timestamp: str
32
+ duration: Optional[float] = None
33
+ desc: str = ""
34
+
35
+ @property
36
+ def epoch(self) -> float:
37
+ return datetime.fromisoformat(self.timestamp).timestamp()
38
+
39
+
40
+ class TaskSnapshot(BaseModel):
41
+ """任务状态快照。跨进程传输用。
42
+
43
+ 所有时间点/时长信息由 ``transitions`` 流水账承载,不在本结构上再挂
44
+ submit_time / start_time / end_time / duration 等派生字段,避免:
45
+ - pydantic ``@property`` 不会进 ``model_dump()`` 造成序列化不一致;
46
+ - 同一信息在"字段"和"派生"两处同时存在带来的同步负担。
47
+ 如需取某阶段的时间点/时长,调用方自行遍历 ``transitions``。
48
+ """
49
+
50
+ task_id: str
51
+ status: TaskStatus
52
+ transitions: list[TransitionRecord] = Field(default_factory=list)
53
+ task_timeout: Optional[float] = None
54
+ persist: bool = False
55
+ error: Optional[str] = None
56
+ has_result: bool = False
57
+
58
+ def is_finished(self) -> bool:
59
+ return self.status in (
60
+ TaskStatus.SUCCESS,
61
+ TaskStatus.FAILED,
62
+ TaskStatus.TIMEOUT,
63
+ TaskStatus.CANCELLED,
64
+ TaskStatus.REJECTED,
65
+ )
66
+
67
+
68
+ class WireResult(BaseModel):
69
+ """线缆上的 Result。不含 request,收端自行拼回。
70
+
71
+ ``status`` 无默认值:调用方必须显式指定终态,避免"忘记传 → 失败被记成 SUCCESS"
72
+ 这类语义漂移。
73
+ """
74
+
75
+ success: bool
76
+ task_id: str
77
+ status: TaskStatus # 终态;client 用于同步 snapshot.status(区分 SUCCESS/FAILED/TIMEOUT/REJECTED/CANCELLED)
78
+ model_output: Optional[ModelOutput] = None
79
+ error_message: Optional[str] = None
80
+ duration: Optional[float] = None
81
+
82
+
83
+ class MockRequest(TypedDict):
84
+ """发给大模型服务商的完整请求数据,仅用于调试,不发出网络调用。"""
85
+
86
+ urls: list[str]
87
+ model: str
88
+ headers: dict[str, str]
89
+ payload: dict[str, Any]