shotgun-sh 0.2.6.dev1__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shotgun/agents/agent_manager.py +694 -73
- shotgun/agents/common.py +69 -70
- shotgun/agents/config/constants.py +0 -6
- shotgun/agents/config/manager.py +70 -35
- shotgun/agents/config/models.py +41 -1
- shotgun/agents/config/provider.py +33 -5
- shotgun/agents/context_analyzer/__init__.py +28 -0
- shotgun/agents/context_analyzer/analyzer.py +471 -0
- shotgun/agents/context_analyzer/constants.py +9 -0
- shotgun/agents/context_analyzer/formatter.py +115 -0
- shotgun/agents/context_analyzer/models.py +212 -0
- shotgun/agents/conversation_history.py +125 -2
- shotgun/agents/conversation_manager.py +57 -19
- shotgun/agents/export.py +6 -7
- shotgun/agents/history/compaction.py +9 -4
- shotgun/agents/history/context_extraction.py +93 -6
- shotgun/agents/history/history_processors.py +113 -5
- shotgun/agents/history/token_counting/anthropic.py +39 -3
- shotgun/agents/history/token_counting/base.py +14 -3
- shotgun/agents/history/token_counting/openai.py +11 -1
- shotgun/agents/history/token_counting/sentencepiece_counter.py +8 -0
- shotgun/agents/history/token_counting/tokenizer_cache.py +3 -1
- shotgun/agents/history/token_counting/utils.py +0 -3
- shotgun/agents/models.py +50 -2
- shotgun/agents/plan.py +6 -7
- shotgun/agents/research.py +7 -8
- shotgun/agents/specify.py +6 -7
- shotgun/agents/tasks.py +6 -7
- shotgun/agents/tools/__init__.py +0 -2
- shotgun/agents/tools/codebase/codebase_shell.py +6 -0
- shotgun/agents/tools/codebase/directory_lister.py +6 -0
- shotgun/agents/tools/codebase/file_read.py +11 -2
- shotgun/agents/tools/codebase/query_graph.py +6 -0
- shotgun/agents/tools/codebase/retrieve_code.py +6 -0
- shotgun/agents/tools/file_management.py +82 -16
- shotgun/agents/tools/registry.py +217 -0
- shotgun/agents/tools/web_search/__init__.py +8 -8
- shotgun/agents/tools/web_search/anthropic.py +8 -2
- shotgun/agents/tools/web_search/gemini.py +7 -1
- shotgun/agents/tools/web_search/openai.py +7 -1
- shotgun/agents/tools/web_search/utils.py +2 -2
- shotgun/agents/usage_manager.py +16 -11
- shotgun/api_endpoints.py +7 -3
- shotgun/build_constants.py +3 -3
- shotgun/cli/clear.py +53 -0
- shotgun/cli/compact.py +186 -0
- shotgun/cli/config.py +8 -5
- shotgun/cli/context.py +111 -0
- shotgun/cli/export.py +1 -1
- shotgun/cli/feedback.py +4 -2
- shotgun/cli/models.py +1 -0
- shotgun/cli/plan.py +1 -1
- shotgun/cli/research.py +1 -1
- shotgun/cli/specify.py +1 -1
- shotgun/cli/tasks.py +1 -1
- shotgun/cli/update.py +16 -2
- shotgun/codebase/core/change_detector.py +5 -3
- shotgun/codebase/core/code_retrieval.py +4 -2
- shotgun/codebase/core/ingestor.py +10 -8
- shotgun/codebase/core/manager.py +13 -4
- shotgun/codebase/core/nl_query.py +1 -1
- shotgun/exceptions.py +32 -0
- shotgun/logging_config.py +18 -27
- shotgun/main.py +73 -11
- shotgun/posthog_telemetry.py +37 -28
- shotgun/prompts/agents/export.j2 +18 -1
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +5 -1
- shotgun/prompts/agents/partials/interactive_mode.j2 +24 -7
- shotgun/prompts/agents/plan.j2 +1 -1
- shotgun/prompts/agents/research.j2 +1 -1
- shotgun/prompts/agents/specify.j2 +270 -3
- shotgun/prompts/agents/tasks.j2 +1 -1
- shotgun/sentry_telemetry.py +163 -16
- shotgun/settings.py +238 -0
- shotgun/telemetry.py +18 -33
- shotgun/tui/app.py +243 -43
- shotgun/tui/commands/__init__.py +1 -1
- shotgun/tui/components/context_indicator.py +179 -0
- shotgun/tui/components/mode_indicator.py +70 -0
- shotgun/tui/components/status_bar.py +48 -0
- shotgun/tui/containers.py +91 -0
- shotgun/tui/dependencies.py +39 -0
- shotgun/tui/protocols.py +45 -0
- shotgun/tui/screens/chat/__init__.py +5 -0
- shotgun/tui/screens/chat/chat.tcss +54 -0
- shotgun/tui/screens/chat/chat_screen.py +1254 -0
- shotgun/tui/screens/chat/codebase_index_prompt_screen.py +64 -0
- shotgun/tui/screens/chat/codebase_index_selection.py +12 -0
- shotgun/tui/screens/chat/help_text.py +40 -0
- shotgun/tui/screens/chat/prompt_history.py +48 -0
- shotgun/tui/screens/chat.tcss +11 -0
- shotgun/tui/screens/chat_screen/command_providers.py +78 -2
- shotgun/tui/screens/chat_screen/history/__init__.py +22 -0
- shotgun/tui/screens/chat_screen/history/agent_response.py +66 -0
- shotgun/tui/screens/chat_screen/history/chat_history.py +115 -0
- shotgun/tui/screens/chat_screen/history/formatters.py +115 -0
- shotgun/tui/screens/chat_screen/history/partial_response.py +43 -0
- shotgun/tui/screens/chat_screen/history/user_question.py +42 -0
- shotgun/tui/screens/confirmation_dialog.py +151 -0
- shotgun/tui/screens/feedback.py +4 -4
- shotgun/tui/screens/github_issue.py +102 -0
- shotgun/tui/screens/model_picker.py +49 -24
- shotgun/tui/screens/onboarding.py +431 -0
- shotgun/tui/screens/pipx_migration.py +153 -0
- shotgun/tui/screens/provider_config.py +50 -27
- shotgun/tui/screens/shotgun_auth.py +2 -2
- shotgun/tui/screens/welcome.py +23 -12
- shotgun/tui/services/__init__.py +5 -0
- shotgun/tui/services/conversation_service.py +184 -0
- shotgun/tui/state/__init__.py +7 -0
- shotgun/tui/state/processing_state.py +185 -0
- shotgun/tui/utils/mode_progress.py +14 -7
- shotgun/tui/widgets/__init__.py +5 -0
- shotgun/tui/widgets/widget_coordinator.py +263 -0
- shotgun/utils/file_system_utils.py +22 -2
- shotgun/utils/marketing.py +110 -0
- shotgun/utils/update_checker.py +69 -14
- shotgun_sh-0.2.17.dist-info/METADATA +465 -0
- shotgun_sh-0.2.17.dist-info/RECORD +194 -0
- {shotgun_sh-0.2.6.dev1.dist-info → shotgun_sh-0.2.17.dist-info}/entry_points.txt +1 -0
- {shotgun_sh-0.2.6.dev1.dist-info → shotgun_sh-0.2.17.dist-info}/licenses/LICENSE +1 -1
- shotgun/agents/tools/user_interaction.py +0 -37
- shotgun/tui/screens/chat.py +0 -804
- shotgun/tui/screens/chat_screen/history.py +0 -401
- shotgun_sh-0.2.6.dev1.dist-info/METADATA +0 -467
- shotgun_sh-0.2.6.dev1.dist-info/RECORD +0 -156
- {shotgun_sh-0.2.6.dev1.dist-info → shotgun_sh-0.2.17.dist-info}/WHEEL +0 -0
shotgun/agents/common.py
CHANGED
|
@@ -1,14 +1,12 @@
|
|
|
1
1
|
"""Common utilities for agent creation and management."""
|
|
2
2
|
|
|
3
|
-
import asyncio
|
|
4
3
|
from collections.abc import Callable
|
|
5
4
|
from pathlib import Path
|
|
6
5
|
from typing import Any
|
|
7
6
|
|
|
7
|
+
import aiofiles
|
|
8
8
|
from pydantic_ai import (
|
|
9
9
|
Agent,
|
|
10
|
-
DeferredToolRequests,
|
|
11
|
-
DeferredToolResults,
|
|
12
10
|
RunContext,
|
|
13
11
|
UsageLimits,
|
|
14
12
|
)
|
|
@@ -19,7 +17,7 @@ from pydantic_ai.messages import (
|
|
|
19
17
|
)
|
|
20
18
|
|
|
21
19
|
from shotgun.agents.config import ProviderType, get_provider_model
|
|
22
|
-
from shotgun.agents.models import AgentType
|
|
20
|
+
from shotgun.agents.models import AgentResponse, AgentType
|
|
23
21
|
from shotgun.logging_config import get_logger
|
|
24
22
|
from shotgun.prompts import PromptLoader
|
|
25
23
|
from shotgun.sdk.services import get_codebase_service
|
|
@@ -28,12 +26,10 @@ from shotgun.utils.datetime_utils import get_datetime_context
|
|
|
28
26
|
from shotgun.utils.file_system_utils import get_shotgun_base_path
|
|
29
27
|
|
|
30
28
|
from .history import token_limit_compactor
|
|
31
|
-
from .history.compaction import apply_persistent_compaction
|
|
32
29
|
from .messages import AgentSystemPrompt, SystemStatusPrompt
|
|
33
30
|
from .models import AgentDeps, AgentRuntimeOptions, PipelineConfigEntry
|
|
34
31
|
from .tools import (
|
|
35
32
|
append_file,
|
|
36
|
-
ask_user,
|
|
37
33
|
codebase_shell,
|
|
38
34
|
directory_lister,
|
|
39
35
|
file_read,
|
|
@@ -73,7 +69,7 @@ async def add_system_status_message(
|
|
|
73
69
|
existing_files = get_agent_existing_files(deps.agent_mode)
|
|
74
70
|
|
|
75
71
|
# Extract table of contents from the agent's markdown file
|
|
76
|
-
markdown_toc = extract_markdown_toc(deps.agent_mode)
|
|
72
|
+
markdown_toc = await extract_markdown_toc(deps.agent_mode)
|
|
77
73
|
|
|
78
74
|
# Get current datetime with timezone information
|
|
79
75
|
dt_context = get_datetime_context()
|
|
@@ -99,14 +95,14 @@ async def add_system_status_message(
|
|
|
99
95
|
return message_history
|
|
100
96
|
|
|
101
97
|
|
|
102
|
-
def create_base_agent(
|
|
98
|
+
async def create_base_agent(
|
|
103
99
|
system_prompt_fn: Callable[[RunContext[AgentDeps]], str],
|
|
104
100
|
agent_runtime_options: AgentRuntimeOptions,
|
|
105
101
|
load_codebase_understanding_tools: bool = True,
|
|
106
102
|
additional_tools: list[Any] | None = None,
|
|
107
103
|
provider: ProviderType | None = None,
|
|
108
104
|
agent_mode: AgentType | None = None,
|
|
109
|
-
) -> tuple[Agent[AgentDeps,
|
|
105
|
+
) -> tuple[Agent[AgentDeps, AgentResponse], AgentDeps]:
|
|
110
106
|
"""Create a base agent with common configuration.
|
|
111
107
|
|
|
112
108
|
Args:
|
|
@@ -124,7 +120,7 @@ def create_base_agent(
|
|
|
124
120
|
|
|
125
121
|
# Get configured model or fall back to first available provider
|
|
126
122
|
try:
|
|
127
|
-
model_config = get_provider_model(provider)
|
|
123
|
+
model_config = await get_provider_model(provider)
|
|
128
124
|
provider_name = model_config.provider
|
|
129
125
|
logger.debug(
|
|
130
126
|
"🤖 Creating agent with configured %s model: %s",
|
|
@@ -164,7 +160,7 @@ def create_base_agent(
|
|
|
164
160
|
|
|
165
161
|
agent = Agent(
|
|
166
162
|
model,
|
|
167
|
-
output_type=
|
|
163
|
+
output_type=AgentResponse,
|
|
168
164
|
deps_type=AgentDeps,
|
|
169
165
|
instrument=True,
|
|
170
166
|
history_processors=[history_processor],
|
|
@@ -179,11 +175,6 @@ def create_base_agent(
|
|
|
179
175
|
for tool in additional_tools or []:
|
|
180
176
|
agent.tool_plain(tool)
|
|
181
177
|
|
|
182
|
-
# Register interactive tool conditionally based on deps
|
|
183
|
-
if deps.interactive_mode:
|
|
184
|
-
agent.tool(ask_user)
|
|
185
|
-
logger.debug("📞 Interactive mode enabled - ask_user tool registered")
|
|
186
|
-
|
|
187
178
|
# Register common file management tools (always available)
|
|
188
179
|
agent.tool(write_file)
|
|
189
180
|
agent.tool(append_file)
|
|
@@ -204,7 +195,7 @@ def create_base_agent(
|
|
|
204
195
|
return agent, deps
|
|
205
196
|
|
|
206
197
|
|
|
207
|
-
def _extract_file_toc_content(
|
|
198
|
+
async def _extract_file_toc_content(
|
|
208
199
|
file_path: Path, max_depth: int | None = None, max_chars: int = 500
|
|
209
200
|
) -> str | None:
|
|
210
201
|
"""Extract TOC from a single file with depth and character limits.
|
|
@@ -221,7 +212,8 @@ def _extract_file_toc_content(
|
|
|
221
212
|
return None
|
|
222
213
|
|
|
223
214
|
try:
|
|
224
|
-
|
|
215
|
+
async with aiofiles.open(file_path, encoding="utf-8") as f:
|
|
216
|
+
content = await f.read()
|
|
225
217
|
lines = content.split("\n")
|
|
226
218
|
|
|
227
219
|
# Extract headings
|
|
@@ -267,7 +259,7 @@ def _extract_file_toc_content(
|
|
|
267
259
|
return None
|
|
268
260
|
|
|
269
261
|
|
|
270
|
-
def extract_markdown_toc(agent_mode: AgentType | None) -> str | None:
|
|
262
|
+
async def extract_markdown_toc(agent_mode: AgentType | None) -> str | None:
|
|
271
263
|
"""Extract TOCs from current and prior agents' files in the pipeline.
|
|
272
264
|
|
|
273
265
|
Shows full TOC of agent's own file and high-level summaries of prior agents'
|
|
@@ -319,22 +311,30 @@ def extract_markdown_toc(agent_mode: AgentType | None) -> str | None:
|
|
|
319
311
|
for prior_file in config.prior_files:
|
|
320
312
|
file_path = base_path / prior_file
|
|
321
313
|
# Only show # and ## headings from prior files, max 500 chars each
|
|
322
|
-
prior_toc = _extract_file_toc_content(
|
|
314
|
+
prior_toc = await _extract_file_toc_content(
|
|
315
|
+
file_path, max_depth=2, max_chars=500
|
|
316
|
+
)
|
|
323
317
|
if prior_toc:
|
|
324
318
|
# Add section with XML tags
|
|
325
319
|
toc_sections.append(
|
|
326
|
-
f'<TABLE_OF_CONTENTS file_name="{prior_file}">\n
|
|
320
|
+
f'<TABLE_OF_CONTENTS file_name="{prior_file}">\n'
|
|
321
|
+
f"{prior_toc}\n"
|
|
322
|
+
f"</TABLE_OF_CONTENTS>"
|
|
327
323
|
)
|
|
328
324
|
|
|
329
325
|
# Extract TOC from own file (full detail)
|
|
330
326
|
if config.own_file:
|
|
331
327
|
own_path = base_path / config.own_file
|
|
332
|
-
own_toc = _extract_file_toc_content(
|
|
328
|
+
own_toc = await _extract_file_toc_content(
|
|
329
|
+
own_path, max_depth=None, max_chars=2000
|
|
330
|
+
)
|
|
333
331
|
if own_toc:
|
|
334
332
|
# Put own file TOC at the beginning with XML tags
|
|
335
333
|
toc_sections.insert(
|
|
336
334
|
0,
|
|
337
|
-
f'<TABLE_OF_CONTENTS file_name="{config.own_file}">\n
|
|
335
|
+
f'<TABLE_OF_CONTENTS file_name="{config.own_file}">\n'
|
|
336
|
+
f"{own_toc}\n"
|
|
337
|
+
f"</TABLE_OF_CONTENTS>",
|
|
338
338
|
)
|
|
339
339
|
|
|
340
340
|
# Combine all sections
|
|
@@ -390,23 +390,48 @@ def get_agent_existing_files(agent_mode: AgentType | None = None) -> list[str]:
|
|
|
390
390
|
relative_path = file_path.relative_to(base_path)
|
|
391
391
|
existing_files.append(str(relative_path))
|
|
392
392
|
else:
|
|
393
|
-
# For other agents, check
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
#
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
#
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
393
|
+
# For other agents, check files/directories they have access to
|
|
394
|
+
allowed_paths_raw = AGENT_DIRECTORIES[agent_mode]
|
|
395
|
+
|
|
396
|
+
# Convert single Path/string to list of Paths for uniform handling
|
|
397
|
+
if isinstance(allowed_paths_raw, str):
|
|
398
|
+
# Special case: "*" means export agent (shouldn't reach here but handle it)
|
|
399
|
+
allowed_paths = (
|
|
400
|
+
[Path(allowed_paths_raw)] if allowed_paths_raw != "*" else []
|
|
401
|
+
)
|
|
402
|
+
elif isinstance(allowed_paths_raw, Path):
|
|
403
|
+
allowed_paths = [allowed_paths_raw]
|
|
404
|
+
else:
|
|
405
|
+
# Already a list
|
|
406
|
+
allowed_paths = allowed_paths_raw
|
|
407
|
+
|
|
408
|
+
# Check each allowed path
|
|
409
|
+
for allowed_path in allowed_paths:
|
|
410
|
+
allowed_str = str(allowed_path)
|
|
411
|
+
|
|
412
|
+
# Check if it's a directory (no .md suffix)
|
|
413
|
+
if not allowed_path.suffix or not allowed_str.endswith(".md"):
|
|
414
|
+
# It's a directory - list all files within it
|
|
415
|
+
dir_path = base_path / allowed_str
|
|
416
|
+
if dir_path.exists() and dir_path.is_dir():
|
|
417
|
+
for file_path in dir_path.rglob("*"):
|
|
418
|
+
if file_path.is_file():
|
|
419
|
+
relative_path = file_path.relative_to(base_path)
|
|
420
|
+
existing_files.append(str(relative_path))
|
|
421
|
+
else:
|
|
422
|
+
# It's a file - check if it exists
|
|
423
|
+
file_path = base_path / allowed_str
|
|
424
|
+
if file_path.exists():
|
|
425
|
+
existing_files.append(allowed_str)
|
|
426
|
+
|
|
427
|
+
# Also check for associated directory (e.g., research/ for research.md)
|
|
428
|
+
base_name = allowed_str.replace(".md", "")
|
|
429
|
+
dir_path = base_path / base_name
|
|
430
|
+
if dir_path.exists() and dir_path.is_dir():
|
|
431
|
+
for file_path in dir_path.rglob("*"):
|
|
432
|
+
if file_path.is_file():
|
|
433
|
+
relative_path = file_path.relative_to(base_path)
|
|
434
|
+
existing_files.append(str(relative_path))
|
|
410
435
|
|
|
411
436
|
return existing_files
|
|
412
437
|
|
|
@@ -476,7 +501,8 @@ async def add_system_prompt_message(
|
|
|
476
501
|
message_history = message_history or []
|
|
477
502
|
|
|
478
503
|
# Create a minimal RunContext to call the system prompt function
|
|
479
|
-
# We'll pass None for model and usage since they're not used
|
|
504
|
+
# We'll pass None for model and usage since they're not used
|
|
505
|
+
# by our system prompt functions
|
|
480
506
|
context = type(
|
|
481
507
|
"RunContext", (), {"deps": deps, "retry": 0, "model": None, "usage": None}
|
|
482
508
|
)()
|
|
@@ -500,12 +526,12 @@ async def add_system_prompt_message(
|
|
|
500
526
|
|
|
501
527
|
|
|
502
528
|
async def run_agent(
|
|
503
|
-
agent: Agent[AgentDeps,
|
|
529
|
+
agent: Agent[AgentDeps, AgentResponse],
|
|
504
530
|
prompt: str,
|
|
505
531
|
deps: AgentDeps,
|
|
506
532
|
message_history: list[ModelMessage] | None = None,
|
|
507
533
|
usage_limits: UsageLimits | None = None,
|
|
508
|
-
) -> AgentRunResult[
|
|
534
|
+
) -> AgentRunResult[AgentResponse]:
|
|
509
535
|
# Clear file tracker for new run
|
|
510
536
|
deps.file_tracker.clear()
|
|
511
537
|
logger.debug("🔧 Cleared file tracker for new agent run")
|
|
@@ -520,33 +546,6 @@ async def run_agent(
|
|
|
520
546
|
message_history=message_history,
|
|
521
547
|
)
|
|
522
548
|
|
|
523
|
-
# Apply persistent compaction to prevent cascading token growth across CLI commands
|
|
524
|
-
messages = await apply_persistent_compaction(result.all_messages(), deps)
|
|
525
|
-
while isinstance(result.output, DeferredToolRequests):
|
|
526
|
-
logger.info("got deferred tool requests")
|
|
527
|
-
await deps.queue.join()
|
|
528
|
-
requests = result.output
|
|
529
|
-
done, _ = await asyncio.wait(deps.tasks)
|
|
530
|
-
|
|
531
|
-
task_results = [task.result() for task in done]
|
|
532
|
-
task_results_by_tool_call_id = {
|
|
533
|
-
result.tool_call_id: result.answer for result in task_results
|
|
534
|
-
}
|
|
535
|
-
logger.info("got task results", task_results_by_tool_call_id)
|
|
536
|
-
results = DeferredToolResults()
|
|
537
|
-
for call in requests.calls:
|
|
538
|
-
results.calls[call.tool_call_id] = task_results_by_tool_call_id[
|
|
539
|
-
call.tool_call_id
|
|
540
|
-
]
|
|
541
|
-
result = await agent.run(
|
|
542
|
-
deps=deps,
|
|
543
|
-
usage_limits=usage_limits,
|
|
544
|
-
message_history=messages,
|
|
545
|
-
deferred_tool_results=results,
|
|
546
|
-
)
|
|
547
|
-
# Apply persistent compaction to prevent cascading token growth in multi-turn loops
|
|
548
|
-
messages = await apply_persistent_compaction(result.all_messages(), deps)
|
|
549
|
-
|
|
550
549
|
# Log file operations summary if any files were modified
|
|
551
550
|
if deps.file_tracker.operations:
|
|
552
551
|
summary = deps.file_tracker.format_summary()
|
|
@@ -24,11 +24,5 @@ ANTHROPIC_PROVIDER = ConfigSection.ANTHROPIC.value
|
|
|
24
24
|
GOOGLE_PROVIDER = ConfigSection.GOOGLE.value
|
|
25
25
|
SHOTGUN_PROVIDER = ConfigSection.SHOTGUN.value
|
|
26
26
|
|
|
27
|
-
# Environment variable names
|
|
28
|
-
OPENAI_API_KEY_ENV = "OPENAI_API_KEY"
|
|
29
|
-
ANTHROPIC_API_KEY_ENV = "ANTHROPIC_API_KEY"
|
|
30
|
-
GEMINI_API_KEY_ENV = "GEMINI_API_KEY"
|
|
31
|
-
SHOTGUN_API_KEY_ENV = "SHOTGUN_API_KEY"
|
|
32
|
-
|
|
33
27
|
# Token limits
|
|
34
28
|
MEDIUM_TEXT_8K_TOKENS = 8192 # Default max_tokens for web search requests
|
shotgun/agents/config/manager.py
CHANGED
|
@@ -5,6 +5,8 @@ import uuid
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
+
import aiofiles
|
|
9
|
+
import aiofiles.os
|
|
8
10
|
from pydantic import SecretStr
|
|
9
11
|
|
|
10
12
|
from shotgun.logging_config import get_logger
|
|
@@ -48,7 +50,7 @@ class ConfigManager:
|
|
|
48
50
|
|
|
49
51
|
self._config: ShotgunConfig | None = None
|
|
50
52
|
|
|
51
|
-
def load(self, force_reload: bool = True) -> ShotgunConfig:
|
|
53
|
+
async def load(self, force_reload: bool = True) -> ShotgunConfig:
|
|
52
54
|
"""Load configuration from file.
|
|
53
55
|
|
|
54
56
|
Args:
|
|
@@ -60,18 +62,19 @@ class ConfigManager:
|
|
|
60
62
|
if self._config is not None and not force_reload:
|
|
61
63
|
return self._config
|
|
62
64
|
|
|
63
|
-
if not
|
|
65
|
+
if not await aiofiles.os.path.exists(self.config_path):
|
|
64
66
|
logger.info(
|
|
65
67
|
"Configuration file not found, creating new config at: %s",
|
|
66
68
|
self.config_path,
|
|
67
69
|
)
|
|
68
70
|
# Create new config with generated shotgun_instance_id
|
|
69
|
-
self._config = self.initialize()
|
|
71
|
+
self._config = await self.initialize()
|
|
70
72
|
return self._config
|
|
71
73
|
|
|
72
74
|
try:
|
|
73
|
-
with open(self.config_path, encoding="utf-8") as f:
|
|
74
|
-
|
|
75
|
+
async with aiofiles.open(self.config_path, encoding="utf-8") as f:
|
|
76
|
+
content = await f.read()
|
|
77
|
+
data = json.loads(content)
|
|
75
78
|
|
|
76
79
|
# Migration: Rename user_id to shotgun_instance_id (config v2 -> v3)
|
|
77
80
|
if "user_id" in data and SHOTGUN_INSTANCE_ID_FIELD not in data:
|
|
@@ -101,6 +104,12 @@ class ConfigManager:
|
|
|
101
104
|
"Existing BYOK user detected: set shown_welcome_screen=False to show welcome screen"
|
|
102
105
|
)
|
|
103
106
|
|
|
107
|
+
# Migration: Add marketing config for v3 -> v4
|
|
108
|
+
if "marketing" not in data:
|
|
109
|
+
data["marketing"] = {"messages": {}}
|
|
110
|
+
data["config_version"] = 4
|
|
111
|
+
logger.info("Migrated config v3->v4: added marketing configuration")
|
|
112
|
+
|
|
104
113
|
# Convert plain text secrets to SecretStr objects
|
|
105
114
|
self._convert_secrets_to_secretstr(data)
|
|
106
115
|
|
|
@@ -117,7 +126,7 @@ class ConfigManager:
|
|
|
117
126
|
|
|
118
127
|
if self._config.selected_model in MODEL_SPECS:
|
|
119
128
|
spec = MODEL_SPECS[self._config.selected_model]
|
|
120
|
-
if not self.has_provider_key(spec.provider):
|
|
129
|
+
if not await self.has_provider_key(spec.provider):
|
|
121
130
|
logger.info(
|
|
122
131
|
"Selected model %s provider has no API key, finding available model",
|
|
123
132
|
self._config.selected_model.value,
|
|
@@ -135,14 +144,14 @@ class ConfigManager:
|
|
|
135
144
|
# If no selected_model or it was invalid, find first available model
|
|
136
145
|
if not self._config.selected_model:
|
|
137
146
|
for provider in ProviderType:
|
|
138
|
-
if self.has_provider_key(provider):
|
|
147
|
+
if await self.has_provider_key(provider):
|
|
139
148
|
# Set to that provider's default model
|
|
140
149
|
from .models import MODEL_SPECS, ModelName
|
|
141
150
|
|
|
142
151
|
# Find default model for this provider
|
|
143
152
|
provider_models = {
|
|
144
153
|
ProviderType.OPENAI: ModelName.GPT_5,
|
|
145
|
-
ProviderType.ANTHROPIC: ModelName.
|
|
154
|
+
ProviderType.ANTHROPIC: ModelName.CLAUDE_HAIKU_4_5,
|
|
146
155
|
ProviderType.GOOGLE: ModelName.GEMINI_2_5_PRO,
|
|
147
156
|
}
|
|
148
157
|
|
|
@@ -156,7 +165,7 @@ class ConfigManager:
|
|
|
156
165
|
break
|
|
157
166
|
|
|
158
167
|
if should_save:
|
|
159
|
-
self.save(self._config)
|
|
168
|
+
await self.save(self._config)
|
|
160
169
|
|
|
161
170
|
return self._config
|
|
162
171
|
|
|
@@ -165,10 +174,10 @@ class ConfigManager:
|
|
|
165
174
|
"Failed to load configuration from %s: %s", self.config_path, e
|
|
166
175
|
)
|
|
167
176
|
logger.info("Creating new configuration with generated shotgun_instance_id")
|
|
168
|
-
self._config = self.initialize()
|
|
177
|
+
self._config = await self.initialize()
|
|
169
178
|
return self._config
|
|
170
179
|
|
|
171
|
-
def save(self, config: ShotgunConfig | None = None) -> None:
|
|
180
|
+
async def save(self, config: ShotgunConfig | None = None) -> None:
|
|
172
181
|
"""Save configuration to file.
|
|
173
182
|
|
|
174
183
|
Args:
|
|
@@ -184,15 +193,17 @@ class ConfigManager:
|
|
|
184
193
|
)
|
|
185
194
|
|
|
186
195
|
# Ensure directory exists
|
|
187
|
-
self.config_path.parent
|
|
196
|
+
await aiofiles.os.makedirs(self.config_path.parent, exist_ok=True)
|
|
188
197
|
|
|
189
198
|
try:
|
|
190
199
|
# Convert SecretStr to plain text for JSON serialization
|
|
191
200
|
data = config.model_dump()
|
|
192
201
|
self._convert_secretstr_to_plain(data)
|
|
202
|
+
self._convert_datetime_to_isoformat(data)
|
|
193
203
|
|
|
194
|
-
|
|
195
|
-
|
|
204
|
+
json_content = json.dumps(data, indent=2, ensure_ascii=False)
|
|
205
|
+
async with aiofiles.open(self.config_path, "w", encoding="utf-8") as f:
|
|
206
|
+
await f.write(json_content)
|
|
196
207
|
|
|
197
208
|
logger.debug("Configuration saved to %s", self.config_path)
|
|
198
209
|
self._config = config
|
|
@@ -201,14 +212,16 @@ class ConfigManager:
|
|
|
201
212
|
logger.error("Failed to save configuration to %s: %s", self.config_path, e)
|
|
202
213
|
raise
|
|
203
214
|
|
|
204
|
-
def update_provider(
|
|
215
|
+
async def update_provider(
|
|
216
|
+
self, provider: ProviderType | str, **kwargs: Any
|
|
217
|
+
) -> None:
|
|
205
218
|
"""Update provider configuration.
|
|
206
219
|
|
|
207
220
|
Args:
|
|
208
221
|
provider: Provider to update
|
|
209
222
|
**kwargs: Configuration fields to update (only api_key supported)
|
|
210
223
|
"""
|
|
211
|
-
config = self.load()
|
|
224
|
+
config = await self.load()
|
|
212
225
|
|
|
213
226
|
# Get provider config and check if it's shotgun
|
|
214
227
|
provider_config, is_shotgun = self._get_provider_config_and_type(
|
|
@@ -243,17 +256,21 @@ class ConfigManager:
|
|
|
243
256
|
|
|
244
257
|
provider_models = {
|
|
245
258
|
ProviderType.OPENAI: ModelName.GPT_5,
|
|
246
|
-
ProviderType.ANTHROPIC: ModelName.
|
|
259
|
+
ProviderType.ANTHROPIC: ModelName.CLAUDE_HAIKU_4_5,
|
|
247
260
|
ProviderType.GOOGLE: ModelName.GEMINI_2_5_PRO,
|
|
248
261
|
}
|
|
249
262
|
if provider_enum in provider_models:
|
|
250
263
|
config.selected_model = provider_models[provider_enum]
|
|
251
264
|
|
|
252
|
-
|
|
265
|
+
# Mark welcome screen as shown when BYOK provider is configured
|
|
266
|
+
# This prevents the welcome screen from showing again after user has made their choice
|
|
267
|
+
config.shown_welcome_screen = True
|
|
268
|
+
|
|
269
|
+
await self.save(config)
|
|
253
270
|
|
|
254
|
-
def clear_provider_key(self, provider: ProviderType | str) -> None:
|
|
271
|
+
async def clear_provider_key(self, provider: ProviderType | str) -> None:
|
|
255
272
|
"""Remove the API key for the given provider (LLM provider or shotgun)."""
|
|
256
|
-
config = self.load()
|
|
273
|
+
config = await self.load()
|
|
257
274
|
|
|
258
275
|
# Get provider config (shotgun or LLM provider)
|
|
259
276
|
provider_config, is_shotgun = self._get_provider_config_and_type(
|
|
@@ -266,34 +283,34 @@ class ConfigManager:
|
|
|
266
283
|
if is_shotgun and isinstance(provider_config, ShotgunAccountConfig):
|
|
267
284
|
provider_config.supabase_jwt = None
|
|
268
285
|
|
|
269
|
-
self.save(config)
|
|
286
|
+
await self.save(config)
|
|
270
287
|
|
|
271
|
-
def update_selected_model(self, model_name: "ModelName") -> None:
|
|
288
|
+
async def update_selected_model(self, model_name: "ModelName") -> None:
|
|
272
289
|
"""Update the selected model.
|
|
273
290
|
|
|
274
291
|
Args:
|
|
275
292
|
model_name: Model to select
|
|
276
293
|
"""
|
|
277
|
-
config = self.load()
|
|
294
|
+
config = await self.load()
|
|
278
295
|
config.selected_model = model_name
|
|
279
|
-
self.save(config)
|
|
296
|
+
await self.save(config)
|
|
280
297
|
|
|
281
|
-
def has_provider_key(self, provider: ProviderType | str) -> bool:
|
|
298
|
+
async def has_provider_key(self, provider: ProviderType | str) -> bool:
|
|
282
299
|
"""Check if the given provider has a non-empty API key configured.
|
|
283
300
|
|
|
284
301
|
This checks only the configuration file.
|
|
285
302
|
"""
|
|
286
303
|
# Use force_reload=False to avoid infinite loop when called from load()
|
|
287
|
-
config = self.load(force_reload=False)
|
|
304
|
+
config = await self.load(force_reload=False)
|
|
288
305
|
provider_enum = self._ensure_provider_enum(provider)
|
|
289
306
|
provider_config = self._get_provider_config(config, provider_enum)
|
|
290
307
|
|
|
291
308
|
return self._provider_has_api_key(provider_config)
|
|
292
309
|
|
|
293
|
-
def has_any_provider_key(self) -> bool:
|
|
310
|
+
async def has_any_provider_key(self) -> bool:
|
|
294
311
|
"""Determine whether any provider has a configured API key."""
|
|
295
312
|
# Use force_reload=False to avoid infinite loop when called from load()
|
|
296
|
-
config = self.load(force_reload=False)
|
|
313
|
+
config = await self.load(force_reload=False)
|
|
297
314
|
# Check LLM provider keys (BYOK)
|
|
298
315
|
has_llm_key = any(
|
|
299
316
|
self._provider_has_api_key(self._get_provider_config(config, provider))
|
|
@@ -307,7 +324,7 @@ class ConfigManager:
|
|
|
307
324
|
has_shotgun_key = self._provider_has_api_key(config.shotgun)
|
|
308
325
|
return has_llm_key or has_shotgun_key
|
|
309
326
|
|
|
310
|
-
def initialize(self) -> ShotgunConfig:
|
|
327
|
+
async def initialize(self) -> ShotgunConfig:
|
|
311
328
|
"""Initialize configuration with defaults and save to file.
|
|
312
329
|
|
|
313
330
|
Returns:
|
|
@@ -317,7 +334,7 @@ class ConfigManager:
|
|
|
317
334
|
config = ShotgunConfig(
|
|
318
335
|
shotgun_instance_id=str(uuid.uuid4()),
|
|
319
336
|
)
|
|
320
|
-
self.save(config)
|
|
337
|
+
await self.save(config)
|
|
321
338
|
logger.info(
|
|
322
339
|
"Configuration initialized at %s with shotgun_instance_id: %s",
|
|
323
340
|
self.config_path,
|
|
@@ -373,6 +390,24 @@ class ConfigManager:
|
|
|
373
390
|
SUPABASE_JWT_FIELD
|
|
374
391
|
].get_secret_value()
|
|
375
392
|
|
|
393
|
+
def _convert_datetime_to_isoformat(self, data: dict[str, Any]) -> None:
|
|
394
|
+
"""Convert datetime objects in data to ISO8601 format strings for JSON serialization."""
|
|
395
|
+
from datetime import datetime
|
|
396
|
+
|
|
397
|
+
def convert_dict(d: dict[str, Any]) -> None:
|
|
398
|
+
"""Recursively convert datetime objects in a dict."""
|
|
399
|
+
for key, value in d.items():
|
|
400
|
+
if isinstance(value, datetime):
|
|
401
|
+
d[key] = value.isoformat()
|
|
402
|
+
elif isinstance(value, dict):
|
|
403
|
+
convert_dict(value)
|
|
404
|
+
elif isinstance(value, list):
|
|
405
|
+
for item in value:
|
|
406
|
+
if isinstance(item, dict):
|
|
407
|
+
convert_dict(item)
|
|
408
|
+
|
|
409
|
+
convert_dict(data)
|
|
410
|
+
|
|
376
411
|
def _ensure_provider_enum(self, provider: ProviderType | str) -> ProviderType:
|
|
377
412
|
"""Normalize provider values to ProviderType enum."""
|
|
378
413
|
return (
|
|
@@ -436,16 +471,16 @@ class ConfigManager:
|
|
|
436
471
|
provider_enum = self._ensure_provider_enum(provider)
|
|
437
472
|
return (self._get_provider_config(config, provider_enum), False)
|
|
438
473
|
|
|
439
|
-
def get_shotgun_instance_id(self) -> str:
|
|
474
|
+
async def get_shotgun_instance_id(self) -> str:
|
|
440
475
|
"""Get the shotgun instance ID from configuration.
|
|
441
476
|
|
|
442
477
|
Returns:
|
|
443
478
|
The unique shotgun instance ID string
|
|
444
479
|
"""
|
|
445
|
-
config = self.load()
|
|
480
|
+
config = await self.load()
|
|
446
481
|
return config.shotgun_instance_id
|
|
447
482
|
|
|
448
|
-
def update_shotgun_account(
|
|
483
|
+
async def update_shotgun_account(
|
|
449
484
|
self, api_key: str | None = None, supabase_jwt: str | None = None
|
|
450
485
|
) -> None:
|
|
451
486
|
"""Update Shotgun Account configuration.
|
|
@@ -454,7 +489,7 @@ class ConfigManager:
|
|
|
454
489
|
api_key: LiteLLM proxy API key (optional)
|
|
455
490
|
supabase_jwt: Supabase authentication JWT (optional)
|
|
456
491
|
"""
|
|
457
|
-
config = self.load()
|
|
492
|
+
config = await self.load()
|
|
458
493
|
|
|
459
494
|
if api_key is not None:
|
|
460
495
|
config.shotgun.api_key = SecretStr(api_key) if api_key else None
|
|
@@ -464,7 +499,7 @@ class ConfigManager:
|
|
|
464
499
|
SecretStr(supabase_jwt) if supabase_jwt else None
|
|
465
500
|
)
|
|
466
501
|
|
|
467
|
-
self.save(config)
|
|
502
|
+
await self.save(config)
|
|
468
503
|
logger.info("Updated Shotgun Account configuration")
|
|
469
504
|
|
|
470
505
|
|