shotgun-sh 0.2.8.dev2__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shotgun/agents/agent_manager.py +354 -46
- shotgun/agents/common.py +14 -8
- shotgun/agents/config/constants.py +0 -6
- shotgun/agents/config/manager.py +66 -35
- shotgun/agents/config/models.py +41 -1
- shotgun/agents/config/provider.py +33 -5
- shotgun/agents/context_analyzer/__init__.py +28 -0
- shotgun/agents/context_analyzer/analyzer.py +471 -0
- shotgun/agents/context_analyzer/constants.py +9 -0
- shotgun/agents/context_analyzer/formatter.py +115 -0
- shotgun/agents/context_analyzer/models.py +212 -0
- shotgun/agents/conversation_history.py +2 -0
- shotgun/agents/conversation_manager.py +35 -19
- shotgun/agents/export.py +2 -2
- shotgun/agents/history/compaction.py +9 -4
- shotgun/agents/history/history_processors.py +113 -5
- shotgun/agents/history/token_counting/anthropic.py +17 -1
- shotgun/agents/history/token_counting/base.py +14 -3
- shotgun/agents/history/token_counting/openai.py +11 -1
- shotgun/agents/history/token_counting/sentencepiece_counter.py +8 -0
- shotgun/agents/history/token_counting/tokenizer_cache.py +3 -1
- shotgun/agents/history/token_counting/utils.py +0 -3
- shotgun/agents/plan.py +2 -2
- shotgun/agents/research.py +3 -3
- shotgun/agents/specify.py +2 -2
- shotgun/agents/tasks.py +2 -2
- shotgun/agents/tools/codebase/codebase_shell.py +6 -0
- shotgun/agents/tools/codebase/directory_lister.py +6 -0
- shotgun/agents/tools/codebase/file_read.py +11 -2
- shotgun/agents/tools/codebase/query_graph.py +6 -0
- shotgun/agents/tools/codebase/retrieve_code.py +6 -0
- shotgun/agents/tools/file_management.py +27 -7
- shotgun/agents/tools/registry.py +217 -0
- shotgun/agents/tools/web_search/__init__.py +8 -8
- shotgun/agents/tools/web_search/anthropic.py +8 -2
- shotgun/agents/tools/web_search/gemini.py +7 -1
- shotgun/agents/tools/web_search/openai.py +7 -1
- shotgun/agents/tools/web_search/utils.py +2 -2
- shotgun/agents/usage_manager.py +16 -11
- shotgun/api_endpoints.py +7 -3
- shotgun/build_constants.py +3 -3
- shotgun/cli/clear.py +53 -0
- shotgun/cli/compact.py +186 -0
- shotgun/cli/config.py +8 -5
- shotgun/cli/context.py +111 -0
- shotgun/cli/export.py +1 -1
- shotgun/cli/feedback.py +4 -2
- shotgun/cli/models.py +1 -0
- shotgun/cli/plan.py +1 -1
- shotgun/cli/research.py +1 -1
- shotgun/cli/specify.py +1 -1
- shotgun/cli/tasks.py +1 -1
- shotgun/cli/update.py +16 -2
- shotgun/codebase/core/change_detector.py +5 -3
- shotgun/codebase/core/code_retrieval.py +4 -2
- shotgun/codebase/core/ingestor.py +10 -8
- shotgun/codebase/core/manager.py +13 -4
- shotgun/codebase/core/nl_query.py +1 -1
- shotgun/exceptions.py +32 -0
- shotgun/logging_config.py +18 -27
- shotgun/main.py +73 -11
- shotgun/posthog_telemetry.py +37 -28
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +3 -2
- shotgun/sentry_telemetry.py +163 -16
- shotgun/settings.py +238 -0
- shotgun/telemetry.py +10 -33
- shotgun/tui/app.py +243 -43
- shotgun/tui/commands/__init__.py +1 -1
- shotgun/tui/components/context_indicator.py +179 -0
- shotgun/tui/components/mode_indicator.py +70 -0
- shotgun/tui/components/status_bar.py +48 -0
- shotgun/tui/containers.py +91 -0
- shotgun/tui/dependencies.py +39 -0
- shotgun/tui/protocols.py +45 -0
- shotgun/tui/screens/chat/__init__.py +5 -0
- shotgun/tui/screens/chat/chat.tcss +54 -0
- shotgun/tui/screens/chat/chat_screen.py +1254 -0
- shotgun/tui/screens/chat/codebase_index_prompt_screen.py +64 -0
- shotgun/tui/screens/chat/codebase_index_selection.py +12 -0
- shotgun/tui/screens/chat/help_text.py +40 -0
- shotgun/tui/screens/chat/prompt_history.py +48 -0
- shotgun/tui/screens/chat.tcss +11 -0
- shotgun/tui/screens/chat_screen/command_providers.py +78 -2
- shotgun/tui/screens/chat_screen/history/__init__.py +22 -0
- shotgun/tui/screens/chat_screen/history/agent_response.py +66 -0
- shotgun/tui/screens/chat_screen/history/chat_history.py +115 -0
- shotgun/tui/screens/chat_screen/history/formatters.py +115 -0
- shotgun/tui/screens/chat_screen/history/partial_response.py +43 -0
- shotgun/tui/screens/chat_screen/history/user_question.py +42 -0
- shotgun/tui/screens/confirmation_dialog.py +151 -0
- shotgun/tui/screens/feedback.py +4 -4
- shotgun/tui/screens/github_issue.py +102 -0
- shotgun/tui/screens/model_picker.py +49 -24
- shotgun/tui/screens/onboarding.py +431 -0
- shotgun/tui/screens/pipx_migration.py +153 -0
- shotgun/tui/screens/provider_config.py +50 -27
- shotgun/tui/screens/shotgun_auth.py +2 -2
- shotgun/tui/screens/welcome.py +14 -11
- shotgun/tui/services/__init__.py +5 -0
- shotgun/tui/services/conversation_service.py +184 -0
- shotgun/tui/state/__init__.py +7 -0
- shotgun/tui/state/processing_state.py +185 -0
- shotgun/tui/utils/mode_progress.py +14 -7
- shotgun/tui/widgets/__init__.py +5 -0
- shotgun/tui/widgets/widget_coordinator.py +263 -0
- shotgun/utils/file_system_utils.py +22 -2
- shotgun/utils/marketing.py +110 -0
- shotgun/utils/update_checker.py +69 -14
- shotgun_sh-0.2.17.dist-info/METADATA +465 -0
- shotgun_sh-0.2.17.dist-info/RECORD +194 -0
- {shotgun_sh-0.2.8.dev2.dist-info → shotgun_sh-0.2.17.dist-info}/entry_points.txt +1 -0
- {shotgun_sh-0.2.8.dev2.dist-info → shotgun_sh-0.2.17.dist-info}/licenses/LICENSE +1 -1
- shotgun/tui/screens/chat.py +0 -996
- shotgun/tui/screens/chat_screen/history.py +0 -335
- shotgun_sh-0.2.8.dev2.dist-info/METADATA +0 -126
- shotgun_sh-0.2.8.dev2.dist-info/RECORD +0 -155
- {shotgun_sh-0.2.8.dev2.dist-info → shotgun_sh-0.2.17.dist-info}/WHEEL +0 -0
|
@@ -57,9 +57,15 @@ class OpenAITokenCounter(TokenCounter):
|
|
|
57
57
|
Raises:
|
|
58
58
|
RuntimeError: If token counting fails
|
|
59
59
|
"""
|
|
60
|
+
# Handle empty text to avoid unnecessary encoding
|
|
61
|
+
if not text or not text.strip():
|
|
62
|
+
return 0
|
|
63
|
+
|
|
60
64
|
try:
|
|
61
65
|
return len(self.encoding.encode(text))
|
|
62
|
-
except
|
|
66
|
+
except BaseException as e:
|
|
67
|
+
# Must catch BaseException to handle PanicException from tiktoken's Rust layer
|
|
68
|
+
# which can occur with extremely long texts. Regular Exception won't catch it.
|
|
63
69
|
raise RuntimeError(
|
|
64
70
|
f"Failed to count tokens for OpenAI model {self.model_name}"
|
|
65
71
|
) from e
|
|
@@ -76,5 +82,9 @@ class OpenAITokenCounter(TokenCounter):
|
|
|
76
82
|
Raises:
|
|
77
83
|
RuntimeError: If token counting fails
|
|
78
84
|
"""
|
|
85
|
+
# Handle empty message list early
|
|
86
|
+
if not messages:
|
|
87
|
+
return 0
|
|
88
|
+
|
|
79
89
|
total_text = extract_text_from_messages(messages)
|
|
80
90
|
return await self.count_tokens(total_text)
|
|
@@ -88,6 +88,10 @@ class SentencePieceTokenCounter(TokenCounter):
|
|
|
88
88
|
Raises:
|
|
89
89
|
RuntimeError: If token counting fails
|
|
90
90
|
"""
|
|
91
|
+
# Handle empty text to avoid unnecessary tokenization
|
|
92
|
+
if not text or not text.strip():
|
|
93
|
+
return 0
|
|
94
|
+
|
|
91
95
|
await self._ensure_tokenizer()
|
|
92
96
|
|
|
93
97
|
if self.sp is None:
|
|
@@ -115,5 +119,9 @@ class SentencePieceTokenCounter(TokenCounter):
|
|
|
115
119
|
Raises:
|
|
116
120
|
RuntimeError: If token counting fails
|
|
117
121
|
"""
|
|
122
|
+
# Handle empty message list early
|
|
123
|
+
if not messages:
|
|
124
|
+
return 0
|
|
125
|
+
|
|
118
126
|
total_text = extract_text_from_messages(messages)
|
|
119
127
|
return await self.count_tokens(total_text)
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
import hashlib
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
|
|
6
|
+
import aiofiles
|
|
6
7
|
import httpx
|
|
7
8
|
|
|
8
9
|
from shotgun.logging_config import get_logger
|
|
@@ -78,7 +79,8 @@ async def download_gemini_tokenizer() -> Path:
|
|
|
78
79
|
|
|
79
80
|
# Atomic write: write to temp file first, then rename
|
|
80
81
|
temp_path = cache_path.with_suffix(".tmp")
|
|
81
|
-
|
|
82
|
+
async with aiofiles.open(temp_path, "wb") as f:
|
|
83
|
+
await f.write(content)
|
|
82
84
|
temp_path.rename(cache_path)
|
|
83
85
|
|
|
84
86
|
logger.info(f"Gemini tokenizer downloaded and cached at {cache_path}")
|
|
@@ -44,9 +44,6 @@ def get_token_counter(model_config: ModelConfig) -> TokenCounter:
|
|
|
44
44
|
|
|
45
45
|
# Return cached instance if available
|
|
46
46
|
if cache_key in _token_counter_cache:
|
|
47
|
-
logger.debug(
|
|
48
|
-
f"Reusing cached token counter for {model_config.provider.value}:{model_config.name}"
|
|
49
|
-
)
|
|
50
47
|
return _token_counter_cache[cache_key]
|
|
51
48
|
|
|
52
49
|
# Create new instance and cache it
|
shotgun/agents/plan.py
CHANGED
|
@@ -23,7 +23,7 @@ from .models import AgentDeps, AgentResponse, AgentRuntimeOptions, AgentType
|
|
|
23
23
|
logger = get_logger(__name__)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def create_plan_agent(
|
|
26
|
+
async def create_plan_agent(
|
|
27
27
|
agent_runtime_options: AgentRuntimeOptions, provider: ProviderType | None = None
|
|
28
28
|
) -> tuple[Agent[AgentDeps, AgentResponse], AgentDeps]:
|
|
29
29
|
"""Create a plan agent with artifact management capabilities.
|
|
@@ -39,7 +39,7 @@ def create_plan_agent(
|
|
|
39
39
|
# Use partial to create system prompt function for plan agent
|
|
40
40
|
system_prompt_fn = partial(build_agent_system_prompt, "plan")
|
|
41
41
|
|
|
42
|
-
agent, deps = create_base_agent(
|
|
42
|
+
agent, deps = await create_base_agent(
|
|
43
43
|
system_prompt_fn,
|
|
44
44
|
agent_runtime_options,
|
|
45
45
|
load_codebase_understanding_tools=True,
|
shotgun/agents/research.py
CHANGED
|
@@ -26,7 +26,7 @@ from .tools import get_available_web_search_tools
|
|
|
26
26
|
logger = get_logger(__name__)
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
def create_research_agent(
|
|
29
|
+
async def create_research_agent(
|
|
30
30
|
agent_runtime_options: AgentRuntimeOptions, provider: ProviderType | None = None
|
|
31
31
|
) -> tuple[Agent[AgentDeps, AgentResponse], AgentDeps]:
|
|
32
32
|
"""Create a research agent with web search and artifact management capabilities.
|
|
@@ -41,7 +41,7 @@ def create_research_agent(
|
|
|
41
41
|
logger.debug("Initializing research agent")
|
|
42
42
|
|
|
43
43
|
# Get available web search tools based on configured API keys
|
|
44
|
-
web_search_tools = get_available_web_search_tools()
|
|
44
|
+
web_search_tools = await get_available_web_search_tools()
|
|
45
45
|
if web_search_tools:
|
|
46
46
|
logger.info(
|
|
47
47
|
"Research agent configured with %d web search tool(s)",
|
|
@@ -53,7 +53,7 @@ def create_research_agent(
|
|
|
53
53
|
# Use partial to create system prompt function for research agent
|
|
54
54
|
system_prompt_fn = partial(build_agent_system_prompt, "research")
|
|
55
55
|
|
|
56
|
-
agent, deps = create_base_agent(
|
|
56
|
+
agent, deps = await create_base_agent(
|
|
57
57
|
system_prompt_fn,
|
|
58
58
|
agent_runtime_options,
|
|
59
59
|
load_codebase_understanding_tools=True,
|
shotgun/agents/specify.py
CHANGED
|
@@ -23,7 +23,7 @@ from .models import AgentDeps, AgentResponse, AgentRuntimeOptions, AgentType
|
|
|
23
23
|
logger = get_logger(__name__)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def create_specify_agent(
|
|
26
|
+
async def create_specify_agent(
|
|
27
27
|
agent_runtime_options: AgentRuntimeOptions, provider: ProviderType | None = None
|
|
28
28
|
) -> tuple[Agent[AgentDeps, AgentResponse], AgentDeps]:
|
|
29
29
|
"""Create a specify agent with artifact management capabilities.
|
|
@@ -39,7 +39,7 @@ def create_specify_agent(
|
|
|
39
39
|
# Use partial to create system prompt function for specify agent
|
|
40
40
|
system_prompt_fn = partial(build_agent_system_prompt, "specify")
|
|
41
41
|
|
|
42
|
-
agent, deps = create_base_agent(
|
|
42
|
+
agent, deps = await create_base_agent(
|
|
43
43
|
system_prompt_fn,
|
|
44
44
|
agent_runtime_options,
|
|
45
45
|
load_codebase_understanding_tools=True,
|
shotgun/agents/tasks.py
CHANGED
|
@@ -23,7 +23,7 @@ from .models import AgentDeps, AgentResponse, AgentRuntimeOptions, AgentType
|
|
|
23
23
|
logger = get_logger(__name__)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def create_tasks_agent(
|
|
26
|
+
async def create_tasks_agent(
|
|
27
27
|
agent_runtime_options: AgentRuntimeOptions, provider: ProviderType | None = None
|
|
28
28
|
) -> tuple[Agent[AgentDeps, AgentResponse], AgentDeps]:
|
|
29
29
|
"""Create a tasks agent with file management capabilities.
|
|
@@ -39,7 +39,7 @@ def create_tasks_agent(
|
|
|
39
39
|
# Use partial to create system prompt function for tasks agent
|
|
40
40
|
system_prompt_fn = partial(build_agent_system_prompt, "tasks")
|
|
41
41
|
|
|
42
|
-
agent, deps = create_base_agent(
|
|
42
|
+
agent, deps = await create_base_agent(
|
|
43
43
|
system_prompt_fn,
|
|
44
44
|
agent_runtime_options,
|
|
45
45
|
provider=provider,
|
|
@@ -8,6 +8,7 @@ from pathlib import Path
|
|
|
8
8
|
from pydantic_ai import RunContext
|
|
9
9
|
|
|
10
10
|
from shotgun.agents.models import AgentDeps
|
|
11
|
+
from shotgun.agents.tools.registry import ToolCategory, register_tool
|
|
11
12
|
from shotgun.logging_config import get_logger
|
|
12
13
|
|
|
13
14
|
from .models import ShellCommandResult
|
|
@@ -48,6 +49,11 @@ DANGEROUS_PATTERNS = [
|
|
|
48
49
|
]
|
|
49
50
|
|
|
50
51
|
|
|
52
|
+
@register_tool(
|
|
53
|
+
category=ToolCategory.CODEBASE_UNDERSTANDING,
|
|
54
|
+
display_text="Running shell",
|
|
55
|
+
key_arg="command",
|
|
56
|
+
)
|
|
51
57
|
async def codebase_shell(
|
|
52
58
|
ctx: RunContext[AgentDeps],
|
|
53
59
|
command: str,
|
|
@@ -5,6 +5,7 @@ from pathlib import Path
|
|
|
5
5
|
from pydantic_ai import RunContext
|
|
6
6
|
|
|
7
7
|
from shotgun.agents.models import AgentDeps
|
|
8
|
+
from shotgun.agents.tools.registry import ToolCategory, register_tool
|
|
8
9
|
from shotgun.logging_config import get_logger
|
|
9
10
|
|
|
10
11
|
from .models import DirectoryListResult
|
|
@@ -12,6 +13,11 @@ from .models import DirectoryListResult
|
|
|
12
13
|
logger = get_logger(__name__)
|
|
13
14
|
|
|
14
15
|
|
|
16
|
+
@register_tool(
|
|
17
|
+
category=ToolCategory.CODEBASE_UNDERSTANDING,
|
|
18
|
+
display_text="Listing directory",
|
|
19
|
+
key_arg="directory",
|
|
20
|
+
)
|
|
15
21
|
async def directory_lister(
|
|
16
22
|
ctx: RunContext[AgentDeps], graph_id: str, directory: str = "."
|
|
17
23
|
) -> DirectoryListResult:
|
|
@@ -2,9 +2,11 @@
|
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
|
|
5
|
+
import aiofiles
|
|
5
6
|
from pydantic_ai import RunContext
|
|
6
7
|
|
|
7
8
|
from shotgun.agents.models import AgentDeps
|
|
9
|
+
from shotgun.agents.tools.registry import ToolCategory, register_tool
|
|
8
10
|
from shotgun.codebase.core.language_config import get_language_config
|
|
9
11
|
from shotgun.logging_config import get_logger
|
|
10
12
|
|
|
@@ -13,6 +15,11 @@ from .models import FileReadResult
|
|
|
13
15
|
logger = get_logger(__name__)
|
|
14
16
|
|
|
15
17
|
|
|
18
|
+
@register_tool(
|
|
19
|
+
category=ToolCategory.CODEBASE_UNDERSTANDING,
|
|
20
|
+
display_text="Reading file",
|
|
21
|
+
key_arg="file_path",
|
|
22
|
+
)
|
|
16
23
|
async def file_read(
|
|
17
24
|
ctx: RunContext[AgentDeps], graph_id: str, file_path: str
|
|
18
25
|
) -> FileReadResult:
|
|
@@ -87,7 +94,8 @@ async def file_read(
|
|
|
87
94
|
# Read file contents
|
|
88
95
|
encoding_used = "utf-8"
|
|
89
96
|
try:
|
|
90
|
-
|
|
97
|
+
async with aiofiles.open(full_file_path, encoding="utf-8") as f:
|
|
98
|
+
content = await f.read()
|
|
91
99
|
size_bytes = full_file_path.stat().st_size
|
|
92
100
|
|
|
93
101
|
logger.debug(
|
|
@@ -113,7 +121,8 @@ async def file_read(
|
|
|
113
121
|
try:
|
|
114
122
|
# Try with different encoding
|
|
115
123
|
encoding_used = "latin-1"
|
|
116
|
-
|
|
124
|
+
async with aiofiles.open(full_file_path, encoding="latin-1") as f:
|
|
125
|
+
content = await f.read()
|
|
117
126
|
size_bytes = full_file_path.stat().st_size
|
|
118
127
|
|
|
119
128
|
# Detect language from file extension
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from pydantic_ai import RunContext
|
|
4
4
|
|
|
5
5
|
from shotgun.agents.models import AgentDeps
|
|
6
|
+
from shotgun.agents.tools.registry import ToolCategory, register_tool
|
|
6
7
|
from shotgun.codebase.models import QueryType
|
|
7
8
|
from shotgun.logging_config import get_logger
|
|
8
9
|
|
|
@@ -11,6 +12,11 @@ from .models import QueryGraphResult
|
|
|
11
12
|
logger = get_logger(__name__)
|
|
12
13
|
|
|
13
14
|
|
|
15
|
+
@register_tool(
|
|
16
|
+
category=ToolCategory.CODEBASE_UNDERSTANDING,
|
|
17
|
+
display_text="Querying code",
|
|
18
|
+
key_arg="query",
|
|
19
|
+
)
|
|
14
20
|
async def query_graph(
|
|
15
21
|
ctx: RunContext[AgentDeps], graph_id: str, query: str
|
|
16
22
|
) -> QueryGraphResult:
|
|
@@ -5,6 +5,7 @@ from pathlib import Path
|
|
|
5
5
|
from pydantic_ai import RunContext
|
|
6
6
|
|
|
7
7
|
from shotgun.agents.models import AgentDeps
|
|
8
|
+
from shotgun.agents.tools.registry import ToolCategory, register_tool
|
|
8
9
|
from shotgun.codebase.core.code_retrieval import retrieve_code_by_qualified_name
|
|
9
10
|
from shotgun.codebase.core.language_config import get_language_config
|
|
10
11
|
from shotgun.logging_config import get_logger
|
|
@@ -14,6 +15,11 @@ from .models import CodeSnippetResult
|
|
|
14
15
|
logger = get_logger(__name__)
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
@register_tool(
|
|
19
|
+
category=ToolCategory.CODEBASE_UNDERSTANDING,
|
|
20
|
+
display_text="Retrieving code",
|
|
21
|
+
key_arg="qualified_name",
|
|
22
|
+
)
|
|
17
23
|
async def retrieve_code(
|
|
18
24
|
ctx: RunContext[AgentDeps], graph_id: str, qualified_name: str
|
|
19
25
|
) -> CodeSnippetResult:
|
|
@@ -6,9 +6,12 @@ These tools are restricted to the .shotgun directory for security.
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from typing import Literal
|
|
8
8
|
|
|
9
|
+
import aiofiles
|
|
10
|
+
import aiofiles.os
|
|
9
11
|
from pydantic_ai import RunContext
|
|
10
12
|
|
|
11
13
|
from shotgun.agents.models import AgentDeps, AgentType, FileOperationType
|
|
14
|
+
from shotgun.agents.tools.registry import ToolCategory, register_tool
|
|
12
15
|
from shotgun.logging_config import get_logger
|
|
13
16
|
from shotgun.utils.file_system_utils import get_shotgun_base_path
|
|
14
17
|
|
|
@@ -157,6 +160,11 @@ def _validate_shotgun_path(filename: str) -> Path:
|
|
|
157
160
|
return full_path
|
|
158
161
|
|
|
159
162
|
|
|
163
|
+
@register_tool(
|
|
164
|
+
category=ToolCategory.ARTIFACT_MANAGEMENT,
|
|
165
|
+
display_text="Reading file",
|
|
166
|
+
key_arg="filename",
|
|
167
|
+
)
|
|
160
168
|
async def read_file(ctx: RunContext[AgentDeps], filename: str) -> str:
|
|
161
169
|
"""Read a file from the .shotgun directory.
|
|
162
170
|
|
|
@@ -175,10 +183,11 @@ async def read_file(ctx: RunContext[AgentDeps], filename: str) -> str:
|
|
|
175
183
|
try:
|
|
176
184
|
file_path = _validate_shotgun_path(filename)
|
|
177
185
|
|
|
178
|
-
if not
|
|
186
|
+
if not await aiofiles.os.path.exists(file_path):
|
|
179
187
|
raise FileNotFoundError(f"File not found: {filename}")
|
|
180
188
|
|
|
181
|
-
|
|
189
|
+
async with aiofiles.open(file_path, encoding="utf-8") as f:
|
|
190
|
+
content = await f.read()
|
|
182
191
|
logger.debug("📄 Read %d characters from %s", len(content), filename)
|
|
183
192
|
return content
|
|
184
193
|
|
|
@@ -188,6 +197,11 @@ async def read_file(ctx: RunContext[AgentDeps], filename: str) -> str:
|
|
|
188
197
|
return error_msg
|
|
189
198
|
|
|
190
199
|
|
|
200
|
+
@register_tool(
|
|
201
|
+
category=ToolCategory.ARTIFACT_MANAGEMENT,
|
|
202
|
+
display_text="Writing file",
|
|
203
|
+
key_arg="filename",
|
|
204
|
+
)
|
|
191
205
|
async def write_file(
|
|
192
206
|
ctx: RunContext[AgentDeps],
|
|
193
207
|
filename: str,
|
|
@@ -222,21 +236,22 @@ async def write_file(
|
|
|
222
236
|
else:
|
|
223
237
|
operation = (
|
|
224
238
|
FileOperationType.CREATED
|
|
225
|
-
if not
|
|
239
|
+
if not await aiofiles.os.path.exists(file_path)
|
|
226
240
|
else FileOperationType.UPDATED
|
|
227
241
|
)
|
|
228
242
|
|
|
229
243
|
# Ensure parent directory exists
|
|
230
|
-
file_path.parent
|
|
244
|
+
await aiofiles.os.makedirs(file_path.parent, exist_ok=True)
|
|
231
245
|
|
|
232
246
|
# Write content
|
|
233
247
|
if mode == "a":
|
|
234
|
-
with open(file_path, "a", encoding="utf-8") as f:
|
|
235
|
-
f.write(content)
|
|
248
|
+
async with aiofiles.open(file_path, "a", encoding="utf-8") as f:
|
|
249
|
+
await f.write(content)
|
|
236
250
|
logger.debug("📄 Appended %d characters to %s", len(content), filename)
|
|
237
251
|
result = f"Successfully appended {len(content)} characters to {filename}"
|
|
238
252
|
else:
|
|
239
|
-
|
|
253
|
+
async with aiofiles.open(file_path, "w", encoding="utf-8") as f:
|
|
254
|
+
await f.write(content)
|
|
240
255
|
logger.debug("📄 Wrote %d characters to %s", len(content), filename)
|
|
241
256
|
result = f"Successfully wrote {len(content)} characters to {filename}"
|
|
242
257
|
|
|
@@ -251,6 +266,11 @@ async def write_file(
|
|
|
251
266
|
return error_msg
|
|
252
267
|
|
|
253
268
|
|
|
269
|
+
@register_tool(
|
|
270
|
+
category=ToolCategory.ARTIFACT_MANAGEMENT,
|
|
271
|
+
display_text="Appending to file",
|
|
272
|
+
key_arg="filename",
|
|
273
|
+
)
|
|
254
274
|
async def append_file(ctx: RunContext[AgentDeps], filename: str, content: str) -> str:
|
|
255
275
|
"""Append content to a file in the .shotgun directory.
|
|
256
276
|
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""Tool category registry using decorators for automatic registration.
|
|
2
|
+
|
|
3
|
+
This module provides a decorator-based system for categorizing tools used by agents.
|
|
4
|
+
Tools can be decorated with @register_tool to automatically register their category,
|
|
5
|
+
which is then used by the context analyzer to break down token usage by tool type.
|
|
6
|
+
|
|
7
|
+
It also provides a display registry system for tool formatting in the TUI, allowing
|
|
8
|
+
tools to declare how they should be displayed when streaming.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from enum import StrEnum
|
|
13
|
+
from typing import TypeVar, overload
|
|
14
|
+
|
|
15
|
+
import sentry_sdk
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
|
|
18
|
+
from shotgun.logging_config import get_logger
|
|
19
|
+
|
|
20
|
+
logger = get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
# Type variable for decorated functions
|
|
23
|
+
F = TypeVar("F", bound=Callable[..., object])
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ToolCategory(StrEnum):
|
|
27
|
+
"""Categories for agent tools used in context analysis."""
|
|
28
|
+
|
|
29
|
+
CODEBASE_UNDERSTANDING = "codebase_understanding"
|
|
30
|
+
ARTIFACT_MANAGEMENT = "artifact_management"
|
|
31
|
+
WEB_RESEARCH = "web_research"
|
|
32
|
+
AGENT_RESPONSE = "agent_response"
|
|
33
|
+
UNKNOWN = "unknown"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ToolDisplayConfig(BaseModel):
|
|
37
|
+
"""Configuration for how a tool should be displayed in the TUI.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
display_text: Text to show (e.g., "Reading file", "Querying code")
|
|
41
|
+
key_arg: Primary argument to extract from tool args for display
|
|
42
|
+
hide: Whether to completely hide this tool call from the UI
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
display_text: str
|
|
46
|
+
key_arg: str
|
|
47
|
+
hide: bool = False
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# Global registry mapping tool names to categories
|
|
51
|
+
_TOOL_REGISTRY: dict[str, ToolCategory] = {}
|
|
52
|
+
|
|
53
|
+
# Global registry mapping tool names to display configs
|
|
54
|
+
_TOOL_DISPLAY_REGISTRY: dict[str, ToolDisplayConfig] = {}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@overload
|
|
58
|
+
def register_tool(
|
|
59
|
+
category: ToolCategory,
|
|
60
|
+
display_text: str,
|
|
61
|
+
key_arg: str,
|
|
62
|
+
) -> Callable[[F], F]: ...
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@overload
|
|
66
|
+
def register_tool(
|
|
67
|
+
category: ToolCategory,
|
|
68
|
+
display_text: str,
|
|
69
|
+
key_arg: str,
|
|
70
|
+
*,
|
|
71
|
+
hide: bool,
|
|
72
|
+
) -> Callable[[F], F]: ...
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def register_tool(
|
|
76
|
+
category: ToolCategory,
|
|
77
|
+
display_text: str,
|
|
78
|
+
key_arg: str,
|
|
79
|
+
*,
|
|
80
|
+
hide: bool = False,
|
|
81
|
+
) -> Callable[[F], F]:
|
|
82
|
+
"""Decorator to register a tool's category and display configuration.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
category: The ToolCategory enum value for this tool
|
|
86
|
+
display_text: Text to show (e.g., "Reading file", "Querying code")
|
|
87
|
+
key_arg: Primary argument name to extract for display (e.g., "query", "filename")
|
|
88
|
+
hide: Whether to hide this tool call completely from the UI (default: False)
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Decorator function that registers the tool and returns it unchanged
|
|
92
|
+
|
|
93
|
+
Display Format:
|
|
94
|
+
- When key_arg value is missing: Shows just display_text (e.g., "Reading file")
|
|
95
|
+
- When key_arg value is present: Shows "display_text: key_arg_value" (e.g., "Reading file: foo.py")
|
|
96
|
+
|
|
97
|
+
Example:
|
|
98
|
+
@register_tool(
|
|
99
|
+
category=ToolCategory.CODEBASE_UNDERSTANDING,
|
|
100
|
+
display_text="Querying code",
|
|
101
|
+
key_arg="query",
|
|
102
|
+
)
|
|
103
|
+
async def query_graph(ctx: RunContext[AgentDeps], query: str) -> str:
|
|
104
|
+
...
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def decorator(func: F) -> F:
|
|
108
|
+
tool_name = func.__name__
|
|
109
|
+
_TOOL_REGISTRY[tool_name] = category
|
|
110
|
+
logger.debug(f"Registered tool '{tool_name}' as category '{category.value}'")
|
|
111
|
+
|
|
112
|
+
# Register display config
|
|
113
|
+
config = ToolDisplayConfig(
|
|
114
|
+
display_text=display_text,
|
|
115
|
+
key_arg=key_arg,
|
|
116
|
+
hide=hide,
|
|
117
|
+
)
|
|
118
|
+
_TOOL_DISPLAY_REGISTRY[tool_name] = config
|
|
119
|
+
logger.debug(f"Registered display config for tool '{tool_name}'")
|
|
120
|
+
|
|
121
|
+
return func
|
|
122
|
+
|
|
123
|
+
return decorator
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
# Backwards compatibility alias
|
|
127
|
+
tool_category = register_tool
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_tool_category(tool_name: str) -> ToolCategory:
|
|
131
|
+
"""Get category for a tool, logging unknown tools to Sentry.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
tool_name: Name of the tool to look up
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
ToolCategory enum value for the tool, or UNKNOWN if not registered
|
|
138
|
+
"""
|
|
139
|
+
category = _TOOL_REGISTRY.get(tool_name)
|
|
140
|
+
|
|
141
|
+
if category is None:
|
|
142
|
+
logger.warning(f"Unknown tool encountered in context analysis: {tool_name}")
|
|
143
|
+
sentry_sdk.capture_message(
|
|
144
|
+
f"Unknown tool in context analysis: {tool_name}",
|
|
145
|
+
level="warning",
|
|
146
|
+
extras={"tool_name": tool_name},
|
|
147
|
+
)
|
|
148
|
+
return ToolCategory.UNKNOWN
|
|
149
|
+
|
|
150
|
+
return category
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def register_special_tool(tool_name: str, category: ToolCategory) -> None:
|
|
154
|
+
"""Register a special tool that doesn't have a decorator.
|
|
155
|
+
|
|
156
|
+
Used for tools like 'final_result' that aren't actual Python functions
|
|
157
|
+
but need to be categorized.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
tool_name: Name of the special tool
|
|
161
|
+
category: Category to assign to this tool
|
|
162
|
+
"""
|
|
163
|
+
_TOOL_REGISTRY[tool_name] = category
|
|
164
|
+
logger.debug(
|
|
165
|
+
f"Registered special tool '{tool_name}' as category '{category.value}'"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def get_tool_display_config(tool_name: str) -> ToolDisplayConfig | None:
|
|
170
|
+
"""Get display configuration for a tool.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
tool_name: Name of the tool to look up
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
ToolDisplayConfig for the tool, or None if not registered
|
|
177
|
+
"""
|
|
178
|
+
return _TOOL_DISPLAY_REGISTRY.get(tool_name)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def register_tool_display(
|
|
182
|
+
tool_name: str,
|
|
183
|
+
display_text: str,
|
|
184
|
+
key_arg: str,
|
|
185
|
+
*,
|
|
186
|
+
hide: bool = False,
|
|
187
|
+
) -> None:
|
|
188
|
+
"""Register a display config for a special tool that doesn't have a decorator.
|
|
189
|
+
|
|
190
|
+
Used for tools like 'final_result' or builtin tools that aren't actual Python functions.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
tool_name: Name of the special tool
|
|
194
|
+
display_text: Text to show (e.g., "Reading file", "Querying code")
|
|
195
|
+
key_arg: Primary argument name to extract for display
|
|
196
|
+
hide: Whether to hide this tool call completely
|
|
197
|
+
"""
|
|
198
|
+
config = ToolDisplayConfig(
|
|
199
|
+
display_text=display_text,
|
|
200
|
+
key_arg=key_arg,
|
|
201
|
+
hide=hide,
|
|
202
|
+
)
|
|
203
|
+
_TOOL_DISPLAY_REGISTRY[tool_name] = config
|
|
204
|
+
logger.debug(f"Registered display config for special tool '{tool_name}'")
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# Register special tools that don't have decorators
|
|
208
|
+
register_special_tool("final_result", ToolCategory.AGENT_RESPONSE)
|
|
209
|
+
register_tool_display("final_result", display_text="", key_arg="", hide=True)
|
|
210
|
+
|
|
211
|
+
# Register builtin tools (tools that come from Pydantic AI or model providers)
|
|
212
|
+
# These don't have Python function definitions but need display formatting
|
|
213
|
+
register_tool_display(
|
|
214
|
+
"web_search",
|
|
215
|
+
display_text="Searching",
|
|
216
|
+
key_arg="query",
|
|
217
|
+
)
|
|
@@ -26,7 +26,7 @@ logger = get_logger(__name__)
|
|
|
26
26
|
WebSearchTool = Callable[[str], Awaitable[str]]
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
def get_available_web_search_tools() -> list[WebSearchTool]:
|
|
29
|
+
async def get_available_web_search_tools() -> list[WebSearchTool]:
|
|
30
30
|
"""Get list of available web search tools based on configured API keys.
|
|
31
31
|
|
|
32
32
|
Works with both Shotgun Account (via LiteLLM proxy) and BYOK (individual provider keys).
|
|
@@ -43,25 +43,25 @@ def get_available_web_search_tools() -> list[WebSearchTool]:
|
|
|
43
43
|
|
|
44
44
|
# Check if using Shotgun Account
|
|
45
45
|
config_manager = get_config_manager()
|
|
46
|
-
config = config_manager.load()
|
|
46
|
+
config = await config_manager.load()
|
|
47
47
|
has_shotgun_key = config.shotgun.api_key is not None
|
|
48
48
|
|
|
49
49
|
if has_shotgun_key:
|
|
50
50
|
logger.debug("🔑 Shotgun Account - only Gemini web search available")
|
|
51
51
|
|
|
52
52
|
# Gemini: Only search tool available for Shotgun Account
|
|
53
|
-
if is_provider_available(ProviderType.GOOGLE):
|
|
53
|
+
if await is_provider_available(ProviderType.GOOGLE):
|
|
54
54
|
logger.debug("✅ Gemini web search tool available")
|
|
55
55
|
tools.append(gemini_web_search_tool)
|
|
56
56
|
|
|
57
57
|
# Anthropic: Not available for Shotgun Account (Gemini-only for Shotgun)
|
|
58
|
-
if is_provider_available(ProviderType.ANTHROPIC):
|
|
58
|
+
if await is_provider_available(ProviderType.ANTHROPIC):
|
|
59
59
|
logger.debug(
|
|
60
60
|
"⚠️ Anthropic web search requires BYOK (Shotgun Account uses Gemini only)"
|
|
61
61
|
)
|
|
62
62
|
|
|
63
63
|
# OpenAI: Not available for Shotgun Account (Responses API incompatible with proxy)
|
|
64
|
-
if is_provider_available(ProviderType.OPENAI):
|
|
64
|
+
if await is_provider_available(ProviderType.OPENAI):
|
|
65
65
|
logger.debug(
|
|
66
66
|
"⚠️ OpenAI web search requires BYOK (Responses API not supported via proxy)"
|
|
67
67
|
)
|
|
@@ -69,15 +69,15 @@ def get_available_web_search_tools() -> list[WebSearchTool]:
|
|
|
69
69
|
# BYOK mode: Load all available tools based on individual provider keys
|
|
70
70
|
logger.debug("🔑 BYOK mode - checking all provider web search tools")
|
|
71
71
|
|
|
72
|
-
if is_provider_available(ProviderType.OPENAI):
|
|
72
|
+
if await is_provider_available(ProviderType.OPENAI):
|
|
73
73
|
logger.debug("✅ OpenAI web search tool available")
|
|
74
74
|
tools.append(openai_web_search_tool)
|
|
75
75
|
|
|
76
|
-
if is_provider_available(ProviderType.ANTHROPIC):
|
|
76
|
+
if await is_provider_available(ProviderType.ANTHROPIC):
|
|
77
77
|
logger.debug("✅ Anthropic web search tool available")
|
|
78
78
|
tools.append(anthropic_web_search_tool)
|
|
79
79
|
|
|
80
|
-
if is_provider_available(ProviderType.GOOGLE):
|
|
80
|
+
if await is_provider_available(ProviderType.GOOGLE):
|
|
81
81
|
logger.debug("✅ Gemini web search tool available")
|
|
82
82
|
tools.append(gemini_web_search_tool)
|
|
83
83
|
|