agent-cli 0.70.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_cli/__init__.py +5 -0
- agent_cli/__main__.py +6 -0
- agent_cli/_extras.json +14 -0
- agent_cli/_requirements/.gitkeep +0 -0
- agent_cli/_requirements/audio.txt +79 -0
- agent_cli/_requirements/faster-whisper.txt +215 -0
- agent_cli/_requirements/kokoro.txt +425 -0
- agent_cli/_requirements/llm.txt +183 -0
- agent_cli/_requirements/memory.txt +355 -0
- agent_cli/_requirements/mlx-whisper.txt +222 -0
- agent_cli/_requirements/piper.txt +176 -0
- agent_cli/_requirements/rag.txt +402 -0
- agent_cli/_requirements/server.txt +154 -0
- agent_cli/_requirements/speed.txt +77 -0
- agent_cli/_requirements/vad.txt +155 -0
- agent_cli/_requirements/wyoming.txt +71 -0
- agent_cli/_tools.py +368 -0
- agent_cli/agents/__init__.py +23 -0
- agent_cli/agents/_voice_agent_common.py +136 -0
- agent_cli/agents/assistant.py +383 -0
- agent_cli/agents/autocorrect.py +284 -0
- agent_cli/agents/chat.py +496 -0
- agent_cli/agents/memory/__init__.py +31 -0
- agent_cli/agents/memory/add.py +190 -0
- agent_cli/agents/memory/proxy.py +160 -0
- agent_cli/agents/rag_proxy.py +128 -0
- agent_cli/agents/speak.py +209 -0
- agent_cli/agents/transcribe.py +671 -0
- agent_cli/agents/transcribe_daemon.py +499 -0
- agent_cli/agents/voice_edit.py +291 -0
- agent_cli/api.py +22 -0
- agent_cli/cli.py +106 -0
- agent_cli/config.py +503 -0
- agent_cli/config_cmd.py +307 -0
- agent_cli/constants.py +27 -0
- agent_cli/core/__init__.py +1 -0
- agent_cli/core/audio.py +461 -0
- agent_cli/core/audio_format.py +299 -0
- agent_cli/core/chroma.py +88 -0
- agent_cli/core/deps.py +191 -0
- agent_cli/core/openai_proxy.py +139 -0
- agent_cli/core/process.py +195 -0
- agent_cli/core/reranker.py +120 -0
- agent_cli/core/sse.py +87 -0
- agent_cli/core/transcription_logger.py +70 -0
- agent_cli/core/utils.py +526 -0
- agent_cli/core/vad.py +175 -0
- agent_cli/core/watch.py +65 -0
- agent_cli/dev/__init__.py +14 -0
- agent_cli/dev/cli.py +1588 -0
- agent_cli/dev/coding_agents/__init__.py +19 -0
- agent_cli/dev/coding_agents/aider.py +24 -0
- agent_cli/dev/coding_agents/base.py +167 -0
- agent_cli/dev/coding_agents/claude.py +39 -0
- agent_cli/dev/coding_agents/codex.py +24 -0
- agent_cli/dev/coding_agents/continue_dev.py +15 -0
- agent_cli/dev/coding_agents/copilot.py +24 -0
- agent_cli/dev/coding_agents/cursor_agent.py +48 -0
- agent_cli/dev/coding_agents/gemini.py +28 -0
- agent_cli/dev/coding_agents/opencode.py +15 -0
- agent_cli/dev/coding_agents/registry.py +49 -0
- agent_cli/dev/editors/__init__.py +19 -0
- agent_cli/dev/editors/base.py +89 -0
- agent_cli/dev/editors/cursor.py +15 -0
- agent_cli/dev/editors/emacs.py +46 -0
- agent_cli/dev/editors/jetbrains.py +56 -0
- agent_cli/dev/editors/nano.py +31 -0
- agent_cli/dev/editors/neovim.py +33 -0
- agent_cli/dev/editors/registry.py +59 -0
- agent_cli/dev/editors/sublime.py +20 -0
- agent_cli/dev/editors/vim.py +42 -0
- agent_cli/dev/editors/vscode.py +15 -0
- agent_cli/dev/editors/zed.py +20 -0
- agent_cli/dev/project.py +568 -0
- agent_cli/dev/registry.py +52 -0
- agent_cli/dev/skill/SKILL.md +141 -0
- agent_cli/dev/skill/examples.md +571 -0
- agent_cli/dev/terminals/__init__.py +19 -0
- agent_cli/dev/terminals/apple_terminal.py +82 -0
- agent_cli/dev/terminals/base.py +56 -0
- agent_cli/dev/terminals/gnome.py +51 -0
- agent_cli/dev/terminals/iterm2.py +84 -0
- agent_cli/dev/terminals/kitty.py +77 -0
- agent_cli/dev/terminals/registry.py +48 -0
- agent_cli/dev/terminals/tmux.py +58 -0
- agent_cli/dev/terminals/warp.py +132 -0
- agent_cli/dev/terminals/zellij.py +78 -0
- agent_cli/dev/worktree.py +856 -0
- agent_cli/docs_gen.py +417 -0
- agent_cli/example-config.toml +185 -0
- agent_cli/install/__init__.py +5 -0
- agent_cli/install/common.py +89 -0
- agent_cli/install/extras.py +174 -0
- agent_cli/install/hotkeys.py +48 -0
- agent_cli/install/services.py +87 -0
- agent_cli/memory/__init__.py +7 -0
- agent_cli/memory/_files.py +250 -0
- agent_cli/memory/_filters.py +63 -0
- agent_cli/memory/_git.py +157 -0
- agent_cli/memory/_indexer.py +142 -0
- agent_cli/memory/_ingest.py +408 -0
- agent_cli/memory/_persistence.py +182 -0
- agent_cli/memory/_prompt.py +91 -0
- agent_cli/memory/_retrieval.py +294 -0
- agent_cli/memory/_store.py +169 -0
- agent_cli/memory/_streaming.py +44 -0
- agent_cli/memory/_tasks.py +48 -0
- agent_cli/memory/api.py +113 -0
- agent_cli/memory/client.py +272 -0
- agent_cli/memory/engine.py +361 -0
- agent_cli/memory/entities.py +43 -0
- agent_cli/memory/models.py +112 -0
- agent_cli/opts.py +433 -0
- agent_cli/py.typed +0 -0
- agent_cli/rag/__init__.py +3 -0
- agent_cli/rag/_indexer.py +67 -0
- agent_cli/rag/_indexing.py +226 -0
- agent_cli/rag/_prompt.py +30 -0
- agent_cli/rag/_retriever.py +156 -0
- agent_cli/rag/_store.py +48 -0
- agent_cli/rag/_utils.py +218 -0
- agent_cli/rag/api.py +175 -0
- agent_cli/rag/client.py +299 -0
- agent_cli/rag/engine.py +302 -0
- agent_cli/rag/models.py +55 -0
- agent_cli/scripts/.runtime/.gitkeep +0 -0
- agent_cli/scripts/__init__.py +1 -0
- agent_cli/scripts/check_plugin_skill_sync.py +50 -0
- agent_cli/scripts/linux-hotkeys/README.md +63 -0
- agent_cli/scripts/linux-hotkeys/toggle-autocorrect.sh +45 -0
- agent_cli/scripts/linux-hotkeys/toggle-transcription.sh +58 -0
- agent_cli/scripts/linux-hotkeys/toggle-voice-edit.sh +58 -0
- agent_cli/scripts/macos-hotkeys/README.md +45 -0
- agent_cli/scripts/macos-hotkeys/skhd-config-example +5 -0
- agent_cli/scripts/macos-hotkeys/toggle-autocorrect.sh +12 -0
- agent_cli/scripts/macos-hotkeys/toggle-transcription.sh +37 -0
- agent_cli/scripts/macos-hotkeys/toggle-voice-edit.sh +37 -0
- agent_cli/scripts/nvidia-asr-server/README.md +99 -0
- agent_cli/scripts/nvidia-asr-server/pyproject.toml +27 -0
- agent_cli/scripts/nvidia-asr-server/server.py +255 -0
- agent_cli/scripts/nvidia-asr-server/shell.nix +32 -0
- agent_cli/scripts/nvidia-asr-server/uv.lock +4654 -0
- agent_cli/scripts/run-openwakeword.sh +11 -0
- agent_cli/scripts/run-piper-windows.ps1 +30 -0
- agent_cli/scripts/run-piper.sh +24 -0
- agent_cli/scripts/run-whisper-linux.sh +40 -0
- agent_cli/scripts/run-whisper-macos.sh +6 -0
- agent_cli/scripts/run-whisper-windows.ps1 +51 -0
- agent_cli/scripts/run-whisper.sh +9 -0
- agent_cli/scripts/run_faster_whisper_server.py +136 -0
- agent_cli/scripts/setup-linux-hotkeys.sh +72 -0
- agent_cli/scripts/setup-linux.sh +108 -0
- agent_cli/scripts/setup-macos-hotkeys.sh +61 -0
- agent_cli/scripts/setup-macos.sh +76 -0
- agent_cli/scripts/setup-windows.ps1 +63 -0
- agent_cli/scripts/start-all-services-windows.ps1 +53 -0
- agent_cli/scripts/start-all-services.sh +178 -0
- agent_cli/scripts/sync_extras.py +138 -0
- agent_cli/server/__init__.py +3 -0
- agent_cli/server/cli.py +721 -0
- agent_cli/server/common.py +222 -0
- agent_cli/server/model_manager.py +288 -0
- agent_cli/server/model_registry.py +225 -0
- agent_cli/server/proxy/__init__.py +3 -0
- agent_cli/server/proxy/api.py +444 -0
- agent_cli/server/streaming.py +67 -0
- agent_cli/server/tts/__init__.py +3 -0
- agent_cli/server/tts/api.py +335 -0
- agent_cli/server/tts/backends/__init__.py +82 -0
- agent_cli/server/tts/backends/base.py +139 -0
- agent_cli/server/tts/backends/kokoro.py +403 -0
- agent_cli/server/tts/backends/piper.py +253 -0
- agent_cli/server/tts/model_manager.py +201 -0
- agent_cli/server/tts/model_registry.py +28 -0
- agent_cli/server/tts/wyoming_handler.py +249 -0
- agent_cli/server/whisper/__init__.py +3 -0
- agent_cli/server/whisper/api.py +413 -0
- agent_cli/server/whisper/backends/__init__.py +89 -0
- agent_cli/server/whisper/backends/base.py +97 -0
- agent_cli/server/whisper/backends/faster_whisper.py +225 -0
- agent_cli/server/whisper/backends/mlx.py +270 -0
- agent_cli/server/whisper/languages.py +116 -0
- agent_cli/server/whisper/model_manager.py +157 -0
- agent_cli/server/whisper/model_registry.py +28 -0
- agent_cli/server/whisper/wyoming_handler.py +203 -0
- agent_cli/services/__init__.py +343 -0
- agent_cli/services/_wyoming_utils.py +64 -0
- agent_cli/services/asr.py +506 -0
- agent_cli/services/llm.py +228 -0
- agent_cli/services/tts.py +450 -0
- agent_cli/services/wake_word.py +142 -0
- agent_cli-0.70.5.dist-info/METADATA +2118 -0
- agent_cli-0.70.5.dist-info/RECORD +196 -0
- agent_cli-0.70.5.dist-info/WHEEL +4 -0
- agent_cli-0.70.5.dist-info/entry_points.txt +4 -0
- agent_cli-0.70.5.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""Client for interacting with LLMs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
from rich.live import Live
|
|
10
|
+
|
|
11
|
+
from agent_cli.core.utils import console, live_timer, print_error_message, print_output_panel
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
import logging
|
|
15
|
+
|
|
16
|
+
from pydantic_ai import Agent
|
|
17
|
+
from pydantic_ai.models.gemini import GeminiModel
|
|
18
|
+
from pydantic_ai.models.openai import OpenAIModel
|
|
19
|
+
from pydantic_ai.tools import Tool
|
|
20
|
+
|
|
21
|
+
from agent_cli import config
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _openai_llm_model(openai_cfg: config.OpenAILLM) -> OpenAIModel:
|
|
25
|
+
from pydantic_ai.models.openai import OpenAIModel # noqa: PLC0415
|
|
26
|
+
from pydantic_ai.providers.openai import OpenAIProvider # noqa: PLC0415
|
|
27
|
+
|
|
28
|
+
# For custom base URLs (like llama-server), API key might not be required
|
|
29
|
+
if openai_cfg.openai_base_url:
|
|
30
|
+
# Custom endpoint - API key is optional
|
|
31
|
+
provider = OpenAIProvider(
|
|
32
|
+
api_key=openai_cfg.openai_api_key or "dummy",
|
|
33
|
+
base_url=openai_cfg.openai_base_url,
|
|
34
|
+
)
|
|
35
|
+
else:
|
|
36
|
+
# Standard OpenAI - API key is required
|
|
37
|
+
if not openai_cfg.openai_api_key:
|
|
38
|
+
msg = "OpenAI API key is not set."
|
|
39
|
+
raise ValueError(msg)
|
|
40
|
+
provider = OpenAIProvider(api_key=openai_cfg.openai_api_key)
|
|
41
|
+
|
|
42
|
+
model_name = openai_cfg.llm_openai_model
|
|
43
|
+
return OpenAIModel(model_name=model_name, provider=provider)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _ollama_llm_model(ollama_cfg: config.Ollama) -> OpenAIModel:
|
|
47
|
+
from pydantic_ai.models.openai import OpenAIModel # noqa: PLC0415
|
|
48
|
+
from pydantic_ai.providers.openai import OpenAIProvider # noqa: PLC0415
|
|
49
|
+
|
|
50
|
+
provider = OpenAIProvider(base_url=f"{ollama_cfg.llm_ollama_host}/v1")
|
|
51
|
+
model_name = ollama_cfg.llm_ollama_model
|
|
52
|
+
return OpenAIModel(model_name=model_name, provider=provider)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _gemini_llm_model(gemini_cfg: config.GeminiLLM) -> GeminiModel:
|
|
56
|
+
from pydantic_ai.models.gemini import GeminiModel # noqa: PLC0415
|
|
57
|
+
from pydantic_ai.providers.google_gla import GoogleGLAProvider # noqa: PLC0415
|
|
58
|
+
|
|
59
|
+
if not gemini_cfg.gemini_api_key:
|
|
60
|
+
msg = "Gemini API key is not set."
|
|
61
|
+
raise ValueError(msg)
|
|
62
|
+
provider = GoogleGLAProvider(api_key=gemini_cfg.gemini_api_key)
|
|
63
|
+
model_name = gemini_cfg.llm_gemini_model
|
|
64
|
+
return GeminiModel(model_name=model_name, provider=provider)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def create_llm_agent(
|
|
68
|
+
provider_cfg: config.ProviderSelection,
|
|
69
|
+
ollama_cfg: config.Ollama,
|
|
70
|
+
openai_cfg: config.OpenAILLM,
|
|
71
|
+
gemini_cfg: config.GeminiLLM,
|
|
72
|
+
*,
|
|
73
|
+
system_prompt: str | None = None,
|
|
74
|
+
instructions: str | None = None,
|
|
75
|
+
tools: list[Tool] | None = None,
|
|
76
|
+
) -> Agent:
|
|
77
|
+
"""Construct and return a PydanticAI agent."""
|
|
78
|
+
from pydantic_ai import Agent # noqa: PLC0415
|
|
79
|
+
|
|
80
|
+
if provider_cfg.llm_provider == "openai":
|
|
81
|
+
llm_model = _openai_llm_model(openai_cfg)
|
|
82
|
+
elif provider_cfg.llm_provider == "ollama":
|
|
83
|
+
llm_model = _ollama_llm_model(ollama_cfg)
|
|
84
|
+
elif provider_cfg.llm_provider == "gemini":
|
|
85
|
+
llm_model = _gemini_llm_model(gemini_cfg)
|
|
86
|
+
|
|
87
|
+
return Agent(
|
|
88
|
+
model=llm_model,
|
|
89
|
+
system_prompt=system_prompt or (),
|
|
90
|
+
instructions=instructions,
|
|
91
|
+
tools=tools or [],
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# --- LLM (Editing) Logic ---
|
|
96
|
+
|
|
97
|
+
INPUT_TEMPLATE = """
|
|
98
|
+
{context_block}<original-text>
|
|
99
|
+
{original_text}
|
|
100
|
+
</original-text>
|
|
101
|
+
|
|
102
|
+
<instruction>
|
|
103
|
+
{instruction}
|
|
104
|
+
</instruction>
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
async def get_llm_response(
|
|
109
|
+
*,
|
|
110
|
+
system_prompt: str,
|
|
111
|
+
agent_instructions: str,
|
|
112
|
+
user_input: str,
|
|
113
|
+
provider_cfg: config.ProviderSelection,
|
|
114
|
+
ollama_cfg: config.Ollama,
|
|
115
|
+
openai_cfg: config.OpenAILLM,
|
|
116
|
+
gemini_cfg: config.GeminiLLM,
|
|
117
|
+
logger: logging.Logger,
|
|
118
|
+
live: Live | None = None,
|
|
119
|
+
tools: list[Tool] | None = None,
|
|
120
|
+
quiet: bool = False,
|
|
121
|
+
clipboard: bool = False,
|
|
122
|
+
show_output: bool = False,
|
|
123
|
+
exit_on_error: bool = False,
|
|
124
|
+
) -> str | None:
|
|
125
|
+
"""Get a response from the LLM with optional clipboard and output handling."""
|
|
126
|
+
agent = create_llm_agent(
|
|
127
|
+
provider_cfg=provider_cfg,
|
|
128
|
+
ollama_cfg=ollama_cfg,
|
|
129
|
+
openai_cfg=openai_cfg,
|
|
130
|
+
gemini_cfg=gemini_cfg,
|
|
131
|
+
system_prompt=system_prompt,
|
|
132
|
+
instructions=agent_instructions,
|
|
133
|
+
tools=tools,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
start_time = time.monotonic()
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
if provider_cfg.llm_provider == "ollama":
|
|
140
|
+
model_name = ollama_cfg.llm_ollama_model
|
|
141
|
+
elif provider_cfg.llm_provider == "openai":
|
|
142
|
+
model_name = openai_cfg.llm_openai_model
|
|
143
|
+
elif provider_cfg.llm_provider == "gemini":
|
|
144
|
+
model_name = gemini_cfg.llm_gemini_model
|
|
145
|
+
|
|
146
|
+
async with live_timer(
|
|
147
|
+
live or Live(console=console),
|
|
148
|
+
f"đ¤ Applying instruction with {model_name}",
|
|
149
|
+
style="bold yellow",
|
|
150
|
+
quiet=quiet,
|
|
151
|
+
):
|
|
152
|
+
result = await agent.run(user_input)
|
|
153
|
+
|
|
154
|
+
elapsed = time.monotonic() - start_time
|
|
155
|
+
result_text = result.output
|
|
156
|
+
|
|
157
|
+
if clipboard:
|
|
158
|
+
import pyperclip # noqa: PLC0415
|
|
159
|
+
|
|
160
|
+
pyperclip.copy(result_text)
|
|
161
|
+
logger.info("Copied result to clipboard.")
|
|
162
|
+
|
|
163
|
+
if show_output and not quiet:
|
|
164
|
+
print_output_panel(
|
|
165
|
+
result_text,
|
|
166
|
+
title="⨠Result (Copied to Clipboard)" if clipboard else "⨠Result",
|
|
167
|
+
subtitle=f"[dim]took {elapsed:.2f}s[/dim]",
|
|
168
|
+
)
|
|
169
|
+
elif quiet and clipboard:
|
|
170
|
+
print(result_text)
|
|
171
|
+
|
|
172
|
+
return result_text
|
|
173
|
+
|
|
174
|
+
except Exception as e:
|
|
175
|
+
logger.exception("An error occurred during LLM processing.")
|
|
176
|
+
if provider_cfg.llm_provider == "openai":
|
|
177
|
+
msg = "Please check your OpenAI API key."
|
|
178
|
+
elif provider_cfg.llm_provider == "gemini":
|
|
179
|
+
msg = "Please check your Gemini API key."
|
|
180
|
+
elif provider_cfg.llm_provider == "ollama":
|
|
181
|
+
msg = f"Please check your Ollama server at [cyan]{ollama_cfg.llm_ollama_host}[/cyan]"
|
|
182
|
+
print_error_message(f"An unexpected LLM error occurred: {e}", msg)
|
|
183
|
+
if exit_on_error:
|
|
184
|
+
sys.exit(1)
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
async def process_and_update_clipboard(
|
|
189
|
+
system_prompt: str,
|
|
190
|
+
agent_instructions: str,
|
|
191
|
+
*,
|
|
192
|
+
provider_cfg: config.ProviderSelection,
|
|
193
|
+
ollama_cfg: config.Ollama,
|
|
194
|
+
openai_cfg: config.OpenAILLM,
|
|
195
|
+
gemini_cfg: config.GeminiLLM,
|
|
196
|
+
logger: logging.Logger,
|
|
197
|
+
original_text: str,
|
|
198
|
+
instruction: str,
|
|
199
|
+
clipboard: bool,
|
|
200
|
+
quiet: bool,
|
|
201
|
+
live: Live | None,
|
|
202
|
+
context: str | None = None,
|
|
203
|
+
) -> str | None:
|
|
204
|
+
"""Processes the text with the LLM, updates the clipboard, and displays the result."""
|
|
205
|
+
context_block = ""
|
|
206
|
+
if context:
|
|
207
|
+
context_block = f"<context>\n{context}\n</context>\n\n"
|
|
208
|
+
user_input = INPUT_TEMPLATE.format(
|
|
209
|
+
context_block=context_block,
|
|
210
|
+
original_text=original_text,
|
|
211
|
+
instruction=instruction,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return await get_llm_response(
|
|
215
|
+
system_prompt=system_prompt,
|
|
216
|
+
agent_instructions=agent_instructions,
|
|
217
|
+
user_input=user_input,
|
|
218
|
+
provider_cfg=provider_cfg,
|
|
219
|
+
ollama_cfg=ollama_cfg,
|
|
220
|
+
openai_cfg=openai_cfg,
|
|
221
|
+
gemini_cfg=gemini_cfg,
|
|
222
|
+
logger=logger,
|
|
223
|
+
quiet=quiet,
|
|
224
|
+
clipboard=clipboard,
|
|
225
|
+
live=live,
|
|
226
|
+
show_output=True,
|
|
227
|
+
exit_on_error=False, # Don't exit the server on LLM errors
|
|
228
|
+
)
|
|
@@ -0,0 +1,450 @@
|
|
|
1
|
+
"""Module for Text-to-Speech using Wyoming or OpenAI."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import importlib.util
|
|
7
|
+
import io
|
|
8
|
+
from functools import partial
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
from rich.live import Live
|
|
13
|
+
|
|
14
|
+
from agent_cli import config, constants
|
|
15
|
+
from agent_cli.core.audio import open_audio_stream, setup_output_stream
|
|
16
|
+
from agent_cli.core.audio_format import extract_pcm_from_wav
|
|
17
|
+
from agent_cli.core.utils import (
|
|
18
|
+
InteractiveStopEvent,
|
|
19
|
+
live_timer,
|
|
20
|
+
manage_send_receive_tasks,
|
|
21
|
+
print_error_message,
|
|
22
|
+
print_with_style,
|
|
23
|
+
)
|
|
24
|
+
from agent_cli.services import pcm_to_wav, synthesize_speech_gemini, synthesize_speech_openai
|
|
25
|
+
from agent_cli.services._wyoming_utils import wyoming_client_context
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
import logging
|
|
29
|
+
from collections.abc import Awaitable, Callable
|
|
30
|
+
|
|
31
|
+
from rich.live import Live
|
|
32
|
+
from wyoming.client import AsyncClient
|
|
33
|
+
from wyoming.tts import Synthesize
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
has_audiostretchy = importlib.util.find_spec("audiostretchy") is not None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def create_synthesizer(
|
|
40
|
+
provider_cfg: config.ProviderSelection,
|
|
41
|
+
audio_output_cfg: config.AudioOutput,
|
|
42
|
+
wyoming_tts_cfg: config.WyomingTTS,
|
|
43
|
+
openai_tts_cfg: config.OpenAITTS,
|
|
44
|
+
kokoro_tts_cfg: config.KokoroTTS,
|
|
45
|
+
gemini_tts_cfg: config.GeminiTTS | None = None,
|
|
46
|
+
) -> Callable[..., Awaitable[bytes | None]]:
|
|
47
|
+
"""Return the appropriate synthesizer based on the config."""
|
|
48
|
+
if not audio_output_cfg.enable_tts:
|
|
49
|
+
return _dummy_synthesizer
|
|
50
|
+
if provider_cfg.tts_provider == "openai":
|
|
51
|
+
return partial(
|
|
52
|
+
_synthesize_speech_openai,
|
|
53
|
+
openai_tts_cfg=openai_tts_cfg,
|
|
54
|
+
)
|
|
55
|
+
if provider_cfg.tts_provider == "kokoro":
|
|
56
|
+
return partial(
|
|
57
|
+
_synthesize_speech_kokoro,
|
|
58
|
+
kokoro_tts_cfg=kokoro_tts_cfg,
|
|
59
|
+
)
|
|
60
|
+
if provider_cfg.tts_provider == "gemini":
|
|
61
|
+
assert gemini_tts_cfg is not None, "Gemini TTS config required"
|
|
62
|
+
return partial(_synthesize_speech_gemini, gemini_tts_cfg=gemini_tts_cfg)
|
|
63
|
+
if provider_cfg.tts_provider == "wyoming":
|
|
64
|
+
return partial(_synthesize_speech_wyoming, wyoming_tts_cfg=wyoming_tts_cfg)
|
|
65
|
+
msg = f"Unknown TTS provider: {provider_cfg.tts_provider}"
|
|
66
|
+
raise NotImplementedError(msg)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
async def handle_tts_playback(
|
|
70
|
+
*,
|
|
71
|
+
text: str,
|
|
72
|
+
provider_cfg: config.ProviderSelection,
|
|
73
|
+
audio_output_cfg: config.AudioOutput,
|
|
74
|
+
wyoming_tts_cfg: config.WyomingTTS,
|
|
75
|
+
openai_tts_cfg: config.OpenAITTS,
|
|
76
|
+
kokoro_tts_cfg: config.KokoroTTS,
|
|
77
|
+
gemini_tts_cfg: config.GeminiTTS | None = None,
|
|
78
|
+
save_file: Path | None,
|
|
79
|
+
quiet: bool,
|
|
80
|
+
logger: logging.Logger,
|
|
81
|
+
play_audio: bool = True,
|
|
82
|
+
status_message: str = "đ Speaking...",
|
|
83
|
+
description: str = "Audio",
|
|
84
|
+
stop_event: InteractiveStopEvent | None = None,
|
|
85
|
+
live: Live,
|
|
86
|
+
) -> bytes | None:
|
|
87
|
+
"""Handle TTS synthesis, playback, and file saving."""
|
|
88
|
+
try:
|
|
89
|
+
if not quiet and status_message:
|
|
90
|
+
print_with_style(status_message, style="blue")
|
|
91
|
+
|
|
92
|
+
audio_data = await _speak_text(
|
|
93
|
+
text=text,
|
|
94
|
+
provider_cfg=provider_cfg,
|
|
95
|
+
audio_output_cfg=audio_output_cfg,
|
|
96
|
+
wyoming_tts_cfg=wyoming_tts_cfg,
|
|
97
|
+
openai_tts_cfg=openai_tts_cfg,
|
|
98
|
+
kokoro_tts_cfg=kokoro_tts_cfg,
|
|
99
|
+
gemini_tts_cfg=gemini_tts_cfg,
|
|
100
|
+
logger=logger,
|
|
101
|
+
quiet=quiet,
|
|
102
|
+
play_audio_flag=play_audio,
|
|
103
|
+
stop_event=stop_event,
|
|
104
|
+
live=live,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if save_file and audio_data:
|
|
108
|
+
await _save_audio_file(
|
|
109
|
+
audio_data,
|
|
110
|
+
save_file,
|
|
111
|
+
quiet,
|
|
112
|
+
logger,
|
|
113
|
+
description=description,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return audio_data
|
|
117
|
+
|
|
118
|
+
except (OSError, ConnectionError, TimeoutError) as e:
|
|
119
|
+
logger.warning("Failed TTS operation: %s", e)
|
|
120
|
+
if not quiet:
|
|
121
|
+
print_with_style(f"â ī¸ TTS failed: {e}", style="yellow")
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
# --- Helper Functions ---
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _create_synthesis_request(
|
|
129
|
+
text: str,
|
|
130
|
+
*,
|
|
131
|
+
voice_name: str | None = None,
|
|
132
|
+
language: str | None = None,
|
|
133
|
+
speaker: str | None = None,
|
|
134
|
+
) -> Synthesize:
|
|
135
|
+
"""Create a synthesis request with optional voice parameters."""
|
|
136
|
+
from wyoming.tts import Synthesize, SynthesizeVoice # noqa: PLC0415
|
|
137
|
+
|
|
138
|
+
synthesize_event = Synthesize(text=text)
|
|
139
|
+
|
|
140
|
+
# Add voice parameters if specified
|
|
141
|
+
if voice_name or language or speaker:
|
|
142
|
+
synthesize_event.voice = SynthesizeVoice(
|
|
143
|
+
name=voice_name,
|
|
144
|
+
language=language,
|
|
145
|
+
speaker=speaker,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
return synthesize_event
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
async def _process_audio_events(
|
|
152
|
+
client: AsyncClient,
|
|
153
|
+
logger: logging.Logger,
|
|
154
|
+
) -> tuple[bytes, int | None, int | None, int | None]:
|
|
155
|
+
"""Process audio events from TTS server and return audio data with metadata."""
|
|
156
|
+
from wyoming.audio import AudioChunk, AudioStart, AudioStop # noqa: PLC0415
|
|
157
|
+
|
|
158
|
+
audio_data = io.BytesIO()
|
|
159
|
+
sample_rate = None
|
|
160
|
+
sample_width = None
|
|
161
|
+
channels = None
|
|
162
|
+
|
|
163
|
+
while True:
|
|
164
|
+
event = await client.read_event()
|
|
165
|
+
if event is None:
|
|
166
|
+
logger.warning("Connection to TTS server lost.")
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
if AudioStart.is_type(event.type):
|
|
170
|
+
audio_start = AudioStart.from_event(event)
|
|
171
|
+
sample_rate = audio_start.rate
|
|
172
|
+
sample_width = audio_start.width
|
|
173
|
+
channels = audio_start.channels
|
|
174
|
+
logger.debug(
|
|
175
|
+
"Audio stream started: %dHz, %d channels, %d bytes/sample",
|
|
176
|
+
sample_rate,
|
|
177
|
+
channels,
|
|
178
|
+
sample_width,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
elif AudioChunk.is_type(event.type):
|
|
182
|
+
chunk = AudioChunk.from_event(event)
|
|
183
|
+
audio_data.write(chunk.audio)
|
|
184
|
+
logger.debug("Received %d bytes of audio", len(chunk.audio))
|
|
185
|
+
|
|
186
|
+
elif AudioStop.is_type(event.type):
|
|
187
|
+
logger.debug("Audio stream completed")
|
|
188
|
+
break
|
|
189
|
+
else:
|
|
190
|
+
logger.debug("Ignoring event type: %s", event.type)
|
|
191
|
+
|
|
192
|
+
return audio_data.getvalue(), sample_rate, sample_width, channels
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
async def _dummy_synthesizer(**_kwargs: object) -> bytes | None:
|
|
196
|
+
"""A dummy synthesizer that does nothing."""
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
async def _synthesize_speech_openai(
|
|
201
|
+
*,
|
|
202
|
+
text: str,
|
|
203
|
+
openai_tts_cfg: config.OpenAITTS,
|
|
204
|
+
logger: logging.Logger,
|
|
205
|
+
**_kwargs: object,
|
|
206
|
+
) -> bytes | None:
|
|
207
|
+
"""Synthesize speech from text using OpenAI-compatible TTS server."""
|
|
208
|
+
return await synthesize_speech_openai(
|
|
209
|
+
text=text,
|
|
210
|
+
openai_tts_cfg=openai_tts_cfg,
|
|
211
|
+
logger=logger,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
async def _synthesize_speech_kokoro(
|
|
216
|
+
*,
|
|
217
|
+
text: str,
|
|
218
|
+
kokoro_tts_cfg: config.KokoroTTS,
|
|
219
|
+
logger: logging.Logger,
|
|
220
|
+
**_kwargs: object,
|
|
221
|
+
) -> bytes | None:
|
|
222
|
+
"""Synthesize speech from text using Kokoro TTS server via OpenAI client."""
|
|
223
|
+
openai_tts_cfg = config.OpenAITTS(
|
|
224
|
+
tts_openai_model=kokoro_tts_cfg.tts_kokoro_model,
|
|
225
|
+
tts_openai_voice=kokoro_tts_cfg.tts_kokoro_voice,
|
|
226
|
+
tts_openai_base_url=kokoro_tts_cfg.tts_kokoro_host,
|
|
227
|
+
)
|
|
228
|
+
try:
|
|
229
|
+
return await synthesize_speech_openai(
|
|
230
|
+
text=text,
|
|
231
|
+
openai_tts_cfg=openai_tts_cfg,
|
|
232
|
+
logger=logger,
|
|
233
|
+
)
|
|
234
|
+
except Exception:
|
|
235
|
+
logger.exception("Error during Kokoro speech synthesis")
|
|
236
|
+
return None
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
async def _synthesize_speech_gemini(
|
|
240
|
+
*,
|
|
241
|
+
text: str,
|
|
242
|
+
gemini_tts_cfg: config.GeminiTTS,
|
|
243
|
+
logger: logging.Logger,
|
|
244
|
+
**_kwargs: object,
|
|
245
|
+
) -> bytes | None:
|
|
246
|
+
"""Synthesize speech from text using Gemini TTS."""
|
|
247
|
+
return await synthesize_speech_gemini(text=text, gemini_tts_cfg=gemini_tts_cfg, logger=logger)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
async def _synthesize_speech_wyoming(
|
|
251
|
+
*,
|
|
252
|
+
text: str,
|
|
253
|
+
wyoming_tts_cfg: config.WyomingTTS,
|
|
254
|
+
logger: logging.Logger,
|
|
255
|
+
quiet: bool = False,
|
|
256
|
+
live: Live,
|
|
257
|
+
**_kwargs: object,
|
|
258
|
+
) -> bytes | None:
|
|
259
|
+
"""Synthesize speech from text using Wyoming TTS server."""
|
|
260
|
+
try:
|
|
261
|
+
async with wyoming_client_context(
|
|
262
|
+
wyoming_tts_cfg.tts_wyoming_ip,
|
|
263
|
+
wyoming_tts_cfg.tts_wyoming_port,
|
|
264
|
+
"TTS",
|
|
265
|
+
logger,
|
|
266
|
+
quiet=quiet,
|
|
267
|
+
) as client:
|
|
268
|
+
async with live_timer(live, "đ Synthesizing text", style="blue", quiet=quiet):
|
|
269
|
+
synthesize_event = _create_synthesis_request(
|
|
270
|
+
text,
|
|
271
|
+
voice_name=wyoming_tts_cfg.tts_wyoming_voice,
|
|
272
|
+
language=wyoming_tts_cfg.tts_wyoming_language,
|
|
273
|
+
speaker=wyoming_tts_cfg.tts_wyoming_speaker,
|
|
274
|
+
)
|
|
275
|
+
_send_task, recv_task = await manage_send_receive_tasks(
|
|
276
|
+
client.write_event(synthesize_event.event()),
|
|
277
|
+
_process_audio_events(client, logger),
|
|
278
|
+
)
|
|
279
|
+
audio_data, sample_rate, sample_width, channels = recv_task.result()
|
|
280
|
+
if sample_rate and sample_width and channels and audio_data:
|
|
281
|
+
wav_data = pcm_to_wav(
|
|
282
|
+
audio_data,
|
|
283
|
+
sample_rate=sample_rate,
|
|
284
|
+
sample_width=sample_width,
|
|
285
|
+
channels=channels,
|
|
286
|
+
)
|
|
287
|
+
logger.info("Speech synthesis completed: %d bytes", len(wav_data))
|
|
288
|
+
return wav_data
|
|
289
|
+
logger.warning("No audio data received from TTS server")
|
|
290
|
+
return None
|
|
291
|
+
except (ConnectionRefusedError, Exception):
|
|
292
|
+
return None
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _apply_speed_adjustment(
|
|
296
|
+
audio_data: io.BytesIO,
|
|
297
|
+
speed: float,
|
|
298
|
+
) -> tuple[io.BytesIO, bool]:
|
|
299
|
+
"""Apply speed adjustment to audio data."""
|
|
300
|
+
if speed == 1.0 or not has_audiostretchy:
|
|
301
|
+
return audio_data, False
|
|
302
|
+
from audiostretchy.stretch import AudioStretch # noqa: PLC0415
|
|
303
|
+
|
|
304
|
+
audio_data.seek(0)
|
|
305
|
+
input_copy = io.BytesIO(audio_data.read())
|
|
306
|
+
audio_stretch = AudioStretch()
|
|
307
|
+
audio_stretch.open(file=input_copy, format="wav")
|
|
308
|
+
audio_stretch.stretch(ratio=1 / speed)
|
|
309
|
+
out = io.BytesIO()
|
|
310
|
+
audio_stretch.save_wav(out, close=False)
|
|
311
|
+
out.seek(0)
|
|
312
|
+
return out, True
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
async def _play_audio(
|
|
316
|
+
audio_data: bytes,
|
|
317
|
+
logger: logging.Logger,
|
|
318
|
+
*,
|
|
319
|
+
audio_output_cfg: config.AudioOutput,
|
|
320
|
+
quiet: bool = False,
|
|
321
|
+
stop_event: InteractiveStopEvent | None = None,
|
|
322
|
+
live: Live,
|
|
323
|
+
) -> None:
|
|
324
|
+
"""Play WAV audio data using SoundDevice."""
|
|
325
|
+
import numpy as np # noqa: PLC0415
|
|
326
|
+
|
|
327
|
+
try:
|
|
328
|
+
wav_io = io.BytesIO(audio_data)
|
|
329
|
+
speed = audio_output_cfg.tts_speed
|
|
330
|
+
wav_io, speed_changed = _apply_speed_adjustment(wav_io, speed)
|
|
331
|
+
wav = extract_pcm_from_wav(wav_io.read())
|
|
332
|
+
sample_rate = wav.sample_rate if speed_changed else int(wav.sample_rate * speed)
|
|
333
|
+
base_msg = f"đ Playing audio at {speed}x speed" if speed != 1.0 else "đ Playing audio"
|
|
334
|
+
async with live_timer(live, base_msg, style="blue", quiet=quiet):
|
|
335
|
+
stream_config = setup_output_stream(
|
|
336
|
+
audio_output_cfg.output_device_index,
|
|
337
|
+
sample_rate=sample_rate,
|
|
338
|
+
sample_width=wav.sample_width,
|
|
339
|
+
channels=wav.num_channels,
|
|
340
|
+
)
|
|
341
|
+
dtype = stream_config.dtype
|
|
342
|
+
|
|
343
|
+
with open_audio_stream(stream_config) as stream:
|
|
344
|
+
chunk_size_frames = constants.AUDIO_CHUNK_SIZE
|
|
345
|
+
bytes_per_frame = wav.num_channels * wav.sample_width
|
|
346
|
+
chunk_bytes = chunk_size_frames * bytes_per_frame
|
|
347
|
+
|
|
348
|
+
for i in range(0, len(wav.pcm_data), chunk_bytes):
|
|
349
|
+
if stop_event and stop_event.is_set():
|
|
350
|
+
logger.info("Audio playback interrupted")
|
|
351
|
+
if not quiet:
|
|
352
|
+
print_with_style("âšī¸ Audio playback interrupted", style="yellow")
|
|
353
|
+
break
|
|
354
|
+
chunk = wav.pcm_data[i : i + chunk_bytes]
|
|
355
|
+
|
|
356
|
+
# Convert bytes to numpy array for sounddevice
|
|
357
|
+
audio_array = np.frombuffer(chunk, dtype=dtype)
|
|
358
|
+
if wav.num_channels > 1:
|
|
359
|
+
audio_array = audio_array.reshape(-1, wav.num_channels)
|
|
360
|
+
|
|
361
|
+
stream.write(audio_array)
|
|
362
|
+
await asyncio.sleep(0)
|
|
363
|
+
if not (stop_event and stop_event.is_set()):
|
|
364
|
+
logger.info("Audio playback completed (speed: %.1fx)", speed)
|
|
365
|
+
if not quiet:
|
|
366
|
+
print_with_style("â
Audio playback finished")
|
|
367
|
+
except Exception as e:
|
|
368
|
+
logger.exception("Error during audio playback")
|
|
369
|
+
if not quiet:
|
|
370
|
+
print_error_message(f"Playback error: {e}")
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
async def _speak_text(
|
|
374
|
+
*,
|
|
375
|
+
text: str,
|
|
376
|
+
provider_cfg: config.ProviderSelection,
|
|
377
|
+
audio_output_cfg: config.AudioOutput,
|
|
378
|
+
wyoming_tts_cfg: config.WyomingTTS,
|
|
379
|
+
openai_tts_cfg: config.OpenAITTS,
|
|
380
|
+
kokoro_tts_cfg: config.KokoroTTS,
|
|
381
|
+
gemini_tts_cfg: config.GeminiTTS | None = None,
|
|
382
|
+
logger: logging.Logger,
|
|
383
|
+
quiet: bool = False,
|
|
384
|
+
play_audio_flag: bool = True,
|
|
385
|
+
stop_event: InteractiveStopEvent | None = None,
|
|
386
|
+
live: Live,
|
|
387
|
+
) -> bytes | None:
|
|
388
|
+
"""Synthesize and optionally play speech from text."""
|
|
389
|
+
synthesizer = create_synthesizer(
|
|
390
|
+
provider_cfg,
|
|
391
|
+
audio_output_cfg,
|
|
392
|
+
wyoming_tts_cfg,
|
|
393
|
+
openai_tts_cfg,
|
|
394
|
+
kokoro_tts_cfg,
|
|
395
|
+
gemini_tts_cfg,
|
|
396
|
+
)
|
|
397
|
+
audio_data = None
|
|
398
|
+
try:
|
|
399
|
+
async with live_timer(live, "đ Synthesizing text", style="blue", quiet=quiet):
|
|
400
|
+
audio_data = await synthesizer(
|
|
401
|
+
text=text,
|
|
402
|
+
wyoming_tts_cfg=wyoming_tts_cfg,
|
|
403
|
+
openai_tts_cfg=openai_tts_cfg,
|
|
404
|
+
kokoro_tts_cfg=kokoro_tts_cfg,
|
|
405
|
+
gemini_tts_cfg=gemini_tts_cfg,
|
|
406
|
+
logger=logger,
|
|
407
|
+
quiet=quiet,
|
|
408
|
+
live=live,
|
|
409
|
+
)
|
|
410
|
+
except Exception:
|
|
411
|
+
logger.exception("Error during speech synthesis")
|
|
412
|
+
return None
|
|
413
|
+
|
|
414
|
+
if audio_data and play_audio_flag:
|
|
415
|
+
await _play_audio(
|
|
416
|
+
audio_data,
|
|
417
|
+
logger,
|
|
418
|
+
audio_output_cfg=audio_output_cfg,
|
|
419
|
+
quiet=quiet,
|
|
420
|
+
stop_event=stop_event,
|
|
421
|
+
live=live,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
return audio_data
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
async def _save_audio_file(
|
|
428
|
+
audio_data: bytes,
|
|
429
|
+
save_file: Path,
|
|
430
|
+
quiet: bool,
|
|
431
|
+
logger: logging.Logger,
|
|
432
|
+
*,
|
|
433
|
+
description: str = "Audio",
|
|
434
|
+
) -> None:
|
|
435
|
+
try:
|
|
436
|
+
save_path = Path(save_file)
|
|
437
|
+
await asyncio.to_thread(save_path.write_bytes, audio_data)
|
|
438
|
+
if not quiet:
|
|
439
|
+
print_with_style(f"đž {description} saved to {save_file}")
|
|
440
|
+
logger.info("%s saved to %s", description, save_file)
|
|
441
|
+
except (OSError, PermissionError) as e:
|
|
442
|
+
logger.exception("Failed to save %s", description.lower())
|
|
443
|
+
if not quiet:
|
|
444
|
+
print_with_style(
|
|
445
|
+
f"â Failed to save {description.lower()}: {e}",
|
|
446
|
+
style="red",
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
__all__ = ["handle_tts_playback"]
|