shotgun-sh 0.2.11.dev1__py3-none-any.whl → 0.2.11.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/agent_manager.py +150 -27
- shotgun/agents/common.py +14 -8
- shotgun/agents/config/manager.py +64 -33
- shotgun/agents/config/models.py +21 -1
- shotgun/agents/config/provider.py +2 -2
- shotgun/agents/context_analyzer/analyzer.py +2 -24
- shotgun/agents/conversation_manager.py +22 -13
- shotgun/agents/export.py +2 -2
- shotgun/agents/history/token_counting/anthropic.py +17 -1
- shotgun/agents/history/token_counting/base.py +14 -3
- shotgun/agents/history/token_counting/openai.py +8 -0
- shotgun/agents/history/token_counting/sentencepiece_counter.py +8 -0
- shotgun/agents/history/token_counting/tokenizer_cache.py +3 -1
- shotgun/agents/history/token_counting/utils.py +0 -3
- shotgun/agents/plan.py +2 -2
- shotgun/agents/research.py +3 -3
- shotgun/agents/specify.py +2 -2
- shotgun/agents/tasks.py +2 -2
- shotgun/agents/tools/codebase/file_read.py +5 -2
- shotgun/agents/tools/file_management.py +11 -7
- shotgun/agents/tools/web_search/__init__.py +8 -8
- shotgun/agents/tools/web_search/anthropic.py +2 -2
- shotgun/agents/tools/web_search/gemini.py +1 -1
- shotgun/agents/tools/web_search/openai.py +1 -1
- shotgun/agents/tools/web_search/utils.py +2 -2
- shotgun/agents/usage_manager.py +16 -11
- shotgun/cli/clear.py +2 -1
- shotgun/cli/compact.py +3 -3
- shotgun/cli/config.py +8 -5
- shotgun/cli/context.py +2 -2
- shotgun/cli/export.py +1 -1
- shotgun/cli/feedback.py +4 -2
- shotgun/cli/plan.py +1 -1
- shotgun/cli/research.py +1 -1
- shotgun/cli/specify.py +1 -1
- shotgun/cli/tasks.py +1 -1
- shotgun/codebase/core/change_detector.py +5 -3
- shotgun/codebase/core/code_retrieval.py +4 -2
- shotgun/codebase/core/ingestor.py +10 -8
- shotgun/codebase/core/manager.py +3 -3
- shotgun/codebase/core/nl_query.py +1 -1
- shotgun/logging_config.py +10 -17
- shotgun/main.py +3 -1
- shotgun/posthog_telemetry.py +14 -4
- shotgun/sentry_telemetry.py +3 -1
- shotgun/telemetry.py +3 -1
- shotgun/tui/app.py +62 -51
- shotgun/tui/components/context_indicator.py +43 -0
- shotgun/tui/containers.py +15 -17
- shotgun/tui/dependencies.py +2 -2
- shotgun/tui/screens/chat/chat_screen.py +75 -15
- shotgun/tui/screens/chat/help_text.py +16 -15
- shotgun/tui/screens/feedback.py +4 -4
- shotgun/tui/screens/model_picker.py +21 -20
- shotgun/tui/screens/provider_config.py +50 -27
- shotgun/tui/screens/shotgun_auth.py +2 -2
- shotgun/tui/screens/welcome.py +14 -11
- shotgun/tui/services/conversation_service.py +8 -8
- shotgun/tui/utils/mode_progress.py +14 -7
- shotgun/tui/widgets/widget_coordinator.py +15 -0
- shotgun/utils/file_system_utils.py +19 -0
- shotgun/utils/marketing.py +110 -0
- {shotgun_sh-0.2.11.dev1.dist-info → shotgun_sh-0.2.11.dev3.dist-info}/METADATA +2 -1
- {shotgun_sh-0.2.11.dev1.dist-info → shotgun_sh-0.2.11.dev3.dist-info}/RECORD +67 -66
- {shotgun_sh-0.2.11.dev1.dist-info → shotgun_sh-0.2.11.dev3.dist-info}/WHEEL +0 -0
- {shotgun_sh-0.2.11.dev1.dist-info → shotgun_sh-0.2.11.dev3.dist-info}/entry_points.txt +0 -0
- {shotgun_sh-0.2.11.dev1.dist-info → shotgun_sh-0.2.11.dev3.dist-info}/licenses/LICENSE +0 -0
|
@@ -67,26 +67,13 @@ class ContextAnalyzer:
|
|
|
67
67
|
for msg in reversed(message_history):
|
|
68
68
|
if isinstance(msg, ModelResponse) and msg.usage:
|
|
69
69
|
last_input_tokens = msg.usage.input_tokens + msg.usage.cache_read_tokens
|
|
70
|
-
logger.debug(
|
|
71
|
-
f"[ANALYZER] Found last response with usage - "
|
|
72
|
-
f"input_tokens={msg.usage.input_tokens}, "
|
|
73
|
-
f"cache_read_tokens={msg.usage.cache_read_tokens}, "
|
|
74
|
-
f"total={last_input_tokens}"
|
|
75
|
-
)
|
|
76
70
|
break
|
|
77
71
|
|
|
78
72
|
if last_input_tokens == 0:
|
|
79
|
-
|
|
80
|
-
f"[ANALYZER] No usage data found in message history! "
|
|
81
|
-
f"message_count={len(message_history)}, "
|
|
82
|
-
f"response_count={sum(1 for m in message_history if isinstance(m, ModelResponse))}"
|
|
83
|
-
)
|
|
84
|
-
# Fallback to token estimation
|
|
85
|
-
logger.info("[ANALYZER] Falling back to token estimation")
|
|
73
|
+
# Fallback to token estimation (no logging to reduce verbosity)
|
|
86
74
|
last_input_tokens = await estimate_tokens_from_messages(
|
|
87
75
|
message_history, self.model_config
|
|
88
76
|
)
|
|
89
|
-
logger.debug(f"[ANALYZER] Estimated tokens: {last_input_tokens}")
|
|
90
77
|
|
|
91
78
|
# Step 2: Calculate total output tokens (sum across all responses)
|
|
92
79
|
for msg in message_history:
|
|
@@ -247,16 +234,7 @@ class ContextAnalyzer:
|
|
|
247
234
|
# If no content, put all in agent responses
|
|
248
235
|
agent_response_tokens = total_output_tokens
|
|
249
236
|
|
|
250
|
-
|
|
251
|
-
f"Token allocation complete: user={user_tokens}, agent_responses={agent_response_tokens}, "
|
|
252
|
-
f"system_prompts={system_prompt_tokens}, system_status={system_status_tokens}, "
|
|
253
|
-
f"codebase_understanding={codebase_understanding_tokens}, "
|
|
254
|
-
f"artifact_management={artifact_management_tokens}, web_research={web_research_tokens}, "
|
|
255
|
-
f"unknown={unknown_tokens}"
|
|
256
|
-
)
|
|
257
|
-
logger.debug(
|
|
258
|
-
f"Input tokens (from last response): {last_input_tokens}, Output tokens (sum): {total_output_tokens}"
|
|
259
|
-
)
|
|
237
|
+
# Token allocation complete (no logging to reduce verbosity)
|
|
260
238
|
|
|
261
239
|
# Create TokenAllocation model
|
|
262
240
|
return TokenAllocation(
|
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
"""Manager for handling conversation persistence operations."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import json
|
|
4
|
-
import shutil
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
|
+
import aiofiles
|
|
8
|
+
import aiofiles.os
|
|
9
|
+
|
|
7
10
|
from shotgun.logging_config import get_logger
|
|
8
11
|
from shotgun.utils import get_shotgun_home
|
|
12
|
+
from shotgun.utils.file_system_utils import async_copy_file
|
|
9
13
|
|
|
10
14
|
from .conversation_history import ConversationHistory
|
|
11
15
|
|
|
@@ -27,14 +31,14 @@ class ConversationManager:
|
|
|
27
31
|
else:
|
|
28
32
|
self.conversation_path = conversation_path
|
|
29
33
|
|
|
30
|
-
def save(self, conversation: ConversationHistory) -> None:
|
|
34
|
+
async def save(self, conversation: ConversationHistory) -> None:
|
|
31
35
|
"""Save conversation history to file.
|
|
32
36
|
|
|
33
37
|
Args:
|
|
34
38
|
conversation: ConversationHistory to save
|
|
35
39
|
"""
|
|
36
40
|
# Ensure directory exists
|
|
37
|
-
self.conversation_path.parent
|
|
41
|
+
await aiofiles.os.makedirs(self.conversation_path.parent, exist_ok=True)
|
|
38
42
|
|
|
39
43
|
try:
|
|
40
44
|
# Update timestamp
|
|
@@ -44,9 +48,12 @@ class ConversationManager:
|
|
|
44
48
|
|
|
45
49
|
# Serialize to JSON using Pydantic's model_dump
|
|
46
50
|
data = conversation.model_dump(mode="json")
|
|
51
|
+
json_content = json.dumps(data, indent=2, ensure_ascii=False)
|
|
47
52
|
|
|
48
|
-
with open(
|
|
49
|
-
|
|
53
|
+
async with aiofiles.open(
|
|
54
|
+
self.conversation_path, "w", encoding="utf-8"
|
|
55
|
+
) as f:
|
|
56
|
+
await f.write(json_content)
|
|
50
57
|
|
|
51
58
|
logger.debug("Conversation saved to %s", self.conversation_path)
|
|
52
59
|
|
|
@@ -56,19 +63,20 @@ class ConversationManager:
|
|
|
56
63
|
)
|
|
57
64
|
# Don't raise - we don't want to interrupt the user's session
|
|
58
65
|
|
|
59
|
-
def load(self) -> ConversationHistory | None:
|
|
66
|
+
async def load(self) -> ConversationHistory | None:
|
|
60
67
|
"""Load conversation history from file.
|
|
61
68
|
|
|
62
69
|
Returns:
|
|
63
70
|
ConversationHistory if file exists and is valid, None otherwise
|
|
64
71
|
"""
|
|
65
|
-
if not
|
|
72
|
+
if not await aiofiles.os.path.exists(self.conversation_path):
|
|
66
73
|
logger.debug("No conversation history found at %s", self.conversation_path)
|
|
67
74
|
return None
|
|
68
75
|
|
|
69
76
|
try:
|
|
70
|
-
with open(self.conversation_path, encoding="utf-8") as f:
|
|
71
|
-
|
|
77
|
+
async with aiofiles.open(self.conversation_path, encoding="utf-8") as f:
|
|
78
|
+
content = await f.read()
|
|
79
|
+
data = json.loads(content)
|
|
72
80
|
|
|
73
81
|
conversation = ConversationHistory.model_validate(data)
|
|
74
82
|
logger.debug(
|
|
@@ -89,7 +97,7 @@ class ConversationManager:
|
|
|
89
97
|
# Create a backup of the corrupted file for debugging
|
|
90
98
|
backup_path = self.conversation_path.with_suffix(".json.backup")
|
|
91
99
|
try:
|
|
92
|
-
|
|
100
|
+
await async_copy_file(self.conversation_path, backup_path)
|
|
93
101
|
logger.info("Backed up corrupted conversation to %s", backup_path)
|
|
94
102
|
except Exception as backup_error: # pragma: no cover
|
|
95
103
|
logger.warning("Failed to backup corrupted file: %s", backup_error)
|
|
@@ -105,11 +113,12 @@ class ConversationManager:
|
|
|
105
113
|
)
|
|
106
114
|
return None
|
|
107
115
|
|
|
108
|
-
def clear(self) -> None:
|
|
116
|
+
async def clear(self) -> None:
|
|
109
117
|
"""Delete the conversation history file."""
|
|
110
|
-
if
|
|
118
|
+
if await aiofiles.os.path.exists(self.conversation_path):
|
|
111
119
|
try:
|
|
112
|
-
|
|
120
|
+
# Use asyncio.to_thread for unlink operation
|
|
121
|
+
await asyncio.to_thread(self.conversation_path.unlink)
|
|
113
122
|
logger.debug(
|
|
114
123
|
"Conversation history cleared at %s", self.conversation_path
|
|
115
124
|
)
|
shotgun/agents/export.py
CHANGED
|
@@ -23,7 +23,7 @@ from .models import AgentDeps, AgentResponse, AgentRuntimeOptions, AgentType
|
|
|
23
23
|
logger = get_logger(__name__)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def create_export_agent(
|
|
26
|
+
async def create_export_agent(
|
|
27
27
|
agent_runtime_options: AgentRuntimeOptions, provider: ProviderType | None = None
|
|
28
28
|
) -> tuple[Agent[AgentDeps, AgentResponse], AgentDeps]:
|
|
29
29
|
"""Create an export agent with file management capabilities.
|
|
@@ -39,7 +39,7 @@ def create_export_agent(
|
|
|
39
39
|
# Use partial to create system prompt function for export agent
|
|
40
40
|
system_prompt_fn = partial(build_agent_system_prompt, "export")
|
|
41
41
|
|
|
42
|
-
agent, deps = create_base_agent(
|
|
42
|
+
agent, deps = await create_base_agent(
|
|
43
43
|
system_prompt_fn,
|
|
44
44
|
agent_runtime_options,
|
|
45
45
|
provider=provider,
|
|
@@ -72,11 +72,23 @@ class AnthropicTokenCounter(TokenCounter):
|
|
|
72
72
|
Raises:
|
|
73
73
|
RuntimeError: If API call fails
|
|
74
74
|
"""
|
|
75
|
+
# Handle empty text to avoid unnecessary API calls
|
|
76
|
+
# Anthropic API requires non-empty content, so we need a strict check
|
|
77
|
+
if not text or not text.strip():
|
|
78
|
+
return 0
|
|
79
|
+
|
|
80
|
+
# Additional validation: ensure the text has actual content
|
|
81
|
+
# Some edge cases might have only whitespace or control characters
|
|
82
|
+
cleaned_text = text.strip()
|
|
83
|
+
if not cleaned_text:
|
|
84
|
+
return 0
|
|
85
|
+
|
|
75
86
|
try:
|
|
76
87
|
# Anthropic API expects messages format and model parameter
|
|
77
88
|
# Use await with async client
|
|
78
89
|
result = await self.client.messages.count_tokens(
|
|
79
|
-
messages=[{"role": "user", "content":
|
|
90
|
+
messages=[{"role": "user", "content": cleaned_text}],
|
|
91
|
+
model=self.model_name,
|
|
80
92
|
)
|
|
81
93
|
return result.input_tokens
|
|
82
94
|
except Exception as e:
|
|
@@ -107,5 +119,9 @@ class AnthropicTokenCounter(TokenCounter):
|
|
|
107
119
|
Raises:
|
|
108
120
|
RuntimeError: If token counting fails
|
|
109
121
|
"""
|
|
122
|
+
# Handle empty message list early
|
|
123
|
+
if not messages:
|
|
124
|
+
return 0
|
|
125
|
+
|
|
110
126
|
total_text = extract_text_from_messages(messages)
|
|
111
127
|
return await self.count_tokens(total_text)
|
|
@@ -56,12 +56,23 @@ def extract_text_from_messages(messages: list[ModelMessage]) -> str:
|
|
|
56
56
|
if hasattr(message, "parts"):
|
|
57
57
|
for part in message.parts:
|
|
58
58
|
if hasattr(part, "content") and isinstance(part.content, str):
|
|
59
|
-
|
|
59
|
+
# Only add non-empty content
|
|
60
|
+
if part.content.strip():
|
|
61
|
+
text_parts.append(part.content)
|
|
60
62
|
else:
|
|
61
63
|
# Handle non-text parts (tool calls, etc.)
|
|
62
|
-
|
|
64
|
+
part_str = str(part)
|
|
65
|
+
if part_str.strip():
|
|
66
|
+
text_parts.append(part_str)
|
|
63
67
|
else:
|
|
64
68
|
# Handle messages without parts
|
|
65
|
-
|
|
69
|
+
msg_str = str(message)
|
|
70
|
+
if msg_str.strip():
|
|
71
|
+
text_parts.append(msg_str)
|
|
72
|
+
|
|
73
|
+
# If no valid text parts found, return a minimal placeholder
|
|
74
|
+
# This ensures we never send completely empty content to APIs
|
|
75
|
+
if not text_parts:
|
|
76
|
+
return "."
|
|
66
77
|
|
|
67
78
|
return "\n".join(text_parts)
|
|
@@ -57,6 +57,10 @@ class OpenAITokenCounter(TokenCounter):
|
|
|
57
57
|
Raises:
|
|
58
58
|
RuntimeError: If token counting fails
|
|
59
59
|
"""
|
|
60
|
+
# Handle empty text to avoid unnecessary encoding
|
|
61
|
+
if not text or not text.strip():
|
|
62
|
+
return 0
|
|
63
|
+
|
|
60
64
|
try:
|
|
61
65
|
return len(self.encoding.encode(text))
|
|
62
66
|
except Exception as e:
|
|
@@ -76,5 +80,9 @@ class OpenAITokenCounter(TokenCounter):
|
|
|
76
80
|
Raises:
|
|
77
81
|
RuntimeError: If token counting fails
|
|
78
82
|
"""
|
|
83
|
+
# Handle empty message list early
|
|
84
|
+
if not messages:
|
|
85
|
+
return 0
|
|
86
|
+
|
|
79
87
|
total_text = extract_text_from_messages(messages)
|
|
80
88
|
return await self.count_tokens(total_text)
|
|
@@ -88,6 +88,10 @@ class SentencePieceTokenCounter(TokenCounter):
|
|
|
88
88
|
Raises:
|
|
89
89
|
RuntimeError: If token counting fails
|
|
90
90
|
"""
|
|
91
|
+
# Handle empty text to avoid unnecessary tokenization
|
|
92
|
+
if not text or not text.strip():
|
|
93
|
+
return 0
|
|
94
|
+
|
|
91
95
|
await self._ensure_tokenizer()
|
|
92
96
|
|
|
93
97
|
if self.sp is None:
|
|
@@ -115,5 +119,9 @@ class SentencePieceTokenCounter(TokenCounter):
|
|
|
115
119
|
Raises:
|
|
116
120
|
RuntimeError: If token counting fails
|
|
117
121
|
"""
|
|
122
|
+
# Handle empty message list early
|
|
123
|
+
if not messages:
|
|
124
|
+
return 0
|
|
125
|
+
|
|
118
126
|
total_text = extract_text_from_messages(messages)
|
|
119
127
|
return await self.count_tokens(total_text)
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
import hashlib
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
|
|
6
|
+
import aiofiles
|
|
6
7
|
import httpx
|
|
7
8
|
|
|
8
9
|
from shotgun.logging_config import get_logger
|
|
@@ -78,7 +79,8 @@ async def download_gemini_tokenizer() -> Path:
|
|
|
78
79
|
|
|
79
80
|
# Atomic write: write to temp file first, then rename
|
|
80
81
|
temp_path = cache_path.with_suffix(".tmp")
|
|
81
|
-
|
|
82
|
+
async with aiofiles.open(temp_path, "wb") as f:
|
|
83
|
+
await f.write(content)
|
|
82
84
|
temp_path.rename(cache_path)
|
|
83
85
|
|
|
84
86
|
logger.info(f"Gemini tokenizer downloaded and cached at {cache_path}")
|
|
@@ -44,9 +44,6 @@ def get_token_counter(model_config: ModelConfig) -> TokenCounter:
|
|
|
44
44
|
|
|
45
45
|
# Return cached instance if available
|
|
46
46
|
if cache_key in _token_counter_cache:
|
|
47
|
-
logger.debug(
|
|
48
|
-
f"Reusing cached token counter for {model_config.provider.value}:{model_config.name}"
|
|
49
|
-
)
|
|
50
47
|
return _token_counter_cache[cache_key]
|
|
51
48
|
|
|
52
49
|
# Create new instance and cache it
|
shotgun/agents/plan.py
CHANGED
|
@@ -23,7 +23,7 @@ from .models import AgentDeps, AgentResponse, AgentRuntimeOptions, AgentType
|
|
|
23
23
|
logger = get_logger(__name__)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def create_plan_agent(
|
|
26
|
+
async def create_plan_agent(
|
|
27
27
|
agent_runtime_options: AgentRuntimeOptions, provider: ProviderType | None = None
|
|
28
28
|
) -> tuple[Agent[AgentDeps, AgentResponse], AgentDeps]:
|
|
29
29
|
"""Create a plan agent with artifact management capabilities.
|
|
@@ -39,7 +39,7 @@ def create_plan_agent(
|
|
|
39
39
|
# Use partial to create system prompt function for plan agent
|
|
40
40
|
system_prompt_fn = partial(build_agent_system_prompt, "plan")
|
|
41
41
|
|
|
42
|
-
agent, deps = create_base_agent(
|
|
42
|
+
agent, deps = await create_base_agent(
|
|
43
43
|
system_prompt_fn,
|
|
44
44
|
agent_runtime_options,
|
|
45
45
|
load_codebase_understanding_tools=True,
|
shotgun/agents/research.py
CHANGED
|
@@ -26,7 +26,7 @@ from .tools import get_available_web_search_tools
|
|
|
26
26
|
logger = get_logger(__name__)
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
def create_research_agent(
|
|
29
|
+
async def create_research_agent(
|
|
30
30
|
agent_runtime_options: AgentRuntimeOptions, provider: ProviderType | None = None
|
|
31
31
|
) -> tuple[Agent[AgentDeps, AgentResponse], AgentDeps]:
|
|
32
32
|
"""Create a research agent with web search and artifact management capabilities.
|
|
@@ -41,7 +41,7 @@ def create_research_agent(
|
|
|
41
41
|
logger.debug("Initializing research agent")
|
|
42
42
|
|
|
43
43
|
# Get available web search tools based on configured API keys
|
|
44
|
-
web_search_tools = get_available_web_search_tools()
|
|
44
|
+
web_search_tools = await get_available_web_search_tools()
|
|
45
45
|
if web_search_tools:
|
|
46
46
|
logger.info(
|
|
47
47
|
"Research agent configured with %d web search tool(s)",
|
|
@@ -53,7 +53,7 @@ def create_research_agent(
|
|
|
53
53
|
# Use partial to create system prompt function for research agent
|
|
54
54
|
system_prompt_fn = partial(build_agent_system_prompt, "research")
|
|
55
55
|
|
|
56
|
-
agent, deps = create_base_agent(
|
|
56
|
+
agent, deps = await create_base_agent(
|
|
57
57
|
system_prompt_fn,
|
|
58
58
|
agent_runtime_options,
|
|
59
59
|
load_codebase_understanding_tools=True,
|
shotgun/agents/specify.py
CHANGED
|
@@ -23,7 +23,7 @@ from .models import AgentDeps, AgentResponse, AgentRuntimeOptions, AgentType
|
|
|
23
23
|
logger = get_logger(__name__)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def create_specify_agent(
|
|
26
|
+
async def create_specify_agent(
|
|
27
27
|
agent_runtime_options: AgentRuntimeOptions, provider: ProviderType | None = None
|
|
28
28
|
) -> tuple[Agent[AgentDeps, AgentResponse], AgentDeps]:
|
|
29
29
|
"""Create a specify agent with artifact management capabilities.
|
|
@@ -39,7 +39,7 @@ def create_specify_agent(
|
|
|
39
39
|
# Use partial to create system prompt function for specify agent
|
|
40
40
|
system_prompt_fn = partial(build_agent_system_prompt, "specify")
|
|
41
41
|
|
|
42
|
-
agent, deps = create_base_agent(
|
|
42
|
+
agent, deps = await create_base_agent(
|
|
43
43
|
system_prompt_fn,
|
|
44
44
|
agent_runtime_options,
|
|
45
45
|
load_codebase_understanding_tools=True,
|
shotgun/agents/tasks.py
CHANGED
|
@@ -23,7 +23,7 @@ from .models import AgentDeps, AgentResponse, AgentRuntimeOptions, AgentType
|
|
|
23
23
|
logger = get_logger(__name__)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def create_tasks_agent(
|
|
26
|
+
async def create_tasks_agent(
|
|
27
27
|
agent_runtime_options: AgentRuntimeOptions, provider: ProviderType | None = None
|
|
28
28
|
) -> tuple[Agent[AgentDeps, AgentResponse], AgentDeps]:
|
|
29
29
|
"""Create a tasks agent with file management capabilities.
|
|
@@ -39,7 +39,7 @@ def create_tasks_agent(
|
|
|
39
39
|
# Use partial to create system prompt function for tasks agent
|
|
40
40
|
system_prompt_fn = partial(build_agent_system_prompt, "tasks")
|
|
41
41
|
|
|
42
|
-
agent, deps = create_base_agent(
|
|
42
|
+
agent, deps = await create_base_agent(
|
|
43
43
|
system_prompt_fn,
|
|
44
44
|
agent_runtime_options,
|
|
45
45
|
provider=provider,
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
|
|
5
|
+
import aiofiles
|
|
5
6
|
from pydantic_ai import RunContext
|
|
6
7
|
|
|
7
8
|
from shotgun.agents.models import AgentDeps
|
|
@@ -93,7 +94,8 @@ async def file_read(
|
|
|
93
94
|
# Read file contents
|
|
94
95
|
encoding_used = "utf-8"
|
|
95
96
|
try:
|
|
96
|
-
|
|
97
|
+
async with aiofiles.open(full_file_path, encoding="utf-8") as f:
|
|
98
|
+
content = await f.read()
|
|
97
99
|
size_bytes = full_file_path.stat().st_size
|
|
98
100
|
|
|
99
101
|
logger.debug(
|
|
@@ -119,7 +121,8 @@ async def file_read(
|
|
|
119
121
|
try:
|
|
120
122
|
# Try with different encoding
|
|
121
123
|
encoding_used = "latin-1"
|
|
122
|
-
|
|
124
|
+
async with aiofiles.open(full_file_path, encoding="latin-1") as f:
|
|
125
|
+
content = await f.read()
|
|
123
126
|
size_bytes = full_file_path.stat().st_size
|
|
124
127
|
|
|
125
128
|
# Detect language from file extension
|
|
@@ -6,6 +6,8 @@ These tools are restricted to the .shotgun directory for security.
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from typing import Literal
|
|
8
8
|
|
|
9
|
+
import aiofiles
|
|
10
|
+
import aiofiles.os
|
|
9
11
|
from pydantic_ai import RunContext
|
|
10
12
|
|
|
11
13
|
from shotgun.agents.models import AgentDeps, AgentType, FileOperationType
|
|
@@ -181,10 +183,11 @@ async def read_file(ctx: RunContext[AgentDeps], filename: str) -> str:
|
|
|
181
183
|
try:
|
|
182
184
|
file_path = _validate_shotgun_path(filename)
|
|
183
185
|
|
|
184
|
-
if not
|
|
186
|
+
if not await aiofiles.os.path.exists(file_path):
|
|
185
187
|
raise FileNotFoundError(f"File not found: {filename}")
|
|
186
188
|
|
|
187
|
-
|
|
189
|
+
async with aiofiles.open(file_path, encoding="utf-8") as f:
|
|
190
|
+
content = await f.read()
|
|
188
191
|
logger.debug("📄 Read %d characters from %s", len(content), filename)
|
|
189
192
|
return content
|
|
190
193
|
|
|
@@ -233,21 +236,22 @@ async def write_file(
|
|
|
233
236
|
else:
|
|
234
237
|
operation = (
|
|
235
238
|
FileOperationType.CREATED
|
|
236
|
-
if not
|
|
239
|
+
if not await aiofiles.os.path.exists(file_path)
|
|
237
240
|
else FileOperationType.UPDATED
|
|
238
241
|
)
|
|
239
242
|
|
|
240
243
|
# Ensure parent directory exists
|
|
241
|
-
file_path.parent
|
|
244
|
+
await aiofiles.os.makedirs(file_path.parent, exist_ok=True)
|
|
242
245
|
|
|
243
246
|
# Write content
|
|
244
247
|
if mode == "a":
|
|
245
|
-
with open(file_path, "a", encoding="utf-8") as f:
|
|
246
|
-
f.write(content)
|
|
248
|
+
async with aiofiles.open(file_path, "a", encoding="utf-8") as f:
|
|
249
|
+
await f.write(content)
|
|
247
250
|
logger.debug("📄 Appended %d characters to %s", len(content), filename)
|
|
248
251
|
result = f"Successfully appended {len(content)} characters to {filename}"
|
|
249
252
|
else:
|
|
250
|
-
|
|
253
|
+
async with aiofiles.open(file_path, "w", encoding="utf-8") as f:
|
|
254
|
+
await f.write(content)
|
|
251
255
|
logger.debug("📄 Wrote %d characters to %s", len(content), filename)
|
|
252
256
|
result = f"Successfully wrote {len(content)} characters to {filename}"
|
|
253
257
|
|
|
@@ -26,7 +26,7 @@ logger = get_logger(__name__)
|
|
|
26
26
|
WebSearchTool = Callable[[str], Awaitable[str]]
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
def get_available_web_search_tools() -> list[WebSearchTool]:
|
|
29
|
+
async def get_available_web_search_tools() -> list[WebSearchTool]:
|
|
30
30
|
"""Get list of available web search tools based on configured API keys.
|
|
31
31
|
|
|
32
32
|
Works with both Shotgun Account (via LiteLLM proxy) and BYOK (individual provider keys).
|
|
@@ -43,25 +43,25 @@ def get_available_web_search_tools() -> list[WebSearchTool]:
|
|
|
43
43
|
|
|
44
44
|
# Check if using Shotgun Account
|
|
45
45
|
config_manager = get_config_manager()
|
|
46
|
-
config = config_manager.load()
|
|
46
|
+
config = await config_manager.load()
|
|
47
47
|
has_shotgun_key = config.shotgun.api_key is not None
|
|
48
48
|
|
|
49
49
|
if has_shotgun_key:
|
|
50
50
|
logger.debug("🔑 Shotgun Account - only Gemini web search available")
|
|
51
51
|
|
|
52
52
|
# Gemini: Only search tool available for Shotgun Account
|
|
53
|
-
if is_provider_available(ProviderType.GOOGLE):
|
|
53
|
+
if await is_provider_available(ProviderType.GOOGLE):
|
|
54
54
|
logger.debug("✅ Gemini web search tool available")
|
|
55
55
|
tools.append(gemini_web_search_tool)
|
|
56
56
|
|
|
57
57
|
# Anthropic: Not available for Shotgun Account (Gemini-only for Shotgun)
|
|
58
|
-
if is_provider_available(ProviderType.ANTHROPIC):
|
|
58
|
+
if await is_provider_available(ProviderType.ANTHROPIC):
|
|
59
59
|
logger.debug(
|
|
60
60
|
"⚠️ Anthropic web search requires BYOK (Shotgun Account uses Gemini only)"
|
|
61
61
|
)
|
|
62
62
|
|
|
63
63
|
# OpenAI: Not available for Shotgun Account (Responses API incompatible with proxy)
|
|
64
|
-
if is_provider_available(ProviderType.OPENAI):
|
|
64
|
+
if await is_provider_available(ProviderType.OPENAI):
|
|
65
65
|
logger.debug(
|
|
66
66
|
"⚠️ OpenAI web search requires BYOK (Responses API not supported via proxy)"
|
|
67
67
|
)
|
|
@@ -69,15 +69,15 @@ def get_available_web_search_tools() -> list[WebSearchTool]:
|
|
|
69
69
|
# BYOK mode: Load all available tools based on individual provider keys
|
|
70
70
|
logger.debug("🔑 BYOK mode - checking all provider web search tools")
|
|
71
71
|
|
|
72
|
-
if is_provider_available(ProviderType.OPENAI):
|
|
72
|
+
if await is_provider_available(ProviderType.OPENAI):
|
|
73
73
|
logger.debug("✅ OpenAI web search tool available")
|
|
74
74
|
tools.append(openai_web_search_tool)
|
|
75
75
|
|
|
76
|
-
if is_provider_available(ProviderType.ANTHROPIC):
|
|
76
|
+
if await is_provider_available(ProviderType.ANTHROPIC):
|
|
77
77
|
logger.debug("✅ Anthropic web search tool available")
|
|
78
78
|
tools.append(anthropic_web_search_tool)
|
|
79
79
|
|
|
80
|
-
if is_provider_available(ProviderType.GOOGLE):
|
|
80
|
+
if await is_provider_available(ProviderType.GOOGLE):
|
|
81
81
|
logger.debug("✅ Gemini web search tool available")
|
|
82
82
|
tools.append(gemini_web_search_tool)
|
|
83
83
|
|
|
@@ -46,7 +46,7 @@ async def anthropic_web_search_tool(query: str) -> str:
|
|
|
46
46
|
|
|
47
47
|
# Get model configuration (supports both Shotgun and BYOK)
|
|
48
48
|
try:
|
|
49
|
-
model_config = get_provider_model(ProviderType.ANTHROPIC)
|
|
49
|
+
model_config = await get_provider_model(ProviderType.ANTHROPIC)
|
|
50
50
|
except ValueError as e:
|
|
51
51
|
error_msg = f"Anthropic API key not configured: {str(e)}"
|
|
52
52
|
logger.error("❌ %s", error_msg)
|
|
@@ -141,7 +141,7 @@ async def main() -> None:
|
|
|
141
141
|
# Check if API key is available
|
|
142
142
|
try:
|
|
143
143
|
if callable(get_provider_model):
|
|
144
|
-
model_config = get_provider_model(ProviderType.ANTHROPIC)
|
|
144
|
+
model_config = await get_provider_model(ProviderType.ANTHROPIC)
|
|
145
145
|
if not model_config.api_key:
|
|
146
146
|
raise ValueError("No API key configured")
|
|
147
147
|
except (ValueError, Exception):
|
|
@@ -46,7 +46,7 @@ async def gemini_web_search_tool(query: str) -> str:
|
|
|
46
46
|
|
|
47
47
|
# Get model configuration (supports both Shotgun and BYOK)
|
|
48
48
|
try:
|
|
49
|
-
model_config = get_provider_model(ModelName.GEMINI_2_5_FLASH)
|
|
49
|
+
model_config = await get_provider_model(ModelName.GEMINI_2_5_FLASH)
|
|
50
50
|
except ValueError as e:
|
|
51
51
|
error_msg = f"Gemini API key not configured: {str(e)}"
|
|
52
52
|
logger.error("❌ %s", error_msg)
|
|
@@ -43,7 +43,7 @@ async def openai_web_search_tool(query: str) -> str:
|
|
|
43
43
|
|
|
44
44
|
# Get API key from centralized configuration
|
|
45
45
|
try:
|
|
46
|
-
model_config = get_provider_model(ProviderType.OPENAI)
|
|
46
|
+
model_config = await get_provider_model(ProviderType.OPENAI)
|
|
47
47
|
api_key = model_config.api_key
|
|
48
48
|
except ValueError as e:
|
|
49
49
|
error_msg = f"OpenAI API key not configured: {str(e)}"
|
|
@@ -4,7 +4,7 @@ from shotgun.agents.config import get_provider_model
|
|
|
4
4
|
from shotgun.agents.config.models import ProviderType
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
def is_provider_available(provider: ProviderType) -> bool:
|
|
7
|
+
async def is_provider_available(provider: ProviderType) -> bool:
|
|
8
8
|
"""Check if a provider has API key configured.
|
|
9
9
|
|
|
10
10
|
Args:
|
|
@@ -14,7 +14,7 @@ def is_provider_available(provider: ProviderType) -> bool:
|
|
|
14
14
|
True if the provider has valid credentials configured (from config or env)
|
|
15
15
|
"""
|
|
16
16
|
try:
|
|
17
|
-
get_provider_model(provider)
|
|
17
|
+
await get_provider_model(provider)
|
|
18
18
|
return True
|
|
19
19
|
except ValueError:
|
|
20
20
|
return False
|
shotgun/agents/usage_manager.py
CHANGED
|
@@ -6,6 +6,8 @@ from logging import getLogger
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from typing import TypeAlias
|
|
8
8
|
|
|
9
|
+
import aiofiles
|
|
10
|
+
import aiofiles.os
|
|
9
11
|
from genai_prices import calc_price
|
|
10
12
|
from pydantic import BaseModel, Field
|
|
11
13
|
from pydantic_ai import RunUsage
|
|
@@ -48,9 +50,10 @@ class SessionUsageManager:
|
|
|
48
50
|
self._model_providers: dict[ModelName, ProviderType] = {}
|
|
49
51
|
self._usage_log: list[UsageLogEntry] = []
|
|
50
52
|
self._usage_path: Path = get_shotgun_home() / "usage.json"
|
|
51
|
-
|
|
53
|
+
# Note: restore_usage_state needs to be called asynchronously after init
|
|
54
|
+
# Caller should use: manager = SessionUsageManager(); await manager.restore_usage_state()
|
|
52
55
|
|
|
53
|
-
def add_usage(
|
|
56
|
+
async def add_usage(
|
|
54
57
|
self, usage: RunUsage, *, model_name: ModelName, provider: ProviderType
|
|
55
58
|
) -> None:
|
|
56
59
|
self.usage[model_name] += usage
|
|
@@ -58,7 +61,7 @@ class SessionUsageManager:
|
|
|
58
61
|
self._usage_log.append(
|
|
59
62
|
UsageLogEntry(model_name=model_name, usage=usage, provider=provider)
|
|
60
63
|
)
|
|
61
|
-
self.persist_usage_state()
|
|
64
|
+
await self.persist_usage_state()
|
|
62
65
|
|
|
63
66
|
def get_usage_report(self) -> dict[ModelName, RunUsage]:
|
|
64
67
|
return self.usage.copy()
|
|
@@ -78,7 +81,7 @@ class SessionUsageManager:
|
|
|
78
81
|
def build_usage_hint(self) -> str | None:
|
|
79
82
|
return format_usage_hint(self.get_usage_breakdown())
|
|
80
83
|
|
|
81
|
-
def persist_usage_state(self) -> None:
|
|
84
|
+
async def persist_usage_state(self) -> None:
|
|
82
85
|
state = UsageState(
|
|
83
86
|
usage=dict(self.usage.items()),
|
|
84
87
|
model_providers=self._model_providers.copy(),
|
|
@@ -86,23 +89,25 @@ class SessionUsageManager:
|
|
|
86
89
|
)
|
|
87
90
|
|
|
88
91
|
try:
|
|
89
|
-
self._usage_path.parent
|
|
90
|
-
|
|
91
|
-
|
|
92
|
+
await aiofiles.os.makedirs(self._usage_path.parent, exist_ok=True)
|
|
93
|
+
json_content = json.dumps(state.model_dump(mode="json"), indent=2)
|
|
94
|
+
async with aiofiles.open(self._usage_path, "w", encoding="utf-8") as f:
|
|
95
|
+
await f.write(json_content)
|
|
92
96
|
logger.debug("Usage state persisted to %s", self._usage_path)
|
|
93
97
|
except Exception as exc:
|
|
94
98
|
logger.error(
|
|
95
99
|
"Failed to persist usage state to %s: %s", self._usage_path, exc
|
|
96
100
|
)
|
|
97
101
|
|
|
98
|
-
def restore_usage_state(self) -> None:
|
|
99
|
-
if not
|
|
102
|
+
async def restore_usage_state(self) -> None:
|
|
103
|
+
if not await aiofiles.os.path.exists(self._usage_path):
|
|
100
104
|
logger.debug("No usage state file found at %s", self._usage_path)
|
|
101
105
|
return
|
|
102
106
|
|
|
103
107
|
try:
|
|
104
|
-
with self._usage_path
|
|
105
|
-
|
|
108
|
+
async with aiofiles.open(self._usage_path, encoding="utf-8") as f:
|
|
109
|
+
content = await f.read()
|
|
110
|
+
data = json.loads(content)
|
|
106
111
|
|
|
107
112
|
state = UsageState.model_validate(data)
|
|
108
113
|
except Exception as exc:
|