shotgun-sh 0.1.0.dev12__py3-none-any.whl → 0.1.0.dev14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/agent_manager.py +16 -3
- shotgun/agents/artifact_state.py +58 -0
- shotgun/agents/common.py +137 -88
- shotgun/agents/config/constants.py +18 -0
- shotgun/agents/config/manager.py +68 -16
- shotgun/agents/config/models.py +61 -0
- shotgun/agents/config/provider.py +11 -6
- shotgun/agents/history/compaction.py +85 -0
- shotgun/agents/history/constants.py +19 -0
- shotgun/agents/history/context_extraction.py +108 -0
- shotgun/agents/history/history_building.py +104 -0
- shotgun/agents/history/history_processors.py +354 -157
- shotgun/agents/history/message_utils.py +46 -0
- shotgun/agents/history/token_counting.py +429 -0
- shotgun/agents/history/token_estimation.py +138 -0
- shotgun/agents/models.py +131 -1
- shotgun/agents/plan.py +15 -37
- shotgun/agents/research.py +10 -45
- shotgun/agents/specify.py +97 -0
- shotgun/agents/tasks.py +7 -36
- shotgun/agents/tools/artifact_management.py +482 -0
- shotgun/agents/tools/file_management.py +31 -12
- shotgun/agents/tools/web_search/anthropic.py +78 -17
- shotgun/agents/tools/web_search/gemini.py +1 -1
- shotgun/agents/tools/web_search/openai.py +16 -2
- shotgun/artifacts/__init__.py +17 -0
- shotgun/artifacts/exceptions.py +89 -0
- shotgun/artifacts/manager.py +530 -0
- shotgun/artifacts/models.py +334 -0
- shotgun/artifacts/service.py +463 -0
- shotgun/artifacts/templates/__init__.py +10 -0
- shotgun/artifacts/templates/loader.py +252 -0
- shotgun/artifacts/templates/models.py +136 -0
- shotgun/artifacts/templates/plan/delivery_and_release_plan.yaml +66 -0
- shotgun/artifacts/templates/research/market_research.yaml +585 -0
- shotgun/artifacts/templates/research/sdk_comparison.yaml +257 -0
- shotgun/artifacts/templates/specify/prd.yaml +331 -0
- shotgun/artifacts/templates/specify/product_spec.yaml +301 -0
- shotgun/artifacts/utils.py +76 -0
- shotgun/cli/plan.py +1 -4
- shotgun/cli/specify.py +69 -0
- shotgun/cli/tasks.py +0 -4
- shotgun/codebase/core/nl_query.py +4 -4
- shotgun/logging_config.py +23 -7
- shotgun/main.py +7 -6
- shotgun/prompts/agents/partials/artifact_system.j2 +35 -0
- shotgun/prompts/agents/partials/codebase_understanding.j2 +1 -2
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +28 -2
- shotgun/prompts/agents/partials/content_formatting.j2 +65 -0
- shotgun/prompts/agents/partials/interactive_mode.j2 +10 -2
- shotgun/prompts/agents/plan.j2 +33 -32
- shotgun/prompts/agents/research.j2 +39 -29
- shotgun/prompts/agents/specify.j2 +32 -0
- shotgun/prompts/agents/state/artifact_templates_available.j2 +18 -0
- shotgun/prompts/agents/state/codebase/codebase_graphs_available.j2 +3 -1
- shotgun/prompts/agents/state/existing_artifacts_available.j2 +23 -0
- shotgun/prompts/agents/state/system_state.j2 +9 -1
- shotgun/prompts/agents/tasks.j2 +27 -12
- shotgun/prompts/history/incremental_summarization.j2 +53 -0
- shotgun/sdk/artifact_models.py +186 -0
- shotgun/sdk/artifacts.py +448 -0
- shotgun/sdk/services.py +14 -0
- shotgun/tui/app.py +26 -7
- shotgun/tui/screens/chat.py +32 -5
- shotgun/tui/screens/directory_setup.py +113 -0
- shotgun/utils/file_system_utils.py +6 -1
- {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev14.dist-info}/METADATA +3 -2
- shotgun_sh-0.1.0.dev14.dist-info/RECORD +138 -0
- shotgun/prompts/user/research.j2 +0 -5
- shotgun_sh-0.1.0.dev12.dist-info/RECORD +0 -104
- {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev14.dist-info}/WHEEL +0 -0
- {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev14.dist-info}/entry_points.txt +0 -0
- {shotgun_sh-0.1.0.dev12.dist-info → shotgun_sh-0.1.0.dev14.dist-info}/licenses/LICENSE +0 -0
shotgun/agents/agent_manager.py
CHANGED
|
@@ -3,13 +3,19 @@
|
|
|
3
3
|
from enum import Enum
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
|
-
from pydantic_ai import
|
|
6
|
+
from pydantic_ai import (
|
|
7
|
+
Agent,
|
|
8
|
+
DeferredToolRequests,
|
|
9
|
+
DeferredToolResults,
|
|
10
|
+
UsageLimits,
|
|
11
|
+
)
|
|
7
12
|
from pydantic_ai.agent import AgentRunResult
|
|
8
13
|
from pydantic_ai.messages import ModelMessage, ModelRequest
|
|
9
14
|
from textual.message import Message
|
|
10
15
|
from textual.widget import Widget
|
|
11
16
|
|
|
12
|
-
from .
|
|
17
|
+
from .history.compaction import apply_persistent_compaction
|
|
18
|
+
from .models import AgentDeps, AgentRuntimeOptions, FileOperation
|
|
13
19
|
from .plan import create_plan_agent
|
|
14
20
|
from .research import create_research_agent
|
|
15
21
|
from .tasks import create_tasks_agent
|
|
@@ -84,6 +90,7 @@ class AgentManager(Widget):
|
|
|
84
90
|
# Maintain shared message history
|
|
85
91
|
self.ui_message_history: list[ModelMessage] = []
|
|
86
92
|
self.message_history: list[ModelMessage] = []
|
|
93
|
+
self.recently_change_files: list[FileOperation] = []
|
|
87
94
|
|
|
88
95
|
@property
|
|
89
96
|
def current_agent(self) -> Agent[AgentDeps, str | DeferredToolRequests]:
|
|
@@ -181,9 +188,15 @@ class AgentManager(Widget):
|
|
|
181
188
|
mes for mes in result.new_messages() if not isinstance(mes, ModelRequest)
|
|
182
189
|
]
|
|
183
190
|
|
|
184
|
-
|
|
191
|
+
# Apply compaction to persistent message history to prevent cascading growth
|
|
192
|
+
self.message_history = await apply_persistent_compaction(
|
|
193
|
+
result.all_messages(), deps
|
|
194
|
+
)
|
|
185
195
|
self._post_messages_updated()
|
|
186
196
|
|
|
197
|
+
# Log file operations summary if any files were modified
|
|
198
|
+
self.recently_change_files = deps.file_tracker.operations.copy()
|
|
199
|
+
|
|
187
200
|
return result
|
|
188
201
|
|
|
189
202
|
def _post_messages_updated(self) -> None:
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Utilities for collecting and organizing artifact state information."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import TypedDict
|
|
5
|
+
|
|
6
|
+
from shotgun.artifacts.models import ArtifactSummary
|
|
7
|
+
from shotgun.artifacts.templates.models import TemplateSummary
|
|
8
|
+
from shotgun.sdk.services import get_artifact_service
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ArtifactState(TypedDict):
|
|
12
|
+
"""Type definition for artifact state information."""
|
|
13
|
+
|
|
14
|
+
available_templates: dict[str, list[TemplateSummary]]
|
|
15
|
+
existing_artifacts: dict[str, list[ArtifactSummary]]
|
|
16
|
+
current_date: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def collect_artifact_state() -> ArtifactState:
|
|
20
|
+
"""Collect and organize artifact state information for system context.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
ArtifactState containing organized templates and artifacts by mode, plus current date
|
|
24
|
+
"""
|
|
25
|
+
artifact_service = get_artifact_service()
|
|
26
|
+
|
|
27
|
+
# Get available templates
|
|
28
|
+
available_templates_list = artifact_service.list_templates()
|
|
29
|
+
|
|
30
|
+
# Group templates by mode for better organization
|
|
31
|
+
templates_by_mode: dict[str, list[TemplateSummary]] = {}
|
|
32
|
+
for template in available_templates_list:
|
|
33
|
+
mode_name = template.template_id.split("/")[0]
|
|
34
|
+
if mode_name not in templates_by_mode:
|
|
35
|
+
templates_by_mode[mode_name] = []
|
|
36
|
+
templates_by_mode[mode_name].append(template)
|
|
37
|
+
|
|
38
|
+
# Get ALL existing artifacts regardless of current agent mode for complete visibility
|
|
39
|
+
existing_artifacts_list = (
|
|
40
|
+
artifact_service.list_artifacts()
|
|
41
|
+
) # No mode filter = all modes
|
|
42
|
+
|
|
43
|
+
# Group artifacts by mode for organized display
|
|
44
|
+
artifacts_by_mode: dict[str, list[ArtifactSummary]] = {}
|
|
45
|
+
for artifact in existing_artifacts_list:
|
|
46
|
+
mode_name = artifact.agent_mode.value
|
|
47
|
+
if mode_name not in artifacts_by_mode:
|
|
48
|
+
artifacts_by_mode[mode_name] = []
|
|
49
|
+
artifacts_by_mode[mode_name].append(artifact)
|
|
50
|
+
|
|
51
|
+
# Get current date for temporal context (month in words for clarity)
|
|
52
|
+
current_date = datetime.now().strftime("%B %d, %Y")
|
|
53
|
+
|
|
54
|
+
return {
|
|
55
|
+
"available_templates": templates_by_mode,
|
|
56
|
+
"existing_artifacts": artifacts_by_mode,
|
|
57
|
+
"current_date": current_date,
|
|
58
|
+
}
|
shotgun/agents/common.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from collections.abc import Callable
|
|
5
|
-
from pathlib import Path
|
|
6
5
|
from typing import Any
|
|
7
6
|
|
|
8
7
|
from pydantic_ai import (
|
|
@@ -15,17 +14,20 @@ from pydantic_ai import (
|
|
|
15
14
|
from pydantic_ai.agent import AgentRunResult
|
|
16
15
|
from pydantic_ai.messages import (
|
|
17
16
|
ModelMessage,
|
|
17
|
+
ModelRequest,
|
|
18
18
|
ModelResponse,
|
|
19
|
+
SystemPromptPart,
|
|
19
20
|
TextPart,
|
|
20
21
|
)
|
|
21
22
|
|
|
22
23
|
from shotgun.agents.config import ProviderType, get_config_manager, get_provider_model
|
|
23
24
|
from shotgun.logging_config import get_logger
|
|
24
25
|
from shotgun.prompts import PromptLoader
|
|
25
|
-
from shotgun.sdk.services import get_codebase_service
|
|
26
|
+
from shotgun.sdk.services import get_artifact_service, get_codebase_service
|
|
26
27
|
from shotgun.utils import ensure_shotgun_directory_exists
|
|
27
28
|
|
|
28
29
|
from .history import token_limit_compactor
|
|
30
|
+
from .history.compaction import apply_persistent_compaction
|
|
29
31
|
from .models import AgentDeps, AgentRuntimeOptions
|
|
30
32
|
from .tools import (
|
|
31
33
|
append_file,
|
|
@@ -38,6 +40,14 @@ from .tools import (
|
|
|
38
40
|
retrieve_code,
|
|
39
41
|
write_file,
|
|
40
42
|
)
|
|
43
|
+
from .tools.artifact_management import (
|
|
44
|
+
create_artifact,
|
|
45
|
+
list_artifact_templates,
|
|
46
|
+
list_artifacts,
|
|
47
|
+
read_artifact,
|
|
48
|
+
read_artifact_section,
|
|
49
|
+
write_artifact_section,
|
|
50
|
+
)
|
|
41
51
|
|
|
42
52
|
logger = get_logger(__name__)
|
|
43
53
|
|
|
@@ -45,70 +55,6 @@ logger = get_logger(__name__)
|
|
|
45
55
|
prompt_loader = PromptLoader()
|
|
46
56
|
|
|
47
57
|
|
|
48
|
-
def ensure_file_exists(filename: str, header: str) -> str:
|
|
49
|
-
"""Ensure a markdown file exists with proper header and return its content.
|
|
50
|
-
|
|
51
|
-
Args:
|
|
52
|
-
filename: Name of the file (e.g., "research.md")
|
|
53
|
-
header: Header to add if file is empty (e.g., "# Research")
|
|
54
|
-
|
|
55
|
-
Returns:
|
|
56
|
-
Current file content
|
|
57
|
-
"""
|
|
58
|
-
shotgun_dir = Path.cwd() / ".shotgun"
|
|
59
|
-
file_path = shotgun_dir / filename
|
|
60
|
-
|
|
61
|
-
try:
|
|
62
|
-
if file_path.exists():
|
|
63
|
-
content = file_path.read_text(encoding="utf-8")
|
|
64
|
-
if not content.strip():
|
|
65
|
-
# File exists but is empty, add header
|
|
66
|
-
header_content = f"{header}\n\n"
|
|
67
|
-
file_path.write_text(header_content, encoding="utf-8")
|
|
68
|
-
return header_content
|
|
69
|
-
return content
|
|
70
|
-
else:
|
|
71
|
-
# File doesn't exist, create it with header
|
|
72
|
-
shotgun_dir.mkdir(exist_ok=True)
|
|
73
|
-
header_content = f"{header}\n\n"
|
|
74
|
-
file_path.write_text(header_content, encoding="utf-8")
|
|
75
|
-
return header_content
|
|
76
|
-
except Exception as e:
|
|
77
|
-
logger.error("Failed to initialize %s: %s", filename, str(e))
|
|
78
|
-
return f"{header}\n\n"
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def register_common_tools(
|
|
82
|
-
agent: Agent[AgentDeps], additional_tools: list[Any], interactive_mode: bool
|
|
83
|
-
) -> None:
|
|
84
|
-
"""Register common tools with an agent.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
agent: The Pydantic AI agent to register tools with
|
|
88
|
-
additional_tools: List of additional tools specific to this agent
|
|
89
|
-
interactive_mode: Whether to register interactive tools
|
|
90
|
-
"""
|
|
91
|
-
logger.debug("📌 Registering tools with agent")
|
|
92
|
-
|
|
93
|
-
# Register additional tools first (agent-specific)
|
|
94
|
-
for tool in additional_tools:
|
|
95
|
-
agent.tool_plain(tool)
|
|
96
|
-
|
|
97
|
-
# Register interactive tool if enabled
|
|
98
|
-
if interactive_mode:
|
|
99
|
-
agent.tool(ask_user)
|
|
100
|
-
logger.debug("📞 User interaction tool registered")
|
|
101
|
-
else:
|
|
102
|
-
logger.debug("🚫 User interaction disabled (non-interactive mode)")
|
|
103
|
-
|
|
104
|
-
# Register common file management tools
|
|
105
|
-
agent.tool_plain(read_file)
|
|
106
|
-
agent.tool_plain(write_file)
|
|
107
|
-
agent.tool_plain(append_file)
|
|
108
|
-
|
|
109
|
-
logger.debug("✅ Tool registration complete")
|
|
110
|
-
|
|
111
|
-
|
|
112
58
|
async def add_system_status_message(
|
|
113
59
|
deps: AgentDeps,
|
|
114
60
|
message_history: list[ModelMessage] | None = None,
|
|
@@ -125,11 +71,17 @@ async def add_system_status_message(
|
|
|
125
71
|
message_history = message_history or []
|
|
126
72
|
codebase_understanding_graphs = await deps.codebase_service.list_graphs()
|
|
127
73
|
|
|
74
|
+
# Collect artifact state information
|
|
75
|
+
from .artifact_state import collect_artifact_state
|
|
76
|
+
|
|
77
|
+
artifact_state = collect_artifact_state()
|
|
78
|
+
|
|
128
79
|
system_state = prompt_loader.render(
|
|
129
80
|
"agents/state/system_state.j2",
|
|
130
81
|
codebase_understanding_graphs=codebase_understanding_graphs,
|
|
131
|
-
|
|
82
|
+
**artifact_state,
|
|
132
83
|
)
|
|
84
|
+
|
|
133
85
|
message_history.append(
|
|
134
86
|
ModelResponse(
|
|
135
87
|
parts=[
|
|
@@ -173,12 +125,15 @@ def create_base_agent(
|
|
|
173
125
|
# Use the Model instance directly (has API key baked in)
|
|
174
126
|
model = model_config.model_instance
|
|
175
127
|
|
|
176
|
-
# Create deps with model config and
|
|
128
|
+
# Create deps with model config and services
|
|
177
129
|
codebase_service = get_codebase_service()
|
|
130
|
+
artifact_service = get_artifact_service()
|
|
178
131
|
deps = AgentDeps(
|
|
179
132
|
**agent_runtime_options.model_dump(),
|
|
180
133
|
llm_model=model_config,
|
|
181
134
|
codebase_service=codebase_service,
|
|
135
|
+
artifact_service=artifact_service,
|
|
136
|
+
system_prompt_fn=system_prompt_fn,
|
|
182
137
|
)
|
|
183
138
|
|
|
184
139
|
except Exception as e:
|
|
@@ -186,16 +141,30 @@ def create_base_agent(
|
|
|
186
141
|
logger.debug("🤖 Creating agent with fallback OpenAI GPT-4o")
|
|
187
142
|
raise ValueError("Configured model is required") from e
|
|
188
143
|
|
|
144
|
+
# Create a history processor that has access to deps via closure
|
|
145
|
+
async def history_processor(messages: list[ModelMessage]) -> list[ModelMessage]:
|
|
146
|
+
"""History processor with access to deps via closure."""
|
|
147
|
+
|
|
148
|
+
# Create a minimal context for compaction
|
|
149
|
+
class ProcessorContext:
|
|
150
|
+
def __init__(self, deps: AgentDeps):
|
|
151
|
+
self.deps = deps
|
|
152
|
+
self.usage = None # Will be estimated from messages
|
|
153
|
+
|
|
154
|
+
ctx = ProcessorContext(deps)
|
|
155
|
+
return await token_limit_compactor(ctx, messages)
|
|
156
|
+
|
|
189
157
|
agent = Agent(
|
|
190
158
|
model,
|
|
191
159
|
output_type=[str, DeferredToolRequests],
|
|
192
160
|
deps_type=AgentDeps,
|
|
193
161
|
instrument=True,
|
|
194
|
-
history_processors=[
|
|
162
|
+
history_processors=[history_processor],
|
|
195
163
|
)
|
|
196
164
|
|
|
197
|
-
#
|
|
198
|
-
|
|
165
|
+
# System prompt function is stored in deps and will be called manually in run_agent
|
|
166
|
+
func_name = getattr(system_prompt_fn, "__name__", str(system_prompt_fn))
|
|
167
|
+
logger.debug("🔧 System prompt function stored: %s", func_name)
|
|
199
168
|
|
|
200
169
|
# Register additional tools first (agent-specific)
|
|
201
170
|
for tool in additional_tools or []:
|
|
@@ -207,11 +176,19 @@ def create_base_agent(
|
|
|
207
176
|
logger.debug("📞 Interactive mode enabled - ask_user tool registered")
|
|
208
177
|
|
|
209
178
|
# Register common file management tools (always available)
|
|
210
|
-
agent.
|
|
211
|
-
agent.
|
|
212
|
-
agent.
|
|
213
|
-
|
|
214
|
-
# Register
|
|
179
|
+
agent.tool(read_file)
|
|
180
|
+
agent.tool(write_file)
|
|
181
|
+
agent.tool(append_file)
|
|
182
|
+
|
|
183
|
+
# Register artifact management tools (always available)
|
|
184
|
+
agent.tool(create_artifact)
|
|
185
|
+
agent.tool(list_artifacts)
|
|
186
|
+
agent.tool(list_artifact_templates)
|
|
187
|
+
agent.tool(read_artifact)
|
|
188
|
+
agent.tool(read_artifact_section)
|
|
189
|
+
agent.tool(write_artifact_section)
|
|
190
|
+
|
|
191
|
+
# Register codebase understanding tools (conditional)
|
|
215
192
|
if load_codebase_understanding_tools:
|
|
216
193
|
agent.tool(query_graph)
|
|
217
194
|
agent.tool(retrieve_code)
|
|
@@ -222,10 +199,47 @@ def create_base_agent(
|
|
|
222
199
|
else:
|
|
223
200
|
logger.debug("🚫🧠 Codebase understanding tools not registered")
|
|
224
201
|
|
|
225
|
-
logger.debug("✅ Agent creation complete")
|
|
202
|
+
logger.debug("✅ Agent creation complete with artifact and codebase tools")
|
|
226
203
|
return agent, deps
|
|
227
204
|
|
|
228
205
|
|
|
206
|
+
def build_agent_system_prompt(
|
|
207
|
+
agent_type: str,
|
|
208
|
+
ctx: RunContext[AgentDeps],
|
|
209
|
+
context_name: str | None = None,
|
|
210
|
+
) -> str:
|
|
211
|
+
"""Build system prompt for any agent type.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
agent_type: Type of agent ('research', 'plan', 'tasks')
|
|
215
|
+
ctx: RunContext containing AgentDeps
|
|
216
|
+
context_name: Optional context name for template rendering
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Rendered system prompt
|
|
220
|
+
"""
|
|
221
|
+
prompt_loader = PromptLoader()
|
|
222
|
+
|
|
223
|
+
# Add logging if research agent
|
|
224
|
+
if agent_type == "research":
|
|
225
|
+
logger.debug("🔧 Building research agent system prompt...")
|
|
226
|
+
logger.debug("Interactive mode: %s", ctx.deps.interactive_mode)
|
|
227
|
+
|
|
228
|
+
result = prompt_loader.render(
|
|
229
|
+
f"agents/{agent_type}.j2",
|
|
230
|
+
interactive_mode=ctx.deps.interactive_mode,
|
|
231
|
+
mode=agent_type,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
if agent_type == "research":
|
|
235
|
+
logger.debug(
|
|
236
|
+
"✅ Research system prompt built successfully (length: %d chars)",
|
|
237
|
+
len(result),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
return result
|
|
241
|
+
|
|
242
|
+
|
|
229
243
|
def create_usage_limits() -> UsageLimits:
|
|
230
244
|
"""Create reasonable usage limits for agent runs.
|
|
231
245
|
|
|
@@ -238,20 +252,41 @@ def create_usage_limits() -> UsageLimits:
|
|
|
238
252
|
)
|
|
239
253
|
|
|
240
254
|
|
|
241
|
-
def
|
|
242
|
-
|
|
255
|
+
async def add_system_prompt_message(
|
|
256
|
+
deps: AgentDeps,
|
|
257
|
+
message_history: list[ModelMessage] | None = None,
|
|
258
|
+
) -> list[ModelMessage]:
|
|
259
|
+
"""Add the system prompt as the first message in the message history.
|
|
243
260
|
|
|
244
261
|
Args:
|
|
245
|
-
|
|
262
|
+
deps: Agent dependencies containing system_prompt_fn
|
|
263
|
+
message_history: Existing message history
|
|
246
264
|
|
|
247
265
|
Returns:
|
|
248
|
-
|
|
266
|
+
Updated message history with system prompt prepended as first message
|
|
249
267
|
"""
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
268
|
+
message_history = message_history or []
|
|
269
|
+
|
|
270
|
+
# Create a minimal RunContext to call the system prompt function
|
|
271
|
+
# We'll pass None for model and usage since they're not used by our system prompt functions
|
|
272
|
+
context = type(
|
|
273
|
+
"RunContext", (), {"deps": deps, "retry": 0, "model": None, "usage": None}
|
|
274
|
+
)()
|
|
275
|
+
|
|
276
|
+
# Render the system prompt using the stored function
|
|
277
|
+
system_prompt_content = deps.system_prompt_fn(context)
|
|
278
|
+
logger.debug(
|
|
279
|
+
"🎯 Rendered system prompt (length: %d chars)", len(system_prompt_content)
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Create system message and prepend to message history
|
|
283
|
+
system_message = ModelRequest(
|
|
284
|
+
parts=[SystemPromptPart(content=system_prompt_content)]
|
|
285
|
+
)
|
|
286
|
+
message_history.insert(0, system_message)
|
|
287
|
+
logger.debug("✅ System prompt prepended as first message")
|
|
288
|
+
|
|
289
|
+
return message_history
|
|
255
290
|
|
|
256
291
|
|
|
257
292
|
async def run_agent(
|
|
@@ -261,6 +296,13 @@ async def run_agent(
|
|
|
261
296
|
message_history: list[ModelMessage] | None = None,
|
|
262
297
|
usage_limits: UsageLimits | None = None,
|
|
263
298
|
) -> AgentRunResult[str | DeferredToolRequests]:
|
|
299
|
+
# Clear file tracker for new run
|
|
300
|
+
deps.file_tracker.clear()
|
|
301
|
+
logger.debug("🔧 Cleared file tracker for new agent run")
|
|
302
|
+
|
|
303
|
+
# Add system prompt as first message
|
|
304
|
+
message_history = await add_system_prompt_message(deps, message_history)
|
|
305
|
+
|
|
264
306
|
result = await agent.run(
|
|
265
307
|
prompt,
|
|
266
308
|
deps=deps,
|
|
@@ -268,7 +310,8 @@ async def run_agent(
|
|
|
268
310
|
message_history=message_history,
|
|
269
311
|
)
|
|
270
312
|
|
|
271
|
-
|
|
313
|
+
# Apply persistent compaction to prevent cascading token growth across CLI commands
|
|
314
|
+
messages = await apply_persistent_compaction(result.all_messages(), deps)
|
|
272
315
|
while isinstance(result.output, DeferredToolRequests):
|
|
273
316
|
logger.info("got deferred tool requests")
|
|
274
317
|
await deps.queue.join()
|
|
@@ -291,6 +334,12 @@ async def run_agent(
|
|
|
291
334
|
message_history=messages,
|
|
292
335
|
deferred_tool_results=results,
|
|
293
336
|
)
|
|
294
|
-
|
|
337
|
+
# Apply persistent compaction to prevent cascading token growth in multi-turn loops
|
|
338
|
+
messages = await apply_persistent_compaction(result.all_messages(), deps)
|
|
339
|
+
|
|
340
|
+
# Log file operations summary if any files were modified
|
|
341
|
+
if deps.file_tracker.operations:
|
|
342
|
+
summary = deps.file_tracker.format_summary()
|
|
343
|
+
logger.info("📁 %s", summary)
|
|
295
344
|
|
|
296
345
|
return result
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Configuration constants for Shotgun agents."""
|
|
2
|
+
|
|
3
|
+
# Field names
|
|
4
|
+
API_KEY_FIELD = "api_key"
|
|
5
|
+
MODEL_NAME_FIELD = "model_name"
|
|
6
|
+
DEFAULT_PROVIDER_FIELD = "default_provider"
|
|
7
|
+
USER_ID_FIELD = "user_id"
|
|
8
|
+
CONFIG_VERSION_FIELD = "config_version"
|
|
9
|
+
|
|
10
|
+
# Provider names (for consistency with data dict keys)
|
|
11
|
+
OPENAI_PROVIDER = "openai"
|
|
12
|
+
ANTHROPIC_PROVIDER = "anthropic"
|
|
13
|
+
GOOGLE_PROVIDER = "google"
|
|
14
|
+
|
|
15
|
+
# Environment variable names
|
|
16
|
+
OPENAI_API_KEY_ENV = "OPENAI_API_KEY"
|
|
17
|
+
ANTHROPIC_API_KEY_ENV = "ANTHROPIC_API_KEY"
|
|
18
|
+
GEMINI_API_KEY_ENV = "GEMINI_API_KEY"
|
shotgun/agents/config/manager.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Configuration manager for Shotgun CLI."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import os
|
|
4
5
|
import uuid
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from typing import Any
|
|
@@ -10,6 +11,15 @@ from pydantic import SecretStr
|
|
|
10
11
|
from shotgun.logging_config import get_logger
|
|
11
12
|
from shotgun.utils import get_shotgun_home
|
|
12
13
|
|
|
14
|
+
from .constants import (
|
|
15
|
+
ANTHROPIC_API_KEY_ENV,
|
|
16
|
+
ANTHROPIC_PROVIDER,
|
|
17
|
+
API_KEY_FIELD,
|
|
18
|
+
GEMINI_API_KEY_ENV,
|
|
19
|
+
GOOGLE_PROVIDER,
|
|
20
|
+
OPENAI_API_KEY_ENV,
|
|
21
|
+
OPENAI_PROVIDER,
|
|
22
|
+
)
|
|
13
23
|
from .models import ProviderType, ShotgunConfig
|
|
14
24
|
|
|
15
25
|
logger = get_logger(__name__)
|
|
@@ -58,6 +68,22 @@ class ConfigManager:
|
|
|
58
68
|
|
|
59
69
|
self._config = ShotgunConfig.model_validate(data)
|
|
60
70
|
logger.debug("Configuration loaded successfully from %s", self.config_path)
|
|
71
|
+
|
|
72
|
+
# Check if the default provider has a key, if not find one that does
|
|
73
|
+
if not self.has_provider_key(self._config.default_provider):
|
|
74
|
+
original_default = self._config.default_provider
|
|
75
|
+
# Find first provider with a configured key
|
|
76
|
+
for provider in ProviderType:
|
|
77
|
+
if self.has_provider_key(provider):
|
|
78
|
+
logger.info(
|
|
79
|
+
"Default provider %s has no API key, updating to %s",
|
|
80
|
+
original_default.value,
|
|
81
|
+
provider.value,
|
|
82
|
+
)
|
|
83
|
+
self._config.default_provider = provider
|
|
84
|
+
self.save(self._config)
|
|
85
|
+
break
|
|
86
|
+
|
|
61
87
|
return self._config
|
|
62
88
|
|
|
63
89
|
except Exception as e:
|
|
@@ -114,17 +140,25 @@ class ConfigManager:
|
|
|
114
140
|
provider_config = self._get_provider_config(config, provider_enum)
|
|
115
141
|
|
|
116
142
|
# Only support api_key updates
|
|
117
|
-
if
|
|
118
|
-
api_key_value = kwargs[
|
|
143
|
+
if API_KEY_FIELD in kwargs:
|
|
144
|
+
api_key_value = kwargs[API_KEY_FIELD]
|
|
119
145
|
provider_config.api_key = (
|
|
120
146
|
SecretStr(api_key_value) if api_key_value is not None else None
|
|
121
147
|
)
|
|
122
148
|
|
|
123
149
|
# Reject other fields
|
|
124
|
-
unsupported_fields = set(kwargs.keys()) - {
|
|
150
|
+
unsupported_fields = set(kwargs.keys()) - {API_KEY_FIELD}
|
|
125
151
|
if unsupported_fields:
|
|
126
152
|
raise ValueError(f"Unsupported configuration fields: {unsupported_fields}")
|
|
127
153
|
|
|
154
|
+
# If no other providers have keys configured and we just added one,
|
|
155
|
+
# set this provider as the default
|
|
156
|
+
if API_KEY_FIELD in kwargs and api_key_value is not None:
|
|
157
|
+
other_providers = [p for p in ProviderType if p != provider_enum]
|
|
158
|
+
has_other_keys = any(self.has_provider_key(p) for p in other_providers)
|
|
159
|
+
if not has_other_keys:
|
|
160
|
+
config.default_provider = provider_enum
|
|
161
|
+
|
|
128
162
|
self.save(config)
|
|
129
163
|
|
|
130
164
|
def clear_provider_key(self, provider: ProviderType | str) -> None:
|
|
@@ -136,11 +170,27 @@ class ConfigManager:
|
|
|
136
170
|
self.save(config)
|
|
137
171
|
|
|
138
172
|
def has_provider_key(self, provider: ProviderType | str) -> bool:
|
|
139
|
-
"""Check if the given provider has a non-empty API key configured.
|
|
173
|
+
"""Check if the given provider has a non-empty API key configured.
|
|
174
|
+
|
|
175
|
+
This checks both the configuration file and environment variables.
|
|
176
|
+
"""
|
|
140
177
|
config = self.load()
|
|
141
178
|
provider_enum = self._ensure_provider_enum(provider)
|
|
142
179
|
provider_config = self._get_provider_config(config, provider_enum)
|
|
143
|
-
|
|
180
|
+
|
|
181
|
+
# Check config first
|
|
182
|
+
if self._provider_has_api_key(provider_config):
|
|
183
|
+
return True
|
|
184
|
+
|
|
185
|
+
# Check environment variable
|
|
186
|
+
if provider_enum == ProviderType.OPENAI:
|
|
187
|
+
return bool(os.getenv(OPENAI_API_KEY_ENV))
|
|
188
|
+
elif provider_enum == ProviderType.ANTHROPIC:
|
|
189
|
+
return bool(os.getenv(ANTHROPIC_API_KEY_ENV))
|
|
190
|
+
elif provider_enum == ProviderType.GOOGLE:
|
|
191
|
+
return bool(os.getenv(GEMINI_API_KEY_ENV))
|
|
192
|
+
|
|
193
|
+
return False
|
|
144
194
|
|
|
145
195
|
def has_any_provider_key(self) -> bool:
|
|
146
196
|
"""Determine whether any provider has a configured API key."""
|
|
@@ -175,25 +225,27 @@ class ConfigManager:
|
|
|
175
225
|
|
|
176
226
|
def _convert_secrets_to_secretstr(self, data: dict[str, Any]) -> None:
|
|
177
227
|
"""Convert plain text secrets in data to SecretStr objects."""
|
|
178
|
-
for provider in [
|
|
228
|
+
for provider in [OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GOOGLE_PROVIDER]:
|
|
179
229
|
if provider in data and isinstance(data[provider], dict):
|
|
180
230
|
if (
|
|
181
|
-
|
|
182
|
-
and data[provider][
|
|
231
|
+
API_KEY_FIELD in data[provider]
|
|
232
|
+
and data[provider][API_KEY_FIELD] is not None
|
|
183
233
|
):
|
|
184
|
-
data[provider][
|
|
234
|
+
data[provider][API_KEY_FIELD] = SecretStr(
|
|
235
|
+
data[provider][API_KEY_FIELD]
|
|
236
|
+
)
|
|
185
237
|
|
|
186
238
|
def _convert_secretstr_to_plain(self, data: dict[str, Any]) -> None:
|
|
187
239
|
"""Convert SecretStr objects in data to plain text for JSON serialization."""
|
|
188
|
-
for provider in [
|
|
240
|
+
for provider in [OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GOOGLE_PROVIDER]:
|
|
189
241
|
if provider in data and isinstance(data[provider], dict):
|
|
190
242
|
if (
|
|
191
|
-
|
|
192
|
-
and data[provider][
|
|
243
|
+
API_KEY_FIELD in data[provider]
|
|
244
|
+
and data[provider][API_KEY_FIELD] is not None
|
|
193
245
|
):
|
|
194
|
-
if hasattr(data[provider][
|
|
195
|
-
data[provider][
|
|
196
|
-
|
|
246
|
+
if hasattr(data[provider][API_KEY_FIELD], "get_secret_value"):
|
|
247
|
+
data[provider][API_KEY_FIELD] = data[provider][
|
|
248
|
+
API_KEY_FIELD
|
|
197
249
|
].get_secret_value()
|
|
198
250
|
|
|
199
251
|
def _ensure_provider_enum(self, provider: ProviderType | str) -> ProviderType:
|
|
@@ -216,7 +268,7 @@ class ConfigManager:
|
|
|
216
268
|
|
|
217
269
|
def _provider_has_api_key(self, provider_config: Any) -> bool:
|
|
218
270
|
"""Return True if the provider config contains a usable API key."""
|
|
219
|
-
api_key = getattr(provider_config,
|
|
271
|
+
api_key = getattr(provider_config, API_KEY_FIELD, None)
|
|
220
272
|
if api_key is None:
|
|
221
273
|
return False
|
|
222
274
|
|