amd-gaia 0.15.0__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/METADATA +223 -223
- amd_gaia-0.15.1.dist-info/RECORD +178 -0
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/entry_points.txt +1 -0
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/licenses/LICENSE.md +20 -20
- gaia/__init__.py +29 -29
- gaia/agents/__init__.py +19 -19
- gaia/agents/base/__init__.py +9 -9
- gaia/agents/base/agent.py +2177 -2177
- gaia/agents/base/api_agent.py +120 -120
- gaia/agents/base/console.py +1841 -1841
- gaia/agents/base/errors.py +237 -237
- gaia/agents/base/mcp_agent.py +86 -86
- gaia/agents/base/tools.py +83 -83
- gaia/agents/blender/agent.py +556 -556
- gaia/agents/blender/agent_simple.py +133 -135
- gaia/agents/blender/app.py +211 -211
- gaia/agents/blender/app_simple.py +41 -41
- gaia/agents/blender/core/__init__.py +16 -16
- gaia/agents/blender/core/materials.py +506 -506
- gaia/agents/blender/core/objects.py +316 -316
- gaia/agents/blender/core/rendering.py +225 -225
- gaia/agents/blender/core/scene.py +220 -220
- gaia/agents/blender/core/view.py +146 -146
- gaia/agents/chat/__init__.py +9 -9
- gaia/agents/chat/agent.py +835 -835
- gaia/agents/chat/app.py +1058 -1058
- gaia/agents/chat/session.py +508 -508
- gaia/agents/chat/tools/__init__.py +15 -15
- gaia/agents/chat/tools/file_tools.py +96 -96
- gaia/agents/chat/tools/rag_tools.py +1729 -1729
- gaia/agents/chat/tools/shell_tools.py +436 -436
- gaia/agents/code/__init__.py +7 -7
- gaia/agents/code/agent.py +549 -549
- gaia/agents/code/cli.py +377 -0
- gaia/agents/code/models.py +135 -135
- gaia/agents/code/orchestration/__init__.py +24 -24
- gaia/agents/code/orchestration/checklist_executor.py +1763 -1763
- gaia/agents/code/orchestration/checklist_generator.py +713 -713
- gaia/agents/code/orchestration/factories/__init__.py +9 -9
- gaia/agents/code/orchestration/factories/base.py +63 -63
- gaia/agents/code/orchestration/factories/nextjs_factory.py +118 -118
- gaia/agents/code/orchestration/factories/python_factory.py +106 -106
- gaia/agents/code/orchestration/orchestrator.py +841 -841
- gaia/agents/code/orchestration/project_analyzer.py +391 -391
- gaia/agents/code/orchestration/steps/__init__.py +67 -67
- gaia/agents/code/orchestration/steps/base.py +188 -188
- gaia/agents/code/orchestration/steps/error_handler.py +314 -314
- gaia/agents/code/orchestration/steps/nextjs.py +828 -828
- gaia/agents/code/orchestration/steps/python.py +307 -307
- gaia/agents/code/orchestration/template_catalog.py +469 -469
- gaia/agents/code/orchestration/workflows/__init__.py +14 -14
- gaia/agents/code/orchestration/workflows/base.py +80 -80
- gaia/agents/code/orchestration/workflows/nextjs.py +186 -186
- gaia/agents/code/orchestration/workflows/python.py +94 -94
- gaia/agents/code/prompts/__init__.py +11 -11
- gaia/agents/code/prompts/base_prompt.py +77 -77
- gaia/agents/code/prompts/code_patterns.py +2036 -2036
- gaia/agents/code/prompts/nextjs_prompt.py +40 -40
- gaia/agents/code/prompts/python_prompt.py +109 -109
- gaia/agents/code/schema_inference.py +365 -365
- gaia/agents/code/system_prompt.py +41 -41
- gaia/agents/code/tools/__init__.py +42 -42
- gaia/agents/code/tools/cli_tools.py +1138 -1138
- gaia/agents/code/tools/code_formatting.py +319 -319
- gaia/agents/code/tools/code_tools.py +769 -769
- gaia/agents/code/tools/error_fixing.py +1347 -1347
- gaia/agents/code/tools/external_tools.py +180 -180
- gaia/agents/code/tools/file_io.py +845 -845
- gaia/agents/code/tools/prisma_tools.py +190 -190
- gaia/agents/code/tools/project_management.py +1016 -1016
- gaia/agents/code/tools/testing.py +321 -321
- gaia/agents/code/tools/typescript_tools.py +122 -122
- gaia/agents/code/tools/validation_parsing.py +461 -461
- gaia/agents/code/tools/validation_tools.py +806 -806
- gaia/agents/code/tools/web_dev_tools.py +1758 -1758
- gaia/agents/code/validators/__init__.py +16 -16
- gaia/agents/code/validators/antipattern_checker.py +241 -241
- gaia/agents/code/validators/ast_analyzer.py +197 -197
- gaia/agents/code/validators/requirements_validator.py +145 -145
- gaia/agents/code/validators/syntax_validator.py +171 -171
- gaia/agents/docker/__init__.py +7 -7
- gaia/agents/docker/agent.py +642 -642
- gaia/agents/emr/__init__.py +8 -8
- gaia/agents/emr/agent.py +1506 -1506
- gaia/agents/emr/cli.py +1322 -1322
- gaia/agents/emr/constants.py +475 -475
- gaia/agents/emr/dashboard/__init__.py +4 -4
- gaia/agents/emr/dashboard/server.py +1974 -1974
- gaia/agents/jira/__init__.py +11 -11
- gaia/agents/jira/agent.py +894 -894
- gaia/agents/jira/jql_templates.py +299 -299
- gaia/agents/routing/__init__.py +7 -7
- gaia/agents/routing/agent.py +567 -570
- gaia/agents/routing/system_prompt.py +75 -75
- gaia/agents/summarize/__init__.py +11 -0
- gaia/agents/summarize/agent.py +885 -0
- gaia/agents/summarize/prompts.py +129 -0
- gaia/api/__init__.py +23 -23
- gaia/api/agent_registry.py +238 -238
- gaia/api/app.py +305 -305
- gaia/api/openai_server.py +575 -575
- gaia/api/schemas.py +186 -186
- gaia/api/sse_handler.py +373 -373
- gaia/apps/__init__.py +4 -4
- gaia/apps/llm/__init__.py +6 -6
- gaia/apps/llm/app.py +173 -169
- gaia/apps/summarize/app.py +116 -633
- gaia/apps/summarize/html_viewer.py +133 -133
- gaia/apps/summarize/pdf_formatter.py +284 -284
- gaia/audio/__init__.py +2 -2
- gaia/audio/audio_client.py +439 -439
- gaia/audio/audio_recorder.py +269 -269
- gaia/audio/kokoro_tts.py +599 -599
- gaia/audio/whisper_asr.py +432 -432
- gaia/chat/__init__.py +16 -16
- gaia/chat/app.py +430 -430
- gaia/chat/prompts.py +522 -522
- gaia/chat/sdk.py +1228 -1225
- gaia/cli.py +5481 -5632
- gaia/database/__init__.py +10 -10
- gaia/database/agent.py +176 -176
- gaia/database/mixin.py +290 -290
- gaia/database/testing.py +64 -64
- gaia/eval/batch_experiment.py +2332 -2332
- gaia/eval/claude.py +542 -542
- gaia/eval/config.py +37 -37
- gaia/eval/email_generator.py +512 -512
- gaia/eval/eval.py +3179 -3179
- gaia/eval/groundtruth.py +1130 -1130
- gaia/eval/transcript_generator.py +582 -582
- gaia/eval/webapp/README.md +167 -167
- gaia/eval/webapp/package-lock.json +875 -875
- gaia/eval/webapp/package.json +20 -20
- gaia/eval/webapp/public/app.js +3402 -3402
- gaia/eval/webapp/public/index.html +87 -87
- gaia/eval/webapp/public/styles.css +3661 -3661
- gaia/eval/webapp/server.js +415 -415
- gaia/eval/webapp/test-setup.js +72 -72
- gaia/llm/__init__.py +9 -2
- gaia/llm/base_client.py +60 -0
- gaia/llm/exceptions.py +12 -0
- gaia/llm/factory.py +70 -0
- gaia/llm/lemonade_client.py +3236 -3221
- gaia/llm/lemonade_manager.py +294 -294
- gaia/llm/providers/__init__.py +9 -0
- gaia/llm/providers/claude.py +108 -0
- gaia/llm/providers/lemonade.py +120 -0
- gaia/llm/providers/openai_provider.py +79 -0
- gaia/llm/vlm_client.py +382 -382
- gaia/logger.py +189 -189
- gaia/mcp/agent_mcp_server.py +245 -245
- gaia/mcp/blender_mcp_client.py +138 -138
- gaia/mcp/blender_mcp_server.py +648 -648
- gaia/mcp/context7_cache.py +332 -332
- gaia/mcp/external_services.py +518 -518
- gaia/mcp/mcp_bridge.py +811 -550
- gaia/mcp/servers/__init__.py +6 -6
- gaia/mcp/servers/docker_mcp.py +83 -83
- gaia/perf_analysis.py +361 -0
- gaia/rag/__init__.py +10 -10
- gaia/rag/app.py +293 -293
- gaia/rag/demo.py +304 -304
- gaia/rag/pdf_utils.py +235 -235
- gaia/rag/sdk.py +2194 -2194
- gaia/security.py +163 -163
- gaia/talk/app.py +289 -289
- gaia/talk/sdk.py +538 -538
- gaia/testing/__init__.py +87 -87
- gaia/testing/assertions.py +330 -330
- gaia/testing/fixtures.py +333 -333
- gaia/testing/mocks.py +493 -493
- gaia/util.py +46 -46
- gaia/utils/__init__.py +33 -33
- gaia/utils/file_watcher.py +675 -675
- gaia/utils/parsing.py +223 -223
- gaia/version.py +100 -100
- amd_gaia-0.15.0.dist-info/RECORD +0 -168
- gaia/agents/code/app.py +0 -266
- gaia/llm/llm_client.py +0 -723
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/WHEEL +0 -0
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/top_level.txt +0 -0
gaia/testing/mocks.py
CHANGED
|
@@ -1,493 +1,493 @@
|
|
|
1
|
-
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: MIT
|
|
3
|
-
|
|
4
|
-
"""Mock providers for testing GAIA agents without real LLM/VLM services."""
|
|
5
|
-
|
|
6
|
-
import logging
|
|
7
|
-
import time
|
|
8
|
-
from typing import Any, Dict, Iterator, List, Optional
|
|
9
|
-
|
|
10
|
-
logger = logging.getLogger(__name__)
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class MockLLMProvider:
|
|
14
|
-
"""
|
|
15
|
-
Mock LLM provider for testing agents without real API calls.
|
|
16
|
-
|
|
17
|
-
Returns pre-configured responses instead of calling a real LLM.
|
|
18
|
-
Tracks all calls for test assertions.
|
|
19
|
-
|
|
20
|
-
Example:
|
|
21
|
-
from gaia.testing import MockLLMProvider
|
|
22
|
-
|
|
23
|
-
mock_llm = MockLLMProvider(responses=["First response", "Second response"])
|
|
24
|
-
|
|
25
|
-
# Use in tests
|
|
26
|
-
result = mock_llm.generate("Test prompt")
|
|
27
|
-
assert result == "First response"
|
|
28
|
-
|
|
29
|
-
result = mock_llm.generate("Another prompt")
|
|
30
|
-
assert result == "Second response"
|
|
31
|
-
|
|
32
|
-
# Check call history
|
|
33
|
-
assert len(mock_llm.call_history) == 2
|
|
34
|
-
assert mock_llm.call_history[0]["prompt"] == "Test prompt"
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
def __init__(
|
|
38
|
-
self,
|
|
39
|
-
responses: Optional[List[str]] = None,
|
|
40
|
-
tool_responses: Optional[Dict[str, Any]] = None,
|
|
41
|
-
default_response: str = "Mock LLM response",
|
|
42
|
-
):
|
|
43
|
-
"""
|
|
44
|
-
Initialize mock LLM provider.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
responses: List of responses to return in sequence.
|
|
48
|
-
Cycles back to first if more calls than responses.
|
|
49
|
-
tool_responses: Dict mapping tool names to their mock results.
|
|
50
|
-
Used when simulating tool calls.
|
|
51
|
-
default_response: Response when responses list is exhausted or empty.
|
|
52
|
-
"""
|
|
53
|
-
self.responses = responses or []
|
|
54
|
-
self.tool_responses = tool_responses or {}
|
|
55
|
-
self.default_response = default_response
|
|
56
|
-
self.call_history: List[Dict[str, Any]] = []
|
|
57
|
-
self._response_index = 0
|
|
58
|
-
|
|
59
|
-
def generate(
|
|
60
|
-
self,
|
|
61
|
-
prompt: str,
|
|
62
|
-
system_prompt: Optional[str] = None,
|
|
63
|
-
temperature: float = 0.7,
|
|
64
|
-
max_tokens: Optional[int] = None,
|
|
65
|
-
**kwargs,
|
|
66
|
-
) -> str:
|
|
67
|
-
"""
|
|
68
|
-
Generate mock response.
|
|
69
|
-
|
|
70
|
-
Args:
|
|
71
|
-
prompt: Input prompt (recorded but not processed)
|
|
72
|
-
system_prompt: System prompt (recorded)
|
|
73
|
-
temperature: Temperature setting (recorded)
|
|
74
|
-
max_tokens: Max tokens (recorded)
|
|
75
|
-
**kwargs: Additional parameters (recorded)
|
|
76
|
-
|
|
77
|
-
Returns:
|
|
78
|
-
Next response from response list, or default_response
|
|
79
|
-
"""
|
|
80
|
-
self.call_history.append(
|
|
81
|
-
{
|
|
82
|
-
"method": "generate",
|
|
83
|
-
"prompt": prompt,
|
|
84
|
-
"system_prompt": system_prompt,
|
|
85
|
-
"temperature": temperature,
|
|
86
|
-
"max_tokens": max_tokens,
|
|
87
|
-
"kwargs": kwargs,
|
|
88
|
-
"timestamp": time.time(),
|
|
89
|
-
}
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
if self.responses:
|
|
93
|
-
response = self.responses[self._response_index % len(self.responses)]
|
|
94
|
-
self._response_index += 1
|
|
95
|
-
return response
|
|
96
|
-
|
|
97
|
-
return self.default_response
|
|
98
|
-
|
|
99
|
-
def chat(
|
|
100
|
-
self,
|
|
101
|
-
messages: List[Dict[str, str]],
|
|
102
|
-
**kwargs,
|
|
103
|
-
) -> str:
|
|
104
|
-
"""
|
|
105
|
-
Mock chat completion (messages format).
|
|
106
|
-
|
|
107
|
-
Args:
|
|
108
|
-
messages: List of message dicts with 'role' and 'content'
|
|
109
|
-
**kwargs: Additional parameters
|
|
110
|
-
|
|
111
|
-
Returns:
|
|
112
|
-
Next response from response list
|
|
113
|
-
"""
|
|
114
|
-
# Extract the last user message as the prompt
|
|
115
|
-
prompt = ""
|
|
116
|
-
for msg in reversed(messages):
|
|
117
|
-
if msg.get("role") == "user":
|
|
118
|
-
prompt = msg.get("content", "")
|
|
119
|
-
break
|
|
120
|
-
|
|
121
|
-
self.call_history.append(
|
|
122
|
-
{
|
|
123
|
-
"method": "chat",
|
|
124
|
-
"messages": messages,
|
|
125
|
-
"prompt": prompt,
|
|
126
|
-
"kwargs": kwargs,
|
|
127
|
-
"timestamp": time.time(),
|
|
128
|
-
}
|
|
129
|
-
)
|
|
130
|
-
|
|
131
|
-
if self.responses:
|
|
132
|
-
response = self.responses[self._response_index % len(self.responses)]
|
|
133
|
-
self._response_index += 1
|
|
134
|
-
return response
|
|
135
|
-
|
|
136
|
-
return self.default_response
|
|
137
|
-
|
|
138
|
-
def stream(
|
|
139
|
-
self,
|
|
140
|
-
prompt: str,
|
|
141
|
-
**kwargs,
|
|
142
|
-
) -> Iterator[str]:
|
|
143
|
-
"""
|
|
144
|
-
Mock streaming response.
|
|
145
|
-
|
|
146
|
-
Yields the full response as a single chunk for simplicity.
|
|
147
|
-
|
|
148
|
-
Args:
|
|
149
|
-
prompt: Input prompt
|
|
150
|
-
**kwargs: Additional parameters
|
|
151
|
-
|
|
152
|
-
Yields:
|
|
153
|
-
Response chunks (full response as single chunk)
|
|
154
|
-
"""
|
|
155
|
-
response = self.generate(prompt, **kwargs)
|
|
156
|
-
# Update the last call to note it was streaming
|
|
157
|
-
if self.call_history:
|
|
158
|
-
self.call_history[-1]["method"] = "stream"
|
|
159
|
-
yield response
|
|
160
|
-
|
|
161
|
-
def complete(self, prompt: str, **kwargs) -> str:
|
|
162
|
-
"""Alias for generate() for compatibility."""
|
|
163
|
-
return self.generate(prompt, **kwargs)
|
|
164
|
-
|
|
165
|
-
def get_tool_response(self, tool_name: str) -> Any:
|
|
166
|
-
"""
|
|
167
|
-
Get mock response for a tool call.
|
|
168
|
-
|
|
169
|
-
Args:
|
|
170
|
-
tool_name: Name of the tool
|
|
171
|
-
|
|
172
|
-
Returns:
|
|
173
|
-
Configured mock result or default dict
|
|
174
|
-
"""
|
|
175
|
-
return self.tool_responses.get(tool_name, {"status": "success"})
|
|
176
|
-
|
|
177
|
-
@property
|
|
178
|
-
def was_called(self) -> bool:
|
|
179
|
-
"""Check if any method was called."""
|
|
180
|
-
return len(self.call_history) > 0
|
|
181
|
-
|
|
182
|
-
@property
|
|
183
|
-
def call_count(self) -> int:
|
|
184
|
-
"""Number of times LLM was called."""
|
|
185
|
-
return len(self.call_history)
|
|
186
|
-
|
|
187
|
-
@property
|
|
188
|
-
def last_prompt(self) -> Optional[str]:
|
|
189
|
-
"""Get the last prompt that was sent."""
|
|
190
|
-
if self.call_history:
|
|
191
|
-
return self.call_history[-1].get("prompt")
|
|
192
|
-
return None
|
|
193
|
-
|
|
194
|
-
def reset(self) -> None:
|
|
195
|
-
"""Reset call history and response index."""
|
|
196
|
-
self.call_history = []
|
|
197
|
-
self._response_index = 0
|
|
198
|
-
|
|
199
|
-
def set_responses(self, responses: List[str]) -> None:
|
|
200
|
-
"""
|
|
201
|
-
Set new responses and reset index.
|
|
202
|
-
|
|
203
|
-
Args:
|
|
204
|
-
responses: New list of responses
|
|
205
|
-
"""
|
|
206
|
-
self.responses = responses
|
|
207
|
-
self._response_index = 0
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
class MockVLMClient:
|
|
211
|
-
"""
|
|
212
|
-
Mock VLM client for testing image processing without real API calls.
|
|
213
|
-
|
|
214
|
-
Returns pre-configured text instead of processing images.
|
|
215
|
-
Tracks all calls for test assertions.
|
|
216
|
-
|
|
217
|
-
Example:
|
|
218
|
-
from gaia.testing import MockVLMClient
|
|
219
|
-
|
|
220
|
-
mock_vlm = MockVLMClient(
|
|
221
|
-
extracted_text='{"name": "John", "dob": "1990-01-01"}'
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
# Inject into agent
|
|
225
|
-
agent = MyAgent()
|
|
226
|
-
agent.vlm = mock_vlm
|
|
227
|
-
|
|
228
|
-
# Test extraction
|
|
229
|
-
result = agent.extract_form("test.png")
|
|
230
|
-
|
|
231
|
-
# Verify VLM was called
|
|
232
|
-
assert mock_vlm.was_called
|
|
233
|
-
assert mock_vlm.call_count == 1
|
|
234
|
-
"""
|
|
235
|
-
|
|
236
|
-
def __init__(
|
|
237
|
-
self,
|
|
238
|
-
extracted_text: str = "Mock extracted text",
|
|
239
|
-
extraction_results: Optional[List[str]] = None,
|
|
240
|
-
is_available: bool = True,
|
|
241
|
-
):
|
|
242
|
-
"""
|
|
243
|
-
Initialize mock VLM client.
|
|
244
|
-
|
|
245
|
-
Args:
|
|
246
|
-
extracted_text: Default text to return from extract_from_image()
|
|
247
|
-
extraction_results: List of results to return in sequence
|
|
248
|
-
is_available: Whether check_availability() returns True
|
|
249
|
-
"""
|
|
250
|
-
self.extracted_text = extracted_text
|
|
251
|
-
self.extraction_results = extraction_results or []
|
|
252
|
-
self.is_available = is_available
|
|
253
|
-
self.call_history: List[Dict[str, Any]] = []
|
|
254
|
-
self._result_index = 0
|
|
255
|
-
|
|
256
|
-
def check_availability(self) -> bool:
|
|
257
|
-
"""
|
|
258
|
-
Check if VLM is available.
|
|
259
|
-
|
|
260
|
-
Returns:
|
|
261
|
-
Configured is_available value
|
|
262
|
-
"""
|
|
263
|
-
return self.is_available
|
|
264
|
-
|
|
265
|
-
def extract_from_image(
|
|
266
|
-
self,
|
|
267
|
-
image_bytes: bytes,
|
|
268
|
-
prompt: Optional[str] = None,
|
|
269
|
-
**kwargs,
|
|
270
|
-
) -> str:
|
|
271
|
-
"""
|
|
272
|
-
Mock image text extraction.
|
|
273
|
-
|
|
274
|
-
Args:
|
|
275
|
-
image_bytes: Image data (recorded but not processed)
|
|
276
|
-
prompt: Extraction prompt (recorded)
|
|
277
|
-
**kwargs: Additional parameters
|
|
278
|
-
|
|
279
|
-
Returns:
|
|
280
|
-
Pre-configured extracted text
|
|
281
|
-
"""
|
|
282
|
-
self.call_history.append(
|
|
283
|
-
{
|
|
284
|
-
"method": "extract_from_image",
|
|
285
|
-
"image_size": len(image_bytes) if image_bytes else 0,
|
|
286
|
-
"prompt": prompt,
|
|
287
|
-
"kwargs": kwargs,
|
|
288
|
-
"timestamp": time.time(),
|
|
289
|
-
}
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
if self.extraction_results:
|
|
293
|
-
result = self.extraction_results[
|
|
294
|
-
self._result_index % len(self.extraction_results)
|
|
295
|
-
]
|
|
296
|
-
self._result_index += 1
|
|
297
|
-
return result
|
|
298
|
-
|
|
299
|
-
return self.extracted_text
|
|
300
|
-
|
|
301
|
-
def extract_from_file(
|
|
302
|
-
self,
|
|
303
|
-
file_path: str,
|
|
304
|
-
prompt: Optional[str] = None,
|
|
305
|
-
**kwargs,
|
|
306
|
-
) -> str:
|
|
307
|
-
"""
|
|
308
|
-
Mock file-based extraction.
|
|
309
|
-
|
|
310
|
-
Args:
|
|
311
|
-
file_path: Path to image file
|
|
312
|
-
prompt: Extraction prompt
|
|
313
|
-
|
|
314
|
-
Returns:
|
|
315
|
-
Pre-configured extracted text
|
|
316
|
-
"""
|
|
317
|
-
self.call_history.append(
|
|
318
|
-
{
|
|
319
|
-
"method": "extract_from_file",
|
|
320
|
-
"file_path": file_path,
|
|
321
|
-
"prompt": prompt,
|
|
322
|
-
"kwargs": kwargs,
|
|
323
|
-
"timestamp": time.time(),
|
|
324
|
-
}
|
|
325
|
-
)
|
|
326
|
-
|
|
327
|
-
if self.extraction_results:
|
|
328
|
-
result = self.extraction_results[
|
|
329
|
-
self._result_index % len(self.extraction_results)
|
|
330
|
-
]
|
|
331
|
-
self._result_index += 1
|
|
332
|
-
return result
|
|
333
|
-
|
|
334
|
-
return self.extracted_text
|
|
335
|
-
|
|
336
|
-
def describe_image(
|
|
337
|
-
self,
|
|
338
|
-
image_bytes: bytes,
|
|
339
|
-
prompt: Optional[str] = None,
|
|
340
|
-
**kwargs,
|
|
341
|
-
) -> str:
|
|
342
|
-
"""
|
|
343
|
-
Mock image description.
|
|
344
|
-
|
|
345
|
-
Args:
|
|
346
|
-
image_bytes: Image data
|
|
347
|
-
prompt: Description prompt
|
|
348
|
-
|
|
349
|
-
Returns:
|
|
350
|
-
Pre-configured text
|
|
351
|
-
"""
|
|
352
|
-
return self.extract_from_image(image_bytes, prompt, **kwargs)
|
|
353
|
-
|
|
354
|
-
@property
|
|
355
|
-
def was_called(self) -> bool:
|
|
356
|
-
"""Check if any extraction method was called."""
|
|
357
|
-
return len(self.call_history) > 0
|
|
358
|
-
|
|
359
|
-
@property
|
|
360
|
-
def call_count(self) -> int:
|
|
361
|
-
"""Number of times extraction was called."""
|
|
362
|
-
return len(self.call_history)
|
|
363
|
-
|
|
364
|
-
@property
|
|
365
|
-
def last_prompt(self) -> Optional[str]:
|
|
366
|
-
"""Get the last prompt that was sent."""
|
|
367
|
-
if self.call_history:
|
|
368
|
-
return self.call_history[-1].get("prompt")
|
|
369
|
-
return None
|
|
370
|
-
|
|
371
|
-
def reset(self) -> None:
|
|
372
|
-
"""Reset call history and result index."""
|
|
373
|
-
self.call_history = []
|
|
374
|
-
self._result_index = 0
|
|
375
|
-
|
|
376
|
-
def set_extracted_text(self, text: str) -> None:
|
|
377
|
-
"""
|
|
378
|
-
Set new extracted text.
|
|
379
|
-
|
|
380
|
-
Args:
|
|
381
|
-
text: New text to return
|
|
382
|
-
"""
|
|
383
|
-
self.extracted_text = text
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
class MockToolExecutor:
|
|
387
|
-
"""
|
|
388
|
-
Mock tool executor for testing tool calls.
|
|
389
|
-
|
|
390
|
-
Tracks tool calls and returns configurable results.
|
|
391
|
-
|
|
392
|
-
Example:
|
|
393
|
-
from gaia.testing import MockToolExecutor
|
|
394
|
-
|
|
395
|
-
executor = MockToolExecutor(
|
|
396
|
-
results={
|
|
397
|
-
"search": {"results": ["item1", "item2"]},
|
|
398
|
-
"create_record": {"id": 123, "status": "created"},
|
|
399
|
-
}
|
|
400
|
-
)
|
|
401
|
-
|
|
402
|
-
result = executor.execute("search", {"query": "test"})
|
|
403
|
-
assert result == {"results": ["item1", "item2"]}
|
|
404
|
-
|
|
405
|
-
assert executor.was_tool_called("search")
|
|
406
|
-
assert executor.get_tool_args("search") == {"query": "test"}
|
|
407
|
-
"""
|
|
408
|
-
|
|
409
|
-
def __init__(
|
|
410
|
-
self,
|
|
411
|
-
results: Optional[Dict[str, Any]] = None,
|
|
412
|
-
default_result: Optional[Dict[str, Any]] = None,
|
|
413
|
-
):
|
|
414
|
-
"""
|
|
415
|
-
Initialize mock tool executor.
|
|
416
|
-
|
|
417
|
-
Args:
|
|
418
|
-
results: Dict mapping tool names to their results
|
|
419
|
-
default_result: Default result for unknown tools
|
|
420
|
-
"""
|
|
421
|
-
self.results = results or {}
|
|
422
|
-
self.default_result = default_result or {"status": "success"}
|
|
423
|
-
self.call_history: List[Dict[str, Any]] = []
|
|
424
|
-
|
|
425
|
-
def execute(self, tool_name: str, args: Dict[str, Any]) -> Any:
|
|
426
|
-
"""
|
|
427
|
-
Execute a mock tool.
|
|
428
|
-
|
|
429
|
-
Args:
|
|
430
|
-
tool_name: Name of the tool
|
|
431
|
-
args: Tool arguments
|
|
432
|
-
|
|
433
|
-
Returns:
|
|
434
|
-
Configured result for the tool
|
|
435
|
-
"""
|
|
436
|
-
self.call_history.append(
|
|
437
|
-
{
|
|
438
|
-
"tool": tool_name,
|
|
439
|
-
"args": args,
|
|
440
|
-
"timestamp": time.time(),
|
|
441
|
-
}
|
|
442
|
-
)
|
|
443
|
-
|
|
444
|
-
return self.results.get(tool_name, self.default_result)
|
|
445
|
-
|
|
446
|
-
def was_tool_called(self, tool_name: str) -> bool:
|
|
447
|
-
"""
|
|
448
|
-
Check if a specific tool was called.
|
|
449
|
-
|
|
450
|
-
Args:
|
|
451
|
-
tool_name: Name of the tool
|
|
452
|
-
|
|
453
|
-
Returns:
|
|
454
|
-
True if tool was called at least once
|
|
455
|
-
"""
|
|
456
|
-
return any(call["tool"] == tool_name for call in self.call_history)
|
|
457
|
-
|
|
458
|
-
def get_tool_calls(self, tool_name: str) -> List[Dict[str, Any]]:
|
|
459
|
-
"""
|
|
460
|
-
Get all calls to a specific tool.
|
|
461
|
-
|
|
462
|
-
Args:
|
|
463
|
-
tool_name: Name of the tool
|
|
464
|
-
|
|
465
|
-
Returns:
|
|
466
|
-
List of call records for that tool
|
|
467
|
-
"""
|
|
468
|
-
return [call for call in self.call_history if call["tool"] == tool_name]
|
|
469
|
-
|
|
470
|
-
def get_tool_args(self, tool_name: str, call_index: int = 0) -> Optional[Dict]:
|
|
471
|
-
"""
|
|
472
|
-
Get arguments from a specific tool call.
|
|
473
|
-
|
|
474
|
-
Args:
|
|
475
|
-
tool_name: Name of the tool
|
|
476
|
-
call_index: Which call to get (0 = first call)
|
|
477
|
-
|
|
478
|
-
Returns:
|
|
479
|
-
Arguments dict or None if not found
|
|
480
|
-
"""
|
|
481
|
-
calls = self.get_tool_calls(tool_name)
|
|
482
|
-
if call_index < len(calls):
|
|
483
|
-
return calls[call_index]["args"]
|
|
484
|
-
return None
|
|
485
|
-
|
|
486
|
-
@property
|
|
487
|
-
def tool_names_called(self) -> List[str]:
|
|
488
|
-
"""Get list of all tool names that were called."""
|
|
489
|
-
return list(set(call["tool"] for call in self.call_history))
|
|
490
|
-
|
|
491
|
-
def reset(self) -> None:
|
|
492
|
-
"""Reset call history."""
|
|
493
|
-
self.call_history = []
|
|
1
|
+
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""Mock providers for testing GAIA agents without real LLM/VLM services."""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
from typing import Any, Dict, Iterator, List, Optional
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MockLLMProvider:
|
|
14
|
+
"""
|
|
15
|
+
Mock LLM provider for testing agents without real API calls.
|
|
16
|
+
|
|
17
|
+
Returns pre-configured responses instead of calling a real LLM.
|
|
18
|
+
Tracks all calls for test assertions.
|
|
19
|
+
|
|
20
|
+
Example:
|
|
21
|
+
from gaia.testing import MockLLMProvider
|
|
22
|
+
|
|
23
|
+
mock_llm = MockLLMProvider(responses=["First response", "Second response"])
|
|
24
|
+
|
|
25
|
+
# Use in tests
|
|
26
|
+
result = mock_llm.generate("Test prompt")
|
|
27
|
+
assert result == "First response"
|
|
28
|
+
|
|
29
|
+
result = mock_llm.generate("Another prompt")
|
|
30
|
+
assert result == "Second response"
|
|
31
|
+
|
|
32
|
+
# Check call history
|
|
33
|
+
assert len(mock_llm.call_history) == 2
|
|
34
|
+
assert mock_llm.call_history[0]["prompt"] == "Test prompt"
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
responses: Optional[List[str]] = None,
|
|
40
|
+
tool_responses: Optional[Dict[str, Any]] = None,
|
|
41
|
+
default_response: str = "Mock LLM response",
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Initialize mock LLM provider.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
responses: List of responses to return in sequence.
|
|
48
|
+
Cycles back to first if more calls than responses.
|
|
49
|
+
tool_responses: Dict mapping tool names to their mock results.
|
|
50
|
+
Used when simulating tool calls.
|
|
51
|
+
default_response: Response when responses list is exhausted or empty.
|
|
52
|
+
"""
|
|
53
|
+
self.responses = responses or []
|
|
54
|
+
self.tool_responses = tool_responses or {}
|
|
55
|
+
self.default_response = default_response
|
|
56
|
+
self.call_history: List[Dict[str, Any]] = []
|
|
57
|
+
self._response_index = 0
|
|
58
|
+
|
|
59
|
+
def generate(
|
|
60
|
+
self,
|
|
61
|
+
prompt: str,
|
|
62
|
+
system_prompt: Optional[str] = None,
|
|
63
|
+
temperature: float = 0.7,
|
|
64
|
+
max_tokens: Optional[int] = None,
|
|
65
|
+
**kwargs,
|
|
66
|
+
) -> str:
|
|
67
|
+
"""
|
|
68
|
+
Generate mock response.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
prompt: Input prompt (recorded but not processed)
|
|
72
|
+
system_prompt: System prompt (recorded)
|
|
73
|
+
temperature: Temperature setting (recorded)
|
|
74
|
+
max_tokens: Max tokens (recorded)
|
|
75
|
+
**kwargs: Additional parameters (recorded)
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Next response from response list, or default_response
|
|
79
|
+
"""
|
|
80
|
+
self.call_history.append(
|
|
81
|
+
{
|
|
82
|
+
"method": "generate",
|
|
83
|
+
"prompt": prompt,
|
|
84
|
+
"system_prompt": system_prompt,
|
|
85
|
+
"temperature": temperature,
|
|
86
|
+
"max_tokens": max_tokens,
|
|
87
|
+
"kwargs": kwargs,
|
|
88
|
+
"timestamp": time.time(),
|
|
89
|
+
}
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if self.responses:
|
|
93
|
+
response = self.responses[self._response_index % len(self.responses)]
|
|
94
|
+
self._response_index += 1
|
|
95
|
+
return response
|
|
96
|
+
|
|
97
|
+
return self.default_response
|
|
98
|
+
|
|
99
|
+
def chat(
|
|
100
|
+
self,
|
|
101
|
+
messages: List[Dict[str, str]],
|
|
102
|
+
**kwargs,
|
|
103
|
+
) -> str:
|
|
104
|
+
"""
|
|
105
|
+
Mock chat completion (messages format).
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
messages: List of message dicts with 'role' and 'content'
|
|
109
|
+
**kwargs: Additional parameters
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Next response from response list
|
|
113
|
+
"""
|
|
114
|
+
# Extract the last user message as the prompt
|
|
115
|
+
prompt = ""
|
|
116
|
+
for msg in reversed(messages):
|
|
117
|
+
if msg.get("role") == "user":
|
|
118
|
+
prompt = msg.get("content", "")
|
|
119
|
+
break
|
|
120
|
+
|
|
121
|
+
self.call_history.append(
|
|
122
|
+
{
|
|
123
|
+
"method": "chat",
|
|
124
|
+
"messages": messages,
|
|
125
|
+
"prompt": prompt,
|
|
126
|
+
"kwargs": kwargs,
|
|
127
|
+
"timestamp": time.time(),
|
|
128
|
+
}
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if self.responses:
|
|
132
|
+
response = self.responses[self._response_index % len(self.responses)]
|
|
133
|
+
self._response_index += 1
|
|
134
|
+
return response
|
|
135
|
+
|
|
136
|
+
return self.default_response
|
|
137
|
+
|
|
138
|
+
def stream(
|
|
139
|
+
self,
|
|
140
|
+
prompt: str,
|
|
141
|
+
**kwargs,
|
|
142
|
+
) -> Iterator[str]:
|
|
143
|
+
"""
|
|
144
|
+
Mock streaming response.
|
|
145
|
+
|
|
146
|
+
Yields the full response as a single chunk for simplicity.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
prompt: Input prompt
|
|
150
|
+
**kwargs: Additional parameters
|
|
151
|
+
|
|
152
|
+
Yields:
|
|
153
|
+
Response chunks (full response as single chunk)
|
|
154
|
+
"""
|
|
155
|
+
response = self.generate(prompt, **kwargs)
|
|
156
|
+
# Update the last call to note it was streaming
|
|
157
|
+
if self.call_history:
|
|
158
|
+
self.call_history[-1]["method"] = "stream"
|
|
159
|
+
yield response
|
|
160
|
+
|
|
161
|
+
def complete(self, prompt: str, **kwargs) -> str:
|
|
162
|
+
"""Alias for generate() for compatibility."""
|
|
163
|
+
return self.generate(prompt, **kwargs)
|
|
164
|
+
|
|
165
|
+
def get_tool_response(self, tool_name: str) -> Any:
|
|
166
|
+
"""
|
|
167
|
+
Get mock response for a tool call.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
tool_name: Name of the tool
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Configured mock result or default dict
|
|
174
|
+
"""
|
|
175
|
+
return self.tool_responses.get(tool_name, {"status": "success"})
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def was_called(self) -> bool:
|
|
179
|
+
"""Check if any method was called."""
|
|
180
|
+
return len(self.call_history) > 0
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def call_count(self) -> int:
|
|
184
|
+
"""Number of times LLM was called."""
|
|
185
|
+
return len(self.call_history)
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def last_prompt(self) -> Optional[str]:
|
|
189
|
+
"""Get the last prompt that was sent."""
|
|
190
|
+
if self.call_history:
|
|
191
|
+
return self.call_history[-1].get("prompt")
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
def reset(self) -> None:
|
|
195
|
+
"""Reset call history and response index."""
|
|
196
|
+
self.call_history = []
|
|
197
|
+
self._response_index = 0
|
|
198
|
+
|
|
199
|
+
def set_responses(self, responses: List[str]) -> None:
|
|
200
|
+
"""
|
|
201
|
+
Set new responses and reset index.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
responses: New list of responses
|
|
205
|
+
"""
|
|
206
|
+
self.responses = responses
|
|
207
|
+
self._response_index = 0
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class MockVLMClient:
|
|
211
|
+
"""
|
|
212
|
+
Mock VLM client for testing image processing without real API calls.
|
|
213
|
+
|
|
214
|
+
Returns pre-configured text instead of processing images.
|
|
215
|
+
Tracks all calls for test assertions.
|
|
216
|
+
|
|
217
|
+
Example:
|
|
218
|
+
from gaia.testing import MockVLMClient
|
|
219
|
+
|
|
220
|
+
mock_vlm = MockVLMClient(
|
|
221
|
+
extracted_text='{"name": "John", "dob": "1990-01-01"}'
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Inject into agent
|
|
225
|
+
agent = MyAgent()
|
|
226
|
+
agent.vlm = mock_vlm
|
|
227
|
+
|
|
228
|
+
# Test extraction
|
|
229
|
+
result = agent.extract_form("test.png")
|
|
230
|
+
|
|
231
|
+
# Verify VLM was called
|
|
232
|
+
assert mock_vlm.was_called
|
|
233
|
+
assert mock_vlm.call_count == 1
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
def __init__(
|
|
237
|
+
self,
|
|
238
|
+
extracted_text: str = "Mock extracted text",
|
|
239
|
+
extraction_results: Optional[List[str]] = None,
|
|
240
|
+
is_available: bool = True,
|
|
241
|
+
):
|
|
242
|
+
"""
|
|
243
|
+
Initialize mock VLM client.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
extracted_text: Default text to return from extract_from_image()
|
|
247
|
+
extraction_results: List of results to return in sequence
|
|
248
|
+
is_available: Whether check_availability() returns True
|
|
249
|
+
"""
|
|
250
|
+
self.extracted_text = extracted_text
|
|
251
|
+
self.extraction_results = extraction_results or []
|
|
252
|
+
self.is_available = is_available
|
|
253
|
+
self.call_history: List[Dict[str, Any]] = []
|
|
254
|
+
self._result_index = 0
|
|
255
|
+
|
|
256
|
+
def check_availability(self) -> bool:
|
|
257
|
+
"""
|
|
258
|
+
Check if VLM is available.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
Configured is_available value
|
|
262
|
+
"""
|
|
263
|
+
return self.is_available
|
|
264
|
+
|
|
265
|
+
def extract_from_image(
|
|
266
|
+
self,
|
|
267
|
+
image_bytes: bytes,
|
|
268
|
+
prompt: Optional[str] = None,
|
|
269
|
+
**kwargs,
|
|
270
|
+
) -> str:
|
|
271
|
+
"""
|
|
272
|
+
Mock image text extraction.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
image_bytes: Image data (recorded but not processed)
|
|
276
|
+
prompt: Extraction prompt (recorded)
|
|
277
|
+
**kwargs: Additional parameters
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
Pre-configured extracted text
|
|
281
|
+
"""
|
|
282
|
+
self.call_history.append(
|
|
283
|
+
{
|
|
284
|
+
"method": "extract_from_image",
|
|
285
|
+
"image_size": len(image_bytes) if image_bytes else 0,
|
|
286
|
+
"prompt": prompt,
|
|
287
|
+
"kwargs": kwargs,
|
|
288
|
+
"timestamp": time.time(),
|
|
289
|
+
}
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
if self.extraction_results:
|
|
293
|
+
result = self.extraction_results[
|
|
294
|
+
self._result_index % len(self.extraction_results)
|
|
295
|
+
]
|
|
296
|
+
self._result_index += 1
|
|
297
|
+
return result
|
|
298
|
+
|
|
299
|
+
return self.extracted_text
|
|
300
|
+
|
|
301
|
+
def extract_from_file(
|
|
302
|
+
self,
|
|
303
|
+
file_path: str,
|
|
304
|
+
prompt: Optional[str] = None,
|
|
305
|
+
**kwargs,
|
|
306
|
+
) -> str:
|
|
307
|
+
"""
|
|
308
|
+
Mock file-based extraction.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
file_path: Path to image file
|
|
312
|
+
prompt: Extraction prompt
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Pre-configured extracted text
|
|
316
|
+
"""
|
|
317
|
+
self.call_history.append(
|
|
318
|
+
{
|
|
319
|
+
"method": "extract_from_file",
|
|
320
|
+
"file_path": file_path,
|
|
321
|
+
"prompt": prompt,
|
|
322
|
+
"kwargs": kwargs,
|
|
323
|
+
"timestamp": time.time(),
|
|
324
|
+
}
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
if self.extraction_results:
|
|
328
|
+
result = self.extraction_results[
|
|
329
|
+
self._result_index % len(self.extraction_results)
|
|
330
|
+
]
|
|
331
|
+
self._result_index += 1
|
|
332
|
+
return result
|
|
333
|
+
|
|
334
|
+
return self.extracted_text
|
|
335
|
+
|
|
336
|
+
def describe_image(
|
|
337
|
+
self,
|
|
338
|
+
image_bytes: bytes,
|
|
339
|
+
prompt: Optional[str] = None,
|
|
340
|
+
**kwargs,
|
|
341
|
+
) -> str:
|
|
342
|
+
"""
|
|
343
|
+
Mock image description.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
image_bytes: Image data
|
|
347
|
+
prompt: Description prompt
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
Pre-configured text
|
|
351
|
+
"""
|
|
352
|
+
return self.extract_from_image(image_bytes, prompt, **kwargs)
|
|
353
|
+
|
|
354
|
+
@property
|
|
355
|
+
def was_called(self) -> bool:
|
|
356
|
+
"""Check if any extraction method was called."""
|
|
357
|
+
return len(self.call_history) > 0
|
|
358
|
+
|
|
359
|
+
@property
|
|
360
|
+
def call_count(self) -> int:
|
|
361
|
+
"""Number of times extraction was called."""
|
|
362
|
+
return len(self.call_history)
|
|
363
|
+
|
|
364
|
+
@property
|
|
365
|
+
def last_prompt(self) -> Optional[str]:
|
|
366
|
+
"""Get the last prompt that was sent."""
|
|
367
|
+
if self.call_history:
|
|
368
|
+
return self.call_history[-1].get("prompt")
|
|
369
|
+
return None
|
|
370
|
+
|
|
371
|
+
def reset(self) -> None:
|
|
372
|
+
"""Reset call history and result index."""
|
|
373
|
+
self.call_history = []
|
|
374
|
+
self._result_index = 0
|
|
375
|
+
|
|
376
|
+
def set_extracted_text(self, text: str) -> None:
|
|
377
|
+
"""
|
|
378
|
+
Set new extracted text.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
text: New text to return
|
|
382
|
+
"""
|
|
383
|
+
self.extracted_text = text
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
class MockToolExecutor:
|
|
387
|
+
"""
|
|
388
|
+
Mock tool executor for testing tool calls.
|
|
389
|
+
|
|
390
|
+
Tracks tool calls and returns configurable results.
|
|
391
|
+
|
|
392
|
+
Example:
|
|
393
|
+
from gaia.testing import MockToolExecutor
|
|
394
|
+
|
|
395
|
+
executor = MockToolExecutor(
|
|
396
|
+
results={
|
|
397
|
+
"search": {"results": ["item1", "item2"]},
|
|
398
|
+
"create_record": {"id": 123, "status": "created"},
|
|
399
|
+
}
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
result = executor.execute("search", {"query": "test"})
|
|
403
|
+
assert result == {"results": ["item1", "item2"]}
|
|
404
|
+
|
|
405
|
+
assert executor.was_tool_called("search")
|
|
406
|
+
assert executor.get_tool_args("search") == {"query": "test"}
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
def __init__(
|
|
410
|
+
self,
|
|
411
|
+
results: Optional[Dict[str, Any]] = None,
|
|
412
|
+
default_result: Optional[Dict[str, Any]] = None,
|
|
413
|
+
):
|
|
414
|
+
"""
|
|
415
|
+
Initialize mock tool executor.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
results: Dict mapping tool names to their results
|
|
419
|
+
default_result: Default result for unknown tools
|
|
420
|
+
"""
|
|
421
|
+
self.results = results or {}
|
|
422
|
+
self.default_result = default_result or {"status": "success"}
|
|
423
|
+
self.call_history: List[Dict[str, Any]] = []
|
|
424
|
+
|
|
425
|
+
def execute(self, tool_name: str, args: Dict[str, Any]) -> Any:
|
|
426
|
+
"""
|
|
427
|
+
Execute a mock tool.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
tool_name: Name of the tool
|
|
431
|
+
args: Tool arguments
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
Configured result for the tool
|
|
435
|
+
"""
|
|
436
|
+
self.call_history.append(
|
|
437
|
+
{
|
|
438
|
+
"tool": tool_name,
|
|
439
|
+
"args": args,
|
|
440
|
+
"timestamp": time.time(),
|
|
441
|
+
}
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
return self.results.get(tool_name, self.default_result)
|
|
445
|
+
|
|
446
|
+
def was_tool_called(self, tool_name: str) -> bool:
|
|
447
|
+
"""
|
|
448
|
+
Check if a specific tool was called.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
tool_name: Name of the tool
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
True if tool was called at least once
|
|
455
|
+
"""
|
|
456
|
+
return any(call["tool"] == tool_name for call in self.call_history)
|
|
457
|
+
|
|
458
|
+
def get_tool_calls(self, tool_name: str) -> List[Dict[str, Any]]:
|
|
459
|
+
"""
|
|
460
|
+
Get all calls to a specific tool.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
tool_name: Name of the tool
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
List of call records for that tool
|
|
467
|
+
"""
|
|
468
|
+
return [call for call in self.call_history if call["tool"] == tool_name]
|
|
469
|
+
|
|
470
|
+
def get_tool_args(self, tool_name: str, call_index: int = 0) -> Optional[Dict]:
|
|
471
|
+
"""
|
|
472
|
+
Get arguments from a specific tool call.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
tool_name: Name of the tool
|
|
476
|
+
call_index: Which call to get (0 = first call)
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
Arguments dict or None if not found
|
|
480
|
+
"""
|
|
481
|
+
calls = self.get_tool_calls(tool_name)
|
|
482
|
+
if call_index < len(calls):
|
|
483
|
+
return calls[call_index]["args"]
|
|
484
|
+
return None
|
|
485
|
+
|
|
486
|
+
@property
|
|
487
|
+
def tool_names_called(self) -> List[str]:
|
|
488
|
+
"""Get list of all tool names that were called."""
|
|
489
|
+
return list(set(call["tool"] for call in self.call_history))
|
|
490
|
+
|
|
491
|
+
def reset(self) -> None:
|
|
492
|
+
"""Reset call history."""
|
|
493
|
+
self.call_history = []
|