agent-cli 0.70.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_cli/__init__.py +5 -0
- agent_cli/__main__.py +6 -0
- agent_cli/_extras.json +14 -0
- agent_cli/_requirements/.gitkeep +0 -0
- agent_cli/_requirements/audio.txt +79 -0
- agent_cli/_requirements/faster-whisper.txt +215 -0
- agent_cli/_requirements/kokoro.txt +425 -0
- agent_cli/_requirements/llm.txt +183 -0
- agent_cli/_requirements/memory.txt +355 -0
- agent_cli/_requirements/mlx-whisper.txt +222 -0
- agent_cli/_requirements/piper.txt +176 -0
- agent_cli/_requirements/rag.txt +402 -0
- agent_cli/_requirements/server.txt +154 -0
- agent_cli/_requirements/speed.txt +77 -0
- agent_cli/_requirements/vad.txt +155 -0
- agent_cli/_requirements/wyoming.txt +71 -0
- agent_cli/_tools.py +368 -0
- agent_cli/agents/__init__.py +23 -0
- agent_cli/agents/_voice_agent_common.py +136 -0
- agent_cli/agents/assistant.py +383 -0
- agent_cli/agents/autocorrect.py +284 -0
- agent_cli/agents/chat.py +496 -0
- agent_cli/agents/memory/__init__.py +31 -0
- agent_cli/agents/memory/add.py +190 -0
- agent_cli/agents/memory/proxy.py +160 -0
- agent_cli/agents/rag_proxy.py +128 -0
- agent_cli/agents/speak.py +209 -0
- agent_cli/agents/transcribe.py +671 -0
- agent_cli/agents/transcribe_daemon.py +499 -0
- agent_cli/agents/voice_edit.py +291 -0
- agent_cli/api.py +22 -0
- agent_cli/cli.py +106 -0
- agent_cli/config.py +503 -0
- agent_cli/config_cmd.py +307 -0
- agent_cli/constants.py +27 -0
- agent_cli/core/__init__.py +1 -0
- agent_cli/core/audio.py +461 -0
- agent_cli/core/audio_format.py +299 -0
- agent_cli/core/chroma.py +88 -0
- agent_cli/core/deps.py +191 -0
- agent_cli/core/openai_proxy.py +139 -0
- agent_cli/core/process.py +195 -0
- agent_cli/core/reranker.py +120 -0
- agent_cli/core/sse.py +87 -0
- agent_cli/core/transcription_logger.py +70 -0
- agent_cli/core/utils.py +526 -0
- agent_cli/core/vad.py +175 -0
- agent_cli/core/watch.py +65 -0
- agent_cli/dev/__init__.py +14 -0
- agent_cli/dev/cli.py +1588 -0
- agent_cli/dev/coding_agents/__init__.py +19 -0
- agent_cli/dev/coding_agents/aider.py +24 -0
- agent_cli/dev/coding_agents/base.py +167 -0
- agent_cli/dev/coding_agents/claude.py +39 -0
- agent_cli/dev/coding_agents/codex.py +24 -0
- agent_cli/dev/coding_agents/continue_dev.py +15 -0
- agent_cli/dev/coding_agents/copilot.py +24 -0
- agent_cli/dev/coding_agents/cursor_agent.py +48 -0
- agent_cli/dev/coding_agents/gemini.py +28 -0
- agent_cli/dev/coding_agents/opencode.py +15 -0
- agent_cli/dev/coding_agents/registry.py +49 -0
- agent_cli/dev/editors/__init__.py +19 -0
- agent_cli/dev/editors/base.py +89 -0
- agent_cli/dev/editors/cursor.py +15 -0
- agent_cli/dev/editors/emacs.py +46 -0
- agent_cli/dev/editors/jetbrains.py +56 -0
- agent_cli/dev/editors/nano.py +31 -0
- agent_cli/dev/editors/neovim.py +33 -0
- agent_cli/dev/editors/registry.py +59 -0
- agent_cli/dev/editors/sublime.py +20 -0
- agent_cli/dev/editors/vim.py +42 -0
- agent_cli/dev/editors/vscode.py +15 -0
- agent_cli/dev/editors/zed.py +20 -0
- agent_cli/dev/project.py +568 -0
- agent_cli/dev/registry.py +52 -0
- agent_cli/dev/skill/SKILL.md +141 -0
- agent_cli/dev/skill/examples.md +571 -0
- agent_cli/dev/terminals/__init__.py +19 -0
- agent_cli/dev/terminals/apple_terminal.py +82 -0
- agent_cli/dev/terminals/base.py +56 -0
- agent_cli/dev/terminals/gnome.py +51 -0
- agent_cli/dev/terminals/iterm2.py +84 -0
- agent_cli/dev/terminals/kitty.py +77 -0
- agent_cli/dev/terminals/registry.py +48 -0
- agent_cli/dev/terminals/tmux.py +58 -0
- agent_cli/dev/terminals/warp.py +132 -0
- agent_cli/dev/terminals/zellij.py +78 -0
- agent_cli/dev/worktree.py +856 -0
- agent_cli/docs_gen.py +417 -0
- agent_cli/example-config.toml +185 -0
- agent_cli/install/__init__.py +5 -0
- agent_cli/install/common.py +89 -0
- agent_cli/install/extras.py +174 -0
- agent_cli/install/hotkeys.py +48 -0
- agent_cli/install/services.py +87 -0
- agent_cli/memory/__init__.py +7 -0
- agent_cli/memory/_files.py +250 -0
- agent_cli/memory/_filters.py +63 -0
- agent_cli/memory/_git.py +157 -0
- agent_cli/memory/_indexer.py +142 -0
- agent_cli/memory/_ingest.py +408 -0
- agent_cli/memory/_persistence.py +182 -0
- agent_cli/memory/_prompt.py +91 -0
- agent_cli/memory/_retrieval.py +294 -0
- agent_cli/memory/_store.py +169 -0
- agent_cli/memory/_streaming.py +44 -0
- agent_cli/memory/_tasks.py +48 -0
- agent_cli/memory/api.py +113 -0
- agent_cli/memory/client.py +272 -0
- agent_cli/memory/engine.py +361 -0
- agent_cli/memory/entities.py +43 -0
- agent_cli/memory/models.py +112 -0
- agent_cli/opts.py +433 -0
- agent_cli/py.typed +0 -0
- agent_cli/rag/__init__.py +3 -0
- agent_cli/rag/_indexer.py +67 -0
- agent_cli/rag/_indexing.py +226 -0
- agent_cli/rag/_prompt.py +30 -0
- agent_cli/rag/_retriever.py +156 -0
- agent_cli/rag/_store.py +48 -0
- agent_cli/rag/_utils.py +218 -0
- agent_cli/rag/api.py +175 -0
- agent_cli/rag/client.py +299 -0
- agent_cli/rag/engine.py +302 -0
- agent_cli/rag/models.py +55 -0
- agent_cli/scripts/.runtime/.gitkeep +0 -0
- agent_cli/scripts/__init__.py +1 -0
- agent_cli/scripts/check_plugin_skill_sync.py +50 -0
- agent_cli/scripts/linux-hotkeys/README.md +63 -0
- agent_cli/scripts/linux-hotkeys/toggle-autocorrect.sh +45 -0
- agent_cli/scripts/linux-hotkeys/toggle-transcription.sh +58 -0
- agent_cli/scripts/linux-hotkeys/toggle-voice-edit.sh +58 -0
- agent_cli/scripts/macos-hotkeys/README.md +45 -0
- agent_cli/scripts/macos-hotkeys/skhd-config-example +5 -0
- agent_cli/scripts/macos-hotkeys/toggle-autocorrect.sh +12 -0
- agent_cli/scripts/macos-hotkeys/toggle-transcription.sh +37 -0
- agent_cli/scripts/macos-hotkeys/toggle-voice-edit.sh +37 -0
- agent_cli/scripts/nvidia-asr-server/README.md +99 -0
- agent_cli/scripts/nvidia-asr-server/pyproject.toml +27 -0
- agent_cli/scripts/nvidia-asr-server/server.py +255 -0
- agent_cli/scripts/nvidia-asr-server/shell.nix +32 -0
- agent_cli/scripts/nvidia-asr-server/uv.lock +4654 -0
- agent_cli/scripts/run-openwakeword.sh +11 -0
- agent_cli/scripts/run-piper-windows.ps1 +30 -0
- agent_cli/scripts/run-piper.sh +24 -0
- agent_cli/scripts/run-whisper-linux.sh +40 -0
- agent_cli/scripts/run-whisper-macos.sh +6 -0
- agent_cli/scripts/run-whisper-windows.ps1 +51 -0
- agent_cli/scripts/run-whisper.sh +9 -0
- agent_cli/scripts/run_faster_whisper_server.py +136 -0
- agent_cli/scripts/setup-linux-hotkeys.sh +72 -0
- agent_cli/scripts/setup-linux.sh +108 -0
- agent_cli/scripts/setup-macos-hotkeys.sh +61 -0
- agent_cli/scripts/setup-macos.sh +76 -0
- agent_cli/scripts/setup-windows.ps1 +63 -0
- agent_cli/scripts/start-all-services-windows.ps1 +53 -0
- agent_cli/scripts/start-all-services.sh +178 -0
- agent_cli/scripts/sync_extras.py +138 -0
- agent_cli/server/__init__.py +3 -0
- agent_cli/server/cli.py +721 -0
- agent_cli/server/common.py +222 -0
- agent_cli/server/model_manager.py +288 -0
- agent_cli/server/model_registry.py +225 -0
- agent_cli/server/proxy/__init__.py +3 -0
- agent_cli/server/proxy/api.py +444 -0
- agent_cli/server/streaming.py +67 -0
- agent_cli/server/tts/__init__.py +3 -0
- agent_cli/server/tts/api.py +335 -0
- agent_cli/server/tts/backends/__init__.py +82 -0
- agent_cli/server/tts/backends/base.py +139 -0
- agent_cli/server/tts/backends/kokoro.py +403 -0
- agent_cli/server/tts/backends/piper.py +253 -0
- agent_cli/server/tts/model_manager.py +201 -0
- agent_cli/server/tts/model_registry.py +28 -0
- agent_cli/server/tts/wyoming_handler.py +249 -0
- agent_cli/server/whisper/__init__.py +3 -0
- agent_cli/server/whisper/api.py +413 -0
- agent_cli/server/whisper/backends/__init__.py +89 -0
- agent_cli/server/whisper/backends/base.py +97 -0
- agent_cli/server/whisper/backends/faster_whisper.py +225 -0
- agent_cli/server/whisper/backends/mlx.py +270 -0
- agent_cli/server/whisper/languages.py +116 -0
- agent_cli/server/whisper/model_manager.py +157 -0
- agent_cli/server/whisper/model_registry.py +28 -0
- agent_cli/server/whisper/wyoming_handler.py +203 -0
- agent_cli/services/__init__.py +343 -0
- agent_cli/services/_wyoming_utils.py +64 -0
- agent_cli/services/asr.py +506 -0
- agent_cli/services/llm.py +228 -0
- agent_cli/services/tts.py +450 -0
- agent_cli/services/wake_word.py +142 -0
- agent_cli-0.70.5.dist-info/METADATA +2118 -0
- agent_cli-0.70.5.dist-info/RECORD +196 -0
- agent_cli-0.70.5.dist-info/WHEEL +4 -0
- agent_cli-0.70.5.dist-info/entry_points.txt +4 -0
- agent_cli-0.70.5.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
#!/usr/bin/env bash
|
|
2
|
+
|
|
3
|
+
# Toggle script for agent-cli autocorrect on macOS
|
|
4
|
+
|
|
5
|
+
/opt/homebrew/bin/terminal-notifier -title "📝 Autocorrect" -message "Processing clipboard text..."
|
|
6
|
+
|
|
7
|
+
OUTPUT=$("$HOME/.local/bin/agent-cli" autocorrect --quiet 2>/dev/null)
|
|
8
|
+
if [ -n "$OUTPUT" ]; then
|
|
9
|
+
/opt/homebrew/bin/terminal-notifier -title "✅ Corrected" -message "$OUTPUT"
|
|
10
|
+
else
|
|
11
|
+
/opt/homebrew/bin/terminal-notifier -title "❌ Error" -message "No text to correct"
|
|
12
|
+
fi
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
#!/usr/bin/env bash
|
|
2
|
+
|
|
3
|
+
# Toggle script for agent-cli transcription on macOS
|
|
4
|
+
|
|
5
|
+
NOTIFIER=${NOTIFIER:-/opt/homebrew/bin/terminal-notifier}
|
|
6
|
+
RECORDING_GROUP="agent-cli-transcribe-recording"
|
|
7
|
+
TEMP_PREFIX="agent-cli-transcribe-temp"
|
|
8
|
+
|
|
9
|
+
notify_temp() {
|
|
10
|
+
local title=$1
|
|
11
|
+
local message=$2
|
|
12
|
+
local duration=${3:-4} # 4 seconds default
|
|
13
|
+
local group="${TEMP_PREFIX}-${RANDOM}-$$"
|
|
14
|
+
|
|
15
|
+
"$NOTIFIER" -title "$title" -message "$message" -group "$group"
|
|
16
|
+
(
|
|
17
|
+
sleep "$duration"
|
|
18
|
+
"$NOTIFIER" -remove "$group" >/dev/null 2>&1 || true
|
|
19
|
+
) &
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
if pgrep -f "agent-cli transcribe( |$)" > /dev/null; then
|
|
23
|
+
pkill -INT -f "agent-cli transcribe( |$)"
|
|
24
|
+
"$NOTIFIER" -remove "$RECORDING_GROUP" >/dev/null 2>&1 || true
|
|
25
|
+
notify_temp "🛑 Stopped" "Processing results..."
|
|
26
|
+
else
|
|
27
|
+
"$NOTIFIER" -title "🎙️ Started" -message "Listening..." -group "$RECORDING_GROUP"
|
|
28
|
+
(
|
|
29
|
+
OUTPUT=$("$HOME/.local/bin/agent-cli" transcribe --llm --quiet 2>/dev/null)
|
|
30
|
+
"$NOTIFIER" -remove "$RECORDING_GROUP" >/dev/null 2>&1 || true
|
|
31
|
+
if [ -n "$OUTPUT" ]; then
|
|
32
|
+
notify_temp "📄 Result" "$OUTPUT"
|
|
33
|
+
else
|
|
34
|
+
notify_temp "❌ Error" "No output"
|
|
35
|
+
fi
|
|
36
|
+
) &
|
|
37
|
+
fi
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
#!/usr/bin/env bash
|
|
2
|
+
|
|
3
|
+
# Toggle script for agent-cli voice-edit on macOS
|
|
4
|
+
|
|
5
|
+
NOTIFIER=${NOTIFIER:-/opt/homebrew/bin/terminal-notifier}
|
|
6
|
+
RECORDING_GROUP="agent-cli-voice-edit-recording"
|
|
7
|
+
TEMP_PREFIX="agent-cli-voice-edit-temp"
|
|
8
|
+
|
|
9
|
+
notify_temp() {
|
|
10
|
+
local title=$1
|
|
11
|
+
local message=$2
|
|
12
|
+
local duration=${3:-4} # 4 seconds default
|
|
13
|
+
local group="${TEMP_PREFIX}-${RANDOM}-$$"
|
|
14
|
+
|
|
15
|
+
"$NOTIFIER" -title "$title" -message "$message" -group "$group"
|
|
16
|
+
(
|
|
17
|
+
sleep "$duration"
|
|
18
|
+
"$NOTIFIER" -remove "$group" >/dev/null 2>&1 || true
|
|
19
|
+
) &
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
if pgrep -f "agent-cli voice-edit" > /dev/null; then
|
|
23
|
+
pkill -INT -f "agent-cli voice-edit"
|
|
24
|
+
"$NOTIFIER" -remove "$RECORDING_GROUP" >/dev/null 2>&1 || true
|
|
25
|
+
notify_temp "🛑 Stopped" "Processing voice command..."
|
|
26
|
+
else
|
|
27
|
+
"$NOTIFIER" -title "🎙️ Started" -message "Listening for voice command..." -group "$RECORDING_GROUP"
|
|
28
|
+
(
|
|
29
|
+
OUTPUT=$("$HOME/.local/bin/agent-cli" voice-edit --quiet 2>/dev/null)
|
|
30
|
+
"$NOTIFIER" -remove "$RECORDING_GROUP" >/dev/null 2>&1 || true
|
|
31
|
+
if [ -n "$OUTPUT" ]; then
|
|
32
|
+
notify_temp "✨ Voice Edit Result" "$OUTPUT"
|
|
33
|
+
else
|
|
34
|
+
notify_temp "❌ Error" "No output"
|
|
35
|
+
fi
|
|
36
|
+
) &
|
|
37
|
+
fi
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# NVIDIA ASR Server
|
|
2
|
+
|
|
3
|
+
OpenAI-compatible API server for NVIDIA ASR models.
|
|
4
|
+
|
|
5
|
+
## Quick Start
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
cd scripts/nvidia-asr-server
|
|
9
|
+
uv run server.py
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
Server runs at `http://localhost:9898`
|
|
13
|
+
|
|
14
|
+
## CLI Options
|
|
15
|
+
|
|
16
|
+
- `--model`, `-m`: Model to use (default: `canary-qwen-2.5b`)
|
|
17
|
+
- `canary-qwen-2.5b`: Multilingual ASR (~5GB VRAM)
|
|
18
|
+
- `parakeet-tdt-0.6b-v2`: English with timestamps (~2GB VRAM)
|
|
19
|
+
- `--port`, `-p`: Port (default: 9898)
|
|
20
|
+
- `--device`, `-d`: Device (default: auto-select best GPU)
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
# Examples
|
|
24
|
+
uv run server.py --model parakeet-tdt-0.6b-v2
|
|
25
|
+
uv run server.py -m parakeet-tdt-0.6b-v2 -p 9090 -d cuda:1
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
## Using with Agent-CLI
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
# Start server
|
|
32
|
+
cd scripts/nvidia-asr-server
|
|
33
|
+
uv run server.py
|
|
34
|
+
|
|
35
|
+
# In another terminal
|
|
36
|
+
agent-cli transcribe \
|
|
37
|
+
--asr-provider openai \
|
|
38
|
+
--asr-openai-base-url http://localhost:9898/v1
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
**Note**: The `/v1` suffix is required for OpenAI compatibility.
|
|
42
|
+
|
|
43
|
+
## API Usage
|
|
44
|
+
|
|
45
|
+
### Python Example
|
|
46
|
+
|
|
47
|
+
```python
|
|
48
|
+
import requests
|
|
49
|
+
|
|
50
|
+
with open("audio.wav", "rb") as f:
|
|
51
|
+
response = requests.post(
|
|
52
|
+
"http://localhost:9898/v1/audio/transcriptions",
|
|
53
|
+
files={"file": f},
|
|
54
|
+
data={"model": "parakeet-tdt-0.6b-v2"}
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
print(response.json()["text"])
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
### With Timestamps (Parakeet only)
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
response = requests.post(
|
|
64
|
+
"http://localhost:9898/v1/audio/transcriptions",
|
|
65
|
+
files={"file": open("audio.wav", "rb")},
|
|
66
|
+
data={
|
|
67
|
+
"model": "parakeet-tdt-0.6b-v2",
|
|
68
|
+
"timestamp_granularities": ["word"]
|
|
69
|
+
}
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
result = response.json()
|
|
73
|
+
for word in result.get("words", []):
|
|
74
|
+
print(f"{word['start']:.2f}s - {word['end']:.2f}s: {word['word']}")
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
## Requirements
|
|
78
|
+
|
|
79
|
+
- Python 3.13+
|
|
80
|
+
- CUDA-compatible GPU (recommended)
|
|
81
|
+
- ~2-5GB VRAM depending on model
|
|
82
|
+
|
|
83
|
+
## Troubleshooting
|
|
84
|
+
|
|
85
|
+
**GPU out of memory**: Try smaller model or CPU
|
|
86
|
+
```bash
|
|
87
|
+
uv run server.py --model parakeet-tdt-0.6b-v2
|
|
88
|
+
uv run server.py --device cpu
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
**Port in use**: Change port
|
|
92
|
+
```bash
|
|
93
|
+
uv run server.py --port 9999
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
## License
|
|
97
|
+
|
|
98
|
+
- Canary: NVIDIA AI Foundation Models Community License
|
|
99
|
+
- Parakeet: CC-BY-4.0
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "nvidia-asr-server"
|
|
3
|
+
version = "1.0.0"
|
|
4
|
+
description = "NVIDIA ASR server with OpenAI-compatible API"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.13"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"fastapi[standard]>=0.115.0",
|
|
9
|
+
"torch>=2.5.0",
|
|
10
|
+
"soundfile>=0.12.1",
|
|
11
|
+
"sacrebleu>=2.4.0",
|
|
12
|
+
"typer>=0.9.0",
|
|
13
|
+
"nemo-toolkit[asr,tts] @ git+https://github.com/NVIDIA/NeMo.git",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
[tool.uv.sources]
|
|
17
|
+
torch = [{ index = "pytorch-cu124" }]
|
|
18
|
+
|
|
19
|
+
[[tool.uv.index]]
|
|
20
|
+
name = "pytorch-cu124"
|
|
21
|
+
url = "https://download.pytorch.org/whl/cu124"
|
|
22
|
+
explicit = true
|
|
23
|
+
|
|
24
|
+
[tool.uv]
|
|
25
|
+
override-dependencies = [
|
|
26
|
+
"ml-dtypes>=0.5.0",
|
|
27
|
+
]
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
#!/usr/bin/env -S uv run
|
|
2
|
+
"""NVIDIA ASR server with OpenAI-compatible API.
|
|
3
|
+
|
|
4
|
+
Supports multiple NVIDIA ASR models:
|
|
5
|
+
- nvidia/canary-qwen-2.5b (default): Multilingual ASR with translation capabilities
|
|
6
|
+
- nvidia/parakeet-tdt-0.6b-v2: High-quality English ASR with timestamps
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
cd scripts/nvidia-asr-server
|
|
10
|
+
uv run server.py
|
|
11
|
+
uv run server.py --model parakeet-tdt-0.6b-v2
|
|
12
|
+
uv run server.py --port 9090 --device cuda:1
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import shutil
|
|
18
|
+
import subprocess
|
|
19
|
+
import tempfile
|
|
20
|
+
import traceback
|
|
21
|
+
from contextlib import asynccontextmanager, suppress
|
|
22
|
+
from dataclasses import dataclass
|
|
23
|
+
from enum import Enum
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import TYPE_CHECKING, Annotated, Any
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
import typer
|
|
29
|
+
import uvicorn
|
|
30
|
+
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
31
|
+
from fastapi.responses import JSONResponse
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from collections.abc import AsyncGenerator
|
|
35
|
+
from typing import TypedDict
|
|
36
|
+
|
|
37
|
+
class TranscriptionResult(TypedDict, total=False):
|
|
38
|
+
"""Transcription result with optional word-level timestamps."""
|
|
39
|
+
|
|
40
|
+
text: str
|
|
41
|
+
words: list[dict[str, Any]]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ModelType(str, Enum):
|
|
45
|
+
"""Supported ASR models."""
|
|
46
|
+
|
|
47
|
+
CANARY = "canary-qwen-2.5b"
|
|
48
|
+
PARAKEET = "parakeet-tdt-0.6b-v2"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class ServerConfig:
|
|
53
|
+
"""Server configuration."""
|
|
54
|
+
|
|
55
|
+
model_type: ModelType
|
|
56
|
+
device: str
|
|
57
|
+
port: int
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def select_best_gpu() -> str:
|
|
61
|
+
"""Select the GPU with the most free memory, or CPU if no GPU available."""
|
|
62
|
+
if not torch.cuda.is_available():
|
|
63
|
+
return "cpu"
|
|
64
|
+
|
|
65
|
+
if torch.cuda.device_count() == 1:
|
|
66
|
+
return "cuda:0"
|
|
67
|
+
|
|
68
|
+
best_gpu = max(
|
|
69
|
+
range(torch.cuda.device_count()),
|
|
70
|
+
key=lambda i: torch.cuda.mem_get_info(i)[0],
|
|
71
|
+
)
|
|
72
|
+
return f"cuda:{best_gpu}"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def resample_audio(input_path: str) -> str:
|
|
76
|
+
"""Resample audio to 16kHz mono WAV using ffmpeg."""
|
|
77
|
+
out_path = f"{input_path}_16k.wav"
|
|
78
|
+
cmd = [
|
|
79
|
+
"ffmpeg",
|
|
80
|
+
"-y",
|
|
81
|
+
"-i",
|
|
82
|
+
input_path,
|
|
83
|
+
"-ar",
|
|
84
|
+
"16000",
|
|
85
|
+
"-ac",
|
|
86
|
+
"1",
|
|
87
|
+
out_path,
|
|
88
|
+
]
|
|
89
|
+
result = subprocess.run(cmd, capture_output=True, check=False)
|
|
90
|
+
if result.returncode != 0:
|
|
91
|
+
stderr = result.stderr.decode() if result.stderr else "No error output"
|
|
92
|
+
msg = f"ffmpeg failed: {stderr}"
|
|
93
|
+
raise RuntimeError(msg)
|
|
94
|
+
return out_path
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def load_asr_model(config: ServerConfig) -> Any:
|
|
98
|
+
"""Load the appropriate ASR model based on configuration."""
|
|
99
|
+
import nemo.collections.asr as nemo_asr # noqa: PLC0415
|
|
100
|
+
from nemo.collections.speechlm2.models import SALM # noqa: PLC0415
|
|
101
|
+
|
|
102
|
+
model_name = f"nvidia/{config.model_type.value}"
|
|
103
|
+
|
|
104
|
+
# Print device info
|
|
105
|
+
if config.device.startswith("cuda"):
|
|
106
|
+
gpu_id = int(config.device.split(":")[1]) if ":" in config.device else 0
|
|
107
|
+
free_mem, total_mem = torch.cuda.mem_get_info(gpu_id)
|
|
108
|
+
free_gb = free_mem / 1024**3
|
|
109
|
+
total_gb = total_mem / 1024**3
|
|
110
|
+
print(
|
|
111
|
+
f"Loading {model_name} on {config.device} ({free_gb:.1f}GB / {total_gb:.1f}GB)",
|
|
112
|
+
flush=True,
|
|
113
|
+
)
|
|
114
|
+
else:
|
|
115
|
+
print(f"Loading {model_name} on {config.device}", flush=True)
|
|
116
|
+
|
|
117
|
+
if config.model_type == ModelType.CANARY:
|
|
118
|
+
model = SALM.from_pretrained(model_name)
|
|
119
|
+
elif config.model_type == ModelType.PARAKEET:
|
|
120
|
+
model = nemo_asr.models.ASRModel.from_pretrained(model_name)
|
|
121
|
+
else:
|
|
122
|
+
msg = f"Unsupported model type: {config.model_type}"
|
|
123
|
+
raise ValueError(msg)
|
|
124
|
+
|
|
125
|
+
return model.to(config.device).eval()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
asr_model: Any = None
|
|
129
|
+
config: ServerConfig | None = None
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@asynccontextmanager
|
|
133
|
+
async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
|
134
|
+
"""Load the ASR model on startup."""
|
|
135
|
+
global asr_model
|
|
136
|
+
assert config is not None
|
|
137
|
+
asr_model = load_asr_model(config)
|
|
138
|
+
yield
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
app = FastAPI(lifespan=lifespan)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def transcribe_canary(audio_path: str, prompt: str | None) -> str:
|
|
145
|
+
"""Transcribe audio using Canary model."""
|
|
146
|
+
user_prompt = prompt or "Transcribe the following:"
|
|
147
|
+
full_prompt = f"{user_prompt} {asr_model.audio_locator_tag}"
|
|
148
|
+
|
|
149
|
+
prompts = [[{"role": "user", "content": full_prompt, "audio": [audio_path]}]]
|
|
150
|
+
answer_ids = asr_model.generate(prompts=prompts, max_new_tokens=128)
|
|
151
|
+
return asr_model.tokenizer.ids_to_text(answer_ids[0].cpu())
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def transcribe_parakeet(
|
|
155
|
+
audio_path: str,
|
|
156
|
+
timestamp_granularities: list[str] | None,
|
|
157
|
+
) -> TranscriptionResult:
|
|
158
|
+
"""Transcribe audio using Parakeet model."""
|
|
159
|
+
enable_timestamps = bool(timestamp_granularities)
|
|
160
|
+
output = asr_model.transcribe([audio_path], timestamps=enable_timestamps)
|
|
161
|
+
|
|
162
|
+
result: TranscriptionResult = {"text": output[0].text}
|
|
163
|
+
|
|
164
|
+
if enable_timestamps and timestamp_granularities and "word" in timestamp_granularities:
|
|
165
|
+
word_timestamps = output[0].timestamp.get("word", [])
|
|
166
|
+
if word_timestamps:
|
|
167
|
+
result["words"] = [
|
|
168
|
+
{"word": w["word"], "start": w["start"], "end": w["end"]} for w in word_timestamps
|
|
169
|
+
]
|
|
170
|
+
|
|
171
|
+
return result
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def cleanup_files(*paths: str | None) -> None:
|
|
175
|
+
"""Clean up temporary files."""
|
|
176
|
+
for p in paths:
|
|
177
|
+
if p:
|
|
178
|
+
with suppress(OSError):
|
|
179
|
+
Path(p).unlink(missing_ok=True)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@app.post("/v1/audio/transcriptions", response_model=None)
|
|
183
|
+
async def transcribe(
|
|
184
|
+
file: Annotated[UploadFile, File()],
|
|
185
|
+
response_format: Annotated[str, Form()] = "json",
|
|
186
|
+
prompt: Annotated[str | None, Form()] = None,
|
|
187
|
+
timestamp_granularities: Annotated[list[str] | None, Form()] = None,
|
|
188
|
+
) -> str | JSONResponse:
|
|
189
|
+
"""Transcribe audio using ASR model with OpenAI-compatible API."""
|
|
190
|
+
if asr_model is None:
|
|
191
|
+
raise HTTPException(status_code=503, detail="Model not loaded yet")
|
|
192
|
+
|
|
193
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix="") as tmp:
|
|
194
|
+
tmp_path = tmp.name
|
|
195
|
+
shutil.copyfileobj(file.file, tmp)
|
|
196
|
+
|
|
197
|
+
resampled_path = None
|
|
198
|
+
try:
|
|
199
|
+
resampled_path = resample_audio(tmp_path)
|
|
200
|
+
|
|
201
|
+
with torch.inference_mode():
|
|
202
|
+
assert config is not None
|
|
203
|
+
if config.model_type == ModelType.CANARY:
|
|
204
|
+
text = transcribe_canary(resampled_path, prompt)
|
|
205
|
+
return text if response_format == "text" else JSONResponse({"text": text})
|
|
206
|
+
if config.model_type == ModelType.PARAKEET:
|
|
207
|
+
result = transcribe_parakeet(resampled_path, timestamp_granularities)
|
|
208
|
+
return result["text"] if response_format == "text" else JSONResponse(result)
|
|
209
|
+
|
|
210
|
+
except Exception as e:
|
|
211
|
+
traceback.print_exc()
|
|
212
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
213
|
+
finally:
|
|
214
|
+
cleanup_files(tmp_path, resampled_path)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def main(
|
|
218
|
+
model: Annotated[
|
|
219
|
+
ModelType,
|
|
220
|
+
typer.Option("--model", "-m", help="ASR model to use"),
|
|
221
|
+
] = ModelType.CANARY,
|
|
222
|
+
port: Annotated[int, typer.Option("--port", "-p", help="Server port")] = 9898,
|
|
223
|
+
device: Annotated[
|
|
224
|
+
str | None,
|
|
225
|
+
typer.Option(
|
|
226
|
+
"--device",
|
|
227
|
+
"-d",
|
|
228
|
+
help="Device to use (cpu, cuda, cuda:0, etc.). Auto-selects GPU with most free memory if not specified.",
|
|
229
|
+
),
|
|
230
|
+
] = None,
|
|
231
|
+
) -> None:
|
|
232
|
+
"""Run NVIDIA ASR server with OpenAI-compatible API.
|
|
233
|
+
|
|
234
|
+
Supports multiple models:
|
|
235
|
+
- canary-qwen-2.5b: Multilingual ASR with translation (default)
|
|
236
|
+
- parakeet-tdt-0.6b-v2: High-quality English ASR with timestamps
|
|
237
|
+
"""
|
|
238
|
+
global config
|
|
239
|
+
|
|
240
|
+
config = ServerConfig(
|
|
241
|
+
model_type=model,
|
|
242
|
+
device=device or select_best_gpu(),
|
|
243
|
+
port=port,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
print(f"Starting ASR server with model: {model.value}")
|
|
247
|
+
print(f"Device: {config.device}")
|
|
248
|
+
print(f"Port: {config.port}")
|
|
249
|
+
print()
|
|
250
|
+
|
|
251
|
+
uvicorn.run(app, host="0.0.0.0", port=config.port) # noqa: S104
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
if __name__ == "__main__":
|
|
255
|
+
typer.run(main)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
{ pkgs ? import <nixpkgs> { config.allowUnfree = true; } }:
|
|
2
|
+
|
|
3
|
+
pkgs.mkShell {
|
|
4
|
+
buildInputs = with pkgs; [
|
|
5
|
+
# Python and uv
|
|
6
|
+
python313
|
|
7
|
+
uv
|
|
8
|
+
|
|
9
|
+
# Audio libraries
|
|
10
|
+
ffmpeg
|
|
11
|
+
];
|
|
12
|
+
|
|
13
|
+
shellHook = ''
|
|
14
|
+
# Set up CUDA environment (use system NVIDIA drivers and CUDA libraries)
|
|
15
|
+
export LD_LIBRARY_PATH=/run/opengl-driver/lib:/run/current-system/sw/lib:$LD_LIBRARY_PATH
|
|
16
|
+
|
|
17
|
+
# Tell triton where to find libcuda.so (avoids calling /sbin/ldconfig)
|
|
18
|
+
export TRITON_LIBCUDA_PATH=/run/opengl-driver/lib
|
|
19
|
+
|
|
20
|
+
# PyTorch memory management - avoid fragmentation
|
|
21
|
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
|
22
|
+
|
|
23
|
+
# Canary server defaults
|
|
24
|
+
export CANARY_PORT=9898
|
|
25
|
+
# CANARY_DEVICE auto-detects GPU with most free memory (override if needed)
|
|
26
|
+
|
|
27
|
+
echo "CUDA environment configured (using system NVIDIA drivers)"
|
|
28
|
+
echo "TRITON_LIBCUDA_PATH: $TRITON_LIBCUDA_PATH"
|
|
29
|
+
echo "PYTORCH_CUDA_ALLOC_CONF: $PYTORCH_CUDA_ALLOC_CONF"
|
|
30
|
+
echo "Run 'uv run server.py' to start the server"
|
|
31
|
+
'';
|
|
32
|
+
}
|