agent-cli 0.70.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. agent_cli/__init__.py +5 -0
  2. agent_cli/__main__.py +6 -0
  3. agent_cli/_extras.json +14 -0
  4. agent_cli/_requirements/.gitkeep +0 -0
  5. agent_cli/_requirements/audio.txt +79 -0
  6. agent_cli/_requirements/faster-whisper.txt +215 -0
  7. agent_cli/_requirements/kokoro.txt +425 -0
  8. agent_cli/_requirements/llm.txt +183 -0
  9. agent_cli/_requirements/memory.txt +355 -0
  10. agent_cli/_requirements/mlx-whisper.txt +222 -0
  11. agent_cli/_requirements/piper.txt +176 -0
  12. agent_cli/_requirements/rag.txt +402 -0
  13. agent_cli/_requirements/server.txt +154 -0
  14. agent_cli/_requirements/speed.txt +77 -0
  15. agent_cli/_requirements/vad.txt +155 -0
  16. agent_cli/_requirements/wyoming.txt +71 -0
  17. agent_cli/_tools.py +368 -0
  18. agent_cli/agents/__init__.py +23 -0
  19. agent_cli/agents/_voice_agent_common.py +136 -0
  20. agent_cli/agents/assistant.py +383 -0
  21. agent_cli/agents/autocorrect.py +284 -0
  22. agent_cli/agents/chat.py +496 -0
  23. agent_cli/agents/memory/__init__.py +31 -0
  24. agent_cli/agents/memory/add.py +190 -0
  25. agent_cli/agents/memory/proxy.py +160 -0
  26. agent_cli/agents/rag_proxy.py +128 -0
  27. agent_cli/agents/speak.py +209 -0
  28. agent_cli/agents/transcribe.py +671 -0
  29. agent_cli/agents/transcribe_daemon.py +499 -0
  30. agent_cli/agents/voice_edit.py +291 -0
  31. agent_cli/api.py +22 -0
  32. agent_cli/cli.py +106 -0
  33. agent_cli/config.py +503 -0
  34. agent_cli/config_cmd.py +307 -0
  35. agent_cli/constants.py +27 -0
  36. agent_cli/core/__init__.py +1 -0
  37. agent_cli/core/audio.py +461 -0
  38. agent_cli/core/audio_format.py +299 -0
  39. agent_cli/core/chroma.py +88 -0
  40. agent_cli/core/deps.py +191 -0
  41. agent_cli/core/openai_proxy.py +139 -0
  42. agent_cli/core/process.py +195 -0
  43. agent_cli/core/reranker.py +120 -0
  44. agent_cli/core/sse.py +87 -0
  45. agent_cli/core/transcription_logger.py +70 -0
  46. agent_cli/core/utils.py +526 -0
  47. agent_cli/core/vad.py +175 -0
  48. agent_cli/core/watch.py +65 -0
  49. agent_cli/dev/__init__.py +14 -0
  50. agent_cli/dev/cli.py +1588 -0
  51. agent_cli/dev/coding_agents/__init__.py +19 -0
  52. agent_cli/dev/coding_agents/aider.py +24 -0
  53. agent_cli/dev/coding_agents/base.py +167 -0
  54. agent_cli/dev/coding_agents/claude.py +39 -0
  55. agent_cli/dev/coding_agents/codex.py +24 -0
  56. agent_cli/dev/coding_agents/continue_dev.py +15 -0
  57. agent_cli/dev/coding_agents/copilot.py +24 -0
  58. agent_cli/dev/coding_agents/cursor_agent.py +48 -0
  59. agent_cli/dev/coding_agents/gemini.py +28 -0
  60. agent_cli/dev/coding_agents/opencode.py +15 -0
  61. agent_cli/dev/coding_agents/registry.py +49 -0
  62. agent_cli/dev/editors/__init__.py +19 -0
  63. agent_cli/dev/editors/base.py +89 -0
  64. agent_cli/dev/editors/cursor.py +15 -0
  65. agent_cli/dev/editors/emacs.py +46 -0
  66. agent_cli/dev/editors/jetbrains.py +56 -0
  67. agent_cli/dev/editors/nano.py +31 -0
  68. agent_cli/dev/editors/neovim.py +33 -0
  69. agent_cli/dev/editors/registry.py +59 -0
  70. agent_cli/dev/editors/sublime.py +20 -0
  71. agent_cli/dev/editors/vim.py +42 -0
  72. agent_cli/dev/editors/vscode.py +15 -0
  73. agent_cli/dev/editors/zed.py +20 -0
  74. agent_cli/dev/project.py +568 -0
  75. agent_cli/dev/registry.py +52 -0
  76. agent_cli/dev/skill/SKILL.md +141 -0
  77. agent_cli/dev/skill/examples.md +571 -0
  78. agent_cli/dev/terminals/__init__.py +19 -0
  79. agent_cli/dev/terminals/apple_terminal.py +82 -0
  80. agent_cli/dev/terminals/base.py +56 -0
  81. agent_cli/dev/terminals/gnome.py +51 -0
  82. agent_cli/dev/terminals/iterm2.py +84 -0
  83. agent_cli/dev/terminals/kitty.py +77 -0
  84. agent_cli/dev/terminals/registry.py +48 -0
  85. agent_cli/dev/terminals/tmux.py +58 -0
  86. agent_cli/dev/terminals/warp.py +132 -0
  87. agent_cli/dev/terminals/zellij.py +78 -0
  88. agent_cli/dev/worktree.py +856 -0
  89. agent_cli/docs_gen.py +417 -0
  90. agent_cli/example-config.toml +185 -0
  91. agent_cli/install/__init__.py +5 -0
  92. agent_cli/install/common.py +89 -0
  93. agent_cli/install/extras.py +174 -0
  94. agent_cli/install/hotkeys.py +48 -0
  95. agent_cli/install/services.py +87 -0
  96. agent_cli/memory/__init__.py +7 -0
  97. agent_cli/memory/_files.py +250 -0
  98. agent_cli/memory/_filters.py +63 -0
  99. agent_cli/memory/_git.py +157 -0
  100. agent_cli/memory/_indexer.py +142 -0
  101. agent_cli/memory/_ingest.py +408 -0
  102. agent_cli/memory/_persistence.py +182 -0
  103. agent_cli/memory/_prompt.py +91 -0
  104. agent_cli/memory/_retrieval.py +294 -0
  105. agent_cli/memory/_store.py +169 -0
  106. agent_cli/memory/_streaming.py +44 -0
  107. agent_cli/memory/_tasks.py +48 -0
  108. agent_cli/memory/api.py +113 -0
  109. agent_cli/memory/client.py +272 -0
  110. agent_cli/memory/engine.py +361 -0
  111. agent_cli/memory/entities.py +43 -0
  112. agent_cli/memory/models.py +112 -0
  113. agent_cli/opts.py +433 -0
  114. agent_cli/py.typed +0 -0
  115. agent_cli/rag/__init__.py +3 -0
  116. agent_cli/rag/_indexer.py +67 -0
  117. agent_cli/rag/_indexing.py +226 -0
  118. agent_cli/rag/_prompt.py +30 -0
  119. agent_cli/rag/_retriever.py +156 -0
  120. agent_cli/rag/_store.py +48 -0
  121. agent_cli/rag/_utils.py +218 -0
  122. agent_cli/rag/api.py +175 -0
  123. agent_cli/rag/client.py +299 -0
  124. agent_cli/rag/engine.py +302 -0
  125. agent_cli/rag/models.py +55 -0
  126. agent_cli/scripts/.runtime/.gitkeep +0 -0
  127. agent_cli/scripts/__init__.py +1 -0
  128. agent_cli/scripts/check_plugin_skill_sync.py +50 -0
  129. agent_cli/scripts/linux-hotkeys/README.md +63 -0
  130. agent_cli/scripts/linux-hotkeys/toggle-autocorrect.sh +45 -0
  131. agent_cli/scripts/linux-hotkeys/toggle-transcription.sh +58 -0
  132. agent_cli/scripts/linux-hotkeys/toggle-voice-edit.sh +58 -0
  133. agent_cli/scripts/macos-hotkeys/README.md +45 -0
  134. agent_cli/scripts/macos-hotkeys/skhd-config-example +5 -0
  135. agent_cli/scripts/macos-hotkeys/toggle-autocorrect.sh +12 -0
  136. agent_cli/scripts/macos-hotkeys/toggle-transcription.sh +37 -0
  137. agent_cli/scripts/macos-hotkeys/toggle-voice-edit.sh +37 -0
  138. agent_cli/scripts/nvidia-asr-server/README.md +99 -0
  139. agent_cli/scripts/nvidia-asr-server/pyproject.toml +27 -0
  140. agent_cli/scripts/nvidia-asr-server/server.py +255 -0
  141. agent_cli/scripts/nvidia-asr-server/shell.nix +32 -0
  142. agent_cli/scripts/nvidia-asr-server/uv.lock +4654 -0
  143. agent_cli/scripts/run-openwakeword.sh +11 -0
  144. agent_cli/scripts/run-piper-windows.ps1 +30 -0
  145. agent_cli/scripts/run-piper.sh +24 -0
  146. agent_cli/scripts/run-whisper-linux.sh +40 -0
  147. agent_cli/scripts/run-whisper-macos.sh +6 -0
  148. agent_cli/scripts/run-whisper-windows.ps1 +51 -0
  149. agent_cli/scripts/run-whisper.sh +9 -0
  150. agent_cli/scripts/run_faster_whisper_server.py +136 -0
  151. agent_cli/scripts/setup-linux-hotkeys.sh +72 -0
  152. agent_cli/scripts/setup-linux.sh +108 -0
  153. agent_cli/scripts/setup-macos-hotkeys.sh +61 -0
  154. agent_cli/scripts/setup-macos.sh +76 -0
  155. agent_cli/scripts/setup-windows.ps1 +63 -0
  156. agent_cli/scripts/start-all-services-windows.ps1 +53 -0
  157. agent_cli/scripts/start-all-services.sh +178 -0
  158. agent_cli/scripts/sync_extras.py +138 -0
  159. agent_cli/server/__init__.py +3 -0
  160. agent_cli/server/cli.py +721 -0
  161. agent_cli/server/common.py +222 -0
  162. agent_cli/server/model_manager.py +288 -0
  163. agent_cli/server/model_registry.py +225 -0
  164. agent_cli/server/proxy/__init__.py +3 -0
  165. agent_cli/server/proxy/api.py +444 -0
  166. agent_cli/server/streaming.py +67 -0
  167. agent_cli/server/tts/__init__.py +3 -0
  168. agent_cli/server/tts/api.py +335 -0
  169. agent_cli/server/tts/backends/__init__.py +82 -0
  170. agent_cli/server/tts/backends/base.py +139 -0
  171. agent_cli/server/tts/backends/kokoro.py +403 -0
  172. agent_cli/server/tts/backends/piper.py +253 -0
  173. agent_cli/server/tts/model_manager.py +201 -0
  174. agent_cli/server/tts/model_registry.py +28 -0
  175. agent_cli/server/tts/wyoming_handler.py +249 -0
  176. agent_cli/server/whisper/__init__.py +3 -0
  177. agent_cli/server/whisper/api.py +413 -0
  178. agent_cli/server/whisper/backends/__init__.py +89 -0
  179. agent_cli/server/whisper/backends/base.py +97 -0
  180. agent_cli/server/whisper/backends/faster_whisper.py +225 -0
  181. agent_cli/server/whisper/backends/mlx.py +270 -0
  182. agent_cli/server/whisper/languages.py +116 -0
  183. agent_cli/server/whisper/model_manager.py +157 -0
  184. agent_cli/server/whisper/model_registry.py +28 -0
  185. agent_cli/server/whisper/wyoming_handler.py +203 -0
  186. agent_cli/services/__init__.py +343 -0
  187. agent_cli/services/_wyoming_utils.py +64 -0
  188. agent_cli/services/asr.py +506 -0
  189. agent_cli/services/llm.py +228 -0
  190. agent_cli/services/tts.py +450 -0
  191. agent_cli/services/wake_word.py +142 -0
  192. agent_cli-0.70.5.dist-info/METADATA +2118 -0
  193. agent_cli-0.70.5.dist-info/RECORD +196 -0
  194. agent_cli-0.70.5.dist-info/WHEEL +4 -0
  195. agent_cli-0.70.5.dist-info/entry_points.txt +4 -0
  196. agent_cli-0.70.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,12 @@
1
+ #!/usr/bin/env bash
2
+
3
+ # Toggle script for agent-cli autocorrect on macOS
4
+
5
+ /opt/homebrew/bin/terminal-notifier -title "📝 Autocorrect" -message "Processing clipboard text..."
6
+
7
+ OUTPUT=$("$HOME/.local/bin/agent-cli" autocorrect --quiet 2>/dev/null)
8
+ if [ -n "$OUTPUT" ]; then
9
+ /opt/homebrew/bin/terminal-notifier -title "✅ Corrected" -message "$OUTPUT"
10
+ else
11
+ /opt/homebrew/bin/terminal-notifier -title "❌ Error" -message "No text to correct"
12
+ fi
@@ -0,0 +1,37 @@
1
+ #!/usr/bin/env bash
2
+
3
+ # Toggle script for agent-cli transcription on macOS
4
+
5
+ NOTIFIER=${NOTIFIER:-/opt/homebrew/bin/terminal-notifier}
6
+ RECORDING_GROUP="agent-cli-transcribe-recording"
7
+ TEMP_PREFIX="agent-cli-transcribe-temp"
8
+
9
+ notify_temp() {
10
+ local title=$1
11
+ local message=$2
12
+ local duration=${3:-4} # 4 seconds default
13
+ local group="${TEMP_PREFIX}-${RANDOM}-$$"
14
+
15
+ "$NOTIFIER" -title "$title" -message "$message" -group "$group"
16
+ (
17
+ sleep "$duration"
18
+ "$NOTIFIER" -remove "$group" >/dev/null 2>&1 || true
19
+ ) &
20
+ }
21
+
22
+ if pgrep -f "agent-cli transcribe( |$)" > /dev/null; then
23
+ pkill -INT -f "agent-cli transcribe( |$)"
24
+ "$NOTIFIER" -remove "$RECORDING_GROUP" >/dev/null 2>&1 || true
25
+ notify_temp "🛑 Stopped" "Processing results..."
26
+ else
27
+ "$NOTIFIER" -title "🎙️ Started" -message "Listening..." -group "$RECORDING_GROUP"
28
+ (
29
+ OUTPUT=$("$HOME/.local/bin/agent-cli" transcribe --llm --quiet 2>/dev/null)
30
+ "$NOTIFIER" -remove "$RECORDING_GROUP" >/dev/null 2>&1 || true
31
+ if [ -n "$OUTPUT" ]; then
32
+ notify_temp "📄 Result" "$OUTPUT"
33
+ else
34
+ notify_temp "❌ Error" "No output"
35
+ fi
36
+ ) &
37
+ fi
@@ -0,0 +1,37 @@
1
+ #!/usr/bin/env bash
2
+
3
+ # Toggle script for agent-cli voice-edit on macOS
4
+
5
+ NOTIFIER=${NOTIFIER:-/opt/homebrew/bin/terminal-notifier}
6
+ RECORDING_GROUP="agent-cli-voice-edit-recording"
7
+ TEMP_PREFIX="agent-cli-voice-edit-temp"
8
+
9
+ notify_temp() {
10
+ local title=$1
11
+ local message=$2
12
+ local duration=${3:-4} # 4 seconds default
13
+ local group="${TEMP_PREFIX}-${RANDOM}-$$"
14
+
15
+ "$NOTIFIER" -title "$title" -message "$message" -group "$group"
16
+ (
17
+ sleep "$duration"
18
+ "$NOTIFIER" -remove "$group" >/dev/null 2>&1 || true
19
+ ) &
20
+ }
21
+
22
+ if pgrep -f "agent-cli voice-edit" > /dev/null; then
23
+ pkill -INT -f "agent-cli voice-edit"
24
+ "$NOTIFIER" -remove "$RECORDING_GROUP" >/dev/null 2>&1 || true
25
+ notify_temp "🛑 Stopped" "Processing voice command..."
26
+ else
27
+ "$NOTIFIER" -title "🎙️ Started" -message "Listening for voice command..." -group "$RECORDING_GROUP"
28
+ (
29
+ OUTPUT=$("$HOME/.local/bin/agent-cli" voice-edit --quiet 2>/dev/null)
30
+ "$NOTIFIER" -remove "$RECORDING_GROUP" >/dev/null 2>&1 || true
31
+ if [ -n "$OUTPUT" ]; then
32
+ notify_temp "✨ Voice Edit Result" "$OUTPUT"
33
+ else
34
+ notify_temp "❌ Error" "No output"
35
+ fi
36
+ ) &
37
+ fi
@@ -0,0 +1,99 @@
1
+ # NVIDIA ASR Server
2
+
3
+ OpenAI-compatible API server for NVIDIA ASR models.
4
+
5
+ ## Quick Start
6
+
7
+ ```bash
8
+ cd scripts/nvidia-asr-server
9
+ uv run server.py
10
+ ```
11
+
12
+ Server runs at `http://localhost:9898`
13
+
14
+ ## CLI Options
15
+
16
+ - `--model`, `-m`: Model to use (default: `canary-qwen-2.5b`)
17
+ - `canary-qwen-2.5b`: Multilingual ASR (~5GB VRAM)
18
+ - `parakeet-tdt-0.6b-v2`: English with timestamps (~2GB VRAM)
19
+ - `--port`, `-p`: Port (default: 9898)
20
+ - `--device`, `-d`: Device (default: auto-select best GPU)
21
+
22
+ ```bash
23
+ # Examples
24
+ uv run server.py --model parakeet-tdt-0.6b-v2
25
+ uv run server.py -m parakeet-tdt-0.6b-v2 -p 9090 -d cuda:1
26
+ ```
27
+
28
+ ## Using with Agent-CLI
29
+
30
+ ```bash
31
+ # Start server
32
+ cd scripts/nvidia-asr-server
33
+ uv run server.py
34
+
35
+ # In another terminal
36
+ agent-cli transcribe \
37
+ --asr-provider openai \
38
+ --asr-openai-base-url http://localhost:9898/v1
39
+ ```
40
+
41
+ **Note**: The `/v1` suffix is required for OpenAI compatibility.
42
+
43
+ ## API Usage
44
+
45
+ ### Python Example
46
+
47
+ ```python
48
+ import requests
49
+
50
+ with open("audio.wav", "rb") as f:
51
+ response = requests.post(
52
+ "http://localhost:9898/v1/audio/transcriptions",
53
+ files={"file": f},
54
+ data={"model": "parakeet-tdt-0.6b-v2"}
55
+ )
56
+
57
+ print(response.json()["text"])
58
+ ```
59
+
60
+ ### With Timestamps (Parakeet only)
61
+
62
+ ```python
63
+ response = requests.post(
64
+ "http://localhost:9898/v1/audio/transcriptions",
65
+ files={"file": open("audio.wav", "rb")},
66
+ data={
67
+ "model": "parakeet-tdt-0.6b-v2",
68
+ "timestamp_granularities": ["word"]
69
+ }
70
+ )
71
+
72
+ result = response.json()
73
+ for word in result.get("words", []):
74
+ print(f"{word['start']:.2f}s - {word['end']:.2f}s: {word['word']}")
75
+ ```
76
+
77
+ ## Requirements
78
+
79
+ - Python 3.13+
80
+ - CUDA-compatible GPU (recommended)
81
+ - ~2-5GB VRAM depending on model
82
+
83
+ ## Troubleshooting
84
+
85
+ **GPU out of memory**: Try smaller model or CPU
86
+ ```bash
87
+ uv run server.py --model parakeet-tdt-0.6b-v2
88
+ uv run server.py --device cpu
89
+ ```
90
+
91
+ **Port in use**: Change port
92
+ ```bash
93
+ uv run server.py --port 9999
94
+ ```
95
+
96
+ ## License
97
+
98
+ - Canary: NVIDIA AI Foundation Models Community License
99
+ - Parakeet: CC-BY-4.0
@@ -0,0 +1,27 @@
1
+ [project]
2
+ name = "nvidia-asr-server"
3
+ version = "1.0.0"
4
+ description = "NVIDIA ASR server with OpenAI-compatible API"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "fastapi[standard]>=0.115.0",
9
+ "torch>=2.5.0",
10
+ "soundfile>=0.12.1",
11
+ "sacrebleu>=2.4.0",
12
+ "typer>=0.9.0",
13
+ "nemo-toolkit[asr,tts] @ git+https://github.com/NVIDIA/NeMo.git",
14
+ ]
15
+
16
+ [tool.uv.sources]
17
+ torch = [{ index = "pytorch-cu124" }]
18
+
19
+ [[tool.uv.index]]
20
+ name = "pytorch-cu124"
21
+ url = "https://download.pytorch.org/whl/cu124"
22
+ explicit = true
23
+
24
+ [tool.uv]
25
+ override-dependencies = [
26
+ "ml-dtypes>=0.5.0",
27
+ ]
@@ -0,0 +1,255 @@
1
+ #!/usr/bin/env -S uv run
2
+ """NVIDIA ASR server with OpenAI-compatible API.
3
+
4
+ Supports multiple NVIDIA ASR models:
5
+ - nvidia/canary-qwen-2.5b (default): Multilingual ASR with translation capabilities
6
+ - nvidia/parakeet-tdt-0.6b-v2: High-quality English ASR with timestamps
7
+
8
+ Usage:
9
+ cd scripts/nvidia-asr-server
10
+ uv run server.py
11
+ uv run server.py --model parakeet-tdt-0.6b-v2
12
+ uv run server.py --port 9090 --device cuda:1
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import shutil
18
+ import subprocess
19
+ import tempfile
20
+ import traceback
21
+ from contextlib import asynccontextmanager, suppress
22
+ from dataclasses import dataclass
23
+ from enum import Enum
24
+ from pathlib import Path
25
+ from typing import TYPE_CHECKING, Annotated, Any
26
+
27
+ import torch
28
+ import typer
29
+ import uvicorn
30
+ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
31
+ from fastapi.responses import JSONResponse
32
+
33
+ if TYPE_CHECKING:
34
+ from collections.abc import AsyncGenerator
35
+ from typing import TypedDict
36
+
37
+ class TranscriptionResult(TypedDict, total=False):
38
+ """Transcription result with optional word-level timestamps."""
39
+
40
+ text: str
41
+ words: list[dict[str, Any]]
42
+
43
+
44
+ class ModelType(str, Enum):
45
+ """Supported ASR models."""
46
+
47
+ CANARY = "canary-qwen-2.5b"
48
+ PARAKEET = "parakeet-tdt-0.6b-v2"
49
+
50
+
51
+ @dataclass
52
+ class ServerConfig:
53
+ """Server configuration."""
54
+
55
+ model_type: ModelType
56
+ device: str
57
+ port: int
58
+
59
+
60
+ def select_best_gpu() -> str:
61
+ """Select the GPU with the most free memory, or CPU if no GPU available."""
62
+ if not torch.cuda.is_available():
63
+ return "cpu"
64
+
65
+ if torch.cuda.device_count() == 1:
66
+ return "cuda:0"
67
+
68
+ best_gpu = max(
69
+ range(torch.cuda.device_count()),
70
+ key=lambda i: torch.cuda.mem_get_info(i)[0],
71
+ )
72
+ return f"cuda:{best_gpu}"
73
+
74
+
75
+ def resample_audio(input_path: str) -> str:
76
+ """Resample audio to 16kHz mono WAV using ffmpeg."""
77
+ out_path = f"{input_path}_16k.wav"
78
+ cmd = [
79
+ "ffmpeg",
80
+ "-y",
81
+ "-i",
82
+ input_path,
83
+ "-ar",
84
+ "16000",
85
+ "-ac",
86
+ "1",
87
+ out_path,
88
+ ]
89
+ result = subprocess.run(cmd, capture_output=True, check=False)
90
+ if result.returncode != 0:
91
+ stderr = result.stderr.decode() if result.stderr else "No error output"
92
+ msg = f"ffmpeg failed: {stderr}"
93
+ raise RuntimeError(msg)
94
+ return out_path
95
+
96
+
97
+ def load_asr_model(config: ServerConfig) -> Any:
98
+ """Load the appropriate ASR model based on configuration."""
99
+ import nemo.collections.asr as nemo_asr # noqa: PLC0415
100
+ from nemo.collections.speechlm2.models import SALM # noqa: PLC0415
101
+
102
+ model_name = f"nvidia/{config.model_type.value}"
103
+
104
+ # Print device info
105
+ if config.device.startswith("cuda"):
106
+ gpu_id = int(config.device.split(":")[1]) if ":" in config.device else 0
107
+ free_mem, total_mem = torch.cuda.mem_get_info(gpu_id)
108
+ free_gb = free_mem / 1024**3
109
+ total_gb = total_mem / 1024**3
110
+ print(
111
+ f"Loading {model_name} on {config.device} ({free_gb:.1f}GB / {total_gb:.1f}GB)",
112
+ flush=True,
113
+ )
114
+ else:
115
+ print(f"Loading {model_name} on {config.device}", flush=True)
116
+
117
+ if config.model_type == ModelType.CANARY:
118
+ model = SALM.from_pretrained(model_name)
119
+ elif config.model_type == ModelType.PARAKEET:
120
+ model = nemo_asr.models.ASRModel.from_pretrained(model_name)
121
+ else:
122
+ msg = f"Unsupported model type: {config.model_type}"
123
+ raise ValueError(msg)
124
+
125
+ return model.to(config.device).eval()
126
+
127
+
128
+ asr_model: Any = None
129
+ config: ServerConfig | None = None
130
+
131
+
132
+ @asynccontextmanager
133
+ async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
134
+ """Load the ASR model on startup."""
135
+ global asr_model
136
+ assert config is not None
137
+ asr_model = load_asr_model(config)
138
+ yield
139
+
140
+
141
+ app = FastAPI(lifespan=lifespan)
142
+
143
+
144
+ def transcribe_canary(audio_path: str, prompt: str | None) -> str:
145
+ """Transcribe audio using Canary model."""
146
+ user_prompt = prompt or "Transcribe the following:"
147
+ full_prompt = f"{user_prompt} {asr_model.audio_locator_tag}"
148
+
149
+ prompts = [[{"role": "user", "content": full_prompt, "audio": [audio_path]}]]
150
+ answer_ids = asr_model.generate(prompts=prompts, max_new_tokens=128)
151
+ return asr_model.tokenizer.ids_to_text(answer_ids[0].cpu())
152
+
153
+
154
+ def transcribe_parakeet(
155
+ audio_path: str,
156
+ timestamp_granularities: list[str] | None,
157
+ ) -> TranscriptionResult:
158
+ """Transcribe audio using Parakeet model."""
159
+ enable_timestamps = bool(timestamp_granularities)
160
+ output = asr_model.transcribe([audio_path], timestamps=enable_timestamps)
161
+
162
+ result: TranscriptionResult = {"text": output[0].text}
163
+
164
+ if enable_timestamps and timestamp_granularities and "word" in timestamp_granularities:
165
+ word_timestamps = output[0].timestamp.get("word", [])
166
+ if word_timestamps:
167
+ result["words"] = [
168
+ {"word": w["word"], "start": w["start"], "end": w["end"]} for w in word_timestamps
169
+ ]
170
+
171
+ return result
172
+
173
+
174
+ def cleanup_files(*paths: str | None) -> None:
175
+ """Clean up temporary files."""
176
+ for p in paths:
177
+ if p:
178
+ with suppress(OSError):
179
+ Path(p).unlink(missing_ok=True)
180
+
181
+
182
+ @app.post("/v1/audio/transcriptions", response_model=None)
183
+ async def transcribe(
184
+ file: Annotated[UploadFile, File()],
185
+ response_format: Annotated[str, Form()] = "json",
186
+ prompt: Annotated[str | None, Form()] = None,
187
+ timestamp_granularities: Annotated[list[str] | None, Form()] = None,
188
+ ) -> str | JSONResponse:
189
+ """Transcribe audio using ASR model with OpenAI-compatible API."""
190
+ if asr_model is None:
191
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
192
+
193
+ with tempfile.NamedTemporaryFile(delete=False, suffix="") as tmp:
194
+ tmp_path = tmp.name
195
+ shutil.copyfileobj(file.file, tmp)
196
+
197
+ resampled_path = None
198
+ try:
199
+ resampled_path = resample_audio(tmp_path)
200
+
201
+ with torch.inference_mode():
202
+ assert config is not None
203
+ if config.model_type == ModelType.CANARY:
204
+ text = transcribe_canary(resampled_path, prompt)
205
+ return text if response_format == "text" else JSONResponse({"text": text})
206
+ if config.model_type == ModelType.PARAKEET:
207
+ result = transcribe_parakeet(resampled_path, timestamp_granularities)
208
+ return result["text"] if response_format == "text" else JSONResponse(result)
209
+
210
+ except Exception as e:
211
+ traceback.print_exc()
212
+ raise HTTPException(status_code=500, detail=str(e)) from e
213
+ finally:
214
+ cleanup_files(tmp_path, resampled_path)
215
+
216
+
217
+ def main(
218
+ model: Annotated[
219
+ ModelType,
220
+ typer.Option("--model", "-m", help="ASR model to use"),
221
+ ] = ModelType.CANARY,
222
+ port: Annotated[int, typer.Option("--port", "-p", help="Server port")] = 9898,
223
+ device: Annotated[
224
+ str | None,
225
+ typer.Option(
226
+ "--device",
227
+ "-d",
228
+ help="Device to use (cpu, cuda, cuda:0, etc.). Auto-selects GPU with most free memory if not specified.",
229
+ ),
230
+ ] = None,
231
+ ) -> None:
232
+ """Run NVIDIA ASR server with OpenAI-compatible API.
233
+
234
+ Supports multiple models:
235
+ - canary-qwen-2.5b: Multilingual ASR with translation (default)
236
+ - parakeet-tdt-0.6b-v2: High-quality English ASR with timestamps
237
+ """
238
+ global config
239
+
240
+ config = ServerConfig(
241
+ model_type=model,
242
+ device=device or select_best_gpu(),
243
+ port=port,
244
+ )
245
+
246
+ print(f"Starting ASR server with model: {model.value}")
247
+ print(f"Device: {config.device}")
248
+ print(f"Port: {config.port}")
249
+ print()
250
+
251
+ uvicorn.run(app, host="0.0.0.0", port=config.port) # noqa: S104
252
+
253
+
254
+ if __name__ == "__main__":
255
+ typer.run(main)
@@ -0,0 +1,32 @@
1
+ { pkgs ? import <nixpkgs> { config.allowUnfree = true; } }:
2
+
3
+ pkgs.mkShell {
4
+ buildInputs = with pkgs; [
5
+ # Python and uv
6
+ python313
7
+ uv
8
+
9
+ # Audio libraries
10
+ ffmpeg
11
+ ];
12
+
13
+ shellHook = ''
14
+ # Set up CUDA environment (use system NVIDIA drivers and CUDA libraries)
15
+ export LD_LIBRARY_PATH=/run/opengl-driver/lib:/run/current-system/sw/lib:$LD_LIBRARY_PATH
16
+
17
+ # Tell triton where to find libcuda.so (avoids calling /sbin/ldconfig)
18
+ export TRITON_LIBCUDA_PATH=/run/opengl-driver/lib
19
+
20
+ # PyTorch memory management - avoid fragmentation
21
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
22
+
23
+ # Canary server defaults
24
+ export CANARY_PORT=9898
25
+ # CANARY_DEVICE auto-detects GPU with most free memory (override if needed)
26
+
27
+ echo "CUDA environment configured (using system NVIDIA drivers)"
28
+ echo "TRITON_LIBCUDA_PATH: $TRITON_LIBCUDA_PATH"
29
+ echo "PYTORCH_CUDA_ALLOC_CONF: $PYTORCH_CUDA_ALLOC_CONF"
30
+ echo "Run 'uv run server.py' to start the server"
31
+ '';
32
+ }