amd-gaia 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/METADATA +223 -223
  2. amd_gaia-0.15.1.dist-info/RECORD +178 -0
  3. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/entry_points.txt +1 -0
  4. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/licenses/LICENSE.md +20 -20
  5. gaia/__init__.py +29 -29
  6. gaia/agents/__init__.py +19 -19
  7. gaia/agents/base/__init__.py +9 -9
  8. gaia/agents/base/agent.py +2177 -2177
  9. gaia/agents/base/api_agent.py +120 -120
  10. gaia/agents/base/console.py +1841 -1841
  11. gaia/agents/base/errors.py +237 -237
  12. gaia/agents/base/mcp_agent.py +86 -86
  13. gaia/agents/base/tools.py +83 -83
  14. gaia/agents/blender/agent.py +556 -556
  15. gaia/agents/blender/agent_simple.py +133 -135
  16. gaia/agents/blender/app.py +211 -211
  17. gaia/agents/blender/app_simple.py +41 -41
  18. gaia/agents/blender/core/__init__.py +16 -16
  19. gaia/agents/blender/core/materials.py +506 -506
  20. gaia/agents/blender/core/objects.py +316 -316
  21. gaia/agents/blender/core/rendering.py +225 -225
  22. gaia/agents/blender/core/scene.py +220 -220
  23. gaia/agents/blender/core/view.py +146 -146
  24. gaia/agents/chat/__init__.py +9 -9
  25. gaia/agents/chat/agent.py +835 -835
  26. gaia/agents/chat/app.py +1058 -1058
  27. gaia/agents/chat/session.py +508 -508
  28. gaia/agents/chat/tools/__init__.py +15 -15
  29. gaia/agents/chat/tools/file_tools.py +96 -96
  30. gaia/agents/chat/tools/rag_tools.py +1729 -1729
  31. gaia/agents/chat/tools/shell_tools.py +436 -436
  32. gaia/agents/code/__init__.py +7 -7
  33. gaia/agents/code/agent.py +549 -549
  34. gaia/agents/code/cli.py +377 -0
  35. gaia/agents/code/models.py +135 -135
  36. gaia/agents/code/orchestration/__init__.py +24 -24
  37. gaia/agents/code/orchestration/checklist_executor.py +1763 -1763
  38. gaia/agents/code/orchestration/checklist_generator.py +713 -713
  39. gaia/agents/code/orchestration/factories/__init__.py +9 -9
  40. gaia/agents/code/orchestration/factories/base.py +63 -63
  41. gaia/agents/code/orchestration/factories/nextjs_factory.py +118 -118
  42. gaia/agents/code/orchestration/factories/python_factory.py +106 -106
  43. gaia/agents/code/orchestration/orchestrator.py +841 -841
  44. gaia/agents/code/orchestration/project_analyzer.py +391 -391
  45. gaia/agents/code/orchestration/steps/__init__.py +67 -67
  46. gaia/agents/code/orchestration/steps/base.py +188 -188
  47. gaia/agents/code/orchestration/steps/error_handler.py +314 -314
  48. gaia/agents/code/orchestration/steps/nextjs.py +828 -828
  49. gaia/agents/code/orchestration/steps/python.py +307 -307
  50. gaia/agents/code/orchestration/template_catalog.py +469 -469
  51. gaia/agents/code/orchestration/workflows/__init__.py +14 -14
  52. gaia/agents/code/orchestration/workflows/base.py +80 -80
  53. gaia/agents/code/orchestration/workflows/nextjs.py +186 -186
  54. gaia/agents/code/orchestration/workflows/python.py +94 -94
  55. gaia/agents/code/prompts/__init__.py +11 -11
  56. gaia/agents/code/prompts/base_prompt.py +77 -77
  57. gaia/agents/code/prompts/code_patterns.py +2036 -2036
  58. gaia/agents/code/prompts/nextjs_prompt.py +40 -40
  59. gaia/agents/code/prompts/python_prompt.py +109 -109
  60. gaia/agents/code/schema_inference.py +365 -365
  61. gaia/agents/code/system_prompt.py +41 -41
  62. gaia/agents/code/tools/__init__.py +42 -42
  63. gaia/agents/code/tools/cli_tools.py +1138 -1138
  64. gaia/agents/code/tools/code_formatting.py +319 -319
  65. gaia/agents/code/tools/code_tools.py +769 -769
  66. gaia/agents/code/tools/error_fixing.py +1347 -1347
  67. gaia/agents/code/tools/external_tools.py +180 -180
  68. gaia/agents/code/tools/file_io.py +845 -845
  69. gaia/agents/code/tools/prisma_tools.py +190 -190
  70. gaia/agents/code/tools/project_management.py +1016 -1016
  71. gaia/agents/code/tools/testing.py +321 -321
  72. gaia/agents/code/tools/typescript_tools.py +122 -122
  73. gaia/agents/code/tools/validation_parsing.py +461 -461
  74. gaia/agents/code/tools/validation_tools.py +806 -806
  75. gaia/agents/code/tools/web_dev_tools.py +1758 -1758
  76. gaia/agents/code/validators/__init__.py +16 -16
  77. gaia/agents/code/validators/antipattern_checker.py +241 -241
  78. gaia/agents/code/validators/ast_analyzer.py +197 -197
  79. gaia/agents/code/validators/requirements_validator.py +145 -145
  80. gaia/agents/code/validators/syntax_validator.py +171 -171
  81. gaia/agents/docker/__init__.py +7 -7
  82. gaia/agents/docker/agent.py +642 -642
  83. gaia/agents/emr/__init__.py +8 -8
  84. gaia/agents/emr/agent.py +1506 -1506
  85. gaia/agents/emr/cli.py +1322 -1322
  86. gaia/agents/emr/constants.py +475 -475
  87. gaia/agents/emr/dashboard/__init__.py +4 -4
  88. gaia/agents/emr/dashboard/server.py +1974 -1974
  89. gaia/agents/jira/__init__.py +11 -11
  90. gaia/agents/jira/agent.py +894 -894
  91. gaia/agents/jira/jql_templates.py +299 -299
  92. gaia/agents/routing/__init__.py +7 -7
  93. gaia/agents/routing/agent.py +567 -570
  94. gaia/agents/routing/system_prompt.py +75 -75
  95. gaia/agents/summarize/__init__.py +11 -0
  96. gaia/agents/summarize/agent.py +885 -0
  97. gaia/agents/summarize/prompts.py +129 -0
  98. gaia/api/__init__.py +23 -23
  99. gaia/api/agent_registry.py +238 -238
  100. gaia/api/app.py +305 -305
  101. gaia/api/openai_server.py +575 -575
  102. gaia/api/schemas.py +186 -186
  103. gaia/api/sse_handler.py +373 -373
  104. gaia/apps/__init__.py +4 -4
  105. gaia/apps/llm/__init__.py +6 -6
  106. gaia/apps/llm/app.py +173 -169
  107. gaia/apps/summarize/app.py +116 -633
  108. gaia/apps/summarize/html_viewer.py +133 -133
  109. gaia/apps/summarize/pdf_formatter.py +284 -284
  110. gaia/audio/__init__.py +2 -2
  111. gaia/audio/audio_client.py +439 -439
  112. gaia/audio/audio_recorder.py +269 -269
  113. gaia/audio/kokoro_tts.py +599 -599
  114. gaia/audio/whisper_asr.py +432 -432
  115. gaia/chat/__init__.py +16 -16
  116. gaia/chat/app.py +430 -430
  117. gaia/chat/prompts.py +522 -522
  118. gaia/chat/sdk.py +1228 -1225
  119. gaia/cli.py +5481 -5621
  120. gaia/database/__init__.py +10 -10
  121. gaia/database/agent.py +176 -176
  122. gaia/database/mixin.py +290 -290
  123. gaia/database/testing.py +64 -64
  124. gaia/eval/batch_experiment.py +2332 -2332
  125. gaia/eval/claude.py +542 -542
  126. gaia/eval/config.py +37 -37
  127. gaia/eval/email_generator.py +512 -512
  128. gaia/eval/eval.py +3179 -3179
  129. gaia/eval/groundtruth.py +1130 -1130
  130. gaia/eval/transcript_generator.py +582 -582
  131. gaia/eval/webapp/README.md +167 -167
  132. gaia/eval/webapp/package-lock.json +875 -875
  133. gaia/eval/webapp/package.json +20 -20
  134. gaia/eval/webapp/public/app.js +3402 -3402
  135. gaia/eval/webapp/public/index.html +87 -87
  136. gaia/eval/webapp/public/styles.css +3661 -3661
  137. gaia/eval/webapp/server.js +415 -415
  138. gaia/eval/webapp/test-setup.js +72 -72
  139. gaia/llm/__init__.py +9 -2
  140. gaia/llm/base_client.py +60 -0
  141. gaia/llm/exceptions.py +12 -0
  142. gaia/llm/factory.py +70 -0
  143. gaia/llm/lemonade_client.py +3236 -3221
  144. gaia/llm/lemonade_manager.py +294 -294
  145. gaia/llm/providers/__init__.py +9 -0
  146. gaia/llm/providers/claude.py +108 -0
  147. gaia/llm/providers/lemonade.py +120 -0
  148. gaia/llm/providers/openai_provider.py +79 -0
  149. gaia/llm/vlm_client.py +382 -382
  150. gaia/logger.py +189 -189
  151. gaia/mcp/agent_mcp_server.py +245 -245
  152. gaia/mcp/blender_mcp_client.py +138 -138
  153. gaia/mcp/blender_mcp_server.py +648 -648
  154. gaia/mcp/context7_cache.py +332 -332
  155. gaia/mcp/external_services.py +518 -518
  156. gaia/mcp/mcp_bridge.py +811 -550
  157. gaia/mcp/servers/__init__.py +6 -6
  158. gaia/mcp/servers/docker_mcp.py +83 -83
  159. gaia/perf_analysis.py +361 -0
  160. gaia/rag/__init__.py +10 -10
  161. gaia/rag/app.py +293 -293
  162. gaia/rag/demo.py +304 -304
  163. gaia/rag/pdf_utils.py +235 -235
  164. gaia/rag/sdk.py +2194 -2194
  165. gaia/security.py +163 -163
  166. gaia/talk/app.py +289 -289
  167. gaia/talk/sdk.py +538 -538
  168. gaia/testing/__init__.py +87 -87
  169. gaia/testing/assertions.py +330 -330
  170. gaia/testing/fixtures.py +333 -333
  171. gaia/testing/mocks.py +493 -493
  172. gaia/util.py +46 -46
  173. gaia/utils/__init__.py +33 -33
  174. gaia/utils/file_watcher.py +675 -675
  175. gaia/utils/parsing.py +223 -223
  176. gaia/version.py +100 -100
  177. amd_gaia-0.14.3.dist-info/RECORD +0 -168
  178. gaia/agents/code/app.py +0 -266
  179. gaia/llm/llm_client.py +0 -729
  180. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/WHEEL +0 -0
  181. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/top_level.txt +0 -0
@@ -1,439 +1,439 @@
1
- # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2
- # SPDX-License-Identifier: MIT
3
-
4
- import asyncio
5
- import queue
6
- import threading
7
- import time
8
-
9
- from gaia.llm.llm_client import LLMClient
10
- from gaia.logger import get_logger
11
-
12
-
13
- class AudioClient:
14
- """Handles all audio-related functionality including TTS, ASR, and voice chat."""
15
-
16
- def __init__(
17
- self,
18
- whisper_model_size="base",
19
- audio_device_index=None, # Use default input device
20
- silence_threshold=0.5,
21
- enable_tts=True,
22
- logging_level="INFO",
23
- use_claude=False,
24
- use_chatgpt=False,
25
- system_prompt=None,
26
- ):
27
- self.log = get_logger(__name__)
28
- self.log.setLevel(getattr(__import__("logging"), logging_level))
29
-
30
- # Audio configuration
31
- self.whisper_model_size = whisper_model_size
32
- self.audio_device_index = audio_device_index
33
- self.silence_threshold = silence_threshold
34
- self.enable_tts = enable_tts
35
-
36
- # Audio state
37
- self.is_speaking = False
38
- self.tts_thread = None
39
- self.whisper_asr = None
40
- self.transcription_queue = queue.Queue()
41
- self.tts = None
42
-
43
- # Initialize LLM client (base_url handled automatically)
44
- self.llm_client = LLMClient(
45
- use_claude=use_claude,
46
- use_openai=use_chatgpt, # LLMClient uses use_openai, not use_chatgpt
47
- system_prompt=system_prompt,
48
- )
49
-
50
- self.log.info("Audio client initialized.")
51
-
52
- async def start_voice_chat(self, message_processor_callback):
53
- """Start a voice-based chat session."""
54
- try:
55
- self.log.debug("Initializing voice chat...")
56
- print(
57
- "Starting voice chat.\n"
58
- "Say 'stop' to quit application "
59
- "or 'restart' to clear the chat history.\n"
60
- "Press Enter key to stop during audio playback."
61
- )
62
-
63
- # Initialize TTS before starting voice chat
64
- self.initialize_tts()
65
-
66
- from gaia.audio.whisper_asr import WhisperAsr
67
-
68
- # Create WhisperAsr with custom thresholds
69
- # Your audio shows energy levels of 0.02-0.03 when speaking
70
- self.whisper_asr = WhisperAsr(
71
- model_size=self.whisper_model_size,
72
- device_index=self.audio_device_index,
73
- transcription_queue=self.transcription_queue,
74
- silence_threshold=0.01, # Set higher to ensure detection (your levels are 0.01-0.2+)
75
- min_audio_length=16000 * 1.0, # 1 second minimum at 16kHz
76
- )
77
-
78
- # Log the thresholds being used (reduce verbosity)
79
- self.log.debug(
80
- f"Audio settings: SILENCE_THRESHOLD={self.whisper_asr.SILENCE_THRESHOLD}, "
81
- f"MIN_LENGTH={self.whisper_asr.MIN_AUDIO_LENGTH/self.whisper_asr.RATE:.1f}s"
82
- )
83
-
84
- device_name = self.whisper_asr.get_device_name()
85
- self.log.debug(f"Using audio device: {device_name}")
86
-
87
- # Start recording
88
- self.log.debug("Starting audio recording...")
89
- self.whisper_asr.start_recording()
90
-
91
- # Start the processing thread after recording is initialized
92
- self.log.debug("Starting audio processing thread...")
93
- process_thread = threading.Thread(
94
- target=self._process_audio_wrapper, args=(message_processor_callback,)
95
- )
96
- process_thread.daemon = True
97
- process_thread.start()
98
-
99
- # Keep the main thread alive while processing
100
- self.log.debug("Listening for voice input...")
101
- try:
102
- while True:
103
- if not process_thread.is_alive():
104
- self.log.debug("Process thread stopped unexpectedly")
105
- break
106
- if not self.whisper_asr or not self.whisper_asr.is_recording:
107
- self.log.warning("Recording stopped unexpectedly")
108
- break
109
- await asyncio.sleep(0.1)
110
-
111
- except KeyboardInterrupt:
112
- self.log.info("Received keyboard interrupt")
113
- print("\nStopping voice chat...")
114
- except Exception as e:
115
- self.log.error(f"Error in main processing loop: {str(e)}")
116
- raise
117
- finally:
118
- if self.whisper_asr:
119
- self.log.debug("Stopping recording...")
120
- self.whisper_asr.stop_recording()
121
- self.log.debug("Waiting for process thread to finish...")
122
- process_thread.join(timeout=2.0)
123
-
124
- except ImportError:
125
- self.log.error(
126
- 'WhisperAsr not found. Please install voice support with: uv pip install ".[talk]"'
127
- )
128
- raise
129
- except Exception as e:
130
- self.log.error(f"Failed to initialize voice chat: {str(e)}")
131
- raise
132
- finally:
133
- if self.whisper_asr:
134
- self.whisper_asr.stop_recording()
135
- self.log.info("Voice recording stopped")
136
-
137
- async def process_voice_input(self, text, get_stats_callback=None):
138
- """Process transcribed voice input and get AI response"""
139
-
140
- # Initialize TTS streaming
141
- text_queue = None
142
- tts_finished = threading.Event() # Add event to track TTS completion
143
- interrupt_event = threading.Event() # Add event for keyboard interrupts
144
-
145
- try:
146
- # Check if we're currently generating and halt if needed
147
- if self.llm_client.is_generating():
148
- self.log.debug("Generation in progress, halting...")
149
- if self.llm_client.halt_generation():
150
- print("\nGeneration interrupted.")
151
- await asyncio.sleep(0.5)
152
-
153
- # Pause audio recording before sending query
154
- if self.whisper_asr:
155
- self.whisper_asr.pause_recording()
156
- self.log.debug("Recording paused before generation")
157
-
158
- self.log.debug(f"Sending message to LLM: {text[:50]}...")
159
- print("\nGaia: ", end="", flush=True)
160
-
161
- # Keyboard listener thread for both generation and playback
162
- def keyboard_listener():
163
- input() # Wait for any input
164
-
165
- # Use LLMClient to halt generation
166
- if self.llm_client.halt_generation():
167
- print("\nGeneration interrupted.")
168
- else:
169
- print("\nInterrupt requested.")
170
-
171
- interrupt_event.set()
172
- if text_queue:
173
- text_queue.put("__HALT__") # Signal TTS to stop immediately
174
-
175
- # Start keyboard listener thread
176
- keyboard_thread = threading.Thread(target=keyboard_listener)
177
- keyboard_thread.daemon = True
178
- keyboard_thread.start()
179
-
180
- if self.enable_tts:
181
- text_queue = queue.Queue(maxsize=100)
182
-
183
- # Define status callback to update speaking state
184
- def tts_status_callback(is_speaking):
185
- self.is_speaking = is_speaking
186
- if not is_speaking: # When TTS finishes speaking
187
- tts_finished.set()
188
- if self.whisper_asr:
189
- self.whisper_asr.resume_recording()
190
- else: # When TTS starts speaking
191
- if self.whisper_asr:
192
- self.whisper_asr.pause_recording()
193
- self.log.debug(f"TTS speaking state: {is_speaking}")
194
-
195
- self.tts_thread = threading.Thread(
196
- target=self.tts.generate_speech_streaming,
197
- args=(text_queue,),
198
- kwargs={
199
- "status_callback": tts_status_callback,
200
- "interrupt_event": interrupt_event,
201
- },
202
- daemon=True,
203
- )
204
- self.tts_thread.start()
205
-
206
- # Use LLMClient streaming instead of WebSocket
207
- accumulated_response = ""
208
- initial_buffer = "" # Buffer for the start of response
209
- initial_buffer_sent = False
210
-
211
- try:
212
- # Start LLM generation with streaming
213
- response_stream = self.llm_client.generate(text, stream=True)
214
-
215
- # Process streaming response
216
- for chunk in response_stream:
217
- if interrupt_event.is_set():
218
- self.log.debug("Keyboard interrupt detected, stopping...")
219
- if text_queue:
220
- text_queue.put("__END__")
221
- break
222
-
223
- if self.transcription_queue.qsize() > 0:
224
- self.log.debug(
225
- "New input detected during generation, stopping..."
226
- )
227
- if text_queue:
228
- text_queue.put("__END__")
229
- # Use LLMClient to halt generation
230
- if self.llm_client.halt_generation():
231
- self.log.debug("Generation interrupted for new input.")
232
- return
233
-
234
- if chunk:
235
- print(chunk, end="", flush=True)
236
- if text_queue:
237
- if not initial_buffer_sent:
238
- initial_buffer += chunk
239
- # Send if we've reached 20 chars or if we get a clear end marker
240
- if len(initial_buffer) >= 20 or chunk.endswith(
241
- ("\n", ". ", "! ", "? ")
242
- ):
243
- text_queue.put(initial_buffer)
244
- initial_buffer_sent = True
245
- else:
246
- text_queue.put(chunk)
247
- accumulated_response += chunk
248
-
249
- # Send any remaining buffered content
250
- if text_queue:
251
- if not initial_buffer_sent and initial_buffer:
252
- # Small delay for very short responses
253
- if len(initial_buffer) <= 20:
254
- await asyncio.sleep(0.1)
255
- text_queue.put(initial_buffer)
256
- text_queue.put("__END__")
257
-
258
- except Exception as e:
259
- if text_queue:
260
- text_queue.put("__END__")
261
- raise e
262
- finally:
263
- if self.tts_thread and self.tts_thread.is_alive():
264
- self.tts_thread.join(timeout=1.0) # Add timeout to thread join
265
- keyboard_thread.join(timeout=1.0) # Add timeout to keyboard thread join
266
-
267
- print("\n")
268
- # Get performance stats from LLMClient
269
- if get_stats_callback:
270
- # First try the provided callback for backward compatibility
271
- stats = get_stats_callback()
272
- else:
273
- # Use LLMClient stats
274
- stats = self.llm_client.get_performance_stats()
275
-
276
- if stats:
277
- from pprint import pprint
278
-
279
- formatted_stats = {
280
- k: round(v, 1) if isinstance(v, float) else v
281
- for k, v in stats.items()
282
- }
283
- pprint(formatted_stats)
284
-
285
- except Exception as e:
286
- if text_queue:
287
- text_queue.put("__END__")
288
- raise e
289
- finally:
290
- if self.tts_thread and self.tts_thread.is_alive():
291
- # Wait for TTS to finish before resuming recording
292
- tts_finished.wait(timeout=2.0) # Add reasonable timeout
293
- self.tts_thread.join(timeout=1.0)
294
-
295
- # Only resume recording after TTS is completely finished
296
- if self.whisper_asr:
297
- self.whisper_asr.resume_recording()
298
-
299
- def initialize_tts(self):
300
- """Initialize TTS if enabled."""
301
- if self.enable_tts:
302
- try:
303
- from gaia.audio.kokoro_tts import KokoroTTS
304
-
305
- self.tts = KokoroTTS()
306
- self.log.debug("TTS initialized successfully")
307
- except Exception as e:
308
- raise RuntimeError(
309
- f'Failed to initialize TTS:\n{e}\nInstall talk dependencies with: uv pip install ".[talk]"\nYou can also use --no-tts option to disable TTS'
310
- )
311
-
312
- async def speak_text(self, text: str) -> None:
313
- """Speak text using initialized TTS, if available."""
314
- if not self.enable_tts:
315
- return
316
- if not getattr(self, "tts", None):
317
- self.log.debug("TTS is not initialized; skipping speak_text")
318
- return
319
- # Reuse the streaming path used in process_voice_input
320
- text_queue = queue.Queue(maxsize=100)
321
- interrupt_event = threading.Event()
322
- tts_thread = threading.Thread(
323
- target=self.tts.generate_speech_streaming,
324
- args=(text_queue,),
325
- kwargs={"interrupt_event": interrupt_event},
326
- daemon=True,
327
- )
328
- tts_thread.start()
329
- # Send the whole text and end
330
- text_queue.put(text)
331
- text_queue.put("__END__")
332
- tts_thread.join(timeout=5.0)
333
-
334
- def _process_audio_wrapper(self, message_processor_callback):
335
- """Wrapper method to process audio and handle transcriptions"""
336
- try:
337
- accumulated_text = []
338
- current_display = ""
339
- last_transcription_time = time.time()
340
- spinner_chars = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
341
- dots_animation = [" ", ". ", ".. ", "..."]
342
- spinner_idx = 0
343
- dots_idx = 0
344
- animation_counter = 0
345
- self.is_speaking = False # Initialize speaking state
346
-
347
- while self.whisper_asr and self.whisper_asr.is_recording:
348
- try:
349
- text = self.transcription_queue.get(timeout=0.1)
350
-
351
- current_time = time.time()
352
- time_since_last = current_time - last_transcription_time
353
- cleaned_text = text.lower().strip().rstrip(".!?")
354
-
355
- # Handle special commands
356
- if cleaned_text in ["stop"]:
357
- print("\nStopping voice chat...")
358
- self.whisper_asr.stop_recording()
359
- break
360
-
361
- # Update animations
362
- spinner_idx = (spinner_idx + 1) % len(spinner_chars)
363
- animation_counter += 1
364
- if animation_counter % 4 == 0: # Update dots every fourth cycle
365
- dots_idx = (dots_idx + 1) % len(dots_animation)
366
- spinner = spinner_chars[spinner_idx]
367
- dots = dots_animation[dots_idx]
368
-
369
- # Normal text processing - only if it's not a system message
370
- if text != current_display:
371
- # Clear the current line and display updated text with spinner
372
- print(f"\r\033[K{spinner} {text}", end="", flush=True)
373
- current_display = text
374
-
375
- # Only add new text if it's significantly different
376
- if not any(text in existing for existing in accumulated_text):
377
- accumulated_text = [text] # Replace instead of append
378
- last_transcription_time = current_time
379
-
380
- # Process accumulated text after silence threshold
381
- if time_since_last > self.silence_threshold:
382
- if accumulated_text:
383
- complete_text = accumulated_text[
384
- -1
385
- ] # Use only the last transcription
386
- print() # Add a newline before agent response
387
- asyncio.run(message_processor_callback(complete_text))
388
- accumulated_text = []
389
- current_display = ""
390
-
391
- except queue.Empty:
392
- # Update animations
393
- spinner_idx = (spinner_idx + 1) % len(spinner_chars)
394
- animation_counter += 1
395
- if animation_counter % 4 == 0:
396
- dots_idx = (dots_idx + 1) % len(dots_animation)
397
- spinner = spinner_chars[spinner_idx]
398
- dots = dots_animation[dots_idx]
399
-
400
- if current_display:
401
- print(
402
- f"\r\033[K{spinner} {current_display}", end="", flush=True
403
- )
404
- else:
405
- # Access the class-level speaking state
406
- status = (
407
- "Speaking"
408
- if getattr(self, "is_speaking", False)
409
- else "Listening"
410
- )
411
- print(f"\r\033[K{spinner} {status}{dots}", end="", flush=True)
412
-
413
- if (
414
- accumulated_text
415
- and (time.time() - last_transcription_time)
416
- > self.silence_threshold
417
- ):
418
- complete_text = accumulated_text[-1]
419
- print() # Add a newline before agent response
420
- asyncio.run(message_processor_callback(complete_text))
421
- accumulated_text = []
422
- current_display = ""
423
-
424
- except Exception as e:
425
- self.log.error(f"Error in process_audio_wrapper: {str(e)}")
426
- finally:
427
- if self.whisper_asr:
428
- self.whisper_asr.stop_recording()
429
- if self.tts_thread and self.tts_thread.is_alive():
430
- self.tts_thread.join(timeout=1.0) # Add timeout to thread join
431
-
432
- async def halt_generation(self):
433
- """Send a request to halt the current generation."""
434
- if self.llm_client.halt_generation():
435
- self.log.debug("Successfully halted generation via LLMClient")
436
- print("\nGeneration interrupted.")
437
- else:
438
- self.log.debug("Halt requested - generation will stop on next iteration")
439
- print("\nInterrupt requested.")
1
+ # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import asyncio
5
+ import queue
6
+ import threading
7
+ import time
8
+
9
+ from gaia.llm import create_client
10
+ from gaia.logger import get_logger
11
+
12
+
13
+ class AudioClient:
14
+ """Handles all audio-related functionality including TTS, ASR, and voice chat."""
15
+
16
+ def __init__(
17
+ self,
18
+ whisper_model_size="base",
19
+ audio_device_index=None, # Use default input device
20
+ silence_threshold=0.5,
21
+ enable_tts=True,
22
+ logging_level="INFO",
23
+ use_claude=False,
24
+ use_chatgpt=False,
25
+ system_prompt=None,
26
+ ):
27
+ self.log = get_logger(__name__)
28
+ self.log.setLevel(getattr(__import__("logging"), logging_level))
29
+
30
+ # Audio configuration
31
+ self.whisper_model_size = whisper_model_size
32
+ self.audio_device_index = audio_device_index
33
+ self.silence_threshold = silence_threshold
34
+ self.enable_tts = enable_tts
35
+
36
+ # Audio state
37
+ self.is_speaking = False
38
+ self.tts_thread = None
39
+ self.whisper_asr = None
40
+ self.transcription_queue = queue.Queue()
41
+ self.tts = None
42
+
43
+ # Initialize LLM client - factory auto-detects provider from flags
44
+ self.llm_client = create_client(
45
+ use_claude=use_claude,
46
+ use_openai=use_chatgpt,
47
+ system_prompt=system_prompt,
48
+ )
49
+
50
+ self.log.info("Audio client initialized.")
51
+
52
+ async def start_voice_chat(self, message_processor_callback):
53
+ """Start a voice-based chat session."""
54
+ try:
55
+ self.log.debug("Initializing voice chat...")
56
+ print(
57
+ "Starting voice chat.\n"
58
+ "Say 'stop' to quit application "
59
+ "or 'restart' to clear the chat history.\n"
60
+ "Press Enter key to stop during audio playback."
61
+ )
62
+
63
+ # Initialize TTS before starting voice chat
64
+ self.initialize_tts()
65
+
66
+ from gaia.audio.whisper_asr import WhisperAsr
67
+
68
+ # Create WhisperAsr with custom thresholds
69
+ # Your audio shows energy levels of 0.02-0.03 when speaking
70
+ self.whisper_asr = WhisperAsr(
71
+ model_size=self.whisper_model_size,
72
+ device_index=self.audio_device_index,
73
+ transcription_queue=self.transcription_queue,
74
+ silence_threshold=0.01, # Set higher to ensure detection (your levels are 0.01-0.2+)
75
+ min_audio_length=16000 * 1.0, # 1 second minimum at 16kHz
76
+ )
77
+
78
+ # Log the thresholds being used (reduce verbosity)
79
+ self.log.debug(
80
+ f"Audio settings: SILENCE_THRESHOLD={self.whisper_asr.SILENCE_THRESHOLD}, "
81
+ f"MIN_LENGTH={self.whisper_asr.MIN_AUDIO_LENGTH/self.whisper_asr.RATE:.1f}s"
82
+ )
83
+
84
+ device_name = self.whisper_asr.get_device_name()
85
+ self.log.debug(f"Using audio device: {device_name}")
86
+
87
+ # Start recording
88
+ self.log.debug("Starting audio recording...")
89
+ self.whisper_asr.start_recording()
90
+
91
+ # Start the processing thread after recording is initialized
92
+ self.log.debug("Starting audio processing thread...")
93
+ process_thread = threading.Thread(
94
+ target=self._process_audio_wrapper, args=(message_processor_callback,)
95
+ )
96
+ process_thread.daemon = True
97
+ process_thread.start()
98
+
99
+ # Keep the main thread alive while processing
100
+ self.log.debug("Listening for voice input...")
101
+ try:
102
+ while True:
103
+ if not process_thread.is_alive():
104
+ self.log.debug("Process thread stopped unexpectedly")
105
+ break
106
+ if not self.whisper_asr or not self.whisper_asr.is_recording:
107
+ self.log.warning("Recording stopped unexpectedly")
108
+ break
109
+ await asyncio.sleep(0.1)
110
+
111
+ except KeyboardInterrupt:
112
+ self.log.info("Received keyboard interrupt")
113
+ print("\nStopping voice chat...")
114
+ except Exception as e:
115
+ self.log.error(f"Error in main processing loop: {str(e)}")
116
+ raise
117
+ finally:
118
+ if self.whisper_asr:
119
+ self.log.debug("Stopping recording...")
120
+ self.whisper_asr.stop_recording()
121
+ self.log.debug("Waiting for process thread to finish...")
122
+ process_thread.join(timeout=2.0)
123
+
124
+ except ImportError:
125
+ self.log.error(
126
+ 'WhisperAsr not found. Please install voice support with: uv pip install ".[talk]"'
127
+ )
128
+ raise
129
+ except Exception as e:
130
+ self.log.error(f"Failed to initialize voice chat: {str(e)}")
131
+ raise
132
+ finally:
133
+ if self.whisper_asr:
134
+ self.whisper_asr.stop_recording()
135
+ self.log.info("Voice recording stopped")
136
+
137
+ async def process_voice_input(self, text, get_stats_callback=None):
138
+ """Process transcribed voice input and get AI response"""
139
+
140
+ # Initialize TTS streaming
141
+ text_queue = None
142
+ tts_finished = threading.Event() # Add event to track TTS completion
143
+ interrupt_event = threading.Event() # Add event for keyboard interrupts
144
+
145
+ try:
146
+ # Check if we're currently generating and halt if needed
147
+ if self.llm_client.is_generating():
148
+ self.log.debug("Generation in progress, halting...")
149
+ if self.llm_client.halt_generation():
150
+ print("\nGeneration interrupted.")
151
+ await asyncio.sleep(0.5)
152
+
153
+ # Pause audio recording before sending query
154
+ if self.whisper_asr:
155
+ self.whisper_asr.pause_recording()
156
+ self.log.debug("Recording paused before generation")
157
+
158
+ self.log.debug(f"Sending message to LLM: {text[:50]}...")
159
+ print("\nGaia: ", end="", flush=True)
160
+
161
+ # Keyboard listener thread for both generation and playback
162
+ def keyboard_listener():
163
+ input() # Wait for any input
164
+
165
+ # Use LLMClient to halt generation
166
+ if self.llm_client.halt_generation():
167
+ print("\nGeneration interrupted.")
168
+ else:
169
+ print("\nInterrupt requested.")
170
+
171
+ interrupt_event.set()
172
+ if text_queue:
173
+ text_queue.put("__HALT__") # Signal TTS to stop immediately
174
+
175
+ # Start keyboard listener thread
176
+ keyboard_thread = threading.Thread(target=keyboard_listener)
177
+ keyboard_thread.daemon = True
178
+ keyboard_thread.start()
179
+
180
+ if self.enable_tts:
181
+ text_queue = queue.Queue(maxsize=100)
182
+
183
+ # Define status callback to update speaking state
184
+ def tts_status_callback(is_speaking):
185
+ self.is_speaking = is_speaking
186
+ if not is_speaking: # When TTS finishes speaking
187
+ tts_finished.set()
188
+ if self.whisper_asr:
189
+ self.whisper_asr.resume_recording()
190
+ else: # When TTS starts speaking
191
+ if self.whisper_asr:
192
+ self.whisper_asr.pause_recording()
193
+ self.log.debug(f"TTS speaking state: {is_speaking}")
194
+
195
+ self.tts_thread = threading.Thread(
196
+ target=self.tts.generate_speech_streaming,
197
+ args=(text_queue,),
198
+ kwargs={
199
+ "status_callback": tts_status_callback,
200
+ "interrupt_event": interrupt_event,
201
+ },
202
+ daemon=True,
203
+ )
204
+ self.tts_thread.start()
205
+
206
+ # Use LLMClient streaming instead of WebSocket
207
+ accumulated_response = ""
208
+ initial_buffer = "" # Buffer for the start of response
209
+ initial_buffer_sent = False
210
+
211
+ try:
212
+ # Start LLM generation with streaming
213
+ response_stream = self.llm_client.generate(text, stream=True)
214
+
215
+ # Process streaming response
216
+ for chunk in response_stream:
217
+ if interrupt_event.is_set():
218
+ self.log.debug("Keyboard interrupt detected, stopping...")
219
+ if text_queue:
220
+ text_queue.put("__END__")
221
+ break
222
+
223
+ if self.transcription_queue.qsize() > 0:
224
+ self.log.debug(
225
+ "New input detected during generation, stopping..."
226
+ )
227
+ if text_queue:
228
+ text_queue.put("__END__")
229
+ # Use LLMClient to halt generation
230
+ if self.llm_client.halt_generation():
231
+ self.log.debug("Generation interrupted for new input.")
232
+ return
233
+
234
+ if chunk:
235
+ print(chunk, end="", flush=True)
236
+ if text_queue:
237
+ if not initial_buffer_sent:
238
+ initial_buffer += chunk
239
+ # Send if we've reached 20 chars or if we get a clear end marker
240
+ if len(initial_buffer) >= 20 or chunk.endswith(
241
+ ("\n", ". ", "! ", "? ")
242
+ ):
243
+ text_queue.put(initial_buffer)
244
+ initial_buffer_sent = True
245
+ else:
246
+ text_queue.put(chunk)
247
+ accumulated_response += chunk
248
+
249
+ # Send any remaining buffered content
250
+ if text_queue:
251
+ if not initial_buffer_sent and initial_buffer:
252
+ # Small delay for very short responses
253
+ if len(initial_buffer) <= 20:
254
+ await asyncio.sleep(0.1)
255
+ text_queue.put(initial_buffer)
256
+ text_queue.put("__END__")
257
+
258
+ except Exception as e:
259
+ if text_queue:
260
+ text_queue.put("__END__")
261
+ raise e
262
+ finally:
263
+ if self.tts_thread and self.tts_thread.is_alive():
264
+ self.tts_thread.join(timeout=1.0) # Add timeout to thread join
265
+ keyboard_thread.join(timeout=1.0) # Add timeout to keyboard thread join
266
+
267
+ print("\n")
268
+ # Get performance stats from LLMClient
269
+ if get_stats_callback:
270
+ # First try the provided callback for backward compatibility
271
+ stats = get_stats_callback()
272
+ else:
273
+ # Use LLMClient stats
274
+ stats = self.llm_client.get_performance_stats()
275
+
276
+ if stats:
277
+ from pprint import pprint
278
+
279
+ formatted_stats = {
280
+ k: round(v, 1) if isinstance(v, float) else v
281
+ for k, v in stats.items()
282
+ }
283
+ pprint(formatted_stats)
284
+
285
+ except Exception as e:
286
+ if text_queue:
287
+ text_queue.put("__END__")
288
+ raise e
289
+ finally:
290
+ if self.tts_thread and self.tts_thread.is_alive():
291
+ # Wait for TTS to finish before resuming recording
292
+ tts_finished.wait(timeout=2.0) # Add reasonable timeout
293
+ self.tts_thread.join(timeout=1.0)
294
+
295
+ # Only resume recording after TTS is completely finished
296
+ if self.whisper_asr:
297
+ self.whisper_asr.resume_recording()
298
+
299
+ def initialize_tts(self):
300
+ """Initialize TTS if enabled."""
301
+ if self.enable_tts:
302
+ try:
303
+ from gaia.audio.kokoro_tts import KokoroTTS
304
+
305
+ self.tts = KokoroTTS()
306
+ self.log.debug("TTS initialized successfully")
307
+ except Exception as e:
308
+ raise RuntimeError(
309
+ f'Failed to initialize TTS:\n{e}\nInstall talk dependencies with: uv pip install ".[talk]"\nYou can also use --no-tts option to disable TTS'
310
+ )
311
+
312
+ async def speak_text(self, text: str) -> None:
313
+ """Speak text using initialized TTS, if available."""
314
+ if not self.enable_tts:
315
+ return
316
+ if not getattr(self, "tts", None):
317
+ self.log.debug("TTS is not initialized; skipping speak_text")
318
+ return
319
+ # Reuse the streaming path used in process_voice_input
320
+ text_queue = queue.Queue(maxsize=100)
321
+ interrupt_event = threading.Event()
322
+ tts_thread = threading.Thread(
323
+ target=self.tts.generate_speech_streaming,
324
+ args=(text_queue,),
325
+ kwargs={"interrupt_event": interrupt_event},
326
+ daemon=True,
327
+ )
328
+ tts_thread.start()
329
+ # Send the whole text and end
330
+ text_queue.put(text)
331
+ text_queue.put("__END__")
332
+ tts_thread.join(timeout=5.0)
333
+
334
+ def _process_audio_wrapper(self, message_processor_callback):
335
+ """Wrapper method to process audio and handle transcriptions"""
336
+ try:
337
+ accumulated_text = []
338
+ current_display = ""
339
+ last_transcription_time = time.time()
340
+ spinner_chars = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
341
+ dots_animation = [" ", ". ", ".. ", "..."]
342
+ spinner_idx = 0
343
+ dots_idx = 0
344
+ animation_counter = 0
345
+ self.is_speaking = False # Initialize speaking state
346
+
347
+ while self.whisper_asr and self.whisper_asr.is_recording:
348
+ try:
349
+ text = self.transcription_queue.get(timeout=0.1)
350
+
351
+ current_time = time.time()
352
+ time_since_last = current_time - last_transcription_time
353
+ cleaned_text = text.lower().strip().rstrip(".!?")
354
+
355
+ # Handle special commands
356
+ if cleaned_text in ["stop"]:
357
+ print("\nStopping voice chat...")
358
+ self.whisper_asr.stop_recording()
359
+ break
360
+
361
+ # Update animations
362
+ spinner_idx = (spinner_idx + 1) % len(spinner_chars)
363
+ animation_counter += 1
364
+ if animation_counter % 4 == 0: # Update dots every fourth cycle
365
+ dots_idx = (dots_idx + 1) % len(dots_animation)
366
+ spinner = spinner_chars[spinner_idx]
367
+ dots = dots_animation[dots_idx]
368
+
369
+ # Normal text processing - only if it's not a system message
370
+ if text != current_display:
371
+ # Clear the current line and display updated text with spinner
372
+ print(f"\r\033[K{spinner} {text}", end="", flush=True)
373
+ current_display = text
374
+
375
+ # Only add new text if it's significantly different
376
+ if not any(text in existing for existing in accumulated_text):
377
+ accumulated_text = [text] # Replace instead of append
378
+ last_transcription_time = current_time
379
+
380
+ # Process accumulated text after silence threshold
381
+ if time_since_last > self.silence_threshold:
382
+ if accumulated_text:
383
+ complete_text = accumulated_text[
384
+ -1
385
+ ] # Use only the last transcription
386
+ print() # Add a newline before agent response
387
+ asyncio.run(message_processor_callback(complete_text))
388
+ accumulated_text = []
389
+ current_display = ""
390
+
391
+ except queue.Empty:
392
+ # Update animations
393
+ spinner_idx = (spinner_idx + 1) % len(spinner_chars)
394
+ animation_counter += 1
395
+ if animation_counter % 4 == 0:
396
+ dots_idx = (dots_idx + 1) % len(dots_animation)
397
+ spinner = spinner_chars[spinner_idx]
398
+ dots = dots_animation[dots_idx]
399
+
400
+ if current_display:
401
+ print(
402
+ f"\r\033[K{spinner} {current_display}", end="", flush=True
403
+ )
404
+ else:
405
+ # Access the class-level speaking state
406
+ status = (
407
+ "Speaking"
408
+ if getattr(self, "is_speaking", False)
409
+ else "Listening"
410
+ )
411
+ print(f"\r\033[K{spinner} {status}{dots}", end="", flush=True)
412
+
413
+ if (
414
+ accumulated_text
415
+ and (time.time() - last_transcription_time)
416
+ > self.silence_threshold
417
+ ):
418
+ complete_text = accumulated_text[-1]
419
+ print() # Add a newline before agent response
420
+ asyncio.run(message_processor_callback(complete_text))
421
+ accumulated_text = []
422
+ current_display = ""
423
+
424
+ except Exception as e:
425
+ self.log.error(f"Error in process_audio_wrapper: {str(e)}")
426
+ finally:
427
+ if self.whisper_asr:
428
+ self.whisper_asr.stop_recording()
429
+ if self.tts_thread and self.tts_thread.is_alive():
430
+ self.tts_thread.join(timeout=1.0) # Add timeout to thread join
431
+
432
+ async def halt_generation(self):
433
+ """Send a request to halt the current generation."""
434
+ if self.llm_client.halt_generation():
435
+ self.log.debug("Successfully halted generation via LLMClient")
436
+ print("\nGeneration interrupted.")
437
+ else:
438
+ self.log.debug("Halt requested - generation will stop on next iteration")
439
+ print("\nInterrupt requested.")