shotgun-sh 0.2.3.dev2__py3-none-any.whl → 0.2.11.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/agent_manager.py +664 -75
- shotgun/agents/common.py +76 -70
- shotgun/agents/config/constants.py +0 -6
- shotgun/agents/config/manager.py +78 -36
- shotgun/agents/config/models.py +41 -1
- shotgun/agents/config/provider.py +70 -15
- shotgun/agents/context_analyzer/__init__.py +28 -0
- shotgun/agents/context_analyzer/analyzer.py +471 -0
- shotgun/agents/context_analyzer/constants.py +9 -0
- shotgun/agents/context_analyzer/formatter.py +115 -0
- shotgun/agents/context_analyzer/models.py +212 -0
- shotgun/agents/conversation_history.py +125 -2
- shotgun/agents/conversation_manager.py +57 -19
- shotgun/agents/export.py +6 -7
- shotgun/agents/history/compaction.py +9 -4
- shotgun/agents/history/context_extraction.py +93 -6
- shotgun/agents/history/history_processors.py +14 -2
- shotgun/agents/history/token_counting/anthropic.py +49 -11
- shotgun/agents/history/token_counting/base.py +14 -3
- shotgun/agents/history/token_counting/openai.py +8 -0
- shotgun/agents/history/token_counting/sentencepiece_counter.py +8 -0
- shotgun/agents/history/token_counting/tokenizer_cache.py +3 -1
- shotgun/agents/history/token_counting/utils.py +0 -3
- shotgun/agents/models.py +50 -2
- shotgun/agents/plan.py +6 -7
- shotgun/agents/research.py +7 -8
- shotgun/agents/specify.py +6 -7
- shotgun/agents/tasks.py +6 -7
- shotgun/agents/tools/__init__.py +0 -2
- shotgun/agents/tools/codebase/codebase_shell.py +6 -0
- shotgun/agents/tools/codebase/directory_lister.py +6 -0
- shotgun/agents/tools/codebase/file_read.py +11 -2
- shotgun/agents/tools/codebase/query_graph.py +6 -0
- shotgun/agents/tools/codebase/retrieve_code.py +6 -0
- shotgun/agents/tools/file_management.py +82 -16
- shotgun/agents/tools/registry.py +217 -0
- shotgun/agents/tools/web_search/__init__.py +30 -18
- shotgun/agents/tools/web_search/anthropic.py +26 -5
- shotgun/agents/tools/web_search/gemini.py +23 -11
- shotgun/agents/tools/web_search/openai.py +22 -13
- shotgun/agents/tools/web_search/utils.py +2 -2
- shotgun/agents/usage_manager.py +16 -11
- shotgun/api_endpoints.py +7 -3
- shotgun/build_constants.py +1 -1
- shotgun/cli/clear.py +53 -0
- shotgun/cli/compact.py +186 -0
- shotgun/cli/config.py +8 -5
- shotgun/cli/context.py +111 -0
- shotgun/cli/export.py +1 -1
- shotgun/cli/feedback.py +4 -2
- shotgun/cli/models.py +1 -0
- shotgun/cli/plan.py +1 -1
- shotgun/cli/research.py +1 -1
- shotgun/cli/specify.py +1 -1
- shotgun/cli/tasks.py +1 -1
- shotgun/cli/update.py +16 -2
- shotgun/codebase/core/change_detector.py +5 -3
- shotgun/codebase/core/code_retrieval.py +4 -2
- shotgun/codebase/core/ingestor.py +10 -8
- shotgun/codebase/core/manager.py +13 -4
- shotgun/codebase/core/nl_query.py +1 -1
- shotgun/llm_proxy/__init__.py +5 -2
- shotgun/llm_proxy/clients.py +12 -7
- shotgun/logging_config.py +18 -27
- shotgun/main.py +73 -11
- shotgun/posthog_telemetry.py +23 -7
- shotgun/prompts/agents/export.j2 +18 -1
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +5 -1
- shotgun/prompts/agents/partials/interactive_mode.j2 +24 -7
- shotgun/prompts/agents/plan.j2 +1 -1
- shotgun/prompts/agents/research.j2 +1 -1
- shotgun/prompts/agents/specify.j2 +270 -3
- shotgun/prompts/agents/state/system_state.j2 +4 -0
- shotgun/prompts/agents/tasks.j2 +1 -1
- shotgun/prompts/loader.py +2 -2
- shotgun/prompts/tools/web_search.j2 +14 -0
- shotgun/sentry_telemetry.py +7 -16
- shotgun/settings.py +238 -0
- shotgun/telemetry.py +18 -33
- shotgun/tui/app.py +243 -43
- shotgun/tui/commands/__init__.py +1 -1
- shotgun/tui/components/context_indicator.py +179 -0
- shotgun/tui/components/mode_indicator.py +70 -0
- shotgun/tui/components/status_bar.py +48 -0
- shotgun/tui/containers.py +91 -0
- shotgun/tui/dependencies.py +39 -0
- shotgun/tui/protocols.py +45 -0
- shotgun/tui/screens/chat/__init__.py +5 -0
- shotgun/tui/screens/chat/chat.tcss +54 -0
- shotgun/tui/screens/chat/chat_screen.py +1202 -0
- shotgun/tui/screens/chat/codebase_index_prompt_screen.py +64 -0
- shotgun/tui/screens/chat/codebase_index_selection.py +12 -0
- shotgun/tui/screens/chat/help_text.py +40 -0
- shotgun/tui/screens/chat/prompt_history.py +48 -0
- shotgun/tui/screens/chat.tcss +11 -0
- shotgun/tui/screens/chat_screen/command_providers.py +78 -2
- shotgun/tui/screens/chat_screen/history/__init__.py +22 -0
- shotgun/tui/screens/chat_screen/history/agent_response.py +66 -0
- shotgun/tui/screens/chat_screen/history/chat_history.py +116 -0
- shotgun/tui/screens/chat_screen/history/formatters.py +115 -0
- shotgun/tui/screens/chat_screen/history/partial_response.py +43 -0
- shotgun/tui/screens/chat_screen/history/user_question.py +42 -0
- shotgun/tui/screens/confirmation_dialog.py +151 -0
- shotgun/tui/screens/feedback.py +4 -4
- shotgun/tui/screens/github_issue.py +102 -0
- shotgun/tui/screens/model_picker.py +49 -24
- shotgun/tui/screens/onboarding.py +431 -0
- shotgun/tui/screens/pipx_migration.py +153 -0
- shotgun/tui/screens/provider_config.py +50 -27
- shotgun/tui/screens/shotgun_auth.py +2 -2
- shotgun/tui/screens/welcome.py +32 -10
- shotgun/tui/services/__init__.py +5 -0
- shotgun/tui/services/conversation_service.py +184 -0
- shotgun/tui/state/__init__.py +7 -0
- shotgun/tui/state/processing_state.py +185 -0
- shotgun/tui/utils/mode_progress.py +14 -7
- shotgun/tui/widgets/__init__.py +5 -0
- shotgun/tui/widgets/widget_coordinator.py +262 -0
- shotgun/utils/datetime_utils.py +77 -0
- shotgun/utils/file_system_utils.py +22 -2
- shotgun/utils/marketing.py +110 -0
- shotgun/utils/update_checker.py +69 -14
- shotgun_sh-0.2.11.dev5.dist-info/METADATA +130 -0
- shotgun_sh-0.2.11.dev5.dist-info/RECORD +193 -0
- {shotgun_sh-0.2.3.dev2.dist-info → shotgun_sh-0.2.11.dev5.dist-info}/entry_points.txt +1 -0
- {shotgun_sh-0.2.3.dev2.dist-info → shotgun_sh-0.2.11.dev5.dist-info}/licenses/LICENSE +1 -1
- shotgun/agents/tools/user_interaction.py +0 -37
- shotgun/tui/screens/chat.py +0 -804
- shotgun/tui/screens/chat_screen/history.py +0 -352
- shotgun_sh-0.2.3.dev2.dist-info/METADATA +0 -467
- shotgun_sh-0.2.3.dev2.dist-info/RECORD +0 -154
- {shotgun_sh-0.2.3.dev2.dist-info → shotgun_sh-0.2.11.dev5.dist-info}/WHEEL +0 -0
shotgun/agents/common.py
CHANGED
|
@@ -1,14 +1,12 @@
|
|
|
1
1
|
"""Common utilities for agent creation and management."""
|
|
2
2
|
|
|
3
|
-
import asyncio
|
|
4
3
|
from collections.abc import Callable
|
|
5
4
|
from pathlib import Path
|
|
6
5
|
from typing import Any
|
|
7
6
|
|
|
7
|
+
import aiofiles
|
|
8
8
|
from pydantic_ai import (
|
|
9
9
|
Agent,
|
|
10
|
-
DeferredToolRequests,
|
|
11
|
-
DeferredToolResults,
|
|
12
10
|
RunContext,
|
|
13
11
|
UsageLimits,
|
|
14
12
|
)
|
|
@@ -19,20 +17,19 @@ from pydantic_ai.messages import (
|
|
|
19
17
|
)
|
|
20
18
|
|
|
21
19
|
from shotgun.agents.config import ProviderType, get_provider_model
|
|
22
|
-
from shotgun.agents.models import AgentType
|
|
20
|
+
from shotgun.agents.models import AgentResponse, AgentType
|
|
23
21
|
from shotgun.logging_config import get_logger
|
|
24
22
|
from shotgun.prompts import PromptLoader
|
|
25
23
|
from shotgun.sdk.services import get_codebase_service
|
|
26
24
|
from shotgun.utils import ensure_shotgun_directory_exists
|
|
25
|
+
from shotgun.utils.datetime_utils import get_datetime_context
|
|
27
26
|
from shotgun.utils.file_system_utils import get_shotgun_base_path
|
|
28
27
|
|
|
29
28
|
from .history import token_limit_compactor
|
|
30
|
-
from .history.compaction import apply_persistent_compaction
|
|
31
29
|
from .messages import AgentSystemPrompt, SystemStatusPrompt
|
|
32
30
|
from .models import AgentDeps, AgentRuntimeOptions, PipelineConfigEntry
|
|
33
31
|
from .tools import (
|
|
34
32
|
append_file,
|
|
35
|
-
ask_user,
|
|
36
33
|
codebase_shell,
|
|
37
34
|
directory_lister,
|
|
38
35
|
file_read,
|
|
@@ -72,7 +69,10 @@ async def add_system_status_message(
|
|
|
72
69
|
existing_files = get_agent_existing_files(deps.agent_mode)
|
|
73
70
|
|
|
74
71
|
# Extract table of contents from the agent's markdown file
|
|
75
|
-
markdown_toc = extract_markdown_toc(deps.agent_mode)
|
|
72
|
+
markdown_toc = await extract_markdown_toc(deps.agent_mode)
|
|
73
|
+
|
|
74
|
+
# Get current datetime with timezone information
|
|
75
|
+
dt_context = get_datetime_context()
|
|
76
76
|
|
|
77
77
|
system_state = prompt_loader.render(
|
|
78
78
|
"agents/state/system_state.j2",
|
|
@@ -80,6 +80,9 @@ async def add_system_status_message(
|
|
|
80
80
|
is_tui_context=deps.is_tui_context,
|
|
81
81
|
existing_files=existing_files,
|
|
82
82
|
markdown_toc=markdown_toc,
|
|
83
|
+
current_datetime=dt_context.datetime_formatted,
|
|
84
|
+
timezone_name=dt_context.timezone_name,
|
|
85
|
+
utc_offset=dt_context.utc_offset,
|
|
83
86
|
)
|
|
84
87
|
|
|
85
88
|
message_history.append(
|
|
@@ -92,14 +95,14 @@ async def add_system_status_message(
|
|
|
92
95
|
return message_history
|
|
93
96
|
|
|
94
97
|
|
|
95
|
-
def create_base_agent(
|
|
98
|
+
async def create_base_agent(
|
|
96
99
|
system_prompt_fn: Callable[[RunContext[AgentDeps]], str],
|
|
97
100
|
agent_runtime_options: AgentRuntimeOptions,
|
|
98
101
|
load_codebase_understanding_tools: bool = True,
|
|
99
102
|
additional_tools: list[Any] | None = None,
|
|
100
103
|
provider: ProviderType | None = None,
|
|
101
104
|
agent_mode: AgentType | None = None,
|
|
102
|
-
) -> tuple[Agent[AgentDeps,
|
|
105
|
+
) -> tuple[Agent[AgentDeps, AgentResponse], AgentDeps]:
|
|
103
106
|
"""Create a base agent with common configuration.
|
|
104
107
|
|
|
105
108
|
Args:
|
|
@@ -117,7 +120,7 @@ def create_base_agent(
|
|
|
117
120
|
|
|
118
121
|
# Get configured model or fall back to first available provider
|
|
119
122
|
try:
|
|
120
|
-
model_config = get_provider_model(provider)
|
|
123
|
+
model_config = await get_provider_model(provider)
|
|
121
124
|
provider_name = model_config.provider
|
|
122
125
|
logger.debug(
|
|
123
126
|
"🤖 Creating agent with configured %s model: %s",
|
|
@@ -157,7 +160,7 @@ def create_base_agent(
|
|
|
157
160
|
|
|
158
161
|
agent = Agent(
|
|
159
162
|
model,
|
|
160
|
-
output_type=
|
|
163
|
+
output_type=AgentResponse,
|
|
161
164
|
deps_type=AgentDeps,
|
|
162
165
|
instrument=True,
|
|
163
166
|
history_processors=[history_processor],
|
|
@@ -172,11 +175,6 @@ def create_base_agent(
|
|
|
172
175
|
for tool in additional_tools or []:
|
|
173
176
|
agent.tool_plain(tool)
|
|
174
177
|
|
|
175
|
-
# Register interactive tool conditionally based on deps
|
|
176
|
-
if deps.interactive_mode:
|
|
177
|
-
agent.tool(ask_user)
|
|
178
|
-
logger.debug("📞 Interactive mode enabled - ask_user tool registered")
|
|
179
|
-
|
|
180
178
|
# Register common file management tools (always available)
|
|
181
179
|
agent.tool(write_file)
|
|
182
180
|
agent.tool(append_file)
|
|
@@ -197,7 +195,7 @@ def create_base_agent(
|
|
|
197
195
|
return agent, deps
|
|
198
196
|
|
|
199
197
|
|
|
200
|
-
def _extract_file_toc_content(
|
|
198
|
+
async def _extract_file_toc_content(
|
|
201
199
|
file_path: Path, max_depth: int | None = None, max_chars: int = 500
|
|
202
200
|
) -> str | None:
|
|
203
201
|
"""Extract TOC from a single file with depth and character limits.
|
|
@@ -214,7 +212,8 @@ def _extract_file_toc_content(
|
|
|
214
212
|
return None
|
|
215
213
|
|
|
216
214
|
try:
|
|
217
|
-
|
|
215
|
+
async with aiofiles.open(file_path, encoding="utf-8") as f:
|
|
216
|
+
content = await f.read()
|
|
218
217
|
lines = content.split("\n")
|
|
219
218
|
|
|
220
219
|
# Extract headings
|
|
@@ -260,7 +259,7 @@ def _extract_file_toc_content(
|
|
|
260
259
|
return None
|
|
261
260
|
|
|
262
261
|
|
|
263
|
-
def extract_markdown_toc(agent_mode: AgentType | None) -> str | None:
|
|
262
|
+
async def extract_markdown_toc(agent_mode: AgentType | None) -> str | None:
|
|
264
263
|
"""Extract TOCs from current and prior agents' files in the pipeline.
|
|
265
264
|
|
|
266
265
|
Shows full TOC of agent's own file and high-level summaries of prior agents'
|
|
@@ -312,22 +311,30 @@ def extract_markdown_toc(agent_mode: AgentType | None) -> str | None:
|
|
|
312
311
|
for prior_file in config.prior_files:
|
|
313
312
|
file_path = base_path / prior_file
|
|
314
313
|
# Only show # and ## headings from prior files, max 500 chars each
|
|
315
|
-
prior_toc = _extract_file_toc_content(
|
|
314
|
+
prior_toc = await _extract_file_toc_content(
|
|
315
|
+
file_path, max_depth=2, max_chars=500
|
|
316
|
+
)
|
|
316
317
|
if prior_toc:
|
|
317
318
|
# Add section with XML tags
|
|
318
319
|
toc_sections.append(
|
|
319
|
-
f'<TABLE_OF_CONTENTS file_name="{prior_file}">\n
|
|
320
|
+
f'<TABLE_OF_CONTENTS file_name="{prior_file}">\n'
|
|
321
|
+
f"{prior_toc}\n"
|
|
322
|
+
f"</TABLE_OF_CONTENTS>"
|
|
320
323
|
)
|
|
321
324
|
|
|
322
325
|
# Extract TOC from own file (full detail)
|
|
323
326
|
if config.own_file:
|
|
324
327
|
own_path = base_path / config.own_file
|
|
325
|
-
own_toc = _extract_file_toc_content(
|
|
328
|
+
own_toc = await _extract_file_toc_content(
|
|
329
|
+
own_path, max_depth=None, max_chars=2000
|
|
330
|
+
)
|
|
326
331
|
if own_toc:
|
|
327
332
|
# Put own file TOC at the beginning with XML tags
|
|
328
333
|
toc_sections.insert(
|
|
329
334
|
0,
|
|
330
|
-
f'<TABLE_OF_CONTENTS file_name="{config.own_file}">\n
|
|
335
|
+
f'<TABLE_OF_CONTENTS file_name="{config.own_file}">\n'
|
|
336
|
+
f"{own_toc}\n"
|
|
337
|
+
f"</TABLE_OF_CONTENTS>",
|
|
331
338
|
)
|
|
332
339
|
|
|
333
340
|
# Combine all sections
|
|
@@ -383,23 +390,48 @@ def get_agent_existing_files(agent_mode: AgentType | None = None) -> list[str]:
|
|
|
383
390
|
relative_path = file_path.relative_to(base_path)
|
|
384
391
|
existing_files.append(str(relative_path))
|
|
385
392
|
else:
|
|
386
|
-
# For other agents, check
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
#
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
#
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
393
|
+
# For other agents, check files/directories they have access to
|
|
394
|
+
allowed_paths_raw = AGENT_DIRECTORIES[agent_mode]
|
|
395
|
+
|
|
396
|
+
# Convert single Path/string to list of Paths for uniform handling
|
|
397
|
+
if isinstance(allowed_paths_raw, str):
|
|
398
|
+
# Special case: "*" means export agent (shouldn't reach here but handle it)
|
|
399
|
+
allowed_paths = (
|
|
400
|
+
[Path(allowed_paths_raw)] if allowed_paths_raw != "*" else []
|
|
401
|
+
)
|
|
402
|
+
elif isinstance(allowed_paths_raw, Path):
|
|
403
|
+
allowed_paths = [allowed_paths_raw]
|
|
404
|
+
else:
|
|
405
|
+
# Already a list
|
|
406
|
+
allowed_paths = allowed_paths_raw
|
|
407
|
+
|
|
408
|
+
# Check each allowed path
|
|
409
|
+
for allowed_path in allowed_paths:
|
|
410
|
+
allowed_str = str(allowed_path)
|
|
411
|
+
|
|
412
|
+
# Check if it's a directory (no .md suffix)
|
|
413
|
+
if not allowed_path.suffix or not allowed_str.endswith(".md"):
|
|
414
|
+
# It's a directory - list all files within it
|
|
415
|
+
dir_path = base_path / allowed_str
|
|
416
|
+
if dir_path.exists() and dir_path.is_dir():
|
|
417
|
+
for file_path in dir_path.rglob("*"):
|
|
418
|
+
if file_path.is_file():
|
|
419
|
+
relative_path = file_path.relative_to(base_path)
|
|
420
|
+
existing_files.append(str(relative_path))
|
|
421
|
+
else:
|
|
422
|
+
# It's a file - check if it exists
|
|
423
|
+
file_path = base_path / allowed_str
|
|
424
|
+
if file_path.exists():
|
|
425
|
+
existing_files.append(allowed_str)
|
|
426
|
+
|
|
427
|
+
# Also check for associated directory (e.g., research/ for research.md)
|
|
428
|
+
base_name = allowed_str.replace(".md", "")
|
|
429
|
+
dir_path = base_path / base_name
|
|
430
|
+
if dir_path.exists() and dir_path.is_dir():
|
|
431
|
+
for file_path in dir_path.rglob("*"):
|
|
432
|
+
if file_path.is_file():
|
|
433
|
+
relative_path = file_path.relative_to(base_path)
|
|
434
|
+
existing_files.append(str(relative_path))
|
|
403
435
|
|
|
404
436
|
return existing_files
|
|
405
437
|
|
|
@@ -469,7 +501,8 @@ async def add_system_prompt_message(
|
|
|
469
501
|
message_history = message_history or []
|
|
470
502
|
|
|
471
503
|
# Create a minimal RunContext to call the system prompt function
|
|
472
|
-
# We'll pass None for model and usage since they're not used
|
|
504
|
+
# We'll pass None for model and usage since they're not used
|
|
505
|
+
# by our system prompt functions
|
|
473
506
|
context = type(
|
|
474
507
|
"RunContext", (), {"deps": deps, "retry": 0, "model": None, "usage": None}
|
|
475
508
|
)()
|
|
@@ -493,12 +526,12 @@ async def add_system_prompt_message(
|
|
|
493
526
|
|
|
494
527
|
|
|
495
528
|
async def run_agent(
|
|
496
|
-
agent: Agent[AgentDeps,
|
|
529
|
+
agent: Agent[AgentDeps, AgentResponse],
|
|
497
530
|
prompt: str,
|
|
498
531
|
deps: AgentDeps,
|
|
499
532
|
message_history: list[ModelMessage] | None = None,
|
|
500
533
|
usage_limits: UsageLimits | None = None,
|
|
501
|
-
) -> AgentRunResult[
|
|
534
|
+
) -> AgentRunResult[AgentResponse]:
|
|
502
535
|
# Clear file tracker for new run
|
|
503
536
|
deps.file_tracker.clear()
|
|
504
537
|
logger.debug("🔧 Cleared file tracker for new agent run")
|
|
@@ -513,33 +546,6 @@ async def run_agent(
|
|
|
513
546
|
message_history=message_history,
|
|
514
547
|
)
|
|
515
548
|
|
|
516
|
-
# Apply persistent compaction to prevent cascading token growth across CLI commands
|
|
517
|
-
messages = await apply_persistent_compaction(result.all_messages(), deps)
|
|
518
|
-
while isinstance(result.output, DeferredToolRequests):
|
|
519
|
-
logger.info("got deferred tool requests")
|
|
520
|
-
await deps.queue.join()
|
|
521
|
-
requests = result.output
|
|
522
|
-
done, _ = await asyncio.wait(deps.tasks)
|
|
523
|
-
|
|
524
|
-
task_results = [task.result() for task in done]
|
|
525
|
-
task_results_by_tool_call_id = {
|
|
526
|
-
result.tool_call_id: result.answer for result in task_results
|
|
527
|
-
}
|
|
528
|
-
logger.info("got task results", task_results_by_tool_call_id)
|
|
529
|
-
results = DeferredToolResults()
|
|
530
|
-
for call in requests.calls:
|
|
531
|
-
results.calls[call.tool_call_id] = task_results_by_tool_call_id[
|
|
532
|
-
call.tool_call_id
|
|
533
|
-
]
|
|
534
|
-
result = await agent.run(
|
|
535
|
-
deps=deps,
|
|
536
|
-
usage_limits=usage_limits,
|
|
537
|
-
message_history=messages,
|
|
538
|
-
deferred_tool_results=results,
|
|
539
|
-
)
|
|
540
|
-
# Apply persistent compaction to prevent cascading token growth in multi-turn loops
|
|
541
|
-
messages = await apply_persistent_compaction(result.all_messages(), deps)
|
|
542
|
-
|
|
543
549
|
# Log file operations summary if any files were modified
|
|
544
550
|
if deps.file_tracker.operations:
|
|
545
551
|
summary = deps.file_tracker.format_summary()
|
|
@@ -24,11 +24,5 @@ ANTHROPIC_PROVIDER = ConfigSection.ANTHROPIC.value
|
|
|
24
24
|
GOOGLE_PROVIDER = ConfigSection.GOOGLE.value
|
|
25
25
|
SHOTGUN_PROVIDER = ConfigSection.SHOTGUN.value
|
|
26
26
|
|
|
27
|
-
# Environment variable names
|
|
28
|
-
OPENAI_API_KEY_ENV = "OPENAI_API_KEY"
|
|
29
|
-
ANTHROPIC_API_KEY_ENV = "ANTHROPIC_API_KEY"
|
|
30
|
-
GEMINI_API_KEY_ENV = "GEMINI_API_KEY"
|
|
31
|
-
SHOTGUN_API_KEY_ENV = "SHOTGUN_API_KEY"
|
|
32
|
-
|
|
33
27
|
# Token limits
|
|
34
28
|
MEDIUM_TEXT_8K_TOKENS = 8192 # Default max_tokens for web search requests
|
shotgun/agents/config/manager.py
CHANGED
|
@@ -5,6 +5,8 @@ import uuid
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
+
import aiofiles
|
|
9
|
+
import aiofiles.os
|
|
8
10
|
from pydantic import SecretStr
|
|
9
11
|
|
|
10
12
|
from shotgun.logging_config import get_logger
|
|
@@ -48,7 +50,7 @@ class ConfigManager:
|
|
|
48
50
|
|
|
49
51
|
self._config: ShotgunConfig | None = None
|
|
50
52
|
|
|
51
|
-
def load(self, force_reload: bool = True) -> ShotgunConfig:
|
|
53
|
+
async def load(self, force_reload: bool = True) -> ShotgunConfig:
|
|
52
54
|
"""Load configuration from file.
|
|
53
55
|
|
|
54
56
|
Args:
|
|
@@ -60,18 +62,19 @@ class ConfigManager:
|
|
|
60
62
|
if self._config is not None and not force_reload:
|
|
61
63
|
return self._config
|
|
62
64
|
|
|
63
|
-
if not
|
|
65
|
+
if not await aiofiles.os.path.exists(self.config_path):
|
|
64
66
|
logger.info(
|
|
65
67
|
"Configuration file not found, creating new config at: %s",
|
|
66
68
|
self.config_path,
|
|
67
69
|
)
|
|
68
70
|
# Create new config with generated shotgun_instance_id
|
|
69
|
-
self._config = self.initialize()
|
|
71
|
+
self._config = await self.initialize()
|
|
70
72
|
return self._config
|
|
71
73
|
|
|
72
74
|
try:
|
|
73
|
-
with open(self.config_path, encoding="utf-8") as f:
|
|
74
|
-
|
|
75
|
+
async with aiofiles.open(self.config_path, encoding="utf-8") as f:
|
|
76
|
+
content = await f.read()
|
|
77
|
+
data = json.loads(content)
|
|
75
78
|
|
|
76
79
|
# Migration: Rename user_id to shotgun_instance_id (config v2 -> v3)
|
|
77
80
|
if "user_id" in data and SHOTGUN_INSTANCE_ID_FIELD not in data:
|
|
@@ -101,6 +104,12 @@ class ConfigManager:
|
|
|
101
104
|
"Existing BYOK user detected: set shown_welcome_screen=False to show welcome screen"
|
|
102
105
|
)
|
|
103
106
|
|
|
107
|
+
# Migration: Add marketing config for v3 -> v4
|
|
108
|
+
if "marketing" not in data:
|
|
109
|
+
data["marketing"] = {"messages": {}}
|
|
110
|
+
data["config_version"] = 4
|
|
111
|
+
logger.info("Migrated config v3->v4: added marketing configuration")
|
|
112
|
+
|
|
104
113
|
# Convert plain text secrets to SecretStr objects
|
|
105
114
|
self._convert_secrets_to_secretstr(data)
|
|
106
115
|
|
|
@@ -117,7 +126,7 @@ class ConfigManager:
|
|
|
117
126
|
|
|
118
127
|
if self._config.selected_model in MODEL_SPECS:
|
|
119
128
|
spec = MODEL_SPECS[self._config.selected_model]
|
|
120
|
-
if not self.has_provider_key(spec.provider):
|
|
129
|
+
if not await self.has_provider_key(spec.provider):
|
|
121
130
|
logger.info(
|
|
122
131
|
"Selected model %s provider has no API key, finding available model",
|
|
123
132
|
self._config.selected_model.value,
|
|
@@ -135,14 +144,14 @@ class ConfigManager:
|
|
|
135
144
|
# If no selected_model or it was invalid, find first available model
|
|
136
145
|
if not self._config.selected_model:
|
|
137
146
|
for provider in ProviderType:
|
|
138
|
-
if self.has_provider_key(provider):
|
|
147
|
+
if await self.has_provider_key(provider):
|
|
139
148
|
# Set to that provider's default model
|
|
140
149
|
from .models import MODEL_SPECS, ModelName
|
|
141
150
|
|
|
142
151
|
# Find default model for this provider
|
|
143
152
|
provider_models = {
|
|
144
153
|
ProviderType.OPENAI: ModelName.GPT_5,
|
|
145
|
-
ProviderType.ANTHROPIC: ModelName.
|
|
154
|
+
ProviderType.ANTHROPIC: ModelName.CLAUDE_HAIKU_4_5,
|
|
146
155
|
ProviderType.GOOGLE: ModelName.GEMINI_2_5_PRO,
|
|
147
156
|
}
|
|
148
157
|
|
|
@@ -156,7 +165,7 @@ class ConfigManager:
|
|
|
156
165
|
break
|
|
157
166
|
|
|
158
167
|
if should_save:
|
|
159
|
-
self.save(self._config)
|
|
168
|
+
await self.save(self._config)
|
|
160
169
|
|
|
161
170
|
return self._config
|
|
162
171
|
|
|
@@ -165,10 +174,10 @@ class ConfigManager:
|
|
|
165
174
|
"Failed to load configuration from %s: %s", self.config_path, e
|
|
166
175
|
)
|
|
167
176
|
logger.info("Creating new configuration with generated shotgun_instance_id")
|
|
168
|
-
self._config = self.initialize()
|
|
177
|
+
self._config = await self.initialize()
|
|
169
178
|
return self._config
|
|
170
179
|
|
|
171
|
-
def save(self, config: ShotgunConfig | None = None) -> None:
|
|
180
|
+
async def save(self, config: ShotgunConfig | None = None) -> None:
|
|
172
181
|
"""Save configuration to file.
|
|
173
182
|
|
|
174
183
|
Args:
|
|
@@ -184,15 +193,17 @@ class ConfigManager:
|
|
|
184
193
|
)
|
|
185
194
|
|
|
186
195
|
# Ensure directory exists
|
|
187
|
-
self.config_path.parent
|
|
196
|
+
await aiofiles.os.makedirs(self.config_path.parent, exist_ok=True)
|
|
188
197
|
|
|
189
198
|
try:
|
|
190
199
|
# Convert SecretStr to plain text for JSON serialization
|
|
191
200
|
data = config.model_dump()
|
|
192
201
|
self._convert_secretstr_to_plain(data)
|
|
202
|
+
self._convert_datetime_to_isoformat(data)
|
|
193
203
|
|
|
194
|
-
|
|
195
|
-
|
|
204
|
+
json_content = json.dumps(data, indent=2, ensure_ascii=False)
|
|
205
|
+
async with aiofiles.open(self.config_path, "w", encoding="utf-8") as f:
|
|
206
|
+
await f.write(json_content)
|
|
196
207
|
|
|
197
208
|
logger.debug("Configuration saved to %s", self.config_path)
|
|
198
209
|
self._config = config
|
|
@@ -201,14 +212,16 @@ class ConfigManager:
|
|
|
201
212
|
logger.error("Failed to save configuration to %s: %s", self.config_path, e)
|
|
202
213
|
raise
|
|
203
214
|
|
|
204
|
-
def update_provider(
|
|
215
|
+
async def update_provider(
|
|
216
|
+
self, provider: ProviderType | str, **kwargs: Any
|
|
217
|
+
) -> None:
|
|
205
218
|
"""Update provider configuration.
|
|
206
219
|
|
|
207
220
|
Args:
|
|
208
221
|
provider: Provider to update
|
|
209
222
|
**kwargs: Configuration fields to update (only api_key supported)
|
|
210
223
|
"""
|
|
211
|
-
config = self.load()
|
|
224
|
+
config = await self.load()
|
|
212
225
|
|
|
213
226
|
# Get provider config and check if it's shotgun
|
|
214
227
|
provider_config, is_shotgun = self._get_provider_config_and_type(
|
|
@@ -243,50 +256,61 @@ class ConfigManager:
|
|
|
243
256
|
|
|
244
257
|
provider_models = {
|
|
245
258
|
ProviderType.OPENAI: ModelName.GPT_5,
|
|
246
|
-
ProviderType.ANTHROPIC: ModelName.
|
|
259
|
+
ProviderType.ANTHROPIC: ModelName.CLAUDE_HAIKU_4_5,
|
|
247
260
|
ProviderType.GOOGLE: ModelName.GEMINI_2_5_PRO,
|
|
248
261
|
}
|
|
249
262
|
if provider_enum in provider_models:
|
|
250
263
|
config.selected_model = provider_models[provider_enum]
|
|
251
264
|
|
|
252
|
-
|
|
265
|
+
# Mark welcome screen as shown when BYOK provider is configured
|
|
266
|
+
# This prevents the welcome screen from showing again after user has made their choice
|
|
267
|
+
config.shown_welcome_screen = True
|
|
268
|
+
|
|
269
|
+
await self.save(config)
|
|
253
270
|
|
|
254
|
-
def clear_provider_key(self, provider: ProviderType | str) -> None:
|
|
271
|
+
async def clear_provider_key(self, provider: ProviderType | str) -> None:
|
|
255
272
|
"""Remove the API key for the given provider (LLM provider or shotgun)."""
|
|
256
|
-
config = self.load()
|
|
273
|
+
config = await self.load()
|
|
257
274
|
|
|
258
275
|
# Get provider config (shotgun or LLM provider)
|
|
259
|
-
provider_config,
|
|
276
|
+
provider_config, is_shotgun = self._get_provider_config_and_type(
|
|
277
|
+
config, provider
|
|
278
|
+
)
|
|
260
279
|
|
|
261
280
|
provider_config.api_key = None
|
|
262
|
-
self.save(config)
|
|
263
281
|
|
|
264
|
-
|
|
282
|
+
# For Shotgun Account, also clear the JWT
|
|
283
|
+
if is_shotgun and isinstance(provider_config, ShotgunAccountConfig):
|
|
284
|
+
provider_config.supabase_jwt = None
|
|
285
|
+
|
|
286
|
+
await self.save(config)
|
|
287
|
+
|
|
288
|
+
async def update_selected_model(self, model_name: "ModelName") -> None:
|
|
265
289
|
"""Update the selected model.
|
|
266
290
|
|
|
267
291
|
Args:
|
|
268
292
|
model_name: Model to select
|
|
269
293
|
"""
|
|
270
|
-
config = self.load()
|
|
294
|
+
config = await self.load()
|
|
271
295
|
config.selected_model = model_name
|
|
272
|
-
self.save(config)
|
|
296
|
+
await self.save(config)
|
|
273
297
|
|
|
274
|
-
def has_provider_key(self, provider: ProviderType | str) -> bool:
|
|
298
|
+
async def has_provider_key(self, provider: ProviderType | str) -> bool:
|
|
275
299
|
"""Check if the given provider has a non-empty API key configured.
|
|
276
300
|
|
|
277
301
|
This checks only the configuration file.
|
|
278
302
|
"""
|
|
279
303
|
# Use force_reload=False to avoid infinite loop when called from load()
|
|
280
|
-
config = self.load(force_reload=False)
|
|
304
|
+
config = await self.load(force_reload=False)
|
|
281
305
|
provider_enum = self._ensure_provider_enum(provider)
|
|
282
306
|
provider_config = self._get_provider_config(config, provider_enum)
|
|
283
307
|
|
|
284
308
|
return self._provider_has_api_key(provider_config)
|
|
285
309
|
|
|
286
|
-
def has_any_provider_key(self) -> bool:
|
|
310
|
+
async def has_any_provider_key(self) -> bool:
|
|
287
311
|
"""Determine whether any provider has a configured API key."""
|
|
288
312
|
# Use force_reload=False to avoid infinite loop when called from load()
|
|
289
|
-
config = self.load(force_reload=False)
|
|
313
|
+
config = await self.load(force_reload=False)
|
|
290
314
|
# Check LLM provider keys (BYOK)
|
|
291
315
|
has_llm_key = any(
|
|
292
316
|
self._provider_has_api_key(self._get_provider_config(config, provider))
|
|
@@ -300,7 +324,7 @@ class ConfigManager:
|
|
|
300
324
|
has_shotgun_key = self._provider_has_api_key(config.shotgun)
|
|
301
325
|
return has_llm_key or has_shotgun_key
|
|
302
326
|
|
|
303
|
-
def initialize(self) -> ShotgunConfig:
|
|
327
|
+
async def initialize(self) -> ShotgunConfig:
|
|
304
328
|
"""Initialize configuration with defaults and save to file.
|
|
305
329
|
|
|
306
330
|
Returns:
|
|
@@ -310,7 +334,7 @@ class ConfigManager:
|
|
|
310
334
|
config = ShotgunConfig(
|
|
311
335
|
shotgun_instance_id=str(uuid.uuid4()),
|
|
312
336
|
)
|
|
313
|
-
self.save(config)
|
|
337
|
+
await self.save(config)
|
|
314
338
|
logger.info(
|
|
315
339
|
"Configuration initialized at %s with shotgun_instance_id: %s",
|
|
316
340
|
self.config_path,
|
|
@@ -366,6 +390,24 @@ class ConfigManager:
|
|
|
366
390
|
SUPABASE_JWT_FIELD
|
|
367
391
|
].get_secret_value()
|
|
368
392
|
|
|
393
|
+
def _convert_datetime_to_isoformat(self, data: dict[str, Any]) -> None:
|
|
394
|
+
"""Convert datetime objects in data to ISO8601 format strings for JSON serialization."""
|
|
395
|
+
from datetime import datetime
|
|
396
|
+
|
|
397
|
+
def convert_dict(d: dict[str, Any]) -> None:
|
|
398
|
+
"""Recursively convert datetime objects in a dict."""
|
|
399
|
+
for key, value in d.items():
|
|
400
|
+
if isinstance(value, datetime):
|
|
401
|
+
d[key] = value.isoformat()
|
|
402
|
+
elif isinstance(value, dict):
|
|
403
|
+
convert_dict(value)
|
|
404
|
+
elif isinstance(value, list):
|
|
405
|
+
for item in value:
|
|
406
|
+
if isinstance(item, dict):
|
|
407
|
+
convert_dict(item)
|
|
408
|
+
|
|
409
|
+
convert_dict(data)
|
|
410
|
+
|
|
369
411
|
def _ensure_provider_enum(self, provider: ProviderType | str) -> ProviderType:
|
|
370
412
|
"""Normalize provider values to ProviderType enum."""
|
|
371
413
|
return (
|
|
@@ -429,16 +471,16 @@ class ConfigManager:
|
|
|
429
471
|
provider_enum = self._ensure_provider_enum(provider)
|
|
430
472
|
return (self._get_provider_config(config, provider_enum), False)
|
|
431
473
|
|
|
432
|
-
def get_shotgun_instance_id(self) -> str:
|
|
474
|
+
async def get_shotgun_instance_id(self) -> str:
|
|
433
475
|
"""Get the shotgun instance ID from configuration.
|
|
434
476
|
|
|
435
477
|
Returns:
|
|
436
478
|
The unique shotgun instance ID string
|
|
437
479
|
"""
|
|
438
|
-
config = self.load()
|
|
480
|
+
config = await self.load()
|
|
439
481
|
return config.shotgun_instance_id
|
|
440
482
|
|
|
441
|
-
def update_shotgun_account(
|
|
483
|
+
async def update_shotgun_account(
|
|
442
484
|
self, api_key: str | None = None, supabase_jwt: str | None = None
|
|
443
485
|
) -> None:
|
|
444
486
|
"""Update Shotgun Account configuration.
|
|
@@ -447,7 +489,7 @@ class ConfigManager:
|
|
|
447
489
|
api_key: LiteLLM proxy API key (optional)
|
|
448
490
|
supabase_jwt: Supabase authentication JWT (optional)
|
|
449
491
|
"""
|
|
450
|
-
config = self.load()
|
|
492
|
+
config = await self.load()
|
|
451
493
|
|
|
452
494
|
if api_key is not None:
|
|
453
495
|
config.shotgun.api_key = SecretStr(api_key) if api_key else None
|
|
@@ -457,7 +499,7 @@ class ConfigManager:
|
|
|
457
499
|
SecretStr(supabase_jwt) if supabase_jwt else None
|
|
458
500
|
)
|
|
459
501
|
|
|
460
|
-
self.save(config)
|
|
502
|
+
await self.save(config)
|
|
461
503
|
logger.info("Updated Shotgun Account configuration")
|
|
462
504
|
|
|
463
505
|
|