shotgun-sh 0.1.16.dev2__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/common.py +4 -5
- shotgun/agents/config/constants.py +23 -6
- shotgun/agents/config/manager.py +239 -76
- shotgun/agents/config/models.py +74 -84
- shotgun/agents/config/provider.py +174 -85
- shotgun/agents/history/compaction.py +1 -1
- shotgun/agents/history/history_processors.py +18 -9
- shotgun/agents/history/token_counting/__init__.py +31 -0
- shotgun/agents/history/token_counting/anthropic.py +89 -0
- shotgun/agents/history/token_counting/base.py +67 -0
- shotgun/agents/history/token_counting/openai.py +80 -0
- shotgun/agents/history/token_counting/sentencepiece_counter.py +119 -0
- shotgun/agents/history/token_counting/tokenizer_cache.py +90 -0
- shotgun/agents/history/token_counting/utils.py +147 -0
- shotgun/agents/history/token_estimation.py +12 -12
- shotgun/agents/llm.py +62 -0
- shotgun/agents/models.py +2 -2
- shotgun/agents/tools/web_search/__init__.py +42 -15
- shotgun/agents/tools/web_search/anthropic.py +54 -50
- shotgun/agents/tools/web_search/gemini.py +31 -20
- shotgun/agents/tools/web_search/openai.py +4 -4
- shotgun/build_constants.py +2 -2
- shotgun/cli/config.py +34 -63
- shotgun/cli/feedback.py +4 -2
- shotgun/cli/models.py +2 -2
- shotgun/codebase/core/ingestor.py +47 -8
- shotgun/codebase/core/manager.py +7 -3
- shotgun/codebase/models.py +4 -4
- shotgun/llm_proxy/__init__.py +16 -0
- shotgun/llm_proxy/clients.py +39 -0
- shotgun/llm_proxy/constants.py +8 -0
- shotgun/main.py +6 -0
- shotgun/posthog_telemetry.py +15 -11
- shotgun/sentry_telemetry.py +3 -3
- shotgun/shotgun_web/__init__.py +19 -0
- shotgun/shotgun_web/client.py +138 -0
- shotgun/shotgun_web/constants.py +17 -0
- shotgun/shotgun_web/models.py +47 -0
- shotgun/telemetry.py +7 -4
- shotgun/tui/app.py +26 -8
- shotgun/tui/screens/chat.py +2 -8
- shotgun/tui/screens/chat_screen/command_providers.py +118 -11
- shotgun/tui/screens/chat_screen/history.py +3 -1
- shotgun/tui/screens/feedback.py +2 -2
- shotgun/tui/screens/model_picker.py +327 -0
- shotgun/tui/screens/provider_config.py +118 -28
- shotgun/tui/screens/shotgun_auth.py +295 -0
- shotgun/tui/screens/welcome.py +176 -0
- shotgun/utils/env_utils.py +12 -0
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dist-info}/METADATA +2 -2
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dist-info}/RECORD +54 -37
- shotgun/agents/history/token_counting.py +0 -429
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dist-info}/WHEEL +0 -0
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dist-info}/entry_points.txt +0 -0
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""Utility functions and cache for token counting."""
|
|
2
|
+
|
|
3
|
+
from pydantic_ai.messages import ModelMessage
|
|
4
|
+
|
|
5
|
+
from shotgun.agents.config.models import ModelConfig, ProviderType
|
|
6
|
+
from shotgun.logging_config import get_logger
|
|
7
|
+
|
|
8
|
+
from .anthropic import AnthropicTokenCounter
|
|
9
|
+
from .base import TokenCounter
|
|
10
|
+
from .openai import OpenAITokenCounter
|
|
11
|
+
from .sentencepiece_counter import SentencePieceTokenCounter
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
# Global cache for token counter instances (singleton pattern)
|
|
16
|
+
_token_counter_cache: dict[tuple[str, str, str], TokenCounter] = {}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_token_counter(model_config: ModelConfig) -> TokenCounter:
|
|
20
|
+
"""Get appropriate token counter for the model provider (cached singleton).
|
|
21
|
+
|
|
22
|
+
This function ensures that every provider has a proper token counting
|
|
23
|
+
implementation without any fallbacks to estimation. Token counters are
|
|
24
|
+
cached to avoid repeated initialization overhead.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model_config: Model configuration with provider and credentials
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Cached provider-specific token counter
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If provider is not supported for token counting
|
|
34
|
+
RuntimeError: If token counter initialization fails
|
|
35
|
+
"""
|
|
36
|
+
# Create cache key from provider, model name, and API key
|
|
37
|
+
cache_key = (
|
|
38
|
+
model_config.provider.value,
|
|
39
|
+
model_config.name,
|
|
40
|
+
model_config.api_key[:10]
|
|
41
|
+
if model_config.api_key
|
|
42
|
+
else "no-key", # Partial key for cache
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Return cached instance if available
|
|
46
|
+
if cache_key in _token_counter_cache:
|
|
47
|
+
logger.debug(
|
|
48
|
+
f"Reusing cached token counter for {model_config.provider.value}:{model_config.name}"
|
|
49
|
+
)
|
|
50
|
+
return _token_counter_cache[cache_key]
|
|
51
|
+
|
|
52
|
+
# Create new instance and cache it
|
|
53
|
+
logger.debug(
|
|
54
|
+
f"Creating new token counter for {model_config.provider.value}:{model_config.name}"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
counter: TokenCounter
|
|
58
|
+
if model_config.provider == ProviderType.OPENAI:
|
|
59
|
+
counter = OpenAITokenCounter(model_config.name)
|
|
60
|
+
elif model_config.provider == ProviderType.ANTHROPIC:
|
|
61
|
+
counter = AnthropicTokenCounter(
|
|
62
|
+
model_config.name, model_config.api_key, model_config.key_provider
|
|
63
|
+
)
|
|
64
|
+
elif model_config.provider == ProviderType.GOOGLE:
|
|
65
|
+
# Use local SentencePiece tokenizer (100% accurate, 10-100x faster than API)
|
|
66
|
+
counter = SentencePieceTokenCounter(model_config.name)
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Unsupported provider for token counting: {model_config.provider}. "
|
|
70
|
+
f"Supported providers: {[p.value for p in ProviderType]}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Cache the instance
|
|
74
|
+
_token_counter_cache[cache_key] = counter
|
|
75
|
+
logger.debug(
|
|
76
|
+
f"Cached token counter for {model_config.provider.value}:{model_config.name}"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
return counter
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
async def count_tokens_from_messages(
|
|
83
|
+
messages: list[ModelMessage], model_config: ModelConfig
|
|
84
|
+
) -> int:
|
|
85
|
+
"""Count actual tokens from messages using provider-specific methods (async).
|
|
86
|
+
|
|
87
|
+
This replaces the old estimation approach with accurate token counting
|
|
88
|
+
using each provider's official APIs and libraries.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
messages: List of messages to count tokens for
|
|
92
|
+
model_config: Model configuration with provider info
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Exact token count for the messages
|
|
96
|
+
|
|
97
|
+
Raises:
|
|
98
|
+
ValueError: If provider is not supported
|
|
99
|
+
RuntimeError: If token counting fails
|
|
100
|
+
"""
|
|
101
|
+
counter = get_token_counter(model_config)
|
|
102
|
+
return await counter.count_message_tokens(messages)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
async def count_post_summary_tokens(
|
|
106
|
+
messages: list[ModelMessage], summary_index: int, model_config: ModelConfig
|
|
107
|
+
) -> int:
|
|
108
|
+
"""Count actual tokens from summary onwards for incremental compaction decisions (async).
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
messages: Full message history
|
|
112
|
+
summary_index: Index of the last summary message
|
|
113
|
+
model_config: Model configuration with provider info
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Exact token count from summary onwards
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
ValueError: If provider is not supported
|
|
120
|
+
RuntimeError: If token counting fails
|
|
121
|
+
"""
|
|
122
|
+
if summary_index >= len(messages):
|
|
123
|
+
return 0
|
|
124
|
+
|
|
125
|
+
post_summary_messages = messages[summary_index:]
|
|
126
|
+
return await count_tokens_from_messages(post_summary_messages, model_config)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
async def count_tokens_from_message_parts(
|
|
130
|
+
messages: list[ModelMessage], model_config: ModelConfig
|
|
131
|
+
) -> int:
|
|
132
|
+
"""Count actual tokens from message parts for summarization requests (async).
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
messages: List of messages to count tokens for
|
|
136
|
+
model_config: Model configuration with provider info
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Exact token count from message parts
|
|
140
|
+
|
|
141
|
+
Raises:
|
|
142
|
+
ValueError: If provider is not supported
|
|
143
|
+
RuntimeError: If token counting fails
|
|
144
|
+
"""
|
|
145
|
+
# For now, use the same logic as count_tokens_from_messages
|
|
146
|
+
# This can be optimized later if needed for different counting strategies
|
|
147
|
+
return await count_tokens_from_messages(messages, model_config)
|
|
@@ -19,10 +19,10 @@ from .constants import INPUT_BUFFER_TOKENS, MIN_SUMMARY_TOKENS
|
|
|
19
19
|
from .token_counting import count_tokens_from_messages as _count_tokens_from_messages
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def estimate_tokens_from_messages(
|
|
22
|
+
async def estimate_tokens_from_messages(
|
|
23
23
|
messages: list[ModelMessage], model_config: ModelConfig
|
|
24
24
|
) -> int:
|
|
25
|
-
"""Count actual tokens from current message list.
|
|
25
|
+
"""Count actual tokens from current message list (async).
|
|
26
26
|
|
|
27
27
|
This provides accurate token counting for compaction decisions using
|
|
28
28
|
provider-specific token counting methods instead of rough estimation.
|
|
@@ -38,13 +38,13 @@ def estimate_tokens_from_messages(
|
|
|
38
38
|
ValueError: If provider is not supported
|
|
39
39
|
RuntimeError: If token counting fails
|
|
40
40
|
"""
|
|
41
|
-
return _count_tokens_from_messages(messages, model_config)
|
|
41
|
+
return await _count_tokens_from_messages(messages, model_config)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def estimate_post_summary_tokens(
|
|
44
|
+
async def estimate_post_summary_tokens(
|
|
45
45
|
messages: list[ModelMessage], summary_index: int, model_config: ModelConfig
|
|
46
46
|
) -> int:
|
|
47
|
-
"""Count actual tokens from summary onwards for incremental compaction decisions.
|
|
47
|
+
"""Count actual tokens from summary onwards for incremental compaction decisions (async).
|
|
48
48
|
|
|
49
49
|
This treats the summary as a reset point and only counts tokens from the summary
|
|
50
50
|
message onwards. Used to determine if incremental compaction is needed.
|
|
@@ -65,13 +65,13 @@ def estimate_post_summary_tokens(
|
|
|
65
65
|
return 0
|
|
66
66
|
|
|
67
67
|
post_summary_messages = messages[summary_index:]
|
|
68
|
-
return estimate_tokens_from_messages(post_summary_messages, model_config)
|
|
68
|
+
return await estimate_tokens_from_messages(post_summary_messages, model_config)
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
def estimate_tokens_from_message_parts(
|
|
71
|
+
async def estimate_tokens_from_message_parts(
|
|
72
72
|
messages: list[ModelMessage], model_config: ModelConfig
|
|
73
73
|
) -> int:
|
|
74
|
-
"""Count actual tokens from message parts for summarization requests.
|
|
74
|
+
"""Count actual tokens from message parts for summarization requests (async).
|
|
75
75
|
|
|
76
76
|
This provides accurate token counting across the codebase using
|
|
77
77
|
provider-specific methods instead of character estimation.
|
|
@@ -87,14 +87,14 @@ def estimate_tokens_from_message_parts(
|
|
|
87
87
|
ValueError: If provider is not supported
|
|
88
88
|
RuntimeError: If token counting fails
|
|
89
89
|
"""
|
|
90
|
-
return _count_tokens_from_messages(messages, model_config)
|
|
90
|
+
return await _count_tokens_from_messages(messages, model_config)
|
|
91
91
|
|
|
92
92
|
|
|
93
|
-
def calculate_max_summarization_tokens(
|
|
93
|
+
async def calculate_max_summarization_tokens(
|
|
94
94
|
ctx_or_model_config: Union["RunContext[AgentDeps]", ModelConfig],
|
|
95
95
|
request_messages: list[ModelMessage],
|
|
96
96
|
) -> int:
|
|
97
|
-
"""Calculate maximum tokens available for summarization output.
|
|
97
|
+
"""Calculate maximum tokens available for summarization output (async).
|
|
98
98
|
|
|
99
99
|
This ensures we use the model's full capacity while leaving room for input tokens.
|
|
100
100
|
|
|
@@ -115,7 +115,7 @@ def calculate_max_summarization_tokens(
|
|
|
115
115
|
return MIN_SUMMARY_TOKENS
|
|
116
116
|
|
|
117
117
|
# Count actual input tokens using shared utility
|
|
118
|
-
estimated_input_tokens = estimate_tokens_from_message_parts(
|
|
118
|
+
estimated_input_tokens = await estimate_tokens_from_message_parts(
|
|
119
119
|
request_messages, model_config
|
|
120
120
|
)
|
|
121
121
|
|
shotgun/agents/llm.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""LLM request utilities for Shotgun agents."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic_ai.direct import model_request
|
|
6
|
+
from pydantic_ai.messages import ModelMessage, ModelResponse
|
|
7
|
+
from pydantic_ai.settings import ModelSettings
|
|
8
|
+
|
|
9
|
+
from shotgun.agents.config.models import ModelConfig
|
|
10
|
+
from shotgun.logging_config import get_logger
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
async def shotgun_model_request(
|
|
16
|
+
model_config: ModelConfig,
|
|
17
|
+
messages: list[ModelMessage],
|
|
18
|
+
model_settings: ModelSettings | None = None,
|
|
19
|
+
**kwargs: Any,
|
|
20
|
+
) -> ModelResponse:
|
|
21
|
+
"""Model request wrapper that uses full token capacity by default.
|
|
22
|
+
|
|
23
|
+
This wrapper ensures all LLM calls in Shotgun use the maximum available
|
|
24
|
+
token capacity of each model, improving response quality and completeness.
|
|
25
|
+
The most common issue this fixes is truncated summaries that were cut off
|
|
26
|
+
at default token limits (e.g., 4096 for Claude models).
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
model_config: ModelConfig instance with model settings and API key
|
|
30
|
+
messages: Messages to send to the model
|
|
31
|
+
model_settings: Optional ModelSettings. If None, creates default with max tokens
|
|
32
|
+
**kwargs: Additional arguments passed to model_request
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
ModelResponse from the model
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
# Uses full token capacity (e.g., 4096 for Claude, 128k for GPT-5)
|
|
39
|
+
response = await shotgun_model_request(model_config, messages)
|
|
40
|
+
|
|
41
|
+
# With custom settings
|
|
42
|
+
response = await shotgun_model_request(model_config, messages, model_settings=ModelSettings(max_tokens=1000, temperature=0.7))
|
|
43
|
+
"""
|
|
44
|
+
if kwargs.get("max_tokens") is not None:
|
|
45
|
+
logger.warning(
|
|
46
|
+
"⚠️ 'max_tokens' argument is ignored in shotgun_model_request. "
|
|
47
|
+
"Set 'model_settings.max_tokens' instead."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if not model_settings:
|
|
51
|
+
model_settings = ModelSettings()
|
|
52
|
+
|
|
53
|
+
if model_settings.get("max_tokens") is None:
|
|
54
|
+
model_settings["max_tokens"] = model_config.max_output_tokens
|
|
55
|
+
|
|
56
|
+
# Make the model request with full token utilization
|
|
57
|
+
return await model_request(
|
|
58
|
+
model=model_config.model_instance,
|
|
59
|
+
messages=messages,
|
|
60
|
+
model_settings=model_settings,
|
|
61
|
+
**kwargs,
|
|
62
|
+
)
|
shotgun/agents/models.py
CHANGED
|
@@ -4,7 +4,7 @@ import os
|
|
|
4
4
|
from asyncio import Future, Queue
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from datetime import datetime
|
|
7
|
-
from enum import
|
|
7
|
+
from enum import StrEnum
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
from typing import TYPE_CHECKING
|
|
10
10
|
|
|
@@ -116,7 +116,7 @@ class AgentRuntimeOptions(BaseModel):
|
|
|
116
116
|
)
|
|
117
117
|
|
|
118
118
|
|
|
119
|
-
class FileOperationType(
|
|
119
|
+
class FileOperationType(StrEnum):
|
|
120
120
|
"""Types of file operations that can be tracked."""
|
|
121
121
|
|
|
122
122
|
CREATED = "created"
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
"""Web search tools for Pydantic AI agents.
|
|
2
2
|
|
|
3
3
|
Provides web search capabilities for multiple LLM providers:
|
|
4
|
-
- OpenAI: Uses Responses API with web_search tool
|
|
5
|
-
- Anthropic: Uses Messages API with web_search_20250305 tool
|
|
6
|
-
- Gemini: Uses grounding with Google Search
|
|
4
|
+
- OpenAI: Uses Responses API with web_search tool (BYOK only)
|
|
5
|
+
- Anthropic: Uses Messages API with web_search_20250305 tool (BYOK only)
|
|
6
|
+
- Gemini: Uses grounding with Google Search via Pydantic AI (works with Shotgun Account)
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
from collections.abc import Callable
|
|
9
|
+
from collections.abc import Awaitable, Callable
|
|
10
10
|
|
|
11
|
+
from shotgun.agents.config import get_config_manager
|
|
11
12
|
from shotgun.agents.config.models import ProviderType
|
|
12
13
|
from shotgun.logging_config import get_logger
|
|
13
14
|
|
|
@@ -18,29 +19,55 @@ from .utils import is_provider_available
|
|
|
18
19
|
|
|
19
20
|
logger = get_logger(__name__)
|
|
20
21
|
|
|
21
|
-
# Type alias for web search tools
|
|
22
|
-
WebSearchTool = Callable[[str], str]
|
|
22
|
+
# Type alias for web search tools (all now async)
|
|
23
|
+
WebSearchTool = Callable[[str], Awaitable[str]]
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
def get_available_web_search_tools() -> list[WebSearchTool]:
|
|
26
27
|
"""Get list of available web search tools based on configured API keys.
|
|
27
28
|
|
|
29
|
+
When using Shotgun Account (via LiteLLM proxy):
|
|
30
|
+
Only Gemini web search is available (others use provider-specific APIs)
|
|
31
|
+
|
|
32
|
+
When using BYOK (individual provider keys):
|
|
33
|
+
All provider tools are available based on their respective keys
|
|
34
|
+
|
|
28
35
|
Returns:
|
|
29
36
|
List of web search tool functions that have API keys configured
|
|
30
37
|
"""
|
|
31
38
|
tools: list[WebSearchTool] = []
|
|
32
39
|
|
|
33
|
-
if
|
|
34
|
-
|
|
35
|
-
|
|
40
|
+
# Check if using Shotgun Account
|
|
41
|
+
config_manager = get_config_manager()
|
|
42
|
+
config = config_manager.load()
|
|
43
|
+
has_shotgun_key = config.shotgun.api_key is not None
|
|
44
|
+
|
|
45
|
+
if has_shotgun_key:
|
|
46
|
+
# Shotgun Account mode: Only Gemini supports web search via LiteLLM
|
|
47
|
+
if is_provider_available(ProviderType.GOOGLE):
|
|
48
|
+
logger.info("🔑 Shotgun Account detected - using Gemini web search only")
|
|
49
|
+
logger.debug(" OpenAI and Anthropic web search require direct API keys")
|
|
50
|
+
tools.append(gemini_web_search_tool)
|
|
51
|
+
else:
|
|
52
|
+
logger.warning(
|
|
53
|
+
"⚠️ Shotgun Account configured but no Gemini key - "
|
|
54
|
+
"web search unavailable"
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
# BYOK mode: Load all available tools based on individual provider keys
|
|
58
|
+
logger.debug("🔑 BYOK mode - checking all provider web search tools")
|
|
59
|
+
|
|
60
|
+
if is_provider_available(ProviderType.OPENAI):
|
|
61
|
+
logger.debug("✅ OpenAI web search tool available")
|
|
62
|
+
tools.append(openai_web_search_tool)
|
|
36
63
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
64
|
+
if is_provider_available(ProviderType.ANTHROPIC):
|
|
65
|
+
logger.debug("✅ Anthropic web search tool available")
|
|
66
|
+
tools.append(anthropic_web_search_tool)
|
|
40
67
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
68
|
+
if is_provider_available(ProviderType.GOOGLE):
|
|
69
|
+
logger.debug("✅ Gemini web search tool available")
|
|
70
|
+
tools.append(gemini_web_search_tool)
|
|
44
71
|
|
|
45
72
|
if not tools:
|
|
46
73
|
logger.warning("⚠️ No web search tools available - no API keys configured")
|
|
@@ -1,20 +1,24 @@
|
|
|
1
1
|
"""Anthropic web search tool implementation."""
|
|
2
2
|
|
|
3
|
-
import anthropic
|
|
4
3
|
from opentelemetry import trace
|
|
4
|
+
from pydantic_ai.messages import ModelMessage, ModelRequest, TextPart
|
|
5
|
+
from pydantic_ai.settings import ModelSettings
|
|
5
6
|
|
|
6
7
|
from shotgun.agents.config import get_provider_model
|
|
8
|
+
from shotgun.agents.config.constants import MEDIUM_TEXT_8K_TOKENS
|
|
7
9
|
from shotgun.agents.config.models import ProviderType
|
|
10
|
+
from shotgun.agents.llm import shotgun_model_request
|
|
8
11
|
from shotgun.logging_config import get_logger
|
|
9
12
|
|
|
10
13
|
logger = get_logger(__name__)
|
|
11
14
|
|
|
12
15
|
|
|
13
|
-
def anthropic_web_search_tool(query: str) -> str:
|
|
14
|
-
"""Perform a web search using Anthropic's Claude API
|
|
16
|
+
async def anthropic_web_search_tool(query: str) -> str:
|
|
17
|
+
"""Perform a web search using Anthropic's Claude API.
|
|
15
18
|
|
|
16
19
|
This tool uses Anthropic's web search capabilities to find current information
|
|
17
|
-
about the given query.
|
|
20
|
+
about the given query. Works with both Shotgun API keys (via LiteLLM proxy)
|
|
21
|
+
and direct Anthropic API keys (BYOK).
|
|
18
22
|
|
|
19
23
|
Args:
|
|
20
24
|
query: The search query
|
|
@@ -27,49 +31,49 @@ def anthropic_web_search_tool(query: str) -> str:
|
|
|
27
31
|
span = trace.get_current_span()
|
|
28
32
|
span.set_attribute("input.value", f"**Query:** {query}\n")
|
|
29
33
|
|
|
30
|
-
logger.debug("📡 Executing Anthropic web search with
|
|
34
|
+
logger.debug("📡 Executing Anthropic web search with prompt: %s", query)
|
|
31
35
|
|
|
32
|
-
# Get
|
|
36
|
+
# Get model configuration (supports both Shotgun and BYOK)
|
|
33
37
|
try:
|
|
34
38
|
model_config = get_provider_model(ProviderType.ANTHROPIC)
|
|
35
|
-
api_key = model_config.api_key
|
|
36
39
|
except ValueError as e:
|
|
37
40
|
error_msg = f"Anthropic API key not configured: {str(e)}"
|
|
38
41
|
logger.error("❌ %s", error_msg)
|
|
39
42
|
span.set_attribute("output.value", f"**Error:**\n {error_msg}\n")
|
|
40
43
|
return error_msg
|
|
41
44
|
|
|
42
|
-
|
|
45
|
+
# Build the request messages
|
|
46
|
+
messages: list[ModelMessage] = [
|
|
47
|
+
ModelRequest.user_text_prompt(f"Search for: {query}")
|
|
48
|
+
]
|
|
43
49
|
|
|
44
|
-
# Use the Messages API with web search tool
|
|
50
|
+
# Use the Messages API with web search tool
|
|
45
51
|
try:
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
if not result_text:
|
|
72
|
-
result_text = "No content returned from search"
|
|
52
|
+
response = await shotgun_model_request(
|
|
53
|
+
model_config=model_config,
|
|
54
|
+
messages=messages,
|
|
55
|
+
model_settings=ModelSettings(
|
|
56
|
+
max_tokens=MEDIUM_TEXT_8K_TOKENS,
|
|
57
|
+
# Enable Anthropic web search tool
|
|
58
|
+
extra_body={
|
|
59
|
+
"tools": [
|
|
60
|
+
{
|
|
61
|
+
"type": "web_search_20250305",
|
|
62
|
+
"name": "web_search",
|
|
63
|
+
}
|
|
64
|
+
],
|
|
65
|
+
"tool_choice": {"type": "tool", "name": "web_search"},
|
|
66
|
+
},
|
|
67
|
+
),
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Extract text from response
|
|
71
|
+
result_text = "No content returned from search"
|
|
72
|
+
if response.parts:
|
|
73
|
+
for part in response.parts:
|
|
74
|
+
if isinstance(part, TextPart):
|
|
75
|
+
result_text = part.content
|
|
76
|
+
break
|
|
73
77
|
|
|
74
78
|
logger.debug("📄 Anthropic web search result: %d characters", len(result_text))
|
|
75
79
|
logger.debug(
|
|
@@ -88,9 +92,8 @@ def anthropic_web_search_tool(query: str) -> str:
|
|
|
88
92
|
return error_msg
|
|
89
93
|
|
|
90
94
|
|
|
91
|
-
def main() -> None:
|
|
95
|
+
async def main() -> None:
|
|
92
96
|
"""Main function for testing the Anthropic web search tool."""
|
|
93
|
-
import os
|
|
94
97
|
import sys
|
|
95
98
|
|
|
96
99
|
from shotgun.logging_config import setup_logger
|
|
@@ -110,24 +113,23 @@ def main() -> None:
|
|
|
110
113
|
# Join all arguments as the search query
|
|
111
114
|
query = " ".join(sys.argv[1:])
|
|
112
115
|
|
|
113
|
-
print("🔍 Testing Anthropic Web Search
|
|
116
|
+
print("🔍 Testing Anthropic Web Search")
|
|
114
117
|
print(f"📝 Query: {query}")
|
|
115
118
|
print("=" * 60)
|
|
116
119
|
|
|
117
120
|
# Check if API key is available
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
print("
|
|
126
|
-
print(" Please set it with: export ANTHROPIC_API_KEY=your_key_here")
|
|
121
|
+
try:
|
|
122
|
+
if callable(get_provider_model):
|
|
123
|
+
model_config = get_provider_model(ProviderType.ANTHROPIC)
|
|
124
|
+
if not model_config.api_key:
|
|
125
|
+
raise ValueError("No API key configured")
|
|
126
|
+
except (ValueError, Exception):
|
|
127
|
+
print("❌ Error: Anthropic API key not configured")
|
|
128
|
+
print(" Please set it in your config file")
|
|
127
129
|
sys.exit(1)
|
|
128
130
|
|
|
129
131
|
try:
|
|
130
|
-
result = anthropic_web_search_tool(query)
|
|
132
|
+
result = await anthropic_web_search_tool(query)
|
|
131
133
|
print(f"✅ Search completed! Result length: {len(result)} characters")
|
|
132
134
|
print("=" * 60)
|
|
133
135
|
print("📄 RESULTS:")
|
|
@@ -141,4 +143,6 @@ def main() -> None:
|
|
|
141
143
|
|
|
142
144
|
|
|
143
145
|
if __name__ == "__main__":
|
|
144
|
-
|
|
146
|
+
import asyncio
|
|
147
|
+
|
|
148
|
+
asyncio.run(main())
|
|
@@ -1,20 +1,24 @@
|
|
|
1
1
|
"""Gemini web search tool implementation."""
|
|
2
2
|
|
|
3
|
-
import google.generativeai as genai
|
|
4
3
|
from opentelemetry import trace
|
|
4
|
+
from pydantic_ai.messages import ModelMessage, ModelRequest
|
|
5
|
+
from pydantic_ai.settings import ModelSettings
|
|
5
6
|
|
|
6
7
|
from shotgun.agents.config import get_provider_model
|
|
7
|
-
from shotgun.agents.config.
|
|
8
|
+
from shotgun.agents.config.constants import MEDIUM_TEXT_8K_TOKENS
|
|
9
|
+
from shotgun.agents.config.models import ModelName
|
|
10
|
+
from shotgun.agents.llm import shotgun_model_request
|
|
8
11
|
from shotgun.logging_config import get_logger
|
|
9
12
|
|
|
10
13
|
logger = get_logger(__name__)
|
|
11
14
|
|
|
12
15
|
|
|
13
|
-
def gemini_web_search_tool(query: str) -> str:
|
|
16
|
+
async def gemini_web_search_tool(query: str) -> str:
|
|
14
17
|
"""Perform a web search using Google's Gemini API with grounding.
|
|
15
18
|
|
|
16
19
|
This tool uses Gemini's Google Search grounding to find current information
|
|
17
|
-
about the given query.
|
|
20
|
+
about the given query. Works with both Shotgun API keys (via LiteLLM proxy)
|
|
21
|
+
and direct Gemini API keys (BYOK).
|
|
18
22
|
|
|
19
23
|
Args:
|
|
20
24
|
query: The search query
|
|
@@ -29,23 +33,16 @@ def gemini_web_search_tool(query: str) -> str:
|
|
|
29
33
|
|
|
30
34
|
logger.debug("📡 Executing Gemini web search with prompt: %s", query)
|
|
31
35
|
|
|
32
|
-
# Get
|
|
36
|
+
# Get model configuration (supports both Shotgun and BYOK)
|
|
33
37
|
try:
|
|
34
|
-
model_config = get_provider_model(
|
|
35
|
-
api_key = model_config.api_key
|
|
38
|
+
model_config = get_provider_model(ModelName.GEMINI_2_5_FLASH)
|
|
36
39
|
except ValueError as e:
|
|
37
40
|
error_msg = f"Gemini API key not configured: {str(e)}"
|
|
38
41
|
logger.error("❌ %s", error_msg)
|
|
39
42
|
span.set_attribute("output.value", f"**Error:**\n {error_msg}\n")
|
|
40
43
|
return error_msg
|
|
41
44
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
# Create model without built-in tools to avoid conflict with Pydantic AI
|
|
45
|
-
# Using prompt-based search approach instead
|
|
46
|
-
model = genai.GenerativeModel("gemini-2.5-pro") # type: ignore[attr-defined]
|
|
47
|
-
|
|
48
|
-
# Create a search-optimized prompt that leverages Gemini's knowledge
|
|
45
|
+
# Create a search-optimized prompt
|
|
49
46
|
search_prompt = f"""Please provide current and accurate information about the following query:
|
|
50
47
|
|
|
51
48
|
Query: {query}
|
|
@@ -56,17 +53,31 @@ Instructions:
|
|
|
56
53
|
- Focus on current and recent information
|
|
57
54
|
- Be specific and accurate in your response"""
|
|
58
55
|
|
|
59
|
-
#
|
|
56
|
+
# Build the request messages
|
|
57
|
+
messages: list[ModelMessage] = [ModelRequest.user_text_prompt(search_prompt)]
|
|
58
|
+
|
|
59
|
+
# Generate response using Pydantic AI with Google Search grounding
|
|
60
60
|
try:
|
|
61
|
-
response =
|
|
62
|
-
|
|
63
|
-
|
|
61
|
+
response = await shotgun_model_request(
|
|
62
|
+
model_config=model_config,
|
|
63
|
+
messages=messages,
|
|
64
|
+
model_settings=ModelSettings(
|
|
64
65
|
temperature=0.3,
|
|
65
|
-
|
|
66
|
+
max_tokens=MEDIUM_TEXT_8K_TOKENS,
|
|
67
|
+
# Enable Google Search grounding for Gemini
|
|
68
|
+
extra_body={"tools": [{"googleSearch": {}}]},
|
|
66
69
|
),
|
|
67
70
|
)
|
|
68
71
|
|
|
69
|
-
|
|
72
|
+
# Extract text from response
|
|
73
|
+
from pydantic_ai.messages import TextPart
|
|
74
|
+
|
|
75
|
+
result_text = "No content returned from search"
|
|
76
|
+
if response.parts:
|
|
77
|
+
for part in response.parts:
|
|
78
|
+
if isinstance(part, TextPart):
|
|
79
|
+
result_text = part.content
|
|
80
|
+
break
|
|
70
81
|
|
|
71
82
|
logger.debug("📄 Gemini web search result: %d characters", len(result_text))
|
|
72
83
|
logger.debug(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""OpenAI web search tool implementation."""
|
|
2
2
|
|
|
3
|
-
from openai import
|
|
3
|
+
from openai import AsyncOpenAI
|
|
4
4
|
from opentelemetry import trace
|
|
5
5
|
|
|
6
6
|
from shotgun.agents.config import get_provider_model
|
|
@@ -10,7 +10,7 @@ from shotgun.logging_config import get_logger
|
|
|
10
10
|
logger = get_logger(__name__)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def openai_web_search_tool(query: str) -> str:
|
|
13
|
+
async def openai_web_search_tool(query: str) -> str:
|
|
14
14
|
"""Perform a web search and return results.
|
|
15
15
|
|
|
16
16
|
This tool uses OpenAI's web search capabilities to find current information
|
|
@@ -54,8 +54,8 @@ Instructions:
|
|
|
54
54
|
ALWAYS PROVIDE THE SOURCES (urls) TO BACK UP THE INFORMATION YOU PROVIDE.
|
|
55
55
|
"""
|
|
56
56
|
|
|
57
|
-
client =
|
|
58
|
-
response = client.responses.create( # type: ignore[call-overload]
|
|
57
|
+
client = AsyncOpenAI(api_key=api_key)
|
|
58
|
+
response = await client.responses.create( # type: ignore[call-overload]
|
|
59
59
|
model="gpt-5-mini",
|
|
60
60
|
input=[
|
|
61
61
|
{"role": "user", "content": [{"type": "input_text", "text": prompt}]}
|