amd-gaia 0.15.0__py3-none-any.whl → 0.15.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/METADATA +222 -223
- amd_gaia-0.15.2.dist-info/RECORD +182 -0
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/WHEEL +1 -1
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/entry_points.txt +1 -0
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/licenses/LICENSE.md +20 -20
- gaia/__init__.py +29 -29
- gaia/agents/__init__.py +19 -19
- gaia/agents/base/__init__.py +9 -9
- gaia/agents/base/agent.py +2132 -2177
- gaia/agents/base/api_agent.py +119 -120
- gaia/agents/base/console.py +1967 -1841
- gaia/agents/base/errors.py +237 -237
- gaia/agents/base/mcp_agent.py +86 -86
- gaia/agents/base/tools.py +88 -83
- gaia/agents/blender/__init__.py +7 -0
- gaia/agents/blender/agent.py +553 -556
- gaia/agents/blender/agent_simple.py +133 -135
- gaia/agents/blender/app.py +211 -211
- gaia/agents/blender/app_simple.py +41 -41
- gaia/agents/blender/core/__init__.py +16 -16
- gaia/agents/blender/core/materials.py +506 -506
- gaia/agents/blender/core/objects.py +316 -316
- gaia/agents/blender/core/rendering.py +225 -225
- gaia/agents/blender/core/scene.py +220 -220
- gaia/agents/blender/core/view.py +146 -146
- gaia/agents/chat/__init__.py +9 -9
- gaia/agents/chat/agent.py +809 -835
- gaia/agents/chat/app.py +1065 -1058
- gaia/agents/chat/session.py +508 -508
- gaia/agents/chat/tools/__init__.py +15 -15
- gaia/agents/chat/tools/file_tools.py +96 -96
- gaia/agents/chat/tools/rag_tools.py +1744 -1729
- gaia/agents/chat/tools/shell_tools.py +437 -436
- gaia/agents/code/__init__.py +7 -7
- gaia/agents/code/agent.py +549 -549
- gaia/agents/code/cli.py +377 -0
- gaia/agents/code/models.py +135 -135
- gaia/agents/code/orchestration/__init__.py +24 -24
- gaia/agents/code/orchestration/checklist_executor.py +1763 -1763
- gaia/agents/code/orchestration/checklist_generator.py +713 -713
- gaia/agents/code/orchestration/factories/__init__.py +9 -9
- gaia/agents/code/orchestration/factories/base.py +63 -63
- gaia/agents/code/orchestration/factories/nextjs_factory.py +118 -118
- gaia/agents/code/orchestration/factories/python_factory.py +106 -106
- gaia/agents/code/orchestration/orchestrator.py +841 -841
- gaia/agents/code/orchestration/project_analyzer.py +391 -391
- gaia/agents/code/orchestration/steps/__init__.py +67 -67
- gaia/agents/code/orchestration/steps/base.py +188 -188
- gaia/agents/code/orchestration/steps/error_handler.py +314 -314
- gaia/agents/code/orchestration/steps/nextjs.py +828 -828
- gaia/agents/code/orchestration/steps/python.py +307 -307
- gaia/agents/code/orchestration/template_catalog.py +469 -469
- gaia/agents/code/orchestration/workflows/__init__.py +14 -14
- gaia/agents/code/orchestration/workflows/base.py +80 -80
- gaia/agents/code/orchestration/workflows/nextjs.py +186 -186
- gaia/agents/code/orchestration/workflows/python.py +94 -94
- gaia/agents/code/prompts/__init__.py +11 -11
- gaia/agents/code/prompts/base_prompt.py +77 -77
- gaia/agents/code/prompts/code_patterns.py +2034 -2036
- gaia/agents/code/prompts/nextjs_prompt.py +40 -40
- gaia/agents/code/prompts/python_prompt.py +109 -109
- gaia/agents/code/schema_inference.py +365 -365
- gaia/agents/code/system_prompt.py +41 -41
- gaia/agents/code/tools/__init__.py +42 -42
- gaia/agents/code/tools/cli_tools.py +1138 -1138
- gaia/agents/code/tools/code_formatting.py +319 -319
- gaia/agents/code/tools/code_tools.py +769 -769
- gaia/agents/code/tools/error_fixing.py +1347 -1347
- gaia/agents/code/tools/external_tools.py +180 -180
- gaia/agents/code/tools/file_io.py +845 -845
- gaia/agents/code/tools/prisma_tools.py +190 -190
- gaia/agents/code/tools/project_management.py +1016 -1016
- gaia/agents/code/tools/testing.py +321 -321
- gaia/agents/code/tools/typescript_tools.py +122 -122
- gaia/agents/code/tools/validation_parsing.py +461 -461
- gaia/agents/code/tools/validation_tools.py +806 -806
- gaia/agents/code/tools/web_dev_tools.py +1758 -1758
- gaia/agents/code/validators/__init__.py +16 -16
- gaia/agents/code/validators/antipattern_checker.py +241 -241
- gaia/agents/code/validators/ast_analyzer.py +197 -197
- gaia/agents/code/validators/requirements_validator.py +145 -145
- gaia/agents/code/validators/syntax_validator.py +171 -171
- gaia/agents/docker/__init__.py +7 -7
- gaia/agents/docker/agent.py +643 -642
- gaia/agents/emr/__init__.py +8 -8
- gaia/agents/emr/agent.py +1504 -1506
- gaia/agents/emr/cli.py +1322 -1322
- gaia/agents/emr/constants.py +475 -475
- gaia/agents/emr/dashboard/__init__.py +4 -4
- gaia/agents/emr/dashboard/server.py +1972 -1974
- gaia/agents/jira/__init__.py +11 -11
- gaia/agents/jira/agent.py +894 -894
- gaia/agents/jira/jql_templates.py +299 -299
- gaia/agents/routing/__init__.py +7 -7
- gaia/agents/routing/agent.py +567 -570
- gaia/agents/routing/system_prompt.py +75 -75
- gaia/agents/summarize/__init__.py +11 -0
- gaia/agents/summarize/agent.py +885 -0
- gaia/agents/summarize/prompts.py +129 -0
- gaia/api/__init__.py +23 -23
- gaia/api/agent_registry.py +238 -238
- gaia/api/app.py +305 -305
- gaia/api/openai_server.py +575 -575
- gaia/api/schemas.py +186 -186
- gaia/api/sse_handler.py +373 -373
- gaia/apps/__init__.py +4 -4
- gaia/apps/llm/__init__.py +6 -6
- gaia/apps/llm/app.py +184 -169
- gaia/apps/summarize/app.py +116 -633
- gaia/apps/summarize/html_viewer.py +133 -133
- gaia/apps/summarize/pdf_formatter.py +284 -284
- gaia/audio/__init__.py +2 -2
- gaia/audio/audio_client.py +439 -439
- gaia/audio/audio_recorder.py +269 -269
- gaia/audio/kokoro_tts.py +599 -599
- gaia/audio/whisper_asr.py +432 -432
- gaia/chat/__init__.py +16 -16
- gaia/chat/app.py +428 -430
- gaia/chat/prompts.py +522 -522
- gaia/chat/sdk.py +1228 -1225
- gaia/cli.py +5659 -5632
- gaia/database/__init__.py +10 -10
- gaia/database/agent.py +176 -176
- gaia/database/mixin.py +290 -290
- gaia/database/testing.py +64 -64
- gaia/eval/batch_experiment.py +2332 -2332
- gaia/eval/claude.py +542 -542
- gaia/eval/config.py +37 -37
- gaia/eval/email_generator.py +512 -512
- gaia/eval/eval.py +3179 -3179
- gaia/eval/groundtruth.py +1130 -1130
- gaia/eval/transcript_generator.py +582 -582
- gaia/eval/webapp/README.md +167 -167
- gaia/eval/webapp/package-lock.json +875 -875
- gaia/eval/webapp/package.json +20 -20
- gaia/eval/webapp/public/app.js +3402 -3402
- gaia/eval/webapp/public/index.html +87 -87
- gaia/eval/webapp/public/styles.css +3661 -3661
- gaia/eval/webapp/server.js +415 -415
- gaia/eval/webapp/test-setup.js +72 -72
- gaia/installer/__init__.py +23 -0
- gaia/installer/init_command.py +1275 -0
- gaia/installer/lemonade_installer.py +619 -0
- gaia/llm/__init__.py +10 -2
- gaia/llm/base_client.py +60 -0
- gaia/llm/exceptions.py +12 -0
- gaia/llm/factory.py +70 -0
- gaia/llm/lemonade_client.py +3421 -3221
- gaia/llm/lemonade_manager.py +294 -294
- gaia/llm/providers/__init__.py +9 -0
- gaia/llm/providers/claude.py +108 -0
- gaia/llm/providers/lemonade.py +118 -0
- gaia/llm/providers/openai_provider.py +79 -0
- gaia/llm/vlm_client.py +382 -382
- gaia/logger.py +189 -189
- gaia/mcp/agent_mcp_server.py +245 -245
- gaia/mcp/blender_mcp_client.py +138 -138
- gaia/mcp/blender_mcp_server.py +648 -648
- gaia/mcp/context7_cache.py +332 -332
- gaia/mcp/external_services.py +518 -518
- gaia/mcp/mcp_bridge.py +811 -550
- gaia/mcp/servers/__init__.py +6 -6
- gaia/mcp/servers/docker_mcp.py +83 -83
- gaia/perf_analysis.py +361 -0
- gaia/rag/__init__.py +10 -10
- gaia/rag/app.py +293 -293
- gaia/rag/demo.py +304 -304
- gaia/rag/pdf_utils.py +235 -235
- gaia/rag/sdk.py +2194 -2194
- gaia/security.py +183 -163
- gaia/talk/app.py +287 -289
- gaia/talk/sdk.py +538 -538
- gaia/testing/__init__.py +87 -87
- gaia/testing/assertions.py +330 -330
- gaia/testing/fixtures.py +333 -333
- gaia/testing/mocks.py +493 -493
- gaia/util.py +46 -46
- gaia/utils/__init__.py +33 -33
- gaia/utils/file_watcher.py +675 -675
- gaia/utils/parsing.py +223 -223
- gaia/version.py +100 -100
- amd_gaia-0.15.0.dist-info/RECORD +0 -168
- gaia/agents/code/app.py +0 -266
- gaia/llm/llm_client.py +0 -723
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.2.dist-info}/top_level.txt +0 -0
|
@@ -1,307 +1,307 @@
|
|
|
1
|
-
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: MIT
|
|
3
|
-
"""
|
|
4
|
-
Python step implementations.
|
|
5
|
-
|
|
6
|
-
Steps wrap the existing Code Agent tools with standardized interfaces.
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
from dataclasses import dataclass
|
|
10
|
-
from typing import Any, Dict, Optional, Tuple
|
|
11
|
-
|
|
12
|
-
from .base import BaseStep, ErrorCategory, StepResult, UserContext
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@dataclass
|
|
16
|
-
class CreateProjectStep(BaseStep):
|
|
17
|
-
"""Step to create a Python project."""
|
|
18
|
-
|
|
19
|
-
name: str = "create_project"
|
|
20
|
-
description: str = "Generate Python project"
|
|
21
|
-
user_request: str = ""
|
|
22
|
-
|
|
23
|
-
def get_tool_invocation(
|
|
24
|
-
self, context: UserContext
|
|
25
|
-
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
26
|
-
"""Return create_project invocation."""
|
|
27
|
-
return (
|
|
28
|
-
"create_project",
|
|
29
|
-
{
|
|
30
|
-
"query": self.user_request or context.user_request,
|
|
31
|
-
},
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
35
|
-
"""Convert tool result to StepResult."""
|
|
36
|
-
if isinstance(result, dict):
|
|
37
|
-
if result.get("success"):
|
|
38
|
-
# Store project info in context
|
|
39
|
-
project_name = result.get("project_name", "")
|
|
40
|
-
files = result.get("files", [])
|
|
41
|
-
return StepResult.ok(
|
|
42
|
-
f"Project {project_name} created",
|
|
43
|
-
project_name=project_name,
|
|
44
|
-
files=files,
|
|
45
|
-
)
|
|
46
|
-
return StepResult.make_error(
|
|
47
|
-
"Failed to create project",
|
|
48
|
-
result.get("error", "Unknown error"),
|
|
49
|
-
ErrorCategory.COMPILATION,
|
|
50
|
-
)
|
|
51
|
-
return StepResult.make_error(
|
|
52
|
-
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
@dataclass
|
|
57
|
-
class ListFilesStep(BaseStep):
|
|
58
|
-
"""Step to list files in project."""
|
|
59
|
-
|
|
60
|
-
name: str = "list_files"
|
|
61
|
-
description: str = "List project files"
|
|
62
|
-
|
|
63
|
-
def get_tool_invocation(
|
|
64
|
-
self, context: UserContext
|
|
65
|
-
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
66
|
-
"""Return list_files invocation."""
|
|
67
|
-
# Get project name from previous step output
|
|
68
|
-
project_name = context.step_outputs.get("create_project", {}).get(
|
|
69
|
-
"project_name", ""
|
|
70
|
-
)
|
|
71
|
-
return (
|
|
72
|
-
"list_files",
|
|
73
|
-
{
|
|
74
|
-
"path": project_name or context.project_dir,
|
|
75
|
-
},
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
79
|
-
"""Convert tool result to StepResult."""
|
|
80
|
-
if isinstance(result, dict):
|
|
81
|
-
if result.get("success") or "files" in result:
|
|
82
|
-
files = result.get("files", [])
|
|
83
|
-
return StepResult.ok(
|
|
84
|
-
f"Found {len(files)} files",
|
|
85
|
-
files=files,
|
|
86
|
-
)
|
|
87
|
-
return StepResult.make_error(
|
|
88
|
-
"Failed to list files",
|
|
89
|
-
result.get("error", "Unknown error"),
|
|
90
|
-
ErrorCategory.UNKNOWN,
|
|
91
|
-
)
|
|
92
|
-
# If result is a list directly
|
|
93
|
-
if isinstance(result, list):
|
|
94
|
-
return StepResult.ok(f"Found {len(result)} files", files=result)
|
|
95
|
-
return StepResult.make_error(
|
|
96
|
-
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
@dataclass
|
|
101
|
-
class ValidateProjectStep(BaseStep):
|
|
102
|
-
"""Step to validate project structure."""
|
|
103
|
-
|
|
104
|
-
name: str = "validate_project"
|
|
105
|
-
description: str = "Validate project structure"
|
|
106
|
-
|
|
107
|
-
def get_tool_invocation(
|
|
108
|
-
self, context: UserContext
|
|
109
|
-
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
110
|
-
"""Return validate_project invocation."""
|
|
111
|
-
project_name = context.step_outputs.get("create_project", {}).get(
|
|
112
|
-
"project_name", ""
|
|
113
|
-
)
|
|
114
|
-
return (
|
|
115
|
-
"validate_project",
|
|
116
|
-
{
|
|
117
|
-
"project_path": project_name or context.project_dir,
|
|
118
|
-
"fix": True,
|
|
119
|
-
},
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
123
|
-
"""Convert tool result to StepResult."""
|
|
124
|
-
if isinstance(result, dict):
|
|
125
|
-
if result.get("valid") or result.get("success"):
|
|
126
|
-
return StepResult.ok("Project structure validated")
|
|
127
|
-
issues = result.get("issues", [])
|
|
128
|
-
return StepResult.warning(
|
|
129
|
-
"Project has issues",
|
|
130
|
-
issues=issues,
|
|
131
|
-
)
|
|
132
|
-
return StepResult.make_error(
|
|
133
|
-
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
@dataclass
|
|
138
|
-
class AutoFixSyntaxStep(BaseStep):
|
|
139
|
-
"""Step to auto-fix syntax errors."""
|
|
140
|
-
|
|
141
|
-
name: str = "auto_fix_syntax"
|
|
142
|
-
description: str = "Fix syntax errors"
|
|
143
|
-
|
|
144
|
-
def get_tool_invocation(
|
|
145
|
-
self, context: UserContext
|
|
146
|
-
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
147
|
-
"""Return auto_fix_syntax_errors invocation."""
|
|
148
|
-
project_name = context.step_outputs.get("create_project", {}).get(
|
|
149
|
-
"project_name", ""
|
|
150
|
-
)
|
|
151
|
-
return (
|
|
152
|
-
"auto_fix_syntax_errors",
|
|
153
|
-
{
|
|
154
|
-
"project_path": project_name or context.project_dir,
|
|
155
|
-
},
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
159
|
-
"""Convert tool result to StepResult."""
|
|
160
|
-
if isinstance(result, dict):
|
|
161
|
-
if result.get("success"):
|
|
162
|
-
fixed_count = result.get("files_fixed", 0)
|
|
163
|
-
if fixed_count > 0:
|
|
164
|
-
return StepResult.warning(
|
|
165
|
-
f"Fixed syntax errors in {fixed_count} files",
|
|
166
|
-
files_fixed=fixed_count,
|
|
167
|
-
)
|
|
168
|
-
return StepResult.ok("No syntax errors found")
|
|
169
|
-
return StepResult.make_error(
|
|
170
|
-
"Failed to fix syntax errors",
|
|
171
|
-
result.get("error", "Unknown error"),
|
|
172
|
-
ErrorCategory.SYNTAX,
|
|
173
|
-
)
|
|
174
|
-
return StepResult.make_error(
|
|
175
|
-
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
@dataclass
|
|
180
|
-
class AnalyzePylintStep(BaseStep):
|
|
181
|
-
"""Step to analyze code with pylint."""
|
|
182
|
-
|
|
183
|
-
name: str = "analyze_pylint"
|
|
184
|
-
description: str = "Run pylint analysis"
|
|
185
|
-
|
|
186
|
-
def get_tool_invocation(
|
|
187
|
-
self, context: UserContext
|
|
188
|
-
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
189
|
-
"""Return analyze_with_pylint invocation."""
|
|
190
|
-
project_name = context.step_outputs.get("create_project", {}).get(
|
|
191
|
-
"project_name", ""
|
|
192
|
-
)
|
|
193
|
-
return (
|
|
194
|
-
"analyze_with_pylint",
|
|
195
|
-
{
|
|
196
|
-
"file_path": project_name or context.project_dir,
|
|
197
|
-
},
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
201
|
-
"""Convert tool result to StepResult."""
|
|
202
|
-
if isinstance(result, dict):
|
|
203
|
-
score = result.get("score", 0)
|
|
204
|
-
issues = result.get("issues", [])
|
|
205
|
-
|
|
206
|
-
if score >= 8.0:
|
|
207
|
-
return StepResult.ok(f"Pylint score: {score}/10", score=score)
|
|
208
|
-
if issues:
|
|
209
|
-
return StepResult.warning(
|
|
210
|
-
f"Pylint score: {score}/10 ({len(issues)} issues)",
|
|
211
|
-
score=score,
|
|
212
|
-
issues=issues,
|
|
213
|
-
)
|
|
214
|
-
return StepResult.ok(
|
|
215
|
-
f"Pylint completed with score: {score}/10", score=score
|
|
216
|
-
)
|
|
217
|
-
return StepResult.make_error(
|
|
218
|
-
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
219
|
-
)
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
@dataclass
|
|
223
|
-
class FixLintingStep(BaseStep):
|
|
224
|
-
"""Step to fix linting issues."""
|
|
225
|
-
|
|
226
|
-
name: str = "fix_linting"
|
|
227
|
-
description: str = "Fix linting issues"
|
|
228
|
-
|
|
229
|
-
def should_skip(self, context: UserContext) -> Optional[str]:
|
|
230
|
-
"""Skip if pylint score is already good."""
|
|
231
|
-
pylint_output = context.step_outputs.get("analyze_pylint", {})
|
|
232
|
-
score = pylint_output.get("score", 0)
|
|
233
|
-
if score >= 8.0:
|
|
234
|
-
return f"Pylint score {score}/10 is good, no fixing needed"
|
|
235
|
-
return None
|
|
236
|
-
|
|
237
|
-
def get_tool_invocation(
|
|
238
|
-
self, context: UserContext
|
|
239
|
-
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
240
|
-
"""Return fix_linting_errors invocation."""
|
|
241
|
-
project_name = context.step_outputs.get("create_project", {}).get(
|
|
242
|
-
"project_name", ""
|
|
243
|
-
)
|
|
244
|
-
return (
|
|
245
|
-
"fix_linting_errors",
|
|
246
|
-
{
|
|
247
|
-
"project_path": project_name or context.project_dir,
|
|
248
|
-
},
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
252
|
-
"""Convert tool result to StepResult."""
|
|
253
|
-
if isinstance(result, dict):
|
|
254
|
-
if result.get("success"):
|
|
255
|
-
fixed_count = result.get("files_fixed", 0)
|
|
256
|
-
return StepResult.ok(f"Fixed linting issues in {fixed_count} files")
|
|
257
|
-
return StepResult.warning(
|
|
258
|
-
"Some linting issues could not be auto-fixed",
|
|
259
|
-
error=result.get("error", ""),
|
|
260
|
-
)
|
|
261
|
-
return StepResult.make_error(
|
|
262
|
-
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
263
|
-
)
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
@dataclass
|
|
267
|
-
class RunPytestStep(BaseStep):
|
|
268
|
-
"""Step to run pytest."""
|
|
269
|
-
|
|
270
|
-
name: str = "run_tests"
|
|
271
|
-
description: str = "Run pytest"
|
|
272
|
-
|
|
273
|
-
def get_tool_invocation(
|
|
274
|
-
self, context: UserContext
|
|
275
|
-
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
276
|
-
"""Return run_tests invocation."""
|
|
277
|
-
project_name = context.step_outputs.get("create_project", {}).get(
|
|
278
|
-
"project_name", ""
|
|
279
|
-
)
|
|
280
|
-
return (
|
|
281
|
-
"run_tests",
|
|
282
|
-
{
|
|
283
|
-
"project_path": project_name or context.project_dir,
|
|
284
|
-
},
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
288
|
-
"""Convert tool result to StepResult."""
|
|
289
|
-
if isinstance(result, dict):
|
|
290
|
-
tests_passed = result.get("tests_passed", False)
|
|
291
|
-
return_code = result.get("return_code", 1)
|
|
292
|
-
|
|
293
|
-
if tests_passed or return_code == 0:
|
|
294
|
-
passed = result.get("passed", 0)
|
|
295
|
-
return StepResult.ok(
|
|
296
|
-
f"All tests passed ({passed} tests)",
|
|
297
|
-
passed=passed,
|
|
298
|
-
)
|
|
299
|
-
failed = result.get("failed", 0)
|
|
300
|
-
return StepResult.warning(
|
|
301
|
-
f"Some tests failed ({failed} failures)",
|
|
302
|
-
failed=failed,
|
|
303
|
-
output=result.get("output", ""),
|
|
304
|
-
)
|
|
305
|
-
return StepResult.make_error(
|
|
306
|
-
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
307
|
-
)
|
|
1
|
+
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
"""
|
|
4
|
+
Python step implementations.
|
|
5
|
+
|
|
6
|
+
Steps wrap the existing Code Agent tools with standardized interfaces.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Any, Dict, Optional, Tuple
|
|
11
|
+
|
|
12
|
+
from .base import BaseStep, ErrorCategory, StepResult, UserContext
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class CreateProjectStep(BaseStep):
|
|
17
|
+
"""Step to create a Python project."""
|
|
18
|
+
|
|
19
|
+
name: str = "create_project"
|
|
20
|
+
description: str = "Generate Python project"
|
|
21
|
+
user_request: str = ""
|
|
22
|
+
|
|
23
|
+
def get_tool_invocation(
|
|
24
|
+
self, context: UserContext
|
|
25
|
+
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
26
|
+
"""Return create_project invocation."""
|
|
27
|
+
return (
|
|
28
|
+
"create_project",
|
|
29
|
+
{
|
|
30
|
+
"query": self.user_request or context.user_request,
|
|
31
|
+
},
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
35
|
+
"""Convert tool result to StepResult."""
|
|
36
|
+
if isinstance(result, dict):
|
|
37
|
+
if result.get("success"):
|
|
38
|
+
# Store project info in context
|
|
39
|
+
project_name = result.get("project_name", "")
|
|
40
|
+
files = result.get("files", [])
|
|
41
|
+
return StepResult.ok(
|
|
42
|
+
f"Project {project_name} created",
|
|
43
|
+
project_name=project_name,
|
|
44
|
+
files=files,
|
|
45
|
+
)
|
|
46
|
+
return StepResult.make_error(
|
|
47
|
+
"Failed to create project",
|
|
48
|
+
result.get("error", "Unknown error"),
|
|
49
|
+
ErrorCategory.COMPILATION,
|
|
50
|
+
)
|
|
51
|
+
return StepResult.make_error(
|
|
52
|
+
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class ListFilesStep(BaseStep):
|
|
58
|
+
"""Step to list files in project."""
|
|
59
|
+
|
|
60
|
+
name: str = "list_files"
|
|
61
|
+
description: str = "List project files"
|
|
62
|
+
|
|
63
|
+
def get_tool_invocation(
|
|
64
|
+
self, context: UserContext
|
|
65
|
+
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
66
|
+
"""Return list_files invocation."""
|
|
67
|
+
# Get project name from previous step output
|
|
68
|
+
project_name = context.step_outputs.get("create_project", {}).get(
|
|
69
|
+
"project_name", ""
|
|
70
|
+
)
|
|
71
|
+
return (
|
|
72
|
+
"list_files",
|
|
73
|
+
{
|
|
74
|
+
"path": project_name or context.project_dir,
|
|
75
|
+
},
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
79
|
+
"""Convert tool result to StepResult."""
|
|
80
|
+
if isinstance(result, dict):
|
|
81
|
+
if result.get("success") or "files" in result:
|
|
82
|
+
files = result.get("files", [])
|
|
83
|
+
return StepResult.ok(
|
|
84
|
+
f"Found {len(files)} files",
|
|
85
|
+
files=files,
|
|
86
|
+
)
|
|
87
|
+
return StepResult.make_error(
|
|
88
|
+
"Failed to list files",
|
|
89
|
+
result.get("error", "Unknown error"),
|
|
90
|
+
ErrorCategory.UNKNOWN,
|
|
91
|
+
)
|
|
92
|
+
# If result is a list directly
|
|
93
|
+
if isinstance(result, list):
|
|
94
|
+
return StepResult.ok(f"Found {len(result)} files", files=result)
|
|
95
|
+
return StepResult.make_error(
|
|
96
|
+
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclass
|
|
101
|
+
class ValidateProjectStep(BaseStep):
|
|
102
|
+
"""Step to validate project structure."""
|
|
103
|
+
|
|
104
|
+
name: str = "validate_project"
|
|
105
|
+
description: str = "Validate project structure"
|
|
106
|
+
|
|
107
|
+
def get_tool_invocation(
|
|
108
|
+
self, context: UserContext
|
|
109
|
+
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
110
|
+
"""Return validate_project invocation."""
|
|
111
|
+
project_name = context.step_outputs.get("create_project", {}).get(
|
|
112
|
+
"project_name", ""
|
|
113
|
+
)
|
|
114
|
+
return (
|
|
115
|
+
"validate_project",
|
|
116
|
+
{
|
|
117
|
+
"project_path": project_name or context.project_dir,
|
|
118
|
+
"fix": True,
|
|
119
|
+
},
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
123
|
+
"""Convert tool result to StepResult."""
|
|
124
|
+
if isinstance(result, dict):
|
|
125
|
+
if result.get("valid") or result.get("success"):
|
|
126
|
+
return StepResult.ok("Project structure validated")
|
|
127
|
+
issues = result.get("issues", [])
|
|
128
|
+
return StepResult.warning(
|
|
129
|
+
"Project has issues",
|
|
130
|
+
issues=issues,
|
|
131
|
+
)
|
|
132
|
+
return StepResult.make_error(
|
|
133
|
+
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@dataclass
|
|
138
|
+
class AutoFixSyntaxStep(BaseStep):
|
|
139
|
+
"""Step to auto-fix syntax errors."""
|
|
140
|
+
|
|
141
|
+
name: str = "auto_fix_syntax"
|
|
142
|
+
description: str = "Fix syntax errors"
|
|
143
|
+
|
|
144
|
+
def get_tool_invocation(
|
|
145
|
+
self, context: UserContext
|
|
146
|
+
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
147
|
+
"""Return auto_fix_syntax_errors invocation."""
|
|
148
|
+
project_name = context.step_outputs.get("create_project", {}).get(
|
|
149
|
+
"project_name", ""
|
|
150
|
+
)
|
|
151
|
+
return (
|
|
152
|
+
"auto_fix_syntax_errors",
|
|
153
|
+
{
|
|
154
|
+
"project_path": project_name or context.project_dir,
|
|
155
|
+
},
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
159
|
+
"""Convert tool result to StepResult."""
|
|
160
|
+
if isinstance(result, dict):
|
|
161
|
+
if result.get("success"):
|
|
162
|
+
fixed_count = result.get("files_fixed", 0)
|
|
163
|
+
if fixed_count > 0:
|
|
164
|
+
return StepResult.warning(
|
|
165
|
+
f"Fixed syntax errors in {fixed_count} files",
|
|
166
|
+
files_fixed=fixed_count,
|
|
167
|
+
)
|
|
168
|
+
return StepResult.ok("No syntax errors found")
|
|
169
|
+
return StepResult.make_error(
|
|
170
|
+
"Failed to fix syntax errors",
|
|
171
|
+
result.get("error", "Unknown error"),
|
|
172
|
+
ErrorCategory.SYNTAX,
|
|
173
|
+
)
|
|
174
|
+
return StepResult.make_error(
|
|
175
|
+
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@dataclass
|
|
180
|
+
class AnalyzePylintStep(BaseStep):
|
|
181
|
+
"""Step to analyze code with pylint."""
|
|
182
|
+
|
|
183
|
+
name: str = "analyze_pylint"
|
|
184
|
+
description: str = "Run pylint analysis"
|
|
185
|
+
|
|
186
|
+
def get_tool_invocation(
|
|
187
|
+
self, context: UserContext
|
|
188
|
+
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
189
|
+
"""Return analyze_with_pylint invocation."""
|
|
190
|
+
project_name = context.step_outputs.get("create_project", {}).get(
|
|
191
|
+
"project_name", ""
|
|
192
|
+
)
|
|
193
|
+
return (
|
|
194
|
+
"analyze_with_pylint",
|
|
195
|
+
{
|
|
196
|
+
"file_path": project_name or context.project_dir,
|
|
197
|
+
},
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
201
|
+
"""Convert tool result to StepResult."""
|
|
202
|
+
if isinstance(result, dict):
|
|
203
|
+
score = result.get("score", 0)
|
|
204
|
+
issues = result.get("issues", [])
|
|
205
|
+
|
|
206
|
+
if score >= 8.0:
|
|
207
|
+
return StepResult.ok(f"Pylint score: {score}/10", score=score)
|
|
208
|
+
if issues:
|
|
209
|
+
return StepResult.warning(
|
|
210
|
+
f"Pylint score: {score}/10 ({len(issues)} issues)",
|
|
211
|
+
score=score,
|
|
212
|
+
issues=issues,
|
|
213
|
+
)
|
|
214
|
+
return StepResult.ok(
|
|
215
|
+
f"Pylint completed with score: {score}/10", score=score
|
|
216
|
+
)
|
|
217
|
+
return StepResult.make_error(
|
|
218
|
+
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@dataclass
|
|
223
|
+
class FixLintingStep(BaseStep):
|
|
224
|
+
"""Step to fix linting issues."""
|
|
225
|
+
|
|
226
|
+
name: str = "fix_linting"
|
|
227
|
+
description: str = "Fix linting issues"
|
|
228
|
+
|
|
229
|
+
def should_skip(self, context: UserContext) -> Optional[str]:
|
|
230
|
+
"""Skip if pylint score is already good."""
|
|
231
|
+
pylint_output = context.step_outputs.get("analyze_pylint", {})
|
|
232
|
+
score = pylint_output.get("score", 0)
|
|
233
|
+
if score >= 8.0:
|
|
234
|
+
return f"Pylint score {score}/10 is good, no fixing needed"
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
def get_tool_invocation(
|
|
238
|
+
self, context: UserContext
|
|
239
|
+
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
240
|
+
"""Return fix_linting_errors invocation."""
|
|
241
|
+
project_name = context.step_outputs.get("create_project", {}).get(
|
|
242
|
+
"project_name", ""
|
|
243
|
+
)
|
|
244
|
+
return (
|
|
245
|
+
"fix_linting_errors",
|
|
246
|
+
{
|
|
247
|
+
"project_path": project_name or context.project_dir,
|
|
248
|
+
},
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
252
|
+
"""Convert tool result to StepResult."""
|
|
253
|
+
if isinstance(result, dict):
|
|
254
|
+
if result.get("success"):
|
|
255
|
+
fixed_count = result.get("files_fixed", 0)
|
|
256
|
+
return StepResult.ok(f"Fixed linting issues in {fixed_count} files")
|
|
257
|
+
return StepResult.warning(
|
|
258
|
+
"Some linting issues could not be auto-fixed",
|
|
259
|
+
error=result.get("error", ""),
|
|
260
|
+
)
|
|
261
|
+
return StepResult.make_error(
|
|
262
|
+
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
@dataclass
|
|
267
|
+
class RunPytestStep(BaseStep):
|
|
268
|
+
"""Step to run pytest."""
|
|
269
|
+
|
|
270
|
+
name: str = "run_tests"
|
|
271
|
+
description: str = "Run pytest"
|
|
272
|
+
|
|
273
|
+
def get_tool_invocation(
|
|
274
|
+
self, context: UserContext
|
|
275
|
+
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
|
276
|
+
"""Return run_tests invocation."""
|
|
277
|
+
project_name = context.step_outputs.get("create_project", {}).get(
|
|
278
|
+
"project_name", ""
|
|
279
|
+
)
|
|
280
|
+
return (
|
|
281
|
+
"run_tests",
|
|
282
|
+
{
|
|
283
|
+
"project_path": project_name or context.project_dir,
|
|
284
|
+
},
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
def handle_result(self, result: Any, context: UserContext) -> StepResult:
|
|
288
|
+
"""Convert tool result to StepResult."""
|
|
289
|
+
if isinstance(result, dict):
|
|
290
|
+
tests_passed = result.get("tests_passed", False)
|
|
291
|
+
return_code = result.get("return_code", 1)
|
|
292
|
+
|
|
293
|
+
if tests_passed or return_code == 0:
|
|
294
|
+
passed = result.get("passed", 0)
|
|
295
|
+
return StepResult.ok(
|
|
296
|
+
f"All tests passed ({passed} tests)",
|
|
297
|
+
passed=passed,
|
|
298
|
+
)
|
|
299
|
+
failed = result.get("failed", 0)
|
|
300
|
+
return StepResult.warning(
|
|
301
|
+
f"Some tests failed ({failed} failures)",
|
|
302
|
+
failed=failed,
|
|
303
|
+
output=result.get("output", ""),
|
|
304
|
+
)
|
|
305
|
+
return StepResult.make_error(
|
|
306
|
+
"Unexpected result format", str(result), ErrorCategory.UNKNOWN
|
|
307
|
+
)
|