amd-gaia 0.15.0__py3-none-any.whl → 0.15.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/METADATA +222 -223
  2. amd_gaia-0.15.2.dist-info/RECORD +182 -0
  3. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/WHEEL +1 -1
  4. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/entry_points.txt +1 -0
  5. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/licenses/LICENSE.md +20 -20
  6. gaia/__init__.py +29 -29
  7. gaia/agents/__init__.py +19 -19
  8. gaia/agents/base/__init__.py +9 -9
  9. gaia/agents/base/agent.py +2132 -2177
  10. gaia/agents/base/api_agent.py +119 -120
  11. gaia/agents/base/console.py +1967 -1841
  12. gaia/agents/base/errors.py +237 -237
  13. gaia/agents/base/mcp_agent.py +86 -86
  14. gaia/agents/base/tools.py +88 -83
  15. gaia/agents/blender/__init__.py +7 -0
  16. gaia/agents/blender/agent.py +553 -556
  17. gaia/agents/blender/agent_simple.py +133 -135
  18. gaia/agents/blender/app.py +211 -211
  19. gaia/agents/blender/app_simple.py +41 -41
  20. gaia/agents/blender/core/__init__.py +16 -16
  21. gaia/agents/blender/core/materials.py +506 -506
  22. gaia/agents/blender/core/objects.py +316 -316
  23. gaia/agents/blender/core/rendering.py +225 -225
  24. gaia/agents/blender/core/scene.py +220 -220
  25. gaia/agents/blender/core/view.py +146 -146
  26. gaia/agents/chat/__init__.py +9 -9
  27. gaia/agents/chat/agent.py +809 -835
  28. gaia/agents/chat/app.py +1065 -1058
  29. gaia/agents/chat/session.py +508 -508
  30. gaia/agents/chat/tools/__init__.py +15 -15
  31. gaia/agents/chat/tools/file_tools.py +96 -96
  32. gaia/agents/chat/tools/rag_tools.py +1744 -1729
  33. gaia/agents/chat/tools/shell_tools.py +437 -436
  34. gaia/agents/code/__init__.py +7 -7
  35. gaia/agents/code/agent.py +549 -549
  36. gaia/agents/code/cli.py +377 -0
  37. gaia/agents/code/models.py +135 -135
  38. gaia/agents/code/orchestration/__init__.py +24 -24
  39. gaia/agents/code/orchestration/checklist_executor.py +1763 -1763
  40. gaia/agents/code/orchestration/checklist_generator.py +713 -713
  41. gaia/agents/code/orchestration/factories/__init__.py +9 -9
  42. gaia/agents/code/orchestration/factories/base.py +63 -63
  43. gaia/agents/code/orchestration/factories/nextjs_factory.py +118 -118
  44. gaia/agents/code/orchestration/factories/python_factory.py +106 -106
  45. gaia/agents/code/orchestration/orchestrator.py +841 -841
  46. gaia/agents/code/orchestration/project_analyzer.py +391 -391
  47. gaia/agents/code/orchestration/steps/__init__.py +67 -67
  48. gaia/agents/code/orchestration/steps/base.py +188 -188
  49. gaia/agents/code/orchestration/steps/error_handler.py +314 -314
  50. gaia/agents/code/orchestration/steps/nextjs.py +828 -828
  51. gaia/agents/code/orchestration/steps/python.py +307 -307
  52. gaia/agents/code/orchestration/template_catalog.py +469 -469
  53. gaia/agents/code/orchestration/workflows/__init__.py +14 -14
  54. gaia/agents/code/orchestration/workflows/base.py +80 -80
  55. gaia/agents/code/orchestration/workflows/nextjs.py +186 -186
  56. gaia/agents/code/orchestration/workflows/python.py +94 -94
  57. gaia/agents/code/prompts/__init__.py +11 -11
  58. gaia/agents/code/prompts/base_prompt.py +77 -77
  59. gaia/agents/code/prompts/code_patterns.py +2034 -2036
  60. gaia/agents/code/prompts/nextjs_prompt.py +40 -40
  61. gaia/agents/code/prompts/python_prompt.py +109 -109
  62. gaia/agents/code/schema_inference.py +365 -365
  63. gaia/agents/code/system_prompt.py +41 -41
  64. gaia/agents/code/tools/__init__.py +42 -42
  65. gaia/agents/code/tools/cli_tools.py +1138 -1138
  66. gaia/agents/code/tools/code_formatting.py +319 -319
  67. gaia/agents/code/tools/code_tools.py +769 -769
  68. gaia/agents/code/tools/error_fixing.py +1347 -1347
  69. gaia/agents/code/tools/external_tools.py +180 -180
  70. gaia/agents/code/tools/file_io.py +845 -845
  71. gaia/agents/code/tools/prisma_tools.py +190 -190
  72. gaia/agents/code/tools/project_management.py +1016 -1016
  73. gaia/agents/code/tools/testing.py +321 -321
  74. gaia/agents/code/tools/typescript_tools.py +122 -122
  75. gaia/agents/code/tools/validation_parsing.py +461 -461
  76. gaia/agents/code/tools/validation_tools.py +806 -806
  77. gaia/agents/code/tools/web_dev_tools.py +1758 -1758
  78. gaia/agents/code/validators/__init__.py +16 -16
  79. gaia/agents/code/validators/antipattern_checker.py +241 -241
  80. gaia/agents/code/validators/ast_analyzer.py +197 -197
  81. gaia/agents/code/validators/requirements_validator.py +145 -145
  82. gaia/agents/code/validators/syntax_validator.py +171 -171
  83. gaia/agents/docker/__init__.py +7 -7
  84. gaia/agents/docker/agent.py +643 -642
  85. gaia/agents/emr/__init__.py +8 -8
  86. gaia/agents/emr/agent.py +1504 -1506
  87. gaia/agents/emr/cli.py +1322 -1322
  88. gaia/agents/emr/constants.py +475 -475
  89. gaia/agents/emr/dashboard/__init__.py +4 -4
  90. gaia/agents/emr/dashboard/server.py +1972 -1974
  91. gaia/agents/jira/__init__.py +11 -11
  92. gaia/agents/jira/agent.py +894 -894
  93. gaia/agents/jira/jql_templates.py +299 -299
  94. gaia/agents/routing/__init__.py +7 -7
  95. gaia/agents/routing/agent.py +567 -570
  96. gaia/agents/routing/system_prompt.py +75 -75
  97. gaia/agents/summarize/__init__.py +11 -0
  98. gaia/agents/summarize/agent.py +885 -0
  99. gaia/agents/summarize/prompts.py +129 -0
  100. gaia/api/__init__.py +23 -23
  101. gaia/api/agent_registry.py +238 -238
  102. gaia/api/app.py +305 -305
  103. gaia/api/openai_server.py +575 -575
  104. gaia/api/schemas.py +186 -186
  105. gaia/api/sse_handler.py +373 -373
  106. gaia/apps/__init__.py +4 -4
  107. gaia/apps/llm/__init__.py +6 -6
  108. gaia/apps/llm/app.py +184 -169
  109. gaia/apps/summarize/app.py +116 -633
  110. gaia/apps/summarize/html_viewer.py +133 -133
  111. gaia/apps/summarize/pdf_formatter.py +284 -284
  112. gaia/audio/__init__.py +2 -2
  113. gaia/audio/audio_client.py +439 -439
  114. gaia/audio/audio_recorder.py +269 -269
  115. gaia/audio/kokoro_tts.py +599 -599
  116. gaia/audio/whisper_asr.py +432 -432
  117. gaia/chat/__init__.py +16 -16
  118. gaia/chat/app.py +428 -430
  119. gaia/chat/prompts.py +522 -522
  120. gaia/chat/sdk.py +1228 -1225
  121. gaia/cli.py +5659 -5632
  122. gaia/database/__init__.py +10 -10
  123. gaia/database/agent.py +176 -176
  124. gaia/database/mixin.py +290 -290
  125. gaia/database/testing.py +64 -64
  126. gaia/eval/batch_experiment.py +2332 -2332
  127. gaia/eval/claude.py +542 -542
  128. gaia/eval/config.py +37 -37
  129. gaia/eval/email_generator.py +512 -512
  130. gaia/eval/eval.py +3179 -3179
  131. gaia/eval/groundtruth.py +1130 -1130
  132. gaia/eval/transcript_generator.py +582 -582
  133. gaia/eval/webapp/README.md +167 -167
  134. gaia/eval/webapp/package-lock.json +875 -875
  135. gaia/eval/webapp/package.json +20 -20
  136. gaia/eval/webapp/public/app.js +3402 -3402
  137. gaia/eval/webapp/public/index.html +87 -87
  138. gaia/eval/webapp/public/styles.css +3661 -3661
  139. gaia/eval/webapp/server.js +415 -415
  140. gaia/eval/webapp/test-setup.js +72 -72
  141. gaia/installer/__init__.py +23 -0
  142. gaia/installer/init_command.py +1275 -0
  143. gaia/installer/lemonade_installer.py +619 -0
  144. gaia/llm/__init__.py +10 -2
  145. gaia/llm/base_client.py +60 -0
  146. gaia/llm/exceptions.py +12 -0
  147. gaia/llm/factory.py +70 -0
  148. gaia/llm/lemonade_client.py +3421 -3221
  149. gaia/llm/lemonade_manager.py +294 -294
  150. gaia/llm/providers/__init__.py +9 -0
  151. gaia/llm/providers/claude.py +108 -0
  152. gaia/llm/providers/lemonade.py +118 -0
  153. gaia/llm/providers/openai_provider.py +79 -0
  154. gaia/llm/vlm_client.py +382 -382
  155. gaia/logger.py +189 -189
  156. gaia/mcp/agent_mcp_server.py +245 -245
  157. gaia/mcp/blender_mcp_client.py +138 -138
  158. gaia/mcp/blender_mcp_server.py +648 -648
  159. gaia/mcp/context7_cache.py +332 -332
  160. gaia/mcp/external_services.py +518 -518
  161. gaia/mcp/mcp_bridge.py +811 -550
  162. gaia/mcp/servers/__init__.py +6 -6
  163. gaia/mcp/servers/docker_mcp.py +83 -83
  164. gaia/perf_analysis.py +361 -0
  165. gaia/rag/__init__.py +10 -10
  166. gaia/rag/app.py +293 -293
  167. gaia/rag/demo.py +304 -304
  168. gaia/rag/pdf_utils.py +235 -235
  169. gaia/rag/sdk.py +2194 -2194
  170. gaia/security.py +183 -163
  171. gaia/talk/app.py +287 -289
  172. gaia/talk/sdk.py +538 -538
  173. gaia/testing/__init__.py +87 -87
  174. gaia/testing/assertions.py +330 -330
  175. gaia/testing/fixtures.py +333 -333
  176. gaia/testing/mocks.py +493 -493
  177. gaia/util.py +46 -46
  178. gaia/utils/__init__.py +33 -33
  179. gaia/utils/file_watcher.py +675 -675
  180. gaia/utils/parsing.py +223 -223
  181. gaia/version.py +100 -100
  182. amd_gaia-0.15.0.dist-info/RECORD +0 -168
  183. gaia/agents/code/app.py +0 -266
  184. gaia/llm/llm_client.py +0 -723
  185. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/top_level.txt +0 -0
gaia/testing/mocks.py CHANGED
@@ -1,493 +1,493 @@
1
- # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2
- # SPDX-License-Identifier: MIT
3
-
4
- """Mock providers for testing GAIA agents without real LLM/VLM services."""
5
-
6
- import logging
7
- import time
8
- from typing import Any, Dict, Iterator, List, Optional
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- class MockLLMProvider:
14
- """
15
- Mock LLM provider for testing agents without real API calls.
16
-
17
- Returns pre-configured responses instead of calling a real LLM.
18
- Tracks all calls for test assertions.
19
-
20
- Example:
21
- from gaia.testing import MockLLMProvider
22
-
23
- mock_llm = MockLLMProvider(responses=["First response", "Second response"])
24
-
25
- # Use in tests
26
- result = mock_llm.generate("Test prompt")
27
- assert result == "First response"
28
-
29
- result = mock_llm.generate("Another prompt")
30
- assert result == "Second response"
31
-
32
- # Check call history
33
- assert len(mock_llm.call_history) == 2
34
- assert mock_llm.call_history[0]["prompt"] == "Test prompt"
35
- """
36
-
37
- def __init__(
38
- self,
39
- responses: Optional[List[str]] = None,
40
- tool_responses: Optional[Dict[str, Any]] = None,
41
- default_response: str = "Mock LLM response",
42
- ):
43
- """
44
- Initialize mock LLM provider.
45
-
46
- Args:
47
- responses: List of responses to return in sequence.
48
- Cycles back to first if more calls than responses.
49
- tool_responses: Dict mapping tool names to their mock results.
50
- Used when simulating tool calls.
51
- default_response: Response when responses list is exhausted or empty.
52
- """
53
- self.responses = responses or []
54
- self.tool_responses = tool_responses or {}
55
- self.default_response = default_response
56
- self.call_history: List[Dict[str, Any]] = []
57
- self._response_index = 0
58
-
59
- def generate(
60
- self,
61
- prompt: str,
62
- system_prompt: Optional[str] = None,
63
- temperature: float = 0.7,
64
- max_tokens: Optional[int] = None,
65
- **kwargs,
66
- ) -> str:
67
- """
68
- Generate mock response.
69
-
70
- Args:
71
- prompt: Input prompt (recorded but not processed)
72
- system_prompt: System prompt (recorded)
73
- temperature: Temperature setting (recorded)
74
- max_tokens: Max tokens (recorded)
75
- **kwargs: Additional parameters (recorded)
76
-
77
- Returns:
78
- Next response from response list, or default_response
79
- """
80
- self.call_history.append(
81
- {
82
- "method": "generate",
83
- "prompt": prompt,
84
- "system_prompt": system_prompt,
85
- "temperature": temperature,
86
- "max_tokens": max_tokens,
87
- "kwargs": kwargs,
88
- "timestamp": time.time(),
89
- }
90
- )
91
-
92
- if self.responses:
93
- response = self.responses[self._response_index % len(self.responses)]
94
- self._response_index += 1
95
- return response
96
-
97
- return self.default_response
98
-
99
- def chat(
100
- self,
101
- messages: List[Dict[str, str]],
102
- **kwargs,
103
- ) -> str:
104
- """
105
- Mock chat completion (messages format).
106
-
107
- Args:
108
- messages: List of message dicts with 'role' and 'content'
109
- **kwargs: Additional parameters
110
-
111
- Returns:
112
- Next response from response list
113
- """
114
- # Extract the last user message as the prompt
115
- prompt = ""
116
- for msg in reversed(messages):
117
- if msg.get("role") == "user":
118
- prompt = msg.get("content", "")
119
- break
120
-
121
- self.call_history.append(
122
- {
123
- "method": "chat",
124
- "messages": messages,
125
- "prompt": prompt,
126
- "kwargs": kwargs,
127
- "timestamp": time.time(),
128
- }
129
- )
130
-
131
- if self.responses:
132
- response = self.responses[self._response_index % len(self.responses)]
133
- self._response_index += 1
134
- return response
135
-
136
- return self.default_response
137
-
138
- def stream(
139
- self,
140
- prompt: str,
141
- **kwargs,
142
- ) -> Iterator[str]:
143
- """
144
- Mock streaming response.
145
-
146
- Yields the full response as a single chunk for simplicity.
147
-
148
- Args:
149
- prompt: Input prompt
150
- **kwargs: Additional parameters
151
-
152
- Yields:
153
- Response chunks (full response as single chunk)
154
- """
155
- response = self.generate(prompt, **kwargs)
156
- # Update the last call to note it was streaming
157
- if self.call_history:
158
- self.call_history[-1]["method"] = "stream"
159
- yield response
160
-
161
- def complete(self, prompt: str, **kwargs) -> str:
162
- """Alias for generate() for compatibility."""
163
- return self.generate(prompt, **kwargs)
164
-
165
- def get_tool_response(self, tool_name: str) -> Any:
166
- """
167
- Get mock response for a tool call.
168
-
169
- Args:
170
- tool_name: Name of the tool
171
-
172
- Returns:
173
- Configured mock result or default dict
174
- """
175
- return self.tool_responses.get(tool_name, {"status": "success"})
176
-
177
- @property
178
- def was_called(self) -> bool:
179
- """Check if any method was called."""
180
- return len(self.call_history) > 0
181
-
182
- @property
183
- def call_count(self) -> int:
184
- """Number of times LLM was called."""
185
- return len(self.call_history)
186
-
187
- @property
188
- def last_prompt(self) -> Optional[str]:
189
- """Get the last prompt that was sent."""
190
- if self.call_history:
191
- return self.call_history[-1].get("prompt")
192
- return None
193
-
194
- def reset(self) -> None:
195
- """Reset call history and response index."""
196
- self.call_history = []
197
- self._response_index = 0
198
-
199
- def set_responses(self, responses: List[str]) -> None:
200
- """
201
- Set new responses and reset index.
202
-
203
- Args:
204
- responses: New list of responses
205
- """
206
- self.responses = responses
207
- self._response_index = 0
208
-
209
-
210
- class MockVLMClient:
211
- """
212
- Mock VLM client for testing image processing without real API calls.
213
-
214
- Returns pre-configured text instead of processing images.
215
- Tracks all calls for test assertions.
216
-
217
- Example:
218
- from gaia.testing import MockVLMClient
219
-
220
- mock_vlm = MockVLMClient(
221
- extracted_text='{"name": "John", "dob": "1990-01-01"}'
222
- )
223
-
224
- # Inject into agent
225
- agent = MyAgent()
226
- agent.vlm = mock_vlm
227
-
228
- # Test extraction
229
- result = agent.extract_form("test.png")
230
-
231
- # Verify VLM was called
232
- assert mock_vlm.was_called
233
- assert mock_vlm.call_count == 1
234
- """
235
-
236
- def __init__(
237
- self,
238
- extracted_text: str = "Mock extracted text",
239
- extraction_results: Optional[List[str]] = None,
240
- is_available: bool = True,
241
- ):
242
- """
243
- Initialize mock VLM client.
244
-
245
- Args:
246
- extracted_text: Default text to return from extract_from_image()
247
- extraction_results: List of results to return in sequence
248
- is_available: Whether check_availability() returns True
249
- """
250
- self.extracted_text = extracted_text
251
- self.extraction_results = extraction_results or []
252
- self.is_available = is_available
253
- self.call_history: List[Dict[str, Any]] = []
254
- self._result_index = 0
255
-
256
- def check_availability(self) -> bool:
257
- """
258
- Check if VLM is available.
259
-
260
- Returns:
261
- Configured is_available value
262
- """
263
- return self.is_available
264
-
265
- def extract_from_image(
266
- self,
267
- image_bytes: bytes,
268
- prompt: Optional[str] = None,
269
- **kwargs,
270
- ) -> str:
271
- """
272
- Mock image text extraction.
273
-
274
- Args:
275
- image_bytes: Image data (recorded but not processed)
276
- prompt: Extraction prompt (recorded)
277
- **kwargs: Additional parameters
278
-
279
- Returns:
280
- Pre-configured extracted text
281
- """
282
- self.call_history.append(
283
- {
284
- "method": "extract_from_image",
285
- "image_size": len(image_bytes) if image_bytes else 0,
286
- "prompt": prompt,
287
- "kwargs": kwargs,
288
- "timestamp": time.time(),
289
- }
290
- )
291
-
292
- if self.extraction_results:
293
- result = self.extraction_results[
294
- self._result_index % len(self.extraction_results)
295
- ]
296
- self._result_index += 1
297
- return result
298
-
299
- return self.extracted_text
300
-
301
- def extract_from_file(
302
- self,
303
- file_path: str,
304
- prompt: Optional[str] = None,
305
- **kwargs,
306
- ) -> str:
307
- """
308
- Mock file-based extraction.
309
-
310
- Args:
311
- file_path: Path to image file
312
- prompt: Extraction prompt
313
-
314
- Returns:
315
- Pre-configured extracted text
316
- """
317
- self.call_history.append(
318
- {
319
- "method": "extract_from_file",
320
- "file_path": file_path,
321
- "prompt": prompt,
322
- "kwargs": kwargs,
323
- "timestamp": time.time(),
324
- }
325
- )
326
-
327
- if self.extraction_results:
328
- result = self.extraction_results[
329
- self._result_index % len(self.extraction_results)
330
- ]
331
- self._result_index += 1
332
- return result
333
-
334
- return self.extracted_text
335
-
336
- def describe_image(
337
- self,
338
- image_bytes: bytes,
339
- prompt: Optional[str] = None,
340
- **kwargs,
341
- ) -> str:
342
- """
343
- Mock image description.
344
-
345
- Args:
346
- image_bytes: Image data
347
- prompt: Description prompt
348
-
349
- Returns:
350
- Pre-configured text
351
- """
352
- return self.extract_from_image(image_bytes, prompt, **kwargs)
353
-
354
- @property
355
- def was_called(self) -> bool:
356
- """Check if any extraction method was called."""
357
- return len(self.call_history) > 0
358
-
359
- @property
360
- def call_count(self) -> int:
361
- """Number of times extraction was called."""
362
- return len(self.call_history)
363
-
364
- @property
365
- def last_prompt(self) -> Optional[str]:
366
- """Get the last prompt that was sent."""
367
- if self.call_history:
368
- return self.call_history[-1].get("prompt")
369
- return None
370
-
371
- def reset(self) -> None:
372
- """Reset call history and result index."""
373
- self.call_history = []
374
- self._result_index = 0
375
-
376
- def set_extracted_text(self, text: str) -> None:
377
- """
378
- Set new extracted text.
379
-
380
- Args:
381
- text: New text to return
382
- """
383
- self.extracted_text = text
384
-
385
-
386
- class MockToolExecutor:
387
- """
388
- Mock tool executor for testing tool calls.
389
-
390
- Tracks tool calls and returns configurable results.
391
-
392
- Example:
393
- from gaia.testing import MockToolExecutor
394
-
395
- executor = MockToolExecutor(
396
- results={
397
- "search": {"results": ["item1", "item2"]},
398
- "create_record": {"id": 123, "status": "created"},
399
- }
400
- )
401
-
402
- result = executor.execute("search", {"query": "test"})
403
- assert result == {"results": ["item1", "item2"]}
404
-
405
- assert executor.was_tool_called("search")
406
- assert executor.get_tool_args("search") == {"query": "test"}
407
- """
408
-
409
- def __init__(
410
- self,
411
- results: Optional[Dict[str, Any]] = None,
412
- default_result: Optional[Dict[str, Any]] = None,
413
- ):
414
- """
415
- Initialize mock tool executor.
416
-
417
- Args:
418
- results: Dict mapping tool names to their results
419
- default_result: Default result for unknown tools
420
- """
421
- self.results = results or {}
422
- self.default_result = default_result or {"status": "success"}
423
- self.call_history: List[Dict[str, Any]] = []
424
-
425
- def execute(self, tool_name: str, args: Dict[str, Any]) -> Any:
426
- """
427
- Execute a mock tool.
428
-
429
- Args:
430
- tool_name: Name of the tool
431
- args: Tool arguments
432
-
433
- Returns:
434
- Configured result for the tool
435
- """
436
- self.call_history.append(
437
- {
438
- "tool": tool_name,
439
- "args": args,
440
- "timestamp": time.time(),
441
- }
442
- )
443
-
444
- return self.results.get(tool_name, self.default_result)
445
-
446
- def was_tool_called(self, tool_name: str) -> bool:
447
- """
448
- Check if a specific tool was called.
449
-
450
- Args:
451
- tool_name: Name of the tool
452
-
453
- Returns:
454
- True if tool was called at least once
455
- """
456
- return any(call["tool"] == tool_name for call in self.call_history)
457
-
458
- def get_tool_calls(self, tool_name: str) -> List[Dict[str, Any]]:
459
- """
460
- Get all calls to a specific tool.
461
-
462
- Args:
463
- tool_name: Name of the tool
464
-
465
- Returns:
466
- List of call records for that tool
467
- """
468
- return [call for call in self.call_history if call["tool"] == tool_name]
469
-
470
- def get_tool_args(self, tool_name: str, call_index: int = 0) -> Optional[Dict]:
471
- """
472
- Get arguments from a specific tool call.
473
-
474
- Args:
475
- tool_name: Name of the tool
476
- call_index: Which call to get (0 = first call)
477
-
478
- Returns:
479
- Arguments dict or None if not found
480
- """
481
- calls = self.get_tool_calls(tool_name)
482
- if call_index < len(calls):
483
- return calls[call_index]["args"]
484
- return None
485
-
486
- @property
487
- def tool_names_called(self) -> List[str]:
488
- """Get list of all tool names that were called."""
489
- return list(set(call["tool"] for call in self.call_history))
490
-
491
- def reset(self) -> None:
492
- """Reset call history."""
493
- self.call_history = []
1
+ # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Mock providers for testing GAIA agents without real LLM/VLM services."""
5
+
6
+ import logging
7
+ import time
8
+ from typing import Any, Dict, Iterator, List, Optional
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class MockLLMProvider:
14
+ """
15
+ Mock LLM provider for testing agents without real API calls.
16
+
17
+ Returns pre-configured responses instead of calling a real LLM.
18
+ Tracks all calls for test assertions.
19
+
20
+ Example:
21
+ from gaia.testing import MockLLMProvider
22
+
23
+ mock_llm = MockLLMProvider(responses=["First response", "Second response"])
24
+
25
+ # Use in tests
26
+ result = mock_llm.generate("Test prompt")
27
+ assert result == "First response"
28
+
29
+ result = mock_llm.generate("Another prompt")
30
+ assert result == "Second response"
31
+
32
+ # Check call history
33
+ assert len(mock_llm.call_history) == 2
34
+ assert mock_llm.call_history[0]["prompt"] == "Test prompt"
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ responses: Optional[List[str]] = None,
40
+ tool_responses: Optional[Dict[str, Any]] = None,
41
+ default_response: str = "Mock LLM response",
42
+ ):
43
+ """
44
+ Initialize mock LLM provider.
45
+
46
+ Args:
47
+ responses: List of responses to return in sequence.
48
+ Cycles back to first if more calls than responses.
49
+ tool_responses: Dict mapping tool names to their mock results.
50
+ Used when simulating tool calls.
51
+ default_response: Response when responses list is exhausted or empty.
52
+ """
53
+ self.responses = responses or []
54
+ self.tool_responses = tool_responses or {}
55
+ self.default_response = default_response
56
+ self.call_history: List[Dict[str, Any]] = []
57
+ self._response_index = 0
58
+
59
+ def generate(
60
+ self,
61
+ prompt: str,
62
+ system_prompt: Optional[str] = None,
63
+ temperature: float = 0.7,
64
+ max_tokens: Optional[int] = None,
65
+ **kwargs,
66
+ ) -> str:
67
+ """
68
+ Generate mock response.
69
+
70
+ Args:
71
+ prompt: Input prompt (recorded but not processed)
72
+ system_prompt: System prompt (recorded)
73
+ temperature: Temperature setting (recorded)
74
+ max_tokens: Max tokens (recorded)
75
+ **kwargs: Additional parameters (recorded)
76
+
77
+ Returns:
78
+ Next response from response list, or default_response
79
+ """
80
+ self.call_history.append(
81
+ {
82
+ "method": "generate",
83
+ "prompt": prompt,
84
+ "system_prompt": system_prompt,
85
+ "temperature": temperature,
86
+ "max_tokens": max_tokens,
87
+ "kwargs": kwargs,
88
+ "timestamp": time.time(),
89
+ }
90
+ )
91
+
92
+ if self.responses:
93
+ response = self.responses[self._response_index % len(self.responses)]
94
+ self._response_index += 1
95
+ return response
96
+
97
+ return self.default_response
98
+
99
+ def chat(
100
+ self,
101
+ messages: List[Dict[str, str]],
102
+ **kwargs,
103
+ ) -> str:
104
+ """
105
+ Mock chat completion (messages format).
106
+
107
+ Args:
108
+ messages: List of message dicts with 'role' and 'content'
109
+ **kwargs: Additional parameters
110
+
111
+ Returns:
112
+ Next response from response list
113
+ """
114
+ # Extract the last user message as the prompt
115
+ prompt = ""
116
+ for msg in reversed(messages):
117
+ if msg.get("role") == "user":
118
+ prompt = msg.get("content", "")
119
+ break
120
+
121
+ self.call_history.append(
122
+ {
123
+ "method": "chat",
124
+ "messages": messages,
125
+ "prompt": prompt,
126
+ "kwargs": kwargs,
127
+ "timestamp": time.time(),
128
+ }
129
+ )
130
+
131
+ if self.responses:
132
+ response = self.responses[self._response_index % len(self.responses)]
133
+ self._response_index += 1
134
+ return response
135
+
136
+ return self.default_response
137
+
138
+ def stream(
139
+ self,
140
+ prompt: str,
141
+ **kwargs,
142
+ ) -> Iterator[str]:
143
+ """
144
+ Mock streaming response.
145
+
146
+ Yields the full response as a single chunk for simplicity.
147
+
148
+ Args:
149
+ prompt: Input prompt
150
+ **kwargs: Additional parameters
151
+
152
+ Yields:
153
+ Response chunks (full response as single chunk)
154
+ """
155
+ response = self.generate(prompt, **kwargs)
156
+ # Update the last call to note it was streaming
157
+ if self.call_history:
158
+ self.call_history[-1]["method"] = "stream"
159
+ yield response
160
+
161
+ def complete(self, prompt: str, **kwargs) -> str:
162
+ """Alias for generate() for compatibility."""
163
+ return self.generate(prompt, **kwargs)
164
+
165
+ def get_tool_response(self, tool_name: str) -> Any:
166
+ """
167
+ Get mock response for a tool call.
168
+
169
+ Args:
170
+ tool_name: Name of the tool
171
+
172
+ Returns:
173
+ Configured mock result or default dict
174
+ """
175
+ return self.tool_responses.get(tool_name, {"status": "success"})
176
+
177
+ @property
178
+ def was_called(self) -> bool:
179
+ """Check if any method was called."""
180
+ return len(self.call_history) > 0
181
+
182
+ @property
183
+ def call_count(self) -> int:
184
+ """Number of times LLM was called."""
185
+ return len(self.call_history)
186
+
187
+ @property
188
+ def last_prompt(self) -> Optional[str]:
189
+ """Get the last prompt that was sent."""
190
+ if self.call_history:
191
+ return self.call_history[-1].get("prompt")
192
+ return None
193
+
194
+ def reset(self) -> None:
195
+ """Reset call history and response index."""
196
+ self.call_history = []
197
+ self._response_index = 0
198
+
199
+ def set_responses(self, responses: List[str]) -> None:
200
+ """
201
+ Set new responses and reset index.
202
+
203
+ Args:
204
+ responses: New list of responses
205
+ """
206
+ self.responses = responses
207
+ self._response_index = 0
208
+
209
+
210
+ class MockVLMClient:
211
+ """
212
+ Mock VLM client for testing image processing without real API calls.
213
+
214
+ Returns pre-configured text instead of processing images.
215
+ Tracks all calls for test assertions.
216
+
217
+ Example:
218
+ from gaia.testing import MockVLMClient
219
+
220
+ mock_vlm = MockVLMClient(
221
+ extracted_text='{"name": "John", "dob": "1990-01-01"}'
222
+ )
223
+
224
+ # Inject into agent
225
+ agent = MyAgent()
226
+ agent.vlm = mock_vlm
227
+
228
+ # Test extraction
229
+ result = agent.extract_form("test.png")
230
+
231
+ # Verify VLM was called
232
+ assert mock_vlm.was_called
233
+ assert mock_vlm.call_count == 1
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ extracted_text: str = "Mock extracted text",
239
+ extraction_results: Optional[List[str]] = None,
240
+ is_available: bool = True,
241
+ ):
242
+ """
243
+ Initialize mock VLM client.
244
+
245
+ Args:
246
+ extracted_text: Default text to return from extract_from_image()
247
+ extraction_results: List of results to return in sequence
248
+ is_available: Whether check_availability() returns True
249
+ """
250
+ self.extracted_text = extracted_text
251
+ self.extraction_results = extraction_results or []
252
+ self.is_available = is_available
253
+ self.call_history: List[Dict[str, Any]] = []
254
+ self._result_index = 0
255
+
256
+ def check_availability(self) -> bool:
257
+ """
258
+ Check if VLM is available.
259
+
260
+ Returns:
261
+ Configured is_available value
262
+ """
263
+ return self.is_available
264
+
265
+ def extract_from_image(
266
+ self,
267
+ image_bytes: bytes,
268
+ prompt: Optional[str] = None,
269
+ **kwargs,
270
+ ) -> str:
271
+ """
272
+ Mock image text extraction.
273
+
274
+ Args:
275
+ image_bytes: Image data (recorded but not processed)
276
+ prompt: Extraction prompt (recorded)
277
+ **kwargs: Additional parameters
278
+
279
+ Returns:
280
+ Pre-configured extracted text
281
+ """
282
+ self.call_history.append(
283
+ {
284
+ "method": "extract_from_image",
285
+ "image_size": len(image_bytes) if image_bytes else 0,
286
+ "prompt": prompt,
287
+ "kwargs": kwargs,
288
+ "timestamp": time.time(),
289
+ }
290
+ )
291
+
292
+ if self.extraction_results:
293
+ result = self.extraction_results[
294
+ self._result_index % len(self.extraction_results)
295
+ ]
296
+ self._result_index += 1
297
+ return result
298
+
299
+ return self.extracted_text
300
+
301
+ def extract_from_file(
302
+ self,
303
+ file_path: str,
304
+ prompt: Optional[str] = None,
305
+ **kwargs,
306
+ ) -> str:
307
+ """
308
+ Mock file-based extraction.
309
+
310
+ Args:
311
+ file_path: Path to image file
312
+ prompt: Extraction prompt
313
+
314
+ Returns:
315
+ Pre-configured extracted text
316
+ """
317
+ self.call_history.append(
318
+ {
319
+ "method": "extract_from_file",
320
+ "file_path": file_path,
321
+ "prompt": prompt,
322
+ "kwargs": kwargs,
323
+ "timestamp": time.time(),
324
+ }
325
+ )
326
+
327
+ if self.extraction_results:
328
+ result = self.extraction_results[
329
+ self._result_index % len(self.extraction_results)
330
+ ]
331
+ self._result_index += 1
332
+ return result
333
+
334
+ return self.extracted_text
335
+
336
+ def describe_image(
337
+ self,
338
+ image_bytes: bytes,
339
+ prompt: Optional[str] = None,
340
+ **kwargs,
341
+ ) -> str:
342
+ """
343
+ Mock image description.
344
+
345
+ Args:
346
+ image_bytes: Image data
347
+ prompt: Description prompt
348
+
349
+ Returns:
350
+ Pre-configured text
351
+ """
352
+ return self.extract_from_image(image_bytes, prompt, **kwargs)
353
+
354
+ @property
355
+ def was_called(self) -> bool:
356
+ """Check if any extraction method was called."""
357
+ return len(self.call_history) > 0
358
+
359
+ @property
360
+ def call_count(self) -> int:
361
+ """Number of times extraction was called."""
362
+ return len(self.call_history)
363
+
364
+ @property
365
+ def last_prompt(self) -> Optional[str]:
366
+ """Get the last prompt that was sent."""
367
+ if self.call_history:
368
+ return self.call_history[-1].get("prompt")
369
+ return None
370
+
371
+ def reset(self) -> None:
372
+ """Reset call history and result index."""
373
+ self.call_history = []
374
+ self._result_index = 0
375
+
376
+ def set_extracted_text(self, text: str) -> None:
377
+ """
378
+ Set new extracted text.
379
+
380
+ Args:
381
+ text: New text to return
382
+ """
383
+ self.extracted_text = text
384
+
385
+
386
+ class MockToolExecutor:
387
+ """
388
+ Mock tool executor for testing tool calls.
389
+
390
+ Tracks tool calls and returns configurable results.
391
+
392
+ Example:
393
+ from gaia.testing import MockToolExecutor
394
+
395
+ executor = MockToolExecutor(
396
+ results={
397
+ "search": {"results": ["item1", "item2"]},
398
+ "create_record": {"id": 123, "status": "created"},
399
+ }
400
+ )
401
+
402
+ result = executor.execute("search", {"query": "test"})
403
+ assert result == {"results": ["item1", "item2"]}
404
+
405
+ assert executor.was_tool_called("search")
406
+ assert executor.get_tool_args("search") == {"query": "test"}
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ results: Optional[Dict[str, Any]] = None,
412
+ default_result: Optional[Dict[str, Any]] = None,
413
+ ):
414
+ """
415
+ Initialize mock tool executor.
416
+
417
+ Args:
418
+ results: Dict mapping tool names to their results
419
+ default_result: Default result for unknown tools
420
+ """
421
+ self.results = results or {}
422
+ self.default_result = default_result or {"status": "success"}
423
+ self.call_history: List[Dict[str, Any]] = []
424
+
425
+ def execute(self, tool_name: str, args: Dict[str, Any]) -> Any:
426
+ """
427
+ Execute a mock tool.
428
+
429
+ Args:
430
+ tool_name: Name of the tool
431
+ args: Tool arguments
432
+
433
+ Returns:
434
+ Configured result for the tool
435
+ """
436
+ self.call_history.append(
437
+ {
438
+ "tool": tool_name,
439
+ "args": args,
440
+ "timestamp": time.time(),
441
+ }
442
+ )
443
+
444
+ return self.results.get(tool_name, self.default_result)
445
+
446
+ def was_tool_called(self, tool_name: str) -> bool:
447
+ """
448
+ Check if a specific tool was called.
449
+
450
+ Args:
451
+ tool_name: Name of the tool
452
+
453
+ Returns:
454
+ True if tool was called at least once
455
+ """
456
+ return any(call["tool"] == tool_name for call in self.call_history)
457
+
458
+ def get_tool_calls(self, tool_name: str) -> List[Dict[str, Any]]:
459
+ """
460
+ Get all calls to a specific tool.
461
+
462
+ Args:
463
+ tool_name: Name of the tool
464
+
465
+ Returns:
466
+ List of call records for that tool
467
+ """
468
+ return [call for call in self.call_history if call["tool"] == tool_name]
469
+
470
+ def get_tool_args(self, tool_name: str, call_index: int = 0) -> Optional[Dict]:
471
+ """
472
+ Get arguments from a specific tool call.
473
+
474
+ Args:
475
+ tool_name: Name of the tool
476
+ call_index: Which call to get (0 = first call)
477
+
478
+ Returns:
479
+ Arguments dict or None if not found
480
+ """
481
+ calls = self.get_tool_calls(tool_name)
482
+ if call_index < len(calls):
483
+ return calls[call_index]["args"]
484
+ return None
485
+
486
+ @property
487
+ def tool_names_called(self) -> List[str]:
488
+ """Get list of all tool names that were called."""
489
+ return list(set(call["tool"] for call in self.call_history))
490
+
491
+ def reset(self) -> None:
492
+ """Reset call history."""
493
+ self.call_history = []