shotgun-sh 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/__init__.py +5 -0
- shotgun/agents/__init__.py +1 -0
- shotgun/agents/agent_manager.py +651 -0
- shotgun/agents/common.py +549 -0
- shotgun/agents/config/__init__.py +13 -0
- shotgun/agents/config/constants.py +17 -0
- shotgun/agents/config/manager.py +294 -0
- shotgun/agents/config/models.py +185 -0
- shotgun/agents/config/provider.py +206 -0
- shotgun/agents/conversation_history.py +106 -0
- shotgun/agents/conversation_manager.py +105 -0
- shotgun/agents/export.py +96 -0
- shotgun/agents/history/__init__.py +5 -0
- shotgun/agents/history/compaction.py +85 -0
- shotgun/agents/history/constants.py +19 -0
- shotgun/agents/history/context_extraction.py +108 -0
- shotgun/agents/history/history_building.py +104 -0
- shotgun/agents/history/history_processors.py +426 -0
- shotgun/agents/history/message_utils.py +84 -0
- shotgun/agents/history/token_counting.py +429 -0
- shotgun/agents/history/token_estimation.py +138 -0
- shotgun/agents/messages.py +35 -0
- shotgun/agents/models.py +275 -0
- shotgun/agents/plan.py +98 -0
- shotgun/agents/research.py +108 -0
- shotgun/agents/specify.py +98 -0
- shotgun/agents/tasks.py +96 -0
- shotgun/agents/tools/__init__.py +34 -0
- shotgun/agents/tools/codebase/__init__.py +28 -0
- shotgun/agents/tools/codebase/codebase_shell.py +256 -0
- shotgun/agents/tools/codebase/directory_lister.py +141 -0
- shotgun/agents/tools/codebase/file_read.py +144 -0
- shotgun/agents/tools/codebase/models.py +252 -0
- shotgun/agents/tools/codebase/query_graph.py +67 -0
- shotgun/agents/tools/codebase/retrieve_code.py +81 -0
- shotgun/agents/tools/file_management.py +218 -0
- shotgun/agents/tools/user_interaction.py +37 -0
- shotgun/agents/tools/web_search/__init__.py +60 -0
- shotgun/agents/tools/web_search/anthropic.py +144 -0
- shotgun/agents/tools/web_search/gemini.py +85 -0
- shotgun/agents/tools/web_search/openai.py +98 -0
- shotgun/agents/tools/web_search/utils.py +20 -0
- shotgun/build_constants.py +20 -0
- shotgun/cli/__init__.py +1 -0
- shotgun/cli/codebase/__init__.py +5 -0
- shotgun/cli/codebase/commands.py +202 -0
- shotgun/cli/codebase/models.py +21 -0
- shotgun/cli/config.py +275 -0
- shotgun/cli/export.py +81 -0
- shotgun/cli/models.py +10 -0
- shotgun/cli/plan.py +73 -0
- shotgun/cli/research.py +85 -0
- shotgun/cli/specify.py +69 -0
- shotgun/cli/tasks.py +78 -0
- shotgun/cli/update.py +152 -0
- shotgun/cli/utils.py +25 -0
- shotgun/codebase/__init__.py +12 -0
- shotgun/codebase/core/__init__.py +46 -0
- shotgun/codebase/core/change_detector.py +358 -0
- shotgun/codebase/core/code_retrieval.py +243 -0
- shotgun/codebase/core/ingestor.py +1497 -0
- shotgun/codebase/core/language_config.py +297 -0
- shotgun/codebase/core/manager.py +1662 -0
- shotgun/codebase/core/nl_query.py +331 -0
- shotgun/codebase/core/parser_loader.py +128 -0
- shotgun/codebase/models.py +111 -0
- shotgun/codebase/service.py +206 -0
- shotgun/logging_config.py +227 -0
- shotgun/main.py +167 -0
- shotgun/posthog_telemetry.py +158 -0
- shotgun/prompts/__init__.py +5 -0
- shotgun/prompts/agents/__init__.py +1 -0
- shotgun/prompts/agents/export.j2 +350 -0
- shotgun/prompts/agents/partials/codebase_understanding.j2 +87 -0
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +37 -0
- shotgun/prompts/agents/partials/content_formatting.j2 +65 -0
- shotgun/prompts/agents/partials/interactive_mode.j2 +26 -0
- shotgun/prompts/agents/plan.j2 +144 -0
- shotgun/prompts/agents/research.j2 +69 -0
- shotgun/prompts/agents/specify.j2 +51 -0
- shotgun/prompts/agents/state/codebase/codebase_graphs_available.j2 +19 -0
- shotgun/prompts/agents/state/system_state.j2 +31 -0
- shotgun/prompts/agents/tasks.j2 +143 -0
- shotgun/prompts/codebase/__init__.py +1 -0
- shotgun/prompts/codebase/cypher_query_patterns.j2 +223 -0
- shotgun/prompts/codebase/cypher_system.j2 +28 -0
- shotgun/prompts/codebase/enhanced_query_context.j2 +10 -0
- shotgun/prompts/codebase/partials/cypher_rules.j2 +24 -0
- shotgun/prompts/codebase/partials/graph_schema.j2 +30 -0
- shotgun/prompts/codebase/partials/temporal_context.j2 +21 -0
- shotgun/prompts/history/__init__.py +1 -0
- shotgun/prompts/history/incremental_summarization.j2 +53 -0
- shotgun/prompts/history/summarization.j2 +46 -0
- shotgun/prompts/loader.py +140 -0
- shotgun/py.typed +0 -0
- shotgun/sdk/__init__.py +13 -0
- shotgun/sdk/codebase.py +219 -0
- shotgun/sdk/exceptions.py +17 -0
- shotgun/sdk/models.py +189 -0
- shotgun/sdk/services.py +23 -0
- shotgun/sentry_telemetry.py +87 -0
- shotgun/telemetry.py +93 -0
- shotgun/tui/__init__.py +0 -0
- shotgun/tui/app.py +116 -0
- shotgun/tui/commands/__init__.py +76 -0
- shotgun/tui/components/prompt_input.py +69 -0
- shotgun/tui/components/spinner.py +86 -0
- shotgun/tui/components/splash.py +25 -0
- shotgun/tui/components/vertical_tail.py +13 -0
- shotgun/tui/screens/chat.py +782 -0
- shotgun/tui/screens/chat.tcss +43 -0
- shotgun/tui/screens/chat_screen/__init__.py +0 -0
- shotgun/tui/screens/chat_screen/command_providers.py +219 -0
- shotgun/tui/screens/chat_screen/hint_message.py +40 -0
- shotgun/tui/screens/chat_screen/history.py +221 -0
- shotgun/tui/screens/directory_setup.py +113 -0
- shotgun/tui/screens/provider_config.py +221 -0
- shotgun/tui/screens/splash.py +31 -0
- shotgun/tui/styles.tcss +10 -0
- shotgun/tui/utils/__init__.py +5 -0
- shotgun/tui/utils/mode_progress.py +257 -0
- shotgun/utils/__init__.py +5 -0
- shotgun/utils/env_utils.py +35 -0
- shotgun/utils/file_system_utils.py +36 -0
- shotgun/utils/update_checker.py +375 -0
- shotgun_sh-0.1.0.dist-info/METADATA +466 -0
- shotgun_sh-0.1.0.dist-info/RECORD +130 -0
- shotgun_sh-0.1.0.dist-info/WHEEL +4 -0
- shotgun_sh-0.1.0.dist-info/entry_points.txt +2 -0
- shotgun_sh-0.1.0.dist-info/licenses/LICENSE +21 -0
shotgun/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Shotgun AI Agents."""
|
|
@@ -0,0 +1,651 @@
|
|
|
1
|
+
"""Agent manager for coordinating multiple AI agents with shared message history."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from collections.abc import AsyncIterable, Sequence
|
|
5
|
+
from dataclasses import dataclass, field, is_dataclass, replace
|
|
6
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from shotgun.agents.conversation_history import ConversationState
|
|
10
|
+
|
|
11
|
+
from pydantic_ai import (
|
|
12
|
+
Agent,
|
|
13
|
+
DeferredToolRequests,
|
|
14
|
+
DeferredToolResults,
|
|
15
|
+
RunContext,
|
|
16
|
+
UsageLimits,
|
|
17
|
+
)
|
|
18
|
+
from pydantic_ai.agent import AgentRunResult
|
|
19
|
+
from pydantic_ai.messages import (
|
|
20
|
+
AgentStreamEvent,
|
|
21
|
+
FinalResultEvent,
|
|
22
|
+
FunctionToolCallEvent,
|
|
23
|
+
FunctionToolResultEvent,
|
|
24
|
+
ModelMessage,
|
|
25
|
+
ModelRequest,
|
|
26
|
+
ModelRequestPart,
|
|
27
|
+
ModelResponse,
|
|
28
|
+
ModelResponsePart,
|
|
29
|
+
PartDeltaEvent,
|
|
30
|
+
PartStartEvent,
|
|
31
|
+
SystemPromptPart,
|
|
32
|
+
ToolCallPart,
|
|
33
|
+
ToolCallPartDelta,
|
|
34
|
+
)
|
|
35
|
+
from textual.message import Message
|
|
36
|
+
from textual.widget import Widget
|
|
37
|
+
|
|
38
|
+
from shotgun.agents.common import add_system_prompt_message, add_system_status_message
|
|
39
|
+
from shotgun.agents.models import AgentType, FileOperation
|
|
40
|
+
from shotgun.tui.screens.chat_screen.hint_message import HintMessage
|
|
41
|
+
|
|
42
|
+
from .export import create_export_agent
|
|
43
|
+
from .history.compaction import apply_persistent_compaction
|
|
44
|
+
from .messages import AgentSystemPrompt
|
|
45
|
+
from .models import AgentDeps, AgentRuntimeOptions
|
|
46
|
+
from .plan import create_plan_agent
|
|
47
|
+
from .research import create_research_agent
|
|
48
|
+
from .specify import create_specify_agent
|
|
49
|
+
from .tasks import create_tasks_agent
|
|
50
|
+
|
|
51
|
+
logger = logging.getLogger(__name__)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class MessageHistoryUpdated(Message):
|
|
55
|
+
"""Event posted when the message history is updated."""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
messages: list[ModelMessage | HintMessage],
|
|
60
|
+
agent_type: AgentType,
|
|
61
|
+
file_operations: list[FileOperation] | None = None,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""Initialize the message history updated event.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
messages: The updated message history.
|
|
67
|
+
agent_type: The type of agent that triggered the update.
|
|
68
|
+
file_operations: List of file operations from this run.
|
|
69
|
+
"""
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.messages = messages
|
|
72
|
+
self.agent_type = agent_type
|
|
73
|
+
self.file_operations = file_operations or []
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class PartialResponseMessage(Message):
|
|
77
|
+
"""Event posted when a partial response is received."""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
message: ModelResponse | None,
|
|
82
|
+
messages: list[ModelMessage],
|
|
83
|
+
is_last: bool,
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Initialize the partial response message."""
|
|
86
|
+
super().__init__()
|
|
87
|
+
self.message = message
|
|
88
|
+
self.messages = messages
|
|
89
|
+
self.is_last = is_last
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass(slots=True)
|
|
93
|
+
class _PartialStreamState:
|
|
94
|
+
"""Tracks streamed messages while handling a single agent run."""
|
|
95
|
+
|
|
96
|
+
messages: list[ModelRequest | ModelResponse] = field(default_factory=list)
|
|
97
|
+
current_response: ModelResponse | None = None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class AgentManager(Widget):
|
|
101
|
+
"""Manages multiple agents with shared message history."""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
deps: AgentDeps | None = None,
|
|
106
|
+
initial_type: AgentType = AgentType.RESEARCH,
|
|
107
|
+
) -> None:
|
|
108
|
+
"""Initialize the agent manager.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
deps: Optional agent dependencies. If not provided, defaults to interactive mode.
|
|
112
|
+
"""
|
|
113
|
+
super().__init__()
|
|
114
|
+
self.display = False
|
|
115
|
+
|
|
116
|
+
# Use provided deps or create default with interactive mode
|
|
117
|
+
self.deps = deps
|
|
118
|
+
|
|
119
|
+
if self.deps is None:
|
|
120
|
+
raise ValueError("AgentDeps must be provided to AgentManager")
|
|
121
|
+
|
|
122
|
+
# Create AgentRuntimeOptions from deps for agent creation
|
|
123
|
+
agent_runtime_options = AgentRuntimeOptions(
|
|
124
|
+
interactive_mode=self.deps.interactive_mode,
|
|
125
|
+
working_directory=self.deps.working_directory,
|
|
126
|
+
is_tui_context=self.deps.is_tui_context,
|
|
127
|
+
max_iterations=self.deps.max_iterations,
|
|
128
|
+
queue=self.deps.queue,
|
|
129
|
+
tasks=self.deps.tasks,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Initialize all agents and store their specific deps
|
|
133
|
+
self.research_agent, self.research_deps = create_research_agent(
|
|
134
|
+
agent_runtime_options=agent_runtime_options
|
|
135
|
+
)
|
|
136
|
+
self.plan_agent, self.plan_deps = create_plan_agent(
|
|
137
|
+
agent_runtime_options=agent_runtime_options
|
|
138
|
+
)
|
|
139
|
+
self.tasks_agent, self.tasks_deps = create_tasks_agent(
|
|
140
|
+
agent_runtime_options=agent_runtime_options
|
|
141
|
+
)
|
|
142
|
+
self.specify_agent, self.specify_deps = create_specify_agent(
|
|
143
|
+
agent_runtime_options=agent_runtime_options
|
|
144
|
+
)
|
|
145
|
+
self.export_agent, self.export_deps = create_export_agent(
|
|
146
|
+
agent_runtime_options=agent_runtime_options
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Track current active agent
|
|
150
|
+
self._current_agent_type: AgentType = initial_type
|
|
151
|
+
|
|
152
|
+
# Maintain shared message history
|
|
153
|
+
self.ui_message_history: list[ModelMessage | HintMessage] = []
|
|
154
|
+
self.message_history: list[ModelMessage] = []
|
|
155
|
+
self.recently_change_files: list[FileOperation] = []
|
|
156
|
+
self._stream_state: _PartialStreamState | None = None
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def current_agent(self) -> Agent[AgentDeps, str | DeferredToolRequests]:
|
|
160
|
+
"""Get the currently active agent.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
The currently selected agent instance.
|
|
164
|
+
"""
|
|
165
|
+
return self._get_agent(self._current_agent_type)
|
|
166
|
+
|
|
167
|
+
def _get_agent(
|
|
168
|
+
self, agent_type: AgentType
|
|
169
|
+
) -> Agent[AgentDeps, str | DeferredToolRequests]:
|
|
170
|
+
"""Get agent by type.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
agent_type: The type of agent to retrieve.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
The requested agent instance.
|
|
177
|
+
"""
|
|
178
|
+
agent_map = {
|
|
179
|
+
AgentType.RESEARCH: self.research_agent,
|
|
180
|
+
AgentType.PLAN: self.plan_agent,
|
|
181
|
+
AgentType.TASKS: self.tasks_agent,
|
|
182
|
+
AgentType.SPECIFY: self.specify_agent,
|
|
183
|
+
AgentType.EXPORT: self.export_agent,
|
|
184
|
+
}
|
|
185
|
+
return agent_map[agent_type]
|
|
186
|
+
|
|
187
|
+
def _get_agent_deps(self, agent_type: AgentType) -> AgentDeps:
|
|
188
|
+
"""Get agent-specific deps by type.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
agent_type: The type of agent to retrieve deps for.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
The agent-specific dependencies.
|
|
195
|
+
"""
|
|
196
|
+
deps_map = {
|
|
197
|
+
AgentType.RESEARCH: self.research_deps,
|
|
198
|
+
AgentType.PLAN: self.plan_deps,
|
|
199
|
+
AgentType.TASKS: self.tasks_deps,
|
|
200
|
+
AgentType.SPECIFY: self.specify_deps,
|
|
201
|
+
AgentType.EXPORT: self.export_deps,
|
|
202
|
+
}
|
|
203
|
+
return deps_map[agent_type]
|
|
204
|
+
|
|
205
|
+
def _create_merged_deps(self, agent_type: AgentType) -> AgentDeps:
|
|
206
|
+
"""Create merged dependencies combining shared and agent-specific deps.
|
|
207
|
+
|
|
208
|
+
This preserves the agent's system_prompt_fn while using shared runtime state.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
agent_type: The type of agent to create merged deps for.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Merged AgentDeps with agent-specific system_prompt_fn.
|
|
215
|
+
"""
|
|
216
|
+
agent_deps = self._get_agent_deps(agent_type)
|
|
217
|
+
|
|
218
|
+
# Ensure shared deps is not None (should be guaranteed by __init__)
|
|
219
|
+
if self.deps is None:
|
|
220
|
+
raise ValueError("Shared deps is None - this should not happen")
|
|
221
|
+
|
|
222
|
+
# Create new deps with shared runtime state but agent's system_prompt_fn
|
|
223
|
+
# Use a copy of the shared deps and update the system_prompt_fn
|
|
224
|
+
merged_deps = self.deps.model_copy(
|
|
225
|
+
update={"system_prompt_fn": agent_deps.system_prompt_fn}
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
return merged_deps
|
|
229
|
+
|
|
230
|
+
def set_agent(self, agent_type: AgentType) -> None:
|
|
231
|
+
"""Set the current active agent.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
agent_type: The agent type to activate (AgentType enum or string).
|
|
235
|
+
|
|
236
|
+
Raises:
|
|
237
|
+
ValueError: If invalid agent type is provided.
|
|
238
|
+
"""
|
|
239
|
+
try:
|
|
240
|
+
self._current_agent_type = AgentType(agent_type)
|
|
241
|
+
except ValueError:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"Invalid agent type: {agent_type}. Must be one of: {', '.join(e.value for e in AgentType)}"
|
|
244
|
+
) from None
|
|
245
|
+
|
|
246
|
+
async def run(
|
|
247
|
+
self,
|
|
248
|
+
prompt: str | None = None,
|
|
249
|
+
*,
|
|
250
|
+
deps: AgentDeps | None = None,
|
|
251
|
+
usage_limits: UsageLimits | None = None,
|
|
252
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
253
|
+
**kwargs: Any,
|
|
254
|
+
) -> AgentRunResult[str | DeferredToolRequests]:
|
|
255
|
+
"""Run the current agent with automatic message history management.
|
|
256
|
+
|
|
257
|
+
This method wraps the agent's run method, automatically injecting the
|
|
258
|
+
shared message history and updating it after each run.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
prompt: Optional prompt to send to the agent.
|
|
262
|
+
deps: Optional dependencies override (defaults to manager's deps).
|
|
263
|
+
usage_limits: Optional usage limits for the agent run.
|
|
264
|
+
deferred_tool_results: Optional deferred tool results for continuing a conversation.
|
|
265
|
+
**kwargs: Additional keyword arguments to pass to the agent.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
The agent run result.
|
|
269
|
+
"""
|
|
270
|
+
# Use merged deps (shared state + agent-specific system prompt) if not provided
|
|
271
|
+
if deps is None:
|
|
272
|
+
deps = self._create_merged_deps(self._current_agent_type)
|
|
273
|
+
|
|
274
|
+
# Ensure deps is not None
|
|
275
|
+
if deps is None:
|
|
276
|
+
raise ValueError("AgentDeps must be provided")
|
|
277
|
+
|
|
278
|
+
# Clear file tracker before each run to track only this run's operations
|
|
279
|
+
deps.file_tracker.clear()
|
|
280
|
+
original_messages = self.ui_message_history.copy()
|
|
281
|
+
|
|
282
|
+
if prompt:
|
|
283
|
+
self.ui_message_history.append(ModelRequest.user_text_prompt(prompt))
|
|
284
|
+
self._post_messages_updated()
|
|
285
|
+
|
|
286
|
+
# Start with persistent message history
|
|
287
|
+
message_history = self.message_history
|
|
288
|
+
|
|
289
|
+
deps.agent_mode = self._current_agent_type
|
|
290
|
+
|
|
291
|
+
# Filter out system prompts from other agent types
|
|
292
|
+
from pydantic_ai.messages import ModelRequestPart
|
|
293
|
+
|
|
294
|
+
filtered_history: list[ModelMessage] = []
|
|
295
|
+
for message in message_history:
|
|
296
|
+
# Keep all non-ModelRequest messages as-is
|
|
297
|
+
if not isinstance(message, ModelRequest):
|
|
298
|
+
filtered_history.append(message)
|
|
299
|
+
continue
|
|
300
|
+
|
|
301
|
+
# Filter out AgentSystemPrompts from other agent types
|
|
302
|
+
filtered_parts: list[ModelRequestPart] = []
|
|
303
|
+
for part in message.parts:
|
|
304
|
+
# Keep non-AgentSystemPrompt parts
|
|
305
|
+
if not isinstance(part, AgentSystemPrompt):
|
|
306
|
+
filtered_parts.append(part)
|
|
307
|
+
continue
|
|
308
|
+
|
|
309
|
+
# Only keep system prompts from the same agent type
|
|
310
|
+
if part.agent_mode == deps.agent_mode:
|
|
311
|
+
filtered_parts.append(part)
|
|
312
|
+
|
|
313
|
+
# Only add the message if it has parts remaining
|
|
314
|
+
if filtered_parts:
|
|
315
|
+
filtered_history.append(ModelRequest(parts=filtered_parts))
|
|
316
|
+
|
|
317
|
+
message_history = filtered_history
|
|
318
|
+
|
|
319
|
+
# Add a system status message so the agent knows whats going on
|
|
320
|
+
message_history = await add_system_status_message(deps, message_history)
|
|
321
|
+
|
|
322
|
+
# Check if the message history already has a system prompt from the same agent type
|
|
323
|
+
has_system_prompt = False
|
|
324
|
+
for message in message_history:
|
|
325
|
+
if not isinstance(message, ModelRequest):
|
|
326
|
+
continue
|
|
327
|
+
|
|
328
|
+
for part in message.parts:
|
|
329
|
+
if not isinstance(part, AgentSystemPrompt):
|
|
330
|
+
continue
|
|
331
|
+
|
|
332
|
+
# Check if it's from the same agent type
|
|
333
|
+
if part.agent_mode == deps.agent_mode:
|
|
334
|
+
has_system_prompt = True
|
|
335
|
+
break
|
|
336
|
+
|
|
337
|
+
# Always ensure we have a system prompt for the agent
|
|
338
|
+
# (compaction may remove it from persistent history, but agent needs it)
|
|
339
|
+
if not has_system_prompt:
|
|
340
|
+
message_history = await add_system_prompt_message(deps, message_history)
|
|
341
|
+
|
|
342
|
+
# Run the agent with streaming support (from origin/main)
|
|
343
|
+
self._stream_state = _PartialStreamState()
|
|
344
|
+
|
|
345
|
+
model_name = ""
|
|
346
|
+
if hasattr(deps, "llm_model") and deps.llm_model is not None:
|
|
347
|
+
model_name = deps.llm_model.name
|
|
348
|
+
is_gpt5 = ( # streaming is likely not supported for gpt5. It varies between keys.
|
|
349
|
+
"gpt-5" in model_name.lower()
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
try:
|
|
353
|
+
result: AgentRunResult[
|
|
354
|
+
str | DeferredToolRequests
|
|
355
|
+
] = await self.current_agent.run(
|
|
356
|
+
prompt,
|
|
357
|
+
deps=deps,
|
|
358
|
+
usage_limits=usage_limits,
|
|
359
|
+
message_history=message_history,
|
|
360
|
+
deferred_tool_results=deferred_tool_results,
|
|
361
|
+
event_stream_handler=self._handle_event_stream if not is_gpt5 else None,
|
|
362
|
+
**kwargs,
|
|
363
|
+
)
|
|
364
|
+
finally:
|
|
365
|
+
# If the stream ended unexpectedly without a final result, clear accumulated state.
|
|
366
|
+
# state = self._stream_state
|
|
367
|
+
# if state is not None:
|
|
368
|
+
# pending_response = state.current_response
|
|
369
|
+
# if pending_response is not None:
|
|
370
|
+
# already_recorded = (
|
|
371
|
+
# bool(state.messages) and state.messages[-1] is pending_response
|
|
372
|
+
# )
|
|
373
|
+
# if not already_recorded:
|
|
374
|
+
# self._post_partial_message(pending_response, True)
|
|
375
|
+
# state.messages.append(pending_response)
|
|
376
|
+
self._stream_state = None
|
|
377
|
+
|
|
378
|
+
self.ui_message_history = original_messages + cast(
|
|
379
|
+
list[ModelRequest | ModelResponse | HintMessage], result.new_messages()
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# Apply compaction to persistent message history to prevent cascading growth
|
|
383
|
+
all_messages = result.all_messages()
|
|
384
|
+
self.message_history = await apply_persistent_compaction(all_messages, deps)
|
|
385
|
+
|
|
386
|
+
# Log file operations summary if any files were modified
|
|
387
|
+
file_operations = deps.file_tracker.operations.copy()
|
|
388
|
+
self.recently_change_files = file_operations
|
|
389
|
+
|
|
390
|
+
self._post_messages_updated(file_operations)
|
|
391
|
+
|
|
392
|
+
return result
|
|
393
|
+
|
|
394
|
+
async def _handle_event_stream(
|
|
395
|
+
self,
|
|
396
|
+
_ctx: RunContext[AgentDeps],
|
|
397
|
+
stream: AsyncIterable[AgentStreamEvent],
|
|
398
|
+
) -> None:
|
|
399
|
+
"""Process streamed events and forward partial updates to the UI."""
|
|
400
|
+
|
|
401
|
+
state = self._stream_state
|
|
402
|
+
if state is None:
|
|
403
|
+
state = self._stream_state = _PartialStreamState()
|
|
404
|
+
|
|
405
|
+
if state.current_response is not None:
|
|
406
|
+
partial_parts: list[ModelResponsePart | ToolCallPartDelta] = list(
|
|
407
|
+
state.current_response.parts
|
|
408
|
+
# cast(Sequence[ModelResponsePart], state.current_response.parts)
|
|
409
|
+
)
|
|
410
|
+
else:
|
|
411
|
+
partial_parts = []
|
|
412
|
+
|
|
413
|
+
async for event in stream:
|
|
414
|
+
try:
|
|
415
|
+
if isinstance(event, PartStartEvent):
|
|
416
|
+
index = event.index
|
|
417
|
+
if index < len(partial_parts):
|
|
418
|
+
partial_parts[index] = event.part
|
|
419
|
+
elif index == len(partial_parts):
|
|
420
|
+
partial_parts.append(event.part)
|
|
421
|
+
else:
|
|
422
|
+
logger.warning(
|
|
423
|
+
"Received PartStartEvent with out-of-bounds index",
|
|
424
|
+
extra={"index": index, "current_len": len(partial_parts)},
|
|
425
|
+
)
|
|
426
|
+
partial_parts.append(event.part)
|
|
427
|
+
|
|
428
|
+
partial_message = self._build_partial_response(partial_parts)
|
|
429
|
+
if partial_message is not None:
|
|
430
|
+
state.current_response = partial_message
|
|
431
|
+
self._post_partial_message(False)
|
|
432
|
+
|
|
433
|
+
elif isinstance(event, PartDeltaEvent):
|
|
434
|
+
index = event.index
|
|
435
|
+
if index >= len(partial_parts):
|
|
436
|
+
logger.warning(
|
|
437
|
+
"Received PartDeltaEvent before corresponding start event",
|
|
438
|
+
extra={"index": index, "current_len": len(partial_parts)},
|
|
439
|
+
)
|
|
440
|
+
continue
|
|
441
|
+
|
|
442
|
+
try:
|
|
443
|
+
updated_part = event.delta.apply(
|
|
444
|
+
cast(ModelResponsePart, partial_parts[index])
|
|
445
|
+
)
|
|
446
|
+
except Exception: # pragma: no cover - defensive logging
|
|
447
|
+
logger.exception(
|
|
448
|
+
"Failed to apply part delta", extra={"event": event}
|
|
449
|
+
)
|
|
450
|
+
continue
|
|
451
|
+
|
|
452
|
+
partial_parts[index] = updated_part
|
|
453
|
+
|
|
454
|
+
partial_message = self._build_partial_response(partial_parts)
|
|
455
|
+
if partial_message is not None:
|
|
456
|
+
state.current_response = partial_message
|
|
457
|
+
self._post_partial_message(False)
|
|
458
|
+
|
|
459
|
+
elif isinstance(event, FunctionToolCallEvent):
|
|
460
|
+
existing_call_idx = next(
|
|
461
|
+
(
|
|
462
|
+
i
|
|
463
|
+
for i, part in enumerate(partial_parts)
|
|
464
|
+
if isinstance(part, ToolCallPart)
|
|
465
|
+
and part.tool_call_id == event.part.tool_call_id
|
|
466
|
+
),
|
|
467
|
+
None,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
if existing_call_idx is not None:
|
|
471
|
+
partial_parts[existing_call_idx] = event.part
|
|
472
|
+
elif state.messages:
|
|
473
|
+
existing_call_idx = next(
|
|
474
|
+
(
|
|
475
|
+
i
|
|
476
|
+
for i, part in enumerate(state.messages[-1].parts)
|
|
477
|
+
if isinstance(part, ToolCallPart)
|
|
478
|
+
and part.tool_call_id == event.part.tool_call_id
|
|
479
|
+
),
|
|
480
|
+
None,
|
|
481
|
+
)
|
|
482
|
+
else:
|
|
483
|
+
partial_parts.append(event.part)
|
|
484
|
+
partial_message = self._build_partial_response(partial_parts)
|
|
485
|
+
if partial_message is not None:
|
|
486
|
+
state.current_response = partial_message
|
|
487
|
+
self._post_partial_message(False)
|
|
488
|
+
elif isinstance(event, FunctionToolResultEvent):
|
|
489
|
+
request_message = ModelRequest(parts=[event.result])
|
|
490
|
+
state.messages.append(request_message)
|
|
491
|
+
if (
|
|
492
|
+
event.result.tool_name == "ask_user"
|
|
493
|
+
): # special handling to ask_user, because deferred tool results mean we missed the user response
|
|
494
|
+
self.ui_message_history.append(request_message)
|
|
495
|
+
self._post_messages_updated()
|
|
496
|
+
## this is what the user responded with
|
|
497
|
+
self._post_partial_message(is_last=False)
|
|
498
|
+
|
|
499
|
+
elif isinstance(event, FinalResultEvent):
|
|
500
|
+
pass
|
|
501
|
+
except Exception: # pragma: no cover - defensive logging
|
|
502
|
+
logger.exception(
|
|
503
|
+
"Error while handling agent stream event", extra={"event": event}
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
final_message = state.current_response or self._build_partial_response(
|
|
507
|
+
partial_parts
|
|
508
|
+
)
|
|
509
|
+
if final_message is not None:
|
|
510
|
+
state.current_response = final_message
|
|
511
|
+
if final_message not in state.messages:
|
|
512
|
+
state.messages.append(final_message)
|
|
513
|
+
state.current_response = None
|
|
514
|
+
self._post_partial_message(True)
|
|
515
|
+
state.current_response = None
|
|
516
|
+
|
|
517
|
+
def _build_partial_response(
|
|
518
|
+
self, parts: list[ModelResponsePart | ToolCallPartDelta]
|
|
519
|
+
) -> ModelResponse | None:
|
|
520
|
+
"""Create a `ModelResponse` from the currently streamed parts."""
|
|
521
|
+
|
|
522
|
+
completed_parts = [
|
|
523
|
+
part for part in parts if not isinstance(part, ToolCallPartDelta)
|
|
524
|
+
]
|
|
525
|
+
if not completed_parts:
|
|
526
|
+
return None
|
|
527
|
+
return ModelResponse(parts=list(completed_parts))
|
|
528
|
+
|
|
529
|
+
def _post_partial_message(self, is_last: bool) -> None:
|
|
530
|
+
"""Post a partial message to the UI."""
|
|
531
|
+
if self._stream_state is None:
|
|
532
|
+
return
|
|
533
|
+
self.post_message(
|
|
534
|
+
PartialResponseMessage(
|
|
535
|
+
self._stream_state.current_response
|
|
536
|
+
if self._stream_state.current_response
|
|
537
|
+
not in self._stream_state.messages
|
|
538
|
+
else None,
|
|
539
|
+
self._stream_state.messages,
|
|
540
|
+
is_last,
|
|
541
|
+
)
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
def _post_messages_updated(
|
|
545
|
+
self, file_operations: list[FileOperation] | None = None
|
|
546
|
+
) -> None:
|
|
547
|
+
# Post event to notify listeners of the message history update
|
|
548
|
+
self.post_message(
|
|
549
|
+
MessageHistoryUpdated(
|
|
550
|
+
messages=self.ui_message_history.copy(),
|
|
551
|
+
agent_type=self._current_agent_type,
|
|
552
|
+
file_operations=file_operations,
|
|
553
|
+
)
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
def _filter_system_prompts(
|
|
557
|
+
self, messages: list[ModelMessage | HintMessage]
|
|
558
|
+
) -> list[ModelMessage | HintMessage]:
|
|
559
|
+
"""Filter out system prompts from messages for UI display.
|
|
560
|
+
|
|
561
|
+
Args:
|
|
562
|
+
messages: List of messages that may contain system prompts
|
|
563
|
+
|
|
564
|
+
Returns:
|
|
565
|
+
List of messages without system prompt parts
|
|
566
|
+
"""
|
|
567
|
+
filtered_messages: list[ModelMessage | HintMessage] = []
|
|
568
|
+
for msg in messages:
|
|
569
|
+
if isinstance(msg, HintMessage):
|
|
570
|
+
filtered_messages.append(msg)
|
|
571
|
+
continue
|
|
572
|
+
|
|
573
|
+
parts: Sequence[ModelRequestPart] | Sequence[ModelResponsePart] | None = (
|
|
574
|
+
msg.parts if hasattr(msg, "parts") else None
|
|
575
|
+
)
|
|
576
|
+
if not parts:
|
|
577
|
+
filtered_messages.append(msg)
|
|
578
|
+
continue
|
|
579
|
+
|
|
580
|
+
non_system_parts = [
|
|
581
|
+
part for part in parts if not isinstance(part, SystemPromptPart)
|
|
582
|
+
]
|
|
583
|
+
|
|
584
|
+
if not non_system_parts:
|
|
585
|
+
# Skip messages made up entirely of system prompt parts (e.g. system message)
|
|
586
|
+
continue
|
|
587
|
+
|
|
588
|
+
if len(non_system_parts) == len(parts):
|
|
589
|
+
# Nothing was filtered – keep original message
|
|
590
|
+
filtered_messages.append(msg)
|
|
591
|
+
continue
|
|
592
|
+
|
|
593
|
+
if is_dataclass(msg):
|
|
594
|
+
filtered_messages.append(
|
|
595
|
+
# ignore types because of the convoluted Request | Response types
|
|
596
|
+
replace(msg, parts=cast(Any, non_system_parts))
|
|
597
|
+
)
|
|
598
|
+
else:
|
|
599
|
+
filtered_messages.append(msg)
|
|
600
|
+
return filtered_messages
|
|
601
|
+
|
|
602
|
+
def get_conversation_state(self) -> "ConversationState":
|
|
603
|
+
"""Get the current conversation state.
|
|
604
|
+
|
|
605
|
+
Returns:
|
|
606
|
+
ConversationState object containing UI and agent messages and current type
|
|
607
|
+
"""
|
|
608
|
+
from shotgun.agents.conversation_history import ConversationState
|
|
609
|
+
|
|
610
|
+
return ConversationState(
|
|
611
|
+
agent_messages=self.message_history.copy(),
|
|
612
|
+
ui_messages=self.ui_message_history.copy(),
|
|
613
|
+
agent_type=self._current_agent_type.value,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
def restore_conversation_state(self, state: "ConversationState") -> None:
|
|
617
|
+
"""Restore conversation state from a saved state.
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
state: ConversationState object to restore
|
|
621
|
+
"""
|
|
622
|
+
# Restore message history for agents (includes system prompts)
|
|
623
|
+
non_hint_messages = [
|
|
624
|
+
msg for msg in state.agent_messages if not isinstance(msg, HintMessage)
|
|
625
|
+
]
|
|
626
|
+
self.message_history = non_hint_messages
|
|
627
|
+
|
|
628
|
+
# Filter out system prompts for UI display while keeping hints
|
|
629
|
+
ui_source = state.ui_messages or cast(
|
|
630
|
+
list[ModelMessage | HintMessage], state.agent_messages
|
|
631
|
+
)
|
|
632
|
+
self.ui_message_history = self._filter_system_prompts(ui_source)
|
|
633
|
+
|
|
634
|
+
# Restore agent type
|
|
635
|
+
self._current_agent_type = AgentType(state.agent_type)
|
|
636
|
+
|
|
637
|
+
# Notify listeners about the restored messages
|
|
638
|
+
self._post_messages_updated()
|
|
639
|
+
|
|
640
|
+
def add_hint_message(self, message: HintMessage) -> None:
|
|
641
|
+
self.ui_message_history.append(message)
|
|
642
|
+
self._post_messages_updated()
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
# Re-export AgentType for backward compatibility
|
|
646
|
+
__all__ = [
|
|
647
|
+
"AgentManager",
|
|
648
|
+
"AgentType",
|
|
649
|
+
"MessageHistoryUpdated",
|
|
650
|
+
"PartialResponseMessage",
|
|
651
|
+
]
|