agent-cli 0.70.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_cli/__init__.py +5 -0
- agent_cli/__main__.py +6 -0
- agent_cli/_extras.json +14 -0
- agent_cli/_requirements/.gitkeep +0 -0
- agent_cli/_requirements/audio.txt +79 -0
- agent_cli/_requirements/faster-whisper.txt +215 -0
- agent_cli/_requirements/kokoro.txt +425 -0
- agent_cli/_requirements/llm.txt +183 -0
- agent_cli/_requirements/memory.txt +355 -0
- agent_cli/_requirements/mlx-whisper.txt +222 -0
- agent_cli/_requirements/piper.txt +176 -0
- agent_cli/_requirements/rag.txt +402 -0
- agent_cli/_requirements/server.txt +154 -0
- agent_cli/_requirements/speed.txt +77 -0
- agent_cli/_requirements/vad.txt +155 -0
- agent_cli/_requirements/wyoming.txt +71 -0
- agent_cli/_tools.py +368 -0
- agent_cli/agents/__init__.py +23 -0
- agent_cli/agents/_voice_agent_common.py +136 -0
- agent_cli/agents/assistant.py +383 -0
- agent_cli/agents/autocorrect.py +284 -0
- agent_cli/agents/chat.py +496 -0
- agent_cli/agents/memory/__init__.py +31 -0
- agent_cli/agents/memory/add.py +190 -0
- agent_cli/agents/memory/proxy.py +160 -0
- agent_cli/agents/rag_proxy.py +128 -0
- agent_cli/agents/speak.py +209 -0
- agent_cli/agents/transcribe.py +671 -0
- agent_cli/agents/transcribe_daemon.py +499 -0
- agent_cli/agents/voice_edit.py +291 -0
- agent_cli/api.py +22 -0
- agent_cli/cli.py +106 -0
- agent_cli/config.py +503 -0
- agent_cli/config_cmd.py +307 -0
- agent_cli/constants.py +27 -0
- agent_cli/core/__init__.py +1 -0
- agent_cli/core/audio.py +461 -0
- agent_cli/core/audio_format.py +299 -0
- agent_cli/core/chroma.py +88 -0
- agent_cli/core/deps.py +191 -0
- agent_cli/core/openai_proxy.py +139 -0
- agent_cli/core/process.py +195 -0
- agent_cli/core/reranker.py +120 -0
- agent_cli/core/sse.py +87 -0
- agent_cli/core/transcription_logger.py +70 -0
- agent_cli/core/utils.py +526 -0
- agent_cli/core/vad.py +175 -0
- agent_cli/core/watch.py +65 -0
- agent_cli/dev/__init__.py +14 -0
- agent_cli/dev/cli.py +1588 -0
- agent_cli/dev/coding_agents/__init__.py +19 -0
- agent_cli/dev/coding_agents/aider.py +24 -0
- agent_cli/dev/coding_agents/base.py +167 -0
- agent_cli/dev/coding_agents/claude.py +39 -0
- agent_cli/dev/coding_agents/codex.py +24 -0
- agent_cli/dev/coding_agents/continue_dev.py +15 -0
- agent_cli/dev/coding_agents/copilot.py +24 -0
- agent_cli/dev/coding_agents/cursor_agent.py +48 -0
- agent_cli/dev/coding_agents/gemini.py +28 -0
- agent_cli/dev/coding_agents/opencode.py +15 -0
- agent_cli/dev/coding_agents/registry.py +49 -0
- agent_cli/dev/editors/__init__.py +19 -0
- agent_cli/dev/editors/base.py +89 -0
- agent_cli/dev/editors/cursor.py +15 -0
- agent_cli/dev/editors/emacs.py +46 -0
- agent_cli/dev/editors/jetbrains.py +56 -0
- agent_cli/dev/editors/nano.py +31 -0
- agent_cli/dev/editors/neovim.py +33 -0
- agent_cli/dev/editors/registry.py +59 -0
- agent_cli/dev/editors/sublime.py +20 -0
- agent_cli/dev/editors/vim.py +42 -0
- agent_cli/dev/editors/vscode.py +15 -0
- agent_cli/dev/editors/zed.py +20 -0
- agent_cli/dev/project.py +568 -0
- agent_cli/dev/registry.py +52 -0
- agent_cli/dev/skill/SKILL.md +141 -0
- agent_cli/dev/skill/examples.md +571 -0
- agent_cli/dev/terminals/__init__.py +19 -0
- agent_cli/dev/terminals/apple_terminal.py +82 -0
- agent_cli/dev/terminals/base.py +56 -0
- agent_cli/dev/terminals/gnome.py +51 -0
- agent_cli/dev/terminals/iterm2.py +84 -0
- agent_cli/dev/terminals/kitty.py +77 -0
- agent_cli/dev/terminals/registry.py +48 -0
- agent_cli/dev/terminals/tmux.py +58 -0
- agent_cli/dev/terminals/warp.py +132 -0
- agent_cli/dev/terminals/zellij.py +78 -0
- agent_cli/dev/worktree.py +856 -0
- agent_cli/docs_gen.py +417 -0
- agent_cli/example-config.toml +185 -0
- agent_cli/install/__init__.py +5 -0
- agent_cli/install/common.py +89 -0
- agent_cli/install/extras.py +174 -0
- agent_cli/install/hotkeys.py +48 -0
- agent_cli/install/services.py +87 -0
- agent_cli/memory/__init__.py +7 -0
- agent_cli/memory/_files.py +250 -0
- agent_cli/memory/_filters.py +63 -0
- agent_cli/memory/_git.py +157 -0
- agent_cli/memory/_indexer.py +142 -0
- agent_cli/memory/_ingest.py +408 -0
- agent_cli/memory/_persistence.py +182 -0
- agent_cli/memory/_prompt.py +91 -0
- agent_cli/memory/_retrieval.py +294 -0
- agent_cli/memory/_store.py +169 -0
- agent_cli/memory/_streaming.py +44 -0
- agent_cli/memory/_tasks.py +48 -0
- agent_cli/memory/api.py +113 -0
- agent_cli/memory/client.py +272 -0
- agent_cli/memory/engine.py +361 -0
- agent_cli/memory/entities.py +43 -0
- agent_cli/memory/models.py +112 -0
- agent_cli/opts.py +433 -0
- agent_cli/py.typed +0 -0
- agent_cli/rag/__init__.py +3 -0
- agent_cli/rag/_indexer.py +67 -0
- agent_cli/rag/_indexing.py +226 -0
- agent_cli/rag/_prompt.py +30 -0
- agent_cli/rag/_retriever.py +156 -0
- agent_cli/rag/_store.py +48 -0
- agent_cli/rag/_utils.py +218 -0
- agent_cli/rag/api.py +175 -0
- agent_cli/rag/client.py +299 -0
- agent_cli/rag/engine.py +302 -0
- agent_cli/rag/models.py +55 -0
- agent_cli/scripts/.runtime/.gitkeep +0 -0
- agent_cli/scripts/__init__.py +1 -0
- agent_cli/scripts/check_plugin_skill_sync.py +50 -0
- agent_cli/scripts/linux-hotkeys/README.md +63 -0
- agent_cli/scripts/linux-hotkeys/toggle-autocorrect.sh +45 -0
- agent_cli/scripts/linux-hotkeys/toggle-transcription.sh +58 -0
- agent_cli/scripts/linux-hotkeys/toggle-voice-edit.sh +58 -0
- agent_cli/scripts/macos-hotkeys/README.md +45 -0
- agent_cli/scripts/macos-hotkeys/skhd-config-example +5 -0
- agent_cli/scripts/macos-hotkeys/toggle-autocorrect.sh +12 -0
- agent_cli/scripts/macos-hotkeys/toggle-transcription.sh +37 -0
- agent_cli/scripts/macos-hotkeys/toggle-voice-edit.sh +37 -0
- agent_cli/scripts/nvidia-asr-server/README.md +99 -0
- agent_cli/scripts/nvidia-asr-server/pyproject.toml +27 -0
- agent_cli/scripts/nvidia-asr-server/server.py +255 -0
- agent_cli/scripts/nvidia-asr-server/shell.nix +32 -0
- agent_cli/scripts/nvidia-asr-server/uv.lock +4654 -0
- agent_cli/scripts/run-openwakeword.sh +11 -0
- agent_cli/scripts/run-piper-windows.ps1 +30 -0
- agent_cli/scripts/run-piper.sh +24 -0
- agent_cli/scripts/run-whisper-linux.sh +40 -0
- agent_cli/scripts/run-whisper-macos.sh +6 -0
- agent_cli/scripts/run-whisper-windows.ps1 +51 -0
- agent_cli/scripts/run-whisper.sh +9 -0
- agent_cli/scripts/run_faster_whisper_server.py +136 -0
- agent_cli/scripts/setup-linux-hotkeys.sh +72 -0
- agent_cli/scripts/setup-linux.sh +108 -0
- agent_cli/scripts/setup-macos-hotkeys.sh +61 -0
- agent_cli/scripts/setup-macos.sh +76 -0
- agent_cli/scripts/setup-windows.ps1 +63 -0
- agent_cli/scripts/start-all-services-windows.ps1 +53 -0
- agent_cli/scripts/start-all-services.sh +178 -0
- agent_cli/scripts/sync_extras.py +138 -0
- agent_cli/server/__init__.py +3 -0
- agent_cli/server/cli.py +721 -0
- agent_cli/server/common.py +222 -0
- agent_cli/server/model_manager.py +288 -0
- agent_cli/server/model_registry.py +225 -0
- agent_cli/server/proxy/__init__.py +3 -0
- agent_cli/server/proxy/api.py +444 -0
- agent_cli/server/streaming.py +67 -0
- agent_cli/server/tts/__init__.py +3 -0
- agent_cli/server/tts/api.py +335 -0
- agent_cli/server/tts/backends/__init__.py +82 -0
- agent_cli/server/tts/backends/base.py +139 -0
- agent_cli/server/tts/backends/kokoro.py +403 -0
- agent_cli/server/tts/backends/piper.py +253 -0
- agent_cli/server/tts/model_manager.py +201 -0
- agent_cli/server/tts/model_registry.py +28 -0
- agent_cli/server/tts/wyoming_handler.py +249 -0
- agent_cli/server/whisper/__init__.py +3 -0
- agent_cli/server/whisper/api.py +413 -0
- agent_cli/server/whisper/backends/__init__.py +89 -0
- agent_cli/server/whisper/backends/base.py +97 -0
- agent_cli/server/whisper/backends/faster_whisper.py +225 -0
- agent_cli/server/whisper/backends/mlx.py +270 -0
- agent_cli/server/whisper/languages.py +116 -0
- agent_cli/server/whisper/model_manager.py +157 -0
- agent_cli/server/whisper/model_registry.py +28 -0
- agent_cli/server/whisper/wyoming_handler.py +203 -0
- agent_cli/services/__init__.py +343 -0
- agent_cli/services/_wyoming_utils.py +64 -0
- agent_cli/services/asr.py +506 -0
- agent_cli/services/llm.py +228 -0
- agent_cli/services/tts.py +450 -0
- agent_cli/services/wake_word.py +142 -0
- agent_cli-0.70.5.dist-info/METADATA +2118 -0
- agent_cli-0.70.5.dist-info/RECORD +196 -0
- agent_cli-0.70.5.dist-info/WHEEL +4 -0
- agent_cli-0.70.5.dist-info/entry_points.txt +4 -0
- agent_cli-0.70.5.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""Common utilities for FastAPI server modules."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import contextlib
|
|
7
|
+
import importlib
|
|
8
|
+
import logging
|
|
9
|
+
from contextlib import asynccontextmanager
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Protocol
|
|
11
|
+
|
|
12
|
+
from rich.logging import RichHandler
|
|
13
|
+
|
|
14
|
+
from agent_cli import constants
|
|
15
|
+
from agent_cli.core.utils import console
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import wave
|
|
19
|
+
from collections.abc import AsyncIterator, Callable, Coroutine
|
|
20
|
+
from contextlib import AbstractAsyncContextManager
|
|
21
|
+
|
|
22
|
+
from fastapi import FastAPI, Request
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RegistryProtocol(Protocol):
|
|
28
|
+
"""Protocol for model registries."""
|
|
29
|
+
|
|
30
|
+
async def start(self) -> None:
|
|
31
|
+
"""Start the registry."""
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
async def stop(self) -> None:
|
|
35
|
+
"""Stop the registry."""
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
async def preload(self) -> None:
|
|
39
|
+
"""Preload models."""
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def create_lifespan(
|
|
44
|
+
registry: RegistryProtocol,
|
|
45
|
+
*,
|
|
46
|
+
wyoming_handler_module: str,
|
|
47
|
+
enable_wyoming: bool = True,
|
|
48
|
+
wyoming_uri: str = "tcp://0.0.0.0:10300",
|
|
49
|
+
) -> Callable[[FastAPI], AbstractAsyncContextManager[None]]:
|
|
50
|
+
"""Create a lifespan context manager for a server.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
registry: The model registry to manage.
|
|
54
|
+
wyoming_handler_module: Module path containing start_wyoming_server function.
|
|
55
|
+
enable_wyoming: Whether to start Wyoming server.
|
|
56
|
+
wyoming_uri: URI for Wyoming server.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
A lifespan context manager function for FastAPI.
|
|
60
|
+
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
@asynccontextmanager
|
|
64
|
+
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
|
65
|
+
"""Manage application lifecycle."""
|
|
66
|
+
wyoming_task: asyncio.Task[None] | None = None
|
|
67
|
+
|
|
68
|
+
# Start the registry
|
|
69
|
+
await registry.start()
|
|
70
|
+
|
|
71
|
+
# Start Wyoming server if enabled
|
|
72
|
+
if enable_wyoming:
|
|
73
|
+
try:
|
|
74
|
+
module = importlib.import_module(wyoming_handler_module)
|
|
75
|
+
start_wyoming_server: Callable[
|
|
76
|
+
[Any, str],
|
|
77
|
+
Coroutine[Any, Any, None],
|
|
78
|
+
] = module.start_wyoming_server
|
|
79
|
+
|
|
80
|
+
wyoming_task = asyncio.create_task(
|
|
81
|
+
start_wyoming_server(registry, wyoming_uri),
|
|
82
|
+
)
|
|
83
|
+
except ImportError:
|
|
84
|
+
logger.warning("Wyoming not available, skipping Wyoming server")
|
|
85
|
+
except Exception:
|
|
86
|
+
logger.exception("Failed to start Wyoming server")
|
|
87
|
+
|
|
88
|
+
yield
|
|
89
|
+
|
|
90
|
+
# Stop Wyoming server
|
|
91
|
+
if wyoming_task is not None:
|
|
92
|
+
wyoming_task.cancel()
|
|
93
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
94
|
+
await wyoming_task
|
|
95
|
+
|
|
96
|
+
# Stop the registry
|
|
97
|
+
await registry.stop()
|
|
98
|
+
|
|
99
|
+
return lifespan
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def configure_app(app: FastAPI) -> None:
|
|
103
|
+
"""Configure a FastAPI app with common middleware.
|
|
104
|
+
|
|
105
|
+
Adds:
|
|
106
|
+
- CORS middleware allowing all origins
|
|
107
|
+
- Request logging middleware
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
app: The FastAPI application to configure.
|
|
111
|
+
|
|
112
|
+
"""
|
|
113
|
+
from fastapi.middleware.cors import CORSMiddleware # noqa: PLC0415
|
|
114
|
+
|
|
115
|
+
# Add CORS middleware
|
|
116
|
+
app.add_middleware(
|
|
117
|
+
CORSMiddleware,
|
|
118
|
+
allow_origins=["*"],
|
|
119
|
+
allow_credentials=True,
|
|
120
|
+
allow_methods=["*"],
|
|
121
|
+
allow_headers=["*"],
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Add request logging middleware
|
|
125
|
+
@app.middleware("http")
|
|
126
|
+
async def log_requests(request: Any, call_next: Any) -> Any:
|
|
127
|
+
"""Log basic request information."""
|
|
128
|
+
return await log_requests_middleware(request, call_next)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def setup_rich_logging(log_level: str = "info") -> None:
|
|
132
|
+
"""Configure logging to use Rich for consistent, pretty output.
|
|
133
|
+
|
|
134
|
+
This configures:
|
|
135
|
+
- All Python loggers to use RichHandler
|
|
136
|
+
- Uvicorn's loggers to use the same format
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
log_level: Logging level (debug, info, warning, error).
|
|
140
|
+
console: Optional Rich console to use (creates new one if not provided).
|
|
141
|
+
|
|
142
|
+
"""
|
|
143
|
+
level = getattr(logging, log_level.upper(), logging.INFO)
|
|
144
|
+
|
|
145
|
+
# Create Rich handler with clean format
|
|
146
|
+
handler = RichHandler(
|
|
147
|
+
console=console,
|
|
148
|
+
show_time=True,
|
|
149
|
+
show_level=True,
|
|
150
|
+
show_path=False, # Don't show file:line - too verbose
|
|
151
|
+
rich_tracebacks=True,
|
|
152
|
+
markup=True,
|
|
153
|
+
)
|
|
154
|
+
handler.setFormatter(logging.Formatter("%(message)s"))
|
|
155
|
+
|
|
156
|
+
# Configure root logger
|
|
157
|
+
root = logging.getLogger()
|
|
158
|
+
root.handlers.clear()
|
|
159
|
+
root.addHandler(handler)
|
|
160
|
+
root.setLevel(level)
|
|
161
|
+
|
|
162
|
+
# Configure uvicorn loggers to use same handler
|
|
163
|
+
for uvicorn_logger_name in ("uvicorn", "uvicorn.access", "uvicorn.error"):
|
|
164
|
+
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
|
|
165
|
+
uvicorn_logger.handlers.clear()
|
|
166
|
+
uvicorn_logger.addHandler(handler)
|
|
167
|
+
uvicorn_logger.setLevel(level)
|
|
168
|
+
uvicorn_logger.propagate = False
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def setup_wav_file(
|
|
172
|
+
wav_file: wave.Wave_write,
|
|
173
|
+
*,
|
|
174
|
+
rate: int | None = None,
|
|
175
|
+
channels: int | None = None,
|
|
176
|
+
sample_width: int | None = None,
|
|
177
|
+
) -> None:
|
|
178
|
+
"""Configure a WAV file with standard audio parameters.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
wav_file: The WAV file writer to configure.
|
|
182
|
+
rate: Sample rate in Hz (default: constants.AUDIO_RATE).
|
|
183
|
+
channels: Number of channels (default: constants.AUDIO_CHANNELS).
|
|
184
|
+
sample_width: Sample width in bytes (default: constants.AUDIO_FORMAT_WIDTH).
|
|
185
|
+
|
|
186
|
+
"""
|
|
187
|
+
wav_file.setnchannels(channels or constants.AUDIO_CHANNELS)
|
|
188
|
+
wav_file.setsampwidth(sample_width or constants.AUDIO_FORMAT_WIDTH)
|
|
189
|
+
wav_file.setframerate(rate or constants.AUDIO_RATE)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
async def log_requests_middleware(
|
|
193
|
+
request: Request,
|
|
194
|
+
call_next: Any,
|
|
195
|
+
) -> Any:
|
|
196
|
+
"""Log basic request information.
|
|
197
|
+
|
|
198
|
+
This middleware logs incoming requests and warns on errors.
|
|
199
|
+
Use with FastAPI's @app.middleware("http") decorator.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
request: The incoming request.
|
|
203
|
+
call_next: The next middleware/handler in the chain.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
The response from the next handler.
|
|
207
|
+
|
|
208
|
+
"""
|
|
209
|
+
client_ip = request.client.host if request.client else "unknown"
|
|
210
|
+
logger.info("%s %s from %s", request.method, request.url.path, client_ip)
|
|
211
|
+
|
|
212
|
+
response = await call_next(request)
|
|
213
|
+
|
|
214
|
+
if response.status_code >= 400: # noqa: PLR2004
|
|
215
|
+
logger.warning(
|
|
216
|
+
"Request failed: %s %s → %d",
|
|
217
|
+
request.method,
|
|
218
|
+
request.url.path,
|
|
219
|
+
response.status_code,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
return response
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
"""Model manager with TTL-based unloading.
|
|
2
|
+
|
|
3
|
+
This module provides a concrete model manager that handles:
|
|
4
|
+
- Lazy loading of models on first request
|
|
5
|
+
- TTL-based automatic unloading when idle
|
|
6
|
+
- Active request tracking to prevent unload during processing
|
|
7
|
+
- Concurrent request coordination
|
|
8
|
+
|
|
9
|
+
The manager works with any backend that implements the BackendProtocol.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import asyncio
|
|
15
|
+
import contextlib
|
|
16
|
+
import logging
|
|
17
|
+
import time
|
|
18
|
+
from contextlib import asynccontextmanager
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from collections.abc import AsyncIterator
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class ModelConfig:
|
|
31
|
+
"""Configuration for a model."""
|
|
32
|
+
|
|
33
|
+
model_name: str
|
|
34
|
+
device: str = "auto"
|
|
35
|
+
ttl_seconds: int = 300
|
|
36
|
+
cache_dir: Path | None = None
|
|
37
|
+
|
|
38
|
+
def __post_init__(self) -> None:
|
|
39
|
+
"""Validate configuration."""
|
|
40
|
+
if self.ttl_seconds < 1:
|
|
41
|
+
msg = f"ttl_seconds must be >= 1, got {self.ttl_seconds}"
|
|
42
|
+
raise ValueError(msg)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class ModelStats:
|
|
47
|
+
"""Runtime statistics for a model."""
|
|
48
|
+
|
|
49
|
+
load_count: int = 0
|
|
50
|
+
unload_count: int = 0
|
|
51
|
+
total_requests: int = 0
|
|
52
|
+
total_audio_seconds: float = 0.0
|
|
53
|
+
total_processing_seconds: float = 0.0
|
|
54
|
+
last_load_time: float | None = None
|
|
55
|
+
last_request_time: float | None = None
|
|
56
|
+
load_duration_seconds: float | None = None
|
|
57
|
+
extra: dict[str, float] = field(default_factory=dict)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@runtime_checkable
|
|
61
|
+
class BackendProtocol(Protocol):
|
|
62
|
+
"""Protocol for model backends."""
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def is_loaded(self) -> bool:
|
|
66
|
+
"""Check if the model is loaded."""
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def device(self) -> str | None:
|
|
71
|
+
"""Get the device the model is loaded on."""
|
|
72
|
+
...
|
|
73
|
+
|
|
74
|
+
async def load(self) -> float:
|
|
75
|
+
"""Load the model, return load duration in seconds."""
|
|
76
|
+
...
|
|
77
|
+
|
|
78
|
+
async def unload(self) -> None:
|
|
79
|
+
"""Unload the model."""
|
|
80
|
+
...
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ModelManager:
|
|
84
|
+
"""Manages a model with TTL-based unloading.
|
|
85
|
+
|
|
86
|
+
The model is loaded lazily on first request and unloaded after
|
|
87
|
+
being idle for longer than the configured TTL.
|
|
88
|
+
|
|
89
|
+
Usage:
|
|
90
|
+
manager = ModelManager(backend, config)
|
|
91
|
+
await manager.start()
|
|
92
|
+
|
|
93
|
+
# Use request context for processing
|
|
94
|
+
async with manager.request():
|
|
95
|
+
result = await backend.process(...)
|
|
96
|
+
|
|
97
|
+
await manager.stop()
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
backend: BackendProtocol,
|
|
103
|
+
config: ModelConfig,
|
|
104
|
+
stats: ModelStats | None = None,
|
|
105
|
+
) -> None:
|
|
106
|
+
"""Initialize the model manager.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
backend: The backend instance to manage.
|
|
110
|
+
config: Model configuration.
|
|
111
|
+
stats: Optional stats instance (creates new one if not provided).
|
|
112
|
+
|
|
113
|
+
"""
|
|
114
|
+
self.backend = backend
|
|
115
|
+
self.config = config
|
|
116
|
+
self.stats = stats or ModelStats()
|
|
117
|
+
self._condition = asyncio.Condition()
|
|
118
|
+
self._active_requests = 0
|
|
119
|
+
self._unloading = False
|
|
120
|
+
self._unload_task: asyncio.Task[None] | None = None
|
|
121
|
+
self._shutdown = False
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def is_loaded(self) -> bool:
|
|
125
|
+
"""Check if the model is currently loaded."""
|
|
126
|
+
return self.backend.is_loaded
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def device(self) -> str | None:
|
|
130
|
+
"""Get the device the model is loaded on."""
|
|
131
|
+
return self.backend.device
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def active_requests(self) -> int:
|
|
135
|
+
"""Get the number of active requests."""
|
|
136
|
+
return self._active_requests
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def ttl_remaining(self) -> float | None:
|
|
140
|
+
"""Get seconds remaining before model unloads, or None if not loaded."""
|
|
141
|
+
if not self.is_loaded or self.stats.last_request_time is None:
|
|
142
|
+
return None
|
|
143
|
+
elapsed = time.time() - self.stats.last_request_time
|
|
144
|
+
remaining = self.config.ttl_seconds - elapsed
|
|
145
|
+
return max(0.0, remaining)
|
|
146
|
+
|
|
147
|
+
async def start(self) -> None:
|
|
148
|
+
"""Start the TTL unload watcher."""
|
|
149
|
+
if self._unload_task is None:
|
|
150
|
+
self._unload_task = asyncio.create_task(self._unload_watcher())
|
|
151
|
+
|
|
152
|
+
async def stop(self) -> None:
|
|
153
|
+
"""Stop the manager and unload the model."""
|
|
154
|
+
self._shutdown = True
|
|
155
|
+
if self._unload_task is not None:
|
|
156
|
+
self._unload_task.cancel()
|
|
157
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
158
|
+
await self._unload_task
|
|
159
|
+
self._unload_task = None
|
|
160
|
+
await self.unload()
|
|
161
|
+
|
|
162
|
+
async def get_model(self) -> Any:
|
|
163
|
+
"""Get the backend, loading it if necessary."""
|
|
164
|
+
await self._ensure_loaded()
|
|
165
|
+
return self.backend
|
|
166
|
+
|
|
167
|
+
@asynccontextmanager
|
|
168
|
+
async def request(self) -> AsyncIterator[None]:
|
|
169
|
+
"""Context manager for processing requests.
|
|
170
|
+
|
|
171
|
+
Ensures the model is loaded and tracks active requests.
|
|
172
|
+
Use this around any backend operations.
|
|
173
|
+
|
|
174
|
+
Example:
|
|
175
|
+
async with manager.request():
|
|
176
|
+
result = await manager.backend.synthesize(text)
|
|
177
|
+
|
|
178
|
+
"""
|
|
179
|
+
await self._begin_request()
|
|
180
|
+
try:
|
|
181
|
+
yield
|
|
182
|
+
finally:
|
|
183
|
+
await self._end_request()
|
|
184
|
+
|
|
185
|
+
async def unload(self) -> bool:
|
|
186
|
+
"""Unload the model from memory.
|
|
187
|
+
|
|
188
|
+
Returns True if model was unloaded, False if it wasn't loaded.
|
|
189
|
+
"""
|
|
190
|
+
async with self._condition:
|
|
191
|
+
while self._unloading:
|
|
192
|
+
await self._condition.wait()
|
|
193
|
+
|
|
194
|
+
if not self.backend.is_loaded:
|
|
195
|
+
return False
|
|
196
|
+
|
|
197
|
+
self._unloading = True
|
|
198
|
+
try:
|
|
199
|
+
while self._active_requests > 0:
|
|
200
|
+
logger.info(
|
|
201
|
+
"Waiting for %d active requests before unloading %s",
|
|
202
|
+
self._active_requests,
|
|
203
|
+
self.config.model_name,
|
|
204
|
+
)
|
|
205
|
+
await self._condition.wait()
|
|
206
|
+
|
|
207
|
+
if not self.backend.is_loaded:
|
|
208
|
+
return False
|
|
209
|
+
|
|
210
|
+
await self.backend.unload()
|
|
211
|
+
self.stats.unload_count += 1
|
|
212
|
+
return True
|
|
213
|
+
finally:
|
|
214
|
+
self._unloading = False
|
|
215
|
+
self._condition.notify_all()
|
|
216
|
+
|
|
217
|
+
async def _load_if_needed_locked(self) -> None:
|
|
218
|
+
"""Load the model if needed (expects condition lock held)."""
|
|
219
|
+
if not self.backend.is_loaded:
|
|
220
|
+
load_duration = await self.backend.load()
|
|
221
|
+
self.stats.load_count += 1
|
|
222
|
+
self.stats.last_load_time = time.time()
|
|
223
|
+
self.stats.load_duration_seconds = load_duration
|
|
224
|
+
self.stats.last_request_time = time.time()
|
|
225
|
+
|
|
226
|
+
async def _ensure_loaded(self) -> None:
|
|
227
|
+
"""Ensure the model is loaded."""
|
|
228
|
+
async with self._condition:
|
|
229
|
+
while self._unloading:
|
|
230
|
+
await self._condition.wait()
|
|
231
|
+
await self._load_if_needed_locked()
|
|
232
|
+
|
|
233
|
+
async def _begin_request(self) -> None:
|
|
234
|
+
"""Begin a request, waiting if unload is in progress."""
|
|
235
|
+
async with self._condition:
|
|
236
|
+
while self._unloading:
|
|
237
|
+
await self._condition.wait()
|
|
238
|
+
await self._load_if_needed_locked()
|
|
239
|
+
self._active_requests += 1
|
|
240
|
+
|
|
241
|
+
async def _end_request(self) -> None:
|
|
242
|
+
"""End a request and notify waiters if no more active requests."""
|
|
243
|
+
async with self._condition:
|
|
244
|
+
self._active_requests -= 1
|
|
245
|
+
self.stats.last_request_time = time.time()
|
|
246
|
+
if self._active_requests == 0:
|
|
247
|
+
self._condition.notify_all()
|
|
248
|
+
|
|
249
|
+
async def _unload_watcher(self) -> None:
|
|
250
|
+
"""Background task that unloads model after TTL expires."""
|
|
251
|
+
check_interval = min(30, self.config.ttl_seconds / 2)
|
|
252
|
+
|
|
253
|
+
while not self._shutdown:
|
|
254
|
+
try:
|
|
255
|
+
await asyncio.sleep(check_interval)
|
|
256
|
+
|
|
257
|
+
async with self._condition:
|
|
258
|
+
if self._unloading:
|
|
259
|
+
continue
|
|
260
|
+
if not self.backend.is_loaded:
|
|
261
|
+
continue
|
|
262
|
+
|
|
263
|
+
if self.stats.last_request_time is None:
|
|
264
|
+
continue
|
|
265
|
+
|
|
266
|
+
idle_time = time.time() - self.stats.last_request_time
|
|
267
|
+
|
|
268
|
+
if idle_time >= self.config.ttl_seconds:
|
|
269
|
+
if self._active_requests == 0:
|
|
270
|
+
logger.info(
|
|
271
|
+
"Model %s idle for %.0fs (ttl=%ds), unloading",
|
|
272
|
+
self.config.model_name,
|
|
273
|
+
idle_time,
|
|
274
|
+
self.config.ttl_seconds,
|
|
275
|
+
)
|
|
276
|
+
await self.backend.unload()
|
|
277
|
+
self.stats.unload_count += 1
|
|
278
|
+
else:
|
|
279
|
+
logger.debug(
|
|
280
|
+
"Model %s would unload but has %d active requests",
|
|
281
|
+
self.config.model_name,
|
|
282
|
+
self._active_requests,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
except asyncio.CancelledError:
|
|
286
|
+
break
|
|
287
|
+
except Exception:
|
|
288
|
+
logger.exception("Error in unload watcher")
|