llm-engine-kitty 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,550 @@
1
+ # llm_engine/kitty/client.py
2
+
3
+ """
4
+ KittyClient:用户面向的客户端。
5
+
6
+ 两种部署模式:
7
+ 1. embedded(默认):传入 kwargs,内部 spawn 一个子进程启动 KittyServer,自动连接。
8
+ 适合单机单租户使用,生命周期与用户进程绑定。
9
+ 2. remote:KittyClient.connect(endpoint=..., **overrides) 连接到已经单独启动的
10
+ KittyServer(通过 `python -m llm_engine.kitty -c config.yaml` 启动)。
11
+
12
+ 用户无需接触 KittyEngineConfig / KittyEngineOverrides 类,所有参数以 kwargs 形式传入。
13
+ """
14
+
15
+ import json
16
+ import os
17
+ import socket
18
+ import subprocess
19
+ import sys
20
+ import tempfile
21
+ import threading
22
+ import time
23
+ from dataclasses import dataclass
24
+ from typing import IO, Any, Optional, TypeAlias
25
+
26
+ import kitty_logger
27
+ from tqdm import tqdm
28
+
29
+ from ..schemas import (
30
+ InferenceParameters,
31
+ InferenceRequest,
32
+ InferenceRequestResult,
33
+ Message,
34
+ MessageRole,
35
+ PreparedRequest,
36
+ TaskStatus,
37
+ )
38
+ from ..utils import gen_unique_id
39
+ from .config import KittyEngineConfig, KittyEngineOverrides
40
+ from .protocol import MsgType, connect_endpoint, encode_frame, recv_frame
41
+ from .schemas import MockRequest, TaskSnapshot, WireResult
42
+
43
+ logger = kitty_logger.getLogger(__name__)
44
+
45
+
46
+ @dataclass
47
+ class ClientTaskHandle:
48
+ """客户端侧的任务句柄。request 由客户端缓存(WireResult 不含 request)。"""
49
+
50
+ task_id: str
51
+ request: InferenceRequest
52
+ snapshot: TaskSnapshot
53
+
54
+ @property
55
+ def status(self) -> TaskStatus:
56
+ return self.snapshot.status
57
+
58
+ def is_finished(self) -> bool:
59
+ return self.snapshot.is_finished()
60
+
61
+
62
+ Key: TypeAlias = "str | ClientTaskHandle"
63
+
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # subprocess 入口(embedded 模式)已移至 __main__.py --embedded 分支
67
+ # ---------------------------------------------------------------------------
68
+
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # KittyClient
72
+ # ---------------------------------------------------------------------------
73
+
74
+
75
+ class KittyClient:
76
+ """
77
+ KittyClient 客户端。直接构造为 embedded 模式;KittyClient.connect() 为 remote 模式。
78
+
79
+ 线程安全 / 并发使用约束(重要):
80
+ 本实现仅按**单线程顺序使用**设计。`_request` 用一把全局 `self._lock`
81
+ 把 send + recv 整段串行化,因此一旦某个调用阻塞(典型如 `wait()`:
82
+ server 端 `wait_task` 会一直 await `done_event` 到任务终态),**该
83
+ 连接上所有其它 RPC(含 `submit/poll/stats/cancel/clear_done`)都会
84
+ 排队**,并发使用会退化为完全串行。
85
+
86
+ 如需并发使用:
87
+ - 推荐:每个使用线程持有自己的 `KittyClient`(embedded 不可,否则
88
+ 每条都会拉一个 worker 子进程;用 `KittyClient.connect(endpoint=...)`
89
+ 连同一个 server 即可)。
90
+ - 不推荐:在多个线程间共享一个实例并自己加锁——等价于退化为单线程。
91
+
92
+ 协议层目前未做帧多路复用;未来若改造,本约束可放宽。
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ *,
98
+ # —— 必填(embedded 模式) ——
99
+ registry_path: str,
100
+ default_model: Optional[str] = None,
101
+ # —— 引擎默认 ——
102
+ default_api_key: Optional[str] = None,
103
+ default_inference_parameters: Optional[InferenceParameters] = None,
104
+ extra_headers: Optional[dict[str, str]] = None,
105
+ extra_payload: Optional[dict[str, Any]] = None,
106
+ # —— 并发/重试 ——
107
+ max_global_concurrency: int = 64,
108
+ default_timeout: Optional[float] = None,
109
+ default_max_retries: int = 5,
110
+ default_base_delay: int = 2,
111
+ # —— server 专属 ——
112
+ listen: Optional[str] = None,
113
+ log_level: str = "INFO",
114
+ # —— 连接调优 ——
115
+ http_client_connect_timeout: float = 10.0,
116
+ http_client_read_timeout: float = 3600.0,
117
+ ) -> None:
118
+ logger.debug("初始化KittyClient...")
119
+ self._init_common_state()
120
+
121
+ # embedded 模式:构造完整 config 并 spawn worker
122
+ if listen is None:
123
+ # 默认随机 UDS 路径,避免多实例冲突
124
+ tmpdir = tempfile.gettempdir()
125
+ listen = f"unix://{tmpdir}/kitty_{os.getpid()}_{int(time.time()*1000)}.sock"
126
+
127
+ config = KittyEngineConfig(
128
+ registry_path=registry_path,
129
+ default_model=default_model,
130
+ default_api_key=default_api_key,
131
+ default_inference_parameters=default_inference_parameters,
132
+ extra_headers=extra_headers or {},
133
+ extra_payload=extra_payload or {},
134
+ max_global_concurrency=max_global_concurrency,
135
+ default_timeout=default_timeout,
136
+ default_max_retries=default_max_retries,
137
+ default_base_delay=default_base_delay,
138
+ listen=listen,
139
+ log_level=log_level,
140
+ http_client_connect_timeout=http_client_connect_timeout,
141
+ http_client_read_timeout=http_client_read_timeout,
142
+ )
143
+ self._endpoint = listen
144
+ self._spawn_and_connect(config)
145
+ logger.debug("KittyClient初始化完成")
146
+
147
+ def _init_common_state(self) -> None:
148
+ """初始化跨两种部署模式共用的实例状态。"""
149
+ self._lock = threading.Lock()
150
+ self._sock: Optional[socket.socket] = None
151
+ self._proc: Optional[subprocess.Popen] = None
152
+ self._owns_process: bool = False
153
+ self._endpoint: str = ""
154
+ self._handles: dict[str, ClientTaskHandle] = {}
155
+ self._handles_lock = threading.Lock()
156
+ self._connected: bool = False
157
+
158
+ # ------------------------------------------------------------------
159
+ # 连接/启动
160
+ # ------------------------------------------------------------------
161
+
162
+ @classmethod
163
+ def connect(
164
+ cls,
165
+ endpoint: str,
166
+ *,
167
+ default_model: Optional[str] = None,
168
+ default_api_key: Optional[str] = None,
169
+ default_inference_parameters: Optional[InferenceParameters] = None,
170
+ extra_headers: Optional[dict[str, str]] = None,
171
+ extra_payload: Optional[dict[str, Any]] = None,
172
+ default_timeout: Optional[float] = None,
173
+ default_max_retries: Optional[int] = None,
174
+ default_base_delay: Optional[int] = None,
175
+ ) -> "KittyClient":
176
+ """连接到一个已运行的 KittyServer。kwargs 为连接级覆盖。"""
177
+ overrides = KittyEngineOverrides(
178
+ default_model=default_model,
179
+ default_api_key=default_api_key,
180
+ default_inference_parameters=default_inference_parameters,
181
+ extra_headers=extra_headers or {},
182
+ extra_payload=extra_payload or {},
183
+ default_timeout=default_timeout,
184
+ default_max_retries=default_max_retries,
185
+ default_base_delay=default_base_delay,
186
+ )
187
+ # remote 模式不走 embedded __init__(其需要 registry_path 且会 spawn worker)
188
+ self = cls.__new__(cls)
189
+ self._init_common_state()
190
+ self._endpoint = endpoint
191
+ self._connect_remote(endpoint, overrides)
192
+ return self
193
+
194
+ def _spawn_and_connect(self, config: KittyEngineConfig) -> None:
195
+ # 把父进程 sys.path 合并进 PYTHONPATH,兼容未 pip install 而通过 sys.path.insert 加载的场景
196
+ env = os.environ.copy()
197
+ extra_paths = os.pathsep.join(p for p in sys.path if p)
198
+ existing = env.get("PYTHONPATH", "")
199
+ env["PYTHONPATH"] = extra_paths + (os.pathsep + existing if existing else "")
200
+
201
+ proc = subprocess.Popen(
202
+ [sys.executable, "-m", "llm_engine.kitty", "--embedded"],
203
+ stdin=subprocess.PIPE,
204
+ env=env,
205
+ )
206
+ # 把 config 作为第一行 JSON 写入 stdin;保持 stdin 不关闭,EOF = 父进程退出
207
+ assert proc.stdin is not None
208
+ proc.stdin.write((json.dumps(config.model_dump()) + "\n").encode("utf-8"))
209
+ proc.stdin.flush()
210
+
211
+ self._proc = proc
212
+ self._owns_process = True
213
+ logger.info("KittyClient spawn worker pid=%d endpoint=%s", proc.pid, self._endpoint)
214
+
215
+ try:
216
+ sock = connect_endpoint(self._endpoint, connect_timeout=30.0)
217
+ except Exception:
218
+ if proc.poll() is None:
219
+ proc.terminate()
220
+ proc.wait(timeout=5.0)
221
+ raise
222
+ self._sock = sock
223
+ self._hello(overrides=None)
224
+ self._connected = True
225
+
226
+ def _connect_remote(self, endpoint: str, overrides: Optional[KittyEngineOverrides]) -> None:
227
+ sock = connect_endpoint(endpoint, connect_timeout=10.0)
228
+ self._sock = sock
229
+ self._hello(overrides=overrides)
230
+ self._connected = True
231
+ logger.info("KittyClient 已连接到 %s", endpoint)
232
+
233
+ def _hello(self, overrides: Optional[KittyEngineOverrides]) -> None:
234
+ payload: dict[str, Any] = {}
235
+ if overrides is not None and overrides.has_any():
236
+ payload["overrides"] = overrides.model_dump(exclude_none=True)
237
+ msg_type, resp = self._request(MsgType.CMD_HELLO, payload)
238
+ if msg_type != MsgType.MSG_WELCOME:
239
+ raise RuntimeError(f"HELLO 失败: msg_type={msg_type}, resp={resp}")
240
+
241
+ # ------------------------------------------------------------------
242
+ # 关闭
243
+ # ------------------------------------------------------------------
244
+
245
+ def close(self, *, shutdown_remote: bool = False, timeout: float = 10.0) -> None:
246
+ """关闭客户端。embedded 模式会同时停止 worker 进程。
247
+
248
+ 所有网络 I/O 使用 socket timeout 保护,避免 close 阻塞在死连接上。
249
+ """
250
+ if not self._connected:
251
+ return
252
+ sock = self._sock
253
+ try:
254
+ if sock is not None:
255
+ # 给后续 sendall/recv 设上超时,避免 close 永久阻塞
256
+ try:
257
+ sock.settimeout(max(0.1, timeout))
258
+ except OSError:
259
+ pass
260
+ if self._owns_process or shutdown_remote:
261
+ try:
262
+ self._request(MsgType.CMD_SHUTDOWN, {})
263
+ except Exception as e:
264
+ logger.debug("CMD_SHUTDOWN 失败(忽略): %r", e)
265
+ else:
266
+ try:
267
+ with self._lock:
268
+ self._send(MsgType.CMD_BYE, {})
269
+ except Exception as e:
270
+ logger.debug("CMD_BYE 失败(忽略): %r", e)
271
+ try:
272
+ sock.close()
273
+ except Exception:
274
+ pass
275
+ self._sock = None
276
+ finally:
277
+ if self._owns_process and self._proc is not None:
278
+ # 关闭 stdin 触发 server 端 EOF 自动退出(embedded 模式)
279
+ if self._proc.stdin is not None:
280
+ try:
281
+ self._proc.stdin.close()
282
+ except Exception:
283
+ pass
284
+ try:
285
+ self._proc.wait(timeout=timeout)
286
+ except subprocess.TimeoutExpired:
287
+ pass
288
+ if self._proc.poll() is None:
289
+ self._proc.terminate()
290
+ try:
291
+ self._proc.wait(timeout=5.0)
292
+ except subprocess.TimeoutExpired:
293
+ pass
294
+ if self._proc.poll() is None:
295
+ self._proc.kill()
296
+ try:
297
+ self._proc.wait(timeout=2.0)
298
+ except subprocess.TimeoutExpired:
299
+ pass
300
+ self._proc = None
301
+ self._connected = False
302
+
303
+ def __enter__(self) -> "KittyClient":
304
+ return self
305
+
306
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
307
+ self.close()
308
+
309
+ def __del__(self) -> None:
310
+ # __del__ 中不做网络 I/O:解释器关闭阶段 socket/threading 可能已不可用。
311
+ # 只尽力回收 worker 子进程,避免孤儿。
312
+ try:
313
+ if getattr(self, "_sock", None) is not None:
314
+ try:
315
+ self._sock.close() # type: ignore[union-attr]
316
+ except Exception:
317
+ pass
318
+ self._sock = None
319
+ proc = getattr(self, "_proc", None)
320
+ if getattr(self, "_owns_process", False) and proc is not None:
321
+ try:
322
+ if proc.poll() is None:
323
+ proc.terminate()
324
+ proc.wait(timeout=2.0)
325
+ if proc.poll() is None:
326
+ proc.kill()
327
+ except Exception:
328
+ pass
329
+ self._proc = None
330
+ self._connected = False
331
+ except Exception:
332
+ pass
333
+
334
+ # ------------------------------------------------------------------
335
+ # 协议 I/O(全部加锁串行)
336
+ # ------------------------------------------------------------------
337
+
338
+ def _send(self, msg_type: int, payload: dict[str, Any]) -> None:
339
+ assert self._sock is not None
340
+ self._sock.sendall(encode_frame(msg_type, payload))
341
+
342
+ def _request(self, msg_type: int, payload: dict[str, Any]) -> tuple[int, dict[str, Any]]:
343
+ with self._lock:
344
+ assert self._sock is not None
345
+ self._sock.sendall(encode_frame(msg_type, payload))
346
+ resp_type, resp = recv_frame(self._sock)
347
+ if resp_type == MsgType.MSG_ERROR:
348
+ raise RuntimeError(f"server 返回错误: {resp.get('error')}")
349
+ return resp_type, resp
350
+
351
+ # ------------------------------------------------------------------
352
+ # 任务 API
353
+ # ------------------------------------------------------------------
354
+
355
+ def _resolve_key(self, key: Key) -> ClientTaskHandle:
356
+ if isinstance(key, ClientTaskHandle):
357
+ return key
358
+ if isinstance(key, str):
359
+ with self._handles_lock:
360
+ h = self._handles.get(key)
361
+ if h is None:
362
+ raise KeyError(f"未找到 task_id='{key}' 的本地句柄")
363
+ return h
364
+ raise TypeError(f"key 必须是 str 或 ClientTaskHandle,实际: {type(key).__name__}")
365
+
366
+ def submit(
367
+ self,
368
+ request: InferenceRequest,
369
+ *,
370
+ persist: bool = False,
371
+ timeout: Optional[float] = None,
372
+ ) -> ClientTaskHandle:
373
+ task_id = gen_unique_id(prefix="task")
374
+ _, resp = self._request(
375
+ MsgType.CMD_SUBMIT,
376
+ {
377
+ "task_id": task_id,
378
+ "request": request.model_dump(),
379
+ "persist": persist,
380
+ "timeout": timeout,
381
+ },
382
+ )
383
+ snap = TaskSnapshot.model_validate(resp["snapshot"])
384
+ handle = ClientTaskHandle(task_id=task_id, request=request, snapshot=snap)
385
+ with self._handles_lock:
386
+ self._handles[task_id] = handle
387
+ return handle
388
+
389
+ def poll(self, key: Key) -> ClientTaskHandle:
390
+ h = self._resolve_key(key)
391
+ _, resp = self._request(MsgType.CMD_POLL, {"task_id": h.task_id})
392
+ snap_dict = resp.get("snapshot")
393
+ if snap_dict is None:
394
+ # server 端已不存在(可能被 pop)
395
+ return h
396
+ h.snapshot = TaskSnapshot.model_validate(snap_dict)
397
+ return h
398
+
399
+ def get(self, key: Key) -> Optional[InferenceRequestResult]:
400
+ """非阻塞消费:未完成返回 None。"""
401
+ h = self._resolve_key(key)
402
+ _, resp = self._request(MsgType.CMD_POP, {"task_id": h.task_id})
403
+ if not resp.get("found"):
404
+ return None
405
+ return self._consume_result(h, WireResult.model_validate(resp["result"]))
406
+
407
+ def wait(self, key: Key, timeout: Optional[float] = None) -> InferenceRequestResult:
408
+ """阻塞到终态。超时抛 TimeoutError(server 端任务仍在跑)。"""
409
+ h = self._resolve_key(key)
410
+ _, resp = self._request(MsgType.CMD_WAIT, {"task_id": h.task_id, "timeout": timeout})
411
+ if resp.get("timed_out"):
412
+ raise TimeoutError(f"等待任务 {h.task_id} 超时({timeout}s)")
413
+ if not resp.get("found"):
414
+ raise RuntimeError(f"任务 {h.task_id} 在 server 端不存在")
415
+ return self._consume_result(h, WireResult.model_validate(resp["result"]))
416
+
417
+ def cancel(self, key: Key, force: bool = False) -> bool:
418
+ h = self._resolve_key(key)
419
+ _, resp = self._request(MsgType.CMD_CANCEL, {"task_id": h.task_id, "force": force})
420
+ return bool(resp.get("cancelled"))
421
+
422
+ def clear_done(self) -> int:
423
+ _, resp = self._request(MsgType.CMD_CLEAR_DONE, {})
424
+ # 必须按 server 回传的 cleared_ids 同步本地句柄,而不是凭本地
425
+ # `is_finished()` 判断:用户可能 submit 之后从未 poll/wait,本地
426
+ # snapshot.status 仍是 RUNNING/PENDING,但 server 已经把它清理掉了。
427
+ # 早期实现按本地状态删,会导致句柄永久残留 → 后续 wait/get 拿不到结果。
428
+ cleared_ids = resp.get("cleared_ids") or []
429
+ with self._handles_lock:
430
+ for tid in cleared_ids:
431
+ self._handles.pop(tid, None)
432
+ return int(resp.get("count", 0))
433
+
434
+ def mock_request_infer(self, query: str) -> MockRequest:
435
+ """mock_request 的字符串便捷封装,对应 infer 的调用形式。
436
+
437
+ Args:
438
+ query (str): 用户输入的字符串,将封装为单轮 USER 消息。
439
+
440
+ Returns:
441
+ MockRequest: 包含 URLs、模型名、请求头和负载的模拟请求。
442
+ """
443
+ return self.mock_request(InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]))
444
+
445
+ def mock_request(self, request: InferenceRequest) -> MockRequest:
446
+ """组装最终发给大模型服务商的请求,不发出任何网络调用。
447
+
448
+ 用于调试:验证 URL、Headers、Payload 是否符合预期。
449
+
450
+ Args:
451
+ request (InferenceRequest): 推理请求对象。
452
+
453
+ Returns:
454
+ MockRequest: 包含 URLs、模型名、请求头和负载的模拟请求。
455
+ """
456
+ _, resp = self._request(MsgType.CMD_MOCK_REQUEST, {"request": request.model_dump()})
457
+ return MockRequest(urls=resp["urls"], model=resp["model"], headers=resp["headers"], payload=resp["payload"])
458
+
459
+ def ping(self) -> None:
460
+ self._request(MsgType.CMD_PING, {})
461
+
462
+ def stats(self) -> dict[str, Any]:
463
+ """查询本连接的任务状态分布,适合驱动 `finished/total` 进度条。
464
+
465
+ Returns:
466
+ 形如::
467
+
468
+ {
469
+ "total": <int, Σ counts.values()>,
470
+ "finished": <int, SUCCESS + FAILED + TIMEOUT + CANCELLED + REJECTED>,
471
+ "counts": {"SUBMITTED": int, "ACCEPTED": int, ..., "REJECTED": int},
472
+ }
473
+
474
+ 计数口径说明(重要):
475
+ - 非终态(SUBMITTED/ACCEPTED/PENDING/RUNNING)维护的是**当前处于该
476
+ 状态的任务数**,会随状态前进而回落到 0。
477
+ - 终态(SUCCESS/FAILED/TIMEOUT/CANCELLED/REJECTED)只会累加,且 POP/WAIT
478
+ 消费结果、CLEAR_DONE 从 server 侧摘除条目都不会使其回退(这些操作不
479
+ 触发状态转移)。
480
+ - 因此对一个只管 submit / wait 的连接而言,`total` 等于累计 submit
481
+ 的任务数,`finished` 等于累计完成数,可直接当进度条的 `n/total` 用。
482
+ """
483
+ _, resp = self._request(MsgType.CMD_STATS, {})
484
+ return {
485
+ "total": int(resp.get("total", 0)),
486
+ "finished": int(resp.get("finished", 0)),
487
+ "counts": dict(resp.get("counts") or {}),
488
+ }
489
+
490
+ # ------------------------------------------------------------------
491
+ # 结果拼装
492
+ # ------------------------------------------------------------------
493
+
494
+ def _consume_result(self, handle: ClientTaskHandle, wr: WireResult) -> InferenceRequestResult:
495
+ # 同步终态到本地 snapshot:使用 WireResult 携带的 status,以区分 FAILED / REJECTED / CANCELLED
496
+ handle.snapshot.status = wr.status
497
+ with self._handles_lock:
498
+ self._handles.pop(handle.task_id, None)
499
+ return InferenceRequestResult(
500
+ success=wr.success,
501
+ task_id=wr.task_id,
502
+ request=handle.request,
503
+ model_output=wr.model_output,
504
+ error_message=wr.error_message,
505
+ duration=wr.duration,
506
+ )
507
+
508
+ # ------------------------------------------------------------------
509
+ # 同步便捷接口
510
+ # ------------------------------------------------------------------
511
+
512
+ def inference(self, request: InferenceRequest) -> InferenceRequestResult:
513
+ return self.wait(self.submit(request))
514
+
515
+ def infer(self, query: str) -> str:
516
+ result = self.inference(InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]))
517
+ if result.success and result.model_output is not None:
518
+ return result.model_output.content
519
+ return ""
520
+
521
+ def batch_inference(
522
+ self,
523
+ requests: list[InferenceRequest],
524
+ output_file: Optional[str] = None,
525
+ silent_mode: bool = False,
526
+ ) -> list[InferenceRequestResult]:
527
+ handles = [self.submit(req) for req in requests]
528
+ pbar: Optional[tqdm] = tqdm(total=len(requests), desc="Batch: Kitty") if not silent_mode else None
529
+ f: Optional[IO[str]] = open(output_file, "w", encoding="utf-8") if output_file else None
530
+ results: list[InferenceRequestResult] = []
531
+ try:
532
+ for h in handles:
533
+ result = self.wait(h)
534
+ results.append(result)
535
+ if f is not None:
536
+ f.write(result.model_dump_json() + "\n")
537
+ f.flush()
538
+ if pbar is not None:
539
+ pbar.update(1)
540
+ finally:
541
+ if pbar is not None:
542
+ pbar.close()
543
+ if f is not None:
544
+ f.close()
545
+ return results
546
+
547
+ def batch_infer(self, queries: list[str], **kwargs: Any) -> list[str]:
548
+ reqs = [InferenceRequest(messages=[Message(role=MessageRole.USER, content=q)]) for q in queries]
549
+ results = self.batch_inference(reqs, **kwargs)
550
+ return [r.model_output.content if r.success and r.model_output else "" for r in results]
@@ -0,0 +1,83 @@
1
+ # llm_engine/kitty/config.py
2
+
3
+ """
4
+ KittyEngine 配置数据结构。
5
+
6
+ - KittyEngineConfig:引擎完整配置。由 server 启动时从 YAML 加载;embedded 模式下
7
+ 由 KittyEngine.__init__ 根据 kwargs 内部构造后 pickle 给 worker 子进程。
8
+ - KittyEngineOverrides:client 连接时可选的覆盖集。每个字段均 Optional,不包含
9
+ server 全局资源字段(registry_path / max_global_concurrency / listen)。
10
+
11
+ 两者均为内部数据结构,不通过顶层 __init__.py 导出给用户。
12
+ """
13
+
14
+ from pathlib import Path
15
+ from typing import Any, Optional
16
+
17
+ import yaml
18
+ from pydantic import BaseModel, ConfigDict, Field
19
+
20
+ from ..schemas import InferenceParameters
21
+
22
+
23
+ class KittyEngineConfig(BaseModel):
24
+ """KittyEngine 完整运行时配置。"""
25
+
26
+ model_config = ConfigDict(extra="forbid")
27
+
28
+ # —— 模型注册表 ——
29
+ registry_path: str = Field(description="ModelConfigRegistry 的 JSON 配置路径")
30
+
31
+ # —— 引擎默认 ——
32
+ default_model: Optional[str] = None
33
+ default_api_key: Optional[str] = None
34
+ default_inference_parameters: Optional[InferenceParameters] = None
35
+ extra_headers: dict[str, str] = Field(default_factory=dict)
36
+ extra_payload: dict[str, Any] = Field(default_factory=dict)
37
+
38
+ # —— 并发与重试(server 全局,不可被 client 覆盖) ——
39
+ max_global_concurrency: int = Field(default=64, ge=1)
40
+ default_timeout: Optional[float] = None
41
+ # max_retries 语义为"包含首次的总尝试次数",因此最小值是 1(即不重试,只跑一次)。
42
+ # 若传 0,会让 _send_request 的 for-range 循环不执行而落入 "Unexpected end of retry loop" 死代码。
43
+ default_max_retries: int = Field(default=5, ge=1)
44
+ default_base_delay: int = Field(default=2, ge=1)
45
+
46
+ # —— server 专属 ——
47
+ listen: str = "unix:///tmp/kitty.sock"
48
+ log_level: str = "INFO"
49
+
50
+ # —— 连接调优 ——
51
+ http_client_connect_timeout: float = 10.0
52
+ http_client_read_timeout: float = 3600.0
53
+
54
+ @classmethod
55
+ def from_yaml(cls, path: str | Path) -> "KittyEngineConfig":
56
+ path = Path(path)
57
+ with open(path, mode="r", encoding="utf-8") as f:
58
+ data = yaml.safe_load(f) or {}
59
+ if not isinstance(data, dict):
60
+ raise ValueError(f"YAML 根节点必须是映射: {path}")
61
+ return cls.model_validate(data)
62
+
63
+
64
+ class KittyEngineOverrides(BaseModel):
65
+ """Client 连接时可携带的覆盖项。不含 server 全局资源字段。"""
66
+
67
+ model_config = ConfigDict(extra="forbid")
68
+
69
+ default_model: Optional[str] = None
70
+ default_api_key: Optional[str] = None
71
+ default_inference_parameters: Optional[InferenceParameters] = None
72
+ extra_headers: dict[str, str] = Field(default_factory=dict)
73
+ extra_payload: dict[str, Any] = Field(default_factory=dict)
74
+ default_timeout: Optional[float] = None
75
+ # 同 KittyEngineConfig:>=1,None 表示"不覆盖、用 server 默认"。
76
+ default_max_retries: Optional[int] = Field(default=None, ge=1)
77
+ default_base_delay: Optional[int] = Field(default=None, ge=1)
78
+
79
+ def has_any(self) -> bool:
80
+ """是否包含任何非默认值。空 overrides 可免发 HELLO。"""
81
+ dumped = self.model_dump(exclude_none=True)
82
+ # extra_headers / extra_payload 的默认值是空 dict,dump 会保留;过滤掉
83
+ return any(v for v in dumped.values())