amd-gaia 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/METADATA +223 -223
- amd_gaia-0.15.1.dist-info/RECORD +178 -0
- {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/entry_points.txt +1 -0
- {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/licenses/LICENSE.md +20 -20
- gaia/__init__.py +29 -29
- gaia/agents/__init__.py +19 -19
- gaia/agents/base/__init__.py +9 -9
- gaia/agents/base/agent.py +2177 -2177
- gaia/agents/base/api_agent.py +120 -120
- gaia/agents/base/console.py +1841 -1841
- gaia/agents/base/errors.py +237 -237
- gaia/agents/base/mcp_agent.py +86 -86
- gaia/agents/base/tools.py +83 -83
- gaia/agents/blender/agent.py +556 -556
- gaia/agents/blender/agent_simple.py +133 -135
- gaia/agents/blender/app.py +211 -211
- gaia/agents/blender/app_simple.py +41 -41
- gaia/agents/blender/core/__init__.py +16 -16
- gaia/agents/blender/core/materials.py +506 -506
- gaia/agents/blender/core/objects.py +316 -316
- gaia/agents/blender/core/rendering.py +225 -225
- gaia/agents/blender/core/scene.py +220 -220
- gaia/agents/blender/core/view.py +146 -146
- gaia/agents/chat/__init__.py +9 -9
- gaia/agents/chat/agent.py +835 -835
- gaia/agents/chat/app.py +1058 -1058
- gaia/agents/chat/session.py +508 -508
- gaia/agents/chat/tools/__init__.py +15 -15
- gaia/agents/chat/tools/file_tools.py +96 -96
- gaia/agents/chat/tools/rag_tools.py +1729 -1729
- gaia/agents/chat/tools/shell_tools.py +436 -436
- gaia/agents/code/__init__.py +7 -7
- gaia/agents/code/agent.py +549 -549
- gaia/agents/code/cli.py +377 -0
- gaia/agents/code/models.py +135 -135
- gaia/agents/code/orchestration/__init__.py +24 -24
- gaia/agents/code/orchestration/checklist_executor.py +1763 -1763
- gaia/agents/code/orchestration/checklist_generator.py +713 -713
- gaia/agents/code/orchestration/factories/__init__.py +9 -9
- gaia/agents/code/orchestration/factories/base.py +63 -63
- gaia/agents/code/orchestration/factories/nextjs_factory.py +118 -118
- gaia/agents/code/orchestration/factories/python_factory.py +106 -106
- gaia/agents/code/orchestration/orchestrator.py +841 -841
- gaia/agents/code/orchestration/project_analyzer.py +391 -391
- gaia/agents/code/orchestration/steps/__init__.py +67 -67
- gaia/agents/code/orchestration/steps/base.py +188 -188
- gaia/agents/code/orchestration/steps/error_handler.py +314 -314
- gaia/agents/code/orchestration/steps/nextjs.py +828 -828
- gaia/agents/code/orchestration/steps/python.py +307 -307
- gaia/agents/code/orchestration/template_catalog.py +469 -469
- gaia/agents/code/orchestration/workflows/__init__.py +14 -14
- gaia/agents/code/orchestration/workflows/base.py +80 -80
- gaia/agents/code/orchestration/workflows/nextjs.py +186 -186
- gaia/agents/code/orchestration/workflows/python.py +94 -94
- gaia/agents/code/prompts/__init__.py +11 -11
- gaia/agents/code/prompts/base_prompt.py +77 -77
- gaia/agents/code/prompts/code_patterns.py +2036 -2036
- gaia/agents/code/prompts/nextjs_prompt.py +40 -40
- gaia/agents/code/prompts/python_prompt.py +109 -109
- gaia/agents/code/schema_inference.py +365 -365
- gaia/agents/code/system_prompt.py +41 -41
- gaia/agents/code/tools/__init__.py +42 -42
- gaia/agents/code/tools/cli_tools.py +1138 -1138
- gaia/agents/code/tools/code_formatting.py +319 -319
- gaia/agents/code/tools/code_tools.py +769 -769
- gaia/agents/code/tools/error_fixing.py +1347 -1347
- gaia/agents/code/tools/external_tools.py +180 -180
- gaia/agents/code/tools/file_io.py +845 -845
- gaia/agents/code/tools/prisma_tools.py +190 -190
- gaia/agents/code/tools/project_management.py +1016 -1016
- gaia/agents/code/tools/testing.py +321 -321
- gaia/agents/code/tools/typescript_tools.py +122 -122
- gaia/agents/code/tools/validation_parsing.py +461 -461
- gaia/agents/code/tools/validation_tools.py +806 -806
- gaia/agents/code/tools/web_dev_tools.py +1758 -1758
- gaia/agents/code/validators/__init__.py +16 -16
- gaia/agents/code/validators/antipattern_checker.py +241 -241
- gaia/agents/code/validators/ast_analyzer.py +197 -197
- gaia/agents/code/validators/requirements_validator.py +145 -145
- gaia/agents/code/validators/syntax_validator.py +171 -171
- gaia/agents/docker/__init__.py +7 -7
- gaia/agents/docker/agent.py +642 -642
- gaia/agents/emr/__init__.py +8 -8
- gaia/agents/emr/agent.py +1506 -1506
- gaia/agents/emr/cli.py +1322 -1322
- gaia/agents/emr/constants.py +475 -475
- gaia/agents/emr/dashboard/__init__.py +4 -4
- gaia/agents/emr/dashboard/server.py +1974 -1974
- gaia/agents/jira/__init__.py +11 -11
- gaia/agents/jira/agent.py +894 -894
- gaia/agents/jira/jql_templates.py +299 -299
- gaia/agents/routing/__init__.py +7 -7
- gaia/agents/routing/agent.py +567 -570
- gaia/agents/routing/system_prompt.py +75 -75
- gaia/agents/summarize/__init__.py +11 -0
- gaia/agents/summarize/agent.py +885 -0
- gaia/agents/summarize/prompts.py +129 -0
- gaia/api/__init__.py +23 -23
- gaia/api/agent_registry.py +238 -238
- gaia/api/app.py +305 -305
- gaia/api/openai_server.py +575 -575
- gaia/api/schemas.py +186 -186
- gaia/api/sse_handler.py +373 -373
- gaia/apps/__init__.py +4 -4
- gaia/apps/llm/__init__.py +6 -6
- gaia/apps/llm/app.py +173 -169
- gaia/apps/summarize/app.py +116 -633
- gaia/apps/summarize/html_viewer.py +133 -133
- gaia/apps/summarize/pdf_formatter.py +284 -284
- gaia/audio/__init__.py +2 -2
- gaia/audio/audio_client.py +439 -439
- gaia/audio/audio_recorder.py +269 -269
- gaia/audio/kokoro_tts.py +599 -599
- gaia/audio/whisper_asr.py +432 -432
- gaia/chat/__init__.py +16 -16
- gaia/chat/app.py +430 -430
- gaia/chat/prompts.py +522 -522
- gaia/chat/sdk.py +1228 -1225
- gaia/cli.py +5481 -5621
- gaia/database/__init__.py +10 -10
- gaia/database/agent.py +176 -176
- gaia/database/mixin.py +290 -290
- gaia/database/testing.py +64 -64
- gaia/eval/batch_experiment.py +2332 -2332
- gaia/eval/claude.py +542 -542
- gaia/eval/config.py +37 -37
- gaia/eval/email_generator.py +512 -512
- gaia/eval/eval.py +3179 -3179
- gaia/eval/groundtruth.py +1130 -1130
- gaia/eval/transcript_generator.py +582 -582
- gaia/eval/webapp/README.md +167 -167
- gaia/eval/webapp/package-lock.json +875 -875
- gaia/eval/webapp/package.json +20 -20
- gaia/eval/webapp/public/app.js +3402 -3402
- gaia/eval/webapp/public/index.html +87 -87
- gaia/eval/webapp/public/styles.css +3661 -3661
- gaia/eval/webapp/server.js +415 -415
- gaia/eval/webapp/test-setup.js +72 -72
- gaia/llm/__init__.py +9 -2
- gaia/llm/base_client.py +60 -0
- gaia/llm/exceptions.py +12 -0
- gaia/llm/factory.py +70 -0
- gaia/llm/lemonade_client.py +3236 -3221
- gaia/llm/lemonade_manager.py +294 -294
- gaia/llm/providers/__init__.py +9 -0
- gaia/llm/providers/claude.py +108 -0
- gaia/llm/providers/lemonade.py +120 -0
- gaia/llm/providers/openai_provider.py +79 -0
- gaia/llm/vlm_client.py +382 -382
- gaia/logger.py +189 -189
- gaia/mcp/agent_mcp_server.py +245 -245
- gaia/mcp/blender_mcp_client.py +138 -138
- gaia/mcp/blender_mcp_server.py +648 -648
- gaia/mcp/context7_cache.py +332 -332
- gaia/mcp/external_services.py +518 -518
- gaia/mcp/mcp_bridge.py +811 -550
- gaia/mcp/servers/__init__.py +6 -6
- gaia/mcp/servers/docker_mcp.py +83 -83
- gaia/perf_analysis.py +361 -0
- gaia/rag/__init__.py +10 -10
- gaia/rag/app.py +293 -293
- gaia/rag/demo.py +304 -304
- gaia/rag/pdf_utils.py +235 -235
- gaia/rag/sdk.py +2194 -2194
- gaia/security.py +163 -163
- gaia/talk/app.py +289 -289
- gaia/talk/sdk.py +538 -538
- gaia/testing/__init__.py +87 -87
- gaia/testing/assertions.py +330 -330
- gaia/testing/fixtures.py +333 -333
- gaia/testing/mocks.py +493 -493
- gaia/util.py +46 -46
- gaia/utils/__init__.py +33 -33
- gaia/utils/file_watcher.py +675 -675
- gaia/utils/parsing.py +223 -223
- gaia/version.py +100 -100
- amd_gaia-0.14.3.dist-info/RECORD +0 -168
- gaia/agents/code/app.py +0 -266
- gaia/llm/llm_client.py +0 -729
- {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/WHEEL +0 -0
- {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/top_level.txt +0 -0
gaia/audio/audio_client.py
CHANGED
|
@@ -1,439 +1,439 @@
|
|
|
1
|
-
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: MIT
|
|
3
|
-
|
|
4
|
-
import asyncio
|
|
5
|
-
import queue
|
|
6
|
-
import threading
|
|
7
|
-
import time
|
|
8
|
-
|
|
9
|
-
from gaia.llm
|
|
10
|
-
from gaia.logger import get_logger
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class AudioClient:
|
|
14
|
-
"""Handles all audio-related functionality including TTS, ASR, and voice chat."""
|
|
15
|
-
|
|
16
|
-
def __init__(
|
|
17
|
-
self,
|
|
18
|
-
whisper_model_size="base",
|
|
19
|
-
audio_device_index=None, # Use default input device
|
|
20
|
-
silence_threshold=0.5,
|
|
21
|
-
enable_tts=True,
|
|
22
|
-
logging_level="INFO",
|
|
23
|
-
use_claude=False,
|
|
24
|
-
use_chatgpt=False,
|
|
25
|
-
system_prompt=None,
|
|
26
|
-
):
|
|
27
|
-
self.log = get_logger(__name__)
|
|
28
|
-
self.log.setLevel(getattr(__import__("logging"), logging_level))
|
|
29
|
-
|
|
30
|
-
# Audio configuration
|
|
31
|
-
self.whisper_model_size = whisper_model_size
|
|
32
|
-
self.audio_device_index = audio_device_index
|
|
33
|
-
self.silence_threshold = silence_threshold
|
|
34
|
-
self.enable_tts = enable_tts
|
|
35
|
-
|
|
36
|
-
# Audio state
|
|
37
|
-
self.is_speaking = False
|
|
38
|
-
self.tts_thread = None
|
|
39
|
-
self.whisper_asr = None
|
|
40
|
-
self.transcription_queue = queue.Queue()
|
|
41
|
-
self.tts = None
|
|
42
|
-
|
|
43
|
-
# Initialize LLM client
|
|
44
|
-
self.llm_client =
|
|
45
|
-
use_claude=use_claude,
|
|
46
|
-
use_openai=use_chatgpt,
|
|
47
|
-
system_prompt=system_prompt,
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
self.log.info("Audio client initialized.")
|
|
51
|
-
|
|
52
|
-
async def start_voice_chat(self, message_processor_callback):
|
|
53
|
-
"""Start a voice-based chat session."""
|
|
54
|
-
try:
|
|
55
|
-
self.log.debug("Initializing voice chat...")
|
|
56
|
-
print(
|
|
57
|
-
"Starting voice chat.\n"
|
|
58
|
-
"Say 'stop' to quit application "
|
|
59
|
-
"or 'restart' to clear the chat history.\n"
|
|
60
|
-
"Press Enter key to stop during audio playback."
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
# Initialize TTS before starting voice chat
|
|
64
|
-
self.initialize_tts()
|
|
65
|
-
|
|
66
|
-
from gaia.audio.whisper_asr import WhisperAsr
|
|
67
|
-
|
|
68
|
-
# Create WhisperAsr with custom thresholds
|
|
69
|
-
# Your audio shows energy levels of 0.02-0.03 when speaking
|
|
70
|
-
self.whisper_asr = WhisperAsr(
|
|
71
|
-
model_size=self.whisper_model_size,
|
|
72
|
-
device_index=self.audio_device_index,
|
|
73
|
-
transcription_queue=self.transcription_queue,
|
|
74
|
-
silence_threshold=0.01, # Set higher to ensure detection (your levels are 0.01-0.2+)
|
|
75
|
-
min_audio_length=16000 * 1.0, # 1 second minimum at 16kHz
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
# Log the thresholds being used (reduce verbosity)
|
|
79
|
-
self.log.debug(
|
|
80
|
-
f"Audio settings: SILENCE_THRESHOLD={self.whisper_asr.SILENCE_THRESHOLD}, "
|
|
81
|
-
f"MIN_LENGTH={self.whisper_asr.MIN_AUDIO_LENGTH/self.whisper_asr.RATE:.1f}s"
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
device_name = self.whisper_asr.get_device_name()
|
|
85
|
-
self.log.debug(f"Using audio device: {device_name}")
|
|
86
|
-
|
|
87
|
-
# Start recording
|
|
88
|
-
self.log.debug("Starting audio recording...")
|
|
89
|
-
self.whisper_asr.start_recording()
|
|
90
|
-
|
|
91
|
-
# Start the processing thread after recording is initialized
|
|
92
|
-
self.log.debug("Starting audio processing thread...")
|
|
93
|
-
process_thread = threading.Thread(
|
|
94
|
-
target=self._process_audio_wrapper, args=(message_processor_callback,)
|
|
95
|
-
)
|
|
96
|
-
process_thread.daemon = True
|
|
97
|
-
process_thread.start()
|
|
98
|
-
|
|
99
|
-
# Keep the main thread alive while processing
|
|
100
|
-
self.log.debug("Listening for voice input...")
|
|
101
|
-
try:
|
|
102
|
-
while True:
|
|
103
|
-
if not process_thread.is_alive():
|
|
104
|
-
self.log.debug("Process thread stopped unexpectedly")
|
|
105
|
-
break
|
|
106
|
-
if not self.whisper_asr or not self.whisper_asr.is_recording:
|
|
107
|
-
self.log.warning("Recording stopped unexpectedly")
|
|
108
|
-
break
|
|
109
|
-
await asyncio.sleep(0.1)
|
|
110
|
-
|
|
111
|
-
except KeyboardInterrupt:
|
|
112
|
-
self.log.info("Received keyboard interrupt")
|
|
113
|
-
print("\nStopping voice chat...")
|
|
114
|
-
except Exception as e:
|
|
115
|
-
self.log.error(f"Error in main processing loop: {str(e)}")
|
|
116
|
-
raise
|
|
117
|
-
finally:
|
|
118
|
-
if self.whisper_asr:
|
|
119
|
-
self.log.debug("Stopping recording...")
|
|
120
|
-
self.whisper_asr.stop_recording()
|
|
121
|
-
self.log.debug("Waiting for process thread to finish...")
|
|
122
|
-
process_thread.join(timeout=2.0)
|
|
123
|
-
|
|
124
|
-
except ImportError:
|
|
125
|
-
self.log.error(
|
|
126
|
-
'WhisperAsr not found. Please install voice support with: uv pip install ".[talk]"'
|
|
127
|
-
)
|
|
128
|
-
raise
|
|
129
|
-
except Exception as e:
|
|
130
|
-
self.log.error(f"Failed to initialize voice chat: {str(e)}")
|
|
131
|
-
raise
|
|
132
|
-
finally:
|
|
133
|
-
if self.whisper_asr:
|
|
134
|
-
self.whisper_asr.stop_recording()
|
|
135
|
-
self.log.info("Voice recording stopped")
|
|
136
|
-
|
|
137
|
-
async def process_voice_input(self, text, get_stats_callback=None):
|
|
138
|
-
"""Process transcribed voice input and get AI response"""
|
|
139
|
-
|
|
140
|
-
# Initialize TTS streaming
|
|
141
|
-
text_queue = None
|
|
142
|
-
tts_finished = threading.Event() # Add event to track TTS completion
|
|
143
|
-
interrupt_event = threading.Event() # Add event for keyboard interrupts
|
|
144
|
-
|
|
145
|
-
try:
|
|
146
|
-
# Check if we're currently generating and halt if needed
|
|
147
|
-
if self.llm_client.is_generating():
|
|
148
|
-
self.log.debug("Generation in progress, halting...")
|
|
149
|
-
if self.llm_client.halt_generation():
|
|
150
|
-
print("\nGeneration interrupted.")
|
|
151
|
-
await asyncio.sleep(0.5)
|
|
152
|
-
|
|
153
|
-
# Pause audio recording before sending query
|
|
154
|
-
if self.whisper_asr:
|
|
155
|
-
self.whisper_asr.pause_recording()
|
|
156
|
-
self.log.debug("Recording paused before generation")
|
|
157
|
-
|
|
158
|
-
self.log.debug(f"Sending message to LLM: {text[:50]}...")
|
|
159
|
-
print("\nGaia: ", end="", flush=True)
|
|
160
|
-
|
|
161
|
-
# Keyboard listener thread for both generation and playback
|
|
162
|
-
def keyboard_listener():
|
|
163
|
-
input() # Wait for any input
|
|
164
|
-
|
|
165
|
-
# Use LLMClient to halt generation
|
|
166
|
-
if self.llm_client.halt_generation():
|
|
167
|
-
print("\nGeneration interrupted.")
|
|
168
|
-
else:
|
|
169
|
-
print("\nInterrupt requested.")
|
|
170
|
-
|
|
171
|
-
interrupt_event.set()
|
|
172
|
-
if text_queue:
|
|
173
|
-
text_queue.put("__HALT__") # Signal TTS to stop immediately
|
|
174
|
-
|
|
175
|
-
# Start keyboard listener thread
|
|
176
|
-
keyboard_thread = threading.Thread(target=keyboard_listener)
|
|
177
|
-
keyboard_thread.daemon = True
|
|
178
|
-
keyboard_thread.start()
|
|
179
|
-
|
|
180
|
-
if self.enable_tts:
|
|
181
|
-
text_queue = queue.Queue(maxsize=100)
|
|
182
|
-
|
|
183
|
-
# Define status callback to update speaking state
|
|
184
|
-
def tts_status_callback(is_speaking):
|
|
185
|
-
self.is_speaking = is_speaking
|
|
186
|
-
if not is_speaking: # When TTS finishes speaking
|
|
187
|
-
tts_finished.set()
|
|
188
|
-
if self.whisper_asr:
|
|
189
|
-
self.whisper_asr.resume_recording()
|
|
190
|
-
else: # When TTS starts speaking
|
|
191
|
-
if self.whisper_asr:
|
|
192
|
-
self.whisper_asr.pause_recording()
|
|
193
|
-
self.log.debug(f"TTS speaking state: {is_speaking}")
|
|
194
|
-
|
|
195
|
-
self.tts_thread = threading.Thread(
|
|
196
|
-
target=self.tts.generate_speech_streaming,
|
|
197
|
-
args=(text_queue,),
|
|
198
|
-
kwargs={
|
|
199
|
-
"status_callback": tts_status_callback,
|
|
200
|
-
"interrupt_event": interrupt_event,
|
|
201
|
-
},
|
|
202
|
-
daemon=True,
|
|
203
|
-
)
|
|
204
|
-
self.tts_thread.start()
|
|
205
|
-
|
|
206
|
-
# Use LLMClient streaming instead of WebSocket
|
|
207
|
-
accumulated_response = ""
|
|
208
|
-
initial_buffer = "" # Buffer for the start of response
|
|
209
|
-
initial_buffer_sent = False
|
|
210
|
-
|
|
211
|
-
try:
|
|
212
|
-
# Start LLM generation with streaming
|
|
213
|
-
response_stream = self.llm_client.generate(text, stream=True)
|
|
214
|
-
|
|
215
|
-
# Process streaming response
|
|
216
|
-
for chunk in response_stream:
|
|
217
|
-
if interrupt_event.is_set():
|
|
218
|
-
self.log.debug("Keyboard interrupt detected, stopping...")
|
|
219
|
-
if text_queue:
|
|
220
|
-
text_queue.put("__END__")
|
|
221
|
-
break
|
|
222
|
-
|
|
223
|
-
if self.transcription_queue.qsize() > 0:
|
|
224
|
-
self.log.debug(
|
|
225
|
-
"New input detected during generation, stopping..."
|
|
226
|
-
)
|
|
227
|
-
if text_queue:
|
|
228
|
-
text_queue.put("__END__")
|
|
229
|
-
# Use LLMClient to halt generation
|
|
230
|
-
if self.llm_client.halt_generation():
|
|
231
|
-
self.log.debug("Generation interrupted for new input.")
|
|
232
|
-
return
|
|
233
|
-
|
|
234
|
-
if chunk:
|
|
235
|
-
print(chunk, end="", flush=True)
|
|
236
|
-
if text_queue:
|
|
237
|
-
if not initial_buffer_sent:
|
|
238
|
-
initial_buffer += chunk
|
|
239
|
-
# Send if we've reached 20 chars or if we get a clear end marker
|
|
240
|
-
if len(initial_buffer) >= 20 or chunk.endswith(
|
|
241
|
-
("\n", ". ", "! ", "? ")
|
|
242
|
-
):
|
|
243
|
-
text_queue.put(initial_buffer)
|
|
244
|
-
initial_buffer_sent = True
|
|
245
|
-
else:
|
|
246
|
-
text_queue.put(chunk)
|
|
247
|
-
accumulated_response += chunk
|
|
248
|
-
|
|
249
|
-
# Send any remaining buffered content
|
|
250
|
-
if text_queue:
|
|
251
|
-
if not initial_buffer_sent and initial_buffer:
|
|
252
|
-
# Small delay for very short responses
|
|
253
|
-
if len(initial_buffer) <= 20:
|
|
254
|
-
await asyncio.sleep(0.1)
|
|
255
|
-
text_queue.put(initial_buffer)
|
|
256
|
-
text_queue.put("__END__")
|
|
257
|
-
|
|
258
|
-
except Exception as e:
|
|
259
|
-
if text_queue:
|
|
260
|
-
text_queue.put("__END__")
|
|
261
|
-
raise e
|
|
262
|
-
finally:
|
|
263
|
-
if self.tts_thread and self.tts_thread.is_alive():
|
|
264
|
-
self.tts_thread.join(timeout=1.0) # Add timeout to thread join
|
|
265
|
-
keyboard_thread.join(timeout=1.0) # Add timeout to keyboard thread join
|
|
266
|
-
|
|
267
|
-
print("\n")
|
|
268
|
-
# Get performance stats from LLMClient
|
|
269
|
-
if get_stats_callback:
|
|
270
|
-
# First try the provided callback for backward compatibility
|
|
271
|
-
stats = get_stats_callback()
|
|
272
|
-
else:
|
|
273
|
-
# Use LLMClient stats
|
|
274
|
-
stats = self.llm_client.get_performance_stats()
|
|
275
|
-
|
|
276
|
-
if stats:
|
|
277
|
-
from pprint import pprint
|
|
278
|
-
|
|
279
|
-
formatted_stats = {
|
|
280
|
-
k: round(v, 1) if isinstance(v, float) else v
|
|
281
|
-
for k, v in stats.items()
|
|
282
|
-
}
|
|
283
|
-
pprint(formatted_stats)
|
|
284
|
-
|
|
285
|
-
except Exception as e:
|
|
286
|
-
if text_queue:
|
|
287
|
-
text_queue.put("__END__")
|
|
288
|
-
raise e
|
|
289
|
-
finally:
|
|
290
|
-
if self.tts_thread and self.tts_thread.is_alive():
|
|
291
|
-
# Wait for TTS to finish before resuming recording
|
|
292
|
-
tts_finished.wait(timeout=2.0) # Add reasonable timeout
|
|
293
|
-
self.tts_thread.join(timeout=1.0)
|
|
294
|
-
|
|
295
|
-
# Only resume recording after TTS is completely finished
|
|
296
|
-
if self.whisper_asr:
|
|
297
|
-
self.whisper_asr.resume_recording()
|
|
298
|
-
|
|
299
|
-
def initialize_tts(self):
|
|
300
|
-
"""Initialize TTS if enabled."""
|
|
301
|
-
if self.enable_tts:
|
|
302
|
-
try:
|
|
303
|
-
from gaia.audio.kokoro_tts import KokoroTTS
|
|
304
|
-
|
|
305
|
-
self.tts = KokoroTTS()
|
|
306
|
-
self.log.debug("TTS initialized successfully")
|
|
307
|
-
except Exception as e:
|
|
308
|
-
raise RuntimeError(
|
|
309
|
-
f'Failed to initialize TTS:\n{e}\nInstall talk dependencies with: uv pip install ".[talk]"\nYou can also use --no-tts option to disable TTS'
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
async def speak_text(self, text: str) -> None:
|
|
313
|
-
"""Speak text using initialized TTS, if available."""
|
|
314
|
-
if not self.enable_tts:
|
|
315
|
-
return
|
|
316
|
-
if not getattr(self, "tts", None):
|
|
317
|
-
self.log.debug("TTS is not initialized; skipping speak_text")
|
|
318
|
-
return
|
|
319
|
-
# Reuse the streaming path used in process_voice_input
|
|
320
|
-
text_queue = queue.Queue(maxsize=100)
|
|
321
|
-
interrupt_event = threading.Event()
|
|
322
|
-
tts_thread = threading.Thread(
|
|
323
|
-
target=self.tts.generate_speech_streaming,
|
|
324
|
-
args=(text_queue,),
|
|
325
|
-
kwargs={"interrupt_event": interrupt_event},
|
|
326
|
-
daemon=True,
|
|
327
|
-
)
|
|
328
|
-
tts_thread.start()
|
|
329
|
-
# Send the whole text and end
|
|
330
|
-
text_queue.put(text)
|
|
331
|
-
text_queue.put("__END__")
|
|
332
|
-
tts_thread.join(timeout=5.0)
|
|
333
|
-
|
|
334
|
-
def _process_audio_wrapper(self, message_processor_callback):
|
|
335
|
-
"""Wrapper method to process audio and handle transcriptions"""
|
|
336
|
-
try:
|
|
337
|
-
accumulated_text = []
|
|
338
|
-
current_display = ""
|
|
339
|
-
last_transcription_time = time.time()
|
|
340
|
-
spinner_chars = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
|
341
|
-
dots_animation = [" ", ". ", ".. ", "..."]
|
|
342
|
-
spinner_idx = 0
|
|
343
|
-
dots_idx = 0
|
|
344
|
-
animation_counter = 0
|
|
345
|
-
self.is_speaking = False # Initialize speaking state
|
|
346
|
-
|
|
347
|
-
while self.whisper_asr and self.whisper_asr.is_recording:
|
|
348
|
-
try:
|
|
349
|
-
text = self.transcription_queue.get(timeout=0.1)
|
|
350
|
-
|
|
351
|
-
current_time = time.time()
|
|
352
|
-
time_since_last = current_time - last_transcription_time
|
|
353
|
-
cleaned_text = text.lower().strip().rstrip(".!?")
|
|
354
|
-
|
|
355
|
-
# Handle special commands
|
|
356
|
-
if cleaned_text in ["stop"]:
|
|
357
|
-
print("\nStopping voice chat...")
|
|
358
|
-
self.whisper_asr.stop_recording()
|
|
359
|
-
break
|
|
360
|
-
|
|
361
|
-
# Update animations
|
|
362
|
-
spinner_idx = (spinner_idx + 1) % len(spinner_chars)
|
|
363
|
-
animation_counter += 1
|
|
364
|
-
if animation_counter % 4 == 0: # Update dots every fourth cycle
|
|
365
|
-
dots_idx = (dots_idx + 1) % len(dots_animation)
|
|
366
|
-
spinner = spinner_chars[spinner_idx]
|
|
367
|
-
dots = dots_animation[dots_idx]
|
|
368
|
-
|
|
369
|
-
# Normal text processing - only if it's not a system message
|
|
370
|
-
if text != current_display:
|
|
371
|
-
# Clear the current line and display updated text with spinner
|
|
372
|
-
print(f"\r\033[K{spinner} {text}", end="", flush=True)
|
|
373
|
-
current_display = text
|
|
374
|
-
|
|
375
|
-
# Only add new text if it's significantly different
|
|
376
|
-
if not any(text in existing for existing in accumulated_text):
|
|
377
|
-
accumulated_text = [text] # Replace instead of append
|
|
378
|
-
last_transcription_time = current_time
|
|
379
|
-
|
|
380
|
-
# Process accumulated text after silence threshold
|
|
381
|
-
if time_since_last > self.silence_threshold:
|
|
382
|
-
if accumulated_text:
|
|
383
|
-
complete_text = accumulated_text[
|
|
384
|
-
-1
|
|
385
|
-
] # Use only the last transcription
|
|
386
|
-
print() # Add a newline before agent response
|
|
387
|
-
asyncio.run(message_processor_callback(complete_text))
|
|
388
|
-
accumulated_text = []
|
|
389
|
-
current_display = ""
|
|
390
|
-
|
|
391
|
-
except queue.Empty:
|
|
392
|
-
# Update animations
|
|
393
|
-
spinner_idx = (spinner_idx + 1) % len(spinner_chars)
|
|
394
|
-
animation_counter += 1
|
|
395
|
-
if animation_counter % 4 == 0:
|
|
396
|
-
dots_idx = (dots_idx + 1) % len(dots_animation)
|
|
397
|
-
spinner = spinner_chars[spinner_idx]
|
|
398
|
-
dots = dots_animation[dots_idx]
|
|
399
|
-
|
|
400
|
-
if current_display:
|
|
401
|
-
print(
|
|
402
|
-
f"\r\033[K{spinner} {current_display}", end="", flush=True
|
|
403
|
-
)
|
|
404
|
-
else:
|
|
405
|
-
# Access the class-level speaking state
|
|
406
|
-
status = (
|
|
407
|
-
"Speaking"
|
|
408
|
-
if getattr(self, "is_speaking", False)
|
|
409
|
-
else "Listening"
|
|
410
|
-
)
|
|
411
|
-
print(f"\r\033[K{spinner} {status}{dots}", end="", flush=True)
|
|
412
|
-
|
|
413
|
-
if (
|
|
414
|
-
accumulated_text
|
|
415
|
-
and (time.time() - last_transcription_time)
|
|
416
|
-
> self.silence_threshold
|
|
417
|
-
):
|
|
418
|
-
complete_text = accumulated_text[-1]
|
|
419
|
-
print() # Add a newline before agent response
|
|
420
|
-
asyncio.run(message_processor_callback(complete_text))
|
|
421
|
-
accumulated_text = []
|
|
422
|
-
current_display = ""
|
|
423
|
-
|
|
424
|
-
except Exception as e:
|
|
425
|
-
self.log.error(f"Error in process_audio_wrapper: {str(e)}")
|
|
426
|
-
finally:
|
|
427
|
-
if self.whisper_asr:
|
|
428
|
-
self.whisper_asr.stop_recording()
|
|
429
|
-
if self.tts_thread and self.tts_thread.is_alive():
|
|
430
|
-
self.tts_thread.join(timeout=1.0) # Add timeout to thread join
|
|
431
|
-
|
|
432
|
-
async def halt_generation(self):
|
|
433
|
-
"""Send a request to halt the current generation."""
|
|
434
|
-
if self.llm_client.halt_generation():
|
|
435
|
-
self.log.debug("Successfully halted generation via LLMClient")
|
|
436
|
-
print("\nGeneration interrupted.")
|
|
437
|
-
else:
|
|
438
|
-
self.log.debug("Halt requested - generation will stop on next iteration")
|
|
439
|
-
print("\nInterrupt requested.")
|
|
1
|
+
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
import queue
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
|
|
9
|
+
from gaia.llm import create_client
|
|
10
|
+
from gaia.logger import get_logger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AudioClient:
|
|
14
|
+
"""Handles all audio-related functionality including TTS, ASR, and voice chat."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
whisper_model_size="base",
|
|
19
|
+
audio_device_index=None, # Use default input device
|
|
20
|
+
silence_threshold=0.5,
|
|
21
|
+
enable_tts=True,
|
|
22
|
+
logging_level="INFO",
|
|
23
|
+
use_claude=False,
|
|
24
|
+
use_chatgpt=False,
|
|
25
|
+
system_prompt=None,
|
|
26
|
+
):
|
|
27
|
+
self.log = get_logger(__name__)
|
|
28
|
+
self.log.setLevel(getattr(__import__("logging"), logging_level))
|
|
29
|
+
|
|
30
|
+
# Audio configuration
|
|
31
|
+
self.whisper_model_size = whisper_model_size
|
|
32
|
+
self.audio_device_index = audio_device_index
|
|
33
|
+
self.silence_threshold = silence_threshold
|
|
34
|
+
self.enable_tts = enable_tts
|
|
35
|
+
|
|
36
|
+
# Audio state
|
|
37
|
+
self.is_speaking = False
|
|
38
|
+
self.tts_thread = None
|
|
39
|
+
self.whisper_asr = None
|
|
40
|
+
self.transcription_queue = queue.Queue()
|
|
41
|
+
self.tts = None
|
|
42
|
+
|
|
43
|
+
# Initialize LLM client - factory auto-detects provider from flags
|
|
44
|
+
self.llm_client = create_client(
|
|
45
|
+
use_claude=use_claude,
|
|
46
|
+
use_openai=use_chatgpt,
|
|
47
|
+
system_prompt=system_prompt,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
self.log.info("Audio client initialized.")
|
|
51
|
+
|
|
52
|
+
async def start_voice_chat(self, message_processor_callback):
|
|
53
|
+
"""Start a voice-based chat session."""
|
|
54
|
+
try:
|
|
55
|
+
self.log.debug("Initializing voice chat...")
|
|
56
|
+
print(
|
|
57
|
+
"Starting voice chat.\n"
|
|
58
|
+
"Say 'stop' to quit application "
|
|
59
|
+
"or 'restart' to clear the chat history.\n"
|
|
60
|
+
"Press Enter key to stop during audio playback."
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Initialize TTS before starting voice chat
|
|
64
|
+
self.initialize_tts()
|
|
65
|
+
|
|
66
|
+
from gaia.audio.whisper_asr import WhisperAsr
|
|
67
|
+
|
|
68
|
+
# Create WhisperAsr with custom thresholds
|
|
69
|
+
# Your audio shows energy levels of 0.02-0.03 when speaking
|
|
70
|
+
self.whisper_asr = WhisperAsr(
|
|
71
|
+
model_size=self.whisper_model_size,
|
|
72
|
+
device_index=self.audio_device_index,
|
|
73
|
+
transcription_queue=self.transcription_queue,
|
|
74
|
+
silence_threshold=0.01, # Set higher to ensure detection (your levels are 0.01-0.2+)
|
|
75
|
+
min_audio_length=16000 * 1.0, # 1 second minimum at 16kHz
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Log the thresholds being used (reduce verbosity)
|
|
79
|
+
self.log.debug(
|
|
80
|
+
f"Audio settings: SILENCE_THRESHOLD={self.whisper_asr.SILENCE_THRESHOLD}, "
|
|
81
|
+
f"MIN_LENGTH={self.whisper_asr.MIN_AUDIO_LENGTH/self.whisper_asr.RATE:.1f}s"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
device_name = self.whisper_asr.get_device_name()
|
|
85
|
+
self.log.debug(f"Using audio device: {device_name}")
|
|
86
|
+
|
|
87
|
+
# Start recording
|
|
88
|
+
self.log.debug("Starting audio recording...")
|
|
89
|
+
self.whisper_asr.start_recording()
|
|
90
|
+
|
|
91
|
+
# Start the processing thread after recording is initialized
|
|
92
|
+
self.log.debug("Starting audio processing thread...")
|
|
93
|
+
process_thread = threading.Thread(
|
|
94
|
+
target=self._process_audio_wrapper, args=(message_processor_callback,)
|
|
95
|
+
)
|
|
96
|
+
process_thread.daemon = True
|
|
97
|
+
process_thread.start()
|
|
98
|
+
|
|
99
|
+
# Keep the main thread alive while processing
|
|
100
|
+
self.log.debug("Listening for voice input...")
|
|
101
|
+
try:
|
|
102
|
+
while True:
|
|
103
|
+
if not process_thread.is_alive():
|
|
104
|
+
self.log.debug("Process thread stopped unexpectedly")
|
|
105
|
+
break
|
|
106
|
+
if not self.whisper_asr or not self.whisper_asr.is_recording:
|
|
107
|
+
self.log.warning("Recording stopped unexpectedly")
|
|
108
|
+
break
|
|
109
|
+
await asyncio.sleep(0.1)
|
|
110
|
+
|
|
111
|
+
except KeyboardInterrupt:
|
|
112
|
+
self.log.info("Received keyboard interrupt")
|
|
113
|
+
print("\nStopping voice chat...")
|
|
114
|
+
except Exception as e:
|
|
115
|
+
self.log.error(f"Error in main processing loop: {str(e)}")
|
|
116
|
+
raise
|
|
117
|
+
finally:
|
|
118
|
+
if self.whisper_asr:
|
|
119
|
+
self.log.debug("Stopping recording...")
|
|
120
|
+
self.whisper_asr.stop_recording()
|
|
121
|
+
self.log.debug("Waiting for process thread to finish...")
|
|
122
|
+
process_thread.join(timeout=2.0)
|
|
123
|
+
|
|
124
|
+
except ImportError:
|
|
125
|
+
self.log.error(
|
|
126
|
+
'WhisperAsr not found. Please install voice support with: uv pip install ".[talk]"'
|
|
127
|
+
)
|
|
128
|
+
raise
|
|
129
|
+
except Exception as e:
|
|
130
|
+
self.log.error(f"Failed to initialize voice chat: {str(e)}")
|
|
131
|
+
raise
|
|
132
|
+
finally:
|
|
133
|
+
if self.whisper_asr:
|
|
134
|
+
self.whisper_asr.stop_recording()
|
|
135
|
+
self.log.info("Voice recording stopped")
|
|
136
|
+
|
|
137
|
+
async def process_voice_input(self, text, get_stats_callback=None):
|
|
138
|
+
"""Process transcribed voice input and get AI response"""
|
|
139
|
+
|
|
140
|
+
# Initialize TTS streaming
|
|
141
|
+
text_queue = None
|
|
142
|
+
tts_finished = threading.Event() # Add event to track TTS completion
|
|
143
|
+
interrupt_event = threading.Event() # Add event for keyboard interrupts
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
# Check if we're currently generating and halt if needed
|
|
147
|
+
if self.llm_client.is_generating():
|
|
148
|
+
self.log.debug("Generation in progress, halting...")
|
|
149
|
+
if self.llm_client.halt_generation():
|
|
150
|
+
print("\nGeneration interrupted.")
|
|
151
|
+
await asyncio.sleep(0.5)
|
|
152
|
+
|
|
153
|
+
# Pause audio recording before sending query
|
|
154
|
+
if self.whisper_asr:
|
|
155
|
+
self.whisper_asr.pause_recording()
|
|
156
|
+
self.log.debug("Recording paused before generation")
|
|
157
|
+
|
|
158
|
+
self.log.debug(f"Sending message to LLM: {text[:50]}...")
|
|
159
|
+
print("\nGaia: ", end="", flush=True)
|
|
160
|
+
|
|
161
|
+
# Keyboard listener thread for both generation and playback
|
|
162
|
+
def keyboard_listener():
|
|
163
|
+
input() # Wait for any input
|
|
164
|
+
|
|
165
|
+
# Use LLMClient to halt generation
|
|
166
|
+
if self.llm_client.halt_generation():
|
|
167
|
+
print("\nGeneration interrupted.")
|
|
168
|
+
else:
|
|
169
|
+
print("\nInterrupt requested.")
|
|
170
|
+
|
|
171
|
+
interrupt_event.set()
|
|
172
|
+
if text_queue:
|
|
173
|
+
text_queue.put("__HALT__") # Signal TTS to stop immediately
|
|
174
|
+
|
|
175
|
+
# Start keyboard listener thread
|
|
176
|
+
keyboard_thread = threading.Thread(target=keyboard_listener)
|
|
177
|
+
keyboard_thread.daemon = True
|
|
178
|
+
keyboard_thread.start()
|
|
179
|
+
|
|
180
|
+
if self.enable_tts:
|
|
181
|
+
text_queue = queue.Queue(maxsize=100)
|
|
182
|
+
|
|
183
|
+
# Define status callback to update speaking state
|
|
184
|
+
def tts_status_callback(is_speaking):
|
|
185
|
+
self.is_speaking = is_speaking
|
|
186
|
+
if not is_speaking: # When TTS finishes speaking
|
|
187
|
+
tts_finished.set()
|
|
188
|
+
if self.whisper_asr:
|
|
189
|
+
self.whisper_asr.resume_recording()
|
|
190
|
+
else: # When TTS starts speaking
|
|
191
|
+
if self.whisper_asr:
|
|
192
|
+
self.whisper_asr.pause_recording()
|
|
193
|
+
self.log.debug(f"TTS speaking state: {is_speaking}")
|
|
194
|
+
|
|
195
|
+
self.tts_thread = threading.Thread(
|
|
196
|
+
target=self.tts.generate_speech_streaming,
|
|
197
|
+
args=(text_queue,),
|
|
198
|
+
kwargs={
|
|
199
|
+
"status_callback": tts_status_callback,
|
|
200
|
+
"interrupt_event": interrupt_event,
|
|
201
|
+
},
|
|
202
|
+
daemon=True,
|
|
203
|
+
)
|
|
204
|
+
self.tts_thread.start()
|
|
205
|
+
|
|
206
|
+
# Use LLMClient streaming instead of WebSocket
|
|
207
|
+
accumulated_response = ""
|
|
208
|
+
initial_buffer = "" # Buffer for the start of response
|
|
209
|
+
initial_buffer_sent = False
|
|
210
|
+
|
|
211
|
+
try:
|
|
212
|
+
# Start LLM generation with streaming
|
|
213
|
+
response_stream = self.llm_client.generate(text, stream=True)
|
|
214
|
+
|
|
215
|
+
# Process streaming response
|
|
216
|
+
for chunk in response_stream:
|
|
217
|
+
if interrupt_event.is_set():
|
|
218
|
+
self.log.debug("Keyboard interrupt detected, stopping...")
|
|
219
|
+
if text_queue:
|
|
220
|
+
text_queue.put("__END__")
|
|
221
|
+
break
|
|
222
|
+
|
|
223
|
+
if self.transcription_queue.qsize() > 0:
|
|
224
|
+
self.log.debug(
|
|
225
|
+
"New input detected during generation, stopping..."
|
|
226
|
+
)
|
|
227
|
+
if text_queue:
|
|
228
|
+
text_queue.put("__END__")
|
|
229
|
+
# Use LLMClient to halt generation
|
|
230
|
+
if self.llm_client.halt_generation():
|
|
231
|
+
self.log.debug("Generation interrupted for new input.")
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
if chunk:
|
|
235
|
+
print(chunk, end="", flush=True)
|
|
236
|
+
if text_queue:
|
|
237
|
+
if not initial_buffer_sent:
|
|
238
|
+
initial_buffer += chunk
|
|
239
|
+
# Send if we've reached 20 chars or if we get a clear end marker
|
|
240
|
+
if len(initial_buffer) >= 20 or chunk.endswith(
|
|
241
|
+
("\n", ". ", "! ", "? ")
|
|
242
|
+
):
|
|
243
|
+
text_queue.put(initial_buffer)
|
|
244
|
+
initial_buffer_sent = True
|
|
245
|
+
else:
|
|
246
|
+
text_queue.put(chunk)
|
|
247
|
+
accumulated_response += chunk
|
|
248
|
+
|
|
249
|
+
# Send any remaining buffered content
|
|
250
|
+
if text_queue:
|
|
251
|
+
if not initial_buffer_sent and initial_buffer:
|
|
252
|
+
# Small delay for very short responses
|
|
253
|
+
if len(initial_buffer) <= 20:
|
|
254
|
+
await asyncio.sleep(0.1)
|
|
255
|
+
text_queue.put(initial_buffer)
|
|
256
|
+
text_queue.put("__END__")
|
|
257
|
+
|
|
258
|
+
except Exception as e:
|
|
259
|
+
if text_queue:
|
|
260
|
+
text_queue.put("__END__")
|
|
261
|
+
raise e
|
|
262
|
+
finally:
|
|
263
|
+
if self.tts_thread and self.tts_thread.is_alive():
|
|
264
|
+
self.tts_thread.join(timeout=1.0) # Add timeout to thread join
|
|
265
|
+
keyboard_thread.join(timeout=1.0) # Add timeout to keyboard thread join
|
|
266
|
+
|
|
267
|
+
print("\n")
|
|
268
|
+
# Get performance stats from LLMClient
|
|
269
|
+
if get_stats_callback:
|
|
270
|
+
# First try the provided callback for backward compatibility
|
|
271
|
+
stats = get_stats_callback()
|
|
272
|
+
else:
|
|
273
|
+
# Use LLMClient stats
|
|
274
|
+
stats = self.llm_client.get_performance_stats()
|
|
275
|
+
|
|
276
|
+
if stats:
|
|
277
|
+
from pprint import pprint
|
|
278
|
+
|
|
279
|
+
formatted_stats = {
|
|
280
|
+
k: round(v, 1) if isinstance(v, float) else v
|
|
281
|
+
for k, v in stats.items()
|
|
282
|
+
}
|
|
283
|
+
pprint(formatted_stats)
|
|
284
|
+
|
|
285
|
+
except Exception as e:
|
|
286
|
+
if text_queue:
|
|
287
|
+
text_queue.put("__END__")
|
|
288
|
+
raise e
|
|
289
|
+
finally:
|
|
290
|
+
if self.tts_thread and self.tts_thread.is_alive():
|
|
291
|
+
# Wait for TTS to finish before resuming recording
|
|
292
|
+
tts_finished.wait(timeout=2.0) # Add reasonable timeout
|
|
293
|
+
self.tts_thread.join(timeout=1.0)
|
|
294
|
+
|
|
295
|
+
# Only resume recording after TTS is completely finished
|
|
296
|
+
if self.whisper_asr:
|
|
297
|
+
self.whisper_asr.resume_recording()
|
|
298
|
+
|
|
299
|
+
def initialize_tts(self):
|
|
300
|
+
"""Initialize TTS if enabled."""
|
|
301
|
+
if self.enable_tts:
|
|
302
|
+
try:
|
|
303
|
+
from gaia.audio.kokoro_tts import KokoroTTS
|
|
304
|
+
|
|
305
|
+
self.tts = KokoroTTS()
|
|
306
|
+
self.log.debug("TTS initialized successfully")
|
|
307
|
+
except Exception as e:
|
|
308
|
+
raise RuntimeError(
|
|
309
|
+
f'Failed to initialize TTS:\n{e}\nInstall talk dependencies with: uv pip install ".[talk]"\nYou can also use --no-tts option to disable TTS'
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
async def speak_text(self, text: str) -> None:
|
|
313
|
+
"""Speak text using initialized TTS, if available."""
|
|
314
|
+
if not self.enable_tts:
|
|
315
|
+
return
|
|
316
|
+
if not getattr(self, "tts", None):
|
|
317
|
+
self.log.debug("TTS is not initialized; skipping speak_text")
|
|
318
|
+
return
|
|
319
|
+
# Reuse the streaming path used in process_voice_input
|
|
320
|
+
text_queue = queue.Queue(maxsize=100)
|
|
321
|
+
interrupt_event = threading.Event()
|
|
322
|
+
tts_thread = threading.Thread(
|
|
323
|
+
target=self.tts.generate_speech_streaming,
|
|
324
|
+
args=(text_queue,),
|
|
325
|
+
kwargs={"interrupt_event": interrupt_event},
|
|
326
|
+
daemon=True,
|
|
327
|
+
)
|
|
328
|
+
tts_thread.start()
|
|
329
|
+
# Send the whole text and end
|
|
330
|
+
text_queue.put(text)
|
|
331
|
+
text_queue.put("__END__")
|
|
332
|
+
tts_thread.join(timeout=5.0)
|
|
333
|
+
|
|
334
|
+
def _process_audio_wrapper(self, message_processor_callback):
|
|
335
|
+
"""Wrapper method to process audio and handle transcriptions"""
|
|
336
|
+
try:
|
|
337
|
+
accumulated_text = []
|
|
338
|
+
current_display = ""
|
|
339
|
+
last_transcription_time = time.time()
|
|
340
|
+
spinner_chars = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
|
341
|
+
dots_animation = [" ", ". ", ".. ", "..."]
|
|
342
|
+
spinner_idx = 0
|
|
343
|
+
dots_idx = 0
|
|
344
|
+
animation_counter = 0
|
|
345
|
+
self.is_speaking = False # Initialize speaking state
|
|
346
|
+
|
|
347
|
+
while self.whisper_asr and self.whisper_asr.is_recording:
|
|
348
|
+
try:
|
|
349
|
+
text = self.transcription_queue.get(timeout=0.1)
|
|
350
|
+
|
|
351
|
+
current_time = time.time()
|
|
352
|
+
time_since_last = current_time - last_transcription_time
|
|
353
|
+
cleaned_text = text.lower().strip().rstrip(".!?")
|
|
354
|
+
|
|
355
|
+
# Handle special commands
|
|
356
|
+
if cleaned_text in ["stop"]:
|
|
357
|
+
print("\nStopping voice chat...")
|
|
358
|
+
self.whisper_asr.stop_recording()
|
|
359
|
+
break
|
|
360
|
+
|
|
361
|
+
# Update animations
|
|
362
|
+
spinner_idx = (spinner_idx + 1) % len(spinner_chars)
|
|
363
|
+
animation_counter += 1
|
|
364
|
+
if animation_counter % 4 == 0: # Update dots every fourth cycle
|
|
365
|
+
dots_idx = (dots_idx + 1) % len(dots_animation)
|
|
366
|
+
spinner = spinner_chars[spinner_idx]
|
|
367
|
+
dots = dots_animation[dots_idx]
|
|
368
|
+
|
|
369
|
+
# Normal text processing - only if it's not a system message
|
|
370
|
+
if text != current_display:
|
|
371
|
+
# Clear the current line and display updated text with spinner
|
|
372
|
+
print(f"\r\033[K{spinner} {text}", end="", flush=True)
|
|
373
|
+
current_display = text
|
|
374
|
+
|
|
375
|
+
# Only add new text if it's significantly different
|
|
376
|
+
if not any(text in existing for existing in accumulated_text):
|
|
377
|
+
accumulated_text = [text] # Replace instead of append
|
|
378
|
+
last_transcription_time = current_time
|
|
379
|
+
|
|
380
|
+
# Process accumulated text after silence threshold
|
|
381
|
+
if time_since_last > self.silence_threshold:
|
|
382
|
+
if accumulated_text:
|
|
383
|
+
complete_text = accumulated_text[
|
|
384
|
+
-1
|
|
385
|
+
] # Use only the last transcription
|
|
386
|
+
print() # Add a newline before agent response
|
|
387
|
+
asyncio.run(message_processor_callback(complete_text))
|
|
388
|
+
accumulated_text = []
|
|
389
|
+
current_display = ""
|
|
390
|
+
|
|
391
|
+
except queue.Empty:
|
|
392
|
+
# Update animations
|
|
393
|
+
spinner_idx = (spinner_idx + 1) % len(spinner_chars)
|
|
394
|
+
animation_counter += 1
|
|
395
|
+
if animation_counter % 4 == 0:
|
|
396
|
+
dots_idx = (dots_idx + 1) % len(dots_animation)
|
|
397
|
+
spinner = spinner_chars[spinner_idx]
|
|
398
|
+
dots = dots_animation[dots_idx]
|
|
399
|
+
|
|
400
|
+
if current_display:
|
|
401
|
+
print(
|
|
402
|
+
f"\r\033[K{spinner} {current_display}", end="", flush=True
|
|
403
|
+
)
|
|
404
|
+
else:
|
|
405
|
+
# Access the class-level speaking state
|
|
406
|
+
status = (
|
|
407
|
+
"Speaking"
|
|
408
|
+
if getattr(self, "is_speaking", False)
|
|
409
|
+
else "Listening"
|
|
410
|
+
)
|
|
411
|
+
print(f"\r\033[K{spinner} {status}{dots}", end="", flush=True)
|
|
412
|
+
|
|
413
|
+
if (
|
|
414
|
+
accumulated_text
|
|
415
|
+
and (time.time() - last_transcription_time)
|
|
416
|
+
> self.silence_threshold
|
|
417
|
+
):
|
|
418
|
+
complete_text = accumulated_text[-1]
|
|
419
|
+
print() # Add a newline before agent response
|
|
420
|
+
asyncio.run(message_processor_callback(complete_text))
|
|
421
|
+
accumulated_text = []
|
|
422
|
+
current_display = ""
|
|
423
|
+
|
|
424
|
+
except Exception as e:
|
|
425
|
+
self.log.error(f"Error in process_audio_wrapper: {str(e)}")
|
|
426
|
+
finally:
|
|
427
|
+
if self.whisper_asr:
|
|
428
|
+
self.whisper_asr.stop_recording()
|
|
429
|
+
if self.tts_thread and self.tts_thread.is_alive():
|
|
430
|
+
self.tts_thread.join(timeout=1.0) # Add timeout to thread join
|
|
431
|
+
|
|
432
|
+
async def halt_generation(self):
|
|
433
|
+
"""Send a request to halt the current generation."""
|
|
434
|
+
if self.llm_client.halt_generation():
|
|
435
|
+
self.log.debug("Successfully halted generation via LLMClient")
|
|
436
|
+
print("\nGeneration interrupted.")
|
|
437
|
+
else:
|
|
438
|
+
self.log.debug("Halt requested - generation will stop on next iteration")
|
|
439
|
+
print("\nInterrupt requested.")
|