shotgun-sh 0.2.8.dev2__py3-none-any.whl → 0.3.3.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shotgun/agents/agent_manager.py +382 -60
- shotgun/agents/common.py +15 -9
- shotgun/agents/config/README.md +89 -0
- shotgun/agents/config/__init__.py +10 -1
- shotgun/agents/config/constants.py +0 -6
- shotgun/agents/config/manager.py +383 -82
- shotgun/agents/config/models.py +122 -18
- shotgun/agents/config/provider.py +81 -15
- shotgun/agents/config/streaming_test.py +119 -0
- shotgun/agents/context_analyzer/__init__.py +28 -0
- shotgun/agents/context_analyzer/analyzer.py +475 -0
- shotgun/agents/context_analyzer/constants.py +9 -0
- shotgun/agents/context_analyzer/formatter.py +115 -0
- shotgun/agents/context_analyzer/models.py +212 -0
- shotgun/agents/conversation/__init__.py +18 -0
- shotgun/agents/conversation/filters.py +164 -0
- shotgun/agents/conversation/history/chunking.py +278 -0
- shotgun/agents/{history → conversation/history}/compaction.py +36 -5
- shotgun/agents/{history → conversation/history}/constants.py +5 -0
- shotgun/agents/conversation/history/file_content_deduplication.py +216 -0
- shotgun/agents/{history → conversation/history}/history_processors.py +380 -8
- shotgun/agents/{history → conversation/history}/token_counting/anthropic.py +25 -1
- shotgun/agents/{history → conversation/history}/token_counting/base.py +14 -3
- shotgun/agents/{history → conversation/history}/token_counting/openai.py +11 -1
- shotgun/agents/{history → conversation/history}/token_counting/sentencepiece_counter.py +8 -0
- shotgun/agents/{history → conversation/history}/token_counting/tokenizer_cache.py +3 -1
- shotgun/agents/{history → conversation/history}/token_counting/utils.py +0 -3
- shotgun/agents/{conversation_manager.py → conversation/manager.py} +36 -20
- shotgun/agents/{conversation_history.py → conversation/models.py} +8 -92
- shotgun/agents/error/__init__.py +11 -0
- shotgun/agents/error/models.py +19 -0
- shotgun/agents/export.py +2 -2
- shotgun/agents/plan.py +2 -2
- shotgun/agents/research.py +3 -3
- shotgun/agents/runner.py +230 -0
- shotgun/agents/specify.py +2 -2
- shotgun/agents/tasks.py +2 -2
- shotgun/agents/tools/codebase/codebase_shell.py +6 -0
- shotgun/agents/tools/codebase/directory_lister.py +6 -0
- shotgun/agents/tools/codebase/file_read.py +11 -2
- shotgun/agents/tools/codebase/query_graph.py +6 -0
- shotgun/agents/tools/codebase/retrieve_code.py +6 -0
- shotgun/agents/tools/file_management.py +27 -7
- shotgun/agents/tools/registry.py +217 -0
- shotgun/agents/tools/web_search/__init__.py +8 -8
- shotgun/agents/tools/web_search/anthropic.py +8 -2
- shotgun/agents/tools/web_search/gemini.py +7 -1
- shotgun/agents/tools/web_search/openai.py +8 -2
- shotgun/agents/tools/web_search/utils.py +2 -2
- shotgun/agents/usage_manager.py +16 -11
- shotgun/api_endpoints.py +7 -3
- shotgun/build_constants.py +2 -2
- shotgun/cli/clear.py +53 -0
- shotgun/cli/compact.py +188 -0
- shotgun/cli/config.py +8 -5
- shotgun/cli/context.py +154 -0
- shotgun/cli/error_handler.py +24 -0
- shotgun/cli/export.py +34 -34
- shotgun/cli/feedback.py +4 -2
- shotgun/cli/models.py +1 -0
- shotgun/cli/plan.py +34 -34
- shotgun/cli/research.py +18 -10
- shotgun/cli/spec/__init__.py +5 -0
- shotgun/cli/spec/backup.py +81 -0
- shotgun/cli/spec/commands.py +132 -0
- shotgun/cli/spec/models.py +48 -0
- shotgun/cli/spec/pull_service.py +219 -0
- shotgun/cli/specify.py +20 -19
- shotgun/cli/tasks.py +34 -34
- shotgun/cli/update.py +16 -2
- shotgun/codebase/core/change_detector.py +5 -3
- shotgun/codebase/core/code_retrieval.py +4 -2
- shotgun/codebase/core/ingestor.py +163 -15
- shotgun/codebase/core/manager.py +13 -4
- shotgun/codebase/core/nl_query.py +1 -1
- shotgun/codebase/models.py +2 -0
- shotgun/exceptions.py +357 -0
- shotgun/llm_proxy/__init__.py +17 -0
- shotgun/llm_proxy/client.py +215 -0
- shotgun/llm_proxy/models.py +137 -0
- shotgun/logging_config.py +60 -27
- shotgun/main.py +77 -11
- shotgun/posthog_telemetry.py +38 -29
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +28 -2
- shotgun/prompts/agents/partials/interactive_mode.j2 +3 -3
- shotgun/prompts/agents/plan.j2 +16 -0
- shotgun/prompts/agents/research.j2 +16 -3
- shotgun/prompts/agents/specify.j2 +54 -1
- shotgun/prompts/agents/state/system_state.j2 +0 -2
- shotgun/prompts/agents/tasks.j2 +16 -0
- shotgun/prompts/history/chunk_summarization.j2 +34 -0
- shotgun/prompts/history/combine_summaries.j2 +53 -0
- shotgun/sdk/codebase.py +14 -3
- shotgun/sentry_telemetry.py +163 -16
- shotgun/settings.py +243 -0
- shotgun/shotgun_web/__init__.py +67 -1
- shotgun/shotgun_web/client.py +42 -1
- shotgun/shotgun_web/constants.py +46 -0
- shotgun/shotgun_web/exceptions.py +29 -0
- shotgun/shotgun_web/models.py +390 -0
- shotgun/shotgun_web/shared_specs/__init__.py +32 -0
- shotgun/shotgun_web/shared_specs/file_scanner.py +175 -0
- shotgun/shotgun_web/shared_specs/hasher.py +83 -0
- shotgun/shotgun_web/shared_specs/models.py +71 -0
- shotgun/shotgun_web/shared_specs/upload_pipeline.py +329 -0
- shotgun/shotgun_web/shared_specs/utils.py +34 -0
- shotgun/shotgun_web/specs_client.py +703 -0
- shotgun/shotgun_web/supabase_client.py +31 -0
- shotgun/telemetry.py +10 -33
- shotgun/tui/app.py +310 -46
- shotgun/tui/commands/__init__.py +1 -1
- shotgun/tui/components/context_indicator.py +179 -0
- shotgun/tui/components/mode_indicator.py +70 -0
- shotgun/tui/components/status_bar.py +48 -0
- shotgun/tui/containers.py +91 -0
- shotgun/tui/dependencies.py +39 -0
- shotgun/tui/layout.py +5 -0
- shotgun/tui/protocols.py +45 -0
- shotgun/tui/screens/chat/__init__.py +5 -0
- shotgun/tui/screens/chat/chat.tcss +54 -0
- shotgun/tui/screens/chat/chat_screen.py +1531 -0
- shotgun/tui/screens/chat/codebase_index_prompt_screen.py +243 -0
- shotgun/tui/screens/chat/codebase_index_selection.py +12 -0
- shotgun/tui/screens/chat/help_text.py +40 -0
- shotgun/tui/screens/chat/prompt_history.py +48 -0
- shotgun/tui/screens/chat.tcss +11 -0
- shotgun/tui/screens/chat_screen/command_providers.py +91 -4
- shotgun/tui/screens/chat_screen/hint_message.py +76 -1
- shotgun/tui/screens/chat_screen/history/__init__.py +22 -0
- shotgun/tui/screens/chat_screen/history/agent_response.py +66 -0
- shotgun/tui/screens/chat_screen/history/chat_history.py +115 -0
- shotgun/tui/screens/chat_screen/history/formatters.py +115 -0
- shotgun/tui/screens/chat_screen/history/partial_response.py +43 -0
- shotgun/tui/screens/chat_screen/history/user_question.py +42 -0
- shotgun/tui/screens/confirmation_dialog.py +191 -0
- shotgun/tui/screens/directory_setup.py +45 -41
- shotgun/tui/screens/feedback.py +14 -7
- shotgun/tui/screens/github_issue.py +111 -0
- shotgun/tui/screens/model_picker.py +77 -32
- shotgun/tui/screens/onboarding.py +580 -0
- shotgun/tui/screens/pipx_migration.py +205 -0
- shotgun/tui/screens/provider_config.py +116 -35
- shotgun/tui/screens/shared_specs/__init__.py +21 -0
- shotgun/tui/screens/shared_specs/create_spec_dialog.py +273 -0
- shotgun/tui/screens/shared_specs/models.py +56 -0
- shotgun/tui/screens/shared_specs/share_specs_dialog.py +390 -0
- shotgun/tui/screens/shared_specs/upload_progress_screen.py +452 -0
- shotgun/tui/screens/shotgun_auth.py +112 -18
- shotgun/tui/screens/spec_pull.py +288 -0
- shotgun/tui/screens/welcome.py +137 -11
- shotgun/tui/services/__init__.py +5 -0
- shotgun/tui/services/conversation_service.py +187 -0
- shotgun/tui/state/__init__.py +7 -0
- shotgun/tui/state/processing_state.py +185 -0
- shotgun/tui/utils/mode_progress.py +14 -7
- shotgun/tui/widgets/__init__.py +5 -0
- shotgun/tui/widgets/widget_coordinator.py +263 -0
- shotgun/utils/file_system_utils.py +22 -2
- shotgun/utils/marketing.py +110 -0
- shotgun/utils/update_checker.py +69 -14
- shotgun_sh-0.3.3.dev1.dist-info/METADATA +472 -0
- shotgun_sh-0.3.3.dev1.dist-info/RECORD +229 -0
- {shotgun_sh-0.2.8.dev2.dist-info → shotgun_sh-0.3.3.dev1.dist-info}/WHEEL +1 -1
- {shotgun_sh-0.2.8.dev2.dist-info → shotgun_sh-0.3.3.dev1.dist-info}/entry_points.txt +1 -0
- {shotgun_sh-0.2.8.dev2.dist-info → shotgun_sh-0.3.3.dev1.dist-info}/licenses/LICENSE +1 -1
- shotgun/tui/screens/chat.py +0 -996
- shotgun/tui/screens/chat_screen/history.py +0 -335
- shotgun_sh-0.2.8.dev2.dist-info/METADATA +0 -126
- shotgun_sh-0.2.8.dev2.dist-info/RECORD +0 -155
- /shotgun/agents/{history → conversation/history}/__init__.py +0 -0
- /shotgun/agents/{history → conversation/history}/context_extraction.py +0 -0
- /shotgun/agents/{history → conversation/history}/history_building.py +0 -0
- /shotgun/agents/{history → conversation/history}/message_utils.py +0 -0
- /shotgun/agents/{history → conversation/history}/token_counting/__init__.py +0 -0
- /shotgun/agents/{history → conversation/history}/token_estimation.py +0 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Pydantic models for context analysis."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TokenAllocation(BaseModel):
|
|
9
|
+
"""Token counts allocated from API usage data by message/tool type.
|
|
10
|
+
|
|
11
|
+
Used internally by ContextAnalyzer to track token distribution across
|
|
12
|
+
different message types and tool categories.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
user: int = Field(ge=0, default=0, description="Tokens from user prompts")
|
|
16
|
+
agent_responses: int = Field(
|
|
17
|
+
ge=0, default=0, description="Tokens from agent text responses"
|
|
18
|
+
)
|
|
19
|
+
system_prompts: int = Field(
|
|
20
|
+
ge=0, default=0, description="Tokens from system prompts"
|
|
21
|
+
)
|
|
22
|
+
system_status: int = Field(
|
|
23
|
+
ge=0, default=0, description="Tokens from system status messages"
|
|
24
|
+
)
|
|
25
|
+
codebase_understanding: int = Field(
|
|
26
|
+
ge=0, default=0, description="Tokens from codebase understanding tools"
|
|
27
|
+
)
|
|
28
|
+
artifact_management: int = Field(
|
|
29
|
+
ge=0, default=0, description="Tokens from artifact management tools"
|
|
30
|
+
)
|
|
31
|
+
web_research: int = Field(
|
|
32
|
+
ge=0, default=0, description="Tokens from web research tools"
|
|
33
|
+
)
|
|
34
|
+
unknown: int = Field(ge=0, default=0, description="Tokens from uncategorized tools")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class MessageTypeStats(BaseModel):
|
|
38
|
+
"""Statistics for a specific message type."""
|
|
39
|
+
|
|
40
|
+
count: int = Field(ge=0, description="Number of messages of this type")
|
|
41
|
+
tokens: int = Field(ge=0, description="Total tokens consumed by this type")
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def avg_tokens(self) -> float:
|
|
45
|
+
"""Calculate average tokens per message."""
|
|
46
|
+
return self.tokens / self.count if self.count > 0 else 0.0
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ContextAnalysis(BaseModel):
|
|
50
|
+
"""Complete analysis of conversation context composition."""
|
|
51
|
+
|
|
52
|
+
user_messages: MessageTypeStats
|
|
53
|
+
agent_responses: MessageTypeStats
|
|
54
|
+
system_prompts: MessageTypeStats
|
|
55
|
+
system_status: MessageTypeStats
|
|
56
|
+
codebase_understanding: MessageTypeStats
|
|
57
|
+
artifact_management: MessageTypeStats
|
|
58
|
+
web_research: MessageTypeStats
|
|
59
|
+
unknown: MessageTypeStats
|
|
60
|
+
hint_messages: MessageTypeStats
|
|
61
|
+
total_tokens: int = Field(ge=0, description="Total tokens including hints")
|
|
62
|
+
total_messages: int = Field(ge=0, description="Total message count including hints")
|
|
63
|
+
context_window: int = Field(ge=0, description="Model's maximum input tokens")
|
|
64
|
+
agent_context_tokens: int = Field(
|
|
65
|
+
ge=0,
|
|
66
|
+
description="Tokens that actually consume agent context (excluding UI-only)",
|
|
67
|
+
)
|
|
68
|
+
model_name: str = Field(description="Name of the model being used")
|
|
69
|
+
max_usable_tokens: int = Field(
|
|
70
|
+
ge=0, description="80% of max_input_tokens (usable limit)"
|
|
71
|
+
)
|
|
72
|
+
free_space_tokens: int = Field(
|
|
73
|
+
description="Remaining tokens available (negative if over capacity)"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def get_percentage(self, stats: MessageTypeStats) -> float:
|
|
77
|
+
"""Calculate percentage of agent context tokens for a message type.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
stats: Message type statistics to calculate percentage for
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Percentage of total agent context tokens (0-100)
|
|
84
|
+
"""
|
|
85
|
+
return (
|
|
86
|
+
(stats.tokens / self.agent_context_tokens * 100)
|
|
87
|
+
if self.agent_context_tokens > 0
|
|
88
|
+
else 0.0
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ContextCompositionTelemetry(BaseModel):
|
|
93
|
+
"""Telemetry data for context composition tracking to PostHog."""
|
|
94
|
+
|
|
95
|
+
# Context usage
|
|
96
|
+
total_messages: int = Field(ge=0)
|
|
97
|
+
agent_context_tokens: int = Field(ge=0)
|
|
98
|
+
context_window: int = Field(ge=0)
|
|
99
|
+
max_usable_tokens: int = Field(ge=0)
|
|
100
|
+
free_space_tokens: int = Field(ge=0)
|
|
101
|
+
usage_percentage: float = Field(ge=0, le=100)
|
|
102
|
+
|
|
103
|
+
# Message type counts
|
|
104
|
+
user_messages_count: int = Field(ge=0)
|
|
105
|
+
agent_responses_count: int = Field(ge=0)
|
|
106
|
+
system_prompts_count: int = Field(ge=0)
|
|
107
|
+
system_status_count: int = Field(ge=0)
|
|
108
|
+
codebase_understanding_count: int = Field(ge=0)
|
|
109
|
+
artifact_management_count: int = Field(ge=0)
|
|
110
|
+
web_research_count: int = Field(ge=0)
|
|
111
|
+
unknown_tools_count: int = Field(ge=0)
|
|
112
|
+
|
|
113
|
+
# Token distribution percentages
|
|
114
|
+
user_messages_pct: float = Field(ge=0, le=100)
|
|
115
|
+
agent_responses_pct: float = Field(ge=0, le=100)
|
|
116
|
+
system_prompts_pct: float = Field(ge=0, le=100)
|
|
117
|
+
system_status_pct: float = Field(ge=0, le=100)
|
|
118
|
+
codebase_understanding_pct: float = Field(ge=0, le=100)
|
|
119
|
+
artifact_management_pct: float = Field(ge=0, le=100)
|
|
120
|
+
web_research_pct: float = Field(ge=0, le=100)
|
|
121
|
+
unknown_tools_pct: float = Field(ge=0, le=100)
|
|
122
|
+
|
|
123
|
+
# Compaction info
|
|
124
|
+
compaction_occurred: bool
|
|
125
|
+
messages_before_compaction: int | None = None
|
|
126
|
+
messages_after_compaction: int | None = None
|
|
127
|
+
compaction_reduction_pct: float | None = None
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def from_analysis(
|
|
131
|
+
cls,
|
|
132
|
+
analysis: "ContextAnalysis",
|
|
133
|
+
compaction_occurred: bool = False,
|
|
134
|
+
messages_before_compaction: int | None = None,
|
|
135
|
+
) -> "ContextCompositionTelemetry":
|
|
136
|
+
"""Create telemetry from context analysis.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
analysis: The context analysis to convert
|
|
140
|
+
compaction_occurred: Whether message compaction occurred
|
|
141
|
+
messages_before_compaction: Number of messages before compaction
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
ContextCompositionTelemetry instance
|
|
145
|
+
"""
|
|
146
|
+
total_messages = analysis.total_messages - analysis.hint_messages.count
|
|
147
|
+
usage_pct = (
|
|
148
|
+
round((analysis.agent_context_tokens / analysis.max_usable_tokens * 100), 1)
|
|
149
|
+
if analysis.max_usable_tokens > 0
|
|
150
|
+
else 0
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Calculate compaction metrics
|
|
154
|
+
messages_after: int | None = None
|
|
155
|
+
compaction_reduction_pct: float | None = None
|
|
156
|
+
|
|
157
|
+
if compaction_occurred and messages_before_compaction is not None:
|
|
158
|
+
messages_after = total_messages
|
|
159
|
+
if messages_before_compaction > 0:
|
|
160
|
+
compaction_reduction_pct = round(
|
|
161
|
+
(1 - (total_messages / messages_before_compaction)) * 100, 1
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return cls(
|
|
165
|
+
# Context usage
|
|
166
|
+
total_messages=total_messages,
|
|
167
|
+
agent_context_tokens=analysis.agent_context_tokens,
|
|
168
|
+
context_window=analysis.context_window,
|
|
169
|
+
max_usable_tokens=analysis.max_usable_tokens,
|
|
170
|
+
free_space_tokens=analysis.free_space_tokens,
|
|
171
|
+
usage_percentage=usage_pct,
|
|
172
|
+
# Message type counts
|
|
173
|
+
user_messages_count=analysis.user_messages.count,
|
|
174
|
+
agent_responses_count=analysis.agent_responses.count,
|
|
175
|
+
system_prompts_count=analysis.system_prompts.count,
|
|
176
|
+
system_status_count=analysis.system_status.count,
|
|
177
|
+
codebase_understanding_count=analysis.codebase_understanding.count,
|
|
178
|
+
artifact_management_count=analysis.artifact_management.count,
|
|
179
|
+
web_research_count=analysis.web_research.count,
|
|
180
|
+
unknown_tools_count=analysis.unknown.count,
|
|
181
|
+
# Token distribution percentages
|
|
182
|
+
user_messages_pct=round(analysis.get_percentage(analysis.user_messages), 1),
|
|
183
|
+
agent_responses_pct=round(
|
|
184
|
+
analysis.get_percentage(analysis.agent_responses), 1
|
|
185
|
+
),
|
|
186
|
+
system_prompts_pct=round(
|
|
187
|
+
analysis.get_percentage(analysis.system_prompts), 1
|
|
188
|
+
),
|
|
189
|
+
system_status_pct=round(analysis.get_percentage(analysis.system_status), 1),
|
|
190
|
+
codebase_understanding_pct=round(
|
|
191
|
+
analysis.get_percentage(analysis.codebase_understanding), 1
|
|
192
|
+
),
|
|
193
|
+
artifact_management_pct=round(
|
|
194
|
+
analysis.get_percentage(analysis.artifact_management), 1
|
|
195
|
+
),
|
|
196
|
+
web_research_pct=round(analysis.get_percentage(analysis.web_research), 1),
|
|
197
|
+
unknown_tools_pct=round(analysis.get_percentage(analysis.unknown), 1),
|
|
198
|
+
# Compaction info
|
|
199
|
+
compaction_occurred=compaction_occurred,
|
|
200
|
+
messages_before_compaction=messages_before_compaction,
|
|
201
|
+
messages_after_compaction=messages_after,
|
|
202
|
+
compaction_reduction_pct=compaction_reduction_pct,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class ContextAnalysisOutput(BaseModel):
|
|
207
|
+
"""Output format for context analysis with multiple representations."""
|
|
208
|
+
|
|
209
|
+
markdown: str = Field(description="Markdown-formatted analysis for display")
|
|
210
|
+
json_data: dict[str, Any] = Field(
|
|
211
|
+
description="JSON representation of analysis data"
|
|
212
|
+
)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Conversation module for managing conversation history and persistence."""
|
|
2
|
+
|
|
3
|
+
from .filters import (
|
|
4
|
+
filter_incomplete_messages,
|
|
5
|
+
filter_orphaned_tool_responses,
|
|
6
|
+
is_tool_call_complete,
|
|
7
|
+
)
|
|
8
|
+
from .manager import ConversationManager
|
|
9
|
+
from .models import ConversationHistory, ConversationState
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ConversationHistory",
|
|
13
|
+
"ConversationManager",
|
|
14
|
+
"ConversationState",
|
|
15
|
+
"filter_incomplete_messages",
|
|
16
|
+
"filter_orphaned_tool_responses",
|
|
17
|
+
"is_tool_call_complete",
|
|
18
|
+
]
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Filter functions for conversation message validation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
from pydantic_ai.messages import (
|
|
7
|
+
ModelMessage,
|
|
8
|
+
ModelRequest,
|
|
9
|
+
ModelRequestPart,
|
|
10
|
+
ModelResponse,
|
|
11
|
+
ToolCallPart,
|
|
12
|
+
ToolReturnPart,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def is_tool_call_complete(tool_call: ToolCallPart) -> bool:
|
|
19
|
+
"""Check if a tool call has valid, complete JSON arguments.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
tool_call: The tool call part to validate
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
True if the tool call args are valid JSON, False otherwise
|
|
26
|
+
"""
|
|
27
|
+
if tool_call.args is None:
|
|
28
|
+
return True # No args is valid
|
|
29
|
+
|
|
30
|
+
if isinstance(tool_call.args, dict):
|
|
31
|
+
return True # Already parsed dict is valid
|
|
32
|
+
|
|
33
|
+
if not isinstance(tool_call.args, str):
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
# Try to parse the JSON string
|
|
37
|
+
try:
|
|
38
|
+
json.loads(tool_call.args)
|
|
39
|
+
return True
|
|
40
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
41
|
+
# Log incomplete tool call detection
|
|
42
|
+
args_preview = (
|
|
43
|
+
tool_call.args[:100] + "..."
|
|
44
|
+
if len(tool_call.args) > 100
|
|
45
|
+
else tool_call.args
|
|
46
|
+
)
|
|
47
|
+
logger.info(
|
|
48
|
+
"Detected incomplete tool call in validation",
|
|
49
|
+
extra={
|
|
50
|
+
"tool_name": tool_call.tool_name,
|
|
51
|
+
"tool_call_id": tool_call.tool_call_id,
|
|
52
|
+
"args_preview": args_preview,
|
|
53
|
+
"error": str(e),
|
|
54
|
+
},
|
|
55
|
+
)
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def filter_incomplete_messages(messages: list[ModelMessage]) -> list[ModelMessage]:
|
|
60
|
+
"""Filter out messages with incomplete tool calls.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
messages: List of messages to filter
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
List of messages with only complete tool calls
|
|
67
|
+
"""
|
|
68
|
+
filtered: list[ModelMessage] = []
|
|
69
|
+
filtered_count = 0
|
|
70
|
+
filtered_tool_names: list[str] = []
|
|
71
|
+
|
|
72
|
+
for message in messages:
|
|
73
|
+
# Only check ModelResponse messages for tool calls
|
|
74
|
+
if not isinstance(message, ModelResponse):
|
|
75
|
+
filtered.append(message)
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
# Check if any tool calls are incomplete
|
|
79
|
+
has_incomplete_tool_call = False
|
|
80
|
+
for part in message.parts:
|
|
81
|
+
if isinstance(part, ToolCallPart) and not is_tool_call_complete(part):
|
|
82
|
+
has_incomplete_tool_call = True
|
|
83
|
+
filtered_tool_names.append(part.tool_name)
|
|
84
|
+
break
|
|
85
|
+
|
|
86
|
+
# Only include messages without incomplete tool calls
|
|
87
|
+
if not has_incomplete_tool_call:
|
|
88
|
+
filtered.append(message)
|
|
89
|
+
else:
|
|
90
|
+
filtered_count += 1
|
|
91
|
+
|
|
92
|
+
# Log if any messages were filtered
|
|
93
|
+
if filtered_count > 0:
|
|
94
|
+
logger.info(
|
|
95
|
+
"Filtered incomplete messages before saving",
|
|
96
|
+
extra={
|
|
97
|
+
"filtered_count": filtered_count,
|
|
98
|
+
"total_messages": len(messages),
|
|
99
|
+
"filtered_tool_names": filtered_tool_names,
|
|
100
|
+
},
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return filtered
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def filter_orphaned_tool_responses(messages: list[ModelMessage]) -> list[ModelMessage]:
|
|
107
|
+
"""Filter out tool responses without corresponding tool calls.
|
|
108
|
+
|
|
109
|
+
This ensures message history is valid for OpenAI API which requires
|
|
110
|
+
tool responses to follow their corresponding tool calls.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
messages: List of messages to filter
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
List of messages with orphaned tool responses removed
|
|
117
|
+
"""
|
|
118
|
+
# Collect all tool_call_ids from ToolCallPart in ModelResponse
|
|
119
|
+
valid_tool_call_ids: set[str] = set()
|
|
120
|
+
for msg in messages:
|
|
121
|
+
if isinstance(msg, ModelResponse):
|
|
122
|
+
for part in msg.parts:
|
|
123
|
+
if isinstance(part, ToolCallPart) and part.tool_call_id:
|
|
124
|
+
valid_tool_call_ids.add(part.tool_call_id)
|
|
125
|
+
|
|
126
|
+
# Filter out orphaned ToolReturnPart from ModelRequest
|
|
127
|
+
filtered: list[ModelMessage] = []
|
|
128
|
+
orphaned_count = 0
|
|
129
|
+
orphaned_tool_names: list[str] = []
|
|
130
|
+
|
|
131
|
+
for msg in messages:
|
|
132
|
+
if isinstance(msg, ModelRequest):
|
|
133
|
+
# Filter parts, removing orphaned ToolReturnPart
|
|
134
|
+
filtered_parts: list[ModelRequestPart] = []
|
|
135
|
+
request_part: ModelRequestPart
|
|
136
|
+
for request_part in msg.parts:
|
|
137
|
+
if isinstance(request_part, ToolReturnPart):
|
|
138
|
+
if request_part.tool_call_id in valid_tool_call_ids:
|
|
139
|
+
filtered_parts.append(request_part)
|
|
140
|
+
else:
|
|
141
|
+
# Skip orphaned tool response
|
|
142
|
+
orphaned_count += 1
|
|
143
|
+
orphaned_tool_names.append(request_part.tool_name or "unknown")
|
|
144
|
+
else:
|
|
145
|
+
filtered_parts.append(request_part)
|
|
146
|
+
|
|
147
|
+
# Only add if there are remaining parts
|
|
148
|
+
if filtered_parts:
|
|
149
|
+
filtered.append(ModelRequest(parts=filtered_parts))
|
|
150
|
+
else:
|
|
151
|
+
filtered.append(msg)
|
|
152
|
+
|
|
153
|
+
# Log if any tool responses were filtered
|
|
154
|
+
if orphaned_count > 0:
|
|
155
|
+
logger.info(
|
|
156
|
+
"Filtered orphaned tool responses",
|
|
157
|
+
extra={
|
|
158
|
+
"orphaned_count": orphaned_count,
|
|
159
|
+
"total_messages": len(messages),
|
|
160
|
+
"orphaned_tool_names": orphaned_tool_names,
|
|
161
|
+
},
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return filtered
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
"""Pattern-based chunking for oversized conversation compaction.
|
|
2
|
+
|
|
3
|
+
This module provides functions to break oversized conversations into logical
|
|
4
|
+
chunks for summarization, preserving semantic units like tool call sequences.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
|
|
10
|
+
from pydantic_ai.messages import (
|
|
11
|
+
ModelMessage,
|
|
12
|
+
ModelRequest,
|
|
13
|
+
ModelResponse,
|
|
14
|
+
ToolCallPart,
|
|
15
|
+
ToolReturnPart,
|
|
16
|
+
UserPromptPart,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from shotgun.agents.config.models import ModelConfig
|
|
20
|
+
|
|
21
|
+
from .constants import CHUNK_TARGET_RATIO, RETENTION_WINDOW_MESSAGES
|
|
22
|
+
from .token_estimation import estimate_tokens_from_messages
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class MessageGroup:
|
|
29
|
+
"""A logical group of messages that must stay together.
|
|
30
|
+
|
|
31
|
+
Examples:
|
|
32
|
+
- A single user message
|
|
33
|
+
- A tool call sequence: ModelResponse(ToolCallPart) -> ModelRequest(ToolReturnPart)
|
|
34
|
+
- A standalone assistant response
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
messages: list[ModelMessage]
|
|
38
|
+
is_tool_sequence: bool = False
|
|
39
|
+
start_index: int = 0
|
|
40
|
+
end_index: int = 0
|
|
41
|
+
_token_count: int | None = field(default=None, repr=False)
|
|
42
|
+
|
|
43
|
+
async def get_token_count(self, model_config: ModelConfig) -> int:
|
|
44
|
+
"""Lazily compute and cache token count for this group."""
|
|
45
|
+
if self._token_count is None:
|
|
46
|
+
self._token_count = await estimate_tokens_from_messages(
|
|
47
|
+
self.messages, model_config
|
|
48
|
+
)
|
|
49
|
+
return self._token_count
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class Chunk:
|
|
54
|
+
"""A chunk of message groups ready for summarization."""
|
|
55
|
+
|
|
56
|
+
groups: list[MessageGroup]
|
|
57
|
+
chunk_index: int
|
|
58
|
+
total_token_estimate: int = 0
|
|
59
|
+
|
|
60
|
+
def get_all_messages(self) -> list[ModelMessage]:
|
|
61
|
+
"""Flatten all messages in this chunk."""
|
|
62
|
+
messages: list[ModelMessage] = []
|
|
63
|
+
for group in self.groups:
|
|
64
|
+
messages.extend(group.messages)
|
|
65
|
+
return messages
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def identify_message_groups(messages: list[ModelMessage]) -> list[MessageGroup]:
|
|
69
|
+
"""Identify logical message groups that must stay together.
|
|
70
|
+
|
|
71
|
+
Rules:
|
|
72
|
+
1. Tool calls must include their responses (matched by tool_call_id)
|
|
73
|
+
2. User messages are individual groups
|
|
74
|
+
3. Standalone assistant responses are individual groups
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
messages: The full message history
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
List of MessageGroup objects
|
|
81
|
+
"""
|
|
82
|
+
groups: list[MessageGroup] = []
|
|
83
|
+
|
|
84
|
+
# Track pending tool calls that need their returns
|
|
85
|
+
# Maps tool_call_id -> group index
|
|
86
|
+
pending_tool_calls: dict[str, int] = {}
|
|
87
|
+
|
|
88
|
+
for i, msg in enumerate(messages):
|
|
89
|
+
if isinstance(msg, ModelResponse):
|
|
90
|
+
# Check for tool calls in response
|
|
91
|
+
tool_calls = [p for p in msg.parts if isinstance(p, ToolCallPart)]
|
|
92
|
+
|
|
93
|
+
if tool_calls:
|
|
94
|
+
# Start a tool sequence group
|
|
95
|
+
group = MessageGroup(
|
|
96
|
+
messages=[msg],
|
|
97
|
+
is_tool_sequence=True,
|
|
98
|
+
start_index=i,
|
|
99
|
+
end_index=i,
|
|
100
|
+
)
|
|
101
|
+
group_idx = len(groups)
|
|
102
|
+
groups.append(group)
|
|
103
|
+
|
|
104
|
+
# Track all tool call IDs in this response
|
|
105
|
+
for tc in tool_calls:
|
|
106
|
+
if tc.tool_call_id:
|
|
107
|
+
pending_tool_calls[tc.tool_call_id] = group_idx
|
|
108
|
+
else:
|
|
109
|
+
# Standalone assistant response (text only)
|
|
110
|
+
groups.append(
|
|
111
|
+
MessageGroup(
|
|
112
|
+
messages=[msg],
|
|
113
|
+
is_tool_sequence=False,
|
|
114
|
+
start_index=i,
|
|
115
|
+
end_index=i,
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
elif isinstance(msg, ModelRequest):
|
|
120
|
+
# Check for tool returns in request
|
|
121
|
+
tool_returns = [p for p in msg.parts if isinstance(p, ToolReturnPart)]
|
|
122
|
+
user_prompts = [p for p in msg.parts if isinstance(p, UserPromptPart)]
|
|
123
|
+
|
|
124
|
+
if tool_returns:
|
|
125
|
+
# Add to corresponding tool call groups
|
|
126
|
+
for tr in tool_returns:
|
|
127
|
+
if tr.tool_call_id and tr.tool_call_id in pending_tool_calls:
|
|
128
|
+
group_idx = pending_tool_calls.pop(tr.tool_call_id)
|
|
129
|
+
groups[group_idx].messages.append(msg)
|
|
130
|
+
groups[group_idx].end_index = i
|
|
131
|
+
# Note: orphaned tool returns are handled by filter_orphaned_tool_responses
|
|
132
|
+
|
|
133
|
+
elif user_prompts:
|
|
134
|
+
# User message - standalone group
|
|
135
|
+
groups.append(
|
|
136
|
+
MessageGroup(
|
|
137
|
+
messages=[msg],
|
|
138
|
+
is_tool_sequence=False,
|
|
139
|
+
start_index=i,
|
|
140
|
+
end_index=i,
|
|
141
|
+
)
|
|
142
|
+
)
|
|
143
|
+
# Note: System prompts are handled separately by compaction
|
|
144
|
+
|
|
145
|
+
logger.debug(
|
|
146
|
+
f"Identified {len(groups)} message groups "
|
|
147
|
+
f"({sum(1 for g in groups if g.is_tool_sequence)} tool sequences)"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
return groups
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
async def create_chunks(
|
|
154
|
+
groups: list[MessageGroup],
|
|
155
|
+
model_config: ModelConfig,
|
|
156
|
+
retention_window: int = RETENTION_WINDOW_MESSAGES,
|
|
157
|
+
) -> tuple[list[Chunk], list[ModelMessage]]:
|
|
158
|
+
"""Create chunks from message groups, respecting token limits.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
groups: List of message groups from identify_message_groups()
|
|
162
|
+
model_config: Model configuration for token limits
|
|
163
|
+
retention_window: Number of recent groups to keep outside compaction
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Tuple of (chunks_to_summarize, retained_recent_messages)
|
|
167
|
+
"""
|
|
168
|
+
max_chunk_tokens = int(model_config.max_input_tokens * CHUNK_TARGET_RATIO)
|
|
169
|
+
|
|
170
|
+
# Handle edge case: too few groups
|
|
171
|
+
if len(groups) <= retention_window:
|
|
172
|
+
all_messages: list[ModelMessage] = []
|
|
173
|
+
for g in groups:
|
|
174
|
+
all_messages.extend(g.messages)
|
|
175
|
+
return [], all_messages
|
|
176
|
+
|
|
177
|
+
# Separate retention window from groups to chunk
|
|
178
|
+
groups_to_chunk = groups[:-retention_window]
|
|
179
|
+
retained_groups = groups[-retention_window:]
|
|
180
|
+
|
|
181
|
+
# Build chunks
|
|
182
|
+
chunks: list[Chunk] = []
|
|
183
|
+
current_groups: list[MessageGroup] = []
|
|
184
|
+
current_tokens = 0
|
|
185
|
+
|
|
186
|
+
for group in groups_to_chunk:
|
|
187
|
+
group_tokens = await group.get_token_count(model_config)
|
|
188
|
+
|
|
189
|
+
# Handle oversized single group - becomes its own chunk
|
|
190
|
+
if group_tokens > max_chunk_tokens:
|
|
191
|
+
# Finish current chunk if any
|
|
192
|
+
if current_groups:
|
|
193
|
+
chunks.append(
|
|
194
|
+
Chunk(
|
|
195
|
+
groups=current_groups,
|
|
196
|
+
chunk_index=len(chunks),
|
|
197
|
+
total_token_estimate=current_tokens,
|
|
198
|
+
)
|
|
199
|
+
)
|
|
200
|
+
current_groups = []
|
|
201
|
+
current_tokens = 0
|
|
202
|
+
|
|
203
|
+
# Add oversized as its own chunk
|
|
204
|
+
chunks.append(
|
|
205
|
+
Chunk(
|
|
206
|
+
groups=[group],
|
|
207
|
+
chunk_index=len(chunks),
|
|
208
|
+
total_token_estimate=group_tokens,
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
logger.warning(
|
|
212
|
+
f"Oversized message group ({group_tokens:,} tokens) "
|
|
213
|
+
f"added as single chunk - may need special handling"
|
|
214
|
+
)
|
|
215
|
+
continue
|
|
216
|
+
|
|
217
|
+
# Would adding this group exceed limit?
|
|
218
|
+
if current_tokens + group_tokens > max_chunk_tokens:
|
|
219
|
+
# Finish current chunk
|
|
220
|
+
if current_groups:
|
|
221
|
+
chunks.append(
|
|
222
|
+
Chunk(
|
|
223
|
+
groups=current_groups,
|
|
224
|
+
chunk_index=len(chunks),
|
|
225
|
+
total_token_estimate=current_tokens,
|
|
226
|
+
)
|
|
227
|
+
)
|
|
228
|
+
current_groups = [group]
|
|
229
|
+
current_tokens = group_tokens
|
|
230
|
+
else:
|
|
231
|
+
current_groups.append(group)
|
|
232
|
+
current_tokens += group_tokens
|
|
233
|
+
|
|
234
|
+
# Don't forget last chunk
|
|
235
|
+
if current_groups:
|
|
236
|
+
chunks.append(
|
|
237
|
+
Chunk(
|
|
238
|
+
groups=current_groups,
|
|
239
|
+
chunk_index=len(chunks),
|
|
240
|
+
total_token_estimate=current_tokens,
|
|
241
|
+
)
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Extract retained messages
|
|
245
|
+
retained_messages: list[ModelMessage] = []
|
|
246
|
+
for g in retained_groups:
|
|
247
|
+
retained_messages.extend(g.messages)
|
|
248
|
+
|
|
249
|
+
# Update chunk indices (in case any were out of order)
|
|
250
|
+
for i, chunk in enumerate(chunks):
|
|
251
|
+
chunk.chunk_index = i
|
|
252
|
+
|
|
253
|
+
logger.info(
|
|
254
|
+
f"Created {len(chunks)} chunks for compaction, "
|
|
255
|
+
f"retaining {len(retained_messages)} recent messages"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return chunks, retained_messages
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
async def chunk_messages_for_compaction(
|
|
262
|
+
messages: list[ModelMessage],
|
|
263
|
+
model_config: ModelConfig,
|
|
264
|
+
) -> tuple[list[Chunk], list[ModelMessage]]:
|
|
265
|
+
"""Main entry point: chunk oversized conversation for summarization.
|
|
266
|
+
|
|
267
|
+
This function identifies logical message groups (preserving tool call sequences),
|
|
268
|
+
then packs them into chunks that fit within model token limits.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
messages: Full conversation message history
|
|
272
|
+
model_config: Model configuration for token limits
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
Tuple of (chunks_to_summarize, retention_window_messages)
|
|
276
|
+
"""
|
|
277
|
+
groups = identify_message_groups(messages)
|
|
278
|
+
return await create_chunks(groups, model_config)
|