shotgun-sh 0.1.16.dev1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of shotgun-sh might be problematic. Click here for more details.

Files changed (43) hide show
  1. shotgun/agents/common.py +4 -5
  2. shotgun/agents/config/constants.py +21 -5
  3. shotgun/agents/config/manager.py +171 -63
  4. shotgun/agents/config/models.py +65 -84
  5. shotgun/agents/config/provider.py +174 -85
  6. shotgun/agents/history/compaction.py +1 -1
  7. shotgun/agents/history/history_processors.py +18 -9
  8. shotgun/agents/history/token_counting/__init__.py +31 -0
  9. shotgun/agents/history/token_counting/anthropic.py +89 -0
  10. shotgun/agents/history/token_counting/base.py +67 -0
  11. shotgun/agents/history/token_counting/openai.py +80 -0
  12. shotgun/agents/history/token_counting/sentencepiece_counter.py +119 -0
  13. shotgun/agents/history/token_counting/tokenizer_cache.py +90 -0
  14. shotgun/agents/history/token_counting/utils.py +147 -0
  15. shotgun/agents/history/token_estimation.py +12 -12
  16. shotgun/agents/llm.py +62 -0
  17. shotgun/agents/models.py +2 -2
  18. shotgun/agents/tools/web_search/__init__.py +42 -15
  19. shotgun/agents/tools/web_search/anthropic.py +54 -50
  20. shotgun/agents/tools/web_search/gemini.py +31 -20
  21. shotgun/agents/tools/web_search/openai.py +4 -4
  22. shotgun/build_constants.py +2 -2
  23. shotgun/cli/config.py +28 -57
  24. shotgun/cli/models.py +2 -2
  25. shotgun/codebase/models.py +4 -4
  26. shotgun/llm_proxy/__init__.py +16 -0
  27. shotgun/llm_proxy/clients.py +39 -0
  28. shotgun/llm_proxy/constants.py +8 -0
  29. shotgun/main.py +6 -0
  30. shotgun/posthog_telemetry.py +5 -3
  31. shotgun/tui/app.py +7 -3
  32. shotgun/tui/screens/chat.py +15 -10
  33. shotgun/tui/screens/chat_screen/command_providers.py +118 -11
  34. shotgun/tui/screens/chat_screen/history.py +3 -1
  35. shotgun/tui/screens/model_picker.py +327 -0
  36. shotgun/tui/screens/provider_config.py +57 -26
  37. shotgun/utils/env_utils.py +12 -0
  38. {shotgun_sh-0.1.16.dev1.dist-info → shotgun_sh-0.2.0.dist-info}/METADATA +2 -2
  39. {shotgun_sh-0.1.16.dev1.dist-info → shotgun_sh-0.2.0.dist-info}/RECORD +42 -31
  40. shotgun/agents/history/token_counting.py +0 -429
  41. {shotgun_sh-0.1.16.dev1.dist-info → shotgun_sh-0.2.0.dist-info}/WHEEL +0 -0
  42. {shotgun_sh-0.1.16.dev1.dist-info → shotgun_sh-0.2.0.dist-info}/entry_points.txt +0 -0
  43. {shotgun_sh-0.1.16.dev1.dist-info → shotgun_sh-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,80 @@
1
+ """OpenAI token counting using tiktoken."""
2
+
3
+ from pydantic_ai.messages import ModelMessage
4
+
5
+ from shotgun.logging_config import get_logger
6
+
7
+ from .base import TokenCounter, extract_text_from_messages
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ class OpenAITokenCounter(TokenCounter):
13
+ """Token counter for OpenAI models using tiktoken."""
14
+
15
+ # Official encoding mappings for OpenAI models
16
+ ENCODING_MAP = {
17
+ "gpt-5": "o200k_base",
18
+ "gpt-4o": "o200k_base",
19
+ "gpt-4": "cl100k_base",
20
+ "gpt-3.5-turbo": "cl100k_base",
21
+ }
22
+
23
+ def __init__(self, model_name: str):
24
+ """Initialize OpenAI token counter.
25
+
26
+ Args:
27
+ model_name: OpenAI model name to get correct encoding for
28
+
29
+ Raises:
30
+ RuntimeError: If encoding initialization fails
31
+ """
32
+ self.model_name = model_name
33
+
34
+ import tiktoken
35
+
36
+ try:
37
+ # Get the appropriate encoding for this model
38
+ encoding_name = self.ENCODING_MAP.get(model_name, "o200k_base")
39
+ self.encoding = tiktoken.get_encoding(encoding_name)
40
+ logger.debug(
41
+ f"Initialized OpenAI token counter with {encoding_name} encoding"
42
+ )
43
+ except Exception as e:
44
+ raise RuntimeError(
45
+ f"Failed to initialize tiktoken encoding for {model_name}"
46
+ ) from e
47
+
48
+ async def count_tokens(self, text: str) -> int:
49
+ """Count tokens using tiktoken (async).
50
+
51
+ Args:
52
+ text: Text to count tokens for
53
+
54
+ Returns:
55
+ Exact token count using tiktoken
56
+
57
+ Raises:
58
+ RuntimeError: If token counting fails
59
+ """
60
+ try:
61
+ return len(self.encoding.encode(text))
62
+ except Exception as e:
63
+ raise RuntimeError(
64
+ f"Failed to count tokens for OpenAI model {self.model_name}"
65
+ ) from e
66
+
67
+ async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
68
+ """Count tokens across all messages using tiktoken (async).
69
+
70
+ Args:
71
+ messages: List of PydanticAI messages
72
+
73
+ Returns:
74
+ Total token count for all messages
75
+
76
+ Raises:
77
+ RuntimeError: If token counting fails
78
+ """
79
+ total_text = extract_text_from_messages(messages)
80
+ return await self.count_tokens(total_text)
@@ -0,0 +1,119 @@
1
+ """Gemini token counting using official SentencePiece tokenizer.
2
+
3
+ This implementation uses Google's official Gemini/Gemma tokenizer model
4
+ for 100% accurate local token counting without API calls.
5
+
6
+ Performance: 10-100x faster than API-based counting.
7
+ Accuracy: 100% match with actual Gemini API usage.
8
+
9
+ The tokenizer is downloaded on first use and cached locally for future use.
10
+ """
11
+
12
+ from typing import Any
13
+
14
+ from pydantic_ai.messages import ModelMessage
15
+
16
+ from shotgun.logging_config import get_logger
17
+
18
+ from .base import TokenCounter, extract_text_from_messages
19
+ from .tokenizer_cache import download_gemini_tokenizer, get_gemini_tokenizer_path
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ class SentencePieceTokenCounter(TokenCounter):
25
+ """Token counter for Gemini models using official SentencePiece tokenizer.
26
+
27
+ This counter provides 100% accurate token counting for Gemini models
28
+ using the official tokenizer model from Google's gemma_pytorch repository.
29
+ Token counting is performed locally without any API calls, resulting in
30
+ 10-100x performance improvement over API-based methods.
31
+
32
+ The tokenizer is downloaded asynchronously on first use and cached locally.
33
+ """
34
+
35
+ def __init__(self, model_name: str):
36
+ """Initialize Gemini SentencePiece token counter.
37
+
38
+ The tokenizer is not loaded immediately - it will be downloaded and
39
+ loaded lazily on first use.
40
+
41
+ Args:
42
+ model_name: Gemini model name (used for logging)
43
+ """
44
+ self.model_name = model_name
45
+ self.sp: Any | None = None # SentencePieceProcessor, loaded lazily
46
+
47
+ async def _ensure_tokenizer(self) -> None:
48
+ """Ensure tokenizer is downloaded and loaded.
49
+
50
+ This method downloads the tokenizer on first call (if not cached)
51
+ and loads it into memory. Subsequent calls reuse the loaded tokenizer.
52
+
53
+ Raises:
54
+ RuntimeError: If tokenizer download or loading fails
55
+ """
56
+ if self.sp is not None:
57
+ # Already loaded
58
+ return
59
+
60
+ import sentencepiece as spm # type: ignore[import-untyped]
61
+
62
+ try:
63
+ # Check if already cached, otherwise download
64
+ tokenizer_path = get_gemini_tokenizer_path()
65
+ if not tokenizer_path.exists():
66
+ await download_gemini_tokenizer()
67
+
68
+ # Load the tokenizer
69
+ self.sp = spm.SentencePieceProcessor()
70
+ self.sp.load(str(tokenizer_path))
71
+ logger.debug(f"Loaded SentencePiece tokenizer for {self.model_name}")
72
+ except Exception as e:
73
+ raise RuntimeError(
74
+ f"Failed to load Gemini tokenizer for {self.model_name}"
75
+ ) from e
76
+
77
+ async def count_tokens(self, text: str) -> int:
78
+ """Count tokens using SentencePiece (async).
79
+
80
+ Downloads tokenizer on first call if not cached.
81
+
82
+ Args:
83
+ text: Text to count tokens for
84
+
85
+ Returns:
86
+ Exact token count using Gemini's tokenizer
87
+
88
+ Raises:
89
+ RuntimeError: If token counting fails
90
+ """
91
+ await self._ensure_tokenizer()
92
+
93
+ if self.sp is None:
94
+ raise RuntimeError(f"Tokenizer not initialized for {self.model_name}")
95
+
96
+ try:
97
+ tokens = self.sp.encode(text)
98
+ return len(tokens)
99
+ except Exception as e:
100
+ raise RuntimeError(
101
+ f"Failed to count tokens for Gemini model {self.model_name}"
102
+ ) from e
103
+
104
+ async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
105
+ """Count tokens across all messages using SentencePiece (async).
106
+
107
+ Downloads tokenizer on first call if not cached.
108
+
109
+ Args:
110
+ messages: List of PydanticAI messages
111
+
112
+ Returns:
113
+ Total token count for all messages
114
+
115
+ Raises:
116
+ RuntimeError: If token counting fails
117
+ """
118
+ total_text = extract_text_from_messages(messages)
119
+ return await self.count_tokens(total_text)
@@ -0,0 +1,90 @@
1
+ """Async tokenizer download and caching utilities."""
2
+
3
+ import hashlib
4
+ from pathlib import Path
5
+
6
+ import httpx
7
+
8
+ from shotgun.logging_config import get_logger
9
+ from shotgun.utils.file_system_utils import get_shotgun_home
10
+
11
+ logger = get_logger(__name__)
12
+
13
+ # Gemini tokenizer constants
14
+ GEMINI_TOKENIZER_URL = "https://raw.githubusercontent.com/google/gemma_pytorch/main/tokenizer/tokenizer.model"
15
+ GEMINI_TOKENIZER_SHA256 = (
16
+ "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"
17
+ )
18
+
19
+
20
+ def get_tokenizer_cache_dir() -> Path:
21
+ """Get the directory for cached tokenizer models.
22
+
23
+ Returns:
24
+ Path to tokenizers cache directory
25
+ """
26
+ cache_dir = get_shotgun_home() / "tokenizers"
27
+ cache_dir.mkdir(parents=True, exist_ok=True)
28
+ return cache_dir
29
+
30
+
31
+ def get_gemini_tokenizer_path() -> Path:
32
+ """Get the path where the Gemini tokenizer should be cached.
33
+
34
+ Returns:
35
+ Path to cached Gemini tokenizer
36
+ """
37
+ return get_tokenizer_cache_dir() / "gemini_tokenizer.model"
38
+
39
+
40
+ async def download_gemini_tokenizer() -> Path:
41
+ """Download and cache the official Gemini tokenizer model.
42
+
43
+ This downloads Google's official Gemini/Gemma tokenizer from the
44
+ gemma_pytorch repository and caches it locally for future use.
45
+
46
+ The download is async and non-blocking, with SHA256 verification
47
+ for security.
48
+
49
+ Returns:
50
+ Path to the cached tokenizer file
51
+
52
+ Raises:
53
+ RuntimeError: If download fails or checksum verification fails
54
+ """
55
+ cache_path = get_gemini_tokenizer_path()
56
+
57
+ # Check if already cached
58
+ if cache_path.exists():
59
+ logger.debug(f"Gemini tokenizer already cached at {cache_path}")
60
+ return cache_path
61
+
62
+ logger.info("Downloading Gemini tokenizer (4MB, first time only)...")
63
+
64
+ try:
65
+ # Download with async httpx
66
+ async with httpx.AsyncClient(timeout=30.0) as client:
67
+ response = await client.get(GEMINI_TOKENIZER_URL, follow_redirects=True)
68
+ response.raise_for_status()
69
+ content = response.content
70
+
71
+ # Verify SHA256 checksum
72
+ actual_hash = hashlib.sha256(content).hexdigest()
73
+ if actual_hash != GEMINI_TOKENIZER_SHA256:
74
+ raise RuntimeError(
75
+ f"Gemini tokenizer checksum mismatch. "
76
+ f"Expected: {GEMINI_TOKENIZER_SHA256}, got: {actual_hash}"
77
+ )
78
+
79
+ # Atomic write: write to temp file first, then rename
80
+ temp_path = cache_path.with_suffix(".tmp")
81
+ temp_path.write_bytes(content)
82
+ temp_path.rename(cache_path)
83
+
84
+ logger.info(f"Gemini tokenizer downloaded and cached at {cache_path}")
85
+ return cache_path
86
+
87
+ except httpx.HTTPError as e:
88
+ raise RuntimeError(f"Failed to download Gemini tokenizer: {e}") from e
89
+ except OSError as e:
90
+ raise RuntimeError(f"Failed to save Gemini tokenizer: {e}") from e
@@ -0,0 +1,147 @@
1
+ """Utility functions and cache for token counting."""
2
+
3
+ from pydantic_ai.messages import ModelMessage
4
+
5
+ from shotgun.agents.config.models import ModelConfig, ProviderType
6
+ from shotgun.logging_config import get_logger
7
+
8
+ from .anthropic import AnthropicTokenCounter
9
+ from .base import TokenCounter
10
+ from .openai import OpenAITokenCounter
11
+ from .sentencepiece_counter import SentencePieceTokenCounter
12
+
13
+ logger = get_logger(__name__)
14
+
15
+ # Global cache for token counter instances (singleton pattern)
16
+ _token_counter_cache: dict[tuple[str, str, str], TokenCounter] = {}
17
+
18
+
19
+ def get_token_counter(model_config: ModelConfig) -> TokenCounter:
20
+ """Get appropriate token counter for the model provider (cached singleton).
21
+
22
+ This function ensures that every provider has a proper token counting
23
+ implementation without any fallbacks to estimation. Token counters are
24
+ cached to avoid repeated initialization overhead.
25
+
26
+ Args:
27
+ model_config: Model configuration with provider and credentials
28
+
29
+ Returns:
30
+ Cached provider-specific token counter
31
+
32
+ Raises:
33
+ ValueError: If provider is not supported for token counting
34
+ RuntimeError: If token counter initialization fails
35
+ """
36
+ # Create cache key from provider, model name, and API key
37
+ cache_key = (
38
+ model_config.provider.value,
39
+ model_config.name,
40
+ model_config.api_key[:10]
41
+ if model_config.api_key
42
+ else "no-key", # Partial key for cache
43
+ )
44
+
45
+ # Return cached instance if available
46
+ if cache_key in _token_counter_cache:
47
+ logger.debug(
48
+ f"Reusing cached token counter for {model_config.provider.value}:{model_config.name}"
49
+ )
50
+ return _token_counter_cache[cache_key]
51
+
52
+ # Create new instance and cache it
53
+ logger.debug(
54
+ f"Creating new token counter for {model_config.provider.value}:{model_config.name}"
55
+ )
56
+
57
+ counter: TokenCounter
58
+ if model_config.provider == ProviderType.OPENAI:
59
+ counter = OpenAITokenCounter(model_config.name)
60
+ elif model_config.provider == ProviderType.ANTHROPIC:
61
+ counter = AnthropicTokenCounter(
62
+ model_config.name, model_config.api_key, model_config.key_provider
63
+ )
64
+ elif model_config.provider == ProviderType.GOOGLE:
65
+ # Use local SentencePiece tokenizer (100% accurate, 10-100x faster than API)
66
+ counter = SentencePieceTokenCounter(model_config.name)
67
+ else:
68
+ raise ValueError(
69
+ f"Unsupported provider for token counting: {model_config.provider}. "
70
+ f"Supported providers: {[p.value for p in ProviderType]}"
71
+ )
72
+
73
+ # Cache the instance
74
+ _token_counter_cache[cache_key] = counter
75
+ logger.debug(
76
+ f"Cached token counter for {model_config.provider.value}:{model_config.name}"
77
+ )
78
+
79
+ return counter
80
+
81
+
82
+ async def count_tokens_from_messages(
83
+ messages: list[ModelMessage], model_config: ModelConfig
84
+ ) -> int:
85
+ """Count actual tokens from messages using provider-specific methods (async).
86
+
87
+ This replaces the old estimation approach with accurate token counting
88
+ using each provider's official APIs and libraries.
89
+
90
+ Args:
91
+ messages: List of messages to count tokens for
92
+ model_config: Model configuration with provider info
93
+
94
+ Returns:
95
+ Exact token count for the messages
96
+
97
+ Raises:
98
+ ValueError: If provider is not supported
99
+ RuntimeError: If token counting fails
100
+ """
101
+ counter = get_token_counter(model_config)
102
+ return await counter.count_message_tokens(messages)
103
+
104
+
105
+ async def count_post_summary_tokens(
106
+ messages: list[ModelMessage], summary_index: int, model_config: ModelConfig
107
+ ) -> int:
108
+ """Count actual tokens from summary onwards for incremental compaction decisions (async).
109
+
110
+ Args:
111
+ messages: Full message history
112
+ summary_index: Index of the last summary message
113
+ model_config: Model configuration with provider info
114
+
115
+ Returns:
116
+ Exact token count from summary onwards
117
+
118
+ Raises:
119
+ ValueError: If provider is not supported
120
+ RuntimeError: If token counting fails
121
+ """
122
+ if summary_index >= len(messages):
123
+ return 0
124
+
125
+ post_summary_messages = messages[summary_index:]
126
+ return await count_tokens_from_messages(post_summary_messages, model_config)
127
+
128
+
129
+ async def count_tokens_from_message_parts(
130
+ messages: list[ModelMessage], model_config: ModelConfig
131
+ ) -> int:
132
+ """Count actual tokens from message parts for summarization requests (async).
133
+
134
+ Args:
135
+ messages: List of messages to count tokens for
136
+ model_config: Model configuration with provider info
137
+
138
+ Returns:
139
+ Exact token count from message parts
140
+
141
+ Raises:
142
+ ValueError: If provider is not supported
143
+ RuntimeError: If token counting fails
144
+ """
145
+ # For now, use the same logic as count_tokens_from_messages
146
+ # This can be optimized later if needed for different counting strategies
147
+ return await count_tokens_from_messages(messages, model_config)
@@ -19,10 +19,10 @@ from .constants import INPUT_BUFFER_TOKENS, MIN_SUMMARY_TOKENS
19
19
  from .token_counting import count_tokens_from_messages as _count_tokens_from_messages
20
20
 
21
21
 
22
- def estimate_tokens_from_messages(
22
+ async def estimate_tokens_from_messages(
23
23
  messages: list[ModelMessage], model_config: ModelConfig
24
24
  ) -> int:
25
- """Count actual tokens from current message list.
25
+ """Count actual tokens from current message list (async).
26
26
 
27
27
  This provides accurate token counting for compaction decisions using
28
28
  provider-specific token counting methods instead of rough estimation.
@@ -38,13 +38,13 @@ def estimate_tokens_from_messages(
38
38
  ValueError: If provider is not supported
39
39
  RuntimeError: If token counting fails
40
40
  """
41
- return _count_tokens_from_messages(messages, model_config)
41
+ return await _count_tokens_from_messages(messages, model_config)
42
42
 
43
43
 
44
- def estimate_post_summary_tokens(
44
+ async def estimate_post_summary_tokens(
45
45
  messages: list[ModelMessage], summary_index: int, model_config: ModelConfig
46
46
  ) -> int:
47
- """Count actual tokens from summary onwards for incremental compaction decisions.
47
+ """Count actual tokens from summary onwards for incremental compaction decisions (async).
48
48
 
49
49
  This treats the summary as a reset point and only counts tokens from the summary
50
50
  message onwards. Used to determine if incremental compaction is needed.
@@ -65,13 +65,13 @@ def estimate_post_summary_tokens(
65
65
  return 0
66
66
 
67
67
  post_summary_messages = messages[summary_index:]
68
- return estimate_tokens_from_messages(post_summary_messages, model_config)
68
+ return await estimate_tokens_from_messages(post_summary_messages, model_config)
69
69
 
70
70
 
71
- def estimate_tokens_from_message_parts(
71
+ async def estimate_tokens_from_message_parts(
72
72
  messages: list[ModelMessage], model_config: ModelConfig
73
73
  ) -> int:
74
- """Count actual tokens from message parts for summarization requests.
74
+ """Count actual tokens from message parts for summarization requests (async).
75
75
 
76
76
  This provides accurate token counting across the codebase using
77
77
  provider-specific methods instead of character estimation.
@@ -87,14 +87,14 @@ def estimate_tokens_from_message_parts(
87
87
  ValueError: If provider is not supported
88
88
  RuntimeError: If token counting fails
89
89
  """
90
- return _count_tokens_from_messages(messages, model_config)
90
+ return await _count_tokens_from_messages(messages, model_config)
91
91
 
92
92
 
93
- def calculate_max_summarization_tokens(
93
+ async def calculate_max_summarization_tokens(
94
94
  ctx_or_model_config: Union["RunContext[AgentDeps]", ModelConfig],
95
95
  request_messages: list[ModelMessage],
96
96
  ) -> int:
97
- """Calculate maximum tokens available for summarization output.
97
+ """Calculate maximum tokens available for summarization output (async).
98
98
 
99
99
  This ensures we use the model's full capacity while leaving room for input tokens.
100
100
 
@@ -115,7 +115,7 @@ def calculate_max_summarization_tokens(
115
115
  return MIN_SUMMARY_TOKENS
116
116
 
117
117
  # Count actual input tokens using shared utility
118
- estimated_input_tokens = estimate_tokens_from_message_parts(
118
+ estimated_input_tokens = await estimate_tokens_from_message_parts(
119
119
  request_messages, model_config
120
120
  )
121
121
 
shotgun/agents/llm.py ADDED
@@ -0,0 +1,62 @@
1
+ """LLM request utilities for Shotgun agents."""
2
+
3
+ from typing import Any
4
+
5
+ from pydantic_ai.direct import model_request
6
+ from pydantic_ai.messages import ModelMessage, ModelResponse
7
+ from pydantic_ai.settings import ModelSettings
8
+
9
+ from shotgun.agents.config.models import ModelConfig
10
+ from shotgun.logging_config import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ async def shotgun_model_request(
16
+ model_config: ModelConfig,
17
+ messages: list[ModelMessage],
18
+ model_settings: ModelSettings | None = None,
19
+ **kwargs: Any,
20
+ ) -> ModelResponse:
21
+ """Model request wrapper that uses full token capacity by default.
22
+
23
+ This wrapper ensures all LLM calls in Shotgun use the maximum available
24
+ token capacity of each model, improving response quality and completeness.
25
+ The most common issue this fixes is truncated summaries that were cut off
26
+ at default token limits (e.g., 4096 for Claude models).
27
+
28
+ Args:
29
+ model_config: ModelConfig instance with model settings and API key
30
+ messages: Messages to send to the model
31
+ model_settings: Optional ModelSettings. If None, creates default with max tokens
32
+ **kwargs: Additional arguments passed to model_request
33
+
34
+ Returns:
35
+ ModelResponse from the model
36
+
37
+ Example:
38
+ # Uses full token capacity (e.g., 4096 for Claude, 128k for GPT-5)
39
+ response = await shotgun_model_request(model_config, messages)
40
+
41
+ # With custom settings
42
+ response = await shotgun_model_request(model_config, messages, model_settings=ModelSettings(max_tokens=1000, temperature=0.7))
43
+ """
44
+ if kwargs.get("max_tokens") is not None:
45
+ logger.warning(
46
+ "⚠️ 'max_tokens' argument is ignored in shotgun_model_request. "
47
+ "Set 'model_settings.max_tokens' instead."
48
+ )
49
+
50
+ if not model_settings:
51
+ model_settings = ModelSettings()
52
+
53
+ if model_settings.get("max_tokens") is None:
54
+ model_settings["max_tokens"] = model_config.max_output_tokens
55
+
56
+ # Make the model request with full token utilization
57
+ return await model_request(
58
+ model=model_config.model_instance,
59
+ messages=messages,
60
+ model_settings=model_settings,
61
+ **kwargs,
62
+ )
shotgun/agents/models.py CHANGED
@@ -4,7 +4,7 @@ import os
4
4
  from asyncio import Future, Queue
5
5
  from collections.abc import Callable
6
6
  from datetime import datetime
7
- from enum import Enum, StrEnum
7
+ from enum import StrEnum
8
8
  from pathlib import Path
9
9
  from typing import TYPE_CHECKING
10
10
 
@@ -116,7 +116,7 @@ class AgentRuntimeOptions(BaseModel):
116
116
  )
117
117
 
118
118
 
119
- class FileOperationType(str, Enum):
119
+ class FileOperationType(StrEnum):
120
120
  """Types of file operations that can be tracked."""
121
121
 
122
122
  CREATED = "created"
@@ -1,13 +1,14 @@
1
1
  """Web search tools for Pydantic AI agents.
2
2
 
3
3
  Provides web search capabilities for multiple LLM providers:
4
- - OpenAI: Uses Responses API with web_search tool
5
- - Anthropic: Uses Messages API with web_search_20250305 tool
6
- - Gemini: Uses grounding with Google Search
4
+ - OpenAI: Uses Responses API with web_search tool (BYOK only)
5
+ - Anthropic: Uses Messages API with web_search_20250305 tool (BYOK only)
6
+ - Gemini: Uses grounding with Google Search via Pydantic AI (works with Shotgun Account)
7
7
  """
8
8
 
9
- from collections.abc import Callable
9
+ from collections.abc import Awaitable, Callable
10
10
 
11
+ from shotgun.agents.config import get_config_manager
11
12
  from shotgun.agents.config.models import ProviderType
12
13
  from shotgun.logging_config import get_logger
13
14
 
@@ -18,29 +19,55 @@ from .utils import is_provider_available
18
19
 
19
20
  logger = get_logger(__name__)
20
21
 
21
- # Type alias for web search tools
22
- WebSearchTool = Callable[[str], str]
22
+ # Type alias for web search tools (all now async)
23
+ WebSearchTool = Callable[[str], Awaitable[str]]
23
24
 
24
25
 
25
26
  def get_available_web_search_tools() -> list[WebSearchTool]:
26
27
  """Get list of available web search tools based on configured API keys.
27
28
 
29
+ When using Shotgun Account (via LiteLLM proxy):
30
+ Only Gemini web search is available (others use provider-specific APIs)
31
+
32
+ When using BYOK (individual provider keys):
33
+ All provider tools are available based on their respective keys
34
+
28
35
  Returns:
29
36
  List of web search tool functions that have API keys configured
30
37
  """
31
38
  tools: list[WebSearchTool] = []
32
39
 
33
- if is_provider_available(ProviderType.OPENAI):
34
- logger.debug("✅ OpenAI web search tool available")
35
- tools.append(openai_web_search_tool)
40
+ # Check if using Shotgun Account
41
+ config_manager = get_config_manager()
42
+ config = config_manager.load()
43
+ has_shotgun_key = config.shotgun.api_key is not None
44
+
45
+ if has_shotgun_key:
46
+ # Shotgun Account mode: Only Gemini supports web search via LiteLLM
47
+ if is_provider_available(ProviderType.GOOGLE):
48
+ logger.info("🔑 Shotgun Account detected - using Gemini web search only")
49
+ logger.debug(" OpenAI and Anthropic web search require direct API keys")
50
+ tools.append(gemini_web_search_tool)
51
+ else:
52
+ logger.warning(
53
+ "⚠️ Shotgun Account configured but no Gemini key - "
54
+ "web search unavailable"
55
+ )
56
+ else:
57
+ # BYOK mode: Load all available tools based on individual provider keys
58
+ logger.debug("🔑 BYOK mode - checking all provider web search tools")
59
+
60
+ if is_provider_available(ProviderType.OPENAI):
61
+ logger.debug("✅ OpenAI web search tool available")
62
+ tools.append(openai_web_search_tool)
36
63
 
37
- if is_provider_available(ProviderType.ANTHROPIC):
38
- logger.debug("✅ Anthropic web search tool available")
39
- tools.append(anthropic_web_search_tool)
64
+ if is_provider_available(ProviderType.ANTHROPIC):
65
+ logger.debug("✅ Anthropic web search tool available")
66
+ tools.append(anthropic_web_search_tool)
40
67
 
41
- if is_provider_available(ProviderType.GOOGLE):
42
- logger.debug("✅ Gemini web search tool available")
43
- tools.append(gemini_web_search_tool)
68
+ if is_provider_available(ProviderType.GOOGLE):
69
+ logger.debug("✅ Gemini web search tool available")
70
+ tools.append(gemini_web_search_tool)
44
71
 
45
72
  if not tools:
46
73
  logger.warning("⚠️ No web search tools available - no API keys configured")