amd-gaia 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/METADATA +223 -223
  2. amd_gaia-0.15.1.dist-info/RECORD +178 -0
  3. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/entry_points.txt +1 -0
  4. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/licenses/LICENSE.md +20 -20
  5. gaia/__init__.py +29 -29
  6. gaia/agents/__init__.py +19 -19
  7. gaia/agents/base/__init__.py +9 -9
  8. gaia/agents/base/agent.py +2177 -2177
  9. gaia/agents/base/api_agent.py +120 -120
  10. gaia/agents/base/console.py +1841 -1841
  11. gaia/agents/base/errors.py +237 -237
  12. gaia/agents/base/mcp_agent.py +86 -86
  13. gaia/agents/base/tools.py +83 -83
  14. gaia/agents/blender/agent.py +556 -556
  15. gaia/agents/blender/agent_simple.py +133 -135
  16. gaia/agents/blender/app.py +211 -211
  17. gaia/agents/blender/app_simple.py +41 -41
  18. gaia/agents/blender/core/__init__.py +16 -16
  19. gaia/agents/blender/core/materials.py +506 -506
  20. gaia/agents/blender/core/objects.py +316 -316
  21. gaia/agents/blender/core/rendering.py +225 -225
  22. gaia/agents/blender/core/scene.py +220 -220
  23. gaia/agents/blender/core/view.py +146 -146
  24. gaia/agents/chat/__init__.py +9 -9
  25. gaia/agents/chat/agent.py +835 -835
  26. gaia/agents/chat/app.py +1058 -1058
  27. gaia/agents/chat/session.py +508 -508
  28. gaia/agents/chat/tools/__init__.py +15 -15
  29. gaia/agents/chat/tools/file_tools.py +96 -96
  30. gaia/agents/chat/tools/rag_tools.py +1729 -1729
  31. gaia/agents/chat/tools/shell_tools.py +436 -436
  32. gaia/agents/code/__init__.py +7 -7
  33. gaia/agents/code/agent.py +549 -549
  34. gaia/agents/code/cli.py +377 -0
  35. gaia/agents/code/models.py +135 -135
  36. gaia/agents/code/orchestration/__init__.py +24 -24
  37. gaia/agents/code/orchestration/checklist_executor.py +1763 -1763
  38. gaia/agents/code/orchestration/checklist_generator.py +713 -713
  39. gaia/agents/code/orchestration/factories/__init__.py +9 -9
  40. gaia/agents/code/orchestration/factories/base.py +63 -63
  41. gaia/agents/code/orchestration/factories/nextjs_factory.py +118 -118
  42. gaia/agents/code/orchestration/factories/python_factory.py +106 -106
  43. gaia/agents/code/orchestration/orchestrator.py +841 -841
  44. gaia/agents/code/orchestration/project_analyzer.py +391 -391
  45. gaia/agents/code/orchestration/steps/__init__.py +67 -67
  46. gaia/agents/code/orchestration/steps/base.py +188 -188
  47. gaia/agents/code/orchestration/steps/error_handler.py +314 -314
  48. gaia/agents/code/orchestration/steps/nextjs.py +828 -828
  49. gaia/agents/code/orchestration/steps/python.py +307 -307
  50. gaia/agents/code/orchestration/template_catalog.py +469 -469
  51. gaia/agents/code/orchestration/workflows/__init__.py +14 -14
  52. gaia/agents/code/orchestration/workflows/base.py +80 -80
  53. gaia/agents/code/orchestration/workflows/nextjs.py +186 -186
  54. gaia/agents/code/orchestration/workflows/python.py +94 -94
  55. gaia/agents/code/prompts/__init__.py +11 -11
  56. gaia/agents/code/prompts/base_prompt.py +77 -77
  57. gaia/agents/code/prompts/code_patterns.py +2036 -2036
  58. gaia/agents/code/prompts/nextjs_prompt.py +40 -40
  59. gaia/agents/code/prompts/python_prompt.py +109 -109
  60. gaia/agents/code/schema_inference.py +365 -365
  61. gaia/agents/code/system_prompt.py +41 -41
  62. gaia/agents/code/tools/__init__.py +42 -42
  63. gaia/agents/code/tools/cli_tools.py +1138 -1138
  64. gaia/agents/code/tools/code_formatting.py +319 -319
  65. gaia/agents/code/tools/code_tools.py +769 -769
  66. gaia/agents/code/tools/error_fixing.py +1347 -1347
  67. gaia/agents/code/tools/external_tools.py +180 -180
  68. gaia/agents/code/tools/file_io.py +845 -845
  69. gaia/agents/code/tools/prisma_tools.py +190 -190
  70. gaia/agents/code/tools/project_management.py +1016 -1016
  71. gaia/agents/code/tools/testing.py +321 -321
  72. gaia/agents/code/tools/typescript_tools.py +122 -122
  73. gaia/agents/code/tools/validation_parsing.py +461 -461
  74. gaia/agents/code/tools/validation_tools.py +806 -806
  75. gaia/agents/code/tools/web_dev_tools.py +1758 -1758
  76. gaia/agents/code/validators/__init__.py +16 -16
  77. gaia/agents/code/validators/antipattern_checker.py +241 -241
  78. gaia/agents/code/validators/ast_analyzer.py +197 -197
  79. gaia/agents/code/validators/requirements_validator.py +145 -145
  80. gaia/agents/code/validators/syntax_validator.py +171 -171
  81. gaia/agents/docker/__init__.py +7 -7
  82. gaia/agents/docker/agent.py +642 -642
  83. gaia/agents/emr/__init__.py +8 -8
  84. gaia/agents/emr/agent.py +1506 -1506
  85. gaia/agents/emr/cli.py +1322 -1322
  86. gaia/agents/emr/constants.py +475 -475
  87. gaia/agents/emr/dashboard/__init__.py +4 -4
  88. gaia/agents/emr/dashboard/server.py +1974 -1974
  89. gaia/agents/jira/__init__.py +11 -11
  90. gaia/agents/jira/agent.py +894 -894
  91. gaia/agents/jira/jql_templates.py +299 -299
  92. gaia/agents/routing/__init__.py +7 -7
  93. gaia/agents/routing/agent.py +567 -570
  94. gaia/agents/routing/system_prompt.py +75 -75
  95. gaia/agents/summarize/__init__.py +11 -0
  96. gaia/agents/summarize/agent.py +885 -0
  97. gaia/agents/summarize/prompts.py +129 -0
  98. gaia/api/__init__.py +23 -23
  99. gaia/api/agent_registry.py +238 -238
  100. gaia/api/app.py +305 -305
  101. gaia/api/openai_server.py +575 -575
  102. gaia/api/schemas.py +186 -186
  103. gaia/api/sse_handler.py +373 -373
  104. gaia/apps/__init__.py +4 -4
  105. gaia/apps/llm/__init__.py +6 -6
  106. gaia/apps/llm/app.py +173 -169
  107. gaia/apps/summarize/app.py +116 -633
  108. gaia/apps/summarize/html_viewer.py +133 -133
  109. gaia/apps/summarize/pdf_formatter.py +284 -284
  110. gaia/audio/__init__.py +2 -2
  111. gaia/audio/audio_client.py +439 -439
  112. gaia/audio/audio_recorder.py +269 -269
  113. gaia/audio/kokoro_tts.py +599 -599
  114. gaia/audio/whisper_asr.py +432 -432
  115. gaia/chat/__init__.py +16 -16
  116. gaia/chat/app.py +430 -430
  117. gaia/chat/prompts.py +522 -522
  118. gaia/chat/sdk.py +1228 -1225
  119. gaia/cli.py +5481 -5621
  120. gaia/database/__init__.py +10 -10
  121. gaia/database/agent.py +176 -176
  122. gaia/database/mixin.py +290 -290
  123. gaia/database/testing.py +64 -64
  124. gaia/eval/batch_experiment.py +2332 -2332
  125. gaia/eval/claude.py +542 -542
  126. gaia/eval/config.py +37 -37
  127. gaia/eval/email_generator.py +512 -512
  128. gaia/eval/eval.py +3179 -3179
  129. gaia/eval/groundtruth.py +1130 -1130
  130. gaia/eval/transcript_generator.py +582 -582
  131. gaia/eval/webapp/README.md +167 -167
  132. gaia/eval/webapp/package-lock.json +875 -875
  133. gaia/eval/webapp/package.json +20 -20
  134. gaia/eval/webapp/public/app.js +3402 -3402
  135. gaia/eval/webapp/public/index.html +87 -87
  136. gaia/eval/webapp/public/styles.css +3661 -3661
  137. gaia/eval/webapp/server.js +415 -415
  138. gaia/eval/webapp/test-setup.js +72 -72
  139. gaia/llm/__init__.py +9 -2
  140. gaia/llm/base_client.py +60 -0
  141. gaia/llm/exceptions.py +12 -0
  142. gaia/llm/factory.py +70 -0
  143. gaia/llm/lemonade_client.py +3236 -3221
  144. gaia/llm/lemonade_manager.py +294 -294
  145. gaia/llm/providers/__init__.py +9 -0
  146. gaia/llm/providers/claude.py +108 -0
  147. gaia/llm/providers/lemonade.py +120 -0
  148. gaia/llm/providers/openai_provider.py +79 -0
  149. gaia/llm/vlm_client.py +382 -382
  150. gaia/logger.py +189 -189
  151. gaia/mcp/agent_mcp_server.py +245 -245
  152. gaia/mcp/blender_mcp_client.py +138 -138
  153. gaia/mcp/blender_mcp_server.py +648 -648
  154. gaia/mcp/context7_cache.py +332 -332
  155. gaia/mcp/external_services.py +518 -518
  156. gaia/mcp/mcp_bridge.py +811 -550
  157. gaia/mcp/servers/__init__.py +6 -6
  158. gaia/mcp/servers/docker_mcp.py +83 -83
  159. gaia/perf_analysis.py +361 -0
  160. gaia/rag/__init__.py +10 -10
  161. gaia/rag/app.py +293 -293
  162. gaia/rag/demo.py +304 -304
  163. gaia/rag/pdf_utils.py +235 -235
  164. gaia/rag/sdk.py +2194 -2194
  165. gaia/security.py +163 -163
  166. gaia/talk/app.py +289 -289
  167. gaia/talk/sdk.py +538 -538
  168. gaia/testing/__init__.py +87 -87
  169. gaia/testing/assertions.py +330 -330
  170. gaia/testing/fixtures.py +333 -333
  171. gaia/testing/mocks.py +493 -493
  172. gaia/util.py +46 -46
  173. gaia/utils/__init__.py +33 -33
  174. gaia/utils/file_watcher.py +675 -675
  175. gaia/utils/parsing.py +223 -223
  176. gaia/version.py +100 -100
  177. amd_gaia-0.14.3.dist-info/RECORD +0 -168
  178. gaia/agents/code/app.py +0 -266
  179. gaia/llm/llm_client.py +0 -729
  180. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/WHEEL +0 -0
  181. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/top_level.txt +0 -0
gaia/audio/whisper_asr.py CHANGED
@@ -1,432 +1,432 @@
1
- # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2
- # SPDX-License-Identifier: MIT
3
-
4
- # Standard library imports
5
- import os
6
- import queue
7
- import threading
8
- import time
9
-
10
- # Third-party imports
11
- import numpy as np
12
-
13
- try:
14
- import pyaudio
15
- except ImportError:
16
- pyaudio = None
17
-
18
- try:
19
- import torch
20
- except ImportError:
21
- torch = None
22
-
23
- try:
24
- import whisper
25
- except ImportError:
26
- whisper = None
27
-
28
- from gaia.audio.audio_recorder import AudioRecorder
29
-
30
- # First-party imports
31
- from gaia.logger import get_logger
32
-
33
-
34
- class WhisperAsr(AudioRecorder):
35
- log = get_logger(__name__)
36
-
37
- def __init__(
38
- self,
39
- model_size="small",
40
- device_index=None, # Use default input device
41
- transcription_queue=None,
42
- enable_cuda=False,
43
- silence_threshold=None, # Custom silence threshold
44
- min_audio_length=None, # Custom minimum audio length
45
- ):
46
- # Check for required dependencies
47
- missing = []
48
- if pyaudio is None:
49
- missing.append("pyaudio")
50
- if torch is None:
51
- missing.append("torch")
52
- if whisper is None:
53
- missing.append("openai-whisper")
54
-
55
- if missing:
56
- error_msg = (
57
- f"\n❌ Error: Missing required talk dependencies: {', '.join(missing)}\n\n"
58
- f"Please install the talk dependencies:\n"
59
- f' uv pip install -e ".[talk]"\n\n'
60
- f"Or install packages directly:\n"
61
- f" uv pip install {' '.join(missing)}\n"
62
- )
63
- raise ImportError(error_msg)
64
-
65
- super().__init__(device_index)
66
-
67
- # Override thresholds if provided
68
- if silence_threshold is not None:
69
- self.SILENCE_THRESHOLD = silence_threshold
70
- if min_audio_length is not None:
71
- self.MIN_AUDIO_LENGTH = min_audio_length
72
- self.log = self.__class__.log
73
-
74
- # Initialize Whisper model with optimized settings
75
- self.log.debug(f"Loading Whisper model: {model_size}")
76
- self.model = whisper.load_model(model_size)
77
-
78
- # Add compute type optimization if GPU available
79
- self.using_cuda = enable_cuda and torch.cuda.is_available()
80
- if self.using_cuda:
81
- self.model.to(torch.device("cuda"))
82
- torch.set_float32_matmul_precision("high")
83
- # Enable torch compile for better performance
84
- if hasattr(torch, "compile"):
85
- self.model = torch.compile(self.model)
86
- self.log.debug("GPU acceleration enabled with optimizations")
87
-
88
- # Add batch processing capability
89
- self.batch_size = 3 # Process multiple audio segments at once
90
- self.audio_buffer = []
91
- self.last_process_time = time.time()
92
- self.process_interval = 0.5 # Process every 0.5 seconds
93
-
94
- # Rest of initialization
95
- self.transcription_queue = transcription_queue
96
-
97
- def _record_audio_streaming(self):
98
- """Record audio for streaming mode - puts chunks directly into queue."""
99
- pa = pyaudio.PyAudio()
100
-
101
- try:
102
- # Log device info
103
- if self.device_index is not None:
104
- device_info = pa.get_device_info_by_index(self.device_index)
105
- else:
106
- device_info = pa.get_default_input_device_info()
107
- self.device_index = device_info["index"]
108
-
109
- self.log.debug(
110
- f"Using audio device [{self.device_index}]: {device_info['name']}"
111
- )
112
-
113
- self.stream = pa.open(
114
- format=self.FORMAT,
115
- channels=self.CHANNELS,
116
- rate=self.RATE,
117
- input=True,
118
- input_device_index=self.device_index,
119
- frames_per_buffer=self.CHUNK,
120
- )
121
-
122
- self.log.debug("Streaming recording started...")
123
- audio_buffer = np.array([], dtype=np.float32)
124
- chunks_processed = 0
125
-
126
- # Use 3-second chunks for better context (Whisper works better with longer segments)
127
- chunk_duration = 3.0 # seconds
128
- overlap_duration = 0.5 # seconds of overlap to avoid cutting words
129
-
130
- chunk_size = int(self.RATE * chunk_duration)
131
- overlap_size = int(self.RATE * overlap_duration)
132
-
133
- # Simple VAD - only send chunks with sufficient audio energy
134
- min_energy_threshold = 0.001 # Minimum energy to consider as speech
135
-
136
- while self.is_recording:
137
- try:
138
- data = np.frombuffer(
139
- self.stream.read(self.CHUNK, exception_on_overflow=False),
140
- dtype=np.float32,
141
- )
142
- audio_buffer = np.concatenate((audio_buffer, data))
143
-
144
- # Process when we have enough audio (3 seconds)
145
- if len(audio_buffer) >= chunk_size:
146
- chunk = audio_buffer[:chunk_size].copy()
147
-
148
- # Only process if chunk has sufficient audio energy (not silence)
149
- energy = np.abs(chunk).mean()
150
- chunks_processed += 1
151
-
152
- if energy > min_energy_threshold:
153
- self.audio_queue.put(chunk)
154
- self.log.debug(
155
- f"Chunk {chunks_processed}: Added to queue (energy: {energy:.6f})"
156
- )
157
- else:
158
- self.log.debug(
159
- f"Chunk {chunks_processed}: Skipped - too quiet (energy: {energy:.6f})"
160
- )
161
-
162
- # Keep overlap to maintain context between chunks
163
- audio_buffer = audio_buffer[chunk_size - overlap_size :]
164
-
165
- except Exception as e:
166
- self.log.error(f"Error reading from stream: {e}")
167
- break
168
-
169
- # Process any remaining audio
170
- if len(audio_buffer) > self.RATE * 0.5: # At least 0.5 seconds
171
- self.audio_queue.put(audio_buffer.copy())
172
-
173
- finally:
174
- if self.stream:
175
- self.stream.stop_stream()
176
- self.stream.close()
177
- pa.terminate()
178
-
179
- def start_recording_streaming(self):
180
- """Start recording in streaming mode."""
181
- self.is_recording = True
182
- self.record_thread = threading.Thread(target=self._record_audio_streaming)
183
- self.record_thread.start()
184
- time.sleep(0.1)
185
- self.process_thread = threading.Thread(target=self._process_audio)
186
- self.process_thread.start()
187
- time.sleep(0.1)
188
-
189
- def _process_audio(self):
190
- """Internal method to process audio with batching and optimizations."""
191
- self.log.debug("Starting optimized audio processing...")
192
- processed_count = 0
193
-
194
- while self.is_recording:
195
- try:
196
- current_time = time.time()
197
-
198
- # Collect audio segments into buffer
199
- while len(self.audio_buffer) < self.batch_size:
200
- try:
201
- audio = self.audio_queue.get_nowait()
202
- if len(audio) > 0:
203
- self.audio_buffer.append(audio)
204
- self.log.debug(
205
- f"Added audio to buffer (size: {len(self.audio_buffer)}/{self.batch_size})"
206
- )
207
- except queue.Empty:
208
- break
209
-
210
- # Process batch if enough time has passed or buffer is full
211
- if len(self.audio_buffer) >= self.batch_size or (
212
- len(self.audio_buffer) > 0
213
- and current_time - self.last_process_time >= self.process_interval
214
- ):
215
-
216
- try:
217
- processed_count += 1
218
- self.log.debug(
219
- f"Processing batch {processed_count} with {len(self.audio_buffer)} segments..."
220
- )
221
-
222
- with torch.inference_mode():
223
- # Process batch of audio segments with better quality settings
224
- results = [
225
- self.model.transcribe(
226
- audio,
227
- language="en",
228
- temperature=0.0, # Deterministic, no randomness
229
- no_speech_threshold=0.6, # Higher threshold to filter noise
230
- condition_on_previous_text=False, # Don't use previous text as it can cause hallucinations
231
- beam_size=5, # Larger beam for better quality
232
- best_of=5, # More attempts for better quality
233
- fp16=self.using_cuda,
234
- suppress_blank=True, # Suppress blank outputs
235
- suppress_tokens=[-1], # Suppress special tokens
236
- without_timestamps=False, # Keep timestamps for context
237
- )
238
- for audio in self.audio_buffer
239
- ]
240
-
241
- # Send transcriptions to queue
242
- for i, result in enumerate(results):
243
- transcribed_text = result["text"].strip()
244
- if transcribed_text and self.transcription_queue:
245
- self.transcription_queue.put(transcribed_text)
246
- self.log.debug(
247
- f"Transcribed segment {i+1}: {transcribed_text}"
248
- )
249
- else:
250
- self.log.debug(f"Segment {i+1}: No text or empty")
251
-
252
- self.audio_buffer = []
253
- self.last_process_time = current_time
254
-
255
- except Exception as e:
256
- self.log.error(f"Batch transcription error: {e}")
257
- self.audio_buffer = [] # Clear buffer on error
258
-
259
- else:
260
- # Small sleep to prevent CPU spinning
261
- time.sleep(0.01)
262
-
263
- except Exception as e:
264
- self.log.error(f"Error in audio processing: {e}")
265
- if not self.is_recording:
266
- break
267
-
268
- self.log.debug("Audio processing stopped")
269
-
270
- def transcribe_file(self, file_path):
271
- """Transcribe an existing audio file."""
272
- if not os.path.exists(file_path):
273
- raise FileNotFoundError(f"Audio file not found: {file_path}")
274
-
275
- result = self.model.transcribe(file_path)
276
- return result["text"]
277
-
278
-
279
- if __name__ == "__main__":
280
- import argparse
281
-
282
- parser = argparse.ArgumentParser(description="Whisper ASR Demo")
283
- parser.add_argument(
284
- "--mode",
285
- choices=["file", "mic", "both"],
286
- default="file",
287
- help="Test mode: file, mic, or both",
288
- )
289
- parser.add_argument(
290
- "--duration",
291
- type=int,
292
- default=5,
293
- help="Recording duration in seconds for mic mode",
294
- )
295
- parser.add_argument(
296
- "--model",
297
- default="base",
298
- help="Whisper model size (tiny, base, small, medium, large)",
299
- )
300
- parser.add_argument(
301
- "--cuda", action="store_true", help="Enable CUDA acceleration if available"
302
- )
303
- parser.add_argument(
304
- "--stream",
305
- action="store_true",
306
- help="Stream transcriptions in real-time as they arrive",
307
- )
308
- args = parser.parse_args()
309
-
310
- print("=== Whisper ASR Demo ===")
311
- print(f"Model: {args.model}, CUDA: {args.cuda}")
312
-
313
- # Test file transcription
314
- if args.mode in ["file", "both"]:
315
- print("\n--- File Transcription Test ---")
316
- asr = WhisperAsr(model_size=args.model, enable_cuda=args.cuda)
317
- try:
318
- test_file = "./data/audio/test.m4a"
319
- start_time = time.time()
320
- text = asr.transcribe_file(test_file)
321
- elapsed = time.time() - start_time
322
- print(f"Transcription: {text}")
323
- print(f"Time taken: {elapsed:.2f} seconds")
324
- except FileNotFoundError:
325
- print(f"No audio file found at {test_file}")
326
-
327
- # Test microphone transcription
328
- if args.mode in ["mic", "both"]:
329
- print("\n--- Microphone Transcription Test ---")
330
- print(f"Recording for {args.duration} seconds...")
331
- print(f"Mode: {'Streaming' if args.stream else 'Batch'}")
332
-
333
- # Create a queue to collect transcriptions
334
- transcription_queue = queue.Queue()
335
- asr = WhisperAsr(
336
- model_size=args.model,
337
- transcription_queue=transcription_queue,
338
- enable_cuda=args.cuda,
339
- )
340
-
341
- start_time = time.time()
342
- transcriptions = []
343
-
344
- if args.stream:
345
- # Streaming mode - show text as it arrives
346
- print("Starting recording threads...")
347
- asr.start_recording_streaming() # Use streaming-specific method
348
-
349
- print("\n[STREAMING] Transcriptions as they arrive:")
350
- print("-" * 50)
351
-
352
- # Give recording a moment to start properly
353
- time.sleep(0.5)
354
-
355
- print(f"Recording status: {asr.is_recording}")
356
- print(f"Listening for {args.duration} seconds...")
357
-
358
- end_time = start_time + args.duration
359
- checks = 0
360
-
361
- try:
362
- while time.time() < end_time:
363
- checks += 1
364
- # Check for new transcriptions
365
- while not transcription_queue.empty():
366
- try:
367
- text = transcription_queue.get_nowait()
368
- if text:
369
- transcriptions.append(text)
370
- # Stream the text immediately with timestamp
371
- time_offset = time.time() - start_time
372
- print(f"[{time_offset:5.1f}s] {text}")
373
- except queue.Empty:
374
- break
375
-
376
- # Debug: Show we're still checking
377
- if checks % 20 == 0: # Every second (20 * 0.05)
378
- print(
379
- f" ... still listening (audio_queue size: ~{asr.audio_queue.qsize()})"
380
- )
381
-
382
- # Small sleep to prevent CPU spinning
383
- time.sleep(0.05)
384
-
385
- finally:
386
- # Stop recording
387
- asr.stop_recording()
388
-
389
- # Collect any remaining transcriptions
390
- time.sleep(0.5) # Give a moment for final processing
391
- while not transcription_queue.empty():
392
- try:
393
- text = transcription_queue.get_nowait()
394
- if text:
395
- transcriptions.append(text)
396
- time_offset = time.time() - start_time
397
- print(f"[{time_offset:5.1f}s] {text}")
398
- except queue.Empty:
399
- break
400
-
401
- print("-" * 50)
402
-
403
- else:
404
- # Batch mode - collect all text then display
405
- asr.start_recording(duration=args.duration) # Blocking
406
-
407
- # Collect all transcriptions after recording
408
- while not transcription_queue.empty():
409
- try:
410
- text = transcription_queue.get_nowait()
411
- if text:
412
- transcriptions.append(text)
413
- except queue.Empty:
414
- break
415
-
416
- elapsed = time.time() - start_time
417
-
418
- # Display results
419
- print("\nResults:")
420
- if transcriptions:
421
- print(f" Transcription segments: {len(transcriptions)}")
422
- if not args.stream: # Show individual segments in batch mode
423
- for i, text in enumerate(transcriptions, 1):
424
- print(f" {i}. {text}")
425
- print(f" Full transcript: {' '.join(transcriptions)}")
426
- else:
427
- print(" No transcriptions received (possibly no speech detected)")
428
-
429
- print(f" Total time: {elapsed:.2f} seconds")
430
- print(f" Processing efficiency: {args.duration/elapsed:.2f}x realtime")
431
-
432
- print("\nDemo completed!")
1
+ # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ # Standard library imports
5
+ import os
6
+ import queue
7
+ import threading
8
+ import time
9
+
10
+ # Third-party imports
11
+ import numpy as np
12
+
13
+ try:
14
+ import pyaudio
15
+ except ImportError:
16
+ pyaudio = None
17
+
18
+ try:
19
+ import torch
20
+ except ImportError:
21
+ torch = None
22
+
23
+ try:
24
+ import whisper
25
+ except ImportError:
26
+ whisper = None
27
+
28
+ from gaia.audio.audio_recorder import AudioRecorder
29
+
30
+ # First-party imports
31
+ from gaia.logger import get_logger
32
+
33
+
34
+ class WhisperAsr(AudioRecorder):
35
+ log = get_logger(__name__)
36
+
37
+ def __init__(
38
+ self,
39
+ model_size="small",
40
+ device_index=None, # Use default input device
41
+ transcription_queue=None,
42
+ enable_cuda=False,
43
+ silence_threshold=None, # Custom silence threshold
44
+ min_audio_length=None, # Custom minimum audio length
45
+ ):
46
+ # Check for required dependencies
47
+ missing = []
48
+ if pyaudio is None:
49
+ missing.append("pyaudio")
50
+ if torch is None:
51
+ missing.append("torch")
52
+ if whisper is None:
53
+ missing.append("openai-whisper")
54
+
55
+ if missing:
56
+ error_msg = (
57
+ f"\n❌ Error: Missing required talk dependencies: {', '.join(missing)}\n\n"
58
+ f"Please install the talk dependencies:\n"
59
+ f' uv pip install -e ".[talk]"\n\n'
60
+ f"Or install packages directly:\n"
61
+ f" uv pip install {' '.join(missing)}\n"
62
+ )
63
+ raise ImportError(error_msg)
64
+
65
+ super().__init__(device_index)
66
+
67
+ # Override thresholds if provided
68
+ if silence_threshold is not None:
69
+ self.SILENCE_THRESHOLD = silence_threshold
70
+ if min_audio_length is not None:
71
+ self.MIN_AUDIO_LENGTH = min_audio_length
72
+ self.log = self.__class__.log
73
+
74
+ # Initialize Whisper model with optimized settings
75
+ self.log.debug(f"Loading Whisper model: {model_size}")
76
+ self.model = whisper.load_model(model_size)
77
+
78
+ # Add compute type optimization if GPU available
79
+ self.using_cuda = enable_cuda and torch.cuda.is_available()
80
+ if self.using_cuda:
81
+ self.model.to(torch.device("cuda"))
82
+ torch.set_float32_matmul_precision("high")
83
+ # Enable torch compile for better performance
84
+ if hasattr(torch, "compile"):
85
+ self.model = torch.compile(self.model)
86
+ self.log.debug("GPU acceleration enabled with optimizations")
87
+
88
+ # Add batch processing capability
89
+ self.batch_size = 3 # Process multiple audio segments at once
90
+ self.audio_buffer = []
91
+ self.last_process_time = time.time()
92
+ self.process_interval = 0.5 # Process every 0.5 seconds
93
+
94
+ # Rest of initialization
95
+ self.transcription_queue = transcription_queue
96
+
97
+ def _record_audio_streaming(self):
98
+ """Record audio for streaming mode - puts chunks directly into queue."""
99
+ pa = pyaudio.PyAudio()
100
+
101
+ try:
102
+ # Log device info
103
+ if self.device_index is not None:
104
+ device_info = pa.get_device_info_by_index(self.device_index)
105
+ else:
106
+ device_info = pa.get_default_input_device_info()
107
+ self.device_index = device_info["index"]
108
+
109
+ self.log.debug(
110
+ f"Using audio device [{self.device_index}]: {device_info['name']}"
111
+ )
112
+
113
+ self.stream = pa.open(
114
+ format=self.FORMAT,
115
+ channels=self.CHANNELS,
116
+ rate=self.RATE,
117
+ input=True,
118
+ input_device_index=self.device_index,
119
+ frames_per_buffer=self.CHUNK,
120
+ )
121
+
122
+ self.log.debug("Streaming recording started...")
123
+ audio_buffer = np.array([], dtype=np.float32)
124
+ chunks_processed = 0
125
+
126
+ # Use 3-second chunks for better context (Whisper works better with longer segments)
127
+ chunk_duration = 3.0 # seconds
128
+ overlap_duration = 0.5 # seconds of overlap to avoid cutting words
129
+
130
+ chunk_size = int(self.RATE * chunk_duration)
131
+ overlap_size = int(self.RATE * overlap_duration)
132
+
133
+ # Simple VAD - only send chunks with sufficient audio energy
134
+ min_energy_threshold = 0.001 # Minimum energy to consider as speech
135
+
136
+ while self.is_recording:
137
+ try:
138
+ data = np.frombuffer(
139
+ self.stream.read(self.CHUNK, exception_on_overflow=False),
140
+ dtype=np.float32,
141
+ )
142
+ audio_buffer = np.concatenate((audio_buffer, data))
143
+
144
+ # Process when we have enough audio (3 seconds)
145
+ if len(audio_buffer) >= chunk_size:
146
+ chunk = audio_buffer[:chunk_size].copy()
147
+
148
+ # Only process if chunk has sufficient audio energy (not silence)
149
+ energy = np.abs(chunk).mean()
150
+ chunks_processed += 1
151
+
152
+ if energy > min_energy_threshold:
153
+ self.audio_queue.put(chunk)
154
+ self.log.debug(
155
+ f"Chunk {chunks_processed}: Added to queue (energy: {energy:.6f})"
156
+ )
157
+ else:
158
+ self.log.debug(
159
+ f"Chunk {chunks_processed}: Skipped - too quiet (energy: {energy:.6f})"
160
+ )
161
+
162
+ # Keep overlap to maintain context between chunks
163
+ audio_buffer = audio_buffer[chunk_size - overlap_size :]
164
+
165
+ except Exception as e:
166
+ self.log.error(f"Error reading from stream: {e}")
167
+ break
168
+
169
+ # Process any remaining audio
170
+ if len(audio_buffer) > self.RATE * 0.5: # At least 0.5 seconds
171
+ self.audio_queue.put(audio_buffer.copy())
172
+
173
+ finally:
174
+ if self.stream:
175
+ self.stream.stop_stream()
176
+ self.stream.close()
177
+ pa.terminate()
178
+
179
+ def start_recording_streaming(self):
180
+ """Start recording in streaming mode."""
181
+ self.is_recording = True
182
+ self.record_thread = threading.Thread(target=self._record_audio_streaming)
183
+ self.record_thread.start()
184
+ time.sleep(0.1)
185
+ self.process_thread = threading.Thread(target=self._process_audio)
186
+ self.process_thread.start()
187
+ time.sleep(0.1)
188
+
189
+ def _process_audio(self):
190
+ """Internal method to process audio with batching and optimizations."""
191
+ self.log.debug("Starting optimized audio processing...")
192
+ processed_count = 0
193
+
194
+ while self.is_recording:
195
+ try:
196
+ current_time = time.time()
197
+
198
+ # Collect audio segments into buffer
199
+ while len(self.audio_buffer) < self.batch_size:
200
+ try:
201
+ audio = self.audio_queue.get_nowait()
202
+ if len(audio) > 0:
203
+ self.audio_buffer.append(audio)
204
+ self.log.debug(
205
+ f"Added audio to buffer (size: {len(self.audio_buffer)}/{self.batch_size})"
206
+ )
207
+ except queue.Empty:
208
+ break
209
+
210
+ # Process batch if enough time has passed or buffer is full
211
+ if len(self.audio_buffer) >= self.batch_size or (
212
+ len(self.audio_buffer) > 0
213
+ and current_time - self.last_process_time >= self.process_interval
214
+ ):
215
+
216
+ try:
217
+ processed_count += 1
218
+ self.log.debug(
219
+ f"Processing batch {processed_count} with {len(self.audio_buffer)} segments..."
220
+ )
221
+
222
+ with torch.inference_mode():
223
+ # Process batch of audio segments with better quality settings
224
+ results = [
225
+ self.model.transcribe(
226
+ audio,
227
+ language="en",
228
+ temperature=0.0, # Deterministic, no randomness
229
+ no_speech_threshold=0.6, # Higher threshold to filter noise
230
+ condition_on_previous_text=False, # Don't use previous text as it can cause hallucinations
231
+ beam_size=5, # Larger beam for better quality
232
+ best_of=5, # More attempts for better quality
233
+ fp16=self.using_cuda,
234
+ suppress_blank=True, # Suppress blank outputs
235
+ suppress_tokens=[-1], # Suppress special tokens
236
+ without_timestamps=False, # Keep timestamps for context
237
+ )
238
+ for audio in self.audio_buffer
239
+ ]
240
+
241
+ # Send transcriptions to queue
242
+ for i, result in enumerate(results):
243
+ transcribed_text = result["text"].strip()
244
+ if transcribed_text and self.transcription_queue:
245
+ self.transcription_queue.put(transcribed_text)
246
+ self.log.debug(
247
+ f"Transcribed segment {i+1}: {transcribed_text}"
248
+ )
249
+ else:
250
+ self.log.debug(f"Segment {i+1}: No text or empty")
251
+
252
+ self.audio_buffer = []
253
+ self.last_process_time = current_time
254
+
255
+ except Exception as e:
256
+ self.log.error(f"Batch transcription error: {e}")
257
+ self.audio_buffer = [] # Clear buffer on error
258
+
259
+ else:
260
+ # Small sleep to prevent CPU spinning
261
+ time.sleep(0.01)
262
+
263
+ except Exception as e:
264
+ self.log.error(f"Error in audio processing: {e}")
265
+ if not self.is_recording:
266
+ break
267
+
268
+ self.log.debug("Audio processing stopped")
269
+
270
+ def transcribe_file(self, file_path):
271
+ """Transcribe an existing audio file."""
272
+ if not os.path.exists(file_path):
273
+ raise FileNotFoundError(f"Audio file not found: {file_path}")
274
+
275
+ result = self.model.transcribe(file_path)
276
+ return result["text"]
277
+
278
+
279
+ if __name__ == "__main__":
280
+ import argparse
281
+
282
+ parser = argparse.ArgumentParser(description="Whisper ASR Demo")
283
+ parser.add_argument(
284
+ "--mode",
285
+ choices=["file", "mic", "both"],
286
+ default="file",
287
+ help="Test mode: file, mic, or both",
288
+ )
289
+ parser.add_argument(
290
+ "--duration",
291
+ type=int,
292
+ default=5,
293
+ help="Recording duration in seconds for mic mode",
294
+ )
295
+ parser.add_argument(
296
+ "--model",
297
+ default="base",
298
+ help="Whisper model size (tiny, base, small, medium, large)",
299
+ )
300
+ parser.add_argument(
301
+ "--cuda", action="store_true", help="Enable CUDA acceleration if available"
302
+ )
303
+ parser.add_argument(
304
+ "--stream",
305
+ action="store_true",
306
+ help="Stream transcriptions in real-time as they arrive",
307
+ )
308
+ args = parser.parse_args()
309
+
310
+ print("=== Whisper ASR Demo ===")
311
+ print(f"Model: {args.model}, CUDA: {args.cuda}")
312
+
313
+ # Test file transcription
314
+ if args.mode in ["file", "both"]:
315
+ print("\n--- File Transcription Test ---")
316
+ asr = WhisperAsr(model_size=args.model, enable_cuda=args.cuda)
317
+ try:
318
+ test_file = "./data/audio/test.m4a"
319
+ start_time = time.time()
320
+ text = asr.transcribe_file(test_file)
321
+ elapsed = time.time() - start_time
322
+ print(f"Transcription: {text}")
323
+ print(f"Time taken: {elapsed:.2f} seconds")
324
+ except FileNotFoundError:
325
+ print(f"No audio file found at {test_file}")
326
+
327
+ # Test microphone transcription
328
+ if args.mode in ["mic", "both"]:
329
+ print("\n--- Microphone Transcription Test ---")
330
+ print(f"Recording for {args.duration} seconds...")
331
+ print(f"Mode: {'Streaming' if args.stream else 'Batch'}")
332
+
333
+ # Create a queue to collect transcriptions
334
+ transcription_queue = queue.Queue()
335
+ asr = WhisperAsr(
336
+ model_size=args.model,
337
+ transcription_queue=transcription_queue,
338
+ enable_cuda=args.cuda,
339
+ )
340
+
341
+ start_time = time.time()
342
+ transcriptions = []
343
+
344
+ if args.stream:
345
+ # Streaming mode - show text as it arrives
346
+ print("Starting recording threads...")
347
+ asr.start_recording_streaming() # Use streaming-specific method
348
+
349
+ print("\n[STREAMING] Transcriptions as they arrive:")
350
+ print("-" * 50)
351
+
352
+ # Give recording a moment to start properly
353
+ time.sleep(0.5)
354
+
355
+ print(f"Recording status: {asr.is_recording}")
356
+ print(f"Listening for {args.duration} seconds...")
357
+
358
+ end_time = start_time + args.duration
359
+ checks = 0
360
+
361
+ try:
362
+ while time.time() < end_time:
363
+ checks += 1
364
+ # Check for new transcriptions
365
+ while not transcription_queue.empty():
366
+ try:
367
+ text = transcription_queue.get_nowait()
368
+ if text:
369
+ transcriptions.append(text)
370
+ # Stream the text immediately with timestamp
371
+ time_offset = time.time() - start_time
372
+ print(f"[{time_offset:5.1f}s] {text}")
373
+ except queue.Empty:
374
+ break
375
+
376
+ # Debug: Show we're still checking
377
+ if checks % 20 == 0: # Every second (20 * 0.05)
378
+ print(
379
+ f" ... still listening (audio_queue size: ~{asr.audio_queue.qsize()})"
380
+ )
381
+
382
+ # Small sleep to prevent CPU spinning
383
+ time.sleep(0.05)
384
+
385
+ finally:
386
+ # Stop recording
387
+ asr.stop_recording()
388
+
389
+ # Collect any remaining transcriptions
390
+ time.sleep(0.5) # Give a moment for final processing
391
+ while not transcription_queue.empty():
392
+ try:
393
+ text = transcription_queue.get_nowait()
394
+ if text:
395
+ transcriptions.append(text)
396
+ time_offset = time.time() - start_time
397
+ print(f"[{time_offset:5.1f}s] {text}")
398
+ except queue.Empty:
399
+ break
400
+
401
+ print("-" * 50)
402
+
403
+ else:
404
+ # Batch mode - collect all text then display
405
+ asr.start_recording(duration=args.duration) # Blocking
406
+
407
+ # Collect all transcriptions after recording
408
+ while not transcription_queue.empty():
409
+ try:
410
+ text = transcription_queue.get_nowait()
411
+ if text:
412
+ transcriptions.append(text)
413
+ except queue.Empty:
414
+ break
415
+
416
+ elapsed = time.time() - start_time
417
+
418
+ # Display results
419
+ print("\nResults:")
420
+ if transcriptions:
421
+ print(f" Transcription segments: {len(transcriptions)}")
422
+ if not args.stream: # Show individual segments in batch mode
423
+ for i, text in enumerate(transcriptions, 1):
424
+ print(f" {i}. {text}")
425
+ print(f" Full transcript: {' '.join(transcriptions)}")
426
+ else:
427
+ print(" No transcriptions received (possibly no speech detected)")
428
+
429
+ print(f" Total time: {elapsed:.2f} seconds")
430
+ print(f" Processing efficiency: {args.duration/elapsed:.2f}x realtime")
431
+
432
+ print("\nDemo completed!")