amd-gaia 0.15.0__py3-none-any.whl → 0.15.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/METADATA +222 -223
  2. amd_gaia-0.15.2.dist-info/RECORD +182 -0
  3. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/WHEEL +1 -1
  4. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/entry_points.txt +1 -0
  5. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/licenses/LICENSE.md +20 -20
  6. gaia/__init__.py +29 -29
  7. gaia/agents/__init__.py +19 -19
  8. gaia/agents/base/__init__.py +9 -9
  9. gaia/agents/base/agent.py +2132 -2177
  10. gaia/agents/base/api_agent.py +119 -120
  11. gaia/agents/base/console.py +1967 -1841
  12. gaia/agents/base/errors.py +237 -237
  13. gaia/agents/base/mcp_agent.py +86 -86
  14. gaia/agents/base/tools.py +88 -83
  15. gaia/agents/blender/__init__.py +7 -0
  16. gaia/agents/blender/agent.py +553 -556
  17. gaia/agents/blender/agent_simple.py +133 -135
  18. gaia/agents/blender/app.py +211 -211
  19. gaia/agents/blender/app_simple.py +41 -41
  20. gaia/agents/blender/core/__init__.py +16 -16
  21. gaia/agents/blender/core/materials.py +506 -506
  22. gaia/agents/blender/core/objects.py +316 -316
  23. gaia/agents/blender/core/rendering.py +225 -225
  24. gaia/agents/blender/core/scene.py +220 -220
  25. gaia/agents/blender/core/view.py +146 -146
  26. gaia/agents/chat/__init__.py +9 -9
  27. gaia/agents/chat/agent.py +809 -835
  28. gaia/agents/chat/app.py +1065 -1058
  29. gaia/agents/chat/session.py +508 -508
  30. gaia/agents/chat/tools/__init__.py +15 -15
  31. gaia/agents/chat/tools/file_tools.py +96 -96
  32. gaia/agents/chat/tools/rag_tools.py +1744 -1729
  33. gaia/agents/chat/tools/shell_tools.py +437 -436
  34. gaia/agents/code/__init__.py +7 -7
  35. gaia/agents/code/agent.py +549 -549
  36. gaia/agents/code/cli.py +377 -0
  37. gaia/agents/code/models.py +135 -135
  38. gaia/agents/code/orchestration/__init__.py +24 -24
  39. gaia/agents/code/orchestration/checklist_executor.py +1763 -1763
  40. gaia/agents/code/orchestration/checklist_generator.py +713 -713
  41. gaia/agents/code/orchestration/factories/__init__.py +9 -9
  42. gaia/agents/code/orchestration/factories/base.py +63 -63
  43. gaia/agents/code/orchestration/factories/nextjs_factory.py +118 -118
  44. gaia/agents/code/orchestration/factories/python_factory.py +106 -106
  45. gaia/agents/code/orchestration/orchestrator.py +841 -841
  46. gaia/agents/code/orchestration/project_analyzer.py +391 -391
  47. gaia/agents/code/orchestration/steps/__init__.py +67 -67
  48. gaia/agents/code/orchestration/steps/base.py +188 -188
  49. gaia/agents/code/orchestration/steps/error_handler.py +314 -314
  50. gaia/agents/code/orchestration/steps/nextjs.py +828 -828
  51. gaia/agents/code/orchestration/steps/python.py +307 -307
  52. gaia/agents/code/orchestration/template_catalog.py +469 -469
  53. gaia/agents/code/orchestration/workflows/__init__.py +14 -14
  54. gaia/agents/code/orchestration/workflows/base.py +80 -80
  55. gaia/agents/code/orchestration/workflows/nextjs.py +186 -186
  56. gaia/agents/code/orchestration/workflows/python.py +94 -94
  57. gaia/agents/code/prompts/__init__.py +11 -11
  58. gaia/agents/code/prompts/base_prompt.py +77 -77
  59. gaia/agents/code/prompts/code_patterns.py +2034 -2036
  60. gaia/agents/code/prompts/nextjs_prompt.py +40 -40
  61. gaia/agents/code/prompts/python_prompt.py +109 -109
  62. gaia/agents/code/schema_inference.py +365 -365
  63. gaia/agents/code/system_prompt.py +41 -41
  64. gaia/agents/code/tools/__init__.py +42 -42
  65. gaia/agents/code/tools/cli_tools.py +1138 -1138
  66. gaia/agents/code/tools/code_formatting.py +319 -319
  67. gaia/agents/code/tools/code_tools.py +769 -769
  68. gaia/agents/code/tools/error_fixing.py +1347 -1347
  69. gaia/agents/code/tools/external_tools.py +180 -180
  70. gaia/agents/code/tools/file_io.py +845 -845
  71. gaia/agents/code/tools/prisma_tools.py +190 -190
  72. gaia/agents/code/tools/project_management.py +1016 -1016
  73. gaia/agents/code/tools/testing.py +321 -321
  74. gaia/agents/code/tools/typescript_tools.py +122 -122
  75. gaia/agents/code/tools/validation_parsing.py +461 -461
  76. gaia/agents/code/tools/validation_tools.py +806 -806
  77. gaia/agents/code/tools/web_dev_tools.py +1758 -1758
  78. gaia/agents/code/validators/__init__.py +16 -16
  79. gaia/agents/code/validators/antipattern_checker.py +241 -241
  80. gaia/agents/code/validators/ast_analyzer.py +197 -197
  81. gaia/agents/code/validators/requirements_validator.py +145 -145
  82. gaia/agents/code/validators/syntax_validator.py +171 -171
  83. gaia/agents/docker/__init__.py +7 -7
  84. gaia/agents/docker/agent.py +643 -642
  85. gaia/agents/emr/__init__.py +8 -8
  86. gaia/agents/emr/agent.py +1504 -1506
  87. gaia/agents/emr/cli.py +1322 -1322
  88. gaia/agents/emr/constants.py +475 -475
  89. gaia/agents/emr/dashboard/__init__.py +4 -4
  90. gaia/agents/emr/dashboard/server.py +1972 -1974
  91. gaia/agents/jira/__init__.py +11 -11
  92. gaia/agents/jira/agent.py +894 -894
  93. gaia/agents/jira/jql_templates.py +299 -299
  94. gaia/agents/routing/__init__.py +7 -7
  95. gaia/agents/routing/agent.py +567 -570
  96. gaia/agents/routing/system_prompt.py +75 -75
  97. gaia/agents/summarize/__init__.py +11 -0
  98. gaia/agents/summarize/agent.py +885 -0
  99. gaia/agents/summarize/prompts.py +129 -0
  100. gaia/api/__init__.py +23 -23
  101. gaia/api/agent_registry.py +238 -238
  102. gaia/api/app.py +305 -305
  103. gaia/api/openai_server.py +575 -575
  104. gaia/api/schemas.py +186 -186
  105. gaia/api/sse_handler.py +373 -373
  106. gaia/apps/__init__.py +4 -4
  107. gaia/apps/llm/__init__.py +6 -6
  108. gaia/apps/llm/app.py +184 -169
  109. gaia/apps/summarize/app.py +116 -633
  110. gaia/apps/summarize/html_viewer.py +133 -133
  111. gaia/apps/summarize/pdf_formatter.py +284 -284
  112. gaia/audio/__init__.py +2 -2
  113. gaia/audio/audio_client.py +439 -439
  114. gaia/audio/audio_recorder.py +269 -269
  115. gaia/audio/kokoro_tts.py +599 -599
  116. gaia/audio/whisper_asr.py +432 -432
  117. gaia/chat/__init__.py +16 -16
  118. gaia/chat/app.py +428 -430
  119. gaia/chat/prompts.py +522 -522
  120. gaia/chat/sdk.py +1228 -1225
  121. gaia/cli.py +5659 -5632
  122. gaia/database/__init__.py +10 -10
  123. gaia/database/agent.py +176 -176
  124. gaia/database/mixin.py +290 -290
  125. gaia/database/testing.py +64 -64
  126. gaia/eval/batch_experiment.py +2332 -2332
  127. gaia/eval/claude.py +542 -542
  128. gaia/eval/config.py +37 -37
  129. gaia/eval/email_generator.py +512 -512
  130. gaia/eval/eval.py +3179 -3179
  131. gaia/eval/groundtruth.py +1130 -1130
  132. gaia/eval/transcript_generator.py +582 -582
  133. gaia/eval/webapp/README.md +167 -167
  134. gaia/eval/webapp/package-lock.json +875 -875
  135. gaia/eval/webapp/package.json +20 -20
  136. gaia/eval/webapp/public/app.js +3402 -3402
  137. gaia/eval/webapp/public/index.html +87 -87
  138. gaia/eval/webapp/public/styles.css +3661 -3661
  139. gaia/eval/webapp/server.js +415 -415
  140. gaia/eval/webapp/test-setup.js +72 -72
  141. gaia/installer/__init__.py +23 -0
  142. gaia/installer/init_command.py +1275 -0
  143. gaia/installer/lemonade_installer.py +619 -0
  144. gaia/llm/__init__.py +10 -2
  145. gaia/llm/base_client.py +60 -0
  146. gaia/llm/exceptions.py +12 -0
  147. gaia/llm/factory.py +70 -0
  148. gaia/llm/lemonade_client.py +3421 -3221
  149. gaia/llm/lemonade_manager.py +294 -294
  150. gaia/llm/providers/__init__.py +9 -0
  151. gaia/llm/providers/claude.py +108 -0
  152. gaia/llm/providers/lemonade.py +118 -0
  153. gaia/llm/providers/openai_provider.py +79 -0
  154. gaia/llm/vlm_client.py +382 -382
  155. gaia/logger.py +189 -189
  156. gaia/mcp/agent_mcp_server.py +245 -245
  157. gaia/mcp/blender_mcp_client.py +138 -138
  158. gaia/mcp/blender_mcp_server.py +648 -648
  159. gaia/mcp/context7_cache.py +332 -332
  160. gaia/mcp/external_services.py +518 -518
  161. gaia/mcp/mcp_bridge.py +811 -550
  162. gaia/mcp/servers/__init__.py +6 -6
  163. gaia/mcp/servers/docker_mcp.py +83 -83
  164. gaia/perf_analysis.py +361 -0
  165. gaia/rag/__init__.py +10 -10
  166. gaia/rag/app.py +293 -293
  167. gaia/rag/demo.py +304 -304
  168. gaia/rag/pdf_utils.py +235 -235
  169. gaia/rag/sdk.py +2194 -2194
  170. gaia/security.py +183 -163
  171. gaia/talk/app.py +287 -289
  172. gaia/talk/sdk.py +538 -538
  173. gaia/testing/__init__.py +87 -87
  174. gaia/testing/assertions.py +330 -330
  175. gaia/testing/fixtures.py +333 -333
  176. gaia/testing/mocks.py +493 -493
  177. gaia/util.py +46 -46
  178. gaia/utils/__init__.py +33 -33
  179. gaia/utils/file_watcher.py +675 -675
  180. gaia/utils/parsing.py +223 -223
  181. gaia/version.py +100 -100
  182. amd_gaia-0.15.0.dist-info/RECORD +0 -168
  183. gaia/agents/code/app.py +0 -266
  184. gaia/llm/llm_client.py +0 -723
  185. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/top_level.txt +0 -0
gaia/audio/whisper_asr.py CHANGED
@@ -1,432 +1,432 @@
1
- # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2
- # SPDX-License-Identifier: MIT
3
-
4
- # Standard library imports
5
- import os
6
- import queue
7
- import threading
8
- import time
9
-
10
- # Third-party imports
11
- import numpy as np
12
-
13
- try:
14
- import pyaudio
15
- except ImportError:
16
- pyaudio = None
17
-
18
- try:
19
- import torch
20
- except ImportError:
21
- torch = None
22
-
23
- try:
24
- import whisper
25
- except ImportError:
26
- whisper = None
27
-
28
- from gaia.audio.audio_recorder import AudioRecorder
29
-
30
- # First-party imports
31
- from gaia.logger import get_logger
32
-
33
-
34
- class WhisperAsr(AudioRecorder):
35
- log = get_logger(__name__)
36
-
37
- def __init__(
38
- self,
39
- model_size="small",
40
- device_index=None, # Use default input device
41
- transcription_queue=None,
42
- enable_cuda=False,
43
- silence_threshold=None, # Custom silence threshold
44
- min_audio_length=None, # Custom minimum audio length
45
- ):
46
- # Check for required dependencies
47
- missing = []
48
- if pyaudio is None:
49
- missing.append("pyaudio")
50
- if torch is None:
51
- missing.append("torch")
52
- if whisper is None:
53
- missing.append("openai-whisper")
54
-
55
- if missing:
56
- error_msg = (
57
- f"\n❌ Error: Missing required talk dependencies: {', '.join(missing)}\n\n"
58
- f"Please install the talk dependencies:\n"
59
- f' uv pip install -e ".[talk]"\n\n'
60
- f"Or install packages directly:\n"
61
- f" uv pip install {' '.join(missing)}\n"
62
- )
63
- raise ImportError(error_msg)
64
-
65
- super().__init__(device_index)
66
-
67
- # Override thresholds if provided
68
- if silence_threshold is not None:
69
- self.SILENCE_THRESHOLD = silence_threshold
70
- if min_audio_length is not None:
71
- self.MIN_AUDIO_LENGTH = min_audio_length
72
- self.log = self.__class__.log
73
-
74
- # Initialize Whisper model with optimized settings
75
- self.log.debug(f"Loading Whisper model: {model_size}")
76
- self.model = whisper.load_model(model_size)
77
-
78
- # Add compute type optimization if GPU available
79
- self.using_cuda = enable_cuda and torch.cuda.is_available()
80
- if self.using_cuda:
81
- self.model.to(torch.device("cuda"))
82
- torch.set_float32_matmul_precision("high")
83
- # Enable torch compile for better performance
84
- if hasattr(torch, "compile"):
85
- self.model = torch.compile(self.model)
86
- self.log.debug("GPU acceleration enabled with optimizations")
87
-
88
- # Add batch processing capability
89
- self.batch_size = 3 # Process multiple audio segments at once
90
- self.audio_buffer = []
91
- self.last_process_time = time.time()
92
- self.process_interval = 0.5 # Process every 0.5 seconds
93
-
94
- # Rest of initialization
95
- self.transcription_queue = transcription_queue
96
-
97
- def _record_audio_streaming(self):
98
- """Record audio for streaming mode - puts chunks directly into queue."""
99
- pa = pyaudio.PyAudio()
100
-
101
- try:
102
- # Log device info
103
- if self.device_index is not None:
104
- device_info = pa.get_device_info_by_index(self.device_index)
105
- else:
106
- device_info = pa.get_default_input_device_info()
107
- self.device_index = device_info["index"]
108
-
109
- self.log.debug(
110
- f"Using audio device [{self.device_index}]: {device_info['name']}"
111
- )
112
-
113
- self.stream = pa.open(
114
- format=self.FORMAT,
115
- channels=self.CHANNELS,
116
- rate=self.RATE,
117
- input=True,
118
- input_device_index=self.device_index,
119
- frames_per_buffer=self.CHUNK,
120
- )
121
-
122
- self.log.debug("Streaming recording started...")
123
- audio_buffer = np.array([], dtype=np.float32)
124
- chunks_processed = 0
125
-
126
- # Use 3-second chunks for better context (Whisper works better with longer segments)
127
- chunk_duration = 3.0 # seconds
128
- overlap_duration = 0.5 # seconds of overlap to avoid cutting words
129
-
130
- chunk_size = int(self.RATE * chunk_duration)
131
- overlap_size = int(self.RATE * overlap_duration)
132
-
133
- # Simple VAD - only send chunks with sufficient audio energy
134
- min_energy_threshold = 0.001 # Minimum energy to consider as speech
135
-
136
- while self.is_recording:
137
- try:
138
- data = np.frombuffer(
139
- self.stream.read(self.CHUNK, exception_on_overflow=False),
140
- dtype=np.float32,
141
- )
142
- audio_buffer = np.concatenate((audio_buffer, data))
143
-
144
- # Process when we have enough audio (3 seconds)
145
- if len(audio_buffer) >= chunk_size:
146
- chunk = audio_buffer[:chunk_size].copy()
147
-
148
- # Only process if chunk has sufficient audio energy (not silence)
149
- energy = np.abs(chunk).mean()
150
- chunks_processed += 1
151
-
152
- if energy > min_energy_threshold:
153
- self.audio_queue.put(chunk)
154
- self.log.debug(
155
- f"Chunk {chunks_processed}: Added to queue (energy: {energy:.6f})"
156
- )
157
- else:
158
- self.log.debug(
159
- f"Chunk {chunks_processed}: Skipped - too quiet (energy: {energy:.6f})"
160
- )
161
-
162
- # Keep overlap to maintain context between chunks
163
- audio_buffer = audio_buffer[chunk_size - overlap_size :]
164
-
165
- except Exception as e:
166
- self.log.error(f"Error reading from stream: {e}")
167
- break
168
-
169
- # Process any remaining audio
170
- if len(audio_buffer) > self.RATE * 0.5: # At least 0.5 seconds
171
- self.audio_queue.put(audio_buffer.copy())
172
-
173
- finally:
174
- if self.stream:
175
- self.stream.stop_stream()
176
- self.stream.close()
177
- pa.terminate()
178
-
179
- def start_recording_streaming(self):
180
- """Start recording in streaming mode."""
181
- self.is_recording = True
182
- self.record_thread = threading.Thread(target=self._record_audio_streaming)
183
- self.record_thread.start()
184
- time.sleep(0.1)
185
- self.process_thread = threading.Thread(target=self._process_audio)
186
- self.process_thread.start()
187
- time.sleep(0.1)
188
-
189
- def _process_audio(self):
190
- """Internal method to process audio with batching and optimizations."""
191
- self.log.debug("Starting optimized audio processing...")
192
- processed_count = 0
193
-
194
- while self.is_recording:
195
- try:
196
- current_time = time.time()
197
-
198
- # Collect audio segments into buffer
199
- while len(self.audio_buffer) < self.batch_size:
200
- try:
201
- audio = self.audio_queue.get_nowait()
202
- if len(audio) > 0:
203
- self.audio_buffer.append(audio)
204
- self.log.debug(
205
- f"Added audio to buffer (size: {len(self.audio_buffer)}/{self.batch_size})"
206
- )
207
- except queue.Empty:
208
- break
209
-
210
- # Process batch if enough time has passed or buffer is full
211
- if len(self.audio_buffer) >= self.batch_size or (
212
- len(self.audio_buffer) > 0
213
- and current_time - self.last_process_time >= self.process_interval
214
- ):
215
-
216
- try:
217
- processed_count += 1
218
- self.log.debug(
219
- f"Processing batch {processed_count} with {len(self.audio_buffer)} segments..."
220
- )
221
-
222
- with torch.inference_mode():
223
- # Process batch of audio segments with better quality settings
224
- results = [
225
- self.model.transcribe(
226
- audio,
227
- language="en",
228
- temperature=0.0, # Deterministic, no randomness
229
- no_speech_threshold=0.6, # Higher threshold to filter noise
230
- condition_on_previous_text=False, # Don't use previous text as it can cause hallucinations
231
- beam_size=5, # Larger beam for better quality
232
- best_of=5, # More attempts for better quality
233
- fp16=self.using_cuda,
234
- suppress_blank=True, # Suppress blank outputs
235
- suppress_tokens=[-1], # Suppress special tokens
236
- without_timestamps=False, # Keep timestamps for context
237
- )
238
- for audio in self.audio_buffer
239
- ]
240
-
241
- # Send transcriptions to queue
242
- for i, result in enumerate(results):
243
- transcribed_text = result["text"].strip()
244
- if transcribed_text and self.transcription_queue:
245
- self.transcription_queue.put(transcribed_text)
246
- self.log.debug(
247
- f"Transcribed segment {i+1}: {transcribed_text}"
248
- )
249
- else:
250
- self.log.debug(f"Segment {i+1}: No text or empty")
251
-
252
- self.audio_buffer = []
253
- self.last_process_time = current_time
254
-
255
- except Exception as e:
256
- self.log.error(f"Batch transcription error: {e}")
257
- self.audio_buffer = [] # Clear buffer on error
258
-
259
- else:
260
- # Small sleep to prevent CPU spinning
261
- time.sleep(0.01)
262
-
263
- except Exception as e:
264
- self.log.error(f"Error in audio processing: {e}")
265
- if not self.is_recording:
266
- break
267
-
268
- self.log.debug("Audio processing stopped")
269
-
270
- def transcribe_file(self, file_path):
271
- """Transcribe an existing audio file."""
272
- if not os.path.exists(file_path):
273
- raise FileNotFoundError(f"Audio file not found: {file_path}")
274
-
275
- result = self.model.transcribe(file_path)
276
- return result["text"]
277
-
278
-
279
- if __name__ == "__main__":
280
- import argparse
281
-
282
- parser = argparse.ArgumentParser(description="Whisper ASR Demo")
283
- parser.add_argument(
284
- "--mode",
285
- choices=["file", "mic", "both"],
286
- default="file",
287
- help="Test mode: file, mic, or both",
288
- )
289
- parser.add_argument(
290
- "--duration",
291
- type=int,
292
- default=5,
293
- help="Recording duration in seconds for mic mode",
294
- )
295
- parser.add_argument(
296
- "--model",
297
- default="base",
298
- help="Whisper model size (tiny, base, small, medium, large)",
299
- )
300
- parser.add_argument(
301
- "--cuda", action="store_true", help="Enable CUDA acceleration if available"
302
- )
303
- parser.add_argument(
304
- "--stream",
305
- action="store_true",
306
- help="Stream transcriptions in real-time as they arrive",
307
- )
308
- args = parser.parse_args()
309
-
310
- print("=== Whisper ASR Demo ===")
311
- print(f"Model: {args.model}, CUDA: {args.cuda}")
312
-
313
- # Test file transcription
314
- if args.mode in ["file", "both"]:
315
- print("\n--- File Transcription Test ---")
316
- asr = WhisperAsr(model_size=args.model, enable_cuda=args.cuda)
317
- try:
318
- test_file = "./data/audio/test.m4a"
319
- start_time = time.time()
320
- text = asr.transcribe_file(test_file)
321
- elapsed = time.time() - start_time
322
- print(f"Transcription: {text}")
323
- print(f"Time taken: {elapsed:.2f} seconds")
324
- except FileNotFoundError:
325
- print(f"No audio file found at {test_file}")
326
-
327
- # Test microphone transcription
328
- if args.mode in ["mic", "both"]:
329
- print("\n--- Microphone Transcription Test ---")
330
- print(f"Recording for {args.duration} seconds...")
331
- print(f"Mode: {'Streaming' if args.stream else 'Batch'}")
332
-
333
- # Create a queue to collect transcriptions
334
- transcription_queue = queue.Queue()
335
- asr = WhisperAsr(
336
- model_size=args.model,
337
- transcription_queue=transcription_queue,
338
- enable_cuda=args.cuda,
339
- )
340
-
341
- start_time = time.time()
342
- transcriptions = []
343
-
344
- if args.stream:
345
- # Streaming mode - show text as it arrives
346
- print("Starting recording threads...")
347
- asr.start_recording_streaming() # Use streaming-specific method
348
-
349
- print("\n[STREAMING] Transcriptions as they arrive:")
350
- print("-" * 50)
351
-
352
- # Give recording a moment to start properly
353
- time.sleep(0.5)
354
-
355
- print(f"Recording status: {asr.is_recording}")
356
- print(f"Listening for {args.duration} seconds...")
357
-
358
- end_time = start_time + args.duration
359
- checks = 0
360
-
361
- try:
362
- while time.time() < end_time:
363
- checks += 1
364
- # Check for new transcriptions
365
- while not transcription_queue.empty():
366
- try:
367
- text = transcription_queue.get_nowait()
368
- if text:
369
- transcriptions.append(text)
370
- # Stream the text immediately with timestamp
371
- time_offset = time.time() - start_time
372
- print(f"[{time_offset:5.1f}s] {text}")
373
- except queue.Empty:
374
- break
375
-
376
- # Debug: Show we're still checking
377
- if checks % 20 == 0: # Every second (20 * 0.05)
378
- print(
379
- f" ... still listening (audio_queue size: ~{asr.audio_queue.qsize()})"
380
- )
381
-
382
- # Small sleep to prevent CPU spinning
383
- time.sleep(0.05)
384
-
385
- finally:
386
- # Stop recording
387
- asr.stop_recording()
388
-
389
- # Collect any remaining transcriptions
390
- time.sleep(0.5) # Give a moment for final processing
391
- while not transcription_queue.empty():
392
- try:
393
- text = transcription_queue.get_nowait()
394
- if text:
395
- transcriptions.append(text)
396
- time_offset = time.time() - start_time
397
- print(f"[{time_offset:5.1f}s] {text}")
398
- except queue.Empty:
399
- break
400
-
401
- print("-" * 50)
402
-
403
- else:
404
- # Batch mode - collect all text then display
405
- asr.start_recording(duration=args.duration) # Blocking
406
-
407
- # Collect all transcriptions after recording
408
- while not transcription_queue.empty():
409
- try:
410
- text = transcription_queue.get_nowait()
411
- if text:
412
- transcriptions.append(text)
413
- except queue.Empty:
414
- break
415
-
416
- elapsed = time.time() - start_time
417
-
418
- # Display results
419
- print("\nResults:")
420
- if transcriptions:
421
- print(f" Transcription segments: {len(transcriptions)}")
422
- if not args.stream: # Show individual segments in batch mode
423
- for i, text in enumerate(transcriptions, 1):
424
- print(f" {i}. {text}")
425
- print(f" Full transcript: {' '.join(transcriptions)}")
426
- else:
427
- print(" No transcriptions received (possibly no speech detected)")
428
-
429
- print(f" Total time: {elapsed:.2f} seconds")
430
- print(f" Processing efficiency: {args.duration/elapsed:.2f}x realtime")
431
-
432
- print("\nDemo completed!")
1
+ # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ # Standard library imports
5
+ import os
6
+ import queue
7
+ import threading
8
+ import time
9
+
10
+ # Third-party imports
11
+ import numpy as np
12
+
13
+ try:
14
+ import pyaudio
15
+ except ImportError:
16
+ pyaudio = None
17
+
18
+ try:
19
+ import torch
20
+ except ImportError:
21
+ torch = None
22
+
23
+ try:
24
+ import whisper
25
+ except ImportError:
26
+ whisper = None
27
+
28
+ from gaia.audio.audio_recorder import AudioRecorder
29
+
30
+ # First-party imports
31
+ from gaia.logger import get_logger
32
+
33
+
34
+ class WhisperAsr(AudioRecorder):
35
+ log = get_logger(__name__)
36
+
37
+ def __init__(
38
+ self,
39
+ model_size="small",
40
+ device_index=None, # Use default input device
41
+ transcription_queue=None,
42
+ enable_cuda=False,
43
+ silence_threshold=None, # Custom silence threshold
44
+ min_audio_length=None, # Custom minimum audio length
45
+ ):
46
+ # Check for required dependencies
47
+ missing = []
48
+ if pyaudio is None:
49
+ missing.append("pyaudio")
50
+ if torch is None:
51
+ missing.append("torch")
52
+ if whisper is None:
53
+ missing.append("openai-whisper")
54
+
55
+ if missing:
56
+ error_msg = (
57
+ f"\n❌ Error: Missing required talk dependencies: {', '.join(missing)}\n\n"
58
+ f"Please install the talk dependencies:\n"
59
+ f' uv pip install -e ".[talk]"\n\n'
60
+ f"Or install packages directly:\n"
61
+ f" uv pip install {' '.join(missing)}\n"
62
+ )
63
+ raise ImportError(error_msg)
64
+
65
+ super().__init__(device_index)
66
+
67
+ # Override thresholds if provided
68
+ if silence_threshold is not None:
69
+ self.SILENCE_THRESHOLD = silence_threshold
70
+ if min_audio_length is not None:
71
+ self.MIN_AUDIO_LENGTH = min_audio_length
72
+ self.log = self.__class__.log
73
+
74
+ # Initialize Whisper model with optimized settings
75
+ self.log.debug(f"Loading Whisper model: {model_size}")
76
+ self.model = whisper.load_model(model_size)
77
+
78
+ # Add compute type optimization if GPU available
79
+ self.using_cuda = enable_cuda and torch.cuda.is_available()
80
+ if self.using_cuda:
81
+ self.model.to(torch.device("cuda"))
82
+ torch.set_float32_matmul_precision("high")
83
+ # Enable torch compile for better performance
84
+ if hasattr(torch, "compile"):
85
+ self.model = torch.compile(self.model)
86
+ self.log.debug("GPU acceleration enabled with optimizations")
87
+
88
+ # Add batch processing capability
89
+ self.batch_size = 3 # Process multiple audio segments at once
90
+ self.audio_buffer = []
91
+ self.last_process_time = time.time()
92
+ self.process_interval = 0.5 # Process every 0.5 seconds
93
+
94
+ # Rest of initialization
95
+ self.transcription_queue = transcription_queue
96
+
97
+ def _record_audio_streaming(self):
98
+ """Record audio for streaming mode - puts chunks directly into queue."""
99
+ pa = pyaudio.PyAudio()
100
+
101
+ try:
102
+ # Log device info
103
+ if self.device_index is not None:
104
+ device_info = pa.get_device_info_by_index(self.device_index)
105
+ else:
106
+ device_info = pa.get_default_input_device_info()
107
+ self.device_index = device_info["index"]
108
+
109
+ self.log.debug(
110
+ f"Using audio device [{self.device_index}]: {device_info['name']}"
111
+ )
112
+
113
+ self.stream = pa.open(
114
+ format=self.FORMAT,
115
+ channels=self.CHANNELS,
116
+ rate=self.RATE,
117
+ input=True,
118
+ input_device_index=self.device_index,
119
+ frames_per_buffer=self.CHUNK,
120
+ )
121
+
122
+ self.log.debug("Streaming recording started...")
123
+ audio_buffer = np.array([], dtype=np.float32)
124
+ chunks_processed = 0
125
+
126
+ # Use 3-second chunks for better context (Whisper works better with longer segments)
127
+ chunk_duration = 3.0 # seconds
128
+ overlap_duration = 0.5 # seconds of overlap to avoid cutting words
129
+
130
+ chunk_size = int(self.RATE * chunk_duration)
131
+ overlap_size = int(self.RATE * overlap_duration)
132
+
133
+ # Simple VAD - only send chunks with sufficient audio energy
134
+ min_energy_threshold = 0.001 # Minimum energy to consider as speech
135
+
136
+ while self.is_recording:
137
+ try:
138
+ data = np.frombuffer(
139
+ self.stream.read(self.CHUNK, exception_on_overflow=False),
140
+ dtype=np.float32,
141
+ )
142
+ audio_buffer = np.concatenate((audio_buffer, data))
143
+
144
+ # Process when we have enough audio (3 seconds)
145
+ if len(audio_buffer) >= chunk_size:
146
+ chunk = audio_buffer[:chunk_size].copy()
147
+
148
+ # Only process if chunk has sufficient audio energy (not silence)
149
+ energy = np.abs(chunk).mean()
150
+ chunks_processed += 1
151
+
152
+ if energy > min_energy_threshold:
153
+ self.audio_queue.put(chunk)
154
+ self.log.debug(
155
+ f"Chunk {chunks_processed}: Added to queue (energy: {energy:.6f})"
156
+ )
157
+ else:
158
+ self.log.debug(
159
+ f"Chunk {chunks_processed}: Skipped - too quiet (energy: {energy:.6f})"
160
+ )
161
+
162
+ # Keep overlap to maintain context between chunks
163
+ audio_buffer = audio_buffer[chunk_size - overlap_size :]
164
+
165
+ except Exception as e:
166
+ self.log.error(f"Error reading from stream: {e}")
167
+ break
168
+
169
+ # Process any remaining audio
170
+ if len(audio_buffer) > self.RATE * 0.5: # At least 0.5 seconds
171
+ self.audio_queue.put(audio_buffer.copy())
172
+
173
+ finally:
174
+ if self.stream:
175
+ self.stream.stop_stream()
176
+ self.stream.close()
177
+ pa.terminate()
178
+
179
+ def start_recording_streaming(self):
180
+ """Start recording in streaming mode."""
181
+ self.is_recording = True
182
+ self.record_thread = threading.Thread(target=self._record_audio_streaming)
183
+ self.record_thread.start()
184
+ time.sleep(0.1)
185
+ self.process_thread = threading.Thread(target=self._process_audio)
186
+ self.process_thread.start()
187
+ time.sleep(0.1)
188
+
189
+ def _process_audio(self):
190
+ """Internal method to process audio with batching and optimizations."""
191
+ self.log.debug("Starting optimized audio processing...")
192
+ processed_count = 0
193
+
194
+ while self.is_recording:
195
+ try:
196
+ current_time = time.time()
197
+
198
+ # Collect audio segments into buffer
199
+ while len(self.audio_buffer) < self.batch_size:
200
+ try:
201
+ audio = self.audio_queue.get_nowait()
202
+ if len(audio) > 0:
203
+ self.audio_buffer.append(audio)
204
+ self.log.debug(
205
+ f"Added audio to buffer (size: {len(self.audio_buffer)}/{self.batch_size})"
206
+ )
207
+ except queue.Empty:
208
+ break
209
+
210
+ # Process batch if enough time has passed or buffer is full
211
+ if len(self.audio_buffer) >= self.batch_size or (
212
+ len(self.audio_buffer) > 0
213
+ and current_time - self.last_process_time >= self.process_interval
214
+ ):
215
+
216
+ try:
217
+ processed_count += 1
218
+ self.log.debug(
219
+ f"Processing batch {processed_count} with {len(self.audio_buffer)} segments..."
220
+ )
221
+
222
+ with torch.inference_mode():
223
+ # Process batch of audio segments with better quality settings
224
+ results = [
225
+ self.model.transcribe(
226
+ audio,
227
+ language="en",
228
+ temperature=0.0, # Deterministic, no randomness
229
+ no_speech_threshold=0.6, # Higher threshold to filter noise
230
+ condition_on_previous_text=False, # Don't use previous text as it can cause hallucinations
231
+ beam_size=5, # Larger beam for better quality
232
+ best_of=5, # More attempts for better quality
233
+ fp16=self.using_cuda,
234
+ suppress_blank=True, # Suppress blank outputs
235
+ suppress_tokens=[-1], # Suppress special tokens
236
+ without_timestamps=False, # Keep timestamps for context
237
+ )
238
+ for audio in self.audio_buffer
239
+ ]
240
+
241
+ # Send transcriptions to queue
242
+ for i, result in enumerate(results):
243
+ transcribed_text = result["text"].strip()
244
+ if transcribed_text and self.transcription_queue:
245
+ self.transcription_queue.put(transcribed_text)
246
+ self.log.debug(
247
+ f"Transcribed segment {i+1}: {transcribed_text}"
248
+ )
249
+ else:
250
+ self.log.debug(f"Segment {i+1}: No text or empty")
251
+
252
+ self.audio_buffer = []
253
+ self.last_process_time = current_time
254
+
255
+ except Exception as e:
256
+ self.log.error(f"Batch transcription error: {e}")
257
+ self.audio_buffer = [] # Clear buffer on error
258
+
259
+ else:
260
+ # Small sleep to prevent CPU spinning
261
+ time.sleep(0.01)
262
+
263
+ except Exception as e:
264
+ self.log.error(f"Error in audio processing: {e}")
265
+ if not self.is_recording:
266
+ break
267
+
268
+ self.log.debug("Audio processing stopped")
269
+
270
+ def transcribe_file(self, file_path):
271
+ """Transcribe an existing audio file."""
272
+ if not os.path.exists(file_path):
273
+ raise FileNotFoundError(f"Audio file not found: {file_path}")
274
+
275
+ result = self.model.transcribe(file_path)
276
+ return result["text"]
277
+
278
+
279
+ if __name__ == "__main__":
280
+ import argparse
281
+
282
+ parser = argparse.ArgumentParser(description="Whisper ASR Demo")
283
+ parser.add_argument(
284
+ "--mode",
285
+ choices=["file", "mic", "both"],
286
+ default="file",
287
+ help="Test mode: file, mic, or both",
288
+ )
289
+ parser.add_argument(
290
+ "--duration",
291
+ type=int,
292
+ default=5,
293
+ help="Recording duration in seconds for mic mode",
294
+ )
295
+ parser.add_argument(
296
+ "--model",
297
+ default="base",
298
+ help="Whisper model size (tiny, base, small, medium, large)",
299
+ )
300
+ parser.add_argument(
301
+ "--cuda", action="store_true", help="Enable CUDA acceleration if available"
302
+ )
303
+ parser.add_argument(
304
+ "--stream",
305
+ action="store_true",
306
+ help="Stream transcriptions in real-time as they arrive",
307
+ )
308
+ args = parser.parse_args()
309
+
310
+ print("=== Whisper ASR Demo ===")
311
+ print(f"Model: {args.model}, CUDA: {args.cuda}")
312
+
313
+ # Test file transcription
314
+ if args.mode in ["file", "both"]:
315
+ print("\n--- File Transcription Test ---")
316
+ asr = WhisperAsr(model_size=args.model, enable_cuda=args.cuda)
317
+ try:
318
+ test_file = "./data/audio/test.m4a"
319
+ start_time = time.time()
320
+ text = asr.transcribe_file(test_file)
321
+ elapsed = time.time() - start_time
322
+ print(f"Transcription: {text}")
323
+ print(f"Time taken: {elapsed:.2f} seconds")
324
+ except FileNotFoundError:
325
+ print(f"No audio file found at {test_file}")
326
+
327
+ # Test microphone transcription
328
+ if args.mode in ["mic", "both"]:
329
+ print("\n--- Microphone Transcription Test ---")
330
+ print(f"Recording for {args.duration} seconds...")
331
+ print(f"Mode: {'Streaming' if args.stream else 'Batch'}")
332
+
333
+ # Create a queue to collect transcriptions
334
+ transcription_queue = queue.Queue()
335
+ asr = WhisperAsr(
336
+ model_size=args.model,
337
+ transcription_queue=transcription_queue,
338
+ enable_cuda=args.cuda,
339
+ )
340
+
341
+ start_time = time.time()
342
+ transcriptions = []
343
+
344
+ if args.stream:
345
+ # Streaming mode - show text as it arrives
346
+ print("Starting recording threads...")
347
+ asr.start_recording_streaming() # Use streaming-specific method
348
+
349
+ print("\n[STREAMING] Transcriptions as they arrive:")
350
+ print("-" * 50)
351
+
352
+ # Give recording a moment to start properly
353
+ time.sleep(0.5)
354
+
355
+ print(f"Recording status: {asr.is_recording}")
356
+ print(f"Listening for {args.duration} seconds...")
357
+
358
+ end_time = start_time + args.duration
359
+ checks = 0
360
+
361
+ try:
362
+ while time.time() < end_time:
363
+ checks += 1
364
+ # Check for new transcriptions
365
+ while not transcription_queue.empty():
366
+ try:
367
+ text = transcription_queue.get_nowait()
368
+ if text:
369
+ transcriptions.append(text)
370
+ # Stream the text immediately with timestamp
371
+ time_offset = time.time() - start_time
372
+ print(f"[{time_offset:5.1f}s] {text}")
373
+ except queue.Empty:
374
+ break
375
+
376
+ # Debug: Show we're still checking
377
+ if checks % 20 == 0: # Every second (20 * 0.05)
378
+ print(
379
+ f" ... still listening (audio_queue size: ~{asr.audio_queue.qsize()})"
380
+ )
381
+
382
+ # Small sleep to prevent CPU spinning
383
+ time.sleep(0.05)
384
+
385
+ finally:
386
+ # Stop recording
387
+ asr.stop_recording()
388
+
389
+ # Collect any remaining transcriptions
390
+ time.sleep(0.5) # Give a moment for final processing
391
+ while not transcription_queue.empty():
392
+ try:
393
+ text = transcription_queue.get_nowait()
394
+ if text:
395
+ transcriptions.append(text)
396
+ time_offset = time.time() - start_time
397
+ print(f"[{time_offset:5.1f}s] {text}")
398
+ except queue.Empty:
399
+ break
400
+
401
+ print("-" * 50)
402
+
403
+ else:
404
+ # Batch mode - collect all text then display
405
+ asr.start_recording(duration=args.duration) # Blocking
406
+
407
+ # Collect all transcriptions after recording
408
+ while not transcription_queue.empty():
409
+ try:
410
+ text = transcription_queue.get_nowait()
411
+ if text:
412
+ transcriptions.append(text)
413
+ except queue.Empty:
414
+ break
415
+
416
+ elapsed = time.time() - start_time
417
+
418
+ # Display results
419
+ print("\nResults:")
420
+ if transcriptions:
421
+ print(f" Transcription segments: {len(transcriptions)}")
422
+ if not args.stream: # Show individual segments in batch mode
423
+ for i, text in enumerate(transcriptions, 1):
424
+ print(f" {i}. {text}")
425
+ print(f" Full transcript: {' '.join(transcriptions)}")
426
+ else:
427
+ print(" No transcriptions received (possibly no speech detected)")
428
+
429
+ print(f" Total time: {elapsed:.2f} seconds")
430
+ print(f" Processing efficiency: {args.duration/elapsed:.2f}x realtime")
431
+
432
+ print("\nDemo completed!")