shotgun-sh 0.2.6.dev1__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. shotgun/agents/agent_manager.py +694 -73
  2. shotgun/agents/common.py +69 -70
  3. shotgun/agents/config/constants.py +0 -6
  4. shotgun/agents/config/manager.py +70 -35
  5. shotgun/agents/config/models.py +41 -1
  6. shotgun/agents/config/provider.py +33 -5
  7. shotgun/agents/context_analyzer/__init__.py +28 -0
  8. shotgun/agents/context_analyzer/analyzer.py +471 -0
  9. shotgun/agents/context_analyzer/constants.py +9 -0
  10. shotgun/agents/context_analyzer/formatter.py +115 -0
  11. shotgun/agents/context_analyzer/models.py +212 -0
  12. shotgun/agents/conversation_history.py +125 -2
  13. shotgun/agents/conversation_manager.py +57 -19
  14. shotgun/agents/export.py +6 -7
  15. shotgun/agents/history/compaction.py +9 -4
  16. shotgun/agents/history/context_extraction.py +93 -6
  17. shotgun/agents/history/history_processors.py +113 -5
  18. shotgun/agents/history/token_counting/anthropic.py +39 -3
  19. shotgun/agents/history/token_counting/base.py +14 -3
  20. shotgun/agents/history/token_counting/openai.py +11 -1
  21. shotgun/agents/history/token_counting/sentencepiece_counter.py +8 -0
  22. shotgun/agents/history/token_counting/tokenizer_cache.py +3 -1
  23. shotgun/agents/history/token_counting/utils.py +0 -3
  24. shotgun/agents/models.py +50 -2
  25. shotgun/agents/plan.py +6 -7
  26. shotgun/agents/research.py +7 -8
  27. shotgun/agents/specify.py +6 -7
  28. shotgun/agents/tasks.py +6 -7
  29. shotgun/agents/tools/__init__.py +0 -2
  30. shotgun/agents/tools/codebase/codebase_shell.py +6 -0
  31. shotgun/agents/tools/codebase/directory_lister.py +6 -0
  32. shotgun/agents/tools/codebase/file_read.py +11 -2
  33. shotgun/agents/tools/codebase/query_graph.py +6 -0
  34. shotgun/agents/tools/codebase/retrieve_code.py +6 -0
  35. shotgun/agents/tools/file_management.py +82 -16
  36. shotgun/agents/tools/registry.py +217 -0
  37. shotgun/agents/tools/web_search/__init__.py +8 -8
  38. shotgun/agents/tools/web_search/anthropic.py +8 -2
  39. shotgun/agents/tools/web_search/gemini.py +7 -1
  40. shotgun/agents/tools/web_search/openai.py +7 -1
  41. shotgun/agents/tools/web_search/utils.py +2 -2
  42. shotgun/agents/usage_manager.py +16 -11
  43. shotgun/api_endpoints.py +7 -3
  44. shotgun/build_constants.py +3 -3
  45. shotgun/cli/clear.py +53 -0
  46. shotgun/cli/compact.py +186 -0
  47. shotgun/cli/config.py +8 -5
  48. shotgun/cli/context.py +111 -0
  49. shotgun/cli/export.py +1 -1
  50. shotgun/cli/feedback.py +4 -2
  51. shotgun/cli/models.py +1 -0
  52. shotgun/cli/plan.py +1 -1
  53. shotgun/cli/research.py +1 -1
  54. shotgun/cli/specify.py +1 -1
  55. shotgun/cli/tasks.py +1 -1
  56. shotgun/cli/update.py +16 -2
  57. shotgun/codebase/core/change_detector.py +5 -3
  58. shotgun/codebase/core/code_retrieval.py +4 -2
  59. shotgun/codebase/core/ingestor.py +10 -8
  60. shotgun/codebase/core/manager.py +13 -4
  61. shotgun/codebase/core/nl_query.py +1 -1
  62. shotgun/exceptions.py +32 -0
  63. shotgun/logging_config.py +18 -27
  64. shotgun/main.py +73 -11
  65. shotgun/posthog_telemetry.py +37 -28
  66. shotgun/prompts/agents/export.j2 +18 -1
  67. shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +5 -1
  68. shotgun/prompts/agents/partials/interactive_mode.j2 +24 -7
  69. shotgun/prompts/agents/plan.j2 +1 -1
  70. shotgun/prompts/agents/research.j2 +1 -1
  71. shotgun/prompts/agents/specify.j2 +270 -3
  72. shotgun/prompts/agents/tasks.j2 +1 -1
  73. shotgun/sentry_telemetry.py +163 -16
  74. shotgun/settings.py +238 -0
  75. shotgun/telemetry.py +18 -33
  76. shotgun/tui/app.py +243 -43
  77. shotgun/tui/commands/__init__.py +1 -1
  78. shotgun/tui/components/context_indicator.py +179 -0
  79. shotgun/tui/components/mode_indicator.py +70 -0
  80. shotgun/tui/components/status_bar.py +48 -0
  81. shotgun/tui/containers.py +91 -0
  82. shotgun/tui/dependencies.py +39 -0
  83. shotgun/tui/protocols.py +45 -0
  84. shotgun/tui/screens/chat/__init__.py +5 -0
  85. shotgun/tui/screens/chat/chat.tcss +54 -0
  86. shotgun/tui/screens/chat/chat_screen.py +1254 -0
  87. shotgun/tui/screens/chat/codebase_index_prompt_screen.py +64 -0
  88. shotgun/tui/screens/chat/codebase_index_selection.py +12 -0
  89. shotgun/tui/screens/chat/help_text.py +40 -0
  90. shotgun/tui/screens/chat/prompt_history.py +48 -0
  91. shotgun/tui/screens/chat.tcss +11 -0
  92. shotgun/tui/screens/chat_screen/command_providers.py +78 -2
  93. shotgun/tui/screens/chat_screen/history/__init__.py +22 -0
  94. shotgun/tui/screens/chat_screen/history/agent_response.py +66 -0
  95. shotgun/tui/screens/chat_screen/history/chat_history.py +115 -0
  96. shotgun/tui/screens/chat_screen/history/formatters.py +115 -0
  97. shotgun/tui/screens/chat_screen/history/partial_response.py +43 -0
  98. shotgun/tui/screens/chat_screen/history/user_question.py +42 -0
  99. shotgun/tui/screens/confirmation_dialog.py +151 -0
  100. shotgun/tui/screens/feedback.py +4 -4
  101. shotgun/tui/screens/github_issue.py +102 -0
  102. shotgun/tui/screens/model_picker.py +49 -24
  103. shotgun/tui/screens/onboarding.py +431 -0
  104. shotgun/tui/screens/pipx_migration.py +153 -0
  105. shotgun/tui/screens/provider_config.py +50 -27
  106. shotgun/tui/screens/shotgun_auth.py +2 -2
  107. shotgun/tui/screens/welcome.py +23 -12
  108. shotgun/tui/services/__init__.py +5 -0
  109. shotgun/tui/services/conversation_service.py +184 -0
  110. shotgun/tui/state/__init__.py +7 -0
  111. shotgun/tui/state/processing_state.py +185 -0
  112. shotgun/tui/utils/mode_progress.py +14 -7
  113. shotgun/tui/widgets/__init__.py +5 -0
  114. shotgun/tui/widgets/widget_coordinator.py +263 -0
  115. shotgun/utils/file_system_utils.py +22 -2
  116. shotgun/utils/marketing.py +110 -0
  117. shotgun/utils/update_checker.py +69 -14
  118. shotgun_sh-0.2.17.dist-info/METADATA +465 -0
  119. shotgun_sh-0.2.17.dist-info/RECORD +194 -0
  120. {shotgun_sh-0.2.6.dev1.dist-info → shotgun_sh-0.2.17.dist-info}/entry_points.txt +1 -0
  121. {shotgun_sh-0.2.6.dev1.dist-info → shotgun_sh-0.2.17.dist-info}/licenses/LICENSE +1 -1
  122. shotgun/agents/tools/user_interaction.py +0 -37
  123. shotgun/tui/screens/chat.py +0 -804
  124. shotgun/tui/screens/chat_screen/history.py +0 -401
  125. shotgun_sh-0.2.6.dev1.dist-info/METADATA +0 -467
  126. shotgun_sh-0.2.6.dev1.dist-info/RECORD +0 -156
  127. {shotgun_sh-0.2.6.dev1.dist-info → shotgun_sh-0.2.17.dist-info}/WHEEL +0 -0
@@ -0,0 +1,9 @@
1
+ """Tool category registry for context analysis.
2
+
3
+ This module re-exports the tool registry functionality for backward compatibility.
4
+ The actual implementation is in shotgun.agents.tools.registry.
5
+ """
6
+
7
+ from shotgun.agents.tools.registry import ToolCategory, get_tool_category
8
+
9
+ __all__ = ["ToolCategory", "get_tool_category"]
@@ -0,0 +1,115 @@
1
+ """Format context analysis for various output types."""
2
+
3
+ from typing import Any
4
+
5
+ from .models import ContextAnalysis
6
+
7
+
8
+ class ContextFormatter:
9
+ """Formats context analysis for various output types."""
10
+
11
+ @staticmethod
12
+ def format_markdown(analysis: ContextAnalysis) -> str:
13
+ """Format the analysis as markdown for display.
14
+
15
+ Args:
16
+ analysis: Context analysis to format
17
+
18
+ Returns:
19
+ Markdown-formatted string
20
+ """
21
+ lines = ["# Conversation Context Analysis", ""]
22
+
23
+ # Top-level summary with model and usage info
24
+ usage_percent = (
25
+ (analysis.agent_context_tokens / analysis.max_usable_tokens * 100)
26
+ if analysis.max_usable_tokens > 0
27
+ else 0
28
+ )
29
+ free_percent = (
30
+ (analysis.free_space_tokens / analysis.max_usable_tokens * 100)
31
+ if analysis.max_usable_tokens > 0
32
+ else 0
33
+ )
34
+
35
+ lines.extend(
36
+ [
37
+ f"Model: {analysis.model_name}",
38
+ "",
39
+ f"Total Context: {analysis.agent_context_tokens:,} / {analysis.max_usable_tokens:,} tokens ({usage_percent:.1f}%)",
40
+ "",
41
+ f"Free Space: {analysis.free_space_tokens:,} tokens ({free_percent:.1f}%)",
42
+ "",
43
+ "Autocompact Buffer: 500 tokens",
44
+ "",
45
+ ]
46
+ )
47
+
48
+ # Create 25-character visual bar showing proportional usage
49
+ # Each character represents 4% of total context
50
+ filled_chars = int(usage_percent / 4)
51
+ empty_chars = 25 - filled_chars
52
+ visual_bar = "●" * filled_chars + "○" * empty_chars
53
+
54
+ lines.extend(
55
+ [
56
+ "## Context Composition",
57
+ visual_bar,
58
+ "",
59
+ ]
60
+ )
61
+
62
+ # Add agent context categories only (hints are not part of agent context)
63
+ agent_categories = [
64
+ ("🧑 User Messages", analysis.user_messages),
65
+ ("🤖 Agent Responses", analysis.agent_responses),
66
+ ("📋 System Prompts", analysis.system_prompts),
67
+ ("📊 System Status", analysis.system_status),
68
+ ("🔍 Codebase Understanding", analysis.codebase_understanding),
69
+ ("📦 Artifact Management", analysis.artifact_management),
70
+ ("🌐 Web Research", analysis.web_research),
71
+ ]
72
+
73
+ # Only add unknown if it has content
74
+ if analysis.unknown.count > 0:
75
+ agent_categories.append(("⚠️ Unknown Tools", analysis.unknown))
76
+
77
+ for label, stats in agent_categories:
78
+ if stats.count > 0:
79
+ percentage = analysis.get_percentage(stats)
80
+ # Align labels to 30 characters for clean visual layout
81
+ lines.append(
82
+ f"{label:<30} {percentage:>5.1f}% ({stats.count} messages, ~{stats.tokens:,} tokens)"
83
+ )
84
+ # Add blank line to prevent Textual's Markdown widget from reflowing
85
+ lines.append("")
86
+
87
+ return "\n".join(lines)
88
+
89
+ @staticmethod
90
+ def format_json(analysis: ContextAnalysis) -> dict[str, Any]:
91
+ """Format the analysis as a JSON-serializable dictionary.
92
+
93
+ Args:
94
+ analysis: Context analysis to format
95
+
96
+ Returns:
97
+ Dictionary with context analysis data
98
+ """
99
+ # Use Pydantic's model_dump() to serialize the model
100
+ data = analysis.model_dump()
101
+
102
+ # Add computed summary field
103
+ data["summary"] = {
104
+ "total_messages": analysis.total_messages - analysis.hint_messages.count,
105
+ "agent_context_tokens": analysis.agent_context_tokens,
106
+ "context_window": analysis.context_window,
107
+ "usage_percentage": round(
108
+ (analysis.agent_context_tokens / analysis.context_window * 100)
109
+ if analysis.context_window > 0
110
+ else 0,
111
+ 1,
112
+ ),
113
+ }
114
+
115
+ return data
@@ -0,0 +1,212 @@
1
+ """Pydantic models for context analysis."""
2
+
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class TokenAllocation(BaseModel):
9
+ """Token counts allocated from API usage data by message/tool type.
10
+
11
+ Used internally by ContextAnalyzer to track token distribution across
12
+ different message types and tool categories.
13
+ """
14
+
15
+ user: int = Field(ge=0, default=0, description="Tokens from user prompts")
16
+ agent_responses: int = Field(
17
+ ge=0, default=0, description="Tokens from agent text responses"
18
+ )
19
+ system_prompts: int = Field(
20
+ ge=0, default=0, description="Tokens from system prompts"
21
+ )
22
+ system_status: int = Field(
23
+ ge=0, default=0, description="Tokens from system status messages"
24
+ )
25
+ codebase_understanding: int = Field(
26
+ ge=0, default=0, description="Tokens from codebase understanding tools"
27
+ )
28
+ artifact_management: int = Field(
29
+ ge=0, default=0, description="Tokens from artifact management tools"
30
+ )
31
+ web_research: int = Field(
32
+ ge=0, default=0, description="Tokens from web research tools"
33
+ )
34
+ unknown: int = Field(ge=0, default=0, description="Tokens from uncategorized tools")
35
+
36
+
37
+ class MessageTypeStats(BaseModel):
38
+ """Statistics for a specific message type."""
39
+
40
+ count: int = Field(ge=0, description="Number of messages of this type")
41
+ tokens: int = Field(ge=0, description="Total tokens consumed by this type")
42
+
43
+ @property
44
+ def avg_tokens(self) -> float:
45
+ """Calculate average tokens per message."""
46
+ return self.tokens / self.count if self.count > 0 else 0.0
47
+
48
+
49
+ class ContextAnalysis(BaseModel):
50
+ """Complete analysis of conversation context composition."""
51
+
52
+ user_messages: MessageTypeStats
53
+ agent_responses: MessageTypeStats
54
+ system_prompts: MessageTypeStats
55
+ system_status: MessageTypeStats
56
+ codebase_understanding: MessageTypeStats
57
+ artifact_management: MessageTypeStats
58
+ web_research: MessageTypeStats
59
+ unknown: MessageTypeStats
60
+ hint_messages: MessageTypeStats
61
+ total_tokens: int = Field(ge=0, description="Total tokens including hints")
62
+ total_messages: int = Field(ge=0, description="Total message count including hints")
63
+ context_window: int = Field(ge=0, description="Model's maximum input tokens")
64
+ agent_context_tokens: int = Field(
65
+ ge=0,
66
+ description="Tokens that actually consume agent context (excluding UI-only)",
67
+ )
68
+ model_name: str = Field(description="Name of the model being used")
69
+ max_usable_tokens: int = Field(
70
+ ge=0, description="80% of max_input_tokens (usable limit)"
71
+ )
72
+ free_space_tokens: int = Field(
73
+ description="Remaining tokens available (negative if over capacity)"
74
+ )
75
+
76
+ def get_percentage(self, stats: MessageTypeStats) -> float:
77
+ """Calculate percentage of agent context tokens for a message type.
78
+
79
+ Args:
80
+ stats: Message type statistics to calculate percentage for
81
+
82
+ Returns:
83
+ Percentage of total agent context tokens (0-100)
84
+ """
85
+ return (
86
+ (stats.tokens / self.agent_context_tokens * 100)
87
+ if self.agent_context_tokens > 0
88
+ else 0.0
89
+ )
90
+
91
+
92
+ class ContextCompositionTelemetry(BaseModel):
93
+ """Telemetry data for context composition tracking to PostHog."""
94
+
95
+ # Context usage
96
+ total_messages: int = Field(ge=0)
97
+ agent_context_tokens: int = Field(ge=0)
98
+ context_window: int = Field(ge=0)
99
+ max_usable_tokens: int = Field(ge=0)
100
+ free_space_tokens: int = Field(ge=0)
101
+ usage_percentage: float = Field(ge=0, le=100)
102
+
103
+ # Message type counts
104
+ user_messages_count: int = Field(ge=0)
105
+ agent_responses_count: int = Field(ge=0)
106
+ system_prompts_count: int = Field(ge=0)
107
+ system_status_count: int = Field(ge=0)
108
+ codebase_understanding_count: int = Field(ge=0)
109
+ artifact_management_count: int = Field(ge=0)
110
+ web_research_count: int = Field(ge=0)
111
+ unknown_tools_count: int = Field(ge=0)
112
+
113
+ # Token distribution percentages
114
+ user_messages_pct: float = Field(ge=0, le=100)
115
+ agent_responses_pct: float = Field(ge=0, le=100)
116
+ system_prompts_pct: float = Field(ge=0, le=100)
117
+ system_status_pct: float = Field(ge=0, le=100)
118
+ codebase_understanding_pct: float = Field(ge=0, le=100)
119
+ artifact_management_pct: float = Field(ge=0, le=100)
120
+ web_research_pct: float = Field(ge=0, le=100)
121
+ unknown_tools_pct: float = Field(ge=0, le=100)
122
+
123
+ # Compaction info
124
+ compaction_occurred: bool
125
+ messages_before_compaction: int | None = None
126
+ messages_after_compaction: int | None = None
127
+ compaction_reduction_pct: float | None = None
128
+
129
+ @classmethod
130
+ def from_analysis(
131
+ cls,
132
+ analysis: "ContextAnalysis",
133
+ compaction_occurred: bool = False,
134
+ messages_before_compaction: int | None = None,
135
+ ) -> "ContextCompositionTelemetry":
136
+ """Create telemetry from context analysis.
137
+
138
+ Args:
139
+ analysis: The context analysis to convert
140
+ compaction_occurred: Whether message compaction occurred
141
+ messages_before_compaction: Number of messages before compaction
142
+
143
+ Returns:
144
+ ContextCompositionTelemetry instance
145
+ """
146
+ total_messages = analysis.total_messages - analysis.hint_messages.count
147
+ usage_pct = (
148
+ round((analysis.agent_context_tokens / analysis.max_usable_tokens * 100), 1)
149
+ if analysis.max_usable_tokens > 0
150
+ else 0
151
+ )
152
+
153
+ # Calculate compaction metrics
154
+ messages_after: int | None = None
155
+ compaction_reduction_pct: float | None = None
156
+
157
+ if compaction_occurred and messages_before_compaction is not None:
158
+ messages_after = total_messages
159
+ if messages_before_compaction > 0:
160
+ compaction_reduction_pct = round(
161
+ (1 - (total_messages / messages_before_compaction)) * 100, 1
162
+ )
163
+
164
+ return cls(
165
+ # Context usage
166
+ total_messages=total_messages,
167
+ agent_context_tokens=analysis.agent_context_tokens,
168
+ context_window=analysis.context_window,
169
+ max_usable_tokens=analysis.max_usable_tokens,
170
+ free_space_tokens=analysis.free_space_tokens,
171
+ usage_percentage=usage_pct,
172
+ # Message type counts
173
+ user_messages_count=analysis.user_messages.count,
174
+ agent_responses_count=analysis.agent_responses.count,
175
+ system_prompts_count=analysis.system_prompts.count,
176
+ system_status_count=analysis.system_status.count,
177
+ codebase_understanding_count=analysis.codebase_understanding.count,
178
+ artifact_management_count=analysis.artifact_management.count,
179
+ web_research_count=analysis.web_research.count,
180
+ unknown_tools_count=analysis.unknown.count,
181
+ # Token distribution percentages
182
+ user_messages_pct=round(analysis.get_percentage(analysis.user_messages), 1),
183
+ agent_responses_pct=round(
184
+ analysis.get_percentage(analysis.agent_responses), 1
185
+ ),
186
+ system_prompts_pct=round(
187
+ analysis.get_percentage(analysis.system_prompts), 1
188
+ ),
189
+ system_status_pct=round(analysis.get_percentage(analysis.system_status), 1),
190
+ codebase_understanding_pct=round(
191
+ analysis.get_percentage(analysis.codebase_understanding), 1
192
+ ),
193
+ artifact_management_pct=round(
194
+ analysis.get_percentage(analysis.artifact_management), 1
195
+ ),
196
+ web_research_pct=round(analysis.get_percentage(analysis.web_research), 1),
197
+ unknown_tools_pct=round(analysis.get_percentage(analysis.unknown), 1),
198
+ # Compaction info
199
+ compaction_occurred=compaction_occurred,
200
+ messages_before_compaction=messages_before_compaction,
201
+ messages_after_compaction=messages_after,
202
+ compaction_reduction_pct=compaction_reduction_pct,
203
+ )
204
+
205
+
206
+ class ContextAnalysisOutput(BaseModel):
207
+ """Output format for context analysis with multiple representations."""
208
+
209
+ markdown: str = Field(description="Markdown-formatted analysis for display")
210
+ json_data: dict[str, Any] = Field(
211
+ description="JSON representation of analysis data"
212
+ )
@@ -1,5 +1,7 @@
1
1
  """Models and utilities for persisting TUI conversation history."""
2
2
 
3
+ import json
4
+ import logging
3
5
  from datetime import datetime
4
6
  from typing import Any, cast
5
7
 
@@ -7,14 +9,108 @@ from pydantic import BaseModel, ConfigDict, Field
7
9
  from pydantic_ai.messages import (
8
10
  ModelMessage,
9
11
  ModelMessagesTypeAdapter,
12
+ ModelResponse,
13
+ ToolCallPart,
10
14
  )
11
15
  from pydantic_core import to_jsonable_python
12
16
 
13
17
  from shotgun.tui.screens.chat_screen.hint_message import HintMessage
14
18
 
19
+ __all__ = ["HintMessage", "ConversationHistory"]
20
+
21
+ logger = logging.getLogger(__name__)
22
+
15
23
  SerializedMessage = dict[str, Any]
16
24
 
17
25
 
26
+ def is_tool_call_complete(tool_call: ToolCallPart) -> bool:
27
+ """Check if a tool call has valid, complete JSON arguments.
28
+
29
+ Args:
30
+ tool_call: The tool call part to validate
31
+
32
+ Returns:
33
+ True if the tool call args are valid JSON, False otherwise
34
+ """
35
+ if tool_call.args is None:
36
+ return True # No args is valid
37
+
38
+ if isinstance(tool_call.args, dict):
39
+ return True # Already parsed dict is valid
40
+
41
+ if not isinstance(tool_call.args, str):
42
+ return False
43
+
44
+ # Try to parse the JSON string
45
+ try:
46
+ json.loads(tool_call.args)
47
+ return True
48
+ except (json.JSONDecodeError, ValueError) as e:
49
+ # Log incomplete tool call detection
50
+ args_preview = (
51
+ tool_call.args[:100] + "..."
52
+ if len(tool_call.args) > 100
53
+ else tool_call.args
54
+ )
55
+ logger.info(
56
+ "Detected incomplete tool call in validation",
57
+ extra={
58
+ "tool_name": tool_call.tool_name,
59
+ "tool_call_id": tool_call.tool_call_id,
60
+ "args_preview": args_preview,
61
+ "error": str(e),
62
+ },
63
+ )
64
+ return False
65
+
66
+
67
+ def filter_incomplete_messages(messages: list[ModelMessage]) -> list[ModelMessage]:
68
+ """Filter out messages with incomplete tool calls.
69
+
70
+ Args:
71
+ messages: List of messages to filter
72
+
73
+ Returns:
74
+ List of messages with only complete tool calls
75
+ """
76
+ filtered: list[ModelMessage] = []
77
+ filtered_count = 0
78
+ filtered_tool_names: list[str] = []
79
+
80
+ for message in messages:
81
+ # Only check ModelResponse messages for tool calls
82
+ if not isinstance(message, ModelResponse):
83
+ filtered.append(message)
84
+ continue
85
+
86
+ # Check if any tool calls are incomplete
87
+ has_incomplete_tool_call = False
88
+ for part in message.parts:
89
+ if isinstance(part, ToolCallPart) and not is_tool_call_complete(part):
90
+ has_incomplete_tool_call = True
91
+ filtered_tool_names.append(part.tool_name)
92
+ break
93
+
94
+ # Only include messages without incomplete tool calls
95
+ if not has_incomplete_tool_call:
96
+ filtered.append(message)
97
+ else:
98
+ filtered_count += 1
99
+
100
+ # Log if any messages were filtered
101
+ if filtered_count > 0:
102
+ logger.info(
103
+ "Filtered incomplete messages before saving",
104
+ extra={
105
+ "filtered_count": filtered_count,
106
+ "total_messages": len(messages),
107
+ "filtered_tool_names": filtered_tool_names,
108
+ },
109
+ )
110
+
111
+ return filtered
112
+
113
+
18
114
  class ConversationState(BaseModel):
19
115
  """Represents the complete state of a conversation in memory."""
20
116
 
@@ -46,14 +142,41 @@ class ConversationHistory(BaseModel):
46
142
  Args:
47
143
  messages: List of ModelMessage objects to serialize and store
48
144
  """
145
+ # Filter out messages with incomplete tool calls to prevent corruption
146
+ filtered_messages = filter_incomplete_messages(messages)
147
+
49
148
  # Serialize ModelMessage list to JSON-serializable format
50
149
  self.agent_history = to_jsonable_python(
51
- messages, fallback=lambda x: str(x), exclude_none=True
150
+ filtered_messages, fallback=lambda x: str(x), exclude_none=True
52
151
  )
53
152
 
54
153
  def set_ui_messages(self, messages: list[ModelMessage | HintMessage]) -> None:
55
154
  """Set ui_history from a list of UI messages."""
56
155
 
156
+ # Filter out ModelMessages with incomplete tool calls (keep all HintMessages)
157
+ # We need to maintain message order, so we'll check each message individually
158
+ filtered_messages: list[ModelMessage | HintMessage] = []
159
+
160
+ for msg in messages:
161
+ if isinstance(msg, HintMessage):
162
+ # Always keep hint messages
163
+ filtered_messages.append(msg)
164
+ elif isinstance(msg, ModelResponse):
165
+ # Check if this ModelResponse has incomplete tool calls
166
+ has_incomplete = False
167
+ for part in msg.parts:
168
+ if isinstance(part, ToolCallPart) and not is_tool_call_complete(
169
+ part
170
+ ):
171
+ has_incomplete = True
172
+ break
173
+
174
+ if not has_incomplete:
175
+ filtered_messages.append(msg)
176
+ else:
177
+ # Keep all other ModelMessage types (ModelRequest, etc.)
178
+ filtered_messages.append(msg)
179
+
57
180
  def _serialize_message(
58
181
  message: ModelMessage | HintMessage,
59
182
  ) -> Any:
@@ -68,7 +191,7 @@ class ConversationHistory(BaseModel):
68
191
  payload.setdefault("message_type", "model")
69
192
  return payload
70
193
 
71
- self.ui_history = [_serialize_message(msg) for msg in messages]
194
+ self.ui_history = [_serialize_message(msg) for msg in filtered_messages]
72
195
 
73
196
  def get_agent_messages(self) -> list[ModelMessage]:
74
197
  """Get agent_history as a list of ModelMessage objects.
@@ -1,10 +1,15 @@
1
1
  """Manager for handling conversation persistence operations."""
2
2
 
3
+ import asyncio
3
4
  import json
4
5
  from pathlib import Path
5
6
 
7
+ import aiofiles
8
+ import aiofiles.os
9
+
6
10
  from shotgun.logging_config import get_logger
7
11
  from shotgun.utils import get_shotgun_home
12
+ from shotgun.utils.file_system_utils import async_copy_file
8
13
 
9
14
  from .conversation_history import ConversationHistory
10
15
 
@@ -26,14 +31,14 @@ class ConversationManager:
26
31
  else:
27
32
  self.conversation_path = conversation_path
28
33
 
29
- def save(self, conversation: ConversationHistory) -> None:
34
+ async def save(self, conversation: ConversationHistory) -> None:
30
35
  """Save conversation history to file.
31
36
 
32
37
  Args:
33
38
  conversation: ConversationHistory to save
34
39
  """
35
40
  # Ensure directory exists
36
- self.conversation_path.parent.mkdir(parents=True, exist_ok=True)
41
+ await aiofiles.os.makedirs(self.conversation_path.parent, exist_ok=True)
37
42
 
38
43
  try:
39
44
  # Update timestamp
@@ -41,11 +46,17 @@ class ConversationManager:
41
46
 
42
47
  conversation.updated_at = datetime.now()
43
48
 
44
- # Serialize to JSON using Pydantic's model_dump
45
- data = conversation.model_dump(mode="json")
49
+ # Serialize to JSON in background thread to avoid blocking event loop
50
+ # This is crucial for large conversations (5k+ tokens)
51
+ data = await asyncio.to_thread(conversation.model_dump, mode="json")
52
+ json_content = await asyncio.to_thread(
53
+ json.dumps, data, indent=2, ensure_ascii=False
54
+ )
46
55
 
47
- with open(self.conversation_path, "w", encoding="utf-8") as f:
48
- json.dump(data, f, indent=2, ensure_ascii=False)
56
+ async with aiofiles.open(
57
+ self.conversation_path, "w", encoding="utf-8"
58
+ ) as f:
59
+ await f.write(json_content)
49
60
 
50
61
  logger.debug("Conversation saved to %s", self.conversation_path)
51
62
 
@@ -55,21 +66,26 @@ class ConversationManager:
55
66
  )
56
67
  # Don't raise - we don't want to interrupt the user's session
57
68
 
58
- def load(self) -> ConversationHistory | None:
69
+ async def load(self) -> ConversationHistory | None:
59
70
  """Load conversation history from file.
60
71
 
61
72
  Returns:
62
73
  ConversationHistory if file exists and is valid, None otherwise
63
74
  """
64
- if not self.conversation_path.exists():
75
+ if not await aiofiles.os.path.exists(self.conversation_path):
65
76
  logger.debug("No conversation history found at %s", self.conversation_path)
66
77
  return None
67
78
 
68
79
  try:
69
- with open(self.conversation_path, encoding="utf-8") as f:
70
- data = json.load(f)
71
-
72
- conversation = ConversationHistory.model_validate(data)
80
+ async with aiofiles.open(self.conversation_path, encoding="utf-8") as f:
81
+ content = await f.read()
82
+ # Deserialize JSON in background thread to avoid blocking
83
+ data = await asyncio.to_thread(json.loads, content)
84
+
85
+ # Validate model in background thread for large conversations
86
+ conversation = await asyncio.to_thread(
87
+ ConversationHistory.model_validate, data
88
+ )
73
89
  logger.debug(
74
90
  "Conversation loaded from %s with %d agent messages",
75
91
  self.conversation_path,
@@ -77,17 +93,39 @@ class ConversationManager:
77
93
  )
78
94
  return conversation
79
95
 
80
- except Exception as e:
96
+ except (json.JSONDecodeError, ValueError) as e:
97
+ # Handle corrupted JSON or validation errors
98
+ logger.error(
99
+ "Corrupted conversation file at %s: %s. Creating backup and starting fresh.",
100
+ self.conversation_path,
101
+ e,
102
+ )
103
+
104
+ # Create a backup of the corrupted file for debugging
105
+ backup_path = self.conversation_path.with_suffix(".json.backup")
106
+ try:
107
+ await async_copy_file(self.conversation_path, backup_path)
108
+ logger.info("Backed up corrupted conversation to %s", backup_path)
109
+ except Exception as backup_error: # pragma: no cover
110
+ logger.warning("Failed to backup corrupted file: %s", backup_error)
111
+
112
+ return None
113
+
114
+ except Exception as e: # pragma: no cover
115
+ # Catch-all for unexpected errors
81
116
  logger.error(
82
- "Failed to load conversation from %s: %s", self.conversation_path, e
117
+ "Unexpected error loading conversation from %s: %s",
118
+ self.conversation_path,
119
+ e,
83
120
  )
84
121
  return None
85
122
 
86
- def clear(self) -> None:
123
+ async def clear(self) -> None:
87
124
  """Delete the conversation history file."""
88
- if self.conversation_path.exists():
125
+ if await aiofiles.os.path.exists(self.conversation_path):
89
126
  try:
90
- self.conversation_path.unlink()
127
+ # Use asyncio.to_thread for unlink operation
128
+ await asyncio.to_thread(self.conversation_path.unlink)
91
129
  logger.debug(
92
130
  "Conversation history cleared at %s", self.conversation_path
93
131
  )
@@ -96,10 +134,10 @@ class ConversationManager:
96
134
  "Failed to clear conversation at %s: %s", self.conversation_path, e
97
135
  )
98
136
 
99
- def exists(self) -> bool:
137
+ async def exists(self) -> bool:
100
138
  """Check if a conversation history file exists.
101
139
 
102
140
  Returns:
103
141
  True if conversation file exists, False otherwise
104
142
  """
105
- return self.conversation_path.exists()
143
+ return await aiofiles.os.path.exists(str(self.conversation_path))