shotgun-sh 0.1.0.dev12__py3-none-any.whl → 0.1.0.dev13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of shotgun-sh might be problematic. Click here for more details.

Files changed (49) hide show
  1. shotgun/agents/common.py +94 -79
  2. shotgun/agents/config/constants.py +18 -0
  3. shotgun/agents/config/manager.py +68 -16
  4. shotgun/agents/config/provider.py +11 -6
  5. shotgun/agents/models.py +6 -0
  6. shotgun/agents/plan.py +15 -37
  7. shotgun/agents/research.py +10 -45
  8. shotgun/agents/specify.py +97 -0
  9. shotgun/agents/tasks.py +7 -36
  10. shotgun/agents/tools/artifact_management.py +450 -0
  11. shotgun/agents/tools/file_management.py +2 -2
  12. shotgun/artifacts/__init__.py +17 -0
  13. shotgun/artifacts/exceptions.py +89 -0
  14. shotgun/artifacts/manager.py +529 -0
  15. shotgun/artifacts/models.py +332 -0
  16. shotgun/artifacts/service.py +463 -0
  17. shotgun/artifacts/templates/__init__.py +10 -0
  18. shotgun/artifacts/templates/loader.py +252 -0
  19. shotgun/artifacts/templates/models.py +136 -0
  20. shotgun/artifacts/templates/plan/delivery_and_release_plan.yaml +66 -0
  21. shotgun/artifacts/templates/research/market_research.yaml +585 -0
  22. shotgun/artifacts/templates/research/sdk_comparison.yaml +257 -0
  23. shotgun/artifacts/templates/specify/prd.yaml +331 -0
  24. shotgun/artifacts/templates/specify/product_spec.yaml +301 -0
  25. shotgun/artifacts/utils.py +76 -0
  26. shotgun/cli/plan.py +1 -4
  27. shotgun/cli/specify.py +69 -0
  28. shotgun/cli/tasks.py +0 -4
  29. shotgun/logging_config.py +23 -7
  30. shotgun/main.py +7 -6
  31. shotgun/prompts/agents/partials/artifact_system.j2 +32 -0
  32. shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +28 -2
  33. shotgun/prompts/agents/partials/content_formatting.j2 +65 -0
  34. shotgun/prompts/agents/partials/interactive_mode.j2 +10 -2
  35. shotgun/prompts/agents/plan.j2 +31 -32
  36. shotgun/prompts/agents/research.j2 +37 -29
  37. shotgun/prompts/agents/specify.j2 +31 -0
  38. shotgun/prompts/agents/tasks.j2 +27 -12
  39. shotgun/sdk/artifact_models.py +186 -0
  40. shotgun/sdk/artifacts.py +448 -0
  41. shotgun/tui/app.py +26 -7
  42. shotgun/tui/screens/chat.py +28 -3
  43. shotgun/tui/screens/directory_setup.py +113 -0
  44. {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev13.dist-info}/METADATA +2 -2
  45. {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev13.dist-info}/RECORD +48 -25
  46. shotgun/prompts/user/research.j2 +0 -5
  47. {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev13.dist-info}/WHEEL +0 -0
  48. {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev13.dist-info}/entry_points.txt +0 -0
  49. {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev13.dist-info}/licenses/LICENSE +0 -0
shotgun/agents/common.py CHANGED
@@ -2,7 +2,6 @@
2
2
 
3
3
  import asyncio
4
4
  from collections.abc import Callable
5
- from pathlib import Path
6
5
  from typing import Any
7
6
 
8
7
  from pydantic_ai import (
@@ -15,7 +14,9 @@ from pydantic_ai import (
15
14
  from pydantic_ai.agent import AgentRunResult
16
15
  from pydantic_ai.messages import (
17
16
  ModelMessage,
17
+ ModelRequest,
18
18
  ModelResponse,
19
+ SystemPromptPart,
19
20
  TextPart,
20
21
  )
21
22
 
@@ -38,6 +39,14 @@ from .tools import (
38
39
  retrieve_code,
39
40
  write_file,
40
41
  )
42
+ from .tools.artifact_management import (
43
+ create_artifact,
44
+ list_artifact_templates,
45
+ list_artifacts,
46
+ read_artifact,
47
+ read_artifact_section,
48
+ write_artifact_section,
49
+ )
41
50
 
42
51
  logger = get_logger(__name__)
43
52
 
@@ -45,70 +54,6 @@ logger = get_logger(__name__)
45
54
  prompt_loader = PromptLoader()
46
55
 
47
56
 
48
- def ensure_file_exists(filename: str, header: str) -> str:
49
- """Ensure a markdown file exists with proper header and return its content.
50
-
51
- Args:
52
- filename: Name of the file (e.g., "research.md")
53
- header: Header to add if file is empty (e.g., "# Research")
54
-
55
- Returns:
56
- Current file content
57
- """
58
- shotgun_dir = Path.cwd() / ".shotgun"
59
- file_path = shotgun_dir / filename
60
-
61
- try:
62
- if file_path.exists():
63
- content = file_path.read_text(encoding="utf-8")
64
- if not content.strip():
65
- # File exists but is empty, add header
66
- header_content = f"{header}\n\n"
67
- file_path.write_text(header_content, encoding="utf-8")
68
- return header_content
69
- return content
70
- else:
71
- # File doesn't exist, create it with header
72
- shotgun_dir.mkdir(exist_ok=True)
73
- header_content = f"{header}\n\n"
74
- file_path.write_text(header_content, encoding="utf-8")
75
- return header_content
76
- except Exception as e:
77
- logger.error("Failed to initialize %s: %s", filename, str(e))
78
- return f"{header}\n\n"
79
-
80
-
81
- def register_common_tools(
82
- agent: Agent[AgentDeps], additional_tools: list[Any], interactive_mode: bool
83
- ) -> None:
84
- """Register common tools with an agent.
85
-
86
- Args:
87
- agent: The Pydantic AI agent to register tools with
88
- additional_tools: List of additional tools specific to this agent
89
- interactive_mode: Whether to register interactive tools
90
- """
91
- logger.debug("📌 Registering tools with agent")
92
-
93
- # Register additional tools first (agent-specific)
94
- for tool in additional_tools:
95
- agent.tool_plain(tool)
96
-
97
- # Register interactive tool if enabled
98
- if interactive_mode:
99
- agent.tool(ask_user)
100
- logger.debug("📞 User interaction tool registered")
101
- else:
102
- logger.debug("🚫 User interaction disabled (non-interactive mode)")
103
-
104
- # Register common file management tools
105
- agent.tool_plain(read_file)
106
- agent.tool_plain(write_file)
107
- agent.tool_plain(append_file)
108
-
109
- logger.debug("✅ Tool registration complete")
110
-
111
-
112
57
  async def add_system_status_message(
113
58
  deps: AgentDeps,
114
59
  message_history: list[ModelMessage] | None = None,
@@ -128,7 +73,6 @@ async def add_system_status_message(
128
73
  system_state = prompt_loader.render(
129
74
  "agents/state/system_state.j2",
130
75
  codebase_understanding_graphs=codebase_understanding_graphs,
131
- context="system state",
132
76
  )
133
77
  message_history.append(
134
78
  ModelResponse(
@@ -179,6 +123,7 @@ def create_base_agent(
179
123
  **agent_runtime_options.model_dump(),
180
124
  llm_model=model_config,
181
125
  codebase_service=codebase_service,
126
+ system_prompt_fn=system_prompt_fn,
182
127
  )
183
128
 
184
129
  except Exception as e:
@@ -194,8 +139,9 @@ def create_base_agent(
194
139
  history_processors=[token_limit_compactor],
195
140
  )
196
141
 
197
- # Decorate the system prompt function
198
- agent.system_prompt(system_prompt_fn)
142
+ # System prompt function is stored in deps and will be called manually in run_agent
143
+ func_name = getattr(system_prompt_fn, "__name__", str(system_prompt_fn))
144
+ logger.debug("🔧 System prompt function stored: %s", func_name)
199
145
 
200
146
  # Register additional tools first (agent-specific)
201
147
  for tool in additional_tools or []:
@@ -211,7 +157,15 @@ def create_base_agent(
211
157
  agent.tool_plain(write_file)
212
158
  agent.tool_plain(append_file)
213
159
 
214
- # Register codebase understanding tools (always available)
160
+ # Register artifact management tools (always available)
161
+ agent.tool_plain(create_artifact)
162
+ agent.tool_plain(list_artifacts)
163
+ agent.tool_plain(list_artifact_templates)
164
+ agent.tool_plain(read_artifact)
165
+ agent.tool_plain(read_artifact_section)
166
+ agent.tool_plain(write_artifact_section)
167
+
168
+ # Register codebase understanding tools (conditional)
215
169
  if load_codebase_understanding_tools:
216
170
  agent.tool(query_graph)
217
171
  agent.tool(retrieve_code)
@@ -222,10 +176,47 @@ def create_base_agent(
222
176
  else:
223
177
  logger.debug("🚫🧠 Codebase understanding tools not registered")
224
178
 
225
- logger.debug("✅ Agent creation complete")
179
+ logger.debug("✅ Agent creation complete with artifact and codebase tools")
226
180
  return agent, deps
227
181
 
228
182
 
183
+ def build_agent_system_prompt(
184
+ agent_type: str,
185
+ ctx: RunContext[AgentDeps],
186
+ context_name: str | None = None,
187
+ ) -> str:
188
+ """Build system prompt for any agent type.
189
+
190
+ Args:
191
+ agent_type: Type of agent ('research', 'plan', 'tasks')
192
+ ctx: RunContext containing AgentDeps
193
+ context_name: Optional context name for template rendering
194
+
195
+ Returns:
196
+ Rendered system prompt
197
+ """
198
+ prompt_loader = PromptLoader()
199
+
200
+ # Add logging if research agent
201
+ if agent_type == "research":
202
+ logger.debug("🔧 Building research agent system prompt...")
203
+ logger.debug("Interactive mode: %s", ctx.deps.interactive_mode)
204
+
205
+ result = prompt_loader.render(
206
+ f"agents/{agent_type}.j2",
207
+ interactive_mode=ctx.deps.interactive_mode,
208
+ mode=agent_type,
209
+ )
210
+
211
+ if agent_type == "research":
212
+ logger.debug(
213
+ "✅ Research system prompt built successfully (length: %d chars)",
214
+ len(result),
215
+ )
216
+
217
+ return result
218
+
219
+
229
220
  def create_usage_limits() -> UsageLimits:
230
221
  """Create reasonable usage limits for agent runs.
231
222
 
@@ -238,20 +229,41 @@ def create_usage_limits() -> UsageLimits:
238
229
  )
239
230
 
240
231
 
241
- def get_file_history(filename: str) -> str:
242
- """Get the history content from a file.
232
+ async def add_system_prompt_message(
233
+ deps: AgentDeps,
234
+ message_history: list[ModelMessage] | None = None,
235
+ ) -> list[ModelMessage]:
236
+ """Add the system prompt as the first message in the message history.
243
237
 
244
238
  Args:
245
- filename: Name of the file (e.g., "research.md")
239
+ deps: Agent dependencies containing system_prompt_fn
240
+ message_history: Existing message history
246
241
 
247
242
  Returns:
248
- File content or fallback message
243
+ Updated message history with system prompt prepended as first message
249
244
  """
250
- try:
251
- return read_file(filename)
252
- except Exception as e:
253
- logger.debug("Could not load %s history: %s", filename, str(e))
254
- return f"No {filename.replace('.md', '')} history available."
245
+ message_history = message_history or []
246
+
247
+ # Create a minimal RunContext to call the system prompt function
248
+ # We'll pass None for model and usage since they're not used by our system prompt functions
249
+ context = type(
250
+ "RunContext", (), {"deps": deps, "retry": 0, "model": None, "usage": None}
251
+ )()
252
+
253
+ # Render the system prompt using the stored function
254
+ system_prompt_content = deps.system_prompt_fn(context)
255
+ logger.debug(
256
+ "🎯 Rendered system prompt (length: %d chars)", len(system_prompt_content)
257
+ )
258
+
259
+ # Create system message and prepend to message history
260
+ system_message = ModelRequest(
261
+ parts=[SystemPromptPart(content=system_prompt_content)]
262
+ )
263
+ message_history.insert(0, system_message)
264
+ logger.debug("✅ System prompt prepended as first message")
265
+
266
+ return message_history
255
267
 
256
268
 
257
269
  async def run_agent(
@@ -261,6 +273,9 @@ async def run_agent(
261
273
  message_history: list[ModelMessage] | None = None,
262
274
  usage_limits: UsageLimits | None = None,
263
275
  ) -> AgentRunResult[str | DeferredToolRequests]:
276
+ # Add system prompt as first message
277
+ message_history = await add_system_prompt_message(deps, message_history)
278
+
264
279
  result = await agent.run(
265
280
  prompt,
266
281
  deps=deps,
@@ -0,0 +1,18 @@
1
+ """Configuration constants for Shotgun agents."""
2
+
3
+ # Field names
4
+ API_KEY_FIELD = "api_key"
5
+ MODEL_NAME_FIELD = "model_name"
6
+ DEFAULT_PROVIDER_FIELD = "default_provider"
7
+ USER_ID_FIELD = "user_id"
8
+ CONFIG_VERSION_FIELD = "config_version"
9
+
10
+ # Provider names (for consistency with data dict keys)
11
+ OPENAI_PROVIDER = "openai"
12
+ ANTHROPIC_PROVIDER = "anthropic"
13
+ GOOGLE_PROVIDER = "google"
14
+
15
+ # Environment variable names
16
+ OPENAI_API_KEY_ENV = "OPENAI_API_KEY"
17
+ ANTHROPIC_API_KEY_ENV = "ANTHROPIC_API_KEY"
18
+ GEMINI_API_KEY_ENV = "GEMINI_API_KEY"
@@ -1,6 +1,7 @@
1
1
  """Configuration manager for Shotgun CLI."""
2
2
 
3
3
  import json
4
+ import os
4
5
  import uuid
5
6
  from pathlib import Path
6
7
  from typing import Any
@@ -10,6 +11,15 @@ from pydantic import SecretStr
10
11
  from shotgun.logging_config import get_logger
11
12
  from shotgun.utils import get_shotgun_home
12
13
 
14
+ from .constants import (
15
+ ANTHROPIC_API_KEY_ENV,
16
+ ANTHROPIC_PROVIDER,
17
+ API_KEY_FIELD,
18
+ GEMINI_API_KEY_ENV,
19
+ GOOGLE_PROVIDER,
20
+ OPENAI_API_KEY_ENV,
21
+ OPENAI_PROVIDER,
22
+ )
13
23
  from .models import ProviderType, ShotgunConfig
14
24
 
15
25
  logger = get_logger(__name__)
@@ -58,6 +68,22 @@ class ConfigManager:
58
68
 
59
69
  self._config = ShotgunConfig.model_validate(data)
60
70
  logger.debug("Configuration loaded successfully from %s", self.config_path)
71
+
72
+ # Check if the default provider has a key, if not find one that does
73
+ if not self.has_provider_key(self._config.default_provider):
74
+ original_default = self._config.default_provider
75
+ # Find first provider with a configured key
76
+ for provider in ProviderType:
77
+ if self.has_provider_key(provider):
78
+ logger.info(
79
+ "Default provider %s has no API key, updating to %s",
80
+ original_default.value,
81
+ provider.value,
82
+ )
83
+ self._config.default_provider = provider
84
+ self.save(self._config)
85
+ break
86
+
61
87
  return self._config
62
88
 
63
89
  except Exception as e:
@@ -114,17 +140,25 @@ class ConfigManager:
114
140
  provider_config = self._get_provider_config(config, provider_enum)
115
141
 
116
142
  # Only support api_key updates
117
- if "api_key" in kwargs:
118
- api_key_value = kwargs["api_key"]
143
+ if API_KEY_FIELD in kwargs:
144
+ api_key_value = kwargs[API_KEY_FIELD]
119
145
  provider_config.api_key = (
120
146
  SecretStr(api_key_value) if api_key_value is not None else None
121
147
  )
122
148
 
123
149
  # Reject other fields
124
- unsupported_fields = set(kwargs.keys()) - {"api_key"}
150
+ unsupported_fields = set(kwargs.keys()) - {API_KEY_FIELD}
125
151
  if unsupported_fields:
126
152
  raise ValueError(f"Unsupported configuration fields: {unsupported_fields}")
127
153
 
154
+ # If no other providers have keys configured and we just added one,
155
+ # set this provider as the default
156
+ if API_KEY_FIELD in kwargs and api_key_value is not None:
157
+ other_providers = [p for p in ProviderType if p != provider_enum]
158
+ has_other_keys = any(self.has_provider_key(p) for p in other_providers)
159
+ if not has_other_keys:
160
+ config.default_provider = provider_enum
161
+
128
162
  self.save(config)
129
163
 
130
164
  def clear_provider_key(self, provider: ProviderType | str) -> None:
@@ -136,11 +170,27 @@ class ConfigManager:
136
170
  self.save(config)
137
171
 
138
172
  def has_provider_key(self, provider: ProviderType | str) -> bool:
139
- """Check if the given provider has a non-empty API key configured."""
173
+ """Check if the given provider has a non-empty API key configured.
174
+
175
+ This checks both the configuration file and environment variables.
176
+ """
140
177
  config = self.load()
141
178
  provider_enum = self._ensure_provider_enum(provider)
142
179
  provider_config = self._get_provider_config(config, provider_enum)
143
- return self._provider_has_api_key(provider_config)
180
+
181
+ # Check config first
182
+ if self._provider_has_api_key(provider_config):
183
+ return True
184
+
185
+ # Check environment variable
186
+ if provider_enum == ProviderType.OPENAI:
187
+ return bool(os.getenv(OPENAI_API_KEY_ENV))
188
+ elif provider_enum == ProviderType.ANTHROPIC:
189
+ return bool(os.getenv(ANTHROPIC_API_KEY_ENV))
190
+ elif provider_enum == ProviderType.GOOGLE:
191
+ return bool(os.getenv(GEMINI_API_KEY_ENV))
192
+
193
+ return False
144
194
 
145
195
  def has_any_provider_key(self) -> bool:
146
196
  """Determine whether any provider has a configured API key."""
@@ -175,25 +225,27 @@ class ConfigManager:
175
225
 
176
226
  def _convert_secrets_to_secretstr(self, data: dict[str, Any]) -> None:
177
227
  """Convert plain text secrets in data to SecretStr objects."""
178
- for provider in ["openai", "anthropic", "google"]:
228
+ for provider in [OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GOOGLE_PROVIDER]:
179
229
  if provider in data and isinstance(data[provider], dict):
180
230
  if (
181
- "api_key" in data[provider]
182
- and data[provider]["api_key"] is not None
231
+ API_KEY_FIELD in data[provider]
232
+ and data[provider][API_KEY_FIELD] is not None
183
233
  ):
184
- data[provider]["api_key"] = SecretStr(data[provider]["api_key"])
234
+ data[provider][API_KEY_FIELD] = SecretStr(
235
+ data[provider][API_KEY_FIELD]
236
+ )
185
237
 
186
238
  def _convert_secretstr_to_plain(self, data: dict[str, Any]) -> None:
187
239
  """Convert SecretStr objects in data to plain text for JSON serialization."""
188
- for provider in ["openai", "anthropic", "google"]:
240
+ for provider in [OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GOOGLE_PROVIDER]:
189
241
  if provider in data and isinstance(data[provider], dict):
190
242
  if (
191
- "api_key" in data[provider]
192
- and data[provider]["api_key"] is not None
243
+ API_KEY_FIELD in data[provider]
244
+ and data[provider][API_KEY_FIELD] is not None
193
245
  ):
194
- if hasattr(data[provider]["api_key"], "get_secret_value"):
195
- data[provider]["api_key"] = data[provider][
196
- "api_key"
246
+ if hasattr(data[provider][API_KEY_FIELD], "get_secret_value"):
247
+ data[provider][API_KEY_FIELD] = data[provider][
248
+ API_KEY_FIELD
197
249
  ].get_secret_value()
198
250
 
199
251
  def _ensure_provider_enum(self, provider: ProviderType | str) -> ProviderType:
@@ -216,7 +268,7 @@ class ConfigManager:
216
268
 
217
269
  def _provider_has_api_key(self, provider_config: Any) -> bool:
218
270
  """Return True if the provider config contains a usable API key."""
219
- api_key = getattr(provider_config, "api_key", None)
271
+ api_key = getattr(provider_config, API_KEY_FIELD, None)
220
272
  if api_key is None:
221
273
  return False
222
274
 
@@ -13,6 +13,11 @@ from pydantic_ai.providers.openai import OpenAIProvider
13
13
 
14
14
  from shotgun.logging_config import get_logger
15
15
 
16
+ from .constants import (
17
+ ANTHROPIC_API_KEY_ENV,
18
+ GEMINI_API_KEY_ENV,
19
+ OPENAI_API_KEY_ENV,
20
+ )
16
21
  from .manager import get_config_manager
17
22
  from .models import MODEL_SPECS, ModelConfig, ProviderType
18
23
 
@@ -86,10 +91,10 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
86
91
  )
87
92
 
88
93
  if provider_enum == ProviderType.OPENAI:
89
- api_key = _get_api_key(config.openai.api_key, "OPENAI_API_KEY")
94
+ api_key = _get_api_key(config.openai.api_key, OPENAI_API_KEY_ENV)
90
95
  if not api_key:
91
96
  raise ValueError(
92
- "OpenAI API key not configured. Set via environment variable OPENAI_API_KEY or config."
97
+ f"OpenAI API key not configured. Set via environment variable {OPENAI_API_KEY_ENV} or config."
93
98
  )
94
99
 
95
100
  # Get model spec
@@ -108,10 +113,10 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
108
113
  )
109
114
 
110
115
  elif provider_enum == ProviderType.ANTHROPIC:
111
- api_key = _get_api_key(config.anthropic.api_key, "ANTHROPIC_API_KEY")
116
+ api_key = _get_api_key(config.anthropic.api_key, ANTHROPIC_API_KEY_ENV)
112
117
  if not api_key:
113
118
  raise ValueError(
114
- "Anthropic API key not configured. Set via environment variable ANTHROPIC_API_KEY or config."
119
+ f"Anthropic API key not configured. Set via environment variable {ANTHROPIC_API_KEY_ENV} or config."
115
120
  )
116
121
 
117
122
  # Get model spec
@@ -130,10 +135,10 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
130
135
  )
131
136
 
132
137
  elif provider_enum == ProviderType.GOOGLE:
133
- api_key = _get_api_key(config.google.api_key, "GEMINI_API_KEY")
138
+ api_key = _get_api_key(config.google.api_key, GEMINI_API_KEY_ENV)
134
139
  if not api_key:
135
140
  raise ValueError(
136
- "Gemini API key not configured. Set via environment variable GEMINI_API_KEY or config."
141
+ f"Gemini API key not configured. Set via environment variable {GEMINI_API_KEY_ENV} or config."
137
142
  )
138
143
 
139
144
  # Get model spec
shotgun/agents/models.py CHANGED
@@ -1,10 +1,12 @@
1
1
  """Pydantic models for agent dependencies and configuration."""
2
2
 
3
3
  from asyncio import Future, Queue
4
+ from collections.abc import Callable
4
5
  from pathlib import Path
5
6
  from typing import TYPE_CHECKING
6
7
 
7
8
  from pydantic import BaseModel, ConfigDict, Field
9
+ from pydantic_ai import RunContext
8
10
 
9
11
  from .config.models import ModelConfig
10
12
 
@@ -83,6 +85,10 @@ class AgentDeps(AgentRuntimeOptions):
83
85
  description="Codebase service for code analysis tools",
84
86
  )
85
87
 
88
+ system_prompt_fn: Callable[[RunContext["AgentDeps"]], str] = Field(
89
+ description="Function that generates the system prompt for this agent",
90
+ )
91
+
86
92
 
87
93
  # Rebuild model to resolve forward references after imports are available
88
94
  try:
shotgun/agents/plan.py CHANGED
@@ -1,51 +1,33 @@
1
1
  """Plan agent factory and functions using Pydantic AI with file-based memory."""
2
2
 
3
+ from functools import partial
4
+
3
5
  from pydantic_ai import (
4
6
  Agent,
5
7
  DeferredToolRequests,
6
- RunContext,
7
8
  )
8
9
  from pydantic_ai.agent import AgentRunResult
9
10
  from pydantic_ai.messages import ModelMessage
10
11
 
11
12
  from shotgun.agents.config import ProviderType
12
13
  from shotgun.logging_config import get_logger
13
- from shotgun.prompts import PromptLoader
14
14
 
15
15
  from .common import (
16
16
  add_system_status_message,
17
+ build_agent_system_prompt,
17
18
  create_base_agent,
18
19
  create_usage_limits,
19
- ensure_file_exists,
20
- get_file_history,
21
20
  run_agent,
22
21
  )
23
22
  from .models import AgentDeps, AgentRuntimeOptions
24
23
 
25
24
  logger = get_logger(__name__)
26
25
 
27
- # Global prompt loader instance
28
- prompt_loader = PromptLoader()
29
-
30
-
31
- def _build_plan_agent_system_prompt(ctx: RunContext[AgentDeps]) -> str:
32
- """Build the system prompt for the plan agent.
33
-
34
- Args:
35
- ctx: RunContext containing AgentDeps with interactive_mode and other settings
36
-
37
- Returns:
38
- The complete system prompt string for the plan agent
39
- """
40
- return prompt_loader.render(
41
- "agents/plan.j2", interactive_mode=ctx.deps.interactive_mode, context="plans"
42
- )
43
-
44
26
 
45
27
  def create_plan_agent(
46
28
  agent_runtime_options: AgentRuntimeOptions, provider: ProviderType | None = None
47
29
  ) -> tuple[Agent[AgentDeps, str | DeferredToolRequests], AgentDeps]:
48
- """Create a plan agent with file management capabilities.
30
+ """Create a plan agent with artifact management capabilities.
49
31
 
50
32
  Args:
51
33
  agent_runtime_options: Agent runtime options for the agent
@@ -55,8 +37,15 @@ def create_plan_agent(
55
37
  Tuple of (Configured Pydantic AI agent for planning tasks, Agent dependencies)
56
38
  """
57
39
  logger.debug("Initializing plan agent")
40
+ # Use partial to create system prompt function for plan agent
41
+ system_prompt_fn = partial(build_agent_system_prompt, "plan")
42
+
58
43
  agent, deps = create_base_agent(
59
- _build_plan_agent_system_prompt, agent_runtime_options, provider=provider
44
+ system_prompt_fn,
45
+ agent_runtime_options,
46
+ load_codebase_understanding_tools=True,
47
+ additional_tools=None,
48
+ provider=provider,
60
49
  )
61
50
  return agent, deps
62
51
 
@@ -67,7 +56,7 @@ async def run_plan_agent(
67
56
  deps: AgentDeps,
68
57
  message_history: list[ModelMessage] | None = None,
69
58
  ) -> AgentRunResult[str | DeferredToolRequests]:
70
- """Create or update a plan based on the given goal.
59
+ """Create or update a plan based on the given goal using artifacts.
71
60
 
72
61
  Args:
73
62
  agent: The configured plan agent
@@ -80,11 +69,9 @@ async def run_plan_agent(
80
69
  """
81
70
  logger.debug("📋 Starting planning for goal: %s", goal)
82
71
 
83
- # Ensure plan.md exists
84
- ensure_file_exists("plan.md", "# Plan")
85
-
86
- # Let the agent use its tools to read existing plan and research
72
+ # Simple prompt - the agent system prompt has all the artifact instructions
87
73
  full_prompt = f"Create a comprehensive plan for: {goal}"
74
+
88
75
  try:
89
76
  # Create usage limits for responsible API usage
90
77
  usage_limits = create_usage_limits()
@@ -108,12 +95,3 @@ async def run_plan_agent(
108
95
  logger.error("Full traceback:\n%s", traceback.format_exc())
109
96
  logger.error("❌ Planning failed: %s", str(e))
110
97
  raise
111
-
112
-
113
- def get_plan_history() -> str:
114
- """Get the full plan history from the file.
115
-
116
- Returns:
117
- Plan history content or fallback message
118
- """
119
- return get_file_history("plan.md")