amd-gaia 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/METADATA +223 -223
- amd_gaia-0.15.1.dist-info/RECORD +178 -0
- {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/entry_points.txt +1 -0
- {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/licenses/LICENSE.md +20 -20
- gaia/__init__.py +29 -29
- gaia/agents/__init__.py +19 -19
- gaia/agents/base/__init__.py +9 -9
- gaia/agents/base/agent.py +2177 -2177
- gaia/agents/base/api_agent.py +120 -120
- gaia/agents/base/console.py +1841 -1841
- gaia/agents/base/errors.py +237 -237
- gaia/agents/base/mcp_agent.py +86 -86
- gaia/agents/base/tools.py +83 -83
- gaia/agents/blender/agent.py +556 -556
- gaia/agents/blender/agent_simple.py +133 -135
- gaia/agents/blender/app.py +211 -211
- gaia/agents/blender/app_simple.py +41 -41
- gaia/agents/blender/core/__init__.py +16 -16
- gaia/agents/blender/core/materials.py +506 -506
- gaia/agents/blender/core/objects.py +316 -316
- gaia/agents/blender/core/rendering.py +225 -225
- gaia/agents/blender/core/scene.py +220 -220
- gaia/agents/blender/core/view.py +146 -146
- gaia/agents/chat/__init__.py +9 -9
- gaia/agents/chat/agent.py +835 -835
- gaia/agents/chat/app.py +1058 -1058
- gaia/agents/chat/session.py +508 -508
- gaia/agents/chat/tools/__init__.py +15 -15
- gaia/agents/chat/tools/file_tools.py +96 -96
- gaia/agents/chat/tools/rag_tools.py +1729 -1729
- gaia/agents/chat/tools/shell_tools.py +436 -436
- gaia/agents/code/__init__.py +7 -7
- gaia/agents/code/agent.py +549 -549
- gaia/agents/code/cli.py +377 -0
- gaia/agents/code/models.py +135 -135
- gaia/agents/code/orchestration/__init__.py +24 -24
- gaia/agents/code/orchestration/checklist_executor.py +1763 -1763
- gaia/agents/code/orchestration/checklist_generator.py +713 -713
- gaia/agents/code/orchestration/factories/__init__.py +9 -9
- gaia/agents/code/orchestration/factories/base.py +63 -63
- gaia/agents/code/orchestration/factories/nextjs_factory.py +118 -118
- gaia/agents/code/orchestration/factories/python_factory.py +106 -106
- gaia/agents/code/orchestration/orchestrator.py +841 -841
- gaia/agents/code/orchestration/project_analyzer.py +391 -391
- gaia/agents/code/orchestration/steps/__init__.py +67 -67
- gaia/agents/code/orchestration/steps/base.py +188 -188
- gaia/agents/code/orchestration/steps/error_handler.py +314 -314
- gaia/agents/code/orchestration/steps/nextjs.py +828 -828
- gaia/agents/code/orchestration/steps/python.py +307 -307
- gaia/agents/code/orchestration/template_catalog.py +469 -469
- gaia/agents/code/orchestration/workflows/__init__.py +14 -14
- gaia/agents/code/orchestration/workflows/base.py +80 -80
- gaia/agents/code/orchestration/workflows/nextjs.py +186 -186
- gaia/agents/code/orchestration/workflows/python.py +94 -94
- gaia/agents/code/prompts/__init__.py +11 -11
- gaia/agents/code/prompts/base_prompt.py +77 -77
- gaia/agents/code/prompts/code_patterns.py +2036 -2036
- gaia/agents/code/prompts/nextjs_prompt.py +40 -40
- gaia/agents/code/prompts/python_prompt.py +109 -109
- gaia/agents/code/schema_inference.py +365 -365
- gaia/agents/code/system_prompt.py +41 -41
- gaia/agents/code/tools/__init__.py +42 -42
- gaia/agents/code/tools/cli_tools.py +1138 -1138
- gaia/agents/code/tools/code_formatting.py +319 -319
- gaia/agents/code/tools/code_tools.py +769 -769
- gaia/agents/code/tools/error_fixing.py +1347 -1347
- gaia/agents/code/tools/external_tools.py +180 -180
- gaia/agents/code/tools/file_io.py +845 -845
- gaia/agents/code/tools/prisma_tools.py +190 -190
- gaia/agents/code/tools/project_management.py +1016 -1016
- gaia/agents/code/tools/testing.py +321 -321
- gaia/agents/code/tools/typescript_tools.py +122 -122
- gaia/agents/code/tools/validation_parsing.py +461 -461
- gaia/agents/code/tools/validation_tools.py +806 -806
- gaia/agents/code/tools/web_dev_tools.py +1758 -1758
- gaia/agents/code/validators/__init__.py +16 -16
- gaia/agents/code/validators/antipattern_checker.py +241 -241
- gaia/agents/code/validators/ast_analyzer.py +197 -197
- gaia/agents/code/validators/requirements_validator.py +145 -145
- gaia/agents/code/validators/syntax_validator.py +171 -171
- gaia/agents/docker/__init__.py +7 -7
- gaia/agents/docker/agent.py +642 -642
- gaia/agents/emr/__init__.py +8 -8
- gaia/agents/emr/agent.py +1506 -1506
- gaia/agents/emr/cli.py +1322 -1322
- gaia/agents/emr/constants.py +475 -475
- gaia/agents/emr/dashboard/__init__.py +4 -4
- gaia/agents/emr/dashboard/server.py +1974 -1974
- gaia/agents/jira/__init__.py +11 -11
- gaia/agents/jira/agent.py +894 -894
- gaia/agents/jira/jql_templates.py +299 -299
- gaia/agents/routing/__init__.py +7 -7
- gaia/agents/routing/agent.py +567 -570
- gaia/agents/routing/system_prompt.py +75 -75
- gaia/agents/summarize/__init__.py +11 -0
- gaia/agents/summarize/agent.py +885 -0
- gaia/agents/summarize/prompts.py +129 -0
- gaia/api/__init__.py +23 -23
- gaia/api/agent_registry.py +238 -238
- gaia/api/app.py +305 -305
- gaia/api/openai_server.py +575 -575
- gaia/api/schemas.py +186 -186
- gaia/api/sse_handler.py +373 -373
- gaia/apps/__init__.py +4 -4
- gaia/apps/llm/__init__.py +6 -6
- gaia/apps/llm/app.py +173 -169
- gaia/apps/summarize/app.py +116 -633
- gaia/apps/summarize/html_viewer.py +133 -133
- gaia/apps/summarize/pdf_formatter.py +284 -284
- gaia/audio/__init__.py +2 -2
- gaia/audio/audio_client.py +439 -439
- gaia/audio/audio_recorder.py +269 -269
- gaia/audio/kokoro_tts.py +599 -599
- gaia/audio/whisper_asr.py +432 -432
- gaia/chat/__init__.py +16 -16
- gaia/chat/app.py +430 -430
- gaia/chat/prompts.py +522 -522
- gaia/chat/sdk.py +1228 -1225
- gaia/cli.py +5481 -5621
- gaia/database/__init__.py +10 -10
- gaia/database/agent.py +176 -176
- gaia/database/mixin.py +290 -290
- gaia/database/testing.py +64 -64
- gaia/eval/batch_experiment.py +2332 -2332
- gaia/eval/claude.py +542 -542
- gaia/eval/config.py +37 -37
- gaia/eval/email_generator.py +512 -512
- gaia/eval/eval.py +3179 -3179
- gaia/eval/groundtruth.py +1130 -1130
- gaia/eval/transcript_generator.py +582 -582
- gaia/eval/webapp/README.md +167 -167
- gaia/eval/webapp/package-lock.json +875 -875
- gaia/eval/webapp/package.json +20 -20
- gaia/eval/webapp/public/app.js +3402 -3402
- gaia/eval/webapp/public/index.html +87 -87
- gaia/eval/webapp/public/styles.css +3661 -3661
- gaia/eval/webapp/server.js +415 -415
- gaia/eval/webapp/test-setup.js +72 -72
- gaia/llm/__init__.py +9 -2
- gaia/llm/base_client.py +60 -0
- gaia/llm/exceptions.py +12 -0
- gaia/llm/factory.py +70 -0
- gaia/llm/lemonade_client.py +3236 -3221
- gaia/llm/lemonade_manager.py +294 -294
- gaia/llm/providers/__init__.py +9 -0
- gaia/llm/providers/claude.py +108 -0
- gaia/llm/providers/lemonade.py +120 -0
- gaia/llm/providers/openai_provider.py +79 -0
- gaia/llm/vlm_client.py +382 -382
- gaia/logger.py +189 -189
- gaia/mcp/agent_mcp_server.py +245 -245
- gaia/mcp/blender_mcp_client.py +138 -138
- gaia/mcp/blender_mcp_server.py +648 -648
- gaia/mcp/context7_cache.py +332 -332
- gaia/mcp/external_services.py +518 -518
- gaia/mcp/mcp_bridge.py +811 -550
- gaia/mcp/servers/__init__.py +6 -6
- gaia/mcp/servers/docker_mcp.py +83 -83
- gaia/perf_analysis.py +361 -0
- gaia/rag/__init__.py +10 -10
- gaia/rag/app.py +293 -293
- gaia/rag/demo.py +304 -304
- gaia/rag/pdf_utils.py +235 -235
- gaia/rag/sdk.py +2194 -2194
- gaia/security.py +163 -163
- gaia/talk/app.py +289 -289
- gaia/talk/sdk.py +538 -538
- gaia/testing/__init__.py +87 -87
- gaia/testing/assertions.py +330 -330
- gaia/testing/fixtures.py +333 -333
- gaia/testing/mocks.py +493 -493
- gaia/util.py +46 -46
- gaia/utils/__init__.py +33 -33
- gaia/utils/file_watcher.py +675 -675
- gaia/utils/parsing.py +223 -223
- gaia/version.py +100 -100
- amd_gaia-0.14.3.dist-info/RECORD +0 -168
- gaia/agents/code/app.py +0 -266
- gaia/llm/llm_client.py +0 -729
- {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/WHEEL +0 -0
- {amd_gaia-0.14.3.dist-info → amd_gaia-0.15.1.dist-info}/top_level.txt +0 -0
gaia/testing/assertions.py
CHANGED
|
@@ -1,330 +1,330 @@
|
|
|
1
|
-
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: MIT
|
|
3
|
-
|
|
4
|
-
"""Assertion helpers for testing GAIA agents."""
|
|
5
|
-
|
|
6
|
-
from typing import Any, Dict, List, Optional, Union
|
|
7
|
-
|
|
8
|
-
from gaia.testing.mocks import MockLLMProvider, MockToolExecutor, MockVLMClient
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def assert_llm_called(
|
|
12
|
-
mock_llm: MockLLMProvider,
|
|
13
|
-
times: Optional[int] = None,
|
|
14
|
-
min_times: Optional[int] = None,
|
|
15
|
-
max_times: Optional[int] = None,
|
|
16
|
-
) -> None:
|
|
17
|
-
"""
|
|
18
|
-
Assert that the mock LLM was called.
|
|
19
|
-
|
|
20
|
-
Args:
|
|
21
|
-
mock_llm: MockLLMProvider instance
|
|
22
|
-
times: Exact number of expected calls (optional)
|
|
23
|
-
min_times: Minimum number of calls (optional)
|
|
24
|
-
max_times: Maximum number of calls (optional)
|
|
25
|
-
|
|
26
|
-
Raises:
|
|
27
|
-
AssertionError: If call count doesn't match expectations
|
|
28
|
-
|
|
29
|
-
Example:
|
|
30
|
-
from gaia.testing import MockLLMProvider, assert_llm_called
|
|
31
|
-
|
|
32
|
-
mock_llm = MockLLMProvider(responses=["Hello"])
|
|
33
|
-
mock_llm.generate("Test")
|
|
34
|
-
|
|
35
|
-
assert_llm_called(mock_llm) # At least once
|
|
36
|
-
assert_llm_called(mock_llm, times=1) # Exactly once
|
|
37
|
-
assert_llm_called(mock_llm, min_times=1, max_times=5) # Range
|
|
38
|
-
"""
|
|
39
|
-
call_count = mock_llm.call_count
|
|
40
|
-
|
|
41
|
-
if times is not None:
|
|
42
|
-
assert call_count == times, (
|
|
43
|
-
f"Expected LLM to be called {times} time(s), "
|
|
44
|
-
f"but was called {call_count} time(s)"
|
|
45
|
-
)
|
|
46
|
-
else:
|
|
47
|
-
if min_times is None and max_times is None:
|
|
48
|
-
# Just check it was called at least once
|
|
49
|
-
assert call_count > 0, "Expected LLM to be called at least once"
|
|
50
|
-
|
|
51
|
-
if min_times is not None:
|
|
52
|
-
assert call_count >= min_times, (
|
|
53
|
-
f"Expected LLM to be called at least {min_times} time(s), "
|
|
54
|
-
f"but was called {call_count} time(s)"
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
if max_times is not None:
|
|
58
|
-
assert call_count <= max_times, (
|
|
59
|
-
f"Expected LLM to be called at most {max_times} time(s), "
|
|
60
|
-
f"but was called {call_count} time(s)"
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def assert_llm_prompt_contains(
|
|
65
|
-
mock_llm: MockLLMProvider,
|
|
66
|
-
text: str,
|
|
67
|
-
call_index: int = -1,
|
|
68
|
-
) -> None:
|
|
69
|
-
"""
|
|
70
|
-
Assert that an LLM prompt contains specific text.
|
|
71
|
-
|
|
72
|
-
Args:
|
|
73
|
-
mock_llm: MockLLMProvider instance
|
|
74
|
-
text: Text that should be in the prompt
|
|
75
|
-
call_index: Which call to check (-1 = last call, 0 = first call)
|
|
76
|
-
|
|
77
|
-
Raises:
|
|
78
|
-
AssertionError: If text not found in prompt
|
|
79
|
-
|
|
80
|
-
Example:
|
|
81
|
-
assert_llm_prompt_contains(mock_llm, "customer")
|
|
82
|
-
assert_llm_prompt_contains(mock_llm, "search", call_index=0)
|
|
83
|
-
"""
|
|
84
|
-
assert mock_llm.call_history, "LLM was never called"
|
|
85
|
-
|
|
86
|
-
call = mock_llm.call_history[call_index]
|
|
87
|
-
prompt = call.get("prompt", "")
|
|
88
|
-
|
|
89
|
-
assert (
|
|
90
|
-
text in prompt
|
|
91
|
-
), f"Expected prompt to contain '{text}', but prompt was:\n{prompt[:500]}"
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def assert_vlm_called(
|
|
95
|
-
mock_vlm: MockVLMClient,
|
|
96
|
-
times: Optional[int] = None,
|
|
97
|
-
) -> None:
|
|
98
|
-
"""
|
|
99
|
-
Assert that the mock VLM was called.
|
|
100
|
-
|
|
101
|
-
Args:
|
|
102
|
-
mock_vlm: MockVLMClient instance
|
|
103
|
-
times: Exact number of expected calls (optional)
|
|
104
|
-
|
|
105
|
-
Raises:
|
|
106
|
-
AssertionError: If call count doesn't match
|
|
107
|
-
|
|
108
|
-
Example:
|
|
109
|
-
assert_vlm_called(mock_vlm) # At least once
|
|
110
|
-
assert_vlm_called(mock_vlm, times=2) # Exactly twice
|
|
111
|
-
"""
|
|
112
|
-
call_count = mock_vlm.call_count
|
|
113
|
-
|
|
114
|
-
if times is not None:
|
|
115
|
-
assert call_count == times, (
|
|
116
|
-
f"Expected VLM to be called {times} time(s), "
|
|
117
|
-
f"but was called {call_count} time(s)"
|
|
118
|
-
)
|
|
119
|
-
else:
|
|
120
|
-
assert call_count > 0, "Expected VLM to be called at least once"
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def assert_tool_called(
|
|
124
|
-
executor: MockToolExecutor,
|
|
125
|
-
tool_name: str,
|
|
126
|
-
times: Optional[int] = None,
|
|
127
|
-
) -> None:
|
|
128
|
-
"""
|
|
129
|
-
Assert that a specific tool was called.
|
|
130
|
-
|
|
131
|
-
Args:
|
|
132
|
-
executor: MockToolExecutor instance
|
|
133
|
-
tool_name: Name of the tool
|
|
134
|
-
times: Exact number of expected calls (optional)
|
|
135
|
-
|
|
136
|
-
Raises:
|
|
137
|
-
AssertionError: If tool wasn't called or count doesn't match
|
|
138
|
-
|
|
139
|
-
Example:
|
|
140
|
-
from gaia.testing import MockToolExecutor, assert_tool_called
|
|
141
|
-
|
|
142
|
-
executor = MockToolExecutor()
|
|
143
|
-
executor.execute("search", {"query": "test"})
|
|
144
|
-
|
|
145
|
-
assert_tool_called(executor, "search")
|
|
146
|
-
assert_tool_called(executor, "search", times=1)
|
|
147
|
-
"""
|
|
148
|
-
calls = executor.get_tool_calls(tool_name)
|
|
149
|
-
|
|
150
|
-
if times is not None:
|
|
151
|
-
assert len(calls) == times, (
|
|
152
|
-
f"Expected tool '{tool_name}' to be called {times} time(s), "
|
|
153
|
-
f"but was called {len(calls)} time(s)"
|
|
154
|
-
)
|
|
155
|
-
else:
|
|
156
|
-
assert len(calls) > 0, (
|
|
157
|
-
f"Expected tool '{tool_name}' to be called, but it was never called. "
|
|
158
|
-
f"Tools called: {executor.tool_names_called}"
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
def assert_tool_args(
|
|
163
|
-
executor: MockToolExecutor,
|
|
164
|
-
tool_name: str,
|
|
165
|
-
expected_args: Dict[str, Any],
|
|
166
|
-
call_index: int = 0,
|
|
167
|
-
) -> None:
|
|
168
|
-
"""
|
|
169
|
-
Assert that a tool was called with specific arguments.
|
|
170
|
-
|
|
171
|
-
Args:
|
|
172
|
-
executor: MockToolExecutor instance
|
|
173
|
-
tool_name: Name of the tool
|
|
174
|
-
expected_args: Expected arguments (subset matching)
|
|
175
|
-
call_index: Which call to check (0 = first call)
|
|
176
|
-
|
|
177
|
-
Raises:
|
|
178
|
-
AssertionError: If arguments don't match
|
|
179
|
-
|
|
180
|
-
Example:
|
|
181
|
-
executor.execute("search", {"query": "test", "limit": 10})
|
|
182
|
-
assert_tool_args(executor, "search", {"query": "test"})
|
|
183
|
-
"""
|
|
184
|
-
actual_args = executor.get_tool_args(tool_name, call_index)
|
|
185
|
-
|
|
186
|
-
assert (
|
|
187
|
-
actual_args is not None
|
|
188
|
-
), f"Tool '{tool_name}' was not called (or call_index {call_index} out of range)"
|
|
189
|
-
|
|
190
|
-
for key, expected_value in expected_args.items():
|
|
191
|
-
assert key in actual_args, (
|
|
192
|
-
f"Expected argument '{key}' not found in tool call. "
|
|
193
|
-
f"Actual args: {actual_args}"
|
|
194
|
-
)
|
|
195
|
-
assert actual_args[key] == expected_value, (
|
|
196
|
-
f"Argument '{key}' mismatch. "
|
|
197
|
-
f"Expected: {expected_value}, Actual: {actual_args[key]}"
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
def assert_result_has_keys(
|
|
202
|
-
result: Dict[str, Any],
|
|
203
|
-
keys: List[str],
|
|
204
|
-
) -> None:
|
|
205
|
-
"""
|
|
206
|
-
Assert that a result dictionary has specific keys.
|
|
207
|
-
|
|
208
|
-
Args:
|
|
209
|
-
result: Result dictionary to check
|
|
210
|
-
keys: List of required keys
|
|
211
|
-
|
|
212
|
-
Raises:
|
|
213
|
-
AssertionError: If any key is missing
|
|
214
|
-
|
|
215
|
-
Example:
|
|
216
|
-
result = agent.process_query("test")
|
|
217
|
-
assert_result_has_keys(result, ["answer", "steps_taken"])
|
|
218
|
-
"""
|
|
219
|
-
assert isinstance(
|
|
220
|
-
result, dict
|
|
221
|
-
), f"Expected result to be dict, got {type(result).__name__}"
|
|
222
|
-
|
|
223
|
-
missing_keys = [key for key in keys if key not in result]
|
|
224
|
-
if missing_keys:
|
|
225
|
-
raise AssertionError(
|
|
226
|
-
f"Result missing required keys: {missing_keys}. "
|
|
227
|
-
f"Available keys: {list(result.keys())}"
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
def assert_result_value(
|
|
232
|
-
result: Dict[str, Any],
|
|
233
|
-
key: str,
|
|
234
|
-
expected: Any,
|
|
235
|
-
) -> None:
|
|
236
|
-
"""
|
|
237
|
-
Assert that a result has a specific value for a key.
|
|
238
|
-
|
|
239
|
-
Args:
|
|
240
|
-
result: Result dictionary
|
|
241
|
-
key: Key to check
|
|
242
|
-
expected: Expected value
|
|
243
|
-
|
|
244
|
-
Raises:
|
|
245
|
-
AssertionError: If value doesn't match
|
|
246
|
-
|
|
247
|
-
Example:
|
|
248
|
-
assert_result_value(result, "status", "success")
|
|
249
|
-
"""
|
|
250
|
-
assert key in result, f"Key '{key}' not found in result: {list(result.keys())}"
|
|
251
|
-
actual = result[key]
|
|
252
|
-
assert (
|
|
253
|
-
actual == expected
|
|
254
|
-
), f"Value mismatch for key '{key}'. Expected: {expected}, Actual: {actual}"
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
def assert_agent_completed(
|
|
258
|
-
result: Union[Dict[str, Any], str],
|
|
259
|
-
has_answer: bool = True,
|
|
260
|
-
) -> None:
|
|
261
|
-
"""
|
|
262
|
-
Assert that an agent completed processing successfully.
|
|
263
|
-
|
|
264
|
-
Args:
|
|
265
|
-
result: Result from agent.process_query()
|
|
266
|
-
has_answer: Whether to check for an 'answer' key
|
|
267
|
-
|
|
268
|
-
Raises:
|
|
269
|
-
AssertionError: If agent didn't complete properly
|
|
270
|
-
|
|
271
|
-
Example:
|
|
272
|
-
result = agent.process_query("test")
|
|
273
|
-
assert_agent_completed(result)
|
|
274
|
-
"""
|
|
275
|
-
# Handle string results (some agents return strings directly)
|
|
276
|
-
if isinstance(result, str):
|
|
277
|
-
assert len(result) > 0, "Agent returned empty string"
|
|
278
|
-
return
|
|
279
|
-
|
|
280
|
-
assert isinstance(
|
|
281
|
-
result, dict
|
|
282
|
-
), f"Expected result to be dict or str, got {type(result).__name__}"
|
|
283
|
-
|
|
284
|
-
# Check for error indicators
|
|
285
|
-
if "error" in result and result["error"]:
|
|
286
|
-
raise AssertionError(f"Agent returned error: {result['error']}")
|
|
287
|
-
|
|
288
|
-
if "status" in result and result["status"] == "error":
|
|
289
|
-
error_msg = result.get("message", result.get("error", "Unknown error"))
|
|
290
|
-
raise AssertionError(f"Agent returned error status: {error_msg}")
|
|
291
|
-
|
|
292
|
-
# Check for answer if required
|
|
293
|
-
if has_answer:
|
|
294
|
-
assert "answer" in result or "response" in result or "result" in result, (
|
|
295
|
-
"Agent result missing answer/response/result key. "
|
|
296
|
-
f"Keys present: {list(result.keys())}"
|
|
297
|
-
)
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
def assert_no_errors(result: Dict[str, Any]) -> None:
|
|
301
|
-
"""
|
|
302
|
-
Assert that a result contains no errors.
|
|
303
|
-
|
|
304
|
-
Args:
|
|
305
|
-
result: Result dictionary
|
|
306
|
-
|
|
307
|
-
Raises:
|
|
308
|
-
AssertionError: If result contains error indicators
|
|
309
|
-
|
|
310
|
-
Example:
|
|
311
|
-
result = agent.process_query("test")
|
|
312
|
-
assert_no_errors(result)
|
|
313
|
-
"""
|
|
314
|
-
if not isinstance(result, dict):
|
|
315
|
-
return # Non-dict results don't have error keys
|
|
316
|
-
|
|
317
|
-
# Check various error patterns
|
|
318
|
-
if "error" in result and result["error"]:
|
|
319
|
-
raise AssertionError(f"Result contains error: {result['error']}")
|
|
320
|
-
|
|
321
|
-
if "errors" in result and result["errors"]:
|
|
322
|
-
raise AssertionError(f"Result contains errors: {result['errors']}")
|
|
323
|
-
|
|
324
|
-
if result.get("status") == "error":
|
|
325
|
-
msg = result.get("message", result.get("error", "Unknown"))
|
|
326
|
-
raise AssertionError(f"Result has error status: {msg}")
|
|
327
|
-
|
|
328
|
-
if result.get("success") is False:
|
|
329
|
-
msg = result.get("message", result.get("error", "Unknown"))
|
|
330
|
-
raise AssertionError(f"Result indicates failure: {msg}")
|
|
1
|
+
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""Assertion helpers for testing GAIA agents."""
|
|
5
|
+
|
|
6
|
+
from typing import Any, Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
from gaia.testing.mocks import MockLLMProvider, MockToolExecutor, MockVLMClient
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def assert_llm_called(
|
|
12
|
+
mock_llm: MockLLMProvider,
|
|
13
|
+
times: Optional[int] = None,
|
|
14
|
+
min_times: Optional[int] = None,
|
|
15
|
+
max_times: Optional[int] = None,
|
|
16
|
+
) -> None:
|
|
17
|
+
"""
|
|
18
|
+
Assert that the mock LLM was called.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
mock_llm: MockLLMProvider instance
|
|
22
|
+
times: Exact number of expected calls (optional)
|
|
23
|
+
min_times: Minimum number of calls (optional)
|
|
24
|
+
max_times: Maximum number of calls (optional)
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
AssertionError: If call count doesn't match expectations
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
from gaia.testing import MockLLMProvider, assert_llm_called
|
|
31
|
+
|
|
32
|
+
mock_llm = MockLLMProvider(responses=["Hello"])
|
|
33
|
+
mock_llm.generate("Test")
|
|
34
|
+
|
|
35
|
+
assert_llm_called(mock_llm) # At least once
|
|
36
|
+
assert_llm_called(mock_llm, times=1) # Exactly once
|
|
37
|
+
assert_llm_called(mock_llm, min_times=1, max_times=5) # Range
|
|
38
|
+
"""
|
|
39
|
+
call_count = mock_llm.call_count
|
|
40
|
+
|
|
41
|
+
if times is not None:
|
|
42
|
+
assert call_count == times, (
|
|
43
|
+
f"Expected LLM to be called {times} time(s), "
|
|
44
|
+
f"but was called {call_count} time(s)"
|
|
45
|
+
)
|
|
46
|
+
else:
|
|
47
|
+
if min_times is None and max_times is None:
|
|
48
|
+
# Just check it was called at least once
|
|
49
|
+
assert call_count > 0, "Expected LLM to be called at least once"
|
|
50
|
+
|
|
51
|
+
if min_times is not None:
|
|
52
|
+
assert call_count >= min_times, (
|
|
53
|
+
f"Expected LLM to be called at least {min_times} time(s), "
|
|
54
|
+
f"but was called {call_count} time(s)"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if max_times is not None:
|
|
58
|
+
assert call_count <= max_times, (
|
|
59
|
+
f"Expected LLM to be called at most {max_times} time(s), "
|
|
60
|
+
f"but was called {call_count} time(s)"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def assert_llm_prompt_contains(
|
|
65
|
+
mock_llm: MockLLMProvider,
|
|
66
|
+
text: str,
|
|
67
|
+
call_index: int = -1,
|
|
68
|
+
) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Assert that an LLM prompt contains specific text.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
mock_llm: MockLLMProvider instance
|
|
74
|
+
text: Text that should be in the prompt
|
|
75
|
+
call_index: Which call to check (-1 = last call, 0 = first call)
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
AssertionError: If text not found in prompt
|
|
79
|
+
|
|
80
|
+
Example:
|
|
81
|
+
assert_llm_prompt_contains(mock_llm, "customer")
|
|
82
|
+
assert_llm_prompt_contains(mock_llm, "search", call_index=0)
|
|
83
|
+
"""
|
|
84
|
+
assert mock_llm.call_history, "LLM was never called"
|
|
85
|
+
|
|
86
|
+
call = mock_llm.call_history[call_index]
|
|
87
|
+
prompt = call.get("prompt", "")
|
|
88
|
+
|
|
89
|
+
assert (
|
|
90
|
+
text in prompt
|
|
91
|
+
), f"Expected prompt to contain '{text}', but prompt was:\n{prompt[:500]}"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def assert_vlm_called(
|
|
95
|
+
mock_vlm: MockVLMClient,
|
|
96
|
+
times: Optional[int] = None,
|
|
97
|
+
) -> None:
|
|
98
|
+
"""
|
|
99
|
+
Assert that the mock VLM was called.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
mock_vlm: MockVLMClient instance
|
|
103
|
+
times: Exact number of expected calls (optional)
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
AssertionError: If call count doesn't match
|
|
107
|
+
|
|
108
|
+
Example:
|
|
109
|
+
assert_vlm_called(mock_vlm) # At least once
|
|
110
|
+
assert_vlm_called(mock_vlm, times=2) # Exactly twice
|
|
111
|
+
"""
|
|
112
|
+
call_count = mock_vlm.call_count
|
|
113
|
+
|
|
114
|
+
if times is not None:
|
|
115
|
+
assert call_count == times, (
|
|
116
|
+
f"Expected VLM to be called {times} time(s), "
|
|
117
|
+
f"but was called {call_count} time(s)"
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
assert call_count > 0, "Expected VLM to be called at least once"
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def assert_tool_called(
|
|
124
|
+
executor: MockToolExecutor,
|
|
125
|
+
tool_name: str,
|
|
126
|
+
times: Optional[int] = None,
|
|
127
|
+
) -> None:
|
|
128
|
+
"""
|
|
129
|
+
Assert that a specific tool was called.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
executor: MockToolExecutor instance
|
|
133
|
+
tool_name: Name of the tool
|
|
134
|
+
times: Exact number of expected calls (optional)
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
AssertionError: If tool wasn't called or count doesn't match
|
|
138
|
+
|
|
139
|
+
Example:
|
|
140
|
+
from gaia.testing import MockToolExecutor, assert_tool_called
|
|
141
|
+
|
|
142
|
+
executor = MockToolExecutor()
|
|
143
|
+
executor.execute("search", {"query": "test"})
|
|
144
|
+
|
|
145
|
+
assert_tool_called(executor, "search")
|
|
146
|
+
assert_tool_called(executor, "search", times=1)
|
|
147
|
+
"""
|
|
148
|
+
calls = executor.get_tool_calls(tool_name)
|
|
149
|
+
|
|
150
|
+
if times is not None:
|
|
151
|
+
assert len(calls) == times, (
|
|
152
|
+
f"Expected tool '{tool_name}' to be called {times} time(s), "
|
|
153
|
+
f"but was called {len(calls)} time(s)"
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
assert len(calls) > 0, (
|
|
157
|
+
f"Expected tool '{tool_name}' to be called, but it was never called. "
|
|
158
|
+
f"Tools called: {executor.tool_names_called}"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def assert_tool_args(
|
|
163
|
+
executor: MockToolExecutor,
|
|
164
|
+
tool_name: str,
|
|
165
|
+
expected_args: Dict[str, Any],
|
|
166
|
+
call_index: int = 0,
|
|
167
|
+
) -> None:
|
|
168
|
+
"""
|
|
169
|
+
Assert that a tool was called with specific arguments.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
executor: MockToolExecutor instance
|
|
173
|
+
tool_name: Name of the tool
|
|
174
|
+
expected_args: Expected arguments (subset matching)
|
|
175
|
+
call_index: Which call to check (0 = first call)
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
AssertionError: If arguments don't match
|
|
179
|
+
|
|
180
|
+
Example:
|
|
181
|
+
executor.execute("search", {"query": "test", "limit": 10})
|
|
182
|
+
assert_tool_args(executor, "search", {"query": "test"})
|
|
183
|
+
"""
|
|
184
|
+
actual_args = executor.get_tool_args(tool_name, call_index)
|
|
185
|
+
|
|
186
|
+
assert (
|
|
187
|
+
actual_args is not None
|
|
188
|
+
), f"Tool '{tool_name}' was not called (or call_index {call_index} out of range)"
|
|
189
|
+
|
|
190
|
+
for key, expected_value in expected_args.items():
|
|
191
|
+
assert key in actual_args, (
|
|
192
|
+
f"Expected argument '{key}' not found in tool call. "
|
|
193
|
+
f"Actual args: {actual_args}"
|
|
194
|
+
)
|
|
195
|
+
assert actual_args[key] == expected_value, (
|
|
196
|
+
f"Argument '{key}' mismatch. "
|
|
197
|
+
f"Expected: {expected_value}, Actual: {actual_args[key]}"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def assert_result_has_keys(
|
|
202
|
+
result: Dict[str, Any],
|
|
203
|
+
keys: List[str],
|
|
204
|
+
) -> None:
|
|
205
|
+
"""
|
|
206
|
+
Assert that a result dictionary has specific keys.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
result: Result dictionary to check
|
|
210
|
+
keys: List of required keys
|
|
211
|
+
|
|
212
|
+
Raises:
|
|
213
|
+
AssertionError: If any key is missing
|
|
214
|
+
|
|
215
|
+
Example:
|
|
216
|
+
result = agent.process_query("test")
|
|
217
|
+
assert_result_has_keys(result, ["answer", "steps_taken"])
|
|
218
|
+
"""
|
|
219
|
+
assert isinstance(
|
|
220
|
+
result, dict
|
|
221
|
+
), f"Expected result to be dict, got {type(result).__name__}"
|
|
222
|
+
|
|
223
|
+
missing_keys = [key for key in keys if key not in result]
|
|
224
|
+
if missing_keys:
|
|
225
|
+
raise AssertionError(
|
|
226
|
+
f"Result missing required keys: {missing_keys}. "
|
|
227
|
+
f"Available keys: {list(result.keys())}"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def assert_result_value(
|
|
232
|
+
result: Dict[str, Any],
|
|
233
|
+
key: str,
|
|
234
|
+
expected: Any,
|
|
235
|
+
) -> None:
|
|
236
|
+
"""
|
|
237
|
+
Assert that a result has a specific value for a key.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
result: Result dictionary
|
|
241
|
+
key: Key to check
|
|
242
|
+
expected: Expected value
|
|
243
|
+
|
|
244
|
+
Raises:
|
|
245
|
+
AssertionError: If value doesn't match
|
|
246
|
+
|
|
247
|
+
Example:
|
|
248
|
+
assert_result_value(result, "status", "success")
|
|
249
|
+
"""
|
|
250
|
+
assert key in result, f"Key '{key}' not found in result: {list(result.keys())}"
|
|
251
|
+
actual = result[key]
|
|
252
|
+
assert (
|
|
253
|
+
actual == expected
|
|
254
|
+
), f"Value mismatch for key '{key}'. Expected: {expected}, Actual: {actual}"
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def assert_agent_completed(
|
|
258
|
+
result: Union[Dict[str, Any], str],
|
|
259
|
+
has_answer: bool = True,
|
|
260
|
+
) -> None:
|
|
261
|
+
"""
|
|
262
|
+
Assert that an agent completed processing successfully.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
result: Result from agent.process_query()
|
|
266
|
+
has_answer: Whether to check for an 'answer' key
|
|
267
|
+
|
|
268
|
+
Raises:
|
|
269
|
+
AssertionError: If agent didn't complete properly
|
|
270
|
+
|
|
271
|
+
Example:
|
|
272
|
+
result = agent.process_query("test")
|
|
273
|
+
assert_agent_completed(result)
|
|
274
|
+
"""
|
|
275
|
+
# Handle string results (some agents return strings directly)
|
|
276
|
+
if isinstance(result, str):
|
|
277
|
+
assert len(result) > 0, "Agent returned empty string"
|
|
278
|
+
return
|
|
279
|
+
|
|
280
|
+
assert isinstance(
|
|
281
|
+
result, dict
|
|
282
|
+
), f"Expected result to be dict or str, got {type(result).__name__}"
|
|
283
|
+
|
|
284
|
+
# Check for error indicators
|
|
285
|
+
if "error" in result and result["error"]:
|
|
286
|
+
raise AssertionError(f"Agent returned error: {result['error']}")
|
|
287
|
+
|
|
288
|
+
if "status" in result and result["status"] == "error":
|
|
289
|
+
error_msg = result.get("message", result.get("error", "Unknown error"))
|
|
290
|
+
raise AssertionError(f"Agent returned error status: {error_msg}")
|
|
291
|
+
|
|
292
|
+
# Check for answer if required
|
|
293
|
+
if has_answer:
|
|
294
|
+
assert "answer" in result or "response" in result or "result" in result, (
|
|
295
|
+
"Agent result missing answer/response/result key. "
|
|
296
|
+
f"Keys present: {list(result.keys())}"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def assert_no_errors(result: Dict[str, Any]) -> None:
|
|
301
|
+
"""
|
|
302
|
+
Assert that a result contains no errors.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
result: Result dictionary
|
|
306
|
+
|
|
307
|
+
Raises:
|
|
308
|
+
AssertionError: If result contains error indicators
|
|
309
|
+
|
|
310
|
+
Example:
|
|
311
|
+
result = agent.process_query("test")
|
|
312
|
+
assert_no_errors(result)
|
|
313
|
+
"""
|
|
314
|
+
if not isinstance(result, dict):
|
|
315
|
+
return # Non-dict results don't have error keys
|
|
316
|
+
|
|
317
|
+
# Check various error patterns
|
|
318
|
+
if "error" in result and result["error"]:
|
|
319
|
+
raise AssertionError(f"Result contains error: {result['error']}")
|
|
320
|
+
|
|
321
|
+
if "errors" in result and result["errors"]:
|
|
322
|
+
raise AssertionError(f"Result contains errors: {result['errors']}")
|
|
323
|
+
|
|
324
|
+
if result.get("status") == "error":
|
|
325
|
+
msg = result.get("message", result.get("error", "Unknown"))
|
|
326
|
+
raise AssertionError(f"Result has error status: {msg}")
|
|
327
|
+
|
|
328
|
+
if result.get("success") is False:
|
|
329
|
+
msg = result.get("message", result.get("error", "Unknown"))
|
|
330
|
+
raise AssertionError(f"Result indicates failure: {msg}")
|