shotgun-sh 0.1.0.dev12__py3-none-any.whl → 0.1.0.dev13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/common.py +94 -79
- shotgun/agents/config/constants.py +18 -0
- shotgun/agents/config/manager.py +68 -16
- shotgun/agents/config/provider.py +11 -6
- shotgun/agents/models.py +6 -0
- shotgun/agents/plan.py +15 -37
- shotgun/agents/research.py +10 -45
- shotgun/agents/specify.py +97 -0
- shotgun/agents/tasks.py +7 -36
- shotgun/agents/tools/artifact_management.py +450 -0
- shotgun/agents/tools/file_management.py +2 -2
- shotgun/artifacts/__init__.py +17 -0
- shotgun/artifacts/exceptions.py +89 -0
- shotgun/artifacts/manager.py +529 -0
- shotgun/artifacts/models.py +332 -0
- shotgun/artifacts/service.py +463 -0
- shotgun/artifacts/templates/__init__.py +10 -0
- shotgun/artifacts/templates/loader.py +252 -0
- shotgun/artifacts/templates/models.py +136 -0
- shotgun/artifacts/templates/plan/delivery_and_release_plan.yaml +66 -0
- shotgun/artifacts/templates/research/market_research.yaml +585 -0
- shotgun/artifacts/templates/research/sdk_comparison.yaml +257 -0
- shotgun/artifacts/templates/specify/prd.yaml +331 -0
- shotgun/artifacts/templates/specify/product_spec.yaml +301 -0
- shotgun/artifacts/utils.py +76 -0
- shotgun/cli/plan.py +1 -4
- shotgun/cli/specify.py +69 -0
- shotgun/cli/tasks.py +0 -4
- shotgun/logging_config.py +23 -7
- shotgun/main.py +7 -6
- shotgun/prompts/agents/partials/artifact_system.j2 +32 -0
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +28 -2
- shotgun/prompts/agents/partials/content_formatting.j2 +65 -0
- shotgun/prompts/agents/partials/interactive_mode.j2 +10 -2
- shotgun/prompts/agents/plan.j2 +31 -32
- shotgun/prompts/agents/research.j2 +37 -29
- shotgun/prompts/agents/specify.j2 +31 -0
- shotgun/prompts/agents/tasks.j2 +27 -12
- shotgun/sdk/artifact_models.py +186 -0
- shotgun/sdk/artifacts.py +448 -0
- shotgun/tui/app.py +26 -7
- shotgun/tui/screens/chat.py +28 -3
- shotgun/tui/screens/directory_setup.py +113 -0
- {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev13.dist-info}/METADATA +2 -2
- {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev13.dist-info}/RECORD +48 -25
- shotgun/prompts/user/research.j2 +0 -5
- {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev13.dist-info}/WHEEL +0 -0
- {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev13.dist-info}/entry_points.txt +0 -0
- {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev13.dist-info}/licenses/LICENSE +0 -0
shotgun/agents/common.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from collections.abc import Callable
|
|
5
|
-
from pathlib import Path
|
|
6
5
|
from typing import Any
|
|
7
6
|
|
|
8
7
|
from pydantic_ai import (
|
|
@@ -15,7 +14,9 @@ from pydantic_ai import (
|
|
|
15
14
|
from pydantic_ai.agent import AgentRunResult
|
|
16
15
|
from pydantic_ai.messages import (
|
|
17
16
|
ModelMessage,
|
|
17
|
+
ModelRequest,
|
|
18
18
|
ModelResponse,
|
|
19
|
+
SystemPromptPart,
|
|
19
20
|
TextPart,
|
|
20
21
|
)
|
|
21
22
|
|
|
@@ -38,6 +39,14 @@ from .tools import (
|
|
|
38
39
|
retrieve_code,
|
|
39
40
|
write_file,
|
|
40
41
|
)
|
|
42
|
+
from .tools.artifact_management import (
|
|
43
|
+
create_artifact,
|
|
44
|
+
list_artifact_templates,
|
|
45
|
+
list_artifacts,
|
|
46
|
+
read_artifact,
|
|
47
|
+
read_artifact_section,
|
|
48
|
+
write_artifact_section,
|
|
49
|
+
)
|
|
41
50
|
|
|
42
51
|
logger = get_logger(__name__)
|
|
43
52
|
|
|
@@ -45,70 +54,6 @@ logger = get_logger(__name__)
|
|
|
45
54
|
prompt_loader = PromptLoader()
|
|
46
55
|
|
|
47
56
|
|
|
48
|
-
def ensure_file_exists(filename: str, header: str) -> str:
|
|
49
|
-
"""Ensure a markdown file exists with proper header and return its content.
|
|
50
|
-
|
|
51
|
-
Args:
|
|
52
|
-
filename: Name of the file (e.g., "research.md")
|
|
53
|
-
header: Header to add if file is empty (e.g., "# Research")
|
|
54
|
-
|
|
55
|
-
Returns:
|
|
56
|
-
Current file content
|
|
57
|
-
"""
|
|
58
|
-
shotgun_dir = Path.cwd() / ".shotgun"
|
|
59
|
-
file_path = shotgun_dir / filename
|
|
60
|
-
|
|
61
|
-
try:
|
|
62
|
-
if file_path.exists():
|
|
63
|
-
content = file_path.read_text(encoding="utf-8")
|
|
64
|
-
if not content.strip():
|
|
65
|
-
# File exists but is empty, add header
|
|
66
|
-
header_content = f"{header}\n\n"
|
|
67
|
-
file_path.write_text(header_content, encoding="utf-8")
|
|
68
|
-
return header_content
|
|
69
|
-
return content
|
|
70
|
-
else:
|
|
71
|
-
# File doesn't exist, create it with header
|
|
72
|
-
shotgun_dir.mkdir(exist_ok=True)
|
|
73
|
-
header_content = f"{header}\n\n"
|
|
74
|
-
file_path.write_text(header_content, encoding="utf-8")
|
|
75
|
-
return header_content
|
|
76
|
-
except Exception as e:
|
|
77
|
-
logger.error("Failed to initialize %s: %s", filename, str(e))
|
|
78
|
-
return f"{header}\n\n"
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def register_common_tools(
|
|
82
|
-
agent: Agent[AgentDeps], additional_tools: list[Any], interactive_mode: bool
|
|
83
|
-
) -> None:
|
|
84
|
-
"""Register common tools with an agent.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
agent: The Pydantic AI agent to register tools with
|
|
88
|
-
additional_tools: List of additional tools specific to this agent
|
|
89
|
-
interactive_mode: Whether to register interactive tools
|
|
90
|
-
"""
|
|
91
|
-
logger.debug("📌 Registering tools with agent")
|
|
92
|
-
|
|
93
|
-
# Register additional tools first (agent-specific)
|
|
94
|
-
for tool in additional_tools:
|
|
95
|
-
agent.tool_plain(tool)
|
|
96
|
-
|
|
97
|
-
# Register interactive tool if enabled
|
|
98
|
-
if interactive_mode:
|
|
99
|
-
agent.tool(ask_user)
|
|
100
|
-
logger.debug("📞 User interaction tool registered")
|
|
101
|
-
else:
|
|
102
|
-
logger.debug("🚫 User interaction disabled (non-interactive mode)")
|
|
103
|
-
|
|
104
|
-
# Register common file management tools
|
|
105
|
-
agent.tool_plain(read_file)
|
|
106
|
-
agent.tool_plain(write_file)
|
|
107
|
-
agent.tool_plain(append_file)
|
|
108
|
-
|
|
109
|
-
logger.debug("✅ Tool registration complete")
|
|
110
|
-
|
|
111
|
-
|
|
112
57
|
async def add_system_status_message(
|
|
113
58
|
deps: AgentDeps,
|
|
114
59
|
message_history: list[ModelMessage] | None = None,
|
|
@@ -128,7 +73,6 @@ async def add_system_status_message(
|
|
|
128
73
|
system_state = prompt_loader.render(
|
|
129
74
|
"agents/state/system_state.j2",
|
|
130
75
|
codebase_understanding_graphs=codebase_understanding_graphs,
|
|
131
|
-
context="system state",
|
|
132
76
|
)
|
|
133
77
|
message_history.append(
|
|
134
78
|
ModelResponse(
|
|
@@ -179,6 +123,7 @@ def create_base_agent(
|
|
|
179
123
|
**agent_runtime_options.model_dump(),
|
|
180
124
|
llm_model=model_config,
|
|
181
125
|
codebase_service=codebase_service,
|
|
126
|
+
system_prompt_fn=system_prompt_fn,
|
|
182
127
|
)
|
|
183
128
|
|
|
184
129
|
except Exception as e:
|
|
@@ -194,8 +139,9 @@ def create_base_agent(
|
|
|
194
139
|
history_processors=[token_limit_compactor],
|
|
195
140
|
)
|
|
196
141
|
|
|
197
|
-
#
|
|
198
|
-
|
|
142
|
+
# System prompt function is stored in deps and will be called manually in run_agent
|
|
143
|
+
func_name = getattr(system_prompt_fn, "__name__", str(system_prompt_fn))
|
|
144
|
+
logger.debug("🔧 System prompt function stored: %s", func_name)
|
|
199
145
|
|
|
200
146
|
# Register additional tools first (agent-specific)
|
|
201
147
|
for tool in additional_tools or []:
|
|
@@ -211,7 +157,15 @@ def create_base_agent(
|
|
|
211
157
|
agent.tool_plain(write_file)
|
|
212
158
|
agent.tool_plain(append_file)
|
|
213
159
|
|
|
214
|
-
# Register
|
|
160
|
+
# Register artifact management tools (always available)
|
|
161
|
+
agent.tool_plain(create_artifact)
|
|
162
|
+
agent.tool_plain(list_artifacts)
|
|
163
|
+
agent.tool_plain(list_artifact_templates)
|
|
164
|
+
agent.tool_plain(read_artifact)
|
|
165
|
+
agent.tool_plain(read_artifact_section)
|
|
166
|
+
agent.tool_plain(write_artifact_section)
|
|
167
|
+
|
|
168
|
+
# Register codebase understanding tools (conditional)
|
|
215
169
|
if load_codebase_understanding_tools:
|
|
216
170
|
agent.tool(query_graph)
|
|
217
171
|
agent.tool(retrieve_code)
|
|
@@ -222,10 +176,47 @@ def create_base_agent(
|
|
|
222
176
|
else:
|
|
223
177
|
logger.debug("🚫🧠 Codebase understanding tools not registered")
|
|
224
178
|
|
|
225
|
-
logger.debug("✅ Agent creation complete")
|
|
179
|
+
logger.debug("✅ Agent creation complete with artifact and codebase tools")
|
|
226
180
|
return agent, deps
|
|
227
181
|
|
|
228
182
|
|
|
183
|
+
def build_agent_system_prompt(
|
|
184
|
+
agent_type: str,
|
|
185
|
+
ctx: RunContext[AgentDeps],
|
|
186
|
+
context_name: str | None = None,
|
|
187
|
+
) -> str:
|
|
188
|
+
"""Build system prompt for any agent type.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
agent_type: Type of agent ('research', 'plan', 'tasks')
|
|
192
|
+
ctx: RunContext containing AgentDeps
|
|
193
|
+
context_name: Optional context name for template rendering
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Rendered system prompt
|
|
197
|
+
"""
|
|
198
|
+
prompt_loader = PromptLoader()
|
|
199
|
+
|
|
200
|
+
# Add logging if research agent
|
|
201
|
+
if agent_type == "research":
|
|
202
|
+
logger.debug("🔧 Building research agent system prompt...")
|
|
203
|
+
logger.debug("Interactive mode: %s", ctx.deps.interactive_mode)
|
|
204
|
+
|
|
205
|
+
result = prompt_loader.render(
|
|
206
|
+
f"agents/{agent_type}.j2",
|
|
207
|
+
interactive_mode=ctx.deps.interactive_mode,
|
|
208
|
+
mode=agent_type,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
if agent_type == "research":
|
|
212
|
+
logger.debug(
|
|
213
|
+
"✅ Research system prompt built successfully (length: %d chars)",
|
|
214
|
+
len(result),
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
return result
|
|
218
|
+
|
|
219
|
+
|
|
229
220
|
def create_usage_limits() -> UsageLimits:
|
|
230
221
|
"""Create reasonable usage limits for agent runs.
|
|
231
222
|
|
|
@@ -238,20 +229,41 @@ def create_usage_limits() -> UsageLimits:
|
|
|
238
229
|
)
|
|
239
230
|
|
|
240
231
|
|
|
241
|
-
def
|
|
242
|
-
|
|
232
|
+
async def add_system_prompt_message(
|
|
233
|
+
deps: AgentDeps,
|
|
234
|
+
message_history: list[ModelMessage] | None = None,
|
|
235
|
+
) -> list[ModelMessage]:
|
|
236
|
+
"""Add the system prompt as the first message in the message history.
|
|
243
237
|
|
|
244
238
|
Args:
|
|
245
|
-
|
|
239
|
+
deps: Agent dependencies containing system_prompt_fn
|
|
240
|
+
message_history: Existing message history
|
|
246
241
|
|
|
247
242
|
Returns:
|
|
248
|
-
|
|
243
|
+
Updated message history with system prompt prepended as first message
|
|
249
244
|
"""
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
245
|
+
message_history = message_history or []
|
|
246
|
+
|
|
247
|
+
# Create a minimal RunContext to call the system prompt function
|
|
248
|
+
# We'll pass None for model and usage since they're not used by our system prompt functions
|
|
249
|
+
context = type(
|
|
250
|
+
"RunContext", (), {"deps": deps, "retry": 0, "model": None, "usage": None}
|
|
251
|
+
)()
|
|
252
|
+
|
|
253
|
+
# Render the system prompt using the stored function
|
|
254
|
+
system_prompt_content = deps.system_prompt_fn(context)
|
|
255
|
+
logger.debug(
|
|
256
|
+
"🎯 Rendered system prompt (length: %d chars)", len(system_prompt_content)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Create system message and prepend to message history
|
|
260
|
+
system_message = ModelRequest(
|
|
261
|
+
parts=[SystemPromptPart(content=system_prompt_content)]
|
|
262
|
+
)
|
|
263
|
+
message_history.insert(0, system_message)
|
|
264
|
+
logger.debug("✅ System prompt prepended as first message")
|
|
265
|
+
|
|
266
|
+
return message_history
|
|
255
267
|
|
|
256
268
|
|
|
257
269
|
async def run_agent(
|
|
@@ -261,6 +273,9 @@ async def run_agent(
|
|
|
261
273
|
message_history: list[ModelMessage] | None = None,
|
|
262
274
|
usage_limits: UsageLimits | None = None,
|
|
263
275
|
) -> AgentRunResult[str | DeferredToolRequests]:
|
|
276
|
+
# Add system prompt as first message
|
|
277
|
+
message_history = await add_system_prompt_message(deps, message_history)
|
|
278
|
+
|
|
264
279
|
result = await agent.run(
|
|
265
280
|
prompt,
|
|
266
281
|
deps=deps,
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Configuration constants for Shotgun agents."""
|
|
2
|
+
|
|
3
|
+
# Field names
|
|
4
|
+
API_KEY_FIELD = "api_key"
|
|
5
|
+
MODEL_NAME_FIELD = "model_name"
|
|
6
|
+
DEFAULT_PROVIDER_FIELD = "default_provider"
|
|
7
|
+
USER_ID_FIELD = "user_id"
|
|
8
|
+
CONFIG_VERSION_FIELD = "config_version"
|
|
9
|
+
|
|
10
|
+
# Provider names (for consistency with data dict keys)
|
|
11
|
+
OPENAI_PROVIDER = "openai"
|
|
12
|
+
ANTHROPIC_PROVIDER = "anthropic"
|
|
13
|
+
GOOGLE_PROVIDER = "google"
|
|
14
|
+
|
|
15
|
+
# Environment variable names
|
|
16
|
+
OPENAI_API_KEY_ENV = "OPENAI_API_KEY"
|
|
17
|
+
ANTHROPIC_API_KEY_ENV = "ANTHROPIC_API_KEY"
|
|
18
|
+
GEMINI_API_KEY_ENV = "GEMINI_API_KEY"
|
shotgun/agents/config/manager.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Configuration manager for Shotgun CLI."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import os
|
|
4
5
|
import uuid
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from typing import Any
|
|
@@ -10,6 +11,15 @@ from pydantic import SecretStr
|
|
|
10
11
|
from shotgun.logging_config import get_logger
|
|
11
12
|
from shotgun.utils import get_shotgun_home
|
|
12
13
|
|
|
14
|
+
from .constants import (
|
|
15
|
+
ANTHROPIC_API_KEY_ENV,
|
|
16
|
+
ANTHROPIC_PROVIDER,
|
|
17
|
+
API_KEY_FIELD,
|
|
18
|
+
GEMINI_API_KEY_ENV,
|
|
19
|
+
GOOGLE_PROVIDER,
|
|
20
|
+
OPENAI_API_KEY_ENV,
|
|
21
|
+
OPENAI_PROVIDER,
|
|
22
|
+
)
|
|
13
23
|
from .models import ProviderType, ShotgunConfig
|
|
14
24
|
|
|
15
25
|
logger = get_logger(__name__)
|
|
@@ -58,6 +68,22 @@ class ConfigManager:
|
|
|
58
68
|
|
|
59
69
|
self._config = ShotgunConfig.model_validate(data)
|
|
60
70
|
logger.debug("Configuration loaded successfully from %s", self.config_path)
|
|
71
|
+
|
|
72
|
+
# Check if the default provider has a key, if not find one that does
|
|
73
|
+
if not self.has_provider_key(self._config.default_provider):
|
|
74
|
+
original_default = self._config.default_provider
|
|
75
|
+
# Find first provider with a configured key
|
|
76
|
+
for provider in ProviderType:
|
|
77
|
+
if self.has_provider_key(provider):
|
|
78
|
+
logger.info(
|
|
79
|
+
"Default provider %s has no API key, updating to %s",
|
|
80
|
+
original_default.value,
|
|
81
|
+
provider.value,
|
|
82
|
+
)
|
|
83
|
+
self._config.default_provider = provider
|
|
84
|
+
self.save(self._config)
|
|
85
|
+
break
|
|
86
|
+
|
|
61
87
|
return self._config
|
|
62
88
|
|
|
63
89
|
except Exception as e:
|
|
@@ -114,17 +140,25 @@ class ConfigManager:
|
|
|
114
140
|
provider_config = self._get_provider_config(config, provider_enum)
|
|
115
141
|
|
|
116
142
|
# Only support api_key updates
|
|
117
|
-
if
|
|
118
|
-
api_key_value = kwargs[
|
|
143
|
+
if API_KEY_FIELD in kwargs:
|
|
144
|
+
api_key_value = kwargs[API_KEY_FIELD]
|
|
119
145
|
provider_config.api_key = (
|
|
120
146
|
SecretStr(api_key_value) if api_key_value is not None else None
|
|
121
147
|
)
|
|
122
148
|
|
|
123
149
|
# Reject other fields
|
|
124
|
-
unsupported_fields = set(kwargs.keys()) - {
|
|
150
|
+
unsupported_fields = set(kwargs.keys()) - {API_KEY_FIELD}
|
|
125
151
|
if unsupported_fields:
|
|
126
152
|
raise ValueError(f"Unsupported configuration fields: {unsupported_fields}")
|
|
127
153
|
|
|
154
|
+
# If no other providers have keys configured and we just added one,
|
|
155
|
+
# set this provider as the default
|
|
156
|
+
if API_KEY_FIELD in kwargs and api_key_value is not None:
|
|
157
|
+
other_providers = [p for p in ProviderType if p != provider_enum]
|
|
158
|
+
has_other_keys = any(self.has_provider_key(p) for p in other_providers)
|
|
159
|
+
if not has_other_keys:
|
|
160
|
+
config.default_provider = provider_enum
|
|
161
|
+
|
|
128
162
|
self.save(config)
|
|
129
163
|
|
|
130
164
|
def clear_provider_key(self, provider: ProviderType | str) -> None:
|
|
@@ -136,11 +170,27 @@ class ConfigManager:
|
|
|
136
170
|
self.save(config)
|
|
137
171
|
|
|
138
172
|
def has_provider_key(self, provider: ProviderType | str) -> bool:
|
|
139
|
-
"""Check if the given provider has a non-empty API key configured.
|
|
173
|
+
"""Check if the given provider has a non-empty API key configured.
|
|
174
|
+
|
|
175
|
+
This checks both the configuration file and environment variables.
|
|
176
|
+
"""
|
|
140
177
|
config = self.load()
|
|
141
178
|
provider_enum = self._ensure_provider_enum(provider)
|
|
142
179
|
provider_config = self._get_provider_config(config, provider_enum)
|
|
143
|
-
|
|
180
|
+
|
|
181
|
+
# Check config first
|
|
182
|
+
if self._provider_has_api_key(provider_config):
|
|
183
|
+
return True
|
|
184
|
+
|
|
185
|
+
# Check environment variable
|
|
186
|
+
if provider_enum == ProviderType.OPENAI:
|
|
187
|
+
return bool(os.getenv(OPENAI_API_KEY_ENV))
|
|
188
|
+
elif provider_enum == ProviderType.ANTHROPIC:
|
|
189
|
+
return bool(os.getenv(ANTHROPIC_API_KEY_ENV))
|
|
190
|
+
elif provider_enum == ProviderType.GOOGLE:
|
|
191
|
+
return bool(os.getenv(GEMINI_API_KEY_ENV))
|
|
192
|
+
|
|
193
|
+
return False
|
|
144
194
|
|
|
145
195
|
def has_any_provider_key(self) -> bool:
|
|
146
196
|
"""Determine whether any provider has a configured API key."""
|
|
@@ -175,25 +225,27 @@ class ConfigManager:
|
|
|
175
225
|
|
|
176
226
|
def _convert_secrets_to_secretstr(self, data: dict[str, Any]) -> None:
|
|
177
227
|
"""Convert plain text secrets in data to SecretStr objects."""
|
|
178
|
-
for provider in [
|
|
228
|
+
for provider in [OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GOOGLE_PROVIDER]:
|
|
179
229
|
if provider in data and isinstance(data[provider], dict):
|
|
180
230
|
if (
|
|
181
|
-
|
|
182
|
-
and data[provider][
|
|
231
|
+
API_KEY_FIELD in data[provider]
|
|
232
|
+
and data[provider][API_KEY_FIELD] is not None
|
|
183
233
|
):
|
|
184
|
-
data[provider][
|
|
234
|
+
data[provider][API_KEY_FIELD] = SecretStr(
|
|
235
|
+
data[provider][API_KEY_FIELD]
|
|
236
|
+
)
|
|
185
237
|
|
|
186
238
|
def _convert_secretstr_to_plain(self, data: dict[str, Any]) -> None:
|
|
187
239
|
"""Convert SecretStr objects in data to plain text for JSON serialization."""
|
|
188
|
-
for provider in [
|
|
240
|
+
for provider in [OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GOOGLE_PROVIDER]:
|
|
189
241
|
if provider in data and isinstance(data[provider], dict):
|
|
190
242
|
if (
|
|
191
|
-
|
|
192
|
-
and data[provider][
|
|
243
|
+
API_KEY_FIELD in data[provider]
|
|
244
|
+
and data[provider][API_KEY_FIELD] is not None
|
|
193
245
|
):
|
|
194
|
-
if hasattr(data[provider][
|
|
195
|
-
data[provider][
|
|
196
|
-
|
|
246
|
+
if hasattr(data[provider][API_KEY_FIELD], "get_secret_value"):
|
|
247
|
+
data[provider][API_KEY_FIELD] = data[provider][
|
|
248
|
+
API_KEY_FIELD
|
|
197
249
|
].get_secret_value()
|
|
198
250
|
|
|
199
251
|
def _ensure_provider_enum(self, provider: ProviderType | str) -> ProviderType:
|
|
@@ -216,7 +268,7 @@ class ConfigManager:
|
|
|
216
268
|
|
|
217
269
|
def _provider_has_api_key(self, provider_config: Any) -> bool:
|
|
218
270
|
"""Return True if the provider config contains a usable API key."""
|
|
219
|
-
api_key = getattr(provider_config,
|
|
271
|
+
api_key = getattr(provider_config, API_KEY_FIELD, None)
|
|
220
272
|
if api_key is None:
|
|
221
273
|
return False
|
|
222
274
|
|
|
@@ -13,6 +13,11 @@ from pydantic_ai.providers.openai import OpenAIProvider
|
|
|
13
13
|
|
|
14
14
|
from shotgun.logging_config import get_logger
|
|
15
15
|
|
|
16
|
+
from .constants import (
|
|
17
|
+
ANTHROPIC_API_KEY_ENV,
|
|
18
|
+
GEMINI_API_KEY_ENV,
|
|
19
|
+
OPENAI_API_KEY_ENV,
|
|
20
|
+
)
|
|
16
21
|
from .manager import get_config_manager
|
|
17
22
|
from .models import MODEL_SPECS, ModelConfig, ProviderType
|
|
18
23
|
|
|
@@ -86,10 +91,10 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
|
|
|
86
91
|
)
|
|
87
92
|
|
|
88
93
|
if provider_enum == ProviderType.OPENAI:
|
|
89
|
-
api_key = _get_api_key(config.openai.api_key,
|
|
94
|
+
api_key = _get_api_key(config.openai.api_key, OPENAI_API_KEY_ENV)
|
|
90
95
|
if not api_key:
|
|
91
96
|
raise ValueError(
|
|
92
|
-
"OpenAI API key not configured. Set via environment variable
|
|
97
|
+
f"OpenAI API key not configured. Set via environment variable {OPENAI_API_KEY_ENV} or config."
|
|
93
98
|
)
|
|
94
99
|
|
|
95
100
|
# Get model spec
|
|
@@ -108,10 +113,10 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
|
|
|
108
113
|
)
|
|
109
114
|
|
|
110
115
|
elif provider_enum == ProviderType.ANTHROPIC:
|
|
111
|
-
api_key = _get_api_key(config.anthropic.api_key,
|
|
116
|
+
api_key = _get_api_key(config.anthropic.api_key, ANTHROPIC_API_KEY_ENV)
|
|
112
117
|
if not api_key:
|
|
113
118
|
raise ValueError(
|
|
114
|
-
"Anthropic API key not configured. Set via environment variable
|
|
119
|
+
f"Anthropic API key not configured. Set via environment variable {ANTHROPIC_API_KEY_ENV} or config."
|
|
115
120
|
)
|
|
116
121
|
|
|
117
122
|
# Get model spec
|
|
@@ -130,10 +135,10 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
|
|
|
130
135
|
)
|
|
131
136
|
|
|
132
137
|
elif provider_enum == ProviderType.GOOGLE:
|
|
133
|
-
api_key = _get_api_key(config.google.api_key,
|
|
138
|
+
api_key = _get_api_key(config.google.api_key, GEMINI_API_KEY_ENV)
|
|
134
139
|
if not api_key:
|
|
135
140
|
raise ValueError(
|
|
136
|
-
"Gemini API key not configured. Set via environment variable
|
|
141
|
+
f"Gemini API key not configured. Set via environment variable {GEMINI_API_KEY_ENV} or config."
|
|
137
142
|
)
|
|
138
143
|
|
|
139
144
|
# Get model spec
|
shotgun/agents/models.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
"""Pydantic models for agent dependencies and configuration."""
|
|
2
2
|
|
|
3
3
|
from asyncio import Future, Queue
|
|
4
|
+
from collections.abc import Callable
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import TYPE_CHECKING
|
|
6
7
|
|
|
7
8
|
from pydantic import BaseModel, ConfigDict, Field
|
|
9
|
+
from pydantic_ai import RunContext
|
|
8
10
|
|
|
9
11
|
from .config.models import ModelConfig
|
|
10
12
|
|
|
@@ -83,6 +85,10 @@ class AgentDeps(AgentRuntimeOptions):
|
|
|
83
85
|
description="Codebase service for code analysis tools",
|
|
84
86
|
)
|
|
85
87
|
|
|
88
|
+
system_prompt_fn: Callable[[RunContext["AgentDeps"]], str] = Field(
|
|
89
|
+
description="Function that generates the system prompt for this agent",
|
|
90
|
+
)
|
|
91
|
+
|
|
86
92
|
|
|
87
93
|
# Rebuild model to resolve forward references after imports are available
|
|
88
94
|
try:
|
shotgun/agents/plan.py
CHANGED
|
@@ -1,51 +1,33 @@
|
|
|
1
1
|
"""Plan agent factory and functions using Pydantic AI with file-based memory."""
|
|
2
2
|
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
3
5
|
from pydantic_ai import (
|
|
4
6
|
Agent,
|
|
5
7
|
DeferredToolRequests,
|
|
6
|
-
RunContext,
|
|
7
8
|
)
|
|
8
9
|
from pydantic_ai.agent import AgentRunResult
|
|
9
10
|
from pydantic_ai.messages import ModelMessage
|
|
10
11
|
|
|
11
12
|
from shotgun.agents.config import ProviderType
|
|
12
13
|
from shotgun.logging_config import get_logger
|
|
13
|
-
from shotgun.prompts import PromptLoader
|
|
14
14
|
|
|
15
15
|
from .common import (
|
|
16
16
|
add_system_status_message,
|
|
17
|
+
build_agent_system_prompt,
|
|
17
18
|
create_base_agent,
|
|
18
19
|
create_usage_limits,
|
|
19
|
-
ensure_file_exists,
|
|
20
|
-
get_file_history,
|
|
21
20
|
run_agent,
|
|
22
21
|
)
|
|
23
22
|
from .models import AgentDeps, AgentRuntimeOptions
|
|
24
23
|
|
|
25
24
|
logger = get_logger(__name__)
|
|
26
25
|
|
|
27
|
-
# Global prompt loader instance
|
|
28
|
-
prompt_loader = PromptLoader()
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def _build_plan_agent_system_prompt(ctx: RunContext[AgentDeps]) -> str:
|
|
32
|
-
"""Build the system prompt for the plan agent.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
ctx: RunContext containing AgentDeps with interactive_mode and other settings
|
|
36
|
-
|
|
37
|
-
Returns:
|
|
38
|
-
The complete system prompt string for the plan agent
|
|
39
|
-
"""
|
|
40
|
-
return prompt_loader.render(
|
|
41
|
-
"agents/plan.j2", interactive_mode=ctx.deps.interactive_mode, context="plans"
|
|
42
|
-
)
|
|
43
|
-
|
|
44
26
|
|
|
45
27
|
def create_plan_agent(
|
|
46
28
|
agent_runtime_options: AgentRuntimeOptions, provider: ProviderType | None = None
|
|
47
29
|
) -> tuple[Agent[AgentDeps, str | DeferredToolRequests], AgentDeps]:
|
|
48
|
-
"""Create a plan agent with
|
|
30
|
+
"""Create a plan agent with artifact management capabilities.
|
|
49
31
|
|
|
50
32
|
Args:
|
|
51
33
|
agent_runtime_options: Agent runtime options for the agent
|
|
@@ -55,8 +37,15 @@ def create_plan_agent(
|
|
|
55
37
|
Tuple of (Configured Pydantic AI agent for planning tasks, Agent dependencies)
|
|
56
38
|
"""
|
|
57
39
|
logger.debug("Initializing plan agent")
|
|
40
|
+
# Use partial to create system prompt function for plan agent
|
|
41
|
+
system_prompt_fn = partial(build_agent_system_prompt, "plan")
|
|
42
|
+
|
|
58
43
|
agent, deps = create_base_agent(
|
|
59
|
-
|
|
44
|
+
system_prompt_fn,
|
|
45
|
+
agent_runtime_options,
|
|
46
|
+
load_codebase_understanding_tools=True,
|
|
47
|
+
additional_tools=None,
|
|
48
|
+
provider=provider,
|
|
60
49
|
)
|
|
61
50
|
return agent, deps
|
|
62
51
|
|
|
@@ -67,7 +56,7 @@ async def run_plan_agent(
|
|
|
67
56
|
deps: AgentDeps,
|
|
68
57
|
message_history: list[ModelMessage] | None = None,
|
|
69
58
|
) -> AgentRunResult[str | DeferredToolRequests]:
|
|
70
|
-
"""Create or update a plan based on the given goal.
|
|
59
|
+
"""Create or update a plan based on the given goal using artifacts.
|
|
71
60
|
|
|
72
61
|
Args:
|
|
73
62
|
agent: The configured plan agent
|
|
@@ -80,11 +69,9 @@ async def run_plan_agent(
|
|
|
80
69
|
"""
|
|
81
70
|
logger.debug("📋 Starting planning for goal: %s", goal)
|
|
82
71
|
|
|
83
|
-
#
|
|
84
|
-
ensure_file_exists("plan.md", "# Plan")
|
|
85
|
-
|
|
86
|
-
# Let the agent use its tools to read existing plan and research
|
|
72
|
+
# Simple prompt - the agent system prompt has all the artifact instructions
|
|
87
73
|
full_prompt = f"Create a comprehensive plan for: {goal}"
|
|
74
|
+
|
|
88
75
|
try:
|
|
89
76
|
# Create usage limits for responsible API usage
|
|
90
77
|
usage_limits = create_usage_limits()
|
|
@@ -108,12 +95,3 @@ async def run_plan_agent(
|
|
|
108
95
|
logger.error("Full traceback:\n%s", traceback.format_exc())
|
|
109
96
|
logger.error("❌ Planning failed: %s", str(e))
|
|
110
97
|
raise
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
def get_plan_history() -> str:
|
|
114
|
-
"""Get the full plan history from the file.
|
|
115
|
-
|
|
116
|
-
Returns:
|
|
117
|
-
Plan history content or fallback message
|
|
118
|
-
"""
|
|
119
|
-
return get_file_history("plan.md")
|