agent-cli 0.70.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. agent_cli/__init__.py +5 -0
  2. agent_cli/__main__.py +6 -0
  3. agent_cli/_extras.json +14 -0
  4. agent_cli/_requirements/.gitkeep +0 -0
  5. agent_cli/_requirements/audio.txt +79 -0
  6. agent_cli/_requirements/faster-whisper.txt +215 -0
  7. agent_cli/_requirements/kokoro.txt +425 -0
  8. agent_cli/_requirements/llm.txt +183 -0
  9. agent_cli/_requirements/memory.txt +355 -0
  10. agent_cli/_requirements/mlx-whisper.txt +222 -0
  11. agent_cli/_requirements/piper.txt +176 -0
  12. agent_cli/_requirements/rag.txt +402 -0
  13. agent_cli/_requirements/server.txt +154 -0
  14. agent_cli/_requirements/speed.txt +77 -0
  15. agent_cli/_requirements/vad.txt +155 -0
  16. agent_cli/_requirements/wyoming.txt +71 -0
  17. agent_cli/_tools.py +368 -0
  18. agent_cli/agents/__init__.py +23 -0
  19. agent_cli/agents/_voice_agent_common.py +136 -0
  20. agent_cli/agents/assistant.py +383 -0
  21. agent_cli/agents/autocorrect.py +284 -0
  22. agent_cli/agents/chat.py +496 -0
  23. agent_cli/agents/memory/__init__.py +31 -0
  24. agent_cli/agents/memory/add.py +190 -0
  25. agent_cli/agents/memory/proxy.py +160 -0
  26. agent_cli/agents/rag_proxy.py +128 -0
  27. agent_cli/agents/speak.py +209 -0
  28. agent_cli/agents/transcribe.py +671 -0
  29. agent_cli/agents/transcribe_daemon.py +499 -0
  30. agent_cli/agents/voice_edit.py +291 -0
  31. agent_cli/api.py +22 -0
  32. agent_cli/cli.py +106 -0
  33. agent_cli/config.py +503 -0
  34. agent_cli/config_cmd.py +307 -0
  35. agent_cli/constants.py +27 -0
  36. agent_cli/core/__init__.py +1 -0
  37. agent_cli/core/audio.py +461 -0
  38. agent_cli/core/audio_format.py +299 -0
  39. agent_cli/core/chroma.py +88 -0
  40. agent_cli/core/deps.py +191 -0
  41. agent_cli/core/openai_proxy.py +139 -0
  42. agent_cli/core/process.py +195 -0
  43. agent_cli/core/reranker.py +120 -0
  44. agent_cli/core/sse.py +87 -0
  45. agent_cli/core/transcription_logger.py +70 -0
  46. agent_cli/core/utils.py +526 -0
  47. agent_cli/core/vad.py +175 -0
  48. agent_cli/core/watch.py +65 -0
  49. agent_cli/dev/__init__.py +14 -0
  50. agent_cli/dev/cli.py +1588 -0
  51. agent_cli/dev/coding_agents/__init__.py +19 -0
  52. agent_cli/dev/coding_agents/aider.py +24 -0
  53. agent_cli/dev/coding_agents/base.py +167 -0
  54. agent_cli/dev/coding_agents/claude.py +39 -0
  55. agent_cli/dev/coding_agents/codex.py +24 -0
  56. agent_cli/dev/coding_agents/continue_dev.py +15 -0
  57. agent_cli/dev/coding_agents/copilot.py +24 -0
  58. agent_cli/dev/coding_agents/cursor_agent.py +48 -0
  59. agent_cli/dev/coding_agents/gemini.py +28 -0
  60. agent_cli/dev/coding_agents/opencode.py +15 -0
  61. agent_cli/dev/coding_agents/registry.py +49 -0
  62. agent_cli/dev/editors/__init__.py +19 -0
  63. agent_cli/dev/editors/base.py +89 -0
  64. agent_cli/dev/editors/cursor.py +15 -0
  65. agent_cli/dev/editors/emacs.py +46 -0
  66. agent_cli/dev/editors/jetbrains.py +56 -0
  67. agent_cli/dev/editors/nano.py +31 -0
  68. agent_cli/dev/editors/neovim.py +33 -0
  69. agent_cli/dev/editors/registry.py +59 -0
  70. agent_cli/dev/editors/sublime.py +20 -0
  71. agent_cli/dev/editors/vim.py +42 -0
  72. agent_cli/dev/editors/vscode.py +15 -0
  73. agent_cli/dev/editors/zed.py +20 -0
  74. agent_cli/dev/project.py +568 -0
  75. agent_cli/dev/registry.py +52 -0
  76. agent_cli/dev/skill/SKILL.md +141 -0
  77. agent_cli/dev/skill/examples.md +571 -0
  78. agent_cli/dev/terminals/__init__.py +19 -0
  79. agent_cli/dev/terminals/apple_terminal.py +82 -0
  80. agent_cli/dev/terminals/base.py +56 -0
  81. agent_cli/dev/terminals/gnome.py +51 -0
  82. agent_cli/dev/terminals/iterm2.py +84 -0
  83. agent_cli/dev/terminals/kitty.py +77 -0
  84. agent_cli/dev/terminals/registry.py +48 -0
  85. agent_cli/dev/terminals/tmux.py +58 -0
  86. agent_cli/dev/terminals/warp.py +132 -0
  87. agent_cli/dev/terminals/zellij.py +78 -0
  88. agent_cli/dev/worktree.py +856 -0
  89. agent_cli/docs_gen.py +417 -0
  90. agent_cli/example-config.toml +185 -0
  91. agent_cli/install/__init__.py +5 -0
  92. agent_cli/install/common.py +89 -0
  93. agent_cli/install/extras.py +174 -0
  94. agent_cli/install/hotkeys.py +48 -0
  95. agent_cli/install/services.py +87 -0
  96. agent_cli/memory/__init__.py +7 -0
  97. agent_cli/memory/_files.py +250 -0
  98. agent_cli/memory/_filters.py +63 -0
  99. agent_cli/memory/_git.py +157 -0
  100. agent_cli/memory/_indexer.py +142 -0
  101. agent_cli/memory/_ingest.py +408 -0
  102. agent_cli/memory/_persistence.py +182 -0
  103. agent_cli/memory/_prompt.py +91 -0
  104. agent_cli/memory/_retrieval.py +294 -0
  105. agent_cli/memory/_store.py +169 -0
  106. agent_cli/memory/_streaming.py +44 -0
  107. agent_cli/memory/_tasks.py +48 -0
  108. agent_cli/memory/api.py +113 -0
  109. agent_cli/memory/client.py +272 -0
  110. agent_cli/memory/engine.py +361 -0
  111. agent_cli/memory/entities.py +43 -0
  112. agent_cli/memory/models.py +112 -0
  113. agent_cli/opts.py +433 -0
  114. agent_cli/py.typed +0 -0
  115. agent_cli/rag/__init__.py +3 -0
  116. agent_cli/rag/_indexer.py +67 -0
  117. agent_cli/rag/_indexing.py +226 -0
  118. agent_cli/rag/_prompt.py +30 -0
  119. agent_cli/rag/_retriever.py +156 -0
  120. agent_cli/rag/_store.py +48 -0
  121. agent_cli/rag/_utils.py +218 -0
  122. agent_cli/rag/api.py +175 -0
  123. agent_cli/rag/client.py +299 -0
  124. agent_cli/rag/engine.py +302 -0
  125. agent_cli/rag/models.py +55 -0
  126. agent_cli/scripts/.runtime/.gitkeep +0 -0
  127. agent_cli/scripts/__init__.py +1 -0
  128. agent_cli/scripts/check_plugin_skill_sync.py +50 -0
  129. agent_cli/scripts/linux-hotkeys/README.md +63 -0
  130. agent_cli/scripts/linux-hotkeys/toggle-autocorrect.sh +45 -0
  131. agent_cli/scripts/linux-hotkeys/toggle-transcription.sh +58 -0
  132. agent_cli/scripts/linux-hotkeys/toggle-voice-edit.sh +58 -0
  133. agent_cli/scripts/macos-hotkeys/README.md +45 -0
  134. agent_cli/scripts/macos-hotkeys/skhd-config-example +5 -0
  135. agent_cli/scripts/macos-hotkeys/toggle-autocorrect.sh +12 -0
  136. agent_cli/scripts/macos-hotkeys/toggle-transcription.sh +37 -0
  137. agent_cli/scripts/macos-hotkeys/toggle-voice-edit.sh +37 -0
  138. agent_cli/scripts/nvidia-asr-server/README.md +99 -0
  139. agent_cli/scripts/nvidia-asr-server/pyproject.toml +27 -0
  140. agent_cli/scripts/nvidia-asr-server/server.py +255 -0
  141. agent_cli/scripts/nvidia-asr-server/shell.nix +32 -0
  142. agent_cli/scripts/nvidia-asr-server/uv.lock +4654 -0
  143. agent_cli/scripts/run-openwakeword.sh +11 -0
  144. agent_cli/scripts/run-piper-windows.ps1 +30 -0
  145. agent_cli/scripts/run-piper.sh +24 -0
  146. agent_cli/scripts/run-whisper-linux.sh +40 -0
  147. agent_cli/scripts/run-whisper-macos.sh +6 -0
  148. agent_cli/scripts/run-whisper-windows.ps1 +51 -0
  149. agent_cli/scripts/run-whisper.sh +9 -0
  150. agent_cli/scripts/run_faster_whisper_server.py +136 -0
  151. agent_cli/scripts/setup-linux-hotkeys.sh +72 -0
  152. agent_cli/scripts/setup-linux.sh +108 -0
  153. agent_cli/scripts/setup-macos-hotkeys.sh +61 -0
  154. agent_cli/scripts/setup-macos.sh +76 -0
  155. agent_cli/scripts/setup-windows.ps1 +63 -0
  156. agent_cli/scripts/start-all-services-windows.ps1 +53 -0
  157. agent_cli/scripts/start-all-services.sh +178 -0
  158. agent_cli/scripts/sync_extras.py +138 -0
  159. agent_cli/server/__init__.py +3 -0
  160. agent_cli/server/cli.py +721 -0
  161. agent_cli/server/common.py +222 -0
  162. agent_cli/server/model_manager.py +288 -0
  163. agent_cli/server/model_registry.py +225 -0
  164. agent_cli/server/proxy/__init__.py +3 -0
  165. agent_cli/server/proxy/api.py +444 -0
  166. agent_cli/server/streaming.py +67 -0
  167. agent_cli/server/tts/__init__.py +3 -0
  168. agent_cli/server/tts/api.py +335 -0
  169. agent_cli/server/tts/backends/__init__.py +82 -0
  170. agent_cli/server/tts/backends/base.py +139 -0
  171. agent_cli/server/tts/backends/kokoro.py +403 -0
  172. agent_cli/server/tts/backends/piper.py +253 -0
  173. agent_cli/server/tts/model_manager.py +201 -0
  174. agent_cli/server/tts/model_registry.py +28 -0
  175. agent_cli/server/tts/wyoming_handler.py +249 -0
  176. agent_cli/server/whisper/__init__.py +3 -0
  177. agent_cli/server/whisper/api.py +413 -0
  178. agent_cli/server/whisper/backends/__init__.py +89 -0
  179. agent_cli/server/whisper/backends/base.py +97 -0
  180. agent_cli/server/whisper/backends/faster_whisper.py +225 -0
  181. agent_cli/server/whisper/backends/mlx.py +270 -0
  182. agent_cli/server/whisper/languages.py +116 -0
  183. agent_cli/server/whisper/model_manager.py +157 -0
  184. agent_cli/server/whisper/model_registry.py +28 -0
  185. agent_cli/server/whisper/wyoming_handler.py +203 -0
  186. agent_cli/services/__init__.py +343 -0
  187. agent_cli/services/_wyoming_utils.py +64 -0
  188. agent_cli/services/asr.py +506 -0
  189. agent_cli/services/llm.py +228 -0
  190. agent_cli/services/tts.py +450 -0
  191. agent_cli/services/wake_word.py +142 -0
  192. agent_cli-0.70.5.dist-info/METADATA +2118 -0
  193. agent_cli-0.70.5.dist-info/RECORD +196 -0
  194. agent_cli-0.70.5.dist-info/WHEEL +4 -0
  195. agent_cli-0.70.5.dist-info/entry_points.txt +4 -0
  196. agent_cli-0.70.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,139 @@
1
+ """Shared OpenAI-compatible forwarding helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
8
+
9
+ if TYPE_CHECKING:
10
+ from collections.abc import AsyncGenerator, Iterable
11
+
12
+ from fastapi import Request, Response
13
+
14
+ LOGGER = logging.getLogger(__name__)
15
+
16
+
17
+ @runtime_checkable
18
+ class ChatRequestLike(Protocol):
19
+ """Minimal interface required to forward a chat request."""
20
+
21
+ stream: bool | None
22
+
23
+ def model_dump(self, *, exclude: set[str] | None = None) -> dict[str, Any]:
24
+ """Serialize request to a dict for forwarding."""
25
+
26
+
27
+ async def proxy_request_to_upstream(
28
+ request: Request,
29
+ path: str,
30
+ upstream_base_url: str,
31
+ api_key: str | None = None,
32
+ ) -> Response:
33
+ """Forward a raw HTTP request to an upstream OpenAI-compatible provider."""
34
+ import httpx # noqa: PLC0415
35
+ from fastapi import Response # noqa: PLC0415
36
+
37
+ auth_header = request.headers.get("Authorization")
38
+ headers = {}
39
+ if auth_header:
40
+ headers["Authorization"] = auth_header
41
+ elif api_key:
42
+ headers["Authorization"] = f"Bearer {api_key}"
43
+
44
+ if request.headers.get("Content-Type"):
45
+ headers["Content-Type"] = request.headers.get("Content-Type")
46
+
47
+ base = upstream_base_url.rstrip("/")
48
+ target_path = path
49
+
50
+ # Smart path joining to avoid /v1/v1/ if base already has it
51
+ if base.endswith("/v1") and (path == "v1" or path.startswith("v1/")):
52
+ target_path = path[2:].lstrip("/")
53
+
54
+ url = f"{base}/{target_path}"
55
+
56
+ try:
57
+ body = await request.body()
58
+ async with httpx.AsyncClient(timeout=60.0) as http:
59
+ req = http.build_request(
60
+ request.method,
61
+ url,
62
+ headers=headers,
63
+ content=body,
64
+ params=request.query_params,
65
+ )
66
+ resp = await http.send(req)
67
+
68
+ return Response(
69
+ content=resp.content,
70
+ status_code=resp.status_code,
71
+ media_type=resp.headers.get("Content-Type"),
72
+ )
73
+ except Exception:
74
+ LOGGER.warning("Proxy request failed to %s", url, exc_info=True)
75
+ return Response(status_code=502, content="Upstream Proxy Error")
76
+
77
+
78
+ async def forward_chat_request(
79
+ request: ChatRequestLike,
80
+ openai_base_url: str,
81
+ api_key: str | None = None,
82
+ *,
83
+ exclude_fields: Iterable[str] = (),
84
+ ) -> Any:
85
+ """Forward a chat request to a backend LLM."""
86
+ import httpx # noqa: PLC0415
87
+ from fastapi import HTTPException # noqa: PLC0415
88
+ from fastapi.responses import StreamingResponse # noqa: PLC0415
89
+
90
+ forward_payload = request.model_dump(exclude=set(exclude_fields))
91
+ headers = {"Authorization": f"Bearer {api_key}"} if api_key else None
92
+
93
+ if getattr(request, "stream", False):
94
+
95
+ async def generate() -> AsyncGenerator[str, None]:
96
+ try:
97
+ async with (
98
+ httpx.AsyncClient(timeout=120.0) as client,
99
+ client.stream(
100
+ "POST",
101
+ f"{openai_base_url.rstrip('/')}/chat/completions",
102
+ json=forward_payload,
103
+ headers=headers,
104
+ ) as response,
105
+ ):
106
+ if response.status_code != 200: # noqa: PLR2004
107
+ error_text = await response.aread()
108
+ yield f"data: {json.dumps({'error': str(error_text)})}\n\n"
109
+ return
110
+
111
+ async for chunk in response.aiter_raw():
112
+ if isinstance(chunk, bytes):
113
+ yield chunk.decode("utf-8")
114
+ else:
115
+ yield chunk
116
+ except Exception as exc:
117
+ LOGGER.exception("Streaming error")
118
+ yield f"data: {json.dumps({'error': str(exc)})}\n\n"
119
+
120
+ return StreamingResponse(generate(), media_type="text/event-stream")
121
+
122
+ async with httpx.AsyncClient(timeout=120.0) as client:
123
+ response = await client.post(
124
+ f"{openai_base_url.rstrip('/')}/chat/completions",
125
+ json=forward_payload,
126
+ headers=headers,
127
+ )
128
+ if response.status_code != 200: # noqa: PLR2004
129
+ LOGGER.error(
130
+ "Upstream error %s: %s",
131
+ response.status_code,
132
+ response.text,
133
+ )
134
+ raise HTTPException(
135
+ status_code=response.status_code,
136
+ detail=f"Upstream error: {response.text}",
137
+ )
138
+
139
+ return response.json()
@@ -0,0 +1,195 @@
1
+ """Process management utilities for Agent CLI tools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import signal
7
+ import sys
8
+ import time
9
+ from contextlib import contextmanager
10
+ from pathlib import Path
11
+ from typing import TYPE_CHECKING
12
+
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Generator
15
+
16
+ # Default location for PID files
17
+ PID_DIR = Path.home() / ".cache" / "agent-cli"
18
+
19
+ # Store the original process title before any modifications
20
+ _original_proctitle: str | None = None
21
+
22
+
23
+ def set_process_title(process_name: str) -> None:
24
+ """Set the process title and thread name for identification in ps/htop/btop.
25
+
26
+ Sets both:
27
+ - Process title: 'agent-cli-{name} ({original})' - identifiable prefix + original command
28
+ - Thread name: 'ag-{name}' (max 15 chars) - shown as program name in btop/htop
29
+
30
+ The original command line is captured on first call and reused on subsequent
31
+ calls to prevent nested titles like 'agent-cli-x (agent-cli-y (...))'.
32
+
33
+ Args:
34
+ process_name: The name of the process (e.g., 'transcribe', 'chat').
35
+
36
+ """
37
+ import setproctitle # noqa: PLC0415
38
+
39
+ global _original_proctitle
40
+
41
+ # Capture the original command line only once, before any modification
42
+ if _original_proctitle is None:
43
+ _original_proctitle = setproctitle.getproctitle()
44
+
45
+ # Set the full process title: identifiable prefix + original command for debugging
46
+ setproctitle.setproctitle(f"agent-cli-{process_name} ({_original_proctitle})")
47
+
48
+ # Set the thread name (program name in htop/btop, limited to 15 chars on Linux)
49
+ # Use shorter prefix "ag-" to fit more of the command name
50
+ thread_name = f"ag-{process_name}"[:15]
51
+ setproctitle.setthreadtitle(thread_name)
52
+
53
+
54
+ def _get_pid_file(process_name: str) -> Path:
55
+ """Get the path to the PID file for a given process name."""
56
+ PID_DIR.mkdir(parents=True, exist_ok=True)
57
+ return PID_DIR / f"{process_name}.pid"
58
+
59
+
60
+ def _get_stop_file(process_name: str) -> Path:
61
+ """Get the path to the stop file for a given process name."""
62
+ PID_DIR.mkdir(parents=True, exist_ok=True)
63
+ return PID_DIR / f"{process_name}.stop"
64
+
65
+
66
+ def check_stop_file(process_name: str) -> bool:
67
+ """Check if a stop file exists (used for cross-process signaling on Windows)."""
68
+ return _get_stop_file(process_name).exists()
69
+
70
+
71
+ def clear_stop_file(process_name: str) -> None:
72
+ """Remove the stop file for the given process."""
73
+ stop_file = _get_stop_file(process_name)
74
+ if stop_file.exists():
75
+ stop_file.unlink()
76
+
77
+
78
+ def _is_pid_running(pid: int) -> bool:
79
+ """Check if a process with the given PID is running."""
80
+ if sys.platform == "win32":
81
+ # On Windows, os.kill(pid, 0) would terminate the process!
82
+ import psutil # noqa: PLC0415
83
+
84
+ return psutil.pid_exists(pid)
85
+ try:
86
+ os.kill(pid, 0)
87
+ return True
88
+ except (ProcessLookupError, PermissionError):
89
+ return False
90
+
91
+
92
+ def _get_running_pid(process_name: str) -> int | None:
93
+ """Get PID if process is running, None otherwise. Cleans up stale files."""
94
+ pid_file = _get_pid_file(process_name)
95
+
96
+ if not pid_file.exists():
97
+ return None
98
+
99
+ try:
100
+ with pid_file.open() as f:
101
+ pid = int(f.read().strip())
102
+
103
+ # Check if process is actually running
104
+ if _is_pid_running(pid):
105
+ return pid
106
+
107
+ except (FileNotFoundError, ValueError):
108
+ pass
109
+
110
+ # Clean up stale/invalid PID file
111
+ if pid_file.exists():
112
+ pid_file.unlink()
113
+ return None
114
+
115
+
116
+ def is_process_running(process_name: str) -> bool:
117
+ """Check if a process is currently running."""
118
+ return _get_running_pid(process_name) is not None
119
+
120
+
121
+ def read_pid_file(process_name: str) -> int | None:
122
+ """Read PID from file if process is running."""
123
+ return _get_running_pid(process_name)
124
+
125
+
126
+ def kill_process(process_name: str) -> bool:
127
+ """Kill a process by name.
128
+
129
+ Returns True if killed or cleaned up, False if not found.
130
+ On Windows, creates a stop file first to allow graceful shutdown.
131
+ """
132
+ pid_file = _get_pid_file(process_name)
133
+
134
+ # If no PID file exists at all, nothing to do
135
+ if not pid_file.exists():
136
+ return False
137
+
138
+ # Check if we have a running process
139
+ pid = _get_running_pid(process_name)
140
+
141
+ # If _get_running_pid returned None but file existed, it cleaned up a stale file
142
+ if pid is None:
143
+ return True
144
+
145
+ # On Windows, create stop file to signal graceful shutdown
146
+ if sys.platform == "win32":
147
+ _get_stop_file(process_name).touch()
148
+
149
+ # Send SIGINT for graceful shutdown
150
+ try:
151
+ os.kill(pid, signal.SIGINT)
152
+ # Wait for process to terminate
153
+ for _ in range(10): # 1 second max
154
+ if not is_process_running(process_name):
155
+ break
156
+ time.sleep(0.1)
157
+ except (ProcessLookupError, PermissionError):
158
+ pass # Process dead or no permission - we'll clean up regardless
159
+
160
+ # Clean up
161
+ if sys.platform == "win32":
162
+ clear_stop_file(process_name)
163
+ if pid_file.exists():
164
+ pid_file.unlink()
165
+
166
+ return True
167
+
168
+
169
+ @contextmanager
170
+ def pid_file_context(process_name: str) -> Generator[Path, None, None]:
171
+ """Context manager for PID file lifecycle.
172
+
173
+ Creates PID file on entry, cleans up on exit.
174
+ Exits with error if process already running.
175
+ """
176
+ if is_process_running(process_name):
177
+ existing_pid = _get_running_pid(process_name)
178
+ print(f"Process {process_name} is already running (PID: {existing_pid})")
179
+ sys.exit(1)
180
+
181
+ # Clear any stale stop file from previous run (Windows only)
182
+ if sys.platform == "win32":
183
+ clear_stop_file(process_name)
184
+
185
+ pid_file = _get_pid_file(process_name)
186
+ with pid_file.open("w") as f:
187
+ f.write(str(os.getpid()))
188
+
189
+ try:
190
+ yield pid_file
191
+ finally:
192
+ if pid_file.exists():
193
+ pid_file.unlink()
194
+ if sys.platform == "win32":
195
+ clear_stop_file(process_name)
@@ -0,0 +1,120 @@
1
+ """Shared ONNX Cross-Encoder for reranking (used by both RAG and Memory)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ LOGGER = logging.getLogger(__name__)
8
+
9
+
10
+ def _download_onnx_model(model_name: str, onnx_filename: str) -> str:
11
+ """Download the ONNX model, favoring the common `onnx/` folder layout."""
12
+ from huggingface_hub import hf_hub_download # noqa: PLC0415
13
+
14
+ if "/" in onnx_filename:
15
+ return hf_hub_download(repo_id=model_name, filename=onnx_filename)
16
+
17
+ try:
18
+ return hf_hub_download(repo_id=model_name, filename=onnx_filename, subfolder="onnx")
19
+ except Exception as first_error:
20
+ LOGGER.debug(
21
+ "ONNX file not found under onnx/ for %s: %s. Falling back to repo root.",
22
+ model_name,
23
+ first_error,
24
+ )
25
+ try:
26
+ return hf_hub_download(repo_id=model_name, filename=onnx_filename)
27
+ except Exception as second_error:
28
+ LOGGER.exception(
29
+ "Failed to download ONNX model %s (filename=%s)",
30
+ model_name,
31
+ onnx_filename,
32
+ exc_info=second_error,
33
+ )
34
+ raise
35
+
36
+
37
+ class OnnxCrossEncoder:
38
+ """A lightweight CrossEncoder using ONNX Runtime."""
39
+
40
+ def __init__(
41
+ self,
42
+ model_name: str = "Xenova/ms-marco-MiniLM-L-6-v2",
43
+ onnx_filename: str = "model.onnx",
44
+ ) -> None:
45
+ """Initialize the ONNX CrossEncoder."""
46
+ from onnxruntime import InferenceSession # noqa: PLC0415
47
+ from transformers import AutoTokenizer # noqa: PLC0415
48
+
49
+ self.model_name = model_name
50
+
51
+ # Download model if needed
52
+ LOGGER.info("Loading ONNX model: %s", model_name)
53
+ model_path = _download_onnx_model(model_name, onnx_filename)
54
+
55
+ self.session = InferenceSession(model_path)
56
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
57
+
58
+ def predict(
59
+ self,
60
+ pairs: list[tuple[str, str]],
61
+ batch_size: int = 32,
62
+ ) -> list[float]:
63
+ """Predict relevance scores for query-document pairs."""
64
+ import numpy as np # noqa: PLC0415
65
+
66
+ if not pairs:
67
+ return []
68
+
69
+ all_scores = []
70
+
71
+ # Process in batches
72
+ for i in range(0, len(pairs), batch_size):
73
+ batch = pairs[i : i + batch_size]
74
+ queries = [q for q, d in batch]
75
+ docs = [d for q, d in batch]
76
+
77
+ # Tokenize
78
+ inputs = self.tokenizer(
79
+ queries,
80
+ docs,
81
+ padding=True,
82
+ truncation=True,
83
+ return_tensors="np",
84
+ max_length=512,
85
+ )
86
+
87
+ # ONNX Input
88
+ # Check what inputs the model expects. usually input_ids, attention_mask, token_type_ids
89
+ # specific models might not need token_type_ids
90
+ ort_inputs = {
91
+ "input_ids": inputs["input_ids"].astype(np.int64),
92
+ "attention_mask": inputs["attention_mask"].astype(np.int64),
93
+ }
94
+ if "token_type_ids" in inputs:
95
+ ort_inputs["token_type_ids"] = inputs["token_type_ids"].astype(np.int64)
96
+
97
+ # Run inference
98
+ logits = self.session.run(None, ort_inputs)[0]
99
+
100
+ # Extract scores (usually shape [batch, 1] or [batch])
101
+ batch_scores = logits.flatten() if logits.ndim > 1 else logits
102
+
103
+ all_scores.extend(batch_scores.tolist())
104
+
105
+ return all_scores
106
+
107
+
108
+ def get_reranker_model(
109
+ model_name: str = "Xenova/ms-marco-MiniLM-L-6-v2",
110
+ ) -> OnnxCrossEncoder:
111
+ """Load the CrossEncoder model."""
112
+ return OnnxCrossEncoder(model_name)
113
+
114
+
115
+ def predict_relevance(
116
+ model: OnnxCrossEncoder,
117
+ pairs: list[tuple[str, str]],
118
+ ) -> list[float]:
119
+ """Predict relevance scores for query-document pairs."""
120
+ return model.predict(pairs)
agent_cli/core/sse.py ADDED
@@ -0,0 +1,87 @@
1
+ """Shared SSE (Server-Sent Events) formatting helpers for OpenAI-compatible streaming."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import time
7
+ from typing import Any
8
+
9
+
10
+ def format_chunk(
11
+ run_id: str,
12
+ model: str,
13
+ *,
14
+ content: str | None = None,
15
+ finish_reason: str | None = None,
16
+ extra: dict[str, Any] | None = None,
17
+ ) -> str:
18
+ """Format a single SSE chunk in OpenAI chat.completion.chunk format.
19
+
20
+ Args:
21
+ run_id: Unique identifier for this completion.
22
+ model: Model name to include in response.
23
+ content: Text content delta (None for finish chunk).
24
+ finish_reason: Reason for completion (e.g., "stop").
25
+ extra: Additional fields to include in the response.
26
+
27
+ Returns:
28
+ Formatted SSE data line.
29
+
30
+ """
31
+ data: dict[str, Any] = {
32
+ "id": f"chatcmpl-{run_id}",
33
+ "object": "chat.completion.chunk",
34
+ "created": int(time.time()),
35
+ "model": model,
36
+ "choices": [
37
+ {
38
+ "index": 0,
39
+ "delta": {"content": content} if content else {},
40
+ "finish_reason": finish_reason,
41
+ },
42
+ ],
43
+ }
44
+ if extra:
45
+ data.update(extra)
46
+ return f"data: {json.dumps(data)}\n\n"
47
+
48
+
49
+ def format_done() -> str:
50
+ """Format the terminal [DONE] SSE message."""
51
+ return "data: [DONE]\n\n"
52
+
53
+
54
+ def parse_chunk(line: str) -> dict[str, Any] | None:
55
+ """Parse an SSE data line into a dict.
56
+
57
+ Args:
58
+ line: Raw SSE line (e.g., "data: {...}").
59
+
60
+ Returns:
61
+ Parsed JSON dict, or None if not parseable or [DONE].
62
+
63
+ """
64
+ if not line.startswith("data:"):
65
+ return None
66
+ payload = line[5:].strip()
67
+ if payload == "[DONE]":
68
+ return None
69
+ try:
70
+ return json.loads(payload)
71
+ except json.JSONDecodeError:
72
+ return None
73
+
74
+
75
+ def extract_content_from_chunk(chunk: dict[str, Any]) -> str:
76
+ """Extract text content from a parsed SSE chunk.
77
+
78
+ Args:
79
+ chunk: Parsed chunk dict from parse_chunk().
80
+
81
+ Returns:
82
+ Content string, or empty string if not found.
83
+
84
+ """
85
+ choices = chunk.get("choices") or [{}]
86
+ delta = choices[0].get("delta") or {}
87
+ return delta.get("content") or delta.get("text") or ""
@@ -0,0 +1,70 @@
1
+ """Transcription logging utilities for automatic server-side logging."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ from datetime import UTC, datetime
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+
12
+ class TranscriptionLogger:
13
+ """Handles automatic logging of transcription results with timestamps."""
14
+
15
+ def __init__(self, log_file: Path | str | None = None) -> None:
16
+ """Initialize the transcription logger.
17
+
18
+ Args:
19
+ log_file: Path to the log file. If None, uses default location.
20
+
21
+ """
22
+ if log_file is None:
23
+ log_file = Path.home() / ".config" / "agent-cli" / "transcriptions.jsonl"
24
+ elif isinstance(log_file, str):
25
+ log_file = Path(log_file)
26
+
27
+ self.log_file = log_file
28
+
29
+ # Ensure the log directory exists
30
+ self.log_file.parent.mkdir(parents=True, exist_ok=True)
31
+
32
+ def log_transcription(
33
+ self,
34
+ *,
35
+ raw: str,
36
+ processed: str | None = None,
37
+ ) -> None:
38
+ """Log a transcription result.
39
+
40
+ Args:
41
+ raw: The raw transcript from ASR.
42
+ processed: The processed transcript from LLM.
43
+
44
+ """
45
+ log_entry: dict[str, Any] = {
46
+ "timestamp": datetime.now(UTC).isoformat(),
47
+ "raw": raw,
48
+ "processed": processed,
49
+ }
50
+
51
+ # Write to log file as JSON Lines format
52
+ try:
53
+ with self.log_file.open("a", encoding="utf-8") as f:
54
+ f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
55
+ except OSError:
56
+ # Use Python's logging module to log errors with the logger itself
57
+ logger = logging.getLogger(__name__)
58
+ logger.exception("Failed to write transcription log")
59
+
60
+
61
+ # Default logger instance
62
+ _default_logger: TranscriptionLogger | None = None
63
+
64
+
65
+ def get_default_logger() -> TranscriptionLogger:
66
+ """Get the default transcription logger instance."""
67
+ global _default_logger
68
+ if _default_logger is None:
69
+ _default_logger = TranscriptionLogger()
70
+ return _default_logger