amd-gaia 0.15.0__py3-none-any.whl → 0.15.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/METADATA +222 -223
  2. amd_gaia-0.15.2.dist-info/RECORD +182 -0
  3. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/WHEEL +1 -1
  4. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/entry_points.txt +1 -0
  5. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/licenses/LICENSE.md +20 -20
  6. gaia/__init__.py +29 -29
  7. gaia/agents/__init__.py +19 -19
  8. gaia/agents/base/__init__.py +9 -9
  9. gaia/agents/base/agent.py +2132 -2177
  10. gaia/agents/base/api_agent.py +119 -120
  11. gaia/agents/base/console.py +1967 -1841
  12. gaia/agents/base/errors.py +237 -237
  13. gaia/agents/base/mcp_agent.py +86 -86
  14. gaia/agents/base/tools.py +88 -83
  15. gaia/agents/blender/__init__.py +7 -0
  16. gaia/agents/blender/agent.py +553 -556
  17. gaia/agents/blender/agent_simple.py +133 -135
  18. gaia/agents/blender/app.py +211 -211
  19. gaia/agents/blender/app_simple.py +41 -41
  20. gaia/agents/blender/core/__init__.py +16 -16
  21. gaia/agents/blender/core/materials.py +506 -506
  22. gaia/agents/blender/core/objects.py +316 -316
  23. gaia/agents/blender/core/rendering.py +225 -225
  24. gaia/agents/blender/core/scene.py +220 -220
  25. gaia/agents/blender/core/view.py +146 -146
  26. gaia/agents/chat/__init__.py +9 -9
  27. gaia/agents/chat/agent.py +809 -835
  28. gaia/agents/chat/app.py +1065 -1058
  29. gaia/agents/chat/session.py +508 -508
  30. gaia/agents/chat/tools/__init__.py +15 -15
  31. gaia/agents/chat/tools/file_tools.py +96 -96
  32. gaia/agents/chat/tools/rag_tools.py +1744 -1729
  33. gaia/agents/chat/tools/shell_tools.py +437 -436
  34. gaia/agents/code/__init__.py +7 -7
  35. gaia/agents/code/agent.py +549 -549
  36. gaia/agents/code/cli.py +377 -0
  37. gaia/agents/code/models.py +135 -135
  38. gaia/agents/code/orchestration/__init__.py +24 -24
  39. gaia/agents/code/orchestration/checklist_executor.py +1763 -1763
  40. gaia/agents/code/orchestration/checklist_generator.py +713 -713
  41. gaia/agents/code/orchestration/factories/__init__.py +9 -9
  42. gaia/agents/code/orchestration/factories/base.py +63 -63
  43. gaia/agents/code/orchestration/factories/nextjs_factory.py +118 -118
  44. gaia/agents/code/orchestration/factories/python_factory.py +106 -106
  45. gaia/agents/code/orchestration/orchestrator.py +841 -841
  46. gaia/agents/code/orchestration/project_analyzer.py +391 -391
  47. gaia/agents/code/orchestration/steps/__init__.py +67 -67
  48. gaia/agents/code/orchestration/steps/base.py +188 -188
  49. gaia/agents/code/orchestration/steps/error_handler.py +314 -314
  50. gaia/agents/code/orchestration/steps/nextjs.py +828 -828
  51. gaia/agents/code/orchestration/steps/python.py +307 -307
  52. gaia/agents/code/orchestration/template_catalog.py +469 -469
  53. gaia/agents/code/orchestration/workflows/__init__.py +14 -14
  54. gaia/agents/code/orchestration/workflows/base.py +80 -80
  55. gaia/agents/code/orchestration/workflows/nextjs.py +186 -186
  56. gaia/agents/code/orchestration/workflows/python.py +94 -94
  57. gaia/agents/code/prompts/__init__.py +11 -11
  58. gaia/agents/code/prompts/base_prompt.py +77 -77
  59. gaia/agents/code/prompts/code_patterns.py +2034 -2036
  60. gaia/agents/code/prompts/nextjs_prompt.py +40 -40
  61. gaia/agents/code/prompts/python_prompt.py +109 -109
  62. gaia/agents/code/schema_inference.py +365 -365
  63. gaia/agents/code/system_prompt.py +41 -41
  64. gaia/agents/code/tools/__init__.py +42 -42
  65. gaia/agents/code/tools/cli_tools.py +1138 -1138
  66. gaia/agents/code/tools/code_formatting.py +319 -319
  67. gaia/agents/code/tools/code_tools.py +769 -769
  68. gaia/agents/code/tools/error_fixing.py +1347 -1347
  69. gaia/agents/code/tools/external_tools.py +180 -180
  70. gaia/agents/code/tools/file_io.py +845 -845
  71. gaia/agents/code/tools/prisma_tools.py +190 -190
  72. gaia/agents/code/tools/project_management.py +1016 -1016
  73. gaia/agents/code/tools/testing.py +321 -321
  74. gaia/agents/code/tools/typescript_tools.py +122 -122
  75. gaia/agents/code/tools/validation_parsing.py +461 -461
  76. gaia/agents/code/tools/validation_tools.py +806 -806
  77. gaia/agents/code/tools/web_dev_tools.py +1758 -1758
  78. gaia/agents/code/validators/__init__.py +16 -16
  79. gaia/agents/code/validators/antipattern_checker.py +241 -241
  80. gaia/agents/code/validators/ast_analyzer.py +197 -197
  81. gaia/agents/code/validators/requirements_validator.py +145 -145
  82. gaia/agents/code/validators/syntax_validator.py +171 -171
  83. gaia/agents/docker/__init__.py +7 -7
  84. gaia/agents/docker/agent.py +643 -642
  85. gaia/agents/emr/__init__.py +8 -8
  86. gaia/agents/emr/agent.py +1504 -1506
  87. gaia/agents/emr/cli.py +1322 -1322
  88. gaia/agents/emr/constants.py +475 -475
  89. gaia/agents/emr/dashboard/__init__.py +4 -4
  90. gaia/agents/emr/dashboard/server.py +1972 -1974
  91. gaia/agents/jira/__init__.py +11 -11
  92. gaia/agents/jira/agent.py +894 -894
  93. gaia/agents/jira/jql_templates.py +299 -299
  94. gaia/agents/routing/__init__.py +7 -7
  95. gaia/agents/routing/agent.py +567 -570
  96. gaia/agents/routing/system_prompt.py +75 -75
  97. gaia/agents/summarize/__init__.py +11 -0
  98. gaia/agents/summarize/agent.py +885 -0
  99. gaia/agents/summarize/prompts.py +129 -0
  100. gaia/api/__init__.py +23 -23
  101. gaia/api/agent_registry.py +238 -238
  102. gaia/api/app.py +305 -305
  103. gaia/api/openai_server.py +575 -575
  104. gaia/api/schemas.py +186 -186
  105. gaia/api/sse_handler.py +373 -373
  106. gaia/apps/__init__.py +4 -4
  107. gaia/apps/llm/__init__.py +6 -6
  108. gaia/apps/llm/app.py +184 -169
  109. gaia/apps/summarize/app.py +116 -633
  110. gaia/apps/summarize/html_viewer.py +133 -133
  111. gaia/apps/summarize/pdf_formatter.py +284 -284
  112. gaia/audio/__init__.py +2 -2
  113. gaia/audio/audio_client.py +439 -439
  114. gaia/audio/audio_recorder.py +269 -269
  115. gaia/audio/kokoro_tts.py +599 -599
  116. gaia/audio/whisper_asr.py +432 -432
  117. gaia/chat/__init__.py +16 -16
  118. gaia/chat/app.py +428 -430
  119. gaia/chat/prompts.py +522 -522
  120. gaia/chat/sdk.py +1228 -1225
  121. gaia/cli.py +5659 -5632
  122. gaia/database/__init__.py +10 -10
  123. gaia/database/agent.py +176 -176
  124. gaia/database/mixin.py +290 -290
  125. gaia/database/testing.py +64 -64
  126. gaia/eval/batch_experiment.py +2332 -2332
  127. gaia/eval/claude.py +542 -542
  128. gaia/eval/config.py +37 -37
  129. gaia/eval/email_generator.py +512 -512
  130. gaia/eval/eval.py +3179 -3179
  131. gaia/eval/groundtruth.py +1130 -1130
  132. gaia/eval/transcript_generator.py +582 -582
  133. gaia/eval/webapp/README.md +167 -167
  134. gaia/eval/webapp/package-lock.json +875 -875
  135. gaia/eval/webapp/package.json +20 -20
  136. gaia/eval/webapp/public/app.js +3402 -3402
  137. gaia/eval/webapp/public/index.html +87 -87
  138. gaia/eval/webapp/public/styles.css +3661 -3661
  139. gaia/eval/webapp/server.js +415 -415
  140. gaia/eval/webapp/test-setup.js +72 -72
  141. gaia/installer/__init__.py +23 -0
  142. gaia/installer/init_command.py +1275 -0
  143. gaia/installer/lemonade_installer.py +619 -0
  144. gaia/llm/__init__.py +10 -2
  145. gaia/llm/base_client.py +60 -0
  146. gaia/llm/exceptions.py +12 -0
  147. gaia/llm/factory.py +70 -0
  148. gaia/llm/lemonade_client.py +3421 -3221
  149. gaia/llm/lemonade_manager.py +294 -294
  150. gaia/llm/providers/__init__.py +9 -0
  151. gaia/llm/providers/claude.py +108 -0
  152. gaia/llm/providers/lemonade.py +118 -0
  153. gaia/llm/providers/openai_provider.py +79 -0
  154. gaia/llm/vlm_client.py +382 -382
  155. gaia/logger.py +189 -189
  156. gaia/mcp/agent_mcp_server.py +245 -245
  157. gaia/mcp/blender_mcp_client.py +138 -138
  158. gaia/mcp/blender_mcp_server.py +648 -648
  159. gaia/mcp/context7_cache.py +332 -332
  160. gaia/mcp/external_services.py +518 -518
  161. gaia/mcp/mcp_bridge.py +811 -550
  162. gaia/mcp/servers/__init__.py +6 -6
  163. gaia/mcp/servers/docker_mcp.py +83 -83
  164. gaia/perf_analysis.py +361 -0
  165. gaia/rag/__init__.py +10 -10
  166. gaia/rag/app.py +293 -293
  167. gaia/rag/demo.py +304 -304
  168. gaia/rag/pdf_utils.py +235 -235
  169. gaia/rag/sdk.py +2194 -2194
  170. gaia/security.py +183 -163
  171. gaia/talk/app.py +287 -289
  172. gaia/talk/sdk.py +538 -538
  173. gaia/testing/__init__.py +87 -87
  174. gaia/testing/assertions.py +330 -330
  175. gaia/testing/fixtures.py +333 -333
  176. gaia/testing/mocks.py +493 -493
  177. gaia/util.py +46 -46
  178. gaia/utils/__init__.py +33 -33
  179. gaia/utils/file_watcher.py +675 -675
  180. gaia/utils/parsing.py +223 -223
  181. gaia/version.py +100 -100
  182. amd_gaia-0.15.0.dist-info/RECORD +0 -168
  183. gaia/agents/code/app.py +0 -266
  184. gaia/llm/llm_client.py +0 -723
  185. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/top_level.txt +0 -0
@@ -1,439 +1,439 @@
1
- # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2
- # SPDX-License-Identifier: MIT
3
-
4
- import asyncio
5
- import queue
6
- import threading
7
- import time
8
-
9
- from gaia.llm.llm_client import LLMClient
10
- from gaia.logger import get_logger
11
-
12
-
13
- class AudioClient:
14
- """Handles all audio-related functionality including TTS, ASR, and voice chat."""
15
-
16
- def __init__(
17
- self,
18
- whisper_model_size="base",
19
- audio_device_index=None, # Use default input device
20
- silence_threshold=0.5,
21
- enable_tts=True,
22
- logging_level="INFO",
23
- use_claude=False,
24
- use_chatgpt=False,
25
- system_prompt=None,
26
- ):
27
- self.log = get_logger(__name__)
28
- self.log.setLevel(getattr(__import__("logging"), logging_level))
29
-
30
- # Audio configuration
31
- self.whisper_model_size = whisper_model_size
32
- self.audio_device_index = audio_device_index
33
- self.silence_threshold = silence_threshold
34
- self.enable_tts = enable_tts
35
-
36
- # Audio state
37
- self.is_speaking = False
38
- self.tts_thread = None
39
- self.whisper_asr = None
40
- self.transcription_queue = queue.Queue()
41
- self.tts = None
42
-
43
- # Initialize LLM client (base_url handled automatically)
44
- self.llm_client = LLMClient(
45
- use_claude=use_claude,
46
- use_openai=use_chatgpt, # LLMClient uses use_openai, not use_chatgpt
47
- system_prompt=system_prompt,
48
- )
49
-
50
- self.log.info("Audio client initialized.")
51
-
52
- async def start_voice_chat(self, message_processor_callback):
53
- """Start a voice-based chat session."""
54
- try:
55
- self.log.debug("Initializing voice chat...")
56
- print(
57
- "Starting voice chat.\n"
58
- "Say 'stop' to quit application "
59
- "or 'restart' to clear the chat history.\n"
60
- "Press Enter key to stop during audio playback."
61
- )
62
-
63
- # Initialize TTS before starting voice chat
64
- self.initialize_tts()
65
-
66
- from gaia.audio.whisper_asr import WhisperAsr
67
-
68
- # Create WhisperAsr with custom thresholds
69
- # Your audio shows energy levels of 0.02-0.03 when speaking
70
- self.whisper_asr = WhisperAsr(
71
- model_size=self.whisper_model_size,
72
- device_index=self.audio_device_index,
73
- transcription_queue=self.transcription_queue,
74
- silence_threshold=0.01, # Set higher to ensure detection (your levels are 0.01-0.2+)
75
- min_audio_length=16000 * 1.0, # 1 second minimum at 16kHz
76
- )
77
-
78
- # Log the thresholds being used (reduce verbosity)
79
- self.log.debug(
80
- f"Audio settings: SILENCE_THRESHOLD={self.whisper_asr.SILENCE_THRESHOLD}, "
81
- f"MIN_LENGTH={self.whisper_asr.MIN_AUDIO_LENGTH/self.whisper_asr.RATE:.1f}s"
82
- )
83
-
84
- device_name = self.whisper_asr.get_device_name()
85
- self.log.debug(f"Using audio device: {device_name}")
86
-
87
- # Start recording
88
- self.log.debug("Starting audio recording...")
89
- self.whisper_asr.start_recording()
90
-
91
- # Start the processing thread after recording is initialized
92
- self.log.debug("Starting audio processing thread...")
93
- process_thread = threading.Thread(
94
- target=self._process_audio_wrapper, args=(message_processor_callback,)
95
- )
96
- process_thread.daemon = True
97
- process_thread.start()
98
-
99
- # Keep the main thread alive while processing
100
- self.log.debug("Listening for voice input...")
101
- try:
102
- while True:
103
- if not process_thread.is_alive():
104
- self.log.debug("Process thread stopped unexpectedly")
105
- break
106
- if not self.whisper_asr or not self.whisper_asr.is_recording:
107
- self.log.warning("Recording stopped unexpectedly")
108
- break
109
- await asyncio.sleep(0.1)
110
-
111
- except KeyboardInterrupt:
112
- self.log.info("Received keyboard interrupt")
113
- print("\nStopping voice chat...")
114
- except Exception as e:
115
- self.log.error(f"Error in main processing loop: {str(e)}")
116
- raise
117
- finally:
118
- if self.whisper_asr:
119
- self.log.debug("Stopping recording...")
120
- self.whisper_asr.stop_recording()
121
- self.log.debug("Waiting for process thread to finish...")
122
- process_thread.join(timeout=2.0)
123
-
124
- except ImportError:
125
- self.log.error(
126
- 'WhisperAsr not found. Please install voice support with: uv pip install ".[talk]"'
127
- )
128
- raise
129
- except Exception as e:
130
- self.log.error(f"Failed to initialize voice chat: {str(e)}")
131
- raise
132
- finally:
133
- if self.whisper_asr:
134
- self.whisper_asr.stop_recording()
135
- self.log.info("Voice recording stopped")
136
-
137
- async def process_voice_input(self, text, get_stats_callback=None):
138
- """Process transcribed voice input and get AI response"""
139
-
140
- # Initialize TTS streaming
141
- text_queue = None
142
- tts_finished = threading.Event() # Add event to track TTS completion
143
- interrupt_event = threading.Event() # Add event for keyboard interrupts
144
-
145
- try:
146
- # Check if we're currently generating and halt if needed
147
- if self.llm_client.is_generating():
148
- self.log.debug("Generation in progress, halting...")
149
- if self.llm_client.halt_generation():
150
- print("\nGeneration interrupted.")
151
- await asyncio.sleep(0.5)
152
-
153
- # Pause audio recording before sending query
154
- if self.whisper_asr:
155
- self.whisper_asr.pause_recording()
156
- self.log.debug("Recording paused before generation")
157
-
158
- self.log.debug(f"Sending message to LLM: {text[:50]}...")
159
- print("\nGaia: ", end="", flush=True)
160
-
161
- # Keyboard listener thread for both generation and playback
162
- def keyboard_listener():
163
- input() # Wait for any input
164
-
165
- # Use LLMClient to halt generation
166
- if self.llm_client.halt_generation():
167
- print("\nGeneration interrupted.")
168
- else:
169
- print("\nInterrupt requested.")
170
-
171
- interrupt_event.set()
172
- if text_queue:
173
- text_queue.put("__HALT__") # Signal TTS to stop immediately
174
-
175
- # Start keyboard listener thread
176
- keyboard_thread = threading.Thread(target=keyboard_listener)
177
- keyboard_thread.daemon = True
178
- keyboard_thread.start()
179
-
180
- if self.enable_tts:
181
- text_queue = queue.Queue(maxsize=100)
182
-
183
- # Define status callback to update speaking state
184
- def tts_status_callback(is_speaking):
185
- self.is_speaking = is_speaking
186
- if not is_speaking: # When TTS finishes speaking
187
- tts_finished.set()
188
- if self.whisper_asr:
189
- self.whisper_asr.resume_recording()
190
- else: # When TTS starts speaking
191
- if self.whisper_asr:
192
- self.whisper_asr.pause_recording()
193
- self.log.debug(f"TTS speaking state: {is_speaking}")
194
-
195
- self.tts_thread = threading.Thread(
196
- target=self.tts.generate_speech_streaming,
197
- args=(text_queue,),
198
- kwargs={
199
- "status_callback": tts_status_callback,
200
- "interrupt_event": interrupt_event,
201
- },
202
- daemon=True,
203
- )
204
- self.tts_thread.start()
205
-
206
- # Use LLMClient streaming instead of WebSocket
207
- accumulated_response = ""
208
- initial_buffer = "" # Buffer for the start of response
209
- initial_buffer_sent = False
210
-
211
- try:
212
- # Start LLM generation with streaming
213
- response_stream = self.llm_client.generate(text, stream=True)
214
-
215
- # Process streaming response
216
- for chunk in response_stream:
217
- if interrupt_event.is_set():
218
- self.log.debug("Keyboard interrupt detected, stopping...")
219
- if text_queue:
220
- text_queue.put("__END__")
221
- break
222
-
223
- if self.transcription_queue.qsize() > 0:
224
- self.log.debug(
225
- "New input detected during generation, stopping..."
226
- )
227
- if text_queue:
228
- text_queue.put("__END__")
229
- # Use LLMClient to halt generation
230
- if self.llm_client.halt_generation():
231
- self.log.debug("Generation interrupted for new input.")
232
- return
233
-
234
- if chunk:
235
- print(chunk, end="", flush=True)
236
- if text_queue:
237
- if not initial_buffer_sent:
238
- initial_buffer += chunk
239
- # Send if we've reached 20 chars or if we get a clear end marker
240
- if len(initial_buffer) >= 20 or chunk.endswith(
241
- ("\n", ". ", "! ", "? ")
242
- ):
243
- text_queue.put(initial_buffer)
244
- initial_buffer_sent = True
245
- else:
246
- text_queue.put(chunk)
247
- accumulated_response += chunk
248
-
249
- # Send any remaining buffered content
250
- if text_queue:
251
- if not initial_buffer_sent and initial_buffer:
252
- # Small delay for very short responses
253
- if len(initial_buffer) <= 20:
254
- await asyncio.sleep(0.1)
255
- text_queue.put(initial_buffer)
256
- text_queue.put("__END__")
257
-
258
- except Exception as e:
259
- if text_queue:
260
- text_queue.put("__END__")
261
- raise e
262
- finally:
263
- if self.tts_thread and self.tts_thread.is_alive():
264
- self.tts_thread.join(timeout=1.0) # Add timeout to thread join
265
- keyboard_thread.join(timeout=1.0) # Add timeout to keyboard thread join
266
-
267
- print("\n")
268
- # Get performance stats from LLMClient
269
- if get_stats_callback:
270
- # First try the provided callback for backward compatibility
271
- stats = get_stats_callback()
272
- else:
273
- # Use LLMClient stats
274
- stats = self.llm_client.get_performance_stats()
275
-
276
- if stats:
277
- from pprint import pprint
278
-
279
- formatted_stats = {
280
- k: round(v, 1) if isinstance(v, float) else v
281
- for k, v in stats.items()
282
- }
283
- pprint(formatted_stats)
284
-
285
- except Exception as e:
286
- if text_queue:
287
- text_queue.put("__END__")
288
- raise e
289
- finally:
290
- if self.tts_thread and self.tts_thread.is_alive():
291
- # Wait for TTS to finish before resuming recording
292
- tts_finished.wait(timeout=2.0) # Add reasonable timeout
293
- self.tts_thread.join(timeout=1.0)
294
-
295
- # Only resume recording after TTS is completely finished
296
- if self.whisper_asr:
297
- self.whisper_asr.resume_recording()
298
-
299
- def initialize_tts(self):
300
- """Initialize TTS if enabled."""
301
- if self.enable_tts:
302
- try:
303
- from gaia.audio.kokoro_tts import KokoroTTS
304
-
305
- self.tts = KokoroTTS()
306
- self.log.debug("TTS initialized successfully")
307
- except Exception as e:
308
- raise RuntimeError(
309
- f'Failed to initialize TTS:\n{e}\nInstall talk dependencies with: uv pip install ".[talk]"\nYou can also use --no-tts option to disable TTS'
310
- )
311
-
312
- async def speak_text(self, text: str) -> None:
313
- """Speak text using initialized TTS, if available."""
314
- if not self.enable_tts:
315
- return
316
- if not getattr(self, "tts", None):
317
- self.log.debug("TTS is not initialized; skipping speak_text")
318
- return
319
- # Reuse the streaming path used in process_voice_input
320
- text_queue = queue.Queue(maxsize=100)
321
- interrupt_event = threading.Event()
322
- tts_thread = threading.Thread(
323
- target=self.tts.generate_speech_streaming,
324
- args=(text_queue,),
325
- kwargs={"interrupt_event": interrupt_event},
326
- daemon=True,
327
- )
328
- tts_thread.start()
329
- # Send the whole text and end
330
- text_queue.put(text)
331
- text_queue.put("__END__")
332
- tts_thread.join(timeout=5.0)
333
-
334
- def _process_audio_wrapper(self, message_processor_callback):
335
- """Wrapper method to process audio and handle transcriptions"""
336
- try:
337
- accumulated_text = []
338
- current_display = ""
339
- last_transcription_time = time.time()
340
- spinner_chars = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
341
- dots_animation = [" ", ". ", ".. ", "..."]
342
- spinner_idx = 0
343
- dots_idx = 0
344
- animation_counter = 0
345
- self.is_speaking = False # Initialize speaking state
346
-
347
- while self.whisper_asr and self.whisper_asr.is_recording:
348
- try:
349
- text = self.transcription_queue.get(timeout=0.1)
350
-
351
- current_time = time.time()
352
- time_since_last = current_time - last_transcription_time
353
- cleaned_text = text.lower().strip().rstrip(".!?")
354
-
355
- # Handle special commands
356
- if cleaned_text in ["stop"]:
357
- print("\nStopping voice chat...")
358
- self.whisper_asr.stop_recording()
359
- break
360
-
361
- # Update animations
362
- spinner_idx = (spinner_idx + 1) % len(spinner_chars)
363
- animation_counter += 1
364
- if animation_counter % 4 == 0: # Update dots every fourth cycle
365
- dots_idx = (dots_idx + 1) % len(dots_animation)
366
- spinner = spinner_chars[spinner_idx]
367
- dots = dots_animation[dots_idx]
368
-
369
- # Normal text processing - only if it's not a system message
370
- if text != current_display:
371
- # Clear the current line and display updated text with spinner
372
- print(f"\r\033[K{spinner} {text}", end="", flush=True)
373
- current_display = text
374
-
375
- # Only add new text if it's significantly different
376
- if not any(text in existing for existing in accumulated_text):
377
- accumulated_text = [text] # Replace instead of append
378
- last_transcription_time = current_time
379
-
380
- # Process accumulated text after silence threshold
381
- if time_since_last > self.silence_threshold:
382
- if accumulated_text:
383
- complete_text = accumulated_text[
384
- -1
385
- ] # Use only the last transcription
386
- print() # Add a newline before agent response
387
- asyncio.run(message_processor_callback(complete_text))
388
- accumulated_text = []
389
- current_display = ""
390
-
391
- except queue.Empty:
392
- # Update animations
393
- spinner_idx = (spinner_idx + 1) % len(spinner_chars)
394
- animation_counter += 1
395
- if animation_counter % 4 == 0:
396
- dots_idx = (dots_idx + 1) % len(dots_animation)
397
- spinner = spinner_chars[spinner_idx]
398
- dots = dots_animation[dots_idx]
399
-
400
- if current_display:
401
- print(
402
- f"\r\033[K{spinner} {current_display}", end="", flush=True
403
- )
404
- else:
405
- # Access the class-level speaking state
406
- status = (
407
- "Speaking"
408
- if getattr(self, "is_speaking", False)
409
- else "Listening"
410
- )
411
- print(f"\r\033[K{spinner} {status}{dots}", end="", flush=True)
412
-
413
- if (
414
- accumulated_text
415
- and (time.time() - last_transcription_time)
416
- > self.silence_threshold
417
- ):
418
- complete_text = accumulated_text[-1]
419
- print() # Add a newline before agent response
420
- asyncio.run(message_processor_callback(complete_text))
421
- accumulated_text = []
422
- current_display = ""
423
-
424
- except Exception as e:
425
- self.log.error(f"Error in process_audio_wrapper: {str(e)}")
426
- finally:
427
- if self.whisper_asr:
428
- self.whisper_asr.stop_recording()
429
- if self.tts_thread and self.tts_thread.is_alive():
430
- self.tts_thread.join(timeout=1.0) # Add timeout to thread join
431
-
432
- async def halt_generation(self):
433
- """Send a request to halt the current generation."""
434
- if self.llm_client.halt_generation():
435
- self.log.debug("Successfully halted generation via LLMClient")
436
- print("\nGeneration interrupted.")
437
- else:
438
- self.log.debug("Halt requested - generation will stop on next iteration")
439
- print("\nInterrupt requested.")
1
+ # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import asyncio
5
+ import queue
6
+ import threading
7
+ import time
8
+
9
+ from gaia.llm import create_client
10
+ from gaia.logger import get_logger
11
+
12
+
13
+ class AudioClient:
14
+ """Handles all audio-related functionality including TTS, ASR, and voice chat."""
15
+
16
+ def __init__(
17
+ self,
18
+ whisper_model_size="base",
19
+ audio_device_index=None, # Use default input device
20
+ silence_threshold=0.5,
21
+ enable_tts=True,
22
+ logging_level="INFO",
23
+ use_claude=False,
24
+ use_chatgpt=False,
25
+ system_prompt=None,
26
+ ):
27
+ self.log = get_logger(__name__)
28
+ self.log.setLevel(getattr(__import__("logging"), logging_level))
29
+
30
+ # Audio configuration
31
+ self.whisper_model_size = whisper_model_size
32
+ self.audio_device_index = audio_device_index
33
+ self.silence_threshold = silence_threshold
34
+ self.enable_tts = enable_tts
35
+
36
+ # Audio state
37
+ self.is_speaking = False
38
+ self.tts_thread = None
39
+ self.whisper_asr = None
40
+ self.transcription_queue = queue.Queue()
41
+ self.tts = None
42
+
43
+ # Initialize LLM client - factory auto-detects provider from flags
44
+ self.llm_client = create_client(
45
+ use_claude=use_claude,
46
+ use_openai=use_chatgpt,
47
+ system_prompt=system_prompt,
48
+ )
49
+
50
+ self.log.info("Audio client initialized.")
51
+
52
+ async def start_voice_chat(self, message_processor_callback):
53
+ """Start a voice-based chat session."""
54
+ try:
55
+ self.log.debug("Initializing voice chat...")
56
+ print(
57
+ "Starting voice chat.\n"
58
+ "Say 'stop' to quit application "
59
+ "or 'restart' to clear the chat history.\n"
60
+ "Press Enter key to stop during audio playback."
61
+ )
62
+
63
+ # Initialize TTS before starting voice chat
64
+ self.initialize_tts()
65
+
66
+ from gaia.audio.whisper_asr import WhisperAsr
67
+
68
+ # Create WhisperAsr with custom thresholds
69
+ # Your audio shows energy levels of 0.02-0.03 when speaking
70
+ self.whisper_asr = WhisperAsr(
71
+ model_size=self.whisper_model_size,
72
+ device_index=self.audio_device_index,
73
+ transcription_queue=self.transcription_queue,
74
+ silence_threshold=0.01, # Set higher to ensure detection (your levels are 0.01-0.2+)
75
+ min_audio_length=16000 * 1.0, # 1 second minimum at 16kHz
76
+ )
77
+
78
+ # Log the thresholds being used (reduce verbosity)
79
+ self.log.debug(
80
+ f"Audio settings: SILENCE_THRESHOLD={self.whisper_asr.SILENCE_THRESHOLD}, "
81
+ f"MIN_LENGTH={self.whisper_asr.MIN_AUDIO_LENGTH/self.whisper_asr.RATE:.1f}s"
82
+ )
83
+
84
+ device_name = self.whisper_asr.get_device_name()
85
+ self.log.debug(f"Using audio device: {device_name}")
86
+
87
+ # Start recording
88
+ self.log.debug("Starting audio recording...")
89
+ self.whisper_asr.start_recording()
90
+
91
+ # Start the processing thread after recording is initialized
92
+ self.log.debug("Starting audio processing thread...")
93
+ process_thread = threading.Thread(
94
+ target=self._process_audio_wrapper, args=(message_processor_callback,)
95
+ )
96
+ process_thread.daemon = True
97
+ process_thread.start()
98
+
99
+ # Keep the main thread alive while processing
100
+ self.log.debug("Listening for voice input...")
101
+ try:
102
+ while True:
103
+ if not process_thread.is_alive():
104
+ self.log.debug("Process thread stopped unexpectedly")
105
+ break
106
+ if not self.whisper_asr or not self.whisper_asr.is_recording:
107
+ self.log.warning("Recording stopped unexpectedly")
108
+ break
109
+ await asyncio.sleep(0.1)
110
+
111
+ except KeyboardInterrupt:
112
+ self.log.info("Received keyboard interrupt")
113
+ print("\nStopping voice chat...")
114
+ except Exception as e:
115
+ self.log.error(f"Error in main processing loop: {str(e)}")
116
+ raise
117
+ finally:
118
+ if self.whisper_asr:
119
+ self.log.debug("Stopping recording...")
120
+ self.whisper_asr.stop_recording()
121
+ self.log.debug("Waiting for process thread to finish...")
122
+ process_thread.join(timeout=2.0)
123
+
124
+ except ImportError:
125
+ self.log.error(
126
+ 'WhisperAsr not found. Please install voice support with: uv pip install ".[talk]"'
127
+ )
128
+ raise
129
+ except Exception as e:
130
+ self.log.error(f"Failed to initialize voice chat: {str(e)}")
131
+ raise
132
+ finally:
133
+ if self.whisper_asr:
134
+ self.whisper_asr.stop_recording()
135
+ self.log.info("Voice recording stopped")
136
+
137
+ async def process_voice_input(self, text, get_stats_callback=None):
138
+ """Process transcribed voice input and get AI response"""
139
+
140
+ # Initialize TTS streaming
141
+ text_queue = None
142
+ tts_finished = threading.Event() # Add event to track TTS completion
143
+ interrupt_event = threading.Event() # Add event for keyboard interrupts
144
+
145
+ try:
146
+ # Check if we're currently generating and halt if needed
147
+ if self.llm_client.is_generating():
148
+ self.log.debug("Generation in progress, halting...")
149
+ if self.llm_client.halt_generation():
150
+ print("\nGeneration interrupted.")
151
+ await asyncio.sleep(0.5)
152
+
153
+ # Pause audio recording before sending query
154
+ if self.whisper_asr:
155
+ self.whisper_asr.pause_recording()
156
+ self.log.debug("Recording paused before generation")
157
+
158
+ self.log.debug(f"Sending message to LLM: {text[:50]}...")
159
+ print("\nGaia: ", end="", flush=True)
160
+
161
+ # Keyboard listener thread for both generation and playback
162
+ def keyboard_listener():
163
+ input() # Wait for any input
164
+
165
+ # Use LLMClient to halt generation
166
+ if self.llm_client.halt_generation():
167
+ print("\nGeneration interrupted.")
168
+ else:
169
+ print("\nInterrupt requested.")
170
+
171
+ interrupt_event.set()
172
+ if text_queue:
173
+ text_queue.put("__HALT__") # Signal TTS to stop immediately
174
+
175
+ # Start keyboard listener thread
176
+ keyboard_thread = threading.Thread(target=keyboard_listener)
177
+ keyboard_thread.daemon = True
178
+ keyboard_thread.start()
179
+
180
+ if self.enable_tts:
181
+ text_queue = queue.Queue(maxsize=100)
182
+
183
+ # Define status callback to update speaking state
184
+ def tts_status_callback(is_speaking):
185
+ self.is_speaking = is_speaking
186
+ if not is_speaking: # When TTS finishes speaking
187
+ tts_finished.set()
188
+ if self.whisper_asr:
189
+ self.whisper_asr.resume_recording()
190
+ else: # When TTS starts speaking
191
+ if self.whisper_asr:
192
+ self.whisper_asr.pause_recording()
193
+ self.log.debug(f"TTS speaking state: {is_speaking}")
194
+
195
+ self.tts_thread = threading.Thread(
196
+ target=self.tts.generate_speech_streaming,
197
+ args=(text_queue,),
198
+ kwargs={
199
+ "status_callback": tts_status_callback,
200
+ "interrupt_event": interrupt_event,
201
+ },
202
+ daemon=True,
203
+ )
204
+ self.tts_thread.start()
205
+
206
+ # Use LLMClient streaming instead of WebSocket
207
+ accumulated_response = ""
208
+ initial_buffer = "" # Buffer for the start of response
209
+ initial_buffer_sent = False
210
+
211
+ try:
212
+ # Start LLM generation with streaming
213
+ response_stream = self.llm_client.generate(text, stream=True)
214
+
215
+ # Process streaming response
216
+ for chunk in response_stream:
217
+ if interrupt_event.is_set():
218
+ self.log.debug("Keyboard interrupt detected, stopping...")
219
+ if text_queue:
220
+ text_queue.put("__END__")
221
+ break
222
+
223
+ if self.transcription_queue.qsize() > 0:
224
+ self.log.debug(
225
+ "New input detected during generation, stopping..."
226
+ )
227
+ if text_queue:
228
+ text_queue.put("__END__")
229
+ # Use LLMClient to halt generation
230
+ if self.llm_client.halt_generation():
231
+ self.log.debug("Generation interrupted for new input.")
232
+ return
233
+
234
+ if chunk:
235
+ print(chunk, end="", flush=True)
236
+ if text_queue:
237
+ if not initial_buffer_sent:
238
+ initial_buffer += chunk
239
+ # Send if we've reached 20 chars or if we get a clear end marker
240
+ if len(initial_buffer) >= 20 or chunk.endswith(
241
+ ("\n", ". ", "! ", "? ")
242
+ ):
243
+ text_queue.put(initial_buffer)
244
+ initial_buffer_sent = True
245
+ else:
246
+ text_queue.put(chunk)
247
+ accumulated_response += chunk
248
+
249
+ # Send any remaining buffered content
250
+ if text_queue:
251
+ if not initial_buffer_sent and initial_buffer:
252
+ # Small delay for very short responses
253
+ if len(initial_buffer) <= 20:
254
+ await asyncio.sleep(0.1)
255
+ text_queue.put(initial_buffer)
256
+ text_queue.put("__END__")
257
+
258
+ except Exception as e:
259
+ if text_queue:
260
+ text_queue.put("__END__")
261
+ raise e
262
+ finally:
263
+ if self.tts_thread and self.tts_thread.is_alive():
264
+ self.tts_thread.join(timeout=1.0) # Add timeout to thread join
265
+ keyboard_thread.join(timeout=1.0) # Add timeout to keyboard thread join
266
+
267
+ print("\n")
268
+ # Get performance stats from LLMClient
269
+ if get_stats_callback:
270
+ # First try the provided callback for backward compatibility
271
+ stats = get_stats_callback()
272
+ else:
273
+ # Use LLMClient stats
274
+ stats = self.llm_client.get_performance_stats()
275
+
276
+ if stats:
277
+ from pprint import pprint
278
+
279
+ formatted_stats = {
280
+ k: round(v, 1) if isinstance(v, float) else v
281
+ for k, v in stats.items()
282
+ }
283
+ pprint(formatted_stats)
284
+
285
+ except Exception as e:
286
+ if text_queue:
287
+ text_queue.put("__END__")
288
+ raise e
289
+ finally:
290
+ if self.tts_thread and self.tts_thread.is_alive():
291
+ # Wait for TTS to finish before resuming recording
292
+ tts_finished.wait(timeout=2.0) # Add reasonable timeout
293
+ self.tts_thread.join(timeout=1.0)
294
+
295
+ # Only resume recording after TTS is completely finished
296
+ if self.whisper_asr:
297
+ self.whisper_asr.resume_recording()
298
+
299
+ def initialize_tts(self):
300
+ """Initialize TTS if enabled."""
301
+ if self.enable_tts:
302
+ try:
303
+ from gaia.audio.kokoro_tts import KokoroTTS
304
+
305
+ self.tts = KokoroTTS()
306
+ self.log.debug("TTS initialized successfully")
307
+ except Exception as e:
308
+ raise RuntimeError(
309
+ f'Failed to initialize TTS:\n{e}\nInstall talk dependencies with: uv pip install ".[talk]"\nYou can also use --no-tts option to disable TTS'
310
+ )
311
+
312
+ async def speak_text(self, text: str) -> None:
313
+ """Speak text using initialized TTS, if available."""
314
+ if not self.enable_tts:
315
+ return
316
+ if not getattr(self, "tts", None):
317
+ self.log.debug("TTS is not initialized; skipping speak_text")
318
+ return
319
+ # Reuse the streaming path used in process_voice_input
320
+ text_queue = queue.Queue(maxsize=100)
321
+ interrupt_event = threading.Event()
322
+ tts_thread = threading.Thread(
323
+ target=self.tts.generate_speech_streaming,
324
+ args=(text_queue,),
325
+ kwargs={"interrupt_event": interrupt_event},
326
+ daemon=True,
327
+ )
328
+ tts_thread.start()
329
+ # Send the whole text and end
330
+ text_queue.put(text)
331
+ text_queue.put("__END__")
332
+ tts_thread.join(timeout=5.0)
333
+
334
+ def _process_audio_wrapper(self, message_processor_callback):
335
+ """Wrapper method to process audio and handle transcriptions"""
336
+ try:
337
+ accumulated_text = []
338
+ current_display = ""
339
+ last_transcription_time = time.time()
340
+ spinner_chars = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
341
+ dots_animation = [" ", ". ", ".. ", "..."]
342
+ spinner_idx = 0
343
+ dots_idx = 0
344
+ animation_counter = 0
345
+ self.is_speaking = False # Initialize speaking state
346
+
347
+ while self.whisper_asr and self.whisper_asr.is_recording:
348
+ try:
349
+ text = self.transcription_queue.get(timeout=0.1)
350
+
351
+ current_time = time.time()
352
+ time_since_last = current_time - last_transcription_time
353
+ cleaned_text = text.lower().strip().rstrip(".!?")
354
+
355
+ # Handle special commands
356
+ if cleaned_text in ["stop"]:
357
+ print("\nStopping voice chat...")
358
+ self.whisper_asr.stop_recording()
359
+ break
360
+
361
+ # Update animations
362
+ spinner_idx = (spinner_idx + 1) % len(spinner_chars)
363
+ animation_counter += 1
364
+ if animation_counter % 4 == 0: # Update dots every fourth cycle
365
+ dots_idx = (dots_idx + 1) % len(dots_animation)
366
+ spinner = spinner_chars[spinner_idx]
367
+ dots = dots_animation[dots_idx]
368
+
369
+ # Normal text processing - only if it's not a system message
370
+ if text != current_display:
371
+ # Clear the current line and display updated text with spinner
372
+ print(f"\r\033[K{spinner} {text}", end="", flush=True)
373
+ current_display = text
374
+
375
+ # Only add new text if it's significantly different
376
+ if not any(text in existing for existing in accumulated_text):
377
+ accumulated_text = [text] # Replace instead of append
378
+ last_transcription_time = current_time
379
+
380
+ # Process accumulated text after silence threshold
381
+ if time_since_last > self.silence_threshold:
382
+ if accumulated_text:
383
+ complete_text = accumulated_text[
384
+ -1
385
+ ] # Use only the last transcription
386
+ print() # Add a newline before agent response
387
+ asyncio.run(message_processor_callback(complete_text))
388
+ accumulated_text = []
389
+ current_display = ""
390
+
391
+ except queue.Empty:
392
+ # Update animations
393
+ spinner_idx = (spinner_idx + 1) % len(spinner_chars)
394
+ animation_counter += 1
395
+ if animation_counter % 4 == 0:
396
+ dots_idx = (dots_idx + 1) % len(dots_animation)
397
+ spinner = spinner_chars[spinner_idx]
398
+ dots = dots_animation[dots_idx]
399
+
400
+ if current_display:
401
+ print(
402
+ f"\r\033[K{spinner} {current_display}", end="", flush=True
403
+ )
404
+ else:
405
+ # Access the class-level speaking state
406
+ status = (
407
+ "Speaking"
408
+ if getattr(self, "is_speaking", False)
409
+ else "Listening"
410
+ )
411
+ print(f"\r\033[K{spinner} {status}{dots}", end="", flush=True)
412
+
413
+ if (
414
+ accumulated_text
415
+ and (time.time() - last_transcription_time)
416
+ > self.silence_threshold
417
+ ):
418
+ complete_text = accumulated_text[-1]
419
+ print() # Add a newline before agent response
420
+ asyncio.run(message_processor_callback(complete_text))
421
+ accumulated_text = []
422
+ current_display = ""
423
+
424
+ except Exception as e:
425
+ self.log.error(f"Error in process_audio_wrapper: {str(e)}")
426
+ finally:
427
+ if self.whisper_asr:
428
+ self.whisper_asr.stop_recording()
429
+ if self.tts_thread and self.tts_thread.is_alive():
430
+ self.tts_thread.join(timeout=1.0) # Add timeout to thread join
431
+
432
+ async def halt_generation(self):
433
+ """Send a request to halt the current generation."""
434
+ if self.llm_client.halt_generation():
435
+ self.log.debug("Successfully halted generation via LLMClient")
436
+ print("\nGeneration interrupted.")
437
+ else:
438
+ self.log.debug("Halt requested - generation will stop on next iteration")
439
+ print("\nInterrupt requested.")