shotgun-sh 0.2.3.dev2__py3-none-any.whl → 0.2.11.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/agent_manager.py +524 -58
- shotgun/agents/common.py +62 -62
- shotgun/agents/config/constants.py +0 -6
- shotgun/agents/config/manager.py +14 -3
- shotgun/agents/config/models.py +16 -0
- shotgun/agents/config/provider.py +68 -13
- shotgun/agents/context_analyzer/__init__.py +28 -0
- shotgun/agents/context_analyzer/analyzer.py +493 -0
- shotgun/agents/context_analyzer/constants.py +9 -0
- shotgun/agents/context_analyzer/formatter.py +115 -0
- shotgun/agents/context_analyzer/models.py +212 -0
- shotgun/agents/conversation_history.py +125 -2
- shotgun/agents/conversation_manager.py +24 -2
- shotgun/agents/export.py +4 -5
- shotgun/agents/history/compaction.py +9 -4
- shotgun/agents/history/context_extraction.py +93 -6
- shotgun/agents/history/history_processors.py +14 -2
- shotgun/agents/history/token_counting/anthropic.py +32 -10
- shotgun/agents/models.py +50 -2
- shotgun/agents/plan.py +4 -5
- shotgun/agents/research.py +4 -5
- shotgun/agents/specify.py +4 -5
- shotgun/agents/tasks.py +4 -5
- shotgun/agents/tools/__init__.py +0 -2
- shotgun/agents/tools/codebase/codebase_shell.py +6 -0
- shotgun/agents/tools/codebase/directory_lister.py +6 -0
- shotgun/agents/tools/codebase/file_read.py +6 -0
- shotgun/agents/tools/codebase/query_graph.py +6 -0
- shotgun/agents/tools/codebase/retrieve_code.py +6 -0
- shotgun/agents/tools/file_management.py +71 -9
- shotgun/agents/tools/registry.py +217 -0
- shotgun/agents/tools/web_search/__init__.py +24 -12
- shotgun/agents/tools/web_search/anthropic.py +24 -3
- shotgun/agents/tools/web_search/gemini.py +22 -10
- shotgun/agents/tools/web_search/openai.py +21 -12
- shotgun/api_endpoints.py +7 -3
- shotgun/build_constants.py +1 -1
- shotgun/cli/clear.py +52 -0
- shotgun/cli/compact.py +186 -0
- shotgun/cli/context.py +111 -0
- shotgun/cli/models.py +1 -0
- shotgun/cli/update.py +16 -2
- shotgun/codebase/core/manager.py +10 -1
- shotgun/llm_proxy/__init__.py +5 -2
- shotgun/llm_proxy/clients.py +12 -7
- shotgun/logging_config.py +8 -10
- shotgun/main.py +70 -10
- shotgun/posthog_telemetry.py +9 -3
- shotgun/prompts/agents/export.j2 +18 -1
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +5 -1
- shotgun/prompts/agents/partials/interactive_mode.j2 +24 -7
- shotgun/prompts/agents/plan.j2 +1 -1
- shotgun/prompts/agents/research.j2 +1 -1
- shotgun/prompts/agents/specify.j2 +270 -3
- shotgun/prompts/agents/state/system_state.j2 +4 -0
- shotgun/prompts/agents/tasks.j2 +1 -1
- shotgun/prompts/loader.py +2 -2
- shotgun/prompts/tools/web_search.j2 +14 -0
- shotgun/sentry_telemetry.py +4 -15
- shotgun/settings.py +238 -0
- shotgun/telemetry.py +15 -32
- shotgun/tui/app.py +203 -9
- shotgun/tui/commands/__init__.py +1 -1
- shotgun/tui/components/context_indicator.py +136 -0
- shotgun/tui/components/mode_indicator.py +70 -0
- shotgun/tui/components/status_bar.py +48 -0
- shotgun/tui/containers.py +93 -0
- shotgun/tui/dependencies.py +39 -0
- shotgun/tui/protocols.py +45 -0
- shotgun/tui/screens/chat/__init__.py +5 -0
- shotgun/tui/screens/chat/chat.tcss +54 -0
- shotgun/tui/screens/chat/chat_screen.py +1110 -0
- shotgun/tui/screens/chat/codebase_index_prompt_screen.py +64 -0
- shotgun/tui/screens/chat/codebase_index_selection.py +12 -0
- shotgun/tui/screens/chat/help_text.py +39 -0
- shotgun/tui/screens/chat/prompt_history.py +48 -0
- shotgun/tui/screens/chat.tcss +11 -0
- shotgun/tui/screens/chat_screen/command_providers.py +68 -2
- shotgun/tui/screens/chat_screen/history/__init__.py +22 -0
- shotgun/tui/screens/chat_screen/history/agent_response.py +66 -0
- shotgun/tui/screens/chat_screen/history/chat_history.py +116 -0
- shotgun/tui/screens/chat_screen/history/formatters.py +115 -0
- shotgun/tui/screens/chat_screen/history/partial_response.py +43 -0
- shotgun/tui/screens/chat_screen/history/user_question.py +42 -0
- shotgun/tui/screens/confirmation_dialog.py +151 -0
- shotgun/tui/screens/model_picker.py +30 -6
- shotgun/tui/screens/pipx_migration.py +153 -0
- shotgun/tui/screens/welcome.py +24 -5
- shotgun/tui/services/__init__.py +5 -0
- shotgun/tui/services/conversation_service.py +182 -0
- shotgun/tui/state/__init__.py +7 -0
- shotgun/tui/state/processing_state.py +185 -0
- shotgun/tui/widgets/__init__.py +5 -0
- shotgun/tui/widgets/widget_coordinator.py +247 -0
- shotgun/utils/datetime_utils.py +77 -0
- shotgun/utils/file_system_utils.py +3 -2
- shotgun/utils/update_checker.py +69 -14
- shotgun_sh-0.2.11.dev1.dist-info/METADATA +129 -0
- shotgun_sh-0.2.11.dev1.dist-info/RECORD +190 -0
- {shotgun_sh-0.2.3.dev2.dist-info → shotgun_sh-0.2.11.dev1.dist-info}/entry_points.txt +1 -0
- {shotgun_sh-0.2.3.dev2.dist-info → shotgun_sh-0.2.11.dev1.dist-info}/licenses/LICENSE +1 -1
- shotgun/agents/tools/user_interaction.py +0 -37
- shotgun/tui/screens/chat.py +0 -804
- shotgun/tui/screens/chat_screen/history.py +0 -352
- shotgun_sh-0.2.3.dev2.dist-info/METADATA +0 -467
- shotgun_sh-0.2.3.dev2.dist-info/RECORD +0 -154
- {shotgun_sh-0.2.3.dev2.dist-info → shotgun_sh-0.2.11.dev1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Pydantic models for context analysis."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TokenAllocation(BaseModel):
|
|
9
|
+
"""Token counts allocated from API usage data by message/tool type.
|
|
10
|
+
|
|
11
|
+
Used internally by ContextAnalyzer to track token distribution across
|
|
12
|
+
different message types and tool categories.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
user: int = Field(ge=0, default=0, description="Tokens from user prompts")
|
|
16
|
+
agent_responses: int = Field(
|
|
17
|
+
ge=0, default=0, description="Tokens from agent text responses"
|
|
18
|
+
)
|
|
19
|
+
system_prompts: int = Field(
|
|
20
|
+
ge=0, default=0, description="Tokens from system prompts"
|
|
21
|
+
)
|
|
22
|
+
system_status: int = Field(
|
|
23
|
+
ge=0, default=0, description="Tokens from system status messages"
|
|
24
|
+
)
|
|
25
|
+
codebase_understanding: int = Field(
|
|
26
|
+
ge=0, default=0, description="Tokens from codebase understanding tools"
|
|
27
|
+
)
|
|
28
|
+
artifact_management: int = Field(
|
|
29
|
+
ge=0, default=0, description="Tokens from artifact management tools"
|
|
30
|
+
)
|
|
31
|
+
web_research: int = Field(
|
|
32
|
+
ge=0, default=0, description="Tokens from web research tools"
|
|
33
|
+
)
|
|
34
|
+
unknown: int = Field(ge=0, default=0, description="Tokens from uncategorized tools")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class MessageTypeStats(BaseModel):
|
|
38
|
+
"""Statistics for a specific message type."""
|
|
39
|
+
|
|
40
|
+
count: int = Field(ge=0, description="Number of messages of this type")
|
|
41
|
+
tokens: int = Field(ge=0, description="Total tokens consumed by this type")
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def avg_tokens(self) -> float:
|
|
45
|
+
"""Calculate average tokens per message."""
|
|
46
|
+
return self.tokens / self.count if self.count > 0 else 0.0
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ContextAnalysis(BaseModel):
|
|
50
|
+
"""Complete analysis of conversation context composition."""
|
|
51
|
+
|
|
52
|
+
user_messages: MessageTypeStats
|
|
53
|
+
agent_responses: MessageTypeStats
|
|
54
|
+
system_prompts: MessageTypeStats
|
|
55
|
+
system_status: MessageTypeStats
|
|
56
|
+
codebase_understanding: MessageTypeStats
|
|
57
|
+
artifact_management: MessageTypeStats
|
|
58
|
+
web_research: MessageTypeStats
|
|
59
|
+
unknown: MessageTypeStats
|
|
60
|
+
hint_messages: MessageTypeStats
|
|
61
|
+
total_tokens: int = Field(ge=0, description="Total tokens including hints")
|
|
62
|
+
total_messages: int = Field(ge=0, description="Total message count including hints")
|
|
63
|
+
context_window: int = Field(ge=0, description="Model's maximum input tokens")
|
|
64
|
+
agent_context_tokens: int = Field(
|
|
65
|
+
ge=0,
|
|
66
|
+
description="Tokens that actually consume agent context (excluding UI-only)",
|
|
67
|
+
)
|
|
68
|
+
model_name: str = Field(description="Name of the model being used")
|
|
69
|
+
max_usable_tokens: int = Field(
|
|
70
|
+
ge=0, description="80% of max_input_tokens (usable limit)"
|
|
71
|
+
)
|
|
72
|
+
free_space_tokens: int = Field(
|
|
73
|
+
description="Remaining tokens available (negative if over capacity)"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def get_percentage(self, stats: MessageTypeStats) -> float:
|
|
77
|
+
"""Calculate percentage of agent context tokens for a message type.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
stats: Message type statistics to calculate percentage for
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Percentage of total agent context tokens (0-100)
|
|
84
|
+
"""
|
|
85
|
+
return (
|
|
86
|
+
(stats.tokens / self.agent_context_tokens * 100)
|
|
87
|
+
if self.agent_context_tokens > 0
|
|
88
|
+
else 0.0
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ContextCompositionTelemetry(BaseModel):
|
|
93
|
+
"""Telemetry data for context composition tracking to PostHog."""
|
|
94
|
+
|
|
95
|
+
# Context usage
|
|
96
|
+
total_messages: int = Field(ge=0)
|
|
97
|
+
agent_context_tokens: int = Field(ge=0)
|
|
98
|
+
context_window: int = Field(ge=0)
|
|
99
|
+
max_usable_tokens: int = Field(ge=0)
|
|
100
|
+
free_space_tokens: int = Field(ge=0)
|
|
101
|
+
usage_percentage: float = Field(ge=0, le=100)
|
|
102
|
+
|
|
103
|
+
# Message type counts
|
|
104
|
+
user_messages_count: int = Field(ge=0)
|
|
105
|
+
agent_responses_count: int = Field(ge=0)
|
|
106
|
+
system_prompts_count: int = Field(ge=0)
|
|
107
|
+
system_status_count: int = Field(ge=0)
|
|
108
|
+
codebase_understanding_count: int = Field(ge=0)
|
|
109
|
+
artifact_management_count: int = Field(ge=0)
|
|
110
|
+
web_research_count: int = Field(ge=0)
|
|
111
|
+
unknown_tools_count: int = Field(ge=0)
|
|
112
|
+
|
|
113
|
+
# Token distribution percentages
|
|
114
|
+
user_messages_pct: float = Field(ge=0, le=100)
|
|
115
|
+
agent_responses_pct: float = Field(ge=0, le=100)
|
|
116
|
+
system_prompts_pct: float = Field(ge=0, le=100)
|
|
117
|
+
system_status_pct: float = Field(ge=0, le=100)
|
|
118
|
+
codebase_understanding_pct: float = Field(ge=0, le=100)
|
|
119
|
+
artifact_management_pct: float = Field(ge=0, le=100)
|
|
120
|
+
web_research_pct: float = Field(ge=0, le=100)
|
|
121
|
+
unknown_tools_pct: float = Field(ge=0, le=100)
|
|
122
|
+
|
|
123
|
+
# Compaction info
|
|
124
|
+
compaction_occurred: bool
|
|
125
|
+
messages_before_compaction: int | None = None
|
|
126
|
+
messages_after_compaction: int | None = None
|
|
127
|
+
compaction_reduction_pct: float | None = None
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def from_analysis(
|
|
131
|
+
cls,
|
|
132
|
+
analysis: "ContextAnalysis",
|
|
133
|
+
compaction_occurred: bool = False,
|
|
134
|
+
messages_before_compaction: int | None = None,
|
|
135
|
+
) -> "ContextCompositionTelemetry":
|
|
136
|
+
"""Create telemetry from context analysis.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
analysis: The context analysis to convert
|
|
140
|
+
compaction_occurred: Whether message compaction occurred
|
|
141
|
+
messages_before_compaction: Number of messages before compaction
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
ContextCompositionTelemetry instance
|
|
145
|
+
"""
|
|
146
|
+
total_messages = analysis.total_messages - analysis.hint_messages.count
|
|
147
|
+
usage_pct = (
|
|
148
|
+
round((analysis.agent_context_tokens / analysis.max_usable_tokens * 100), 1)
|
|
149
|
+
if analysis.max_usable_tokens > 0
|
|
150
|
+
else 0
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Calculate compaction metrics
|
|
154
|
+
messages_after: int | None = None
|
|
155
|
+
compaction_reduction_pct: float | None = None
|
|
156
|
+
|
|
157
|
+
if compaction_occurred and messages_before_compaction is not None:
|
|
158
|
+
messages_after = total_messages
|
|
159
|
+
if messages_before_compaction > 0:
|
|
160
|
+
compaction_reduction_pct = round(
|
|
161
|
+
(1 - (total_messages / messages_before_compaction)) * 100, 1
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return cls(
|
|
165
|
+
# Context usage
|
|
166
|
+
total_messages=total_messages,
|
|
167
|
+
agent_context_tokens=analysis.agent_context_tokens,
|
|
168
|
+
context_window=analysis.context_window,
|
|
169
|
+
max_usable_tokens=analysis.max_usable_tokens,
|
|
170
|
+
free_space_tokens=analysis.free_space_tokens,
|
|
171
|
+
usage_percentage=usage_pct,
|
|
172
|
+
# Message type counts
|
|
173
|
+
user_messages_count=analysis.user_messages.count,
|
|
174
|
+
agent_responses_count=analysis.agent_responses.count,
|
|
175
|
+
system_prompts_count=analysis.system_prompts.count,
|
|
176
|
+
system_status_count=analysis.system_status.count,
|
|
177
|
+
codebase_understanding_count=analysis.codebase_understanding.count,
|
|
178
|
+
artifact_management_count=analysis.artifact_management.count,
|
|
179
|
+
web_research_count=analysis.web_research.count,
|
|
180
|
+
unknown_tools_count=analysis.unknown.count,
|
|
181
|
+
# Token distribution percentages
|
|
182
|
+
user_messages_pct=round(analysis.get_percentage(analysis.user_messages), 1),
|
|
183
|
+
agent_responses_pct=round(
|
|
184
|
+
analysis.get_percentage(analysis.agent_responses), 1
|
|
185
|
+
),
|
|
186
|
+
system_prompts_pct=round(
|
|
187
|
+
analysis.get_percentage(analysis.system_prompts), 1
|
|
188
|
+
),
|
|
189
|
+
system_status_pct=round(analysis.get_percentage(analysis.system_status), 1),
|
|
190
|
+
codebase_understanding_pct=round(
|
|
191
|
+
analysis.get_percentage(analysis.codebase_understanding), 1
|
|
192
|
+
),
|
|
193
|
+
artifact_management_pct=round(
|
|
194
|
+
analysis.get_percentage(analysis.artifact_management), 1
|
|
195
|
+
),
|
|
196
|
+
web_research_pct=round(analysis.get_percentage(analysis.web_research), 1),
|
|
197
|
+
unknown_tools_pct=round(analysis.get_percentage(analysis.unknown), 1),
|
|
198
|
+
# Compaction info
|
|
199
|
+
compaction_occurred=compaction_occurred,
|
|
200
|
+
messages_before_compaction=messages_before_compaction,
|
|
201
|
+
messages_after_compaction=messages_after,
|
|
202
|
+
compaction_reduction_pct=compaction_reduction_pct,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class ContextAnalysisOutput(BaseModel):
|
|
207
|
+
"""Output format for context analysis with multiple representations."""
|
|
208
|
+
|
|
209
|
+
markdown: str = Field(description="Markdown-formatted analysis for display")
|
|
210
|
+
json_data: dict[str, Any] = Field(
|
|
211
|
+
description="JSON representation of analysis data"
|
|
212
|
+
)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Models and utilities for persisting TUI conversation history."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
3
5
|
from datetime import datetime
|
|
4
6
|
from typing import Any, cast
|
|
5
7
|
|
|
@@ -7,14 +9,108 @@ from pydantic import BaseModel, ConfigDict, Field
|
|
|
7
9
|
from pydantic_ai.messages import (
|
|
8
10
|
ModelMessage,
|
|
9
11
|
ModelMessagesTypeAdapter,
|
|
12
|
+
ModelResponse,
|
|
13
|
+
ToolCallPart,
|
|
10
14
|
)
|
|
11
15
|
from pydantic_core import to_jsonable_python
|
|
12
16
|
|
|
13
17
|
from shotgun.tui.screens.chat_screen.hint_message import HintMessage
|
|
14
18
|
|
|
19
|
+
__all__ = ["HintMessage", "ConversationHistory"]
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
15
23
|
SerializedMessage = dict[str, Any]
|
|
16
24
|
|
|
17
25
|
|
|
26
|
+
def is_tool_call_complete(tool_call: ToolCallPart) -> bool:
|
|
27
|
+
"""Check if a tool call has valid, complete JSON arguments.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
tool_call: The tool call part to validate
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
True if the tool call args are valid JSON, False otherwise
|
|
34
|
+
"""
|
|
35
|
+
if tool_call.args is None:
|
|
36
|
+
return True # No args is valid
|
|
37
|
+
|
|
38
|
+
if isinstance(tool_call.args, dict):
|
|
39
|
+
return True # Already parsed dict is valid
|
|
40
|
+
|
|
41
|
+
if not isinstance(tool_call.args, str):
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
# Try to parse the JSON string
|
|
45
|
+
try:
|
|
46
|
+
json.loads(tool_call.args)
|
|
47
|
+
return True
|
|
48
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
49
|
+
# Log incomplete tool call detection
|
|
50
|
+
args_preview = (
|
|
51
|
+
tool_call.args[:100] + "..."
|
|
52
|
+
if len(tool_call.args) > 100
|
|
53
|
+
else tool_call.args
|
|
54
|
+
)
|
|
55
|
+
logger.info(
|
|
56
|
+
"Detected incomplete tool call in validation",
|
|
57
|
+
extra={
|
|
58
|
+
"tool_name": tool_call.tool_name,
|
|
59
|
+
"tool_call_id": tool_call.tool_call_id,
|
|
60
|
+
"args_preview": args_preview,
|
|
61
|
+
"error": str(e),
|
|
62
|
+
},
|
|
63
|
+
)
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def filter_incomplete_messages(messages: list[ModelMessage]) -> list[ModelMessage]:
|
|
68
|
+
"""Filter out messages with incomplete tool calls.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
messages: List of messages to filter
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
List of messages with only complete tool calls
|
|
75
|
+
"""
|
|
76
|
+
filtered: list[ModelMessage] = []
|
|
77
|
+
filtered_count = 0
|
|
78
|
+
filtered_tool_names: list[str] = []
|
|
79
|
+
|
|
80
|
+
for message in messages:
|
|
81
|
+
# Only check ModelResponse messages for tool calls
|
|
82
|
+
if not isinstance(message, ModelResponse):
|
|
83
|
+
filtered.append(message)
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
# Check if any tool calls are incomplete
|
|
87
|
+
has_incomplete_tool_call = False
|
|
88
|
+
for part in message.parts:
|
|
89
|
+
if isinstance(part, ToolCallPart) and not is_tool_call_complete(part):
|
|
90
|
+
has_incomplete_tool_call = True
|
|
91
|
+
filtered_tool_names.append(part.tool_name)
|
|
92
|
+
break
|
|
93
|
+
|
|
94
|
+
# Only include messages without incomplete tool calls
|
|
95
|
+
if not has_incomplete_tool_call:
|
|
96
|
+
filtered.append(message)
|
|
97
|
+
else:
|
|
98
|
+
filtered_count += 1
|
|
99
|
+
|
|
100
|
+
# Log if any messages were filtered
|
|
101
|
+
if filtered_count > 0:
|
|
102
|
+
logger.info(
|
|
103
|
+
"Filtered incomplete messages before saving",
|
|
104
|
+
extra={
|
|
105
|
+
"filtered_count": filtered_count,
|
|
106
|
+
"total_messages": len(messages),
|
|
107
|
+
"filtered_tool_names": filtered_tool_names,
|
|
108
|
+
},
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return filtered
|
|
112
|
+
|
|
113
|
+
|
|
18
114
|
class ConversationState(BaseModel):
|
|
19
115
|
"""Represents the complete state of a conversation in memory."""
|
|
20
116
|
|
|
@@ -46,14 +142,41 @@ class ConversationHistory(BaseModel):
|
|
|
46
142
|
Args:
|
|
47
143
|
messages: List of ModelMessage objects to serialize and store
|
|
48
144
|
"""
|
|
145
|
+
# Filter out messages with incomplete tool calls to prevent corruption
|
|
146
|
+
filtered_messages = filter_incomplete_messages(messages)
|
|
147
|
+
|
|
49
148
|
# Serialize ModelMessage list to JSON-serializable format
|
|
50
149
|
self.agent_history = to_jsonable_python(
|
|
51
|
-
|
|
150
|
+
filtered_messages, fallback=lambda x: str(x), exclude_none=True
|
|
52
151
|
)
|
|
53
152
|
|
|
54
153
|
def set_ui_messages(self, messages: list[ModelMessage | HintMessage]) -> None:
|
|
55
154
|
"""Set ui_history from a list of UI messages."""
|
|
56
155
|
|
|
156
|
+
# Filter out ModelMessages with incomplete tool calls (keep all HintMessages)
|
|
157
|
+
# We need to maintain message order, so we'll check each message individually
|
|
158
|
+
filtered_messages: list[ModelMessage | HintMessage] = []
|
|
159
|
+
|
|
160
|
+
for msg in messages:
|
|
161
|
+
if isinstance(msg, HintMessage):
|
|
162
|
+
# Always keep hint messages
|
|
163
|
+
filtered_messages.append(msg)
|
|
164
|
+
elif isinstance(msg, ModelResponse):
|
|
165
|
+
# Check if this ModelResponse has incomplete tool calls
|
|
166
|
+
has_incomplete = False
|
|
167
|
+
for part in msg.parts:
|
|
168
|
+
if isinstance(part, ToolCallPart) and not is_tool_call_complete(
|
|
169
|
+
part
|
|
170
|
+
):
|
|
171
|
+
has_incomplete = True
|
|
172
|
+
break
|
|
173
|
+
|
|
174
|
+
if not has_incomplete:
|
|
175
|
+
filtered_messages.append(msg)
|
|
176
|
+
else:
|
|
177
|
+
# Keep all other ModelMessage types (ModelRequest, etc.)
|
|
178
|
+
filtered_messages.append(msg)
|
|
179
|
+
|
|
57
180
|
def _serialize_message(
|
|
58
181
|
message: ModelMessage | HintMessage,
|
|
59
182
|
) -> Any:
|
|
@@ -68,7 +191,7 @@ class ConversationHistory(BaseModel):
|
|
|
68
191
|
payload.setdefault("message_type", "model")
|
|
69
192
|
return payload
|
|
70
193
|
|
|
71
|
-
self.ui_history = [_serialize_message(msg) for msg in
|
|
194
|
+
self.ui_history = [_serialize_message(msg) for msg in filtered_messages]
|
|
72
195
|
|
|
73
196
|
def get_agent_messages(self) -> list[ModelMessage]:
|
|
74
197
|
"""Get agent_history as a list of ModelMessage objects.
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Manager for handling conversation persistence operations."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import shutil
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
|
|
6
7
|
from shotgun.logging_config import get_logger
|
|
@@ -77,9 +78,30 @@ class ConversationManager:
|
|
|
77
78
|
)
|
|
78
79
|
return conversation
|
|
79
80
|
|
|
80
|
-
except
|
|
81
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
82
|
+
# Handle corrupted JSON or validation errors
|
|
83
|
+
logger.error(
|
|
84
|
+
"Corrupted conversation file at %s: %s. Creating backup and starting fresh.",
|
|
85
|
+
self.conversation_path,
|
|
86
|
+
e,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Create a backup of the corrupted file for debugging
|
|
90
|
+
backup_path = self.conversation_path.with_suffix(".json.backup")
|
|
91
|
+
try:
|
|
92
|
+
shutil.copy2(self.conversation_path, backup_path)
|
|
93
|
+
logger.info("Backed up corrupted conversation to %s", backup_path)
|
|
94
|
+
except Exception as backup_error: # pragma: no cover
|
|
95
|
+
logger.warning("Failed to backup corrupted file: %s", backup_error)
|
|
96
|
+
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
except Exception as e: # pragma: no cover
|
|
100
|
+
# Catch-all for unexpected errors
|
|
81
101
|
logger.error(
|
|
82
|
-
"
|
|
102
|
+
"Unexpected error loading conversation from %s: %s",
|
|
103
|
+
self.conversation_path,
|
|
104
|
+
e,
|
|
83
105
|
)
|
|
84
106
|
return None
|
|
85
107
|
|
shotgun/agents/export.py
CHANGED
|
@@ -4,7 +4,6 @@ from functools import partial
|
|
|
4
4
|
|
|
5
5
|
from pydantic_ai import (
|
|
6
6
|
Agent,
|
|
7
|
-
DeferredToolRequests,
|
|
8
7
|
)
|
|
9
8
|
from pydantic_ai.agent import AgentRunResult
|
|
10
9
|
from pydantic_ai.messages import ModelMessage
|
|
@@ -19,14 +18,14 @@ from .common import (
|
|
|
19
18
|
create_usage_limits,
|
|
20
19
|
run_agent,
|
|
21
20
|
)
|
|
22
|
-
from .models import AgentDeps, AgentRuntimeOptions, AgentType
|
|
21
|
+
from .models import AgentDeps, AgentResponse, AgentRuntimeOptions, AgentType
|
|
23
22
|
|
|
24
23
|
logger = get_logger(__name__)
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
def create_export_agent(
|
|
28
27
|
agent_runtime_options: AgentRuntimeOptions, provider: ProviderType | None = None
|
|
29
|
-
) -> tuple[Agent[AgentDeps,
|
|
28
|
+
) -> tuple[Agent[AgentDeps, AgentResponse], AgentDeps]:
|
|
30
29
|
"""Create an export agent with file management capabilities.
|
|
31
30
|
|
|
32
31
|
Args:
|
|
@@ -50,11 +49,11 @@ def create_export_agent(
|
|
|
50
49
|
|
|
51
50
|
|
|
52
51
|
async def run_export_agent(
|
|
53
|
-
agent: Agent[AgentDeps,
|
|
52
|
+
agent: Agent[AgentDeps, AgentResponse],
|
|
54
53
|
instruction: str,
|
|
55
54
|
deps: AgentDeps,
|
|
56
55
|
message_history: list[ModelMessage] | None = None,
|
|
57
|
-
) -> AgentRunResult[
|
|
56
|
+
) -> AgentRunResult[AgentResponse]:
|
|
58
57
|
"""Export artifacts based on the given instruction.
|
|
59
58
|
|
|
60
59
|
Args:
|
|
@@ -13,7 +13,7 @@ logger = get_logger(__name__)
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
async def apply_persistent_compaction(
|
|
16
|
-
messages: list[ModelMessage], deps: AgentDeps
|
|
16
|
+
messages: list[ModelMessage], deps: AgentDeps, force: bool = False
|
|
17
17
|
) -> list[ModelMessage]:
|
|
18
18
|
"""Apply compaction to message history for persistent storage.
|
|
19
19
|
|
|
@@ -23,6 +23,7 @@ async def apply_persistent_compaction(
|
|
|
23
23
|
Args:
|
|
24
24
|
messages: Full message history from agent run
|
|
25
25
|
deps: Agent dependencies containing model config
|
|
26
|
+
force: If True, force compaction even if below token threshold
|
|
26
27
|
|
|
27
28
|
Returns:
|
|
28
29
|
Compacted message history that should be stored as conversation state
|
|
@@ -46,7 +47,7 @@ async def apply_persistent_compaction(
|
|
|
46
47
|
self.usage = usage
|
|
47
48
|
|
|
48
49
|
ctx = MockContext(deps, usage)
|
|
49
|
-
compacted_messages = await token_limit_compactor(ctx, messages)
|
|
50
|
+
compacted_messages = await token_limit_compactor(ctx, messages, force=force)
|
|
50
51
|
|
|
51
52
|
# Log the result for monitoring
|
|
52
53
|
original_size = len(messages)
|
|
@@ -59,17 +60,21 @@ async def apply_persistent_compaction(
|
|
|
59
60
|
f"({reduction_pct:.1f}% reduction)"
|
|
60
61
|
)
|
|
61
62
|
|
|
62
|
-
# Track persistent compaction event
|
|
63
|
+
# Track persistent compaction event with simple metrics (fast, no token counting)
|
|
63
64
|
track_event(
|
|
64
65
|
"persistent_compaction_applied",
|
|
65
66
|
{
|
|
67
|
+
# Basic compaction metrics
|
|
66
68
|
"messages_before": original_size,
|
|
67
69
|
"messages_after": compacted_size,
|
|
68
|
-
"tokens_before": estimated_tokens,
|
|
69
70
|
"reduction_percentage": round(reduction_pct, 2),
|
|
70
71
|
"agent_mode": deps.agent_mode.value
|
|
71
72
|
if hasattr(deps, "agent_mode") and deps.agent_mode
|
|
72
73
|
else "unknown",
|
|
74
|
+
# Model and provider info (no computation needed)
|
|
75
|
+
"model_name": deps.llm_model.name.value,
|
|
76
|
+
"provider": deps.llm_model.provider.value,
|
|
77
|
+
"key_provider": deps.llm_model.key_provider.value,
|
|
73
78
|
},
|
|
74
79
|
)
|
|
75
80
|
else:
|
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
"""Context extraction utilities for history processing."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import traceback
|
|
6
|
+
|
|
3
7
|
from pydantic_ai.messages import (
|
|
4
8
|
BuiltinToolCallPart,
|
|
5
9
|
BuiltinToolReturnPart,
|
|
@@ -16,6 +20,46 @@ from pydantic_ai.messages import (
|
|
|
16
20
|
UserPromptPart,
|
|
17
21
|
)
|
|
18
22
|
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _safely_parse_tool_args(args: dict[str, object] | str | None) -> dict[str, object]:
|
|
27
|
+
"""Safely parse tool call arguments, handling incomplete/invalid JSON.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
args: Tool call arguments (dict, JSON string, or None)
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Parsed args dict, or empty dict if parsing fails
|
|
34
|
+
"""
|
|
35
|
+
if args is None:
|
|
36
|
+
return {}
|
|
37
|
+
|
|
38
|
+
if isinstance(args, dict):
|
|
39
|
+
return args
|
|
40
|
+
|
|
41
|
+
if not isinstance(args, str):
|
|
42
|
+
return {}
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
parsed = json.loads(args)
|
|
46
|
+
return parsed if isinstance(parsed, dict) else {}
|
|
47
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
48
|
+
# Only log warning if it looks like JSON (starts with { or [) - incomplete JSON
|
|
49
|
+
# Plain strings are valid args and shouldn't trigger warnings
|
|
50
|
+
stripped_args = args.strip()
|
|
51
|
+
if stripped_args.startswith(("{", "[")):
|
|
52
|
+
args_preview = args[:100] + "..." if len(args) > 100 else args
|
|
53
|
+
logger.warning(
|
|
54
|
+
"Detected incomplete/invalid JSON in tool call args during parsing",
|
|
55
|
+
extra={
|
|
56
|
+
"args_preview": args_preview,
|
|
57
|
+
"error": str(e),
|
|
58
|
+
"args_length": len(args),
|
|
59
|
+
},
|
|
60
|
+
)
|
|
61
|
+
return {}
|
|
62
|
+
|
|
19
63
|
|
|
20
64
|
def extract_context_from_messages(messages: list[ModelMessage]) -> str:
|
|
21
65
|
"""Extract context from a list of messages for summarization."""
|
|
@@ -87,12 +131,55 @@ def extract_context_from_part(
|
|
|
87
131
|
return f"<ASSISTANT_TEXT>\n{message_part.content}\n</ASSISTANT_TEXT>"
|
|
88
132
|
|
|
89
133
|
elif isinstance(message_part, ToolCallPart):
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
134
|
+
# Safely parse args to avoid crashes from incomplete JSON during streaming
|
|
135
|
+
try:
|
|
136
|
+
parsed_args = _safely_parse_tool_args(message_part.args)
|
|
137
|
+
if parsed_args:
|
|
138
|
+
# Successfully parsed as dict - format nicely
|
|
139
|
+
args_str = ", ".join(f"{k}={repr(v)}" for k, v in parsed_args.items())
|
|
140
|
+
tool_call_str = f"{message_part.tool_name}({args_str})"
|
|
141
|
+
elif isinstance(message_part.args, str) and message_part.args:
|
|
142
|
+
# Non-empty string that didn't parse as JSON
|
|
143
|
+
# Check if it looks like JSON (starts with { or [) - if so, it's incomplete
|
|
144
|
+
stripped_args = message_part.args.strip()
|
|
145
|
+
if stripped_args.startswith(("{", "[")):
|
|
146
|
+
# Looks like incomplete JSON - log warning and show empty parens
|
|
147
|
+
args_preview = (
|
|
148
|
+
stripped_args[:100] + "..."
|
|
149
|
+
if len(stripped_args) > 100
|
|
150
|
+
else stripped_args
|
|
151
|
+
)
|
|
152
|
+
stack_trace = "".join(traceback.format_stack())
|
|
153
|
+
logger.warning(
|
|
154
|
+
"ToolCallPart with unparseable args encountered during context extraction",
|
|
155
|
+
extra={
|
|
156
|
+
"tool_name": message_part.tool_name,
|
|
157
|
+
"tool_call_id": message_part.tool_call_id,
|
|
158
|
+
"args_preview": args_preview,
|
|
159
|
+
"args_type": type(message_part.args).__name__,
|
|
160
|
+
"stack_trace": stack_trace,
|
|
161
|
+
},
|
|
162
|
+
)
|
|
163
|
+
tool_call_str = f"{message_part.tool_name}()"
|
|
164
|
+
else:
|
|
165
|
+
# Plain string arg - display as-is
|
|
166
|
+
tool_call_str = f"{message_part.tool_name}({message_part.args})"
|
|
167
|
+
else:
|
|
168
|
+
# No args
|
|
169
|
+
tool_call_str = f"{message_part.tool_name}()"
|
|
170
|
+
return f"<TOOL_CALL>\n{tool_call_str}\n</TOOL_CALL>"
|
|
171
|
+
except Exception as e: # pragma: no cover - defensive catch-all
|
|
172
|
+
# If anything goes wrong, log full exception with stack trace
|
|
173
|
+
logger.error(
|
|
174
|
+
"Unexpected error processing ToolCallPart",
|
|
175
|
+
exc_info=True,
|
|
176
|
+
extra={
|
|
177
|
+
"tool_name": message_part.tool_name,
|
|
178
|
+
"tool_call_id": message_part.tool_call_id,
|
|
179
|
+
"error": str(e),
|
|
180
|
+
},
|
|
181
|
+
)
|
|
182
|
+
return f"<TOOL_CALL>\n{message_part.tool_name}()\n</TOOL_CALL>"
|
|
96
183
|
|
|
97
184
|
elif isinstance(message_part, BuiltinToolCallPart):
|
|
98
185
|
return f"<BUILTIN_TOOL_CALL>\n{message_part.tool_name}\n</BUILTIN_TOOL_CALL>"
|