llm-engine-kitty 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,562 @@
1
+ # llm_engine/general_engine.py
2
+
3
+ """
4
+ GeneralEngine:自包含的任务化推理引擎。
5
+
6
+ - 后台维护一个 asyncio 事件循环(独立 daemon 线程)
7
+ - submit() 返回 TaskHandle,非阻塞
8
+ - wait()/get() 消费结果;poll() 非消费快照
9
+ - inference()/batch_inference() 为同步便捷接口
10
+ - 不依赖 SimpleEngine / SimpleCoroutineEngine
11
+
12
+ 详细设计见 docs/general_engine_design.md
13
+ """
14
+
15
+ import asyncio
16
+ import httpx
17
+ import threading
18
+ import time
19
+
20
+ import kitty_logger
21
+ from pydantic import ValidationError
22
+ from tqdm import tqdm
23
+ from typing import IO, Any, TypeAlias
24
+
25
+ from .model_config import ModelConfig, ModelConfigRegistry
26
+ from .schemas import (
27
+ ChatCompletionChoice,
28
+ ChatCompletionResponse,
29
+ InferenceParameters,
30
+ InferenceRequest,
31
+ InferenceRequestResult,
32
+ Message,
33
+ MessageRole,
34
+ ModelOutput,
35
+ PreparedRequest,
36
+ TaskHandle,
37
+ TaskStatus,
38
+ )
39
+ from .utils import gen_unique_id
40
+
41
+ logger = kitty_logger.getLogger(__name__)
42
+
43
+
44
+ Key: TypeAlias = str | TaskHandle
45
+
46
+
47
+ class GeneralEngine:
48
+
49
+ def __init__(
50
+ self,
51
+ model_registry: ModelConfigRegistry,
52
+ default_model: str,
53
+ default_api_key: str | None = None,
54
+ default_inference_parameters: InferenceParameters | None = None,
55
+ extra_headers: dict[str, str] | None = None,
56
+ extra_payload: dict[str, Any] | None = None,
57
+ stream: bool = False,
58
+ max_global_concurrency: int = 64,
59
+ default_timeout: float | None = None,
60
+ default_max_retries: int = 5,
61
+ default_base_delay: int = 2,
62
+ ) -> None:
63
+ # —— 静态配置 ——
64
+ self.model_registry: ModelConfigRegistry = model_registry
65
+ self.default_model_name: str = default_model
66
+ self.default_model: ModelConfig = model_registry.get(name=default_model)
67
+ self.default_api_key: str | None = default_api_key
68
+ self.default_inference_parameters: InferenceParameters | None = default_inference_parameters
69
+ self.extra_headers: dict[str, str] = extra_headers or {}
70
+ self.extra_payload: dict[str, Any] = extra_payload or {}
71
+ self.stream: bool = stream
72
+ self.max_global_concurrency: int = max_global_concurrency
73
+ self.default_timeout: float | None = default_timeout
74
+ self.default_max_retries: int = default_max_retries
75
+ self.default_base_delay: int = default_base_delay
76
+
77
+ # —— 运行时(由 setup 填充)——
78
+ self._loop: asyncio.AbstractEventLoop | None = None
79
+ self._loop_thread: threading.Thread | None = None
80
+ self._running: bool = False
81
+
82
+ self._global_sem: asyncio.Semaphore | None = None
83
+
84
+ self._handles: dict[str, TaskHandle] = {}
85
+ self._handles_lock: threading.Lock = threading.Lock()
86
+
87
+ # ------------------------------------------------------------------
88
+ # 生命周期
89
+ # ------------------------------------------------------------------
90
+
91
+ def setup(self) -> None:
92
+ """启动后台事件循环线程。幂等。"""
93
+ if self._running:
94
+ logger.warning(msg="GeneralEngine 已启动,请勿重复调用 setup()")
95
+ return
96
+
97
+ ready: threading.Event = threading.Event()
98
+
99
+ def _run() -> None:
100
+ loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
101
+ asyncio.set_event_loop(loop)
102
+ self._loop = loop
103
+ # 在 loop 线程中创建 asyncio 原语
104
+ self._global_sem = asyncio.Semaphore(value=self.max_global_concurrency)
105
+ ready.set()
106
+ try:
107
+ loop.run_forever()
108
+ finally:
109
+ # run_forever 结束后,把所有待清理的生成器/任务关掉
110
+ try:
111
+ tasks: set[asyncio.Task[Any]] = asyncio.all_tasks(loop)
112
+ for t in tasks:
113
+ t.cancel()
114
+ if tasks:
115
+ loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
116
+ finally:
117
+ loop.close()
118
+
119
+ self._loop_thread = threading.Thread(target=_run, name="GeneralEngine-Loop", daemon=True)
120
+ self._loop_thread.start()
121
+ ready.wait()
122
+ self._running = True
123
+ logger.info("GeneralEngine 已启动 (max_global_concurrency=%d)", self.max_global_concurrency)
124
+
125
+ def teardown(self, cancel_running: bool = True, timeout: float = 10.0) -> None:
126
+ """停止后台事件循环,回收线程。"""
127
+ if not self._running:
128
+ return
129
+
130
+ assert self._loop is not None
131
+ assert self._loop_thread is not None
132
+
133
+ if cancel_running:
134
+ with self._handles_lock:
135
+ handles: list[TaskHandle] = list[TaskHandle](self._handles.values())
136
+ for h in handles:
137
+ if not h.is_finished() and h._asyncio_task is not None:
138
+ self._loop.call_soon_threadsafe(callback=h._asyncio_task.cancel)
139
+
140
+ self._loop.call_soon_threadsafe(callback=self._loop.stop)
141
+ self._loop_thread.join(timeout=timeout)
142
+ self._running = False
143
+ logger.info("GeneralEngine 已停止")
144
+
145
+ def is_alive(self) -> bool:
146
+ return self._running and self._loop_thread is not None and self._loop_thread.is_alive()
147
+
148
+ def __enter__(self) -> "GeneralEngine":
149
+ self.setup()
150
+ return self
151
+
152
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
153
+ self.teardown()
154
+
155
+ # ------------------------------------------------------------------
156
+ # 提交与执行
157
+ # ------------------------------------------------------------------
158
+
159
+ def submit(
160
+ self,
161
+ request: InferenceRequest,
162
+ *,
163
+ persist: bool = False,
164
+ timeout: float | None = None,
165
+ ) -> TaskHandle:
166
+ if not self._running:
167
+ raise RuntimeError("GeneralEngine 未启动,请先调用 setup()")
168
+
169
+ task_id: str = gen_unique_id(prefix="task")
170
+
171
+ # 超时层级:request.timeout > submit timeout > default
172
+ effective_timeout: float | None = request.timeout if request.timeout is not None else (timeout if timeout is not None else self.default_timeout)
173
+
174
+ handle: TaskHandle = TaskHandle(
175
+ task_id=task_id,
176
+ request=request,
177
+ status=TaskStatus.SUBMITTED,
178
+ task_timeout=effective_timeout,
179
+ persist=persist,
180
+ )
181
+ # threading.Event 在任意线程都能创建
182
+ handle._done_threading_event = threading.Event()
183
+
184
+ with self._handles_lock:
185
+ self._handles[task_id] = handle
186
+
187
+ assert self._loop is not None
188
+ loop = self._loop
189
+
190
+ # 在 loop 线程里创建 Task,并同步回填到 handle。用 threading.Event 等回执,
191
+ # 确保 submit() 返回后 handle._asyncio_task 一定可用(否则 cancel() 竞态)。
192
+ bound: threading.Event = threading.Event()
193
+
194
+ def _schedule() -> None:
195
+ try:
196
+ handle._asyncio_task = loop.create_task(self._run_task(handle))
197
+ handle._done_event = asyncio.Event()
198
+ finally:
199
+ bound.set()
200
+
201
+ loop.call_soon_threadsafe(_schedule)
202
+ bound.wait()
203
+ return handle
204
+
205
+ async def _run_task(self, handle: TaskHandle) -> None:
206
+ request = handle.request
207
+ # —— 验证阶段:失败 → REJECTED ——
208
+ try:
209
+ model_name: str = request.model_name or self.default_model_name
210
+ if not model_name:
211
+ raise ValueError("no model_name (request/default all empty)")
212
+ model: ModelConfig = self.model_registry.get(name=model_name)
213
+ except Exception as e:
214
+ handle.status = TaskStatus.REJECTED
215
+ handle.error = f"{type(e).__name__}: {e}"
216
+ handle.error_exception = e
217
+ handle.result = InferenceRequestResult(
218
+ success=False,
219
+ task_id=handle.task_id,
220
+ request=request,
221
+ error_message=handle.error,
222
+ )
223
+ handle.end_time = time.time()
224
+ if handle._done_event is not None:
225
+ handle._done_event.set()
226
+ if handle._done_threading_event is not None:
227
+ handle._done_threading_event.set()
228
+ return
229
+
230
+ # 验证通过:ACCEPTED
231
+ handle.status = TaskStatus.ACCEPTED
232
+
233
+ assert self._global_sem is not None
234
+ try:
235
+ # 入队等 sem
236
+ handle.status = TaskStatus.PENDING
237
+ async with self._global_sem:
238
+ handle.status = TaskStatus.RUNNING
239
+ handle.start_time = time.time()
240
+ try:
241
+ # 用 asyncio.timeout 而非 asyncio.wait_for:超时直接抛
242
+ # TimeoutError,不经内层 CancelledError 上传,因此下面的
243
+ # CancelledError 分支可确定是"外部 cancel",避免双终态写入。
244
+ async with asyncio.timeout(handle.task_timeout):
245
+ model_output = await self._send_request(request, model)
246
+ handle.result = InferenceRequestResult(
247
+ success=True,
248
+ task_id=handle.task_id,
249
+ request=request,
250
+ model_output=model_output,
251
+ duration=time.time() - handle.start_time,
252
+ )
253
+ handle.status = TaskStatus.SUCCESS
254
+ except asyncio.CancelledError:
255
+ handle.status = TaskStatus.CANCELLED
256
+ handle.error = "cancelled"
257
+ handle.result = InferenceRequestResult(
258
+ success=False,
259
+ task_id=handle.task_id,
260
+ request=request,
261
+ error_message="cancelled",
262
+ )
263
+ logger.info("任务 %s 已取消", handle.task_id)
264
+ raise
265
+ except TimeoutError:
266
+ handle.status = TaskStatus.TIMEOUT
267
+ handle.error = "timeout"
268
+ handle.result = InferenceRequestResult(
269
+ success=False,
270
+ task_id=handle.task_id,
271
+ request=request,
272
+ error_message="timeout",
273
+ )
274
+ except Exception as e:
275
+ logger.exception("任务 %s 执行异常", handle.task_id)
276
+ handle.status = TaskStatus.FAILED
277
+ handle.error = f"{type(e).__name__}: {e}"
278
+ handle.error_exception = e
279
+ handle.result = InferenceRequestResult(
280
+ success=False,
281
+ task_id=handle.task_id,
282
+ request=request,
283
+ error_message=handle.error,
284
+ )
285
+ finally:
286
+ handle.end_time = time.time()
287
+ if handle._done_event is not None:
288
+ handle._done_event.set()
289
+ if handle._done_threading_event is not None:
290
+ handle._done_threading_event.set()
291
+
292
+ # ------------------------------------------------------------------
293
+ # 查询与控制
294
+ # ------------------------------------------------------------------
295
+
296
+ def _resolve(self, key: Key) -> TaskHandle:
297
+ if isinstance(key, TaskHandle):
298
+ return key
299
+ if isinstance(key, str):
300
+ with self._handles_lock:
301
+ h: TaskHandle | None = self._handles.get(key)
302
+ if h is None:
303
+ raise KeyError(f"未找到 task_id='{key}' 的任务")
304
+ return h
305
+ raise TypeError(f"key 必须是 str 或 TaskHandle,实际: {type(key).__name__}")
306
+
307
+ def _maybe_consume(self, handle: TaskHandle) -> None:
308
+ """终态 + 非 persist 时从 _handles 删除。"""
309
+ if handle.is_finished() and not handle.persist:
310
+ with self._handles_lock:
311
+ self._handles.pop(handle.task_id, None)
312
+
313
+ def poll(self, key: Key) -> TaskHandle:
314
+ """非消费快照:返回 handle 当前状态,不删除。"""
315
+ return self._resolve(key)
316
+
317
+ def get(self, key: Key) -> InferenceRequestResult | None:
318
+ """非阻塞消费:未完成返回 None;完成则返回 result 并(若非 persist)删除。"""
319
+ h: TaskHandle = self._resolve(key)
320
+ if not h.is_finished():
321
+ return None
322
+ result: Any | None = h.result
323
+ self._maybe_consume(handle=h)
324
+ return result
325
+
326
+ def wait(self, key: Key, timeout: float | None = None) -> InferenceRequestResult:
327
+ """阻塞到终态,返回 result。等待超时抛 TimeoutError(任务仍在后台运行)。"""
328
+ h: TaskHandle = self._resolve(key)
329
+ if h._done_threading_event is None:
330
+ raise RuntimeError("handle 缺少 done event,可能未经过 submit")
331
+ finished: bool = h._done_threading_event.wait(timeout=timeout)
332
+ if not finished:
333
+ raise TimeoutError(f"等待任务 {h.task_id} 超时({timeout}s),任务仍在后台运行")
334
+ result = h.result
335
+ self._maybe_consume(handle=h)
336
+ # 终态后 result 必然已被 _run_task 赋值(SUCCESS/FAILED/CANCELLED 均覆盖);
337
+ # 若仍为 None 则说明引擎内部状态异常,直接把失败信息返回给调用方。
338
+ if result is None:
339
+ return InferenceRequestResult(
340
+ success=False,
341
+ task_id=h.task_id,
342
+ request=h.request,
343
+ error_message=h.error or "unknown error (result missing)",
344
+ )
345
+ return result
346
+
347
+ def cancel(self, key: Key, force: bool = False) -> bool:
348
+ h: TaskHandle = self._resolve(key)
349
+ if h.is_finished():
350
+ return False
351
+ if h.status == TaskStatus.RUNNING and not force:
352
+ return False
353
+ if h._asyncio_task is not None and self._loop is not None:
354
+ self._loop.call_soon_threadsafe(callback=h._asyncio_task.cancel)
355
+ # 注意:实际状态更新在 _run_task 的 except CancelledError 分支完成
356
+ return True
357
+
358
+ def pop(self, key: Key) -> InferenceRequestResult | None:
359
+ """强制从 _handles 移除(无论 persist),返回 result(若有)。"""
360
+ h: TaskHandle = self._resolve(key)
361
+ with self._handles_lock:
362
+ self._handles.pop(h.task_id, None)
363
+ return h.result
364
+
365
+ def clear_done(self) -> int:
366
+ """清理所有终态 handle(含 persist)。返回清理个数。"""
367
+ with self._handles_lock:
368
+ done_ids: list[str] = [tid for tid, h in self._handles.items() if h.is_finished()]
369
+ for tid in done_ids:
370
+ self._handles.pop(tid, None)
371
+ return len(done_ids)
372
+
373
+ # ------------------------------------------------------------------
374
+ # 同步便捷接口
375
+ # ------------------------------------------------------------------
376
+
377
+ def inference(self, request: InferenceRequest) -> InferenceRequestResult:
378
+ handle: TaskHandle = self.submit(request)
379
+ # task_timeout 在 _run_task 的 asyncio.wait_for 里生效,此处 threading 侧不再重复设置,
380
+ # 避免两个超时竞争导致偶发 TimeoutError
381
+ return self.wait(handle)
382
+
383
+ def infer(self, query: str) -> str:
384
+ result: InferenceRequestResult = self.inference(request=InferenceRequest(messages=[Message(role=MessageRole.USER, content=query)]))
385
+ if result.success and result.model_output is not None:
386
+ return result.model_output.content
387
+ return ""
388
+
389
+ def batch_inference(
390
+ self,
391
+ requests: list[InferenceRequest],
392
+ output_file: str | None = None,
393
+ silent_mode: bool = False,
394
+ ) -> list[InferenceRequestResult]:
395
+ handles: list[TaskHandle] = [self.submit(request=req) for req in requests]
396
+
397
+ pbar: tqdm | None = tqdm(total=len(requests), desc=f"Batch: {self.default_model_name}") if not silent_mode else None
398
+ f: IO[str] | None = open(output_file, mode="w", encoding="utf-8") if output_file else None
399
+ results: list[InferenceRequestResult] = []
400
+ try:
401
+ for h in handles:
402
+ result: InferenceRequestResult = self.wait(key=h)
403
+ results.append(result)
404
+ if f is not None:
405
+ f.write(result.model_dump_json() + "\n")
406
+ f.flush()
407
+ if pbar is not None:
408
+ pbar.update(n=1)
409
+ finally:
410
+ if pbar is not None:
411
+ pbar.close()
412
+ if f is not None:
413
+ f.close()
414
+ return results
415
+
416
+ def batch_infer(self, queries: list[str], **kwargs: Any) -> list[str]:
417
+ reqs: list[InferenceRequest] = [InferenceRequest(messages=[Message(role=MessageRole.USER, content=q)]) for q in queries]
418
+ results: list[InferenceRequestResult] = self.batch_inference(requests=reqs, **kwargs)
419
+ return [r.model_output.content if r.success and r.model_output else "" for r in results]
420
+
421
+ # ------------------------------------------------------------------
422
+ # HTTP 核心(自包含)
423
+ # ------------------------------------------------------------------
424
+
425
+ def _get_api_key(self, request: InferenceRequest, model: ModelConfig) -> str:
426
+ if request.api_key is not None:
427
+ return request.api_key
428
+ if model.api_key is not None:
429
+ return model.api_key
430
+ if self.default_api_key is not None:
431
+ return self.default_api_key
432
+ logger.warning("没有找到可用的 api_key, 将传递空 api_key")
433
+ return ""
434
+
435
+ def _build_headers(self, request: InferenceRequest, model: ModelConfig) -> dict[str, str]:
436
+ headers: dict[str, str] = {
437
+ "Authorization": "Bearer " + self._get_api_key(request, model),
438
+ "Content-Type": "application/json",
439
+ }
440
+ headers.update(self.extra_headers)
441
+ headers.update(model.extra_headers)
442
+ headers.update(request.extra_headers)
443
+ return headers
444
+
445
+ def _build_payload(self, request: InferenceRequest, model: ModelConfig) -> dict[str, Any]:
446
+ payload: dict[str, Any] = {
447
+ "model": model.model_id,
448
+ "stream": request.stream if request.stream is not None else self.stream,
449
+ "messages": [msg.to_dict() for msg in request.messages],
450
+ }
451
+ if self.default_inference_parameters:
452
+ payload.update(self.default_inference_parameters.to_dict())
453
+ if model.default_inference_parameters:
454
+ payload.update(model.default_inference_parameters.to_dict())
455
+ if request.inference_parameters:
456
+ payload.update(request.inference_parameters.to_dict())
457
+ payload.update(self.extra_payload)
458
+ payload.update(model.extra_payload)
459
+ payload.update(request.extra_payload)
460
+ return payload
461
+
462
+ def _get_wait_time(self, status_code: int, attempt: int, base_delay: int) -> float:
463
+ if status_code == 429:
464
+ return 1.0
465
+ if status_code >= 500:
466
+ return float(base_delay**attempt)
467
+ return 5.0
468
+
469
+ async def _send_request(self, request: InferenceRequest, model: ModelConfig) -> ModelOutput:
470
+ """
471
+ 发送请求,带重试。每次 attempt 重新选 URL(支持多节点故障转移)。
472
+ 仅实现非流式;流式留待后续。
473
+ """
474
+ headers: dict[str, str] = self._build_headers(request, model)
475
+ payload: dict[str, Any] = self._build_payload(request, model)
476
+ stream = payload["stream"]
477
+ max_retries: int = self.default_max_retries
478
+ base_delay: int = self.default_base_delay
479
+
480
+ if stream:
481
+ raise NotImplementedError("GeneralEngine 首版暂未实现 stream 模式")
482
+
483
+ async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=10.0, read=3600.0)) as client:
484
+ for attempt in range(max_retries):
485
+ is_last = attempt == max_retries - 1
486
+ url = model.get_url()
487
+ try:
488
+ resp = await client.post(url, json=payload, headers=headers)
489
+ if resp.status_code != 200:
490
+ body = resp.text
491
+ if is_last:
492
+ logger.error("HTTP %d,已达最大重试次数,响应: %s", resp.status_code, body[:500])
493
+ raise httpx.HTTPStatusError(f"HTTP {resp.status_code}", request=resp.request, response=resp)
494
+ wait_time = self._get_wait_time(resp.status_code, attempt, base_delay)
495
+ logger.warning(
496
+ "HTTP %d, 第 %d 次重试,等待 %.1fs... 响应: %s",
497
+ resp.status_code,
498
+ attempt + 1,
499
+ wait_time,
500
+ body[:500],
501
+ )
502
+ await asyncio.sleep(wait_time)
503
+ continue
504
+
505
+ try:
506
+ parsed = ChatCompletionResponse.model_validate_json(resp.text)
507
+ except ValidationError as ve:
508
+ if is_last:
509
+ raise RuntimeError(f"响应 JSON 解析失败,已达最大重试次数: {ve}; body: {resp.text[:500]}") from ve
510
+ logger.warning("响应 JSON 解析失败,第 %d 次重试: %r", attempt + 1, ve)
511
+ await asyncio.sleep(base_delay)
512
+ continue
513
+
514
+ if not parsed.choices:
515
+ if is_last:
516
+ raise RuntimeError(f"empty choices after {max_retries} attempts, body: {resp.text[:500]}")
517
+ logger.warning("resp.choices 为空,疑似风控,第 %d 次重试...", attempt + 1)
518
+ await asyncio.sleep(base_delay)
519
+ continue
520
+ choice = parsed.choices[0]
521
+ if choice.finish_reason and choice.finish_reason != "stop":
522
+ logger.warning(
523
+ "finish_reason='%s' (非 stop), content_len=%d, usage=%s",
524
+ choice.finish_reason,
525
+ len(choice.message.content or ""),
526
+ parsed.usage.model_dump() if parsed.usage else None,
527
+ )
528
+ return ModelOutput(
529
+ role=choice.message.role,
530
+ content=choice.message.content,
531
+ reasoning=choice.message.reasoning_content,
532
+ finish_reason=choice.finish_reason,
533
+ usage=parsed.usage.model_dump() if parsed.usage else None,
534
+ )
535
+ except (httpx.NetworkError, httpx.TimeoutException) as e:
536
+ # HTTPStatusError 仅会在上面 is_last 分支抛出,不再在此捕获
537
+ if is_last:
538
+ logger.error("达到最大重试次数,最后错误: %r", e)
539
+ raise
540
+ wait_time = base_delay**attempt
541
+ logger.warning("网络异常: %r, 等待 %.1fs 后重试...", e, wait_time)
542
+ await asyncio.sleep(wait_time)
543
+ finally:
544
+ model.release_url(url)
545
+
546
+ raise RuntimeError("Unexpected end of retry loop")
547
+
548
+ # ------------------------------------------------------------------
549
+ # 预留:B 方案(async-native),首版不实现
550
+ # ------------------------------------------------------------------
551
+
552
+ async def a_submit(self, request: InferenceRequest, *, persist: bool = False, timeout: float | None = None) -> TaskHandle:
553
+ raise NotImplementedError("async-native 接口尚未实现")
554
+
555
+ async def a_wait(self, key: Key, timeout: float | None = None) -> InferenceRequestResult:
556
+ raise NotImplementedError("async-native 接口尚未实现")
557
+
558
+ async def a_inference(self, request: InferenceRequest) -> InferenceRequestResult:
559
+ raise NotImplementedError("async-native 接口尚未实现")
560
+
561
+ async def a_batch_inference(self, requests: list[InferenceRequest]) -> list[InferenceRequestResult]:
562
+ raise NotImplementedError("async-native 接口尚未实现")
@@ -0,0 +1,8 @@
1
+ # llm_engine/kitty/__init__.py
2
+
3
+ """KittyEngine:独立进程化的 LLM 推理引擎。"""
4
+
5
+ from .client import ClientTaskHandle, KittyClient
6
+ from .server import KittyServer
7
+
8
+ __all__ = ["KittyClient", "KittyServer", "ClientTaskHandle"]
@@ -0,0 +1,46 @@
1
+ # llm_engine/kitty/__main__.py
2
+
3
+ """
4
+ CLI 入口。
5
+
6
+ 用法:
7
+ python -m llm_engine.kitty -c /path/to/kitty_engine.yaml
8
+ """
9
+
10
+ import argparse
11
+ import asyncio
12
+ import logging
13
+ import sys
14
+
15
+ import kitty_logger
16
+
17
+
18
+ def main() -> int:
19
+ parser = argparse.ArgumentParser(prog="python -m llm_engine.kitty", description="KittyEngine 后端 server")
20
+ group = parser.add_mutually_exclusive_group(required=True)
21
+ group.add_argument("-c", "--config", help="YAML 配置文件路径")
22
+ group.add_argument("--embedded", action="store_true", help="embedded 模式:从 stdin 读取 JSON config,stdin EOF 时自动退出")
23
+ args = parser.parse_args()
24
+
25
+ from .config import KittyEngineConfig
26
+
27
+ if args.embedded:
28
+ import json
29
+
30
+ config_dict = json.loads(sys.stdin.readline())
31
+ config = KittyEngineConfig.model_validate(config_dict)
32
+ else:
33
+ config = KittyEngineConfig.from_yaml(args.config)
34
+
35
+ level = getattr(logging, config.log_level.upper(), logging.INFO)
36
+ kitty_logger.setup_logging(level=level)
37
+
38
+ from .server import KittyServer
39
+
40
+ server = KittyServer(config)
41
+ asyncio.run(server.run(mode="embedded" if args.embedded else "standalone"))
42
+ return 0
43
+
44
+
45
+ if __name__ == "__main__":
46
+ sys.exit(main())