agent-cli 0.70.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. agent_cli/__init__.py +5 -0
  2. agent_cli/__main__.py +6 -0
  3. agent_cli/_extras.json +14 -0
  4. agent_cli/_requirements/.gitkeep +0 -0
  5. agent_cli/_requirements/audio.txt +79 -0
  6. agent_cli/_requirements/faster-whisper.txt +215 -0
  7. agent_cli/_requirements/kokoro.txt +425 -0
  8. agent_cli/_requirements/llm.txt +183 -0
  9. agent_cli/_requirements/memory.txt +355 -0
  10. agent_cli/_requirements/mlx-whisper.txt +222 -0
  11. agent_cli/_requirements/piper.txt +176 -0
  12. agent_cli/_requirements/rag.txt +402 -0
  13. agent_cli/_requirements/server.txt +154 -0
  14. agent_cli/_requirements/speed.txt +77 -0
  15. agent_cli/_requirements/vad.txt +155 -0
  16. agent_cli/_requirements/wyoming.txt +71 -0
  17. agent_cli/_tools.py +368 -0
  18. agent_cli/agents/__init__.py +23 -0
  19. agent_cli/agents/_voice_agent_common.py +136 -0
  20. agent_cli/agents/assistant.py +383 -0
  21. agent_cli/agents/autocorrect.py +284 -0
  22. agent_cli/agents/chat.py +496 -0
  23. agent_cli/agents/memory/__init__.py +31 -0
  24. agent_cli/agents/memory/add.py +190 -0
  25. agent_cli/agents/memory/proxy.py +160 -0
  26. agent_cli/agents/rag_proxy.py +128 -0
  27. agent_cli/agents/speak.py +209 -0
  28. agent_cli/agents/transcribe.py +671 -0
  29. agent_cli/agents/transcribe_daemon.py +499 -0
  30. agent_cli/agents/voice_edit.py +291 -0
  31. agent_cli/api.py +22 -0
  32. agent_cli/cli.py +106 -0
  33. agent_cli/config.py +503 -0
  34. agent_cli/config_cmd.py +307 -0
  35. agent_cli/constants.py +27 -0
  36. agent_cli/core/__init__.py +1 -0
  37. agent_cli/core/audio.py +461 -0
  38. agent_cli/core/audio_format.py +299 -0
  39. agent_cli/core/chroma.py +88 -0
  40. agent_cli/core/deps.py +191 -0
  41. agent_cli/core/openai_proxy.py +139 -0
  42. agent_cli/core/process.py +195 -0
  43. agent_cli/core/reranker.py +120 -0
  44. agent_cli/core/sse.py +87 -0
  45. agent_cli/core/transcription_logger.py +70 -0
  46. agent_cli/core/utils.py +526 -0
  47. agent_cli/core/vad.py +175 -0
  48. agent_cli/core/watch.py +65 -0
  49. agent_cli/dev/__init__.py +14 -0
  50. agent_cli/dev/cli.py +1588 -0
  51. agent_cli/dev/coding_agents/__init__.py +19 -0
  52. agent_cli/dev/coding_agents/aider.py +24 -0
  53. agent_cli/dev/coding_agents/base.py +167 -0
  54. agent_cli/dev/coding_agents/claude.py +39 -0
  55. agent_cli/dev/coding_agents/codex.py +24 -0
  56. agent_cli/dev/coding_agents/continue_dev.py +15 -0
  57. agent_cli/dev/coding_agents/copilot.py +24 -0
  58. agent_cli/dev/coding_agents/cursor_agent.py +48 -0
  59. agent_cli/dev/coding_agents/gemini.py +28 -0
  60. agent_cli/dev/coding_agents/opencode.py +15 -0
  61. agent_cli/dev/coding_agents/registry.py +49 -0
  62. agent_cli/dev/editors/__init__.py +19 -0
  63. agent_cli/dev/editors/base.py +89 -0
  64. agent_cli/dev/editors/cursor.py +15 -0
  65. agent_cli/dev/editors/emacs.py +46 -0
  66. agent_cli/dev/editors/jetbrains.py +56 -0
  67. agent_cli/dev/editors/nano.py +31 -0
  68. agent_cli/dev/editors/neovim.py +33 -0
  69. agent_cli/dev/editors/registry.py +59 -0
  70. agent_cli/dev/editors/sublime.py +20 -0
  71. agent_cli/dev/editors/vim.py +42 -0
  72. agent_cli/dev/editors/vscode.py +15 -0
  73. agent_cli/dev/editors/zed.py +20 -0
  74. agent_cli/dev/project.py +568 -0
  75. agent_cli/dev/registry.py +52 -0
  76. agent_cli/dev/skill/SKILL.md +141 -0
  77. agent_cli/dev/skill/examples.md +571 -0
  78. agent_cli/dev/terminals/__init__.py +19 -0
  79. agent_cli/dev/terminals/apple_terminal.py +82 -0
  80. agent_cli/dev/terminals/base.py +56 -0
  81. agent_cli/dev/terminals/gnome.py +51 -0
  82. agent_cli/dev/terminals/iterm2.py +84 -0
  83. agent_cli/dev/terminals/kitty.py +77 -0
  84. agent_cli/dev/terminals/registry.py +48 -0
  85. agent_cli/dev/terminals/tmux.py +58 -0
  86. agent_cli/dev/terminals/warp.py +132 -0
  87. agent_cli/dev/terminals/zellij.py +78 -0
  88. agent_cli/dev/worktree.py +856 -0
  89. agent_cli/docs_gen.py +417 -0
  90. agent_cli/example-config.toml +185 -0
  91. agent_cli/install/__init__.py +5 -0
  92. agent_cli/install/common.py +89 -0
  93. agent_cli/install/extras.py +174 -0
  94. agent_cli/install/hotkeys.py +48 -0
  95. agent_cli/install/services.py +87 -0
  96. agent_cli/memory/__init__.py +7 -0
  97. agent_cli/memory/_files.py +250 -0
  98. agent_cli/memory/_filters.py +63 -0
  99. agent_cli/memory/_git.py +157 -0
  100. agent_cli/memory/_indexer.py +142 -0
  101. agent_cli/memory/_ingest.py +408 -0
  102. agent_cli/memory/_persistence.py +182 -0
  103. agent_cli/memory/_prompt.py +91 -0
  104. agent_cli/memory/_retrieval.py +294 -0
  105. agent_cli/memory/_store.py +169 -0
  106. agent_cli/memory/_streaming.py +44 -0
  107. agent_cli/memory/_tasks.py +48 -0
  108. agent_cli/memory/api.py +113 -0
  109. agent_cli/memory/client.py +272 -0
  110. agent_cli/memory/engine.py +361 -0
  111. agent_cli/memory/entities.py +43 -0
  112. agent_cli/memory/models.py +112 -0
  113. agent_cli/opts.py +433 -0
  114. agent_cli/py.typed +0 -0
  115. agent_cli/rag/__init__.py +3 -0
  116. agent_cli/rag/_indexer.py +67 -0
  117. agent_cli/rag/_indexing.py +226 -0
  118. agent_cli/rag/_prompt.py +30 -0
  119. agent_cli/rag/_retriever.py +156 -0
  120. agent_cli/rag/_store.py +48 -0
  121. agent_cli/rag/_utils.py +218 -0
  122. agent_cli/rag/api.py +175 -0
  123. agent_cli/rag/client.py +299 -0
  124. agent_cli/rag/engine.py +302 -0
  125. agent_cli/rag/models.py +55 -0
  126. agent_cli/scripts/.runtime/.gitkeep +0 -0
  127. agent_cli/scripts/__init__.py +1 -0
  128. agent_cli/scripts/check_plugin_skill_sync.py +50 -0
  129. agent_cli/scripts/linux-hotkeys/README.md +63 -0
  130. agent_cli/scripts/linux-hotkeys/toggle-autocorrect.sh +45 -0
  131. agent_cli/scripts/linux-hotkeys/toggle-transcription.sh +58 -0
  132. agent_cli/scripts/linux-hotkeys/toggle-voice-edit.sh +58 -0
  133. agent_cli/scripts/macos-hotkeys/README.md +45 -0
  134. agent_cli/scripts/macos-hotkeys/skhd-config-example +5 -0
  135. agent_cli/scripts/macos-hotkeys/toggle-autocorrect.sh +12 -0
  136. agent_cli/scripts/macos-hotkeys/toggle-transcription.sh +37 -0
  137. agent_cli/scripts/macos-hotkeys/toggle-voice-edit.sh +37 -0
  138. agent_cli/scripts/nvidia-asr-server/README.md +99 -0
  139. agent_cli/scripts/nvidia-asr-server/pyproject.toml +27 -0
  140. agent_cli/scripts/nvidia-asr-server/server.py +255 -0
  141. agent_cli/scripts/nvidia-asr-server/shell.nix +32 -0
  142. agent_cli/scripts/nvidia-asr-server/uv.lock +4654 -0
  143. agent_cli/scripts/run-openwakeword.sh +11 -0
  144. agent_cli/scripts/run-piper-windows.ps1 +30 -0
  145. agent_cli/scripts/run-piper.sh +24 -0
  146. agent_cli/scripts/run-whisper-linux.sh +40 -0
  147. agent_cli/scripts/run-whisper-macos.sh +6 -0
  148. agent_cli/scripts/run-whisper-windows.ps1 +51 -0
  149. agent_cli/scripts/run-whisper.sh +9 -0
  150. agent_cli/scripts/run_faster_whisper_server.py +136 -0
  151. agent_cli/scripts/setup-linux-hotkeys.sh +72 -0
  152. agent_cli/scripts/setup-linux.sh +108 -0
  153. agent_cli/scripts/setup-macos-hotkeys.sh +61 -0
  154. agent_cli/scripts/setup-macos.sh +76 -0
  155. agent_cli/scripts/setup-windows.ps1 +63 -0
  156. agent_cli/scripts/start-all-services-windows.ps1 +53 -0
  157. agent_cli/scripts/start-all-services.sh +178 -0
  158. agent_cli/scripts/sync_extras.py +138 -0
  159. agent_cli/server/__init__.py +3 -0
  160. agent_cli/server/cli.py +721 -0
  161. agent_cli/server/common.py +222 -0
  162. agent_cli/server/model_manager.py +288 -0
  163. agent_cli/server/model_registry.py +225 -0
  164. agent_cli/server/proxy/__init__.py +3 -0
  165. agent_cli/server/proxy/api.py +444 -0
  166. agent_cli/server/streaming.py +67 -0
  167. agent_cli/server/tts/__init__.py +3 -0
  168. agent_cli/server/tts/api.py +335 -0
  169. agent_cli/server/tts/backends/__init__.py +82 -0
  170. agent_cli/server/tts/backends/base.py +139 -0
  171. agent_cli/server/tts/backends/kokoro.py +403 -0
  172. agent_cli/server/tts/backends/piper.py +253 -0
  173. agent_cli/server/tts/model_manager.py +201 -0
  174. agent_cli/server/tts/model_registry.py +28 -0
  175. agent_cli/server/tts/wyoming_handler.py +249 -0
  176. agent_cli/server/whisper/__init__.py +3 -0
  177. agent_cli/server/whisper/api.py +413 -0
  178. agent_cli/server/whisper/backends/__init__.py +89 -0
  179. agent_cli/server/whisper/backends/base.py +97 -0
  180. agent_cli/server/whisper/backends/faster_whisper.py +225 -0
  181. agent_cli/server/whisper/backends/mlx.py +270 -0
  182. agent_cli/server/whisper/languages.py +116 -0
  183. agent_cli/server/whisper/model_manager.py +157 -0
  184. agent_cli/server/whisper/model_registry.py +28 -0
  185. agent_cli/server/whisper/wyoming_handler.py +203 -0
  186. agent_cli/services/__init__.py +343 -0
  187. agent_cli/services/_wyoming_utils.py +64 -0
  188. agent_cli/services/asr.py +506 -0
  189. agent_cli/services/llm.py +228 -0
  190. agent_cli/services/tts.py +450 -0
  191. agent_cli/services/wake_word.py +142 -0
  192. agent_cli-0.70.5.dist-info/METADATA +2118 -0
  193. agent_cli-0.70.5.dist-info/RECORD +196 -0
  194. agent_cli-0.70.5.dist-info/WHEEL +4 -0
  195. agent_cli-0.70.5.dist-info/entry_points.txt +4 -0
  196. agent_cli-0.70.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,222 @@
1
+ """Common utilities for FastAPI server modules."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import contextlib
7
+ import importlib
8
+ import logging
9
+ from contextlib import asynccontextmanager
10
+ from typing import TYPE_CHECKING, Any, Protocol
11
+
12
+ from rich.logging import RichHandler
13
+
14
+ from agent_cli import constants
15
+ from agent_cli.core.utils import console
16
+
17
+ if TYPE_CHECKING:
18
+ import wave
19
+ from collections.abc import AsyncIterator, Callable, Coroutine
20
+ from contextlib import AbstractAsyncContextManager
21
+
22
+ from fastapi import FastAPI, Request
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class RegistryProtocol(Protocol):
28
+ """Protocol for model registries."""
29
+
30
+ async def start(self) -> None:
31
+ """Start the registry."""
32
+ ...
33
+
34
+ async def stop(self) -> None:
35
+ """Stop the registry."""
36
+ ...
37
+
38
+ async def preload(self) -> None:
39
+ """Preload models."""
40
+ ...
41
+
42
+
43
+ def create_lifespan(
44
+ registry: RegistryProtocol,
45
+ *,
46
+ wyoming_handler_module: str,
47
+ enable_wyoming: bool = True,
48
+ wyoming_uri: str = "tcp://0.0.0.0:10300",
49
+ ) -> Callable[[FastAPI], AbstractAsyncContextManager[None]]:
50
+ """Create a lifespan context manager for a server.
51
+
52
+ Args:
53
+ registry: The model registry to manage.
54
+ wyoming_handler_module: Module path containing start_wyoming_server function.
55
+ enable_wyoming: Whether to start Wyoming server.
56
+ wyoming_uri: URI for Wyoming server.
57
+
58
+ Returns:
59
+ A lifespan context manager function for FastAPI.
60
+
61
+ """
62
+
63
+ @asynccontextmanager
64
+ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
65
+ """Manage application lifecycle."""
66
+ wyoming_task: asyncio.Task[None] | None = None
67
+
68
+ # Start the registry
69
+ await registry.start()
70
+
71
+ # Start Wyoming server if enabled
72
+ if enable_wyoming:
73
+ try:
74
+ module = importlib.import_module(wyoming_handler_module)
75
+ start_wyoming_server: Callable[
76
+ [Any, str],
77
+ Coroutine[Any, Any, None],
78
+ ] = module.start_wyoming_server
79
+
80
+ wyoming_task = asyncio.create_task(
81
+ start_wyoming_server(registry, wyoming_uri),
82
+ )
83
+ except ImportError:
84
+ logger.warning("Wyoming not available, skipping Wyoming server")
85
+ except Exception:
86
+ logger.exception("Failed to start Wyoming server")
87
+
88
+ yield
89
+
90
+ # Stop Wyoming server
91
+ if wyoming_task is not None:
92
+ wyoming_task.cancel()
93
+ with contextlib.suppress(asyncio.CancelledError):
94
+ await wyoming_task
95
+
96
+ # Stop the registry
97
+ await registry.stop()
98
+
99
+ return lifespan
100
+
101
+
102
+ def configure_app(app: FastAPI) -> None:
103
+ """Configure a FastAPI app with common middleware.
104
+
105
+ Adds:
106
+ - CORS middleware allowing all origins
107
+ - Request logging middleware
108
+
109
+ Args:
110
+ app: The FastAPI application to configure.
111
+
112
+ """
113
+ from fastapi.middleware.cors import CORSMiddleware # noqa: PLC0415
114
+
115
+ # Add CORS middleware
116
+ app.add_middleware(
117
+ CORSMiddleware,
118
+ allow_origins=["*"],
119
+ allow_credentials=True,
120
+ allow_methods=["*"],
121
+ allow_headers=["*"],
122
+ )
123
+
124
+ # Add request logging middleware
125
+ @app.middleware("http")
126
+ async def log_requests(request: Any, call_next: Any) -> Any:
127
+ """Log basic request information."""
128
+ return await log_requests_middleware(request, call_next)
129
+
130
+
131
+ def setup_rich_logging(log_level: str = "info") -> None:
132
+ """Configure logging to use Rich for consistent, pretty output.
133
+
134
+ This configures:
135
+ - All Python loggers to use RichHandler
136
+ - Uvicorn's loggers to use the same format
137
+
138
+ Args:
139
+ log_level: Logging level (debug, info, warning, error).
140
+ console: Optional Rich console to use (creates new one if not provided).
141
+
142
+ """
143
+ level = getattr(logging, log_level.upper(), logging.INFO)
144
+
145
+ # Create Rich handler with clean format
146
+ handler = RichHandler(
147
+ console=console,
148
+ show_time=True,
149
+ show_level=True,
150
+ show_path=False, # Don't show file:line - too verbose
151
+ rich_tracebacks=True,
152
+ markup=True,
153
+ )
154
+ handler.setFormatter(logging.Formatter("%(message)s"))
155
+
156
+ # Configure root logger
157
+ root = logging.getLogger()
158
+ root.handlers.clear()
159
+ root.addHandler(handler)
160
+ root.setLevel(level)
161
+
162
+ # Configure uvicorn loggers to use same handler
163
+ for uvicorn_logger_name in ("uvicorn", "uvicorn.access", "uvicorn.error"):
164
+ uvicorn_logger = logging.getLogger(uvicorn_logger_name)
165
+ uvicorn_logger.handlers.clear()
166
+ uvicorn_logger.addHandler(handler)
167
+ uvicorn_logger.setLevel(level)
168
+ uvicorn_logger.propagate = False
169
+
170
+
171
+ def setup_wav_file(
172
+ wav_file: wave.Wave_write,
173
+ *,
174
+ rate: int | None = None,
175
+ channels: int | None = None,
176
+ sample_width: int | None = None,
177
+ ) -> None:
178
+ """Configure a WAV file with standard audio parameters.
179
+
180
+ Args:
181
+ wav_file: The WAV file writer to configure.
182
+ rate: Sample rate in Hz (default: constants.AUDIO_RATE).
183
+ channels: Number of channels (default: constants.AUDIO_CHANNELS).
184
+ sample_width: Sample width in bytes (default: constants.AUDIO_FORMAT_WIDTH).
185
+
186
+ """
187
+ wav_file.setnchannels(channels or constants.AUDIO_CHANNELS)
188
+ wav_file.setsampwidth(sample_width or constants.AUDIO_FORMAT_WIDTH)
189
+ wav_file.setframerate(rate or constants.AUDIO_RATE)
190
+
191
+
192
+ async def log_requests_middleware(
193
+ request: Request,
194
+ call_next: Any,
195
+ ) -> Any:
196
+ """Log basic request information.
197
+
198
+ This middleware logs incoming requests and warns on errors.
199
+ Use with FastAPI's @app.middleware("http") decorator.
200
+
201
+ Args:
202
+ request: The incoming request.
203
+ call_next: The next middleware/handler in the chain.
204
+
205
+ Returns:
206
+ The response from the next handler.
207
+
208
+ """
209
+ client_ip = request.client.host if request.client else "unknown"
210
+ logger.info("%s %s from %s", request.method, request.url.path, client_ip)
211
+
212
+ response = await call_next(request)
213
+
214
+ if response.status_code >= 400: # noqa: PLR2004
215
+ logger.warning(
216
+ "Request failed: %s %s → %d",
217
+ request.method,
218
+ request.url.path,
219
+ response.status_code,
220
+ )
221
+
222
+ return response
@@ -0,0 +1,288 @@
1
+ """Model manager with TTL-based unloading.
2
+
3
+ This module provides a concrete model manager that handles:
4
+ - Lazy loading of models on first request
5
+ - TTL-based automatic unloading when idle
6
+ - Active request tracking to prevent unload during processing
7
+ - Concurrent request coordination
8
+
9
+ The manager works with any backend that implements the BackendProtocol.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import asyncio
15
+ import contextlib
16
+ import logging
17
+ import time
18
+ from contextlib import asynccontextmanager
19
+ from dataclasses import dataclass, field
20
+ from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
21
+
22
+ if TYPE_CHECKING:
23
+ from collections.abc import AsyncIterator
24
+ from pathlib import Path
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class ModelConfig:
31
+ """Configuration for a model."""
32
+
33
+ model_name: str
34
+ device: str = "auto"
35
+ ttl_seconds: int = 300
36
+ cache_dir: Path | None = None
37
+
38
+ def __post_init__(self) -> None:
39
+ """Validate configuration."""
40
+ if self.ttl_seconds < 1:
41
+ msg = f"ttl_seconds must be >= 1, got {self.ttl_seconds}"
42
+ raise ValueError(msg)
43
+
44
+
45
+ @dataclass
46
+ class ModelStats:
47
+ """Runtime statistics for a model."""
48
+
49
+ load_count: int = 0
50
+ unload_count: int = 0
51
+ total_requests: int = 0
52
+ total_audio_seconds: float = 0.0
53
+ total_processing_seconds: float = 0.0
54
+ last_load_time: float | None = None
55
+ last_request_time: float | None = None
56
+ load_duration_seconds: float | None = None
57
+ extra: dict[str, float] = field(default_factory=dict)
58
+
59
+
60
+ @runtime_checkable
61
+ class BackendProtocol(Protocol):
62
+ """Protocol for model backends."""
63
+
64
+ @property
65
+ def is_loaded(self) -> bool:
66
+ """Check if the model is loaded."""
67
+ ...
68
+
69
+ @property
70
+ def device(self) -> str | None:
71
+ """Get the device the model is loaded on."""
72
+ ...
73
+
74
+ async def load(self) -> float:
75
+ """Load the model, return load duration in seconds."""
76
+ ...
77
+
78
+ async def unload(self) -> None:
79
+ """Unload the model."""
80
+ ...
81
+
82
+
83
+ class ModelManager:
84
+ """Manages a model with TTL-based unloading.
85
+
86
+ The model is loaded lazily on first request and unloaded after
87
+ being idle for longer than the configured TTL.
88
+
89
+ Usage:
90
+ manager = ModelManager(backend, config)
91
+ await manager.start()
92
+
93
+ # Use request context for processing
94
+ async with manager.request():
95
+ result = await backend.process(...)
96
+
97
+ await manager.stop()
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ backend: BackendProtocol,
103
+ config: ModelConfig,
104
+ stats: ModelStats | None = None,
105
+ ) -> None:
106
+ """Initialize the model manager.
107
+
108
+ Args:
109
+ backend: The backend instance to manage.
110
+ config: Model configuration.
111
+ stats: Optional stats instance (creates new one if not provided).
112
+
113
+ """
114
+ self.backend = backend
115
+ self.config = config
116
+ self.stats = stats or ModelStats()
117
+ self._condition = asyncio.Condition()
118
+ self._active_requests = 0
119
+ self._unloading = False
120
+ self._unload_task: asyncio.Task[None] | None = None
121
+ self._shutdown = False
122
+
123
+ @property
124
+ def is_loaded(self) -> bool:
125
+ """Check if the model is currently loaded."""
126
+ return self.backend.is_loaded
127
+
128
+ @property
129
+ def device(self) -> str | None:
130
+ """Get the device the model is loaded on."""
131
+ return self.backend.device
132
+
133
+ @property
134
+ def active_requests(self) -> int:
135
+ """Get the number of active requests."""
136
+ return self._active_requests
137
+
138
+ @property
139
+ def ttl_remaining(self) -> float | None:
140
+ """Get seconds remaining before model unloads, or None if not loaded."""
141
+ if not self.is_loaded or self.stats.last_request_time is None:
142
+ return None
143
+ elapsed = time.time() - self.stats.last_request_time
144
+ remaining = self.config.ttl_seconds - elapsed
145
+ return max(0.0, remaining)
146
+
147
+ async def start(self) -> None:
148
+ """Start the TTL unload watcher."""
149
+ if self._unload_task is None:
150
+ self._unload_task = asyncio.create_task(self._unload_watcher())
151
+
152
+ async def stop(self) -> None:
153
+ """Stop the manager and unload the model."""
154
+ self._shutdown = True
155
+ if self._unload_task is not None:
156
+ self._unload_task.cancel()
157
+ with contextlib.suppress(asyncio.CancelledError):
158
+ await self._unload_task
159
+ self._unload_task = None
160
+ await self.unload()
161
+
162
+ async def get_model(self) -> Any:
163
+ """Get the backend, loading it if necessary."""
164
+ await self._ensure_loaded()
165
+ return self.backend
166
+
167
+ @asynccontextmanager
168
+ async def request(self) -> AsyncIterator[None]:
169
+ """Context manager for processing requests.
170
+
171
+ Ensures the model is loaded and tracks active requests.
172
+ Use this around any backend operations.
173
+
174
+ Example:
175
+ async with manager.request():
176
+ result = await manager.backend.synthesize(text)
177
+
178
+ """
179
+ await self._begin_request()
180
+ try:
181
+ yield
182
+ finally:
183
+ await self._end_request()
184
+
185
+ async def unload(self) -> bool:
186
+ """Unload the model from memory.
187
+
188
+ Returns True if model was unloaded, False if it wasn't loaded.
189
+ """
190
+ async with self._condition:
191
+ while self._unloading:
192
+ await self._condition.wait()
193
+
194
+ if not self.backend.is_loaded:
195
+ return False
196
+
197
+ self._unloading = True
198
+ try:
199
+ while self._active_requests > 0:
200
+ logger.info(
201
+ "Waiting for %d active requests before unloading %s",
202
+ self._active_requests,
203
+ self.config.model_name,
204
+ )
205
+ await self._condition.wait()
206
+
207
+ if not self.backend.is_loaded:
208
+ return False
209
+
210
+ await self.backend.unload()
211
+ self.stats.unload_count += 1
212
+ return True
213
+ finally:
214
+ self._unloading = False
215
+ self._condition.notify_all()
216
+
217
+ async def _load_if_needed_locked(self) -> None:
218
+ """Load the model if needed (expects condition lock held)."""
219
+ if not self.backend.is_loaded:
220
+ load_duration = await self.backend.load()
221
+ self.stats.load_count += 1
222
+ self.stats.last_load_time = time.time()
223
+ self.stats.load_duration_seconds = load_duration
224
+ self.stats.last_request_time = time.time()
225
+
226
+ async def _ensure_loaded(self) -> None:
227
+ """Ensure the model is loaded."""
228
+ async with self._condition:
229
+ while self._unloading:
230
+ await self._condition.wait()
231
+ await self._load_if_needed_locked()
232
+
233
+ async def _begin_request(self) -> None:
234
+ """Begin a request, waiting if unload is in progress."""
235
+ async with self._condition:
236
+ while self._unloading:
237
+ await self._condition.wait()
238
+ await self._load_if_needed_locked()
239
+ self._active_requests += 1
240
+
241
+ async def _end_request(self) -> None:
242
+ """End a request and notify waiters if no more active requests."""
243
+ async with self._condition:
244
+ self._active_requests -= 1
245
+ self.stats.last_request_time = time.time()
246
+ if self._active_requests == 0:
247
+ self._condition.notify_all()
248
+
249
+ async def _unload_watcher(self) -> None:
250
+ """Background task that unloads model after TTL expires."""
251
+ check_interval = min(30, self.config.ttl_seconds / 2)
252
+
253
+ while not self._shutdown:
254
+ try:
255
+ await asyncio.sleep(check_interval)
256
+
257
+ async with self._condition:
258
+ if self._unloading:
259
+ continue
260
+ if not self.backend.is_loaded:
261
+ continue
262
+
263
+ if self.stats.last_request_time is None:
264
+ continue
265
+
266
+ idle_time = time.time() - self.stats.last_request_time
267
+
268
+ if idle_time >= self.config.ttl_seconds:
269
+ if self._active_requests == 0:
270
+ logger.info(
271
+ "Model %s idle for %.0fs (ttl=%ds), unloading",
272
+ self.config.model_name,
273
+ idle_time,
274
+ self.config.ttl_seconds,
275
+ )
276
+ await self.backend.unload()
277
+ self.stats.unload_count += 1
278
+ else:
279
+ logger.debug(
280
+ "Model %s would unload but has %d active requests",
281
+ self.config.model_name,
282
+ self._active_requests,
283
+ )
284
+
285
+ except asyncio.CancelledError:
286
+ break
287
+ except Exception:
288
+ logger.exception("Error in unload watcher")