shotgun-sh 0.2.11.dev2__py3-none-any.whl → 0.2.11.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/agent_manager.py +194 -28
- shotgun/agents/common.py +14 -8
- shotgun/agents/config/manager.py +64 -33
- shotgun/agents/config/models.py +25 -1
- shotgun/agents/config/provider.py +2 -2
- shotgun/agents/context_analyzer/analyzer.py +2 -24
- shotgun/agents/conversation_manager.py +35 -19
- shotgun/agents/export.py +2 -2
- shotgun/agents/history/history_processors.py +99 -3
- shotgun/agents/history/token_counting/anthropic.py +17 -1
- shotgun/agents/history/token_counting/base.py +14 -3
- shotgun/agents/history/token_counting/openai.py +11 -1
- shotgun/agents/history/token_counting/sentencepiece_counter.py +8 -0
- shotgun/agents/history/token_counting/tokenizer_cache.py +3 -1
- shotgun/agents/history/token_counting/utils.py +0 -3
- shotgun/agents/plan.py +2 -2
- shotgun/agents/research.py +3 -3
- shotgun/agents/specify.py +2 -2
- shotgun/agents/tasks.py +2 -2
- shotgun/agents/tools/codebase/file_read.py +5 -2
- shotgun/agents/tools/file_management.py +11 -7
- shotgun/agents/tools/web_search/__init__.py +8 -8
- shotgun/agents/tools/web_search/anthropic.py +2 -2
- shotgun/agents/tools/web_search/gemini.py +1 -1
- shotgun/agents/tools/web_search/openai.py +1 -1
- shotgun/agents/tools/web_search/utils.py +2 -2
- shotgun/agents/usage_manager.py +16 -11
- shotgun/cli/clear.py +2 -1
- shotgun/cli/compact.py +3 -3
- shotgun/cli/config.py +8 -5
- shotgun/cli/context.py +2 -2
- shotgun/cli/export.py +1 -1
- shotgun/cli/feedback.py +4 -2
- shotgun/cli/plan.py +1 -1
- shotgun/cli/research.py +1 -1
- shotgun/cli/specify.py +1 -1
- shotgun/cli/tasks.py +1 -1
- shotgun/codebase/core/change_detector.py +5 -3
- shotgun/codebase/core/code_retrieval.py +4 -2
- shotgun/codebase/core/ingestor.py +10 -8
- shotgun/codebase/core/manager.py +3 -3
- shotgun/codebase/core/nl_query.py +1 -1
- shotgun/exceptions.py +32 -0
- shotgun/logging_config.py +10 -17
- shotgun/main.py +3 -1
- shotgun/posthog_telemetry.py +14 -4
- shotgun/sentry_telemetry.py +22 -2
- shotgun/telemetry.py +3 -1
- shotgun/tui/app.py +71 -65
- shotgun/tui/components/context_indicator.py +43 -0
- shotgun/tui/containers.py +15 -17
- shotgun/tui/dependencies.py +2 -2
- shotgun/tui/screens/chat/chat_screen.py +164 -40
- shotgun/tui/screens/chat/help_text.py +16 -15
- shotgun/tui/screens/chat_screen/command_providers.py +10 -0
- shotgun/tui/screens/feedback.py +4 -4
- shotgun/tui/screens/github_issue.py +102 -0
- shotgun/tui/screens/model_picker.py +21 -20
- shotgun/tui/screens/onboarding.py +431 -0
- shotgun/tui/screens/provider_config.py +50 -27
- shotgun/tui/screens/shotgun_auth.py +2 -2
- shotgun/tui/screens/welcome.py +14 -11
- shotgun/tui/services/conversation_service.py +16 -14
- shotgun/tui/utils/mode_progress.py +14 -7
- shotgun/tui/widgets/widget_coordinator.py +15 -0
- shotgun/utils/file_system_utils.py +19 -0
- shotgun/utils/marketing.py +110 -0
- {shotgun_sh-0.2.11.dev2.dist-info → shotgun_sh-0.2.11.dev7.dist-info}/METADATA +2 -1
- {shotgun_sh-0.2.11.dev2.dist-info → shotgun_sh-0.2.11.dev7.dist-info}/RECORD +72 -68
- {shotgun_sh-0.2.11.dev2.dist-info → shotgun_sh-0.2.11.dev7.dist-info}/WHEEL +0 -0
- {shotgun_sh-0.2.11.dev2.dist-info → shotgun_sh-0.2.11.dev7.dist-info}/entry_points.txt +0 -0
- {shotgun_sh-0.2.11.dev2.dist-info → shotgun_sh-0.2.11.dev7.dist-info}/licenses/LICENSE +0 -0
shotgun/agents/agent_manager.py
CHANGED
|
@@ -58,7 +58,12 @@ from shotgun.agents.context_analyzer import (
|
|
|
58
58
|
ContextCompositionTelemetry,
|
|
59
59
|
ContextFormatter,
|
|
60
60
|
)
|
|
61
|
-
from shotgun.agents.models import
|
|
61
|
+
from shotgun.agents.models import (
|
|
62
|
+
AgentResponse,
|
|
63
|
+
AgentType,
|
|
64
|
+
FileOperation,
|
|
65
|
+
FileOperationTracker,
|
|
66
|
+
)
|
|
62
67
|
from shotgun.posthog_telemetry import track_event
|
|
63
68
|
from shotgun.tui.screens.chat_screen.hint_message import HintMessage
|
|
64
69
|
from shotgun.utils.source_detection import detect_source
|
|
@@ -169,6 +174,14 @@ class CompactionCompletedMessage(Message):
|
|
|
169
174
|
"""Event posted when conversation compaction completes."""
|
|
170
175
|
|
|
171
176
|
|
|
177
|
+
class AgentStreamingStarted(Message):
|
|
178
|
+
"""Event posted when agent starts streaming responses."""
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class AgentStreamingCompleted(Message):
|
|
182
|
+
"""Event posted when agent finishes streaming responses."""
|
|
183
|
+
|
|
184
|
+
|
|
172
185
|
@dataclass(frozen=True)
|
|
173
186
|
class ModelConfigUpdated:
|
|
174
187
|
"""Data returned when AI model configuration changes.
|
|
@@ -222,7 +235,7 @@ class AgentManager(Widget):
|
|
|
222
235
|
self.deps = deps
|
|
223
236
|
|
|
224
237
|
# Create AgentRuntimeOptions from deps for agent creation
|
|
225
|
-
|
|
238
|
+
self._agent_runtime_options = AgentRuntimeOptions(
|
|
226
239
|
interactive_mode=self.deps.interactive_mode,
|
|
227
240
|
working_directory=self.deps.working_directory,
|
|
228
241
|
is_tui_context=self.deps.is_tui_context,
|
|
@@ -231,22 +244,18 @@ class AgentManager(Widget):
|
|
|
231
244
|
tasks=self.deps.tasks,
|
|
232
245
|
)
|
|
233
246
|
|
|
234
|
-
#
|
|
235
|
-
self.
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
self.
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
self.
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
self.
|
|
245
|
-
|
|
246
|
-
)
|
|
247
|
-
self.export_agent, self.export_deps = create_export_agent(
|
|
248
|
-
agent_runtime_options=agent_runtime_options
|
|
249
|
-
)
|
|
247
|
+
# Lazy initialization - agents created on first access
|
|
248
|
+
self._research_agent: Agent[AgentDeps, AgentResponse] | None = None
|
|
249
|
+
self._research_deps: AgentDeps | None = None
|
|
250
|
+
self._plan_agent: Agent[AgentDeps, AgentResponse] | None = None
|
|
251
|
+
self._plan_deps: AgentDeps | None = None
|
|
252
|
+
self._tasks_agent: Agent[AgentDeps, AgentResponse] | None = None
|
|
253
|
+
self._tasks_deps: AgentDeps | None = None
|
|
254
|
+
self._specify_agent: Agent[AgentDeps, AgentResponse] | None = None
|
|
255
|
+
self._specify_deps: AgentDeps | None = None
|
|
256
|
+
self._export_agent: Agent[AgentDeps, AgentResponse] | None = None
|
|
257
|
+
self._export_deps: AgentDeps | None = None
|
|
258
|
+
self._agents_initialized = False
|
|
250
259
|
|
|
251
260
|
# Track current active agent
|
|
252
261
|
self._current_agent_type: AgentType = initial_type
|
|
@@ -261,6 +270,119 @@ class AgentManager(Widget):
|
|
|
261
270
|
self._qa_questions: list[str] | None = None
|
|
262
271
|
self._qa_mode_active: bool = False
|
|
263
272
|
|
|
273
|
+
async def _ensure_agents_initialized(self) -> None:
|
|
274
|
+
"""Ensure all agents are initialized (lazy initialization)."""
|
|
275
|
+
if self._agents_initialized:
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
# Initialize all agents asynchronously
|
|
279
|
+
self._research_agent, self._research_deps = await create_research_agent(
|
|
280
|
+
agent_runtime_options=self._agent_runtime_options
|
|
281
|
+
)
|
|
282
|
+
self._plan_agent, self._plan_deps = await create_plan_agent(
|
|
283
|
+
agent_runtime_options=self._agent_runtime_options
|
|
284
|
+
)
|
|
285
|
+
self._tasks_agent, self._tasks_deps = await create_tasks_agent(
|
|
286
|
+
agent_runtime_options=self._agent_runtime_options
|
|
287
|
+
)
|
|
288
|
+
self._specify_agent, self._specify_deps = await create_specify_agent(
|
|
289
|
+
agent_runtime_options=self._agent_runtime_options
|
|
290
|
+
)
|
|
291
|
+
self._export_agent, self._export_deps = await create_export_agent(
|
|
292
|
+
agent_runtime_options=self._agent_runtime_options
|
|
293
|
+
)
|
|
294
|
+
self._agents_initialized = True
|
|
295
|
+
|
|
296
|
+
@property
|
|
297
|
+
def research_agent(self) -> Agent[AgentDeps, AgentResponse]:
|
|
298
|
+
"""Get research agent (must call _ensure_agents_initialized first)."""
|
|
299
|
+
if self._research_agent is None:
|
|
300
|
+
raise RuntimeError(
|
|
301
|
+
"Agents not initialized. Call _ensure_agents_initialized() first."
|
|
302
|
+
)
|
|
303
|
+
return self._research_agent
|
|
304
|
+
|
|
305
|
+
@property
|
|
306
|
+
def research_deps(self) -> AgentDeps:
|
|
307
|
+
"""Get research deps (must call _ensure_agents_initialized first)."""
|
|
308
|
+
if self._research_deps is None:
|
|
309
|
+
raise RuntimeError(
|
|
310
|
+
"Agents not initialized. Call _ensure_agents_initialized() first."
|
|
311
|
+
)
|
|
312
|
+
return self._research_deps
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def plan_agent(self) -> Agent[AgentDeps, AgentResponse]:
|
|
316
|
+
"""Get plan agent (must call _ensure_agents_initialized first)."""
|
|
317
|
+
if self._plan_agent is None:
|
|
318
|
+
raise RuntimeError(
|
|
319
|
+
"Agents not initialized. Call _ensure_agents_initialized() first."
|
|
320
|
+
)
|
|
321
|
+
return self._plan_agent
|
|
322
|
+
|
|
323
|
+
@property
|
|
324
|
+
def plan_deps(self) -> AgentDeps:
|
|
325
|
+
"""Get plan deps (must call _ensure_agents_initialized first)."""
|
|
326
|
+
if self._plan_deps is None:
|
|
327
|
+
raise RuntimeError(
|
|
328
|
+
"Agents not initialized. Call _ensure_agents_initialized() first."
|
|
329
|
+
)
|
|
330
|
+
return self._plan_deps
|
|
331
|
+
|
|
332
|
+
@property
|
|
333
|
+
def tasks_agent(self) -> Agent[AgentDeps, AgentResponse]:
|
|
334
|
+
"""Get tasks agent (must call _ensure_agents_initialized first)."""
|
|
335
|
+
if self._tasks_agent is None:
|
|
336
|
+
raise RuntimeError(
|
|
337
|
+
"Agents not initialized. Call _ensure_agents_initialized() first."
|
|
338
|
+
)
|
|
339
|
+
return self._tasks_agent
|
|
340
|
+
|
|
341
|
+
@property
|
|
342
|
+
def tasks_deps(self) -> AgentDeps:
|
|
343
|
+
"""Get tasks deps (must call _ensure_agents_initialized first)."""
|
|
344
|
+
if self._tasks_deps is None:
|
|
345
|
+
raise RuntimeError(
|
|
346
|
+
"Agents not initialized. Call _ensure_agents_initialized() first."
|
|
347
|
+
)
|
|
348
|
+
return self._tasks_deps
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def specify_agent(self) -> Agent[AgentDeps, AgentResponse]:
|
|
352
|
+
"""Get specify agent (must call _ensure_agents_initialized first)."""
|
|
353
|
+
if self._specify_agent is None:
|
|
354
|
+
raise RuntimeError(
|
|
355
|
+
"Agents not initialized. Call _ensure_agents_initialized() first."
|
|
356
|
+
)
|
|
357
|
+
return self._specify_agent
|
|
358
|
+
|
|
359
|
+
@property
|
|
360
|
+
def specify_deps(self) -> AgentDeps:
|
|
361
|
+
"""Get specify deps (must call _ensure_agents_initialized first)."""
|
|
362
|
+
if self._specify_deps is None:
|
|
363
|
+
raise RuntimeError(
|
|
364
|
+
"Agents not initialized. Call _ensure_agents_initialized() first."
|
|
365
|
+
)
|
|
366
|
+
return self._specify_deps
|
|
367
|
+
|
|
368
|
+
@property
|
|
369
|
+
def export_agent(self) -> Agent[AgentDeps, AgentResponse]:
|
|
370
|
+
"""Get export agent (must call _ensure_agents_initialized first)."""
|
|
371
|
+
if self._export_agent is None:
|
|
372
|
+
raise RuntimeError(
|
|
373
|
+
"Agents not initialized. Call _ensure_agents_initialized() first."
|
|
374
|
+
)
|
|
375
|
+
return self._export_agent
|
|
376
|
+
|
|
377
|
+
@property
|
|
378
|
+
def export_deps(self) -> AgentDeps:
|
|
379
|
+
"""Get export deps (must call _ensure_agents_initialized first)."""
|
|
380
|
+
if self._export_deps is None:
|
|
381
|
+
raise RuntimeError(
|
|
382
|
+
"Agents not initialized. Call _ensure_agents_initialized() first."
|
|
383
|
+
)
|
|
384
|
+
return self._export_deps
|
|
385
|
+
|
|
264
386
|
@property
|
|
265
387
|
def current_agent(self) -> Agent[AgentDeps, AgentResponse]:
|
|
266
388
|
"""Get the currently active agent.
|
|
@@ -412,6 +534,9 @@ class AgentManager(Widget):
|
|
|
412
534
|
Returns:
|
|
413
535
|
The agent run result.
|
|
414
536
|
"""
|
|
537
|
+
# Ensure agents are initialized before running
|
|
538
|
+
await self._ensure_agents_initialized()
|
|
539
|
+
|
|
415
540
|
logger.info(f"Running agent {self._current_agent_type.value}")
|
|
416
541
|
# Use merged deps (shared state + agent-specific system prompt) if not provided
|
|
417
542
|
if deps is None:
|
|
@@ -649,6 +774,12 @@ class AgentManager(Widget):
|
|
|
649
774
|
HintMessage(message=agent_response.response)
|
|
650
775
|
)
|
|
651
776
|
|
|
777
|
+
# Add file operation hints before questions (so they appear first in UI)
|
|
778
|
+
if file_operations:
|
|
779
|
+
file_hint = self._create_file_operation_hint(file_operations)
|
|
780
|
+
if file_hint:
|
|
781
|
+
self.ui_message_history.append(HintMessage(message=file_hint))
|
|
782
|
+
|
|
652
783
|
if len(agent_response.clarifying_questions) == 1:
|
|
653
784
|
# Single question - treat as non-blocking suggestion, DON'T enter Q&A mode
|
|
654
785
|
self.ui_message_history.append(
|
|
@@ -684,11 +815,9 @@ class AgentManager(Widget):
|
|
|
684
815
|
)
|
|
685
816
|
)
|
|
686
817
|
|
|
687
|
-
# Post UI update with hint messages
|
|
688
|
-
logger.debug(
|
|
689
|
-
|
|
690
|
-
)
|
|
691
|
-
self._post_messages_updated(file_operations)
|
|
818
|
+
# Post UI update with hint messages (file operations will be posted after compaction)
|
|
819
|
+
logger.debug("Posting UI update for Q&A mode with hint messages")
|
|
820
|
+
self._post_messages_updated([])
|
|
692
821
|
else:
|
|
693
822
|
# No clarifying questions - show the response or a default success message
|
|
694
823
|
if agent_response.response and agent_response.response.strip():
|
|
@@ -723,10 +852,9 @@ class AgentManager(Widget):
|
|
|
723
852
|
)
|
|
724
853
|
|
|
725
854
|
# Post UI update immediately so user sees the response without delay
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
)
|
|
729
|
-
self._post_messages_updated(file_operations)
|
|
855
|
+
# (file operations will be posted after compaction to avoid duplicates)
|
|
856
|
+
logger.debug("Posting immediate UI update with hint message")
|
|
857
|
+
self._post_messages_updated([])
|
|
730
858
|
|
|
731
859
|
# Apply compaction to persistent message history to prevent cascading growth
|
|
732
860
|
all_messages = result.all_messages()
|
|
@@ -780,7 +908,7 @@ class AgentManager(Widget):
|
|
|
780
908
|
|
|
781
909
|
usage = result.usage()
|
|
782
910
|
if hasattr(deps, "llm_model") and deps.llm_model is not None:
|
|
783
|
-
deps.usage_manager.add_usage(
|
|
911
|
+
await deps.usage_manager.add_usage(
|
|
784
912
|
usage, model_name=deps.llm_model.name, provider=deps.llm_model.provider
|
|
785
913
|
)
|
|
786
914
|
else:
|
|
@@ -806,6 +934,9 @@ class AgentManager(Widget):
|
|
|
806
934
|
) -> None:
|
|
807
935
|
"""Process streamed events and forward partial updates to the UI."""
|
|
808
936
|
|
|
937
|
+
# Notify UI that streaming has started
|
|
938
|
+
self.post_message(AgentStreamingStarted())
|
|
939
|
+
|
|
809
940
|
state = self._stream_state
|
|
810
941
|
if state is None:
|
|
811
942
|
state = self._stream_state = _PartialStreamState()
|
|
@@ -984,6 +1115,9 @@ class AgentManager(Widget):
|
|
|
984
1115
|
self._post_partial_message(True)
|
|
985
1116
|
state.current_response = None
|
|
986
1117
|
|
|
1118
|
+
# Notify UI that streaming has completed
|
|
1119
|
+
self.post_message(AgentStreamingCompleted())
|
|
1120
|
+
|
|
987
1121
|
def _build_partial_response(
|
|
988
1122
|
self, parts: list[ModelResponsePart | ToolCallPartDelta]
|
|
989
1123
|
) -> ModelResponse | None:
|
|
@@ -1011,6 +1145,38 @@ class AgentManager(Widget):
|
|
|
1011
1145
|
)
|
|
1012
1146
|
)
|
|
1013
1147
|
|
|
1148
|
+
def _create_file_operation_hint(
|
|
1149
|
+
self, file_operations: list[FileOperation]
|
|
1150
|
+
) -> str | None:
|
|
1151
|
+
"""Create a hint message for file operations.
|
|
1152
|
+
|
|
1153
|
+
Args:
|
|
1154
|
+
file_operations: List of file operations to create a hint for
|
|
1155
|
+
|
|
1156
|
+
Returns:
|
|
1157
|
+
Hint message string or None if no operations
|
|
1158
|
+
"""
|
|
1159
|
+
if not file_operations:
|
|
1160
|
+
return None
|
|
1161
|
+
|
|
1162
|
+
tracker = FileOperationTracker(operations=file_operations)
|
|
1163
|
+
display_path = tracker.get_display_path()
|
|
1164
|
+
|
|
1165
|
+
if not display_path:
|
|
1166
|
+
return None
|
|
1167
|
+
|
|
1168
|
+
path_obj = Path(display_path)
|
|
1169
|
+
|
|
1170
|
+
if len(file_operations) == 1:
|
|
1171
|
+
return f"📝 Modified: `{display_path}`"
|
|
1172
|
+
else:
|
|
1173
|
+
num_files = len({op.file_path for op in file_operations})
|
|
1174
|
+
if path_obj.is_dir():
|
|
1175
|
+
return f"📁 Modified {num_files} files in: `{display_path}`"
|
|
1176
|
+
else:
|
|
1177
|
+
# Common path is a file, show parent directory
|
|
1178
|
+
return f"📁 Modified {num_files} files in: `{path_obj.parent}`"
|
|
1179
|
+
|
|
1014
1180
|
def _post_messages_updated(
|
|
1015
1181
|
self, file_operations: list[FileOperation] | None = None
|
|
1016
1182
|
) -> None:
|
shotgun/agents/common.py
CHANGED
|
@@ -4,6 +4,7 @@ from collections.abc import Callable
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import Any
|
|
6
6
|
|
|
7
|
+
import aiofiles
|
|
7
8
|
from pydantic_ai import (
|
|
8
9
|
Agent,
|
|
9
10
|
RunContext,
|
|
@@ -68,7 +69,7 @@ async def add_system_status_message(
|
|
|
68
69
|
existing_files = get_agent_existing_files(deps.agent_mode)
|
|
69
70
|
|
|
70
71
|
# Extract table of contents from the agent's markdown file
|
|
71
|
-
markdown_toc = extract_markdown_toc(deps.agent_mode)
|
|
72
|
+
markdown_toc = await extract_markdown_toc(deps.agent_mode)
|
|
72
73
|
|
|
73
74
|
# Get current datetime with timezone information
|
|
74
75
|
dt_context = get_datetime_context()
|
|
@@ -94,7 +95,7 @@ async def add_system_status_message(
|
|
|
94
95
|
return message_history
|
|
95
96
|
|
|
96
97
|
|
|
97
|
-
def create_base_agent(
|
|
98
|
+
async def create_base_agent(
|
|
98
99
|
system_prompt_fn: Callable[[RunContext[AgentDeps]], str],
|
|
99
100
|
agent_runtime_options: AgentRuntimeOptions,
|
|
100
101
|
load_codebase_understanding_tools: bool = True,
|
|
@@ -119,7 +120,7 @@ def create_base_agent(
|
|
|
119
120
|
|
|
120
121
|
# Get configured model or fall back to first available provider
|
|
121
122
|
try:
|
|
122
|
-
model_config = get_provider_model(provider)
|
|
123
|
+
model_config = await get_provider_model(provider)
|
|
123
124
|
provider_name = model_config.provider
|
|
124
125
|
logger.debug(
|
|
125
126
|
"🤖 Creating agent with configured %s model: %s",
|
|
@@ -194,7 +195,7 @@ def create_base_agent(
|
|
|
194
195
|
return agent, deps
|
|
195
196
|
|
|
196
197
|
|
|
197
|
-
def _extract_file_toc_content(
|
|
198
|
+
async def _extract_file_toc_content(
|
|
198
199
|
file_path: Path, max_depth: int | None = None, max_chars: int = 500
|
|
199
200
|
) -> str | None:
|
|
200
201
|
"""Extract TOC from a single file with depth and character limits.
|
|
@@ -211,7 +212,8 @@ def _extract_file_toc_content(
|
|
|
211
212
|
return None
|
|
212
213
|
|
|
213
214
|
try:
|
|
214
|
-
|
|
215
|
+
async with aiofiles.open(file_path, encoding="utf-8") as f:
|
|
216
|
+
content = await f.read()
|
|
215
217
|
lines = content.split("\n")
|
|
216
218
|
|
|
217
219
|
# Extract headings
|
|
@@ -257,7 +259,7 @@ def _extract_file_toc_content(
|
|
|
257
259
|
return None
|
|
258
260
|
|
|
259
261
|
|
|
260
|
-
def extract_markdown_toc(agent_mode: AgentType | None) -> str | None:
|
|
262
|
+
async def extract_markdown_toc(agent_mode: AgentType | None) -> str | None:
|
|
261
263
|
"""Extract TOCs from current and prior agents' files in the pipeline.
|
|
262
264
|
|
|
263
265
|
Shows full TOC of agent's own file and high-level summaries of prior agents'
|
|
@@ -309,7 +311,9 @@ def extract_markdown_toc(agent_mode: AgentType | None) -> str | None:
|
|
|
309
311
|
for prior_file in config.prior_files:
|
|
310
312
|
file_path = base_path / prior_file
|
|
311
313
|
# Only show # and ## headings from prior files, max 500 chars each
|
|
312
|
-
prior_toc = _extract_file_toc_content(
|
|
314
|
+
prior_toc = await _extract_file_toc_content(
|
|
315
|
+
file_path, max_depth=2, max_chars=500
|
|
316
|
+
)
|
|
313
317
|
if prior_toc:
|
|
314
318
|
# Add section with XML tags
|
|
315
319
|
toc_sections.append(
|
|
@@ -321,7 +325,9 @@ def extract_markdown_toc(agent_mode: AgentType | None) -> str | None:
|
|
|
321
325
|
# Extract TOC from own file (full detail)
|
|
322
326
|
if config.own_file:
|
|
323
327
|
own_path = base_path / config.own_file
|
|
324
|
-
own_toc = _extract_file_toc_content(
|
|
328
|
+
own_toc = await _extract_file_toc_content(
|
|
329
|
+
own_path, max_depth=None, max_chars=2000
|
|
330
|
+
)
|
|
325
331
|
if own_toc:
|
|
326
332
|
# Put own file TOC at the beginning with XML tags
|
|
327
333
|
toc_sections.insert(
|
shotgun/agents/config/manager.py
CHANGED
|
@@ -5,6 +5,8 @@ import uuid
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
+
import aiofiles
|
|
9
|
+
import aiofiles.os
|
|
8
10
|
from pydantic import SecretStr
|
|
9
11
|
|
|
10
12
|
from shotgun.logging_config import get_logger
|
|
@@ -48,7 +50,7 @@ class ConfigManager:
|
|
|
48
50
|
|
|
49
51
|
self._config: ShotgunConfig | None = None
|
|
50
52
|
|
|
51
|
-
def load(self, force_reload: bool = True) -> ShotgunConfig:
|
|
53
|
+
async def load(self, force_reload: bool = True) -> ShotgunConfig:
|
|
52
54
|
"""Load configuration from file.
|
|
53
55
|
|
|
54
56
|
Args:
|
|
@@ -60,18 +62,19 @@ class ConfigManager:
|
|
|
60
62
|
if self._config is not None and not force_reload:
|
|
61
63
|
return self._config
|
|
62
64
|
|
|
63
|
-
if not
|
|
65
|
+
if not await aiofiles.os.path.exists(self.config_path):
|
|
64
66
|
logger.info(
|
|
65
67
|
"Configuration file not found, creating new config at: %s",
|
|
66
68
|
self.config_path,
|
|
67
69
|
)
|
|
68
70
|
# Create new config with generated shotgun_instance_id
|
|
69
|
-
self._config = self.initialize()
|
|
71
|
+
self._config = await self.initialize()
|
|
70
72
|
return self._config
|
|
71
73
|
|
|
72
74
|
try:
|
|
73
|
-
with open(self.config_path, encoding="utf-8") as f:
|
|
74
|
-
|
|
75
|
+
async with aiofiles.open(self.config_path, encoding="utf-8") as f:
|
|
76
|
+
content = await f.read()
|
|
77
|
+
data = json.loads(content)
|
|
75
78
|
|
|
76
79
|
# Migration: Rename user_id to shotgun_instance_id (config v2 -> v3)
|
|
77
80
|
if "user_id" in data and SHOTGUN_INSTANCE_ID_FIELD not in data:
|
|
@@ -101,6 +104,12 @@ class ConfigManager:
|
|
|
101
104
|
"Existing BYOK user detected: set shown_welcome_screen=False to show welcome screen"
|
|
102
105
|
)
|
|
103
106
|
|
|
107
|
+
# Migration: Add marketing config for v3 -> v4
|
|
108
|
+
if "marketing" not in data:
|
|
109
|
+
data["marketing"] = {"messages": {}}
|
|
110
|
+
data["config_version"] = 4
|
|
111
|
+
logger.info("Migrated config v3->v4: added marketing configuration")
|
|
112
|
+
|
|
104
113
|
# Convert plain text secrets to SecretStr objects
|
|
105
114
|
self._convert_secrets_to_secretstr(data)
|
|
106
115
|
|
|
@@ -117,7 +126,7 @@ class ConfigManager:
|
|
|
117
126
|
|
|
118
127
|
if self._config.selected_model in MODEL_SPECS:
|
|
119
128
|
spec = MODEL_SPECS[self._config.selected_model]
|
|
120
|
-
if not self.has_provider_key(spec.provider):
|
|
129
|
+
if not await self.has_provider_key(spec.provider):
|
|
121
130
|
logger.info(
|
|
122
131
|
"Selected model %s provider has no API key, finding available model",
|
|
123
132
|
self._config.selected_model.value,
|
|
@@ -135,7 +144,7 @@ class ConfigManager:
|
|
|
135
144
|
# If no selected_model or it was invalid, find first available model
|
|
136
145
|
if not self._config.selected_model:
|
|
137
146
|
for provider in ProviderType:
|
|
138
|
-
if self.has_provider_key(provider):
|
|
147
|
+
if await self.has_provider_key(provider):
|
|
139
148
|
# Set to that provider's default model
|
|
140
149
|
from .models import MODEL_SPECS, ModelName
|
|
141
150
|
|
|
@@ -156,7 +165,7 @@ class ConfigManager:
|
|
|
156
165
|
break
|
|
157
166
|
|
|
158
167
|
if should_save:
|
|
159
|
-
self.save(self._config)
|
|
168
|
+
await self.save(self._config)
|
|
160
169
|
|
|
161
170
|
return self._config
|
|
162
171
|
|
|
@@ -165,10 +174,10 @@ class ConfigManager:
|
|
|
165
174
|
"Failed to load configuration from %s: %s", self.config_path, e
|
|
166
175
|
)
|
|
167
176
|
logger.info("Creating new configuration with generated shotgun_instance_id")
|
|
168
|
-
self._config = self.initialize()
|
|
177
|
+
self._config = await self.initialize()
|
|
169
178
|
return self._config
|
|
170
179
|
|
|
171
|
-
def save(self, config: ShotgunConfig | None = None) -> None:
|
|
180
|
+
async def save(self, config: ShotgunConfig | None = None) -> None:
|
|
172
181
|
"""Save configuration to file.
|
|
173
182
|
|
|
174
183
|
Args:
|
|
@@ -184,15 +193,17 @@ class ConfigManager:
|
|
|
184
193
|
)
|
|
185
194
|
|
|
186
195
|
# Ensure directory exists
|
|
187
|
-
self.config_path.parent
|
|
196
|
+
await aiofiles.os.makedirs(self.config_path.parent, exist_ok=True)
|
|
188
197
|
|
|
189
198
|
try:
|
|
190
199
|
# Convert SecretStr to plain text for JSON serialization
|
|
191
200
|
data = config.model_dump()
|
|
192
201
|
self._convert_secretstr_to_plain(data)
|
|
202
|
+
self._convert_datetime_to_isoformat(data)
|
|
193
203
|
|
|
194
|
-
|
|
195
|
-
|
|
204
|
+
json_content = json.dumps(data, indent=2, ensure_ascii=False)
|
|
205
|
+
async with aiofiles.open(self.config_path, "w", encoding="utf-8") as f:
|
|
206
|
+
await f.write(json_content)
|
|
196
207
|
|
|
197
208
|
logger.debug("Configuration saved to %s", self.config_path)
|
|
198
209
|
self._config = config
|
|
@@ -201,14 +212,16 @@ class ConfigManager:
|
|
|
201
212
|
logger.error("Failed to save configuration to %s: %s", self.config_path, e)
|
|
202
213
|
raise
|
|
203
214
|
|
|
204
|
-
def update_provider(
|
|
215
|
+
async def update_provider(
|
|
216
|
+
self, provider: ProviderType | str, **kwargs: Any
|
|
217
|
+
) -> None:
|
|
205
218
|
"""Update provider configuration.
|
|
206
219
|
|
|
207
220
|
Args:
|
|
208
221
|
provider: Provider to update
|
|
209
222
|
**kwargs: Configuration fields to update (only api_key supported)
|
|
210
223
|
"""
|
|
211
|
-
config = self.load()
|
|
224
|
+
config = await self.load()
|
|
212
225
|
|
|
213
226
|
# Get provider config and check if it's shotgun
|
|
214
227
|
provider_config, is_shotgun = self._get_provider_config_and_type(
|
|
@@ -253,11 +266,11 @@ class ConfigManager:
|
|
|
253
266
|
# This prevents the welcome screen from showing again after user has made their choice
|
|
254
267
|
config.shown_welcome_screen = True
|
|
255
268
|
|
|
256
|
-
self.save(config)
|
|
269
|
+
await self.save(config)
|
|
257
270
|
|
|
258
|
-
def clear_provider_key(self, provider: ProviderType | str) -> None:
|
|
271
|
+
async def clear_provider_key(self, provider: ProviderType | str) -> None:
|
|
259
272
|
"""Remove the API key for the given provider (LLM provider or shotgun)."""
|
|
260
|
-
config = self.load()
|
|
273
|
+
config = await self.load()
|
|
261
274
|
|
|
262
275
|
# Get provider config (shotgun or LLM provider)
|
|
263
276
|
provider_config, is_shotgun = self._get_provider_config_and_type(
|
|
@@ -270,34 +283,34 @@ class ConfigManager:
|
|
|
270
283
|
if is_shotgun and isinstance(provider_config, ShotgunAccountConfig):
|
|
271
284
|
provider_config.supabase_jwt = None
|
|
272
285
|
|
|
273
|
-
self.save(config)
|
|
286
|
+
await self.save(config)
|
|
274
287
|
|
|
275
|
-
def update_selected_model(self, model_name: "ModelName") -> None:
|
|
288
|
+
async def update_selected_model(self, model_name: "ModelName") -> None:
|
|
276
289
|
"""Update the selected model.
|
|
277
290
|
|
|
278
291
|
Args:
|
|
279
292
|
model_name: Model to select
|
|
280
293
|
"""
|
|
281
|
-
config = self.load()
|
|
294
|
+
config = await self.load()
|
|
282
295
|
config.selected_model = model_name
|
|
283
|
-
self.save(config)
|
|
296
|
+
await self.save(config)
|
|
284
297
|
|
|
285
|
-
def has_provider_key(self, provider: ProviderType | str) -> bool:
|
|
298
|
+
async def has_provider_key(self, provider: ProviderType | str) -> bool:
|
|
286
299
|
"""Check if the given provider has a non-empty API key configured.
|
|
287
300
|
|
|
288
301
|
This checks only the configuration file.
|
|
289
302
|
"""
|
|
290
303
|
# Use force_reload=False to avoid infinite loop when called from load()
|
|
291
|
-
config = self.load(force_reload=False)
|
|
304
|
+
config = await self.load(force_reload=False)
|
|
292
305
|
provider_enum = self._ensure_provider_enum(provider)
|
|
293
306
|
provider_config = self._get_provider_config(config, provider_enum)
|
|
294
307
|
|
|
295
308
|
return self._provider_has_api_key(provider_config)
|
|
296
309
|
|
|
297
|
-
def has_any_provider_key(self) -> bool:
|
|
310
|
+
async def has_any_provider_key(self) -> bool:
|
|
298
311
|
"""Determine whether any provider has a configured API key."""
|
|
299
312
|
# Use force_reload=False to avoid infinite loop when called from load()
|
|
300
|
-
config = self.load(force_reload=False)
|
|
313
|
+
config = await self.load(force_reload=False)
|
|
301
314
|
# Check LLM provider keys (BYOK)
|
|
302
315
|
has_llm_key = any(
|
|
303
316
|
self._provider_has_api_key(self._get_provider_config(config, provider))
|
|
@@ -311,7 +324,7 @@ class ConfigManager:
|
|
|
311
324
|
has_shotgun_key = self._provider_has_api_key(config.shotgun)
|
|
312
325
|
return has_llm_key or has_shotgun_key
|
|
313
326
|
|
|
314
|
-
def initialize(self) -> ShotgunConfig:
|
|
327
|
+
async def initialize(self) -> ShotgunConfig:
|
|
315
328
|
"""Initialize configuration with defaults and save to file.
|
|
316
329
|
|
|
317
330
|
Returns:
|
|
@@ -321,7 +334,7 @@ class ConfigManager:
|
|
|
321
334
|
config = ShotgunConfig(
|
|
322
335
|
shotgun_instance_id=str(uuid.uuid4()),
|
|
323
336
|
)
|
|
324
|
-
self.save(config)
|
|
337
|
+
await self.save(config)
|
|
325
338
|
logger.info(
|
|
326
339
|
"Configuration initialized at %s with shotgun_instance_id: %s",
|
|
327
340
|
self.config_path,
|
|
@@ -377,6 +390,24 @@ class ConfigManager:
|
|
|
377
390
|
SUPABASE_JWT_FIELD
|
|
378
391
|
].get_secret_value()
|
|
379
392
|
|
|
393
|
+
def _convert_datetime_to_isoformat(self, data: dict[str, Any]) -> None:
|
|
394
|
+
"""Convert datetime objects in data to ISO8601 format strings for JSON serialization."""
|
|
395
|
+
from datetime import datetime
|
|
396
|
+
|
|
397
|
+
def convert_dict(d: dict[str, Any]) -> None:
|
|
398
|
+
"""Recursively convert datetime objects in a dict."""
|
|
399
|
+
for key, value in d.items():
|
|
400
|
+
if isinstance(value, datetime):
|
|
401
|
+
d[key] = value.isoformat()
|
|
402
|
+
elif isinstance(value, dict):
|
|
403
|
+
convert_dict(value)
|
|
404
|
+
elif isinstance(value, list):
|
|
405
|
+
for item in value:
|
|
406
|
+
if isinstance(item, dict):
|
|
407
|
+
convert_dict(item)
|
|
408
|
+
|
|
409
|
+
convert_dict(data)
|
|
410
|
+
|
|
380
411
|
def _ensure_provider_enum(self, provider: ProviderType | str) -> ProviderType:
|
|
381
412
|
"""Normalize provider values to ProviderType enum."""
|
|
382
413
|
return (
|
|
@@ -440,16 +471,16 @@ class ConfigManager:
|
|
|
440
471
|
provider_enum = self._ensure_provider_enum(provider)
|
|
441
472
|
return (self._get_provider_config(config, provider_enum), False)
|
|
442
473
|
|
|
443
|
-
def get_shotgun_instance_id(self) -> str:
|
|
474
|
+
async def get_shotgun_instance_id(self) -> str:
|
|
444
475
|
"""Get the shotgun instance ID from configuration.
|
|
445
476
|
|
|
446
477
|
Returns:
|
|
447
478
|
The unique shotgun instance ID string
|
|
448
479
|
"""
|
|
449
|
-
config = self.load()
|
|
480
|
+
config = await self.load()
|
|
450
481
|
return config.shotgun_instance_id
|
|
451
482
|
|
|
452
|
-
def update_shotgun_account(
|
|
483
|
+
async def update_shotgun_account(
|
|
453
484
|
self, api_key: str | None = None, supabase_jwt: str | None = None
|
|
454
485
|
) -> None:
|
|
455
486
|
"""Update Shotgun Account configuration.
|
|
@@ -458,7 +489,7 @@ class ConfigManager:
|
|
|
458
489
|
api_key: LiteLLM proxy API key (optional)
|
|
459
490
|
supabase_jwt: Supabase authentication JWT (optional)
|
|
460
491
|
"""
|
|
461
|
-
config = self.load()
|
|
492
|
+
config = await self.load()
|
|
462
493
|
|
|
463
494
|
if api_key is not None:
|
|
464
495
|
config.shotgun.api_key = SecretStr(api_key) if api_key else None
|
|
@@ -468,7 +499,7 @@ class ConfigManager:
|
|
|
468
499
|
SecretStr(supabase_jwt) if supabase_jwt else None
|
|
469
500
|
)
|
|
470
501
|
|
|
471
|
-
self.save(config)
|
|
502
|
+
await self.save(config)
|
|
472
503
|
logger.info("Updated Shotgun Account configuration")
|
|
473
504
|
|
|
474
505
|
|