@aws/ml-container-creator 1.0.4 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +9 -0
- package/bin/cli.js +57 -0
- package/config/agent.json +16 -0
- package/package.json +4 -1
- package/pyproject.toml +3 -0
- package/servers/agent-knowledge/index.js +592 -0
- package/servers/agent-knowledge/package.json +15 -0
- package/src/agent/__init__.py +2 -0
- package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
- package/src/agent/agent.py +513 -0
- package/src/agent/config_loader.py +215 -0
- package/src/agent/context.py +380 -0
- package/src/agent/data/capability-matrix.json +106 -0
- package/src/agent/health_check.py +341 -0
- package/src/agent/prompts/system.md +173 -0
- package/src/agent/requirements-agent.txt +3 -0
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/tune-config-state.js +89 -68
- package/templates/do/config +6 -1
- package/src/lib/auto-prompt-builder.js +0 -172
- package/src/lib/cli-handler.js +0 -529
- package/src/lib/community-reports-validator.js +0 -91
- package/src/lib/configuration-exporter.js +0 -204
- package/src/lib/dataset-slug.js +0 -152
- package/src/lib/docker-introspection-validator.js +0 -51
- package/src/lib/known-flags-validator.js +0 -200
- package/src/lib/schema-validator.js +0 -157
- package/src/lib/train-config-parser.js +0 -136
- package/src/lib/train-config-persistence.js +0 -143
- package/src/lib/train-config-validator.js +0 -112
- package/src/lib/train-feedback.js +0 -46
- package/src/lib/train-idempotency.js +0 -97
- package/src/lib/train-request-builder.js +0 -120
- package/src/lib/tune-dataset-validator.js +0 -279
- package/src/lib/tune-output-resolver.js +0 -66
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""ml-container-creator hey — Advisory agent powered by Strands.
|
|
5
|
+
|
|
6
|
+
Entry point for the interactive REPL that connects to MCP servers
|
|
7
|
+
and provides ML infrastructure guidance via Claude on Bedrock.
|
|
8
|
+
|
|
9
|
+
Usage:
|
|
10
|
+
python3 src/agent/agent.py --project-dir <path> [--offline|-o]
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import os
|
|
17
|
+
import signal
|
|
18
|
+
import sys
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
os.environ.setdefault("PYTHONUNBUFFERED", "1")
|
|
23
|
+
|
|
24
|
+
from strands import Agent, tool
|
|
25
|
+
from strands.tools.mcp import MCPClient
|
|
26
|
+
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
27
|
+
|
|
28
|
+
from config_loader import load_agent_config
|
|
29
|
+
from context import ProjectContext
|
|
30
|
+
from health_check import EnvironmentHealthCheck, print_health_report
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ─── Constants ────────────────────────────────────────────────────────────────
|
|
34
|
+
|
|
35
|
+
_PACKAGE_ROOT = Path(__file__).resolve().parent.parent.parent
|
|
36
|
+
_MCP_CONFIG_PATH = _PACKAGE_ROOT / "config" / "mcp.json"
|
|
37
|
+
_SYSTEM_PROMPT_PATH = Path(__file__).resolve().parent / "prompts" / "system.md"
|
|
38
|
+
_CAPABILITY_MATRIX_PATH = Path(__file__).resolve().parent / "data" / "capability-matrix.json"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# ─── write_file tool ──────────────────────────────────────────────────────────
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _create_write_file_tool(project_dir: Path):
|
|
45
|
+
"""Create a write_file tool scoped to the given project directory.
|
|
46
|
+
|
|
47
|
+
The tool validates that the target path does not escape the project root,
|
|
48
|
+
preventing path traversal attacks.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
project_dir: Resolved absolute path to the project root.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
A Strands @tool-decorated function.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
@tool
|
|
58
|
+
def write_file(file_path: str, content: str) -> str:
|
|
59
|
+
"""Write content to a file within the project directory.
|
|
60
|
+
|
|
61
|
+
Use this to save action plans, TODO lists, or recommendation summaries.
|
|
62
|
+
The file path must be relative to the project root. Parent directories
|
|
63
|
+
are created automatically.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
file_path: Relative path within the project (e.g., "TODO.md", "docs/plan.md").
|
|
67
|
+
content: Text content to write to the file.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Confirmation message with the absolute path written.
|
|
71
|
+
"""
|
|
72
|
+
# Resolve the target path and validate it stays within project_dir
|
|
73
|
+
target = (project_dir / file_path).resolve()
|
|
74
|
+
try:
|
|
75
|
+
target.relative_to(project_dir)
|
|
76
|
+
except ValueError:
|
|
77
|
+
return f"Error: path '{file_path}' escapes the project directory. Refusing to write."
|
|
78
|
+
|
|
79
|
+
# Create parent directories if needed
|
|
80
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
81
|
+
|
|
82
|
+
# Write the file
|
|
83
|
+
target.write_text(content, encoding="utf-8")
|
|
84
|
+
return f"Written to {target}"
|
|
85
|
+
|
|
86
|
+
return write_file
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# ─── MCP Server Management ───────────────────────────────────────────────────
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _load_mcp_config() -> dict[str, Any]:
|
|
93
|
+
"""Load and parse config/mcp.json from the package root.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Dict of server configurations from mcpServers key.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
SystemExit: If the config file is missing or unparseable.
|
|
100
|
+
"""
|
|
101
|
+
if not _MCP_CONFIG_PATH.is_file():
|
|
102
|
+
print(f"\033[31mError:\033[0m config/mcp.json not found at {_MCP_CONFIG_PATH}")
|
|
103
|
+
sys.exit(1)
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
data = json.loads(_MCP_CONFIG_PATH.read_text(encoding="utf-8"))
|
|
107
|
+
return data.get("mcpServers", {})
|
|
108
|
+
except (json.JSONDecodeError, OSError) as e:
|
|
109
|
+
print(f"\033[31mError:\033[0m Cannot parse config/mcp.json: {e}")
|
|
110
|
+
sys.exit(1)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _start_mcp_servers(
|
|
114
|
+
server_names: frozenset[str],
|
|
115
|
+
timeout: int = 30,
|
|
116
|
+
) -> list[MCPClient]:
|
|
117
|
+
"""Start the subset of MCP servers needed by the advisory agent.
|
|
118
|
+
|
|
119
|
+
Reads config/mcp.json, filters to the agent's required servers,
|
|
120
|
+
resolves paths relative to the package root, and starts each via stdio.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
server_names: Set of MCP server names to connect to.
|
|
124
|
+
timeout: Seconds to wait for each MCP server to start (reserved for future use).
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
List of connected MCPClient instances.
|
|
128
|
+
"""
|
|
129
|
+
all_servers = _load_mcp_config()
|
|
130
|
+
clients: list[MCPClient] = []
|
|
131
|
+
|
|
132
|
+
for name, config in all_servers.items():
|
|
133
|
+
if name not in server_names:
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
command = config.get("command", "node")
|
|
137
|
+
args = config.get("args", [])
|
|
138
|
+
|
|
139
|
+
# Resolve relative server paths against package root
|
|
140
|
+
resolved_args = []
|
|
141
|
+
for arg in args:
|
|
142
|
+
arg_path = _PACKAGE_ROOT / arg
|
|
143
|
+
if arg_path.is_file():
|
|
144
|
+
resolved_args.append(str(arg_path))
|
|
145
|
+
else:
|
|
146
|
+
resolved_args.append(arg)
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
server_params = StdioServerParameters(command=command, args=resolved_args)
|
|
150
|
+
client = MCPClient(lambda sp=server_params: stdio_client(sp))
|
|
151
|
+
clients.append(client)
|
|
152
|
+
except Exception as e:
|
|
153
|
+
print(f" \033[33m⚠\033[0m Could not start MCP server '{name}': {e}")
|
|
154
|
+
|
|
155
|
+
# Also start the agent-knowledge server explicitly if not in mcp.json
|
|
156
|
+
# (it's at servers/agent-knowledge/index.js)
|
|
157
|
+
if "agent-knowledge" in server_names and "agent-knowledge" not in all_servers:
|
|
158
|
+
knowledge_path = _PACKAGE_ROOT / "servers" / "agent-knowledge" / "index.js"
|
|
159
|
+
if knowledge_path.is_file():
|
|
160
|
+
try:
|
|
161
|
+
server_params = StdioServerParameters(command="node", args=[str(knowledge_path)])
|
|
162
|
+
client = MCPClient(lambda sp=server_params: stdio_client(sp))
|
|
163
|
+
clients.append(client)
|
|
164
|
+
except Exception as e:
|
|
165
|
+
print(f" \033[33m⚠\033[0m Could not start agent-knowledge server: {e}")
|
|
166
|
+
|
|
167
|
+
return clients
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _stop_mcp_servers(clients: list[MCPClient]) -> None:
|
|
171
|
+
"""Gracefully stop all MCP clients.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
clients: List of MCPClient instances to shut down.
|
|
175
|
+
"""
|
|
176
|
+
for client in clients:
|
|
177
|
+
try:
|
|
178
|
+
client.stop(None, None, None)
|
|
179
|
+
except Exception:
|
|
180
|
+
pass # Best effort cleanup
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# ─── System Prompt Construction ───────────────────────────────────────────────
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _build_system_prompt(context: dict[str, Any]) -> str:
|
|
187
|
+
"""Build the system prompt by loading the template and injecting context.
|
|
188
|
+
|
|
189
|
+
Substitutes placeholders:
|
|
190
|
+
- {project_context_json} — serialized project context
|
|
191
|
+
- {capability_matrix_json} — capability matrix data
|
|
192
|
+
- {user_context_md} — user-provided context markdown (or empty)
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
context: Project context dict from ProjectContext.load().
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Fully rendered system prompt string.
|
|
199
|
+
"""
|
|
200
|
+
# Load the prompt template
|
|
201
|
+
if _SYSTEM_PROMPT_PATH.is_file():
|
|
202
|
+
template = _SYSTEM_PROMPT_PATH.read_text(encoding="utf-8")
|
|
203
|
+
else:
|
|
204
|
+
template = "You are the ml-container-creator advisor.\n\n{project_context_json}"
|
|
205
|
+
|
|
206
|
+
# Load capability matrix
|
|
207
|
+
capability_matrix = "{}"
|
|
208
|
+
if _CAPABILITY_MATRIX_PATH.is_file():
|
|
209
|
+
try:
|
|
210
|
+
capability_matrix = _CAPABILITY_MATRIX_PATH.read_text(encoding="utf-8")
|
|
211
|
+
except OSError:
|
|
212
|
+
pass
|
|
213
|
+
|
|
214
|
+
# Serialize project context (exclude internal fields)
|
|
215
|
+
context_json = json.dumps(context, indent=2, default=str)
|
|
216
|
+
|
|
217
|
+
# Extract user context
|
|
218
|
+
user_context = context.get("user_context") or "No user-provided context file found."
|
|
219
|
+
|
|
220
|
+
# Perform substitutions
|
|
221
|
+
prompt = template.replace("{project_context_json}", context_json)
|
|
222
|
+
prompt = prompt.replace("{capability_matrix_json}", capability_matrix)
|
|
223
|
+
prompt = prompt.replace("{user_context_md}", user_context)
|
|
224
|
+
|
|
225
|
+
return prompt
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
# ─── Cost Tracking ────────────────────────────────────────────────────────────
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class CostTracker:
|
|
232
|
+
"""Simple token cost tracker for the session.
|
|
233
|
+
|
|
234
|
+
Tracks approximate input/output tokens and computes
|
|
235
|
+
estimated cost based on Claude Sonnet pricing.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
def __init__(self, input_cost_per_1k: float = 0.003, output_cost_per_1k: float = 0.015) -> None:
|
|
239
|
+
self._input_cost_per_1k = input_cost_per_1k
|
|
240
|
+
self._output_cost_per_1k = output_cost_per_1k
|
|
241
|
+
self.input_tokens: int = 0
|
|
242
|
+
self.output_tokens: int = 0
|
|
243
|
+
self.turns: int = 0
|
|
244
|
+
|
|
245
|
+
def record_turn(self, input_tokens: int = 0, output_tokens: int = 0) -> None:
|
|
246
|
+
"""Record token usage from a single turn.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
input_tokens: Number of input tokens consumed.
|
|
250
|
+
output_tokens: Number of output tokens generated.
|
|
251
|
+
"""
|
|
252
|
+
self.input_tokens += input_tokens
|
|
253
|
+
self.output_tokens += output_tokens
|
|
254
|
+
self.turns += 1
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def estimated_cost(self) -> float:
|
|
258
|
+
"""Estimated USD cost for the session."""
|
|
259
|
+
input_cost = (self.input_tokens / 1000) * self._input_cost_per_1k
|
|
260
|
+
output_cost = (self.output_tokens / 1000) * self._output_cost_per_1k
|
|
261
|
+
return input_cost + output_cost
|
|
262
|
+
|
|
263
|
+
def print_summary(self) -> None:
|
|
264
|
+
"""Print a formatted cost summary to stdout."""
|
|
265
|
+
if self.turns == 0:
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
print("\n\033[1mSession Summary\033[0m")
|
|
269
|
+
print("─" * 40)
|
|
270
|
+
print(f" Turns: {self.turns}")
|
|
271
|
+
print(f" Input tokens: ~{self.input_tokens:,}")
|
|
272
|
+
print(f" Output tokens: ~{self.output_tokens:,}")
|
|
273
|
+
print(f" Estimated cost: ~${self.estimated_cost:.4f}")
|
|
274
|
+
print()
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
# ─── CLI Argument Parsing ─────────────────────────────────────────────────────
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _parse_args() -> tuple[str, bool]:
|
|
281
|
+
"""Parse command-line arguments.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
Tuple of (project_dir, offline_mode).
|
|
285
|
+
"""
|
|
286
|
+
args = sys.argv[1:]
|
|
287
|
+
project_dir = os.getcwd()
|
|
288
|
+
offline = False
|
|
289
|
+
|
|
290
|
+
i = 0
|
|
291
|
+
while i < len(args):
|
|
292
|
+
arg = args[i]
|
|
293
|
+
if arg == "--project-dir" and i + 1 < len(args):
|
|
294
|
+
project_dir = args[i + 1]
|
|
295
|
+
i += 2
|
|
296
|
+
elif arg in ("--offline", "-o"):
|
|
297
|
+
offline = True
|
|
298
|
+
i += 1
|
|
299
|
+
else:
|
|
300
|
+
i += 1
|
|
301
|
+
|
|
302
|
+
return project_dir, offline
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
# ─── REPL ─────────────────────────────────────────────────────────────────────
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _run_repl(
|
|
309
|
+
agent: Agent,
|
|
310
|
+
context: dict[str, Any],
|
|
311
|
+
project_dir: str,
|
|
312
|
+
cost: CostTracker,
|
|
313
|
+
exit_commands: frozenset[str],
|
|
314
|
+
reload_commands: frozenset[str],
|
|
315
|
+
) -> None:
|
|
316
|
+
"""Run the interactive REPL loop with streaming output.
|
|
317
|
+
|
|
318
|
+
Supports:
|
|
319
|
+
- Configurable exit commands to quit
|
|
320
|
+
- Configurable reload commands to refresh project context
|
|
321
|
+
- Ctrl+C / EOF for graceful exit
|
|
322
|
+
- Streaming responses
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
agent: Configured Strands Agent instance.
|
|
326
|
+
context: Current project context dict.
|
|
327
|
+
project_dir: Path to the project directory.
|
|
328
|
+
cost: CostTracker instance for session metrics.
|
|
329
|
+
exit_commands: Set of commands that exit the REPL.
|
|
330
|
+
reload_commands: Set of commands that reload project context.
|
|
331
|
+
"""
|
|
332
|
+
print("\n\033[1mReady.\033[0m Type your question, or 'exit' to quit.\n")
|
|
333
|
+
|
|
334
|
+
while True:
|
|
335
|
+
try:
|
|
336
|
+
user_input = input("\033[36myou:\033[0m ").strip()
|
|
337
|
+
except (EOFError, KeyboardInterrupt):
|
|
338
|
+
print("\n")
|
|
339
|
+
break
|
|
340
|
+
|
|
341
|
+
if not user_input:
|
|
342
|
+
continue
|
|
343
|
+
|
|
344
|
+
# Handle exit commands
|
|
345
|
+
if user_input.lower() in exit_commands:
|
|
346
|
+
break
|
|
347
|
+
|
|
348
|
+
# Handle reload
|
|
349
|
+
if user_input.lower() in reload_commands:
|
|
350
|
+
print(" Reloading project context...")
|
|
351
|
+
try:
|
|
352
|
+
new_context = ProjectContext(project_dir).load()
|
|
353
|
+
new_prompt = _build_system_prompt(new_context)
|
|
354
|
+
agent.system_prompt = new_prompt
|
|
355
|
+
context.update(new_context)
|
|
356
|
+
print(" \033[32m✓\033[0m Context reloaded.\n")
|
|
357
|
+
except Exception as e:
|
|
358
|
+
print(f" \033[31m✗\033[0m Reload failed: {e}\n")
|
|
359
|
+
continue
|
|
360
|
+
|
|
361
|
+
# Send to agent with streaming
|
|
362
|
+
try:
|
|
363
|
+
print("\033[90magent:\033[0m ", end="", flush=True)
|
|
364
|
+
response = agent(user_input)
|
|
365
|
+
|
|
366
|
+
# Track tokens from response metrics if available
|
|
367
|
+
if hasattr(response, "metrics") and response.metrics and hasattr(response.metrics, "accumulated_usage"):
|
|
368
|
+
metrics = response.metrics
|
|
369
|
+
usage = metrics.accumulated_usage or {}
|
|
370
|
+
input_t = usage.get("inputTokens", 0) or 0
|
|
371
|
+
output_t = usage.get("outputTokens", 0) or 0
|
|
372
|
+
cost.record_turn(input_tokens=input_t, output_tokens=output_t)
|
|
373
|
+
else:
|
|
374
|
+
# Fallback: approximate from word count
|
|
375
|
+
cost.record_turn(
|
|
376
|
+
input_tokens=len(user_input.split()) * 2,
|
|
377
|
+
output_tokens=len(str(response).split()) * 2,
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
print("\n")
|
|
381
|
+
except KeyboardInterrupt:
|
|
382
|
+
print("\n (interrupted)\n")
|
|
383
|
+
continue
|
|
384
|
+
except Exception as e:
|
|
385
|
+
print(f"\n \033[31mError:\033[0m {e}\n")
|
|
386
|
+
continue
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
# ─── Main ─────────────────────────────────────────────────────────────────────
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def main() -> None:
|
|
393
|
+
"""Entry point for the advisory agent.
|
|
394
|
+
|
|
395
|
+
Parses arguments, runs health checks, connects MCP servers,
|
|
396
|
+
creates the Strands agent, and starts the interactive REPL.
|
|
397
|
+
"""
|
|
398
|
+
project_dir, offline_mode = _parse_args()
|
|
399
|
+
project_path = Path(project_dir).resolve()
|
|
400
|
+
|
|
401
|
+
# Load external configuration
|
|
402
|
+
config = load_agent_config()
|
|
403
|
+
|
|
404
|
+
# Derive frozensets from config for fast membership testing
|
|
405
|
+
agent_mcp_servers = frozenset(config.mcp_servers)
|
|
406
|
+
exit_commands = frozenset(config.exit_commands)
|
|
407
|
+
reload_commands = frozenset(config.reload_commands)
|
|
408
|
+
|
|
409
|
+
# Detect whether we're in a project directory
|
|
410
|
+
in_project = (project_path / "do" / "config").is_file()
|
|
411
|
+
|
|
412
|
+
# Load project context
|
|
413
|
+
if in_project:
|
|
414
|
+
ctx = ProjectContext(str(project_path))
|
|
415
|
+
context = ctx.load()
|
|
416
|
+
project_name = context.get("project_name") or project_path.name
|
|
417
|
+
engine = context.get("engine") or "unknown"
|
|
418
|
+
target = context.get("deployment_target") or "unknown"
|
|
419
|
+
model = context.get("model") or "not set"
|
|
420
|
+
instance = context.get("instance_type") or "not set"
|
|
421
|
+
print(f"\n\033[1m📁 Project:\033[0m {project_name} ({engine}, {target})")
|
|
422
|
+
print(f" Model: {model} on {instance}")
|
|
423
|
+
else:
|
|
424
|
+
context = {"mode": "getting-started"}
|
|
425
|
+
print("\n\033[1m👋 Welcome to ml-container-creator!\033[0m")
|
|
426
|
+
print(" No do/config found — running in getting-started mode.")
|
|
427
|
+
|
|
428
|
+
# Always run health check
|
|
429
|
+
print()
|
|
430
|
+
health_check = EnvironmentHealthCheck()
|
|
431
|
+
items = health_check.run(str(project_path) if in_project else None)
|
|
432
|
+
print_health_report(items)
|
|
433
|
+
|
|
434
|
+
# Offline mode: print summary and exit
|
|
435
|
+
if offline_mode:
|
|
436
|
+
print("📄 \033[1mOffline mode\033[0m — no Bedrock calls, no MCP servers.")
|
|
437
|
+
print(" Run without --offline for interactive conversation.")
|
|
438
|
+
return
|
|
439
|
+
|
|
440
|
+
# Initialize MCP clients and agent
|
|
441
|
+
mcp_clients: list[MCPClient] = []
|
|
442
|
+
cost = CostTracker(
|
|
443
|
+
input_cost_per_1k=config.input_cost_per_1k,
|
|
444
|
+
output_cost_per_1k=config.output_cost_per_1k,
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# Register signal handler for graceful shutdown
|
|
448
|
+
def _signal_handler(signum: int, frame: Any) -> None:
|
|
449
|
+
"""Handle SIGINT for graceful cleanup."""
|
|
450
|
+
print("\n\nShutting down...")
|
|
451
|
+
_stop_mcp_servers(mcp_clients)
|
|
452
|
+
cost.print_summary()
|
|
453
|
+
sys.exit(0)
|
|
454
|
+
|
|
455
|
+
signal.signal(signal.SIGINT, _signal_handler)
|
|
456
|
+
|
|
457
|
+
try:
|
|
458
|
+
# Connect to MCP servers
|
|
459
|
+
print("Connecting to MCP servers...")
|
|
460
|
+
mcp_clients = _start_mcp_servers(
|
|
461
|
+
server_names=agent_mcp_servers,
|
|
462
|
+
timeout=config.mcp_server_timeout,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
if mcp_clients:
|
|
466
|
+
print(f" \033[32m✓\033[0m {len(mcp_clients)} MCP servers configured.")
|
|
467
|
+
else:
|
|
468
|
+
print(" \033[33m⚠\033[0m No MCP servers configured. Tool calls will be unavailable.")
|
|
469
|
+
|
|
470
|
+
# Build tools list from MCP clients + write_file
|
|
471
|
+
tools: list[Any] = list(mcp_clients) # MCPClient instances are passed directly as tools
|
|
472
|
+
tools.append(_create_write_file_tool(project_path))
|
|
473
|
+
|
|
474
|
+
# Build system prompt
|
|
475
|
+
system_prompt = _build_system_prompt(context)
|
|
476
|
+
|
|
477
|
+
# Create the Strands agent
|
|
478
|
+
agent = Agent(
|
|
479
|
+
model=config.model_id,
|
|
480
|
+
system_prompt=system_prompt,
|
|
481
|
+
tools=tools,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
print(f" \033[32m✓\033[0m Agent ready (model: {config.model_id})")
|
|
485
|
+
|
|
486
|
+
except Exception as e:
|
|
487
|
+
error_msg = str(e)
|
|
488
|
+
_stop_mcp_servers(mcp_clients)
|
|
489
|
+
|
|
490
|
+
# Check for Bedrock connection failures
|
|
491
|
+
if "bedrock" in error_msg.lower() or "credential" in error_msg.lower():
|
|
492
|
+
print(f"\n\033[31mError:\033[0m Could not connect to Bedrock: {error_msg}")
|
|
493
|
+
print("\n Suggestions:")
|
|
494
|
+
print(" • Check AWS credentials (aws sts get-caller-identity)")
|
|
495
|
+
print(" • Verify Bedrock model access in your AWS account")
|
|
496
|
+
print(" • Run with --offline for static reference mode")
|
|
497
|
+
else:
|
|
498
|
+
print(f"\n\033[31mError:\033[0m Failed to initialize agent: {error_msg}")
|
|
499
|
+
print(" Try running with --offline for static reference mode.")
|
|
500
|
+
|
|
501
|
+
sys.exit(1)
|
|
502
|
+
|
|
503
|
+
# Run REPL
|
|
504
|
+
try:
|
|
505
|
+
_run_repl(agent, context, str(project_path), cost, exit_commands, reload_commands)
|
|
506
|
+
finally:
|
|
507
|
+
# Cleanup
|
|
508
|
+
_stop_mcp_servers(mcp_clients)
|
|
509
|
+
cost.print_summary()
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
if __name__ == "__main__":
|
|
513
|
+
main()
|