shotgun-sh 0.1.0.dev14__py3-none-any.whl → 0.1.0.dev16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of shotgun-sh might be problematic. Click here for more details.

@@ -1,16 +1,30 @@
1
1
  """Agent manager for coordinating multiple AI agents with shared message history."""
2
2
 
3
+ import logging
4
+ from collections.abc import AsyncIterable
5
+ from dataclasses import dataclass, field
3
6
  from enum import Enum
4
- from typing import Any
7
+ from typing import Any, cast
5
8
 
6
9
  from pydantic_ai import (
7
10
  Agent,
8
11
  DeferredToolRequests,
9
12
  DeferredToolResults,
13
+ RunContext,
10
14
  UsageLimits,
11
15
  )
12
16
  from pydantic_ai.agent import AgentRunResult
13
- from pydantic_ai.messages import ModelMessage, ModelRequest
17
+ from pydantic_ai.messages import (
18
+ AgentStreamEvent,
19
+ FinalResultEvent,
20
+ ModelMessage,
21
+ ModelRequest,
22
+ ModelResponse,
23
+ ModelResponsePart,
24
+ PartDeltaEvent,
25
+ PartStartEvent,
26
+ ToolCallPartDelta,
27
+ )
14
28
  from textual.message import Message
15
29
  from textual.widget import Widget
16
30
 
@@ -18,8 +32,11 @@ from .history.compaction import apply_persistent_compaction
18
32
  from .models import AgentDeps, AgentRuntimeOptions, FileOperation
19
33
  from .plan import create_plan_agent
20
34
  from .research import create_research_agent
35
+ from .specify import create_specify_agent
21
36
  from .tasks import create_tasks_agent
22
37
 
38
+ logger = logging.getLogger(__name__)
39
+
23
40
 
24
41
  class AgentType(Enum):
25
42
  """Enumeration for available agent types (for Python < 3.11)."""
@@ -27,21 +44,48 @@ class AgentType(Enum):
27
44
  RESEARCH = "research"
28
45
  PLAN = "plan"
29
46
  TASKS = "tasks"
47
+ SPECIFY = "specify"
30
48
 
31
49
 
32
50
  class MessageHistoryUpdated(Message):
33
51
  """Event posted when the message history is updated."""
34
52
 
35
- def __init__(self, messages: list[ModelMessage], agent_type: AgentType) -> None:
53
+ def __init__(
54
+ self,
55
+ messages: list[ModelMessage],
56
+ agent_type: AgentType,
57
+ file_operations: list[FileOperation] | None = None,
58
+ ) -> None:
36
59
  """Initialize the message history updated event.
37
60
 
38
61
  Args:
39
62
  messages: The updated message history.
40
63
  agent_type: The type of agent that triggered the update.
64
+ file_operations: List of file operations from this run.
41
65
  """
42
66
  super().__init__()
43
67
  self.messages = messages
44
68
  self.agent_type = agent_type
69
+ self.file_operations = file_operations or []
70
+
71
+
72
+ class PartialResponseMessage(Message):
73
+ """Event posted when a partial response is received."""
74
+
75
+ def __init__(self, message: ModelResponse | None, is_last: bool) -> None:
76
+ """Initialize the partial response message."""
77
+ super().__init__()
78
+ self.message = message
79
+ self.is_last = is_last
80
+
81
+
82
+ @dataclass(slots=True)
83
+ class _PartialStreamState:
84
+ """Tracks partial response parts while streaming a single agent run."""
85
+
86
+ parts: list[ModelResponsePart | ToolCallPartDelta] = field(default_factory=list)
87
+ latest_partial: ModelResponse | None = None
88
+ final_sent: bool = False
45
89
 
46
90
 
47
91
  class AgentManager(Widget):
@@ -58,6 +102,8 @@ class AgentManager(Widget):
58
102
  deps: Optional agent dependencies. If not provided, defaults to interactive mode.
59
103
  """
60
104
  super().__init__()
105
+ self.display = False
106
+
61
107
  # Use provided deps or create default with interactive mode
62
108
  self.deps = deps
63
109
 
@@ -73,14 +119,17 @@ class AgentManager(Widget):
73
119
  tasks=self.deps.tasks,
74
120
  )
75
121
 
76
- # Initialize all agents with the same deps
77
- self.research_agent, _ = create_research_agent(
122
+ # Initialize all agents and store their specific deps
123
+ self.research_agent, self.research_deps = create_research_agent(
78
124
  agent_runtime_options=agent_runtime_options
79
125
  )
80
- self.plan_agent, _ = create_plan_agent(
126
+ self.plan_agent, self.plan_deps = create_plan_agent(
81
127
  agent_runtime_options=agent_runtime_options
82
128
  )
83
- self.tasks_agent, _ = create_tasks_agent(
129
+ self.tasks_agent, self.tasks_deps = create_tasks_agent(
130
+ agent_runtime_options=agent_runtime_options
131
+ )
132
+ self.specify_agent, self.specify_deps = create_specify_agent(
84
133
  agent_runtime_options=agent_runtime_options
85
134
  )
86
135
 
@@ -91,6 +140,7 @@ class AgentManager(Widget):
91
140
  self.ui_message_history: list[ModelMessage] = []
92
141
  self.message_history: list[ModelMessage] = []
93
142
  self.recently_change_files: list[FileOperation] = []
143
+ self._stream_state: _PartialStreamState | None = None
94
144
 
95
145
  @property
96
146
  def current_agent(self) -> Agent[AgentDeps, str | DeferredToolRequests]:
@@ -116,9 +166,52 @@ class AgentManager(Widget):
116
166
  AgentType.RESEARCH: self.research_agent,
117
167
  AgentType.PLAN: self.plan_agent,
118
168
  AgentType.TASKS: self.tasks_agent,
169
+ AgentType.SPECIFY: self.specify_agent,
119
170
  }
120
171
  return agent_map[agent_type]
121
172
 
173
+ def _get_agent_deps(self, agent_type: AgentType) -> AgentDeps:
174
+ """Get agent-specific deps by type.
175
+
176
+ Args:
177
+ agent_type: The type of agent to retrieve deps for.
178
+
179
+ Returns:
180
+ The agent-specific dependencies.
181
+ """
182
+ deps_map = {
183
+ AgentType.RESEARCH: self.research_deps,
184
+ AgentType.PLAN: self.plan_deps,
185
+ AgentType.TASKS: self.tasks_deps,
186
+ AgentType.SPECIFY: self.specify_deps,
187
+ }
188
+ return deps_map[agent_type]
189
+
190
+ def _create_merged_deps(self, agent_type: AgentType) -> AgentDeps:
191
+ """Create merged dependencies combining shared and agent-specific deps.
192
+
193
+ This preserves the agent's system_prompt_fn while using shared runtime state.
194
+
195
+ Args:
196
+ agent_type: The type of agent to create merged deps for.
197
+
198
+ Returns:
199
+ Merged AgentDeps with agent-specific system_prompt_fn.
200
+ """
201
+ agent_deps = self._get_agent_deps(agent_type)
202
+
203
+ # Ensure shared deps is not None (should be guaranteed by __init__)
204
+ if self.deps is None:
205
+ raise ValueError("Shared deps is None - this should not happen")
206
+
207
+ # Create new deps with shared runtime state but agent's system_prompt_fn
208
+ # Use a copy of the shared deps and update the system_prompt_fn
209
+ merged_deps = self.deps.model_copy(
210
+ update={"system_prompt_fn": agent_deps.system_prompt_fn}
211
+ )
212
+
213
+ return merged_deps
214
+
122
215
  def set_agent(self, agent_type: AgentType) -> None:
123
216
  """Set the current active agent.
124
217
 
@@ -159,9 +252,9 @@ class AgentManager(Widget):
159
252
  Returns:
160
253
  The agent run result.
161
254
  """
162
- # Use manager's deps if not provided
255
+ # Use merged deps (shared state + agent-specific system prompt) if not provided
163
256
  if deps is None:
164
- deps = self.deps
257
+ deps = self._create_merged_deps(self._current_agent_type)
165
258
 
166
259
  # Ensure deps is not None
167
260
  if deps is None:
@@ -171,39 +264,176 @@ class AgentManager(Widget):
171
264
  self.ui_message_history.append(ModelRequest.user_text_prompt(prompt))
172
265
  self._post_messages_updated()
173
266
 
174
- # Run the agent with the shared message history
175
- result: AgentRunResult[
176
- str | DeferredToolRequests
177
- ] = await self.current_agent.run(
178
- prompt,
179
- deps=deps,
180
- usage_limits=usage_limits,
181
- message_history=self.message_history,
182
- deferred_tool_results=deferred_tool_results,
183
- **kwargs,
267
+ # Ensure system prompt is added to message history before running agent
268
+ from pydantic_ai.messages import SystemPromptPart
269
+
270
+ from shotgun.agents.common import add_system_prompt_message
271
+
272
+ # Start with persistent message history
273
+ message_history = self.message_history
274
+
275
+ # Check if the message history already has a system prompt
276
+ has_system_prompt = any(
277
+ hasattr(msg, "parts")
278
+ and any(isinstance(part, SystemPromptPart) for part in msg.parts)
279
+ for msg in message_history
280
+ )
281
+
282
+ # Always ensure we have a system prompt for the agent
283
+ # (compaction may remove it from persistent history, but agent needs it)
284
+ if not has_system_prompt:
285
+ message_history = await add_system_prompt_message(deps, message_history)
286
+
287
+ # Run the agent with streaming support (from origin/main)
288
+ self._stream_state = _PartialStreamState()
289
+
290
+ model_name = ""
291
+ if hasattr(deps, "llm_model") and deps.llm_model is not None:
292
+ model_name = deps.llm_model.name
293
+ is_gpt5 = ( # streaming is likely not supported for gpt5. It varies between keys.
294
+ "gpt-5" in model_name.lower()
184
295
  )
185
296
 
186
- # Update the shared message history with all messages from this run
297
+ try:
298
+ result: AgentRunResult[
299
+ str | DeferredToolRequests
300
+ ] = await self.current_agent.run(
301
+ prompt,
302
+ deps=deps,
303
+ usage_limits=usage_limits,
304
+ message_history=message_history,
305
+ deferred_tool_results=deferred_tool_results,
306
+ event_stream_handler=self._handle_event_stream if not is_gpt5 else None,
307
+ **kwargs,
308
+ )
309
+ finally:
310
+ # If the stream ended unexpectedly without a final result, clear accumulated state.
311
+ if self._stream_state is not None and not self._stream_state.final_sent:
312
+ partial_message = self._build_partial_response(self._stream_state.parts)
313
+ if partial_message is not None:
314
+ self._post_partial_message(partial_message, True)
315
+ self._stream_state = None
316
+
187
317
  self.ui_message_history = self.ui_message_history + [
188
318
  mes for mes in result.new_messages() if not isinstance(mes, ModelRequest)
189
319
  ]
190
320
 
191
321
  # Apply compaction to persistent message history to prevent cascading growth
192
- self.message_history = await apply_persistent_compaction(
193
- result.all_messages(), deps
194
- )
195
- self._post_messages_updated()
322
+ all_messages = result.all_messages()
323
+ self.message_history = await apply_persistent_compaction(all_messages, deps)
196
324
 
197
325
  # Log file operations summary if any files were modified
198
- self.recently_change_files = deps.file_tracker.operations.copy()
326
+ file_operations = deps.file_tracker.operations.copy()
327
+ self.recently_change_files = file_operations
328
+
329
+ self._post_messages_updated(file_operations)
199
330
 
200
331
  return result
201
332
 
202
- def _post_messages_updated(self) -> None:
333
+ async def _handle_event_stream(
334
+ self,
335
+ _ctx: RunContext[AgentDeps],
336
+ stream: AsyncIterable[AgentStreamEvent],
337
+ ) -> None:
338
+ """Process streamed events and forward partial updates to the UI."""
339
+
340
+ state = self._stream_state
341
+ if state is None:
342
+ state = self._stream_state = _PartialStreamState()
343
+
344
+ partial_parts = state.parts
345
+
346
+ async for event in stream:
347
+ try:
348
+ if isinstance(event, PartStartEvent):
349
+ index = event.index
350
+ if index < len(partial_parts):
351
+ partial_parts[index] = event.part
352
+ elif index == len(partial_parts):
353
+ partial_parts.append(event.part)
354
+ else:
355
+ logger.warning(
356
+ "Received PartStartEvent with out-of-bounds index",
357
+ extra={"index": index, "current_len": len(partial_parts)},
358
+ )
359
+ partial_parts.append(event.part)
360
+
361
+ partial_message = self._build_partial_response(partial_parts)
362
+ if partial_message is not None:
363
+ state.latest_partial = partial_message
364
+ self._post_partial_message(partial_message, False)
365
+
366
+ elif isinstance(event, PartDeltaEvent):
367
+ index = event.index
368
+ if index >= len(partial_parts):
369
+ logger.warning(
370
+ "Received PartDeltaEvent before corresponding start event",
371
+ extra={"index": index, "current_len": len(partial_parts)},
372
+ )
373
+ continue
374
+
375
+ try:
376
+ updated_part = event.delta.apply(
377
+ cast(ModelResponsePart, partial_parts[index])
378
+ )
379
+ except Exception: # pragma: no cover - defensive logging
380
+ logger.exception(
381
+ "Failed to apply part delta", extra={"event": event}
382
+ )
383
+ continue
384
+
385
+ partial_parts[index] = updated_part
386
+
387
+ partial_message = self._build_partial_response(partial_parts)
388
+ if partial_message is not None:
389
+ state.latest_partial = partial_message
390
+ self._post_partial_message(partial_message, False)
391
+
392
+ elif isinstance(event, FinalResultEvent):
393
+ final_message = (
394
+ state.latest_partial
395
+ or self._build_partial_response(partial_parts)
396
+ )
397
+ self._post_partial_message(final_message, True)
398
+ state.latest_partial = None
399
+ state.final_sent = True
400
+ partial_parts.clear()
401
+ self._stream_state = None
402
+ break
403
+
404
+ # Ignore other AgentStreamEvent variants (e.g. tool call notifications) for partial UI updates.
405
+
406
+ except Exception: # pragma: no cover - defensive logging
407
+ logger.exception(
408
+ "Error while handling agent stream event", extra={"event": event}
409
+ )
410
+
411
+ def _build_partial_response(
412
+ self, parts: list[ModelResponsePart | ToolCallPartDelta]
413
+ ) -> ModelResponse | None:
414
+ """Create a `ModelResponse` from the currently streamed parts."""
415
+
416
+ completed_parts = [
417
+ part for part in parts if not isinstance(part, ToolCallPartDelta)
418
+ ]
419
+ if not completed_parts:
420
+ return None
421
+ return ModelResponse(parts=list(completed_parts))
422
+
423
+ def _post_partial_message(
424
+ self, message: ModelResponse | None, is_last: bool
425
+ ) -> None:
426
+ """Post a partial message to the UI."""
427
+ self.post_message(PartialResponseMessage(message, is_last))
428
+
429
+ def _post_messages_updated(
430
+ self, file_operations: list[FileOperation] | None = None
431
+ ) -> None:
203
432
  # Post event to notify listeners of the message history update
204
433
  self.post_message(
205
434
  MessageHistoryUpdated(
206
435
  messages=self.ui_message_history.copy(),
207
436
  agent_type=self._current_agent_type,
437
+ file_operations=file_operations,
208
438
  )
209
439
  )
shotgun/agents/models.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Pydantic models for agent dependencies and configuration."""
2
2
 
3
+ import os
3
4
  from asyncio import Future, Queue
4
5
  from collections.abc import Callable
5
6
  from datetime import datetime
@@ -187,6 +188,25 @@ class FileOperationTracker(BaseModel):
187
188
 
188
189
  return "\n".join(lines)
189
190
 
191
+ def get_display_path(self) -> str | None:
192
+ """Get a single file path or common parent directory for display.
193
+
194
+ Returns:
195
+ Path string to display, or None if no files were modified
196
+ """
197
+ if not self.operations:
198
+ return None
199
+
200
+ unique_paths = list({op.file_path for op in self.operations})
201
+
202
+ if len(unique_paths) == 1:
203
+ # Single file - return its path
204
+ return unique_paths[0]
205
+
206
+ # Multiple files - find common parent directory
207
+ common_path = os.path.commonpath(unique_paths)
208
+ return common_path
209
+
190
210
 
191
211
  class AgentDeps(AgentRuntimeOptions):
192
212
  """Dependencies passed to all agents for configuration and runtime behavior."""
@@ -347,14 +347,13 @@ async def read_artifact_section(
347
347
 
348
348
  section = service.get_section(artifact_id, mode, section_number)
349
349
 
350
- # Return formatted content with title
351
- formatted_content = f"# {section.title}\n\n{section.content}"
350
+ # Return section content (already contains title header from file storage)
352
351
  logger.debug(
353
352
  "📄 Read section %d with %d characters",
354
353
  section_number,
355
354
  len(section.content),
356
355
  )
357
- return formatted_content
356
+ return section.content
358
357
 
359
358
  except Exception as e:
360
359
  error_msg = (