amd-gaia 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/METADATA +223 -223
  2. amd_gaia-0.15.1.dist-info/RECORD +178 -0
  3. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/entry_points.txt +1 -0
  4. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/licenses/LICENSE.md +20 -20
  5. gaia/__init__.py +29 -29
  6. gaia/agents/__init__.py +19 -19
  7. gaia/agents/base/__init__.py +9 -9
  8. gaia/agents/base/agent.py +2177 -2177
  9. gaia/agents/base/api_agent.py +120 -120
  10. gaia/agents/base/console.py +1841 -1841
  11. gaia/agents/base/errors.py +237 -237
  12. gaia/agents/base/mcp_agent.py +86 -86
  13. gaia/agents/base/tools.py +83 -83
  14. gaia/agents/blender/agent.py +556 -556
  15. gaia/agents/blender/agent_simple.py +133 -135
  16. gaia/agents/blender/app.py +211 -211
  17. gaia/agents/blender/app_simple.py +41 -41
  18. gaia/agents/blender/core/__init__.py +16 -16
  19. gaia/agents/blender/core/materials.py +506 -506
  20. gaia/agents/blender/core/objects.py +316 -316
  21. gaia/agents/blender/core/rendering.py +225 -225
  22. gaia/agents/blender/core/scene.py +220 -220
  23. gaia/agents/blender/core/view.py +146 -146
  24. gaia/agents/chat/__init__.py +9 -9
  25. gaia/agents/chat/agent.py +835 -835
  26. gaia/agents/chat/app.py +1058 -1058
  27. gaia/agents/chat/session.py +508 -508
  28. gaia/agents/chat/tools/__init__.py +15 -15
  29. gaia/agents/chat/tools/file_tools.py +96 -96
  30. gaia/agents/chat/tools/rag_tools.py +1729 -1729
  31. gaia/agents/chat/tools/shell_tools.py +436 -436
  32. gaia/agents/code/__init__.py +7 -7
  33. gaia/agents/code/agent.py +549 -549
  34. gaia/agents/code/cli.py +377 -0
  35. gaia/agents/code/models.py +135 -135
  36. gaia/agents/code/orchestration/__init__.py +24 -24
  37. gaia/agents/code/orchestration/checklist_executor.py +1763 -1763
  38. gaia/agents/code/orchestration/checklist_generator.py +713 -713
  39. gaia/agents/code/orchestration/factories/__init__.py +9 -9
  40. gaia/agents/code/orchestration/factories/base.py +63 -63
  41. gaia/agents/code/orchestration/factories/nextjs_factory.py +118 -118
  42. gaia/agents/code/orchestration/factories/python_factory.py +106 -106
  43. gaia/agents/code/orchestration/orchestrator.py +841 -841
  44. gaia/agents/code/orchestration/project_analyzer.py +391 -391
  45. gaia/agents/code/orchestration/steps/__init__.py +67 -67
  46. gaia/agents/code/orchestration/steps/base.py +188 -188
  47. gaia/agents/code/orchestration/steps/error_handler.py +314 -314
  48. gaia/agents/code/orchestration/steps/nextjs.py +828 -828
  49. gaia/agents/code/orchestration/steps/python.py +307 -307
  50. gaia/agents/code/orchestration/template_catalog.py +469 -469
  51. gaia/agents/code/orchestration/workflows/__init__.py +14 -14
  52. gaia/agents/code/orchestration/workflows/base.py +80 -80
  53. gaia/agents/code/orchestration/workflows/nextjs.py +186 -186
  54. gaia/agents/code/orchestration/workflows/python.py +94 -94
  55. gaia/agents/code/prompts/__init__.py +11 -11
  56. gaia/agents/code/prompts/base_prompt.py +77 -77
  57. gaia/agents/code/prompts/code_patterns.py +2036 -2036
  58. gaia/agents/code/prompts/nextjs_prompt.py +40 -40
  59. gaia/agents/code/prompts/python_prompt.py +109 -109
  60. gaia/agents/code/schema_inference.py +365 -365
  61. gaia/agents/code/system_prompt.py +41 -41
  62. gaia/agents/code/tools/__init__.py +42 -42
  63. gaia/agents/code/tools/cli_tools.py +1138 -1138
  64. gaia/agents/code/tools/code_formatting.py +319 -319
  65. gaia/agents/code/tools/code_tools.py +769 -769
  66. gaia/agents/code/tools/error_fixing.py +1347 -1347
  67. gaia/agents/code/tools/external_tools.py +180 -180
  68. gaia/agents/code/tools/file_io.py +845 -845
  69. gaia/agents/code/tools/prisma_tools.py +190 -190
  70. gaia/agents/code/tools/project_management.py +1016 -1016
  71. gaia/agents/code/tools/testing.py +321 -321
  72. gaia/agents/code/tools/typescript_tools.py +122 -122
  73. gaia/agents/code/tools/validation_parsing.py +461 -461
  74. gaia/agents/code/tools/validation_tools.py +806 -806
  75. gaia/agents/code/tools/web_dev_tools.py +1758 -1758
  76. gaia/agents/code/validators/__init__.py +16 -16
  77. gaia/agents/code/validators/antipattern_checker.py +241 -241
  78. gaia/agents/code/validators/ast_analyzer.py +197 -197
  79. gaia/agents/code/validators/requirements_validator.py +145 -145
  80. gaia/agents/code/validators/syntax_validator.py +171 -171
  81. gaia/agents/docker/__init__.py +7 -7
  82. gaia/agents/docker/agent.py +642 -642
  83. gaia/agents/emr/__init__.py +8 -8
  84. gaia/agents/emr/agent.py +1506 -1506
  85. gaia/agents/emr/cli.py +1322 -1322
  86. gaia/agents/emr/constants.py +475 -475
  87. gaia/agents/emr/dashboard/__init__.py +4 -4
  88. gaia/agents/emr/dashboard/server.py +1974 -1974
  89. gaia/agents/jira/__init__.py +11 -11
  90. gaia/agents/jira/agent.py +894 -894
  91. gaia/agents/jira/jql_templates.py +299 -299
  92. gaia/agents/routing/__init__.py +7 -7
  93. gaia/agents/routing/agent.py +567 -570
  94. gaia/agents/routing/system_prompt.py +75 -75
  95. gaia/agents/summarize/__init__.py +11 -0
  96. gaia/agents/summarize/agent.py +885 -0
  97. gaia/agents/summarize/prompts.py +129 -0
  98. gaia/api/__init__.py +23 -23
  99. gaia/api/agent_registry.py +238 -238
  100. gaia/api/app.py +305 -305
  101. gaia/api/openai_server.py +575 -575
  102. gaia/api/schemas.py +186 -186
  103. gaia/api/sse_handler.py +373 -373
  104. gaia/apps/__init__.py +4 -4
  105. gaia/apps/llm/__init__.py +6 -6
  106. gaia/apps/llm/app.py +173 -169
  107. gaia/apps/summarize/app.py +116 -633
  108. gaia/apps/summarize/html_viewer.py +133 -133
  109. gaia/apps/summarize/pdf_formatter.py +284 -284
  110. gaia/audio/__init__.py +2 -2
  111. gaia/audio/audio_client.py +439 -439
  112. gaia/audio/audio_recorder.py +269 -269
  113. gaia/audio/kokoro_tts.py +599 -599
  114. gaia/audio/whisper_asr.py +432 -432
  115. gaia/chat/__init__.py +16 -16
  116. gaia/chat/app.py +430 -430
  117. gaia/chat/prompts.py +522 -522
  118. gaia/chat/sdk.py +1228 -1225
  119. gaia/cli.py +5481 -5621
  120. gaia/database/__init__.py +10 -10
  121. gaia/database/agent.py +176 -176
  122. gaia/database/mixin.py +290 -290
  123. gaia/database/testing.py +64 -64
  124. gaia/eval/batch_experiment.py +2332 -2332
  125. gaia/eval/claude.py +542 -542
  126. gaia/eval/config.py +37 -37
  127. gaia/eval/email_generator.py +512 -512
  128. gaia/eval/eval.py +3179 -3179
  129. gaia/eval/groundtruth.py +1130 -1130
  130. gaia/eval/transcript_generator.py +582 -582
  131. gaia/eval/webapp/README.md +167 -167
  132. gaia/eval/webapp/package-lock.json +875 -875
  133. gaia/eval/webapp/package.json +20 -20
  134. gaia/eval/webapp/public/app.js +3402 -3402
  135. gaia/eval/webapp/public/index.html +87 -87
  136. gaia/eval/webapp/public/styles.css +3661 -3661
  137. gaia/eval/webapp/server.js +415 -415
  138. gaia/eval/webapp/test-setup.js +72 -72
  139. gaia/llm/__init__.py +9 -2
  140. gaia/llm/base_client.py +60 -0
  141. gaia/llm/exceptions.py +12 -0
  142. gaia/llm/factory.py +70 -0
  143. gaia/llm/lemonade_client.py +3236 -3221
  144. gaia/llm/lemonade_manager.py +294 -294
  145. gaia/llm/providers/__init__.py +9 -0
  146. gaia/llm/providers/claude.py +108 -0
  147. gaia/llm/providers/lemonade.py +120 -0
  148. gaia/llm/providers/openai_provider.py +79 -0
  149. gaia/llm/vlm_client.py +382 -382
  150. gaia/logger.py +189 -189
  151. gaia/mcp/agent_mcp_server.py +245 -245
  152. gaia/mcp/blender_mcp_client.py +138 -138
  153. gaia/mcp/blender_mcp_server.py +648 -648
  154. gaia/mcp/context7_cache.py +332 -332
  155. gaia/mcp/external_services.py +518 -518
  156. gaia/mcp/mcp_bridge.py +811 -550
  157. gaia/mcp/servers/__init__.py +6 -6
  158. gaia/mcp/servers/docker_mcp.py +83 -83
  159. gaia/perf_analysis.py +361 -0
  160. gaia/rag/__init__.py +10 -10
  161. gaia/rag/app.py +293 -293
  162. gaia/rag/demo.py +304 -304
  163. gaia/rag/pdf_utils.py +235 -235
  164. gaia/rag/sdk.py +2194 -2194
  165. gaia/security.py +163 -163
  166. gaia/talk/app.py +289 -289
  167. gaia/talk/sdk.py +538 -538
  168. gaia/testing/__init__.py +87 -87
  169. gaia/testing/assertions.py +330 -330
  170. gaia/testing/fixtures.py +333 -333
  171. gaia/testing/mocks.py +493 -493
  172. gaia/util.py +46 -46
  173. gaia/utils/__init__.py +33 -33
  174. gaia/utils/file_watcher.py +675 -675
  175. gaia/utils/parsing.py +223 -223
  176. gaia/version.py +100 -100
  177. amd_gaia-0.14.3.dist-info/RECORD +0 -168
  178. gaia/agents/code/app.py +0 -266
  179. gaia/llm/llm_client.py +0 -729
  180. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/WHEEL +0 -0
  181. {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/top_level.txt +0 -0
gaia/testing/mocks.py CHANGED
@@ -1,493 +1,493 @@
1
- # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2
- # SPDX-License-Identifier: MIT
3
-
4
- """Mock providers for testing GAIA agents without real LLM/VLM services."""
5
-
6
- import logging
7
- import time
8
- from typing import Any, Dict, Iterator, List, Optional
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- class MockLLMProvider:
14
- """
15
- Mock LLM provider for testing agents without real API calls.
16
-
17
- Returns pre-configured responses instead of calling a real LLM.
18
- Tracks all calls for test assertions.
19
-
20
- Example:
21
- from gaia.testing import MockLLMProvider
22
-
23
- mock_llm = MockLLMProvider(responses=["First response", "Second response"])
24
-
25
- # Use in tests
26
- result = mock_llm.generate("Test prompt")
27
- assert result == "First response"
28
-
29
- result = mock_llm.generate("Another prompt")
30
- assert result == "Second response"
31
-
32
- # Check call history
33
- assert len(mock_llm.call_history) == 2
34
- assert mock_llm.call_history[0]["prompt"] == "Test prompt"
35
- """
36
-
37
- def __init__(
38
- self,
39
- responses: Optional[List[str]] = None,
40
- tool_responses: Optional[Dict[str, Any]] = None,
41
- default_response: str = "Mock LLM response",
42
- ):
43
- """
44
- Initialize mock LLM provider.
45
-
46
- Args:
47
- responses: List of responses to return in sequence.
48
- Cycles back to first if more calls than responses.
49
- tool_responses: Dict mapping tool names to their mock results.
50
- Used when simulating tool calls.
51
- default_response: Response when responses list is exhausted or empty.
52
- """
53
- self.responses = responses or []
54
- self.tool_responses = tool_responses or {}
55
- self.default_response = default_response
56
- self.call_history: List[Dict[str, Any]] = []
57
- self._response_index = 0
58
-
59
- def generate(
60
- self,
61
- prompt: str,
62
- system_prompt: Optional[str] = None,
63
- temperature: float = 0.7,
64
- max_tokens: Optional[int] = None,
65
- **kwargs,
66
- ) -> str:
67
- """
68
- Generate mock response.
69
-
70
- Args:
71
- prompt: Input prompt (recorded but not processed)
72
- system_prompt: System prompt (recorded)
73
- temperature: Temperature setting (recorded)
74
- max_tokens: Max tokens (recorded)
75
- **kwargs: Additional parameters (recorded)
76
-
77
- Returns:
78
- Next response from response list, or default_response
79
- """
80
- self.call_history.append(
81
- {
82
- "method": "generate",
83
- "prompt": prompt,
84
- "system_prompt": system_prompt,
85
- "temperature": temperature,
86
- "max_tokens": max_tokens,
87
- "kwargs": kwargs,
88
- "timestamp": time.time(),
89
- }
90
- )
91
-
92
- if self.responses:
93
- response = self.responses[self._response_index % len(self.responses)]
94
- self._response_index += 1
95
- return response
96
-
97
- return self.default_response
98
-
99
- def chat(
100
- self,
101
- messages: List[Dict[str, str]],
102
- **kwargs,
103
- ) -> str:
104
- """
105
- Mock chat completion (messages format).
106
-
107
- Args:
108
- messages: List of message dicts with 'role' and 'content'
109
- **kwargs: Additional parameters
110
-
111
- Returns:
112
- Next response from response list
113
- """
114
- # Extract the last user message as the prompt
115
- prompt = ""
116
- for msg in reversed(messages):
117
- if msg.get("role") == "user":
118
- prompt = msg.get("content", "")
119
- break
120
-
121
- self.call_history.append(
122
- {
123
- "method": "chat",
124
- "messages": messages,
125
- "prompt": prompt,
126
- "kwargs": kwargs,
127
- "timestamp": time.time(),
128
- }
129
- )
130
-
131
- if self.responses:
132
- response = self.responses[self._response_index % len(self.responses)]
133
- self._response_index += 1
134
- return response
135
-
136
- return self.default_response
137
-
138
- def stream(
139
- self,
140
- prompt: str,
141
- **kwargs,
142
- ) -> Iterator[str]:
143
- """
144
- Mock streaming response.
145
-
146
- Yields the full response as a single chunk for simplicity.
147
-
148
- Args:
149
- prompt: Input prompt
150
- **kwargs: Additional parameters
151
-
152
- Yields:
153
- Response chunks (full response as single chunk)
154
- """
155
- response = self.generate(prompt, **kwargs)
156
- # Update the last call to note it was streaming
157
- if self.call_history:
158
- self.call_history[-1]["method"] = "stream"
159
- yield response
160
-
161
- def complete(self, prompt: str, **kwargs) -> str:
162
- """Alias for generate() for compatibility."""
163
- return self.generate(prompt, **kwargs)
164
-
165
- def get_tool_response(self, tool_name: str) -> Any:
166
- """
167
- Get mock response for a tool call.
168
-
169
- Args:
170
- tool_name: Name of the tool
171
-
172
- Returns:
173
- Configured mock result or default dict
174
- """
175
- return self.tool_responses.get(tool_name, {"status": "success"})
176
-
177
- @property
178
- def was_called(self) -> bool:
179
- """Check if any method was called."""
180
- return len(self.call_history) > 0
181
-
182
- @property
183
- def call_count(self) -> int:
184
- """Number of times LLM was called."""
185
- return len(self.call_history)
186
-
187
- @property
188
- def last_prompt(self) -> Optional[str]:
189
- """Get the last prompt that was sent."""
190
- if self.call_history:
191
- return self.call_history[-1].get("prompt")
192
- return None
193
-
194
- def reset(self) -> None:
195
- """Reset call history and response index."""
196
- self.call_history = []
197
- self._response_index = 0
198
-
199
- def set_responses(self, responses: List[str]) -> None:
200
- """
201
- Set new responses and reset index.
202
-
203
- Args:
204
- responses: New list of responses
205
- """
206
- self.responses = responses
207
- self._response_index = 0
208
-
209
-
210
- class MockVLMClient:
211
- """
212
- Mock VLM client for testing image processing without real API calls.
213
-
214
- Returns pre-configured text instead of processing images.
215
- Tracks all calls for test assertions.
216
-
217
- Example:
218
- from gaia.testing import MockVLMClient
219
-
220
- mock_vlm = MockVLMClient(
221
- extracted_text='{"name": "John", "dob": "1990-01-01"}'
222
- )
223
-
224
- # Inject into agent
225
- agent = MyAgent()
226
- agent.vlm = mock_vlm
227
-
228
- # Test extraction
229
- result = agent.extract_form("test.png")
230
-
231
- # Verify VLM was called
232
- assert mock_vlm.was_called
233
- assert mock_vlm.call_count == 1
234
- """
235
-
236
- def __init__(
237
- self,
238
- extracted_text: str = "Mock extracted text",
239
- extraction_results: Optional[List[str]] = None,
240
- is_available: bool = True,
241
- ):
242
- """
243
- Initialize mock VLM client.
244
-
245
- Args:
246
- extracted_text: Default text to return from extract_from_image()
247
- extraction_results: List of results to return in sequence
248
- is_available: Whether check_availability() returns True
249
- """
250
- self.extracted_text = extracted_text
251
- self.extraction_results = extraction_results or []
252
- self.is_available = is_available
253
- self.call_history: List[Dict[str, Any]] = []
254
- self._result_index = 0
255
-
256
- def check_availability(self) -> bool:
257
- """
258
- Check if VLM is available.
259
-
260
- Returns:
261
- Configured is_available value
262
- """
263
- return self.is_available
264
-
265
- def extract_from_image(
266
- self,
267
- image_bytes: bytes,
268
- prompt: Optional[str] = None,
269
- **kwargs,
270
- ) -> str:
271
- """
272
- Mock image text extraction.
273
-
274
- Args:
275
- image_bytes: Image data (recorded but not processed)
276
- prompt: Extraction prompt (recorded)
277
- **kwargs: Additional parameters
278
-
279
- Returns:
280
- Pre-configured extracted text
281
- """
282
- self.call_history.append(
283
- {
284
- "method": "extract_from_image",
285
- "image_size": len(image_bytes) if image_bytes else 0,
286
- "prompt": prompt,
287
- "kwargs": kwargs,
288
- "timestamp": time.time(),
289
- }
290
- )
291
-
292
- if self.extraction_results:
293
- result = self.extraction_results[
294
- self._result_index % len(self.extraction_results)
295
- ]
296
- self._result_index += 1
297
- return result
298
-
299
- return self.extracted_text
300
-
301
- def extract_from_file(
302
- self,
303
- file_path: str,
304
- prompt: Optional[str] = None,
305
- **kwargs,
306
- ) -> str:
307
- """
308
- Mock file-based extraction.
309
-
310
- Args:
311
- file_path: Path to image file
312
- prompt: Extraction prompt
313
-
314
- Returns:
315
- Pre-configured extracted text
316
- """
317
- self.call_history.append(
318
- {
319
- "method": "extract_from_file",
320
- "file_path": file_path,
321
- "prompt": prompt,
322
- "kwargs": kwargs,
323
- "timestamp": time.time(),
324
- }
325
- )
326
-
327
- if self.extraction_results:
328
- result = self.extraction_results[
329
- self._result_index % len(self.extraction_results)
330
- ]
331
- self._result_index += 1
332
- return result
333
-
334
- return self.extracted_text
335
-
336
- def describe_image(
337
- self,
338
- image_bytes: bytes,
339
- prompt: Optional[str] = None,
340
- **kwargs,
341
- ) -> str:
342
- """
343
- Mock image description.
344
-
345
- Args:
346
- image_bytes: Image data
347
- prompt: Description prompt
348
-
349
- Returns:
350
- Pre-configured text
351
- """
352
- return self.extract_from_image(image_bytes, prompt, **kwargs)
353
-
354
- @property
355
- def was_called(self) -> bool:
356
- """Check if any extraction method was called."""
357
- return len(self.call_history) > 0
358
-
359
- @property
360
- def call_count(self) -> int:
361
- """Number of times extraction was called."""
362
- return len(self.call_history)
363
-
364
- @property
365
- def last_prompt(self) -> Optional[str]:
366
- """Get the last prompt that was sent."""
367
- if self.call_history:
368
- return self.call_history[-1].get("prompt")
369
- return None
370
-
371
- def reset(self) -> None:
372
- """Reset call history and result index."""
373
- self.call_history = []
374
- self._result_index = 0
375
-
376
- def set_extracted_text(self, text: str) -> None:
377
- """
378
- Set new extracted text.
379
-
380
- Args:
381
- text: New text to return
382
- """
383
- self.extracted_text = text
384
-
385
-
386
- class MockToolExecutor:
387
- """
388
- Mock tool executor for testing tool calls.
389
-
390
- Tracks tool calls and returns configurable results.
391
-
392
- Example:
393
- from gaia.testing import MockToolExecutor
394
-
395
- executor = MockToolExecutor(
396
- results={
397
- "search": {"results": ["item1", "item2"]},
398
- "create_record": {"id": 123, "status": "created"},
399
- }
400
- )
401
-
402
- result = executor.execute("search", {"query": "test"})
403
- assert result == {"results": ["item1", "item2"]}
404
-
405
- assert executor.was_tool_called("search")
406
- assert executor.get_tool_args("search") == {"query": "test"}
407
- """
408
-
409
- def __init__(
410
- self,
411
- results: Optional[Dict[str, Any]] = None,
412
- default_result: Optional[Dict[str, Any]] = None,
413
- ):
414
- """
415
- Initialize mock tool executor.
416
-
417
- Args:
418
- results: Dict mapping tool names to their results
419
- default_result: Default result for unknown tools
420
- """
421
- self.results = results or {}
422
- self.default_result = default_result or {"status": "success"}
423
- self.call_history: List[Dict[str, Any]] = []
424
-
425
- def execute(self, tool_name: str, args: Dict[str, Any]) -> Any:
426
- """
427
- Execute a mock tool.
428
-
429
- Args:
430
- tool_name: Name of the tool
431
- args: Tool arguments
432
-
433
- Returns:
434
- Configured result for the tool
435
- """
436
- self.call_history.append(
437
- {
438
- "tool": tool_name,
439
- "args": args,
440
- "timestamp": time.time(),
441
- }
442
- )
443
-
444
- return self.results.get(tool_name, self.default_result)
445
-
446
- def was_tool_called(self, tool_name: str) -> bool:
447
- """
448
- Check if a specific tool was called.
449
-
450
- Args:
451
- tool_name: Name of the tool
452
-
453
- Returns:
454
- True if tool was called at least once
455
- """
456
- return any(call["tool"] == tool_name for call in self.call_history)
457
-
458
- def get_tool_calls(self, tool_name: str) -> List[Dict[str, Any]]:
459
- """
460
- Get all calls to a specific tool.
461
-
462
- Args:
463
- tool_name: Name of the tool
464
-
465
- Returns:
466
- List of call records for that tool
467
- """
468
- return [call for call in self.call_history if call["tool"] == tool_name]
469
-
470
- def get_tool_args(self, tool_name: str, call_index: int = 0) -> Optional[Dict]:
471
- """
472
- Get arguments from a specific tool call.
473
-
474
- Args:
475
- tool_name: Name of the tool
476
- call_index: Which call to get (0 = first call)
477
-
478
- Returns:
479
- Arguments dict or None if not found
480
- """
481
- calls = self.get_tool_calls(tool_name)
482
- if call_index < len(calls):
483
- return calls[call_index]["args"]
484
- return None
485
-
486
- @property
487
- def tool_names_called(self) -> List[str]:
488
- """Get list of all tool names that were called."""
489
- return list(set(call["tool"] for call in self.call_history))
490
-
491
- def reset(self) -> None:
492
- """Reset call history."""
493
- self.call_history = []
1
+ # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Mock providers for testing GAIA agents without real LLM/VLM services."""
5
+
6
+ import logging
7
+ import time
8
+ from typing import Any, Dict, Iterator, List, Optional
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class MockLLMProvider:
14
+ """
15
+ Mock LLM provider for testing agents without real API calls.
16
+
17
+ Returns pre-configured responses instead of calling a real LLM.
18
+ Tracks all calls for test assertions.
19
+
20
+ Example:
21
+ from gaia.testing import MockLLMProvider
22
+
23
+ mock_llm = MockLLMProvider(responses=["First response", "Second response"])
24
+
25
+ # Use in tests
26
+ result = mock_llm.generate("Test prompt")
27
+ assert result == "First response"
28
+
29
+ result = mock_llm.generate("Another prompt")
30
+ assert result == "Second response"
31
+
32
+ # Check call history
33
+ assert len(mock_llm.call_history) == 2
34
+ assert mock_llm.call_history[0]["prompt"] == "Test prompt"
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ responses: Optional[List[str]] = None,
40
+ tool_responses: Optional[Dict[str, Any]] = None,
41
+ default_response: str = "Mock LLM response",
42
+ ):
43
+ """
44
+ Initialize mock LLM provider.
45
+
46
+ Args:
47
+ responses: List of responses to return in sequence.
48
+ Cycles back to first if more calls than responses.
49
+ tool_responses: Dict mapping tool names to their mock results.
50
+ Used when simulating tool calls.
51
+ default_response: Response when responses list is exhausted or empty.
52
+ """
53
+ self.responses = responses or []
54
+ self.tool_responses = tool_responses or {}
55
+ self.default_response = default_response
56
+ self.call_history: List[Dict[str, Any]] = []
57
+ self._response_index = 0
58
+
59
+ def generate(
60
+ self,
61
+ prompt: str,
62
+ system_prompt: Optional[str] = None,
63
+ temperature: float = 0.7,
64
+ max_tokens: Optional[int] = None,
65
+ **kwargs,
66
+ ) -> str:
67
+ """
68
+ Generate mock response.
69
+
70
+ Args:
71
+ prompt: Input prompt (recorded but not processed)
72
+ system_prompt: System prompt (recorded)
73
+ temperature: Temperature setting (recorded)
74
+ max_tokens: Max tokens (recorded)
75
+ **kwargs: Additional parameters (recorded)
76
+
77
+ Returns:
78
+ Next response from response list, or default_response
79
+ """
80
+ self.call_history.append(
81
+ {
82
+ "method": "generate",
83
+ "prompt": prompt,
84
+ "system_prompt": system_prompt,
85
+ "temperature": temperature,
86
+ "max_tokens": max_tokens,
87
+ "kwargs": kwargs,
88
+ "timestamp": time.time(),
89
+ }
90
+ )
91
+
92
+ if self.responses:
93
+ response = self.responses[self._response_index % len(self.responses)]
94
+ self._response_index += 1
95
+ return response
96
+
97
+ return self.default_response
98
+
99
+ def chat(
100
+ self,
101
+ messages: List[Dict[str, str]],
102
+ **kwargs,
103
+ ) -> str:
104
+ """
105
+ Mock chat completion (messages format).
106
+
107
+ Args:
108
+ messages: List of message dicts with 'role' and 'content'
109
+ **kwargs: Additional parameters
110
+
111
+ Returns:
112
+ Next response from response list
113
+ """
114
+ # Extract the last user message as the prompt
115
+ prompt = ""
116
+ for msg in reversed(messages):
117
+ if msg.get("role") == "user":
118
+ prompt = msg.get("content", "")
119
+ break
120
+
121
+ self.call_history.append(
122
+ {
123
+ "method": "chat",
124
+ "messages": messages,
125
+ "prompt": prompt,
126
+ "kwargs": kwargs,
127
+ "timestamp": time.time(),
128
+ }
129
+ )
130
+
131
+ if self.responses:
132
+ response = self.responses[self._response_index % len(self.responses)]
133
+ self._response_index += 1
134
+ return response
135
+
136
+ return self.default_response
137
+
138
+ def stream(
139
+ self,
140
+ prompt: str,
141
+ **kwargs,
142
+ ) -> Iterator[str]:
143
+ """
144
+ Mock streaming response.
145
+
146
+ Yields the full response as a single chunk for simplicity.
147
+
148
+ Args:
149
+ prompt: Input prompt
150
+ **kwargs: Additional parameters
151
+
152
+ Yields:
153
+ Response chunks (full response as single chunk)
154
+ """
155
+ response = self.generate(prompt, **kwargs)
156
+ # Update the last call to note it was streaming
157
+ if self.call_history:
158
+ self.call_history[-1]["method"] = "stream"
159
+ yield response
160
+
161
+ def complete(self, prompt: str, **kwargs) -> str:
162
+ """Alias for generate() for compatibility."""
163
+ return self.generate(prompt, **kwargs)
164
+
165
+ def get_tool_response(self, tool_name: str) -> Any:
166
+ """
167
+ Get mock response for a tool call.
168
+
169
+ Args:
170
+ tool_name: Name of the tool
171
+
172
+ Returns:
173
+ Configured mock result or default dict
174
+ """
175
+ return self.tool_responses.get(tool_name, {"status": "success"})
176
+
177
+ @property
178
+ def was_called(self) -> bool:
179
+ """Check if any method was called."""
180
+ return len(self.call_history) > 0
181
+
182
+ @property
183
+ def call_count(self) -> int:
184
+ """Number of times LLM was called."""
185
+ return len(self.call_history)
186
+
187
+ @property
188
+ def last_prompt(self) -> Optional[str]:
189
+ """Get the last prompt that was sent."""
190
+ if self.call_history:
191
+ return self.call_history[-1].get("prompt")
192
+ return None
193
+
194
+ def reset(self) -> None:
195
+ """Reset call history and response index."""
196
+ self.call_history = []
197
+ self._response_index = 0
198
+
199
+ def set_responses(self, responses: List[str]) -> None:
200
+ """
201
+ Set new responses and reset index.
202
+
203
+ Args:
204
+ responses: New list of responses
205
+ """
206
+ self.responses = responses
207
+ self._response_index = 0
208
+
209
+
210
+ class MockVLMClient:
211
+ """
212
+ Mock VLM client for testing image processing without real API calls.
213
+
214
+ Returns pre-configured text instead of processing images.
215
+ Tracks all calls for test assertions.
216
+
217
+ Example:
218
+ from gaia.testing import MockVLMClient
219
+
220
+ mock_vlm = MockVLMClient(
221
+ extracted_text='{"name": "John", "dob": "1990-01-01"}'
222
+ )
223
+
224
+ # Inject into agent
225
+ agent = MyAgent()
226
+ agent.vlm = mock_vlm
227
+
228
+ # Test extraction
229
+ result = agent.extract_form("test.png")
230
+
231
+ # Verify VLM was called
232
+ assert mock_vlm.was_called
233
+ assert mock_vlm.call_count == 1
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ extracted_text: str = "Mock extracted text",
239
+ extraction_results: Optional[List[str]] = None,
240
+ is_available: bool = True,
241
+ ):
242
+ """
243
+ Initialize mock VLM client.
244
+
245
+ Args:
246
+ extracted_text: Default text to return from extract_from_image()
247
+ extraction_results: List of results to return in sequence
248
+ is_available: Whether check_availability() returns True
249
+ """
250
+ self.extracted_text = extracted_text
251
+ self.extraction_results = extraction_results or []
252
+ self.is_available = is_available
253
+ self.call_history: List[Dict[str, Any]] = []
254
+ self._result_index = 0
255
+
256
+ def check_availability(self) -> bool:
257
+ """
258
+ Check if VLM is available.
259
+
260
+ Returns:
261
+ Configured is_available value
262
+ """
263
+ return self.is_available
264
+
265
+ def extract_from_image(
266
+ self,
267
+ image_bytes: bytes,
268
+ prompt: Optional[str] = None,
269
+ **kwargs,
270
+ ) -> str:
271
+ """
272
+ Mock image text extraction.
273
+
274
+ Args:
275
+ image_bytes: Image data (recorded but not processed)
276
+ prompt: Extraction prompt (recorded)
277
+ **kwargs: Additional parameters
278
+
279
+ Returns:
280
+ Pre-configured extracted text
281
+ """
282
+ self.call_history.append(
283
+ {
284
+ "method": "extract_from_image",
285
+ "image_size": len(image_bytes) if image_bytes else 0,
286
+ "prompt": prompt,
287
+ "kwargs": kwargs,
288
+ "timestamp": time.time(),
289
+ }
290
+ )
291
+
292
+ if self.extraction_results:
293
+ result = self.extraction_results[
294
+ self._result_index % len(self.extraction_results)
295
+ ]
296
+ self._result_index += 1
297
+ return result
298
+
299
+ return self.extracted_text
300
+
301
+ def extract_from_file(
302
+ self,
303
+ file_path: str,
304
+ prompt: Optional[str] = None,
305
+ **kwargs,
306
+ ) -> str:
307
+ """
308
+ Mock file-based extraction.
309
+
310
+ Args:
311
+ file_path: Path to image file
312
+ prompt: Extraction prompt
313
+
314
+ Returns:
315
+ Pre-configured extracted text
316
+ """
317
+ self.call_history.append(
318
+ {
319
+ "method": "extract_from_file",
320
+ "file_path": file_path,
321
+ "prompt": prompt,
322
+ "kwargs": kwargs,
323
+ "timestamp": time.time(),
324
+ }
325
+ )
326
+
327
+ if self.extraction_results:
328
+ result = self.extraction_results[
329
+ self._result_index % len(self.extraction_results)
330
+ ]
331
+ self._result_index += 1
332
+ return result
333
+
334
+ return self.extracted_text
335
+
336
+ def describe_image(
337
+ self,
338
+ image_bytes: bytes,
339
+ prompt: Optional[str] = None,
340
+ **kwargs,
341
+ ) -> str:
342
+ """
343
+ Mock image description.
344
+
345
+ Args:
346
+ image_bytes: Image data
347
+ prompt: Description prompt
348
+
349
+ Returns:
350
+ Pre-configured text
351
+ """
352
+ return self.extract_from_image(image_bytes, prompt, **kwargs)
353
+
354
+ @property
355
+ def was_called(self) -> bool:
356
+ """Check if any extraction method was called."""
357
+ return len(self.call_history) > 0
358
+
359
+ @property
360
+ def call_count(self) -> int:
361
+ """Number of times extraction was called."""
362
+ return len(self.call_history)
363
+
364
+ @property
365
+ def last_prompt(self) -> Optional[str]:
366
+ """Get the last prompt that was sent."""
367
+ if self.call_history:
368
+ return self.call_history[-1].get("prompt")
369
+ return None
370
+
371
+ def reset(self) -> None:
372
+ """Reset call history and result index."""
373
+ self.call_history = []
374
+ self._result_index = 0
375
+
376
+ def set_extracted_text(self, text: str) -> None:
377
+ """
378
+ Set new extracted text.
379
+
380
+ Args:
381
+ text: New text to return
382
+ """
383
+ self.extracted_text = text
384
+
385
+
386
+ class MockToolExecutor:
387
+ """
388
+ Mock tool executor for testing tool calls.
389
+
390
+ Tracks tool calls and returns configurable results.
391
+
392
+ Example:
393
+ from gaia.testing import MockToolExecutor
394
+
395
+ executor = MockToolExecutor(
396
+ results={
397
+ "search": {"results": ["item1", "item2"]},
398
+ "create_record": {"id": 123, "status": "created"},
399
+ }
400
+ )
401
+
402
+ result = executor.execute("search", {"query": "test"})
403
+ assert result == {"results": ["item1", "item2"]}
404
+
405
+ assert executor.was_tool_called("search")
406
+ assert executor.get_tool_args("search") == {"query": "test"}
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ results: Optional[Dict[str, Any]] = None,
412
+ default_result: Optional[Dict[str, Any]] = None,
413
+ ):
414
+ """
415
+ Initialize mock tool executor.
416
+
417
+ Args:
418
+ results: Dict mapping tool names to their results
419
+ default_result: Default result for unknown tools
420
+ """
421
+ self.results = results or {}
422
+ self.default_result = default_result or {"status": "success"}
423
+ self.call_history: List[Dict[str, Any]] = []
424
+
425
+ def execute(self, tool_name: str, args: Dict[str, Any]) -> Any:
426
+ """
427
+ Execute a mock tool.
428
+
429
+ Args:
430
+ tool_name: Name of the tool
431
+ args: Tool arguments
432
+
433
+ Returns:
434
+ Configured result for the tool
435
+ """
436
+ self.call_history.append(
437
+ {
438
+ "tool": tool_name,
439
+ "args": args,
440
+ "timestamp": time.time(),
441
+ }
442
+ )
443
+
444
+ return self.results.get(tool_name, self.default_result)
445
+
446
+ def was_tool_called(self, tool_name: str) -> bool:
447
+ """
448
+ Check if a specific tool was called.
449
+
450
+ Args:
451
+ tool_name: Name of the tool
452
+
453
+ Returns:
454
+ True if tool was called at least once
455
+ """
456
+ return any(call["tool"] == tool_name for call in self.call_history)
457
+
458
+ def get_tool_calls(self, tool_name: str) -> List[Dict[str, Any]]:
459
+ """
460
+ Get all calls to a specific tool.
461
+
462
+ Args:
463
+ tool_name: Name of the tool
464
+
465
+ Returns:
466
+ List of call records for that tool
467
+ """
468
+ return [call for call in self.call_history if call["tool"] == tool_name]
469
+
470
+ def get_tool_args(self, tool_name: str, call_index: int = 0) -> Optional[Dict]:
471
+ """
472
+ Get arguments from a specific tool call.
473
+
474
+ Args:
475
+ tool_name: Name of the tool
476
+ call_index: Which call to get (0 = first call)
477
+
478
+ Returns:
479
+ Arguments dict or None if not found
480
+ """
481
+ calls = self.get_tool_calls(tool_name)
482
+ if call_index < len(calls):
483
+ return calls[call_index]["args"]
484
+ return None
485
+
486
+ @property
487
+ def tool_names_called(self) -> List[str]:
488
+ """Get list of all tool names that were called."""
489
+ return list(set(call["tool"] for call in self.call_history))
490
+
491
+ def reset(self) -> None:
492
+ """Reset call history."""
493
+ self.call_history = []