llm-engine-kitty 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_engine/__init__.py +54 -0
- llm_engine/engine.py +771 -0
- llm_engine/general_engine.py +562 -0
- llm_engine/kitty/__init__.py +8 -0
- llm_engine/kitty/__main__.py +46 -0
- llm_engine/kitty/client.py +550 -0
- llm_engine/kitty/config.py +83 -0
- llm_engine/kitty/engine.py +1077 -0
- llm_engine/kitty/protocol.py +213 -0
- llm_engine/kitty/schemas.py +89 -0
- llm_engine/kitty/server.py +408 -0
- llm_engine/model_config.py +112 -0
- llm_engine/schemas.py +251 -0
- llm_engine/utils.py +34 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/METADATA +15 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/RECORD +18 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/WHEEL +5 -0
- llm_engine_kitty-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# llm_engine/kitty/protocol.py
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
KittyEngine 的线缆协议。
|
|
5
|
+
|
|
6
|
+
帧格式:
|
|
7
|
+
[4 bytes BE uint32: payload_len (含 msg_type 字节)]
|
|
8
|
+
[1 byte: msg_type]
|
|
9
|
+
[payload_len - 1 bytes: msgpack 编码的 dict]
|
|
10
|
+
|
|
11
|
+
序列化:
|
|
12
|
+
msgpack(安全、紧凑;对比 pickle 不支持任意对象反序列化,消除 RCE 风险)。
|
|
13
|
+
所有 payload 均由 pydantic .model_dump() 产出,已是纯 dict/list/scalar 结构。
|
|
14
|
+
|
|
15
|
+
传输:
|
|
16
|
+
UDS (AF_UNIX + SOCK_STREAM) 或 TCP (AF_INET + SOCK_STREAM)。
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import socket
|
|
20
|
+
import struct
|
|
21
|
+
import asyncio
|
|
22
|
+
from enum import IntEnum
|
|
23
|
+
from typing import Any
|
|
24
|
+
from urllib.parse import urlparse
|
|
25
|
+
|
|
26
|
+
import msgpack
|
|
27
|
+
|
|
28
|
+
MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB
|
|
29
|
+
_HEADER_LEN = 4
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MsgType(IntEnum):
|
|
33
|
+
# client → server
|
|
34
|
+
CMD_HELLO = 1
|
|
35
|
+
CMD_SUBMIT = 2
|
|
36
|
+
CMD_CANCEL = 3
|
|
37
|
+
CMD_POLL = 4
|
|
38
|
+
CMD_POP = 5
|
|
39
|
+
CMD_CLEAR_DONE = 6
|
|
40
|
+
CMD_PING = 7
|
|
41
|
+
CMD_BYE = 8
|
|
42
|
+
CMD_SHUTDOWN = 9 # 仅 embedded 模式客户端用,通知 server 整体退出
|
|
43
|
+
CMD_WAIT = 10 # 阻塞等待任务终态
|
|
44
|
+
CMD_STATS = 11 # 查询本连接所提交任务的状态分布(用于进度条)
|
|
45
|
+
CMD_MOCK_REQUEST = 12 # 组装请求但不发送,返回 PreparedRequest
|
|
46
|
+
|
|
47
|
+
# server → client
|
|
48
|
+
MSG_WELCOME = 101
|
|
49
|
+
MSG_RESULT = 102
|
|
50
|
+
MSG_SNAPSHOT = 103
|
|
51
|
+
MSG_PONG = 104
|
|
52
|
+
MSG_ERROR = 105
|
|
53
|
+
MSG_STATS = 106
|
|
54
|
+
MSG_PREPARED_REQUEST = 107
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def encode_frame(msg_type: int, payload: dict[str, Any]) -> bytes:
|
|
58
|
+
"""将消息类型和 payload 编码为帧字节流。
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
msg_type (int): 消息类型,取 MsgType 枚举值。
|
|
62
|
+
payload (dict[str, Any]): 消息体,需可被 msgpack 序列化。
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
bytes: 4 字节大端长度头 + 1 字节消息类型 + msgpack 编码的 payload。
|
|
66
|
+
"""
|
|
67
|
+
# use_bin_type=True(msgpack>=1.0 默认):str → msgpack str,bytes → msgpack bin
|
|
68
|
+
body = msgpack.packb(payload, use_bin_type=True)
|
|
69
|
+
if body is None:
|
|
70
|
+
raise ValueError("msgpack.packb returned None")
|
|
71
|
+
body_len = 1 + len(body)
|
|
72
|
+
if body_len > MAX_FRAME_SIZE:
|
|
73
|
+
raise ValueError(f"帧超限: {body_len} > {MAX_FRAME_SIZE}")
|
|
74
|
+
return struct.pack(">I", body_len) + bytes([int(msg_type)]) + body
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _decode_body(body: bytes) -> dict[str, Any]:
|
|
78
|
+
# raw=False:msgpack str 解码为 Python str;strict_map_key=False:允许非字符串 key(我方不产生,但宽容处理)
|
|
79
|
+
payload = msgpack.unpackb(body, raw=False, strict_map_key=False)
|
|
80
|
+
if not isinstance(payload, dict):
|
|
81
|
+
raise ValueError(f"非法 payload 类型: {type(payload).__name__}")
|
|
82
|
+
return payload
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def recv_exact(sock: socket.socket, n: int) -> bytes:
|
|
86
|
+
"""精确读 n 字节;EOF 抛 ConnectionError。"""
|
|
87
|
+
chunks: list[bytes] = []
|
|
88
|
+
remaining = n
|
|
89
|
+
while remaining > 0:
|
|
90
|
+
chunk = sock.recv(remaining)
|
|
91
|
+
if not chunk:
|
|
92
|
+
raise ConnectionError("socket EOF")
|
|
93
|
+
chunks.append(chunk)
|
|
94
|
+
remaining -= len(chunk)
|
|
95
|
+
return b"".join(chunks)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
async def aread_frame(reader: "asyncio.StreamReader") -> tuple[int, dict[str, Any]]:
|
|
99
|
+
"""从 asyncio.StreamReader 读一帧。EOF/错误抛异常。"""
|
|
100
|
+
header = await reader.readexactly(_HEADER_LEN)
|
|
101
|
+
(body_len,) = struct.unpack(">I", header)
|
|
102
|
+
if body_len <= 0 or body_len > MAX_FRAME_SIZE:
|
|
103
|
+
raise ValueError(f"非法帧长: {body_len}")
|
|
104
|
+
body = await reader.readexactly(body_len)
|
|
105
|
+
msg_type = body[0]
|
|
106
|
+
return msg_type, _decode_body(body[1:])
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def recv_frame(sock: socket.socket) -> tuple[int, dict[str, Any]]:
|
|
110
|
+
"""读一帧,返回 (msg_type, payload)。EOF/错误抛异常。"""
|
|
111
|
+
header = recv_exact(sock, _HEADER_LEN)
|
|
112
|
+
(body_len,) = struct.unpack(">I", header)
|
|
113
|
+
if body_len <= 0 or body_len > MAX_FRAME_SIZE:
|
|
114
|
+
raise ValueError(f"非法帧长: {body_len}")
|
|
115
|
+
body = recv_exact(sock, body_len)
|
|
116
|
+
msg_type = body[0]
|
|
117
|
+
return msg_type, _decode_body(body[1:])
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
# —— Endpoint 解析 —— #
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def parse_endpoint(endpoint: str) -> tuple[str, tuple]:
|
|
124
|
+
"""
|
|
125
|
+
解析 endpoint 字符串。返回 (kind, addr)。
|
|
126
|
+
"unix:///path/to.sock" → ("unix", ("/path/to.sock",))
|
|
127
|
+
"tcp://host:port" → ("tcp", (host, port))
|
|
128
|
+
"""
|
|
129
|
+
parsed = urlparse(endpoint)
|
|
130
|
+
if parsed.scheme == "unix":
|
|
131
|
+
# urlparse("unix:///tmp/x.sock") → path="/tmp/x.sock"
|
|
132
|
+
path = parsed.path
|
|
133
|
+
if not path:
|
|
134
|
+
raise ValueError(f"unix endpoint 缺路径: {endpoint}")
|
|
135
|
+
return "unix", (path,)
|
|
136
|
+
if parsed.scheme == "tcp":
|
|
137
|
+
host = parsed.hostname or "127.0.0.1"
|
|
138
|
+
port = parsed.port
|
|
139
|
+
if port is None:
|
|
140
|
+
raise ValueError(f"tcp endpoint 缺端口: {endpoint}")
|
|
141
|
+
return "tcp", (host, port)
|
|
142
|
+
raise ValueError(f"不支持的 endpoint scheme: {parsed.scheme}")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def create_listen_socket(endpoint: str, backlog: int = 128) -> socket.socket:
|
|
146
|
+
"""按 endpoint 创建监听 socket。UDS 自动 unlink 已有文件。"""
|
|
147
|
+
kind, addr = parse_endpoint(endpoint)
|
|
148
|
+
if kind == "unix":
|
|
149
|
+
import os
|
|
150
|
+
|
|
151
|
+
path = addr[0]
|
|
152
|
+
try:
|
|
153
|
+
os.unlink(path)
|
|
154
|
+
except FileNotFoundError:
|
|
155
|
+
pass
|
|
156
|
+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
157
|
+
sock.bind(path)
|
|
158
|
+
os.chmod(path, 0o600)
|
|
159
|
+
sock.listen(backlog)
|
|
160
|
+
return sock
|
|
161
|
+
else:
|
|
162
|
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
163
|
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
164
|
+
sock.bind(addr)
|
|
165
|
+
sock.listen(backlog)
|
|
166
|
+
return sock
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def connect_endpoint(endpoint: str, connect_timeout: float = 30.0, retry_interval: float = 0.1) -> socket.socket:
|
|
170
|
+
"""
|
|
171
|
+
连 endpoint,带退避重试(embedded 模式下 worker 启动需要时间)。
|
|
172
|
+
超时或重试耗尽抛 TimeoutError。
|
|
173
|
+
- connect_timeout <= 0:表示只尝试一次,不做任何等待/重试。
|
|
174
|
+
"""
|
|
175
|
+
import time
|
|
176
|
+
|
|
177
|
+
kind, addr = parse_endpoint(endpoint)
|
|
178
|
+
attempts = 0
|
|
179
|
+
last_err: Exception | None = None
|
|
180
|
+
|
|
181
|
+
def _try_once() -> socket.socket:
|
|
182
|
+
if kind == "unix":
|
|
183
|
+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
184
|
+
else:
|
|
185
|
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
186
|
+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
|
187
|
+
# connect 失败时必须关闭 sock,否则在重试循环里每次失败都会泄漏一个 fd
|
|
188
|
+
# (embedded 模式下 worker 启动慢时尤为明显)。
|
|
189
|
+
try:
|
|
190
|
+
if kind == "unix":
|
|
191
|
+
sock.connect(addr[0])
|
|
192
|
+
else:
|
|
193
|
+
sock.connect(addr)
|
|
194
|
+
except Exception:
|
|
195
|
+
sock.close()
|
|
196
|
+
raise
|
|
197
|
+
return sock
|
|
198
|
+
|
|
199
|
+
# 至少尝试一次(即便 connect_timeout<=0),避免边界条件下直接报 TimeoutError
|
|
200
|
+
deadline = time.time() + connect_timeout if connect_timeout > 0 else None
|
|
201
|
+
interval = retry_interval
|
|
202
|
+
while True:
|
|
203
|
+
attempts += 1
|
|
204
|
+
try:
|
|
205
|
+
return _try_once()
|
|
206
|
+
except (FileNotFoundError, ConnectionRefusedError, OSError) as e:
|
|
207
|
+
last_err = e
|
|
208
|
+
if deadline is None or time.time() >= deadline:
|
|
209
|
+
break
|
|
210
|
+
time.sleep(interval)
|
|
211
|
+
interval = min(interval * 2, 1.0)
|
|
212
|
+
|
|
213
|
+
raise TimeoutError(f"连接 {endpoint} 超时(尝试 {attempts} 次,connect_timeout={connect_timeout}s): {last_err!r}")
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# llm_engine/kitty/schemas.py
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
KittyEngine 专属数据结构:TaskSnapshot 与 WireResult。
|
|
5
|
+
|
|
6
|
+
- TaskSnapshot:跨进程可序列化的任务快照,替代 GeneralEngine 的 TaskHandle。
|
|
7
|
+
不嵌 result,避免每次 poll 都回传大对象。
|
|
8
|
+
- WireResult:线缆上传输的瘦身 Result,不含 request 字段。客户端收到后用本地
|
|
9
|
+
cache 的原始 request 拼回完整 InferenceRequestResult。
|
|
10
|
+
- TransitionRecord:一次状态转移记录。由 worker 产生,随 snapshot 跨进程回传。
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from typing import Any, Optional, TypedDict
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel, Field
|
|
17
|
+
|
|
18
|
+
from ..schemas import ModelOutput, TaskStatus
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TransitionRecord(BaseModel):
|
|
22
|
+
"""一次状态转移记录。
|
|
23
|
+
|
|
24
|
+
- ``status``:进入的状态
|
|
25
|
+
- ``timestamp``:进入该状态时刻的人类可读本地时间,形如 '2026-05-11 10:30:45.123456'
|
|
26
|
+
- ``duration``:该状态实际持续秒数;在进入下一状态时由调用方回填;末条为 None
|
|
27
|
+
- ``desc``:进入该状态的原因
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
status: TaskStatus
|
|
31
|
+
timestamp: str
|
|
32
|
+
duration: Optional[float] = None
|
|
33
|
+
desc: str = ""
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def epoch(self) -> float:
|
|
37
|
+
return datetime.fromisoformat(self.timestamp).timestamp()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TaskSnapshot(BaseModel):
|
|
41
|
+
"""任务状态快照。跨进程传输用。
|
|
42
|
+
|
|
43
|
+
所有时间点/时长信息由 ``transitions`` 流水账承载,不在本结构上再挂
|
|
44
|
+
submit_time / start_time / end_time / duration 等派生字段,避免:
|
|
45
|
+
- pydantic ``@property`` 不会进 ``model_dump()`` 造成序列化不一致;
|
|
46
|
+
- 同一信息在"字段"和"派生"两处同时存在带来的同步负担。
|
|
47
|
+
如需取某阶段的时间点/时长,调用方自行遍历 ``transitions``。
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
task_id: str
|
|
51
|
+
status: TaskStatus
|
|
52
|
+
transitions: list[TransitionRecord] = Field(default_factory=list)
|
|
53
|
+
task_timeout: Optional[float] = None
|
|
54
|
+
persist: bool = False
|
|
55
|
+
error: Optional[str] = None
|
|
56
|
+
has_result: bool = False
|
|
57
|
+
|
|
58
|
+
def is_finished(self) -> bool:
|
|
59
|
+
return self.status in (
|
|
60
|
+
TaskStatus.SUCCESS,
|
|
61
|
+
TaskStatus.FAILED,
|
|
62
|
+
TaskStatus.TIMEOUT,
|
|
63
|
+
TaskStatus.CANCELLED,
|
|
64
|
+
TaskStatus.REJECTED,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class WireResult(BaseModel):
|
|
69
|
+
"""线缆上的 Result。不含 request,收端自行拼回。
|
|
70
|
+
|
|
71
|
+
``status`` 无默认值:调用方必须显式指定终态,避免"忘记传 → 失败被记成 SUCCESS"
|
|
72
|
+
这类语义漂移。
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
success: bool
|
|
76
|
+
task_id: str
|
|
77
|
+
status: TaskStatus # 终态;client 用于同步 snapshot.status(区分 SUCCESS/FAILED/TIMEOUT/REJECTED/CANCELLED)
|
|
78
|
+
model_output: Optional[ModelOutput] = None
|
|
79
|
+
error_message: Optional[str] = None
|
|
80
|
+
duration: Optional[float] = None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class MockRequest(TypedDict):
|
|
84
|
+
"""发给大模型服务商的完整请求数据,仅用于调试,不发出网络调用。"""
|
|
85
|
+
|
|
86
|
+
urls: list[str]
|
|
87
|
+
model: str
|
|
88
|
+
headers: dict[str, str]
|
|
89
|
+
payload: dict[str, Any]
|