shotgun-sh 0.1.15.dev1__py3-none-any.whl → 0.2.1.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/common.py +4 -5
- shotgun/agents/config/constants.py +21 -5
- shotgun/agents/config/manager.py +147 -39
- shotgun/agents/config/models.py +59 -86
- shotgun/agents/config/provider.py +164 -61
- shotgun/agents/history/compaction.py +1 -1
- shotgun/agents/history/history_processors.py +18 -9
- shotgun/agents/history/token_counting/__init__.py +31 -0
- shotgun/agents/history/token_counting/anthropic.py +89 -0
- shotgun/agents/history/token_counting/base.py +67 -0
- shotgun/agents/history/token_counting/openai.py +80 -0
- shotgun/agents/history/token_counting/sentencepiece_counter.py +119 -0
- shotgun/agents/history/token_counting/tokenizer_cache.py +90 -0
- shotgun/agents/history/token_counting/utils.py +147 -0
- shotgun/agents/history/token_estimation.py +12 -12
- shotgun/agents/llm.py +62 -0
- shotgun/agents/models.py +2 -2
- shotgun/agents/tools/web_search/__init__.py +42 -15
- shotgun/agents/tools/web_search/anthropic.py +46 -40
- shotgun/agents/tools/web_search/gemini.py +31 -20
- shotgun/agents/tools/web_search/openai.py +4 -4
- shotgun/cli/config.py +14 -55
- shotgun/cli/models.py +2 -2
- shotgun/codebase/models.py +4 -4
- shotgun/llm_proxy/__init__.py +16 -0
- shotgun/llm_proxy/clients.py +39 -0
- shotgun/llm_proxy/constants.py +8 -0
- shotgun/main.py +6 -0
- shotgun/posthog_telemetry.py +5 -3
- shotgun/tui/app.py +2 -0
- shotgun/tui/screens/chat_screen/command_providers.py +20 -0
- shotgun/tui/screens/model_picker.py +214 -0
- shotgun/tui/screens/provider_config.py +39 -26
- {shotgun_sh-0.1.15.dev1.dist-info → shotgun_sh-0.2.1.dev1.dist-info}/METADATA +2 -2
- {shotgun_sh-0.1.15.dev1.dist-info → shotgun_sh-0.2.1.dev1.dist-info}/RECORD +38 -27
- shotgun/agents/history/token_counting.py +0 -429
- {shotgun_sh-0.1.15.dev1.dist-info → shotgun_sh-0.2.1.dev1.dist-info}/WHEEL +0 -0
- {shotgun_sh-0.1.15.dev1.dist-info → shotgun_sh-0.2.1.dev1.dist-info}/entry_points.txt +0 -0
- {shotgun_sh-0.1.15.dev1.dist-info → shotgun_sh-0.2.1.dev1.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,27 +12,42 @@ from pydantic_ai.providers.google import GoogleProvider
|
|
|
12
12
|
from pydantic_ai.providers.openai import OpenAIProvider
|
|
13
13
|
from pydantic_ai.settings import ModelSettings
|
|
14
14
|
|
|
15
|
+
from shotgun.llm_proxy import create_litellm_provider
|
|
15
16
|
from shotgun.logging_config import get_logger
|
|
16
17
|
|
|
17
18
|
from .constants import (
|
|
18
19
|
ANTHROPIC_API_KEY_ENV,
|
|
19
20
|
GEMINI_API_KEY_ENV,
|
|
20
21
|
OPENAI_API_KEY_ENV,
|
|
22
|
+
SHOTGUN_API_KEY_ENV,
|
|
21
23
|
)
|
|
22
24
|
from .manager import get_config_manager
|
|
23
|
-
from .models import
|
|
25
|
+
from .models import (
|
|
26
|
+
MODEL_SPECS,
|
|
27
|
+
KeyProvider,
|
|
28
|
+
ModelConfig,
|
|
29
|
+
ModelName,
|
|
30
|
+
ProviderType,
|
|
31
|
+
ShotgunConfig,
|
|
32
|
+
)
|
|
24
33
|
|
|
25
34
|
logger = get_logger(__name__)
|
|
26
35
|
|
|
27
36
|
# Global cache for Model instances (singleton pattern)
|
|
28
|
-
_model_cache: dict[tuple[ProviderType,
|
|
37
|
+
_model_cache: dict[tuple[ProviderType, KeyProvider, ModelName, str], Model] = {}
|
|
29
38
|
|
|
30
39
|
|
|
31
|
-
def get_or_create_model(
|
|
40
|
+
def get_or_create_model(
|
|
41
|
+
provider: ProviderType,
|
|
42
|
+
key_provider: "KeyProvider",
|
|
43
|
+
model_name: ModelName,
|
|
44
|
+
api_key: str,
|
|
45
|
+
) -> Model:
|
|
32
46
|
"""Get or create a singleton Model instance.
|
|
33
47
|
|
|
34
48
|
Args:
|
|
35
|
-
provider:
|
|
49
|
+
provider: Actual LLM provider (openai, anthropic, google)
|
|
50
|
+
key_provider: Authentication method (byok or shotgun)
|
|
36
51
|
model_name: Name of the model
|
|
37
52
|
api_key: API key for the provider
|
|
38
53
|
|
|
@@ -42,66 +57,88 @@ def get_or_create_model(provider: ProviderType, model_name: str, api_key: str) -
|
|
|
42
57
|
Raises:
|
|
43
58
|
ValueError: If provider is not supported
|
|
44
59
|
"""
|
|
45
|
-
cache_key = (provider, model_name, api_key)
|
|
60
|
+
cache_key = (provider, key_provider, model_name, api_key)
|
|
46
61
|
|
|
47
62
|
if cache_key not in _model_cache:
|
|
48
|
-
logger.debug(
|
|
63
|
+
logger.debug(
|
|
64
|
+
"Creating new %s model instance via %s: %s",
|
|
65
|
+
provider.value,
|
|
66
|
+
key_provider.value,
|
|
67
|
+
model_name,
|
|
68
|
+
)
|
|
49
69
|
|
|
50
|
-
|
|
51
|
-
|
|
70
|
+
# Get max_tokens from MODEL_SPECS
|
|
71
|
+
if model_name in MODEL_SPECS:
|
|
72
|
+
max_tokens = MODEL_SPECS[model_name].max_output_tokens
|
|
73
|
+
else:
|
|
74
|
+
# Fallback defaults based on provider
|
|
75
|
+
max_tokens = {
|
|
76
|
+
ProviderType.OPENAI: 16_000,
|
|
77
|
+
ProviderType.ANTHROPIC: 32_000,
|
|
78
|
+
ProviderType.GOOGLE: 64_000,
|
|
79
|
+
}.get(provider, 16_000)
|
|
80
|
+
|
|
81
|
+
# Use LiteLLM proxy for Shotgun Account, native providers for BYOK
|
|
82
|
+
if key_provider == KeyProvider.SHOTGUN:
|
|
83
|
+
# Shotgun Account uses LiteLLM proxy for any model
|
|
52
84
|
if model_name in MODEL_SPECS:
|
|
53
|
-
|
|
85
|
+
litellm_model_name = MODEL_SPECS[model_name].litellm_proxy_model_name
|
|
54
86
|
else:
|
|
55
|
-
|
|
87
|
+
# Fallback for unmapped models
|
|
88
|
+
litellm_model_name = f"openai/{model_name.value}"
|
|
56
89
|
|
|
57
|
-
|
|
90
|
+
litellm_provider = create_litellm_provider(api_key)
|
|
58
91
|
_model_cache[cache_key] = OpenAIChatModel(
|
|
59
|
-
|
|
60
|
-
provider=
|
|
92
|
+
litellm_model_name,
|
|
93
|
+
provider=litellm_provider,
|
|
61
94
|
settings=ModelSettings(max_tokens=max_tokens),
|
|
62
95
|
)
|
|
63
|
-
elif
|
|
64
|
-
#
|
|
65
|
-
if
|
|
66
|
-
|
|
96
|
+
elif key_provider == KeyProvider.BYOK:
|
|
97
|
+
# Use native provider implementations with user's API keys
|
|
98
|
+
if provider == ProviderType.OPENAI:
|
|
99
|
+
openai_provider = OpenAIProvider(api_key=api_key)
|
|
100
|
+
_model_cache[cache_key] = OpenAIChatModel(
|
|
101
|
+
model_name,
|
|
102
|
+
provider=openai_provider,
|
|
103
|
+
settings=ModelSettings(max_tokens=max_tokens),
|
|
104
|
+
)
|
|
105
|
+
elif provider == ProviderType.ANTHROPIC:
|
|
106
|
+
anthropic_provider = AnthropicProvider(api_key=api_key)
|
|
107
|
+
_model_cache[cache_key] = AnthropicModel(
|
|
108
|
+
model_name,
|
|
109
|
+
provider=anthropic_provider,
|
|
110
|
+
settings=AnthropicModelSettings(
|
|
111
|
+
max_tokens=max_tokens,
|
|
112
|
+
timeout=600, # 10 minutes timeout for large responses
|
|
113
|
+
),
|
|
114
|
+
)
|
|
115
|
+
elif provider == ProviderType.GOOGLE:
|
|
116
|
+
google_provider = GoogleProvider(api_key=api_key)
|
|
117
|
+
_model_cache[cache_key] = GoogleModel(
|
|
118
|
+
model_name,
|
|
119
|
+
provider=google_provider,
|
|
120
|
+
settings=ModelSettings(max_tokens=max_tokens),
|
|
121
|
+
)
|
|
67
122
|
else:
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
anthropic_provider = AnthropicProvider(api_key=api_key)
|
|
71
|
-
_model_cache[cache_key] = AnthropicModel(
|
|
72
|
-
model_name,
|
|
73
|
-
provider=anthropic_provider,
|
|
74
|
-
settings=AnthropicModelSettings(
|
|
75
|
-
max_tokens=max_tokens,
|
|
76
|
-
timeout=600, # 10 minutes timeout for large responses
|
|
77
|
-
),
|
|
78
|
-
)
|
|
79
|
-
elif provider == ProviderType.GOOGLE:
|
|
80
|
-
# Get max_tokens from MODEL_SPECS to use full capacity
|
|
81
|
-
if model_name in MODEL_SPECS:
|
|
82
|
-
max_tokens = MODEL_SPECS[model_name].max_output_tokens
|
|
83
|
-
else:
|
|
84
|
-
max_tokens = 64_000 # Default for Gemini models
|
|
85
|
-
|
|
86
|
-
google_provider = GoogleProvider(api_key=api_key)
|
|
87
|
-
_model_cache[cache_key] = GoogleModel(
|
|
88
|
-
model_name,
|
|
89
|
-
provider=google_provider,
|
|
90
|
-
settings=ModelSettings(max_tokens=max_tokens),
|
|
91
|
-
)
|
|
123
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
92
124
|
else:
|
|
93
|
-
raise ValueError(f"Unsupported provider: {
|
|
125
|
+
raise ValueError(f"Unsupported key provider: {key_provider}")
|
|
94
126
|
else:
|
|
95
127
|
logger.debug("Reusing cached %s model instance: %s", provider.value, model_name)
|
|
96
128
|
|
|
97
129
|
return _model_cache[cache_key]
|
|
98
130
|
|
|
99
131
|
|
|
100
|
-
def get_provider_model(
|
|
132
|
+
def get_provider_model(
|
|
133
|
+
provider_or_model: ProviderType | ModelName | None = None,
|
|
134
|
+
) -> ModelConfig:
|
|
101
135
|
"""Get a fully configured ModelConfig with API key and Model instance.
|
|
102
136
|
|
|
103
137
|
Args:
|
|
104
|
-
|
|
138
|
+
provider_or_model: Either a ProviderType, ModelName, or None.
|
|
139
|
+
- If ModelName: returns that specific model with appropriate API key
|
|
140
|
+
- If ProviderType: returns default model for that provider (backward compatible)
|
|
141
|
+
- If None: uses default provider with its default model
|
|
105
142
|
|
|
106
143
|
Returns:
|
|
107
144
|
ModelConfig with API key configured and lazy Model instance
|
|
@@ -111,14 +148,58 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
|
|
|
111
148
|
"""
|
|
112
149
|
config_manager = get_config_manager()
|
|
113
150
|
config = config_manager.load()
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
151
|
+
|
|
152
|
+
# Priority 1: Check if Shotgun key exists - if so, use it for ANY model
|
|
153
|
+
shotgun_api_key = _get_api_key(config.shotgun.api_key, SHOTGUN_API_KEY_ENV)
|
|
154
|
+
if shotgun_api_key:
|
|
155
|
+
# Use selected model or default to claude-opus-4-1
|
|
156
|
+
model_name = config.selected_model or ModelName.CLAUDE_OPUS_4_1
|
|
157
|
+
if model_name not in MODEL_SPECS:
|
|
158
|
+
raise ValueError(f"Model '{model_name.value}' not found")
|
|
159
|
+
spec = MODEL_SPECS[model_name]
|
|
160
|
+
|
|
161
|
+
# Use Shotgun Account with selected model (provider = actual LLM provider)
|
|
162
|
+
return ModelConfig(
|
|
163
|
+
name=spec.name,
|
|
164
|
+
provider=spec.provider, # Actual LLM provider (OPENAI/ANTHROPIC/GOOGLE)
|
|
165
|
+
key_provider=KeyProvider.SHOTGUN, # Authenticated via Shotgun Account
|
|
166
|
+
max_input_tokens=spec.max_input_tokens,
|
|
167
|
+
max_output_tokens=spec.max_output_tokens,
|
|
168
|
+
api_key=shotgun_api_key,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Priority 2: Fall back to individual provider keys
|
|
172
|
+
|
|
173
|
+
# Check if a specific model was requested
|
|
174
|
+
if isinstance(provider_or_model, ModelName):
|
|
175
|
+
# Look up the model spec
|
|
176
|
+
if provider_or_model not in MODEL_SPECS:
|
|
177
|
+
raise ValueError(f"Model '{provider_or_model.value}' not found")
|
|
178
|
+
spec = MODEL_SPECS[provider_or_model]
|
|
179
|
+
provider_enum = spec.provider
|
|
180
|
+
requested_model = provider_or_model
|
|
181
|
+
else:
|
|
182
|
+
# Convert string to ProviderType enum if needed (backward compatible)
|
|
183
|
+
if provider_or_model:
|
|
184
|
+
provider_enum = (
|
|
185
|
+
provider_or_model
|
|
186
|
+
if isinstance(provider_or_model, ProviderType)
|
|
187
|
+
else ProviderType(provider_or_model)
|
|
188
|
+
)
|
|
189
|
+
else:
|
|
190
|
+
# No provider specified - find first available provider with a key
|
|
191
|
+
provider_enum = None
|
|
192
|
+
for provider in ProviderType:
|
|
193
|
+
if _has_provider_key(config, provider):
|
|
194
|
+
provider_enum = provider
|
|
195
|
+
break
|
|
196
|
+
|
|
197
|
+
if provider_enum is None:
|
|
198
|
+
raise ValueError(
|
|
199
|
+
"No provider keys configured. Set via environment variables or config."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
requested_model = None # Will use provider's default model
|
|
122
203
|
|
|
123
204
|
if provider_enum == ProviderType.OPENAI:
|
|
124
205
|
api_key = _get_api_key(config.openai.api_key, OPENAI_API_KEY_ENV)
|
|
@@ -127,16 +208,17 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
|
|
|
127
208
|
f"OpenAI API key not configured. Set via environment variable {OPENAI_API_KEY_ENV} or config."
|
|
128
209
|
)
|
|
129
210
|
|
|
130
|
-
#
|
|
131
|
-
model_name =
|
|
211
|
+
# Use requested model or default to gpt-5
|
|
212
|
+
model_name = requested_model if requested_model else ModelName.GPT_5
|
|
132
213
|
if model_name not in MODEL_SPECS:
|
|
133
|
-
raise ValueError(f"Model '{model_name}' not found")
|
|
214
|
+
raise ValueError(f"Model '{model_name.value}' not found")
|
|
134
215
|
spec = MODEL_SPECS[model_name]
|
|
135
216
|
|
|
136
217
|
# Create fully configured ModelConfig
|
|
137
218
|
return ModelConfig(
|
|
138
219
|
name=spec.name,
|
|
139
220
|
provider=spec.provider,
|
|
221
|
+
key_provider=KeyProvider.BYOK,
|
|
140
222
|
max_input_tokens=spec.max_input_tokens,
|
|
141
223
|
max_output_tokens=spec.max_output_tokens,
|
|
142
224
|
api_key=api_key,
|
|
@@ -149,16 +231,17 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
|
|
|
149
231
|
f"Anthropic API key not configured. Set via environment variable {ANTHROPIC_API_KEY_ENV} or config."
|
|
150
232
|
)
|
|
151
233
|
|
|
152
|
-
#
|
|
153
|
-
model_name =
|
|
234
|
+
# Use requested model or default to claude-opus-4-1
|
|
235
|
+
model_name = requested_model if requested_model else ModelName.CLAUDE_OPUS_4_1
|
|
154
236
|
if model_name not in MODEL_SPECS:
|
|
155
|
-
raise ValueError(f"Model '{model_name}' not found")
|
|
237
|
+
raise ValueError(f"Model '{model_name.value}' not found")
|
|
156
238
|
spec = MODEL_SPECS[model_name]
|
|
157
239
|
|
|
158
240
|
# Create fully configured ModelConfig
|
|
159
241
|
return ModelConfig(
|
|
160
242
|
name=spec.name,
|
|
161
243
|
provider=spec.provider,
|
|
244
|
+
key_provider=KeyProvider.BYOK,
|
|
162
245
|
max_input_tokens=spec.max_input_tokens,
|
|
163
246
|
max_output_tokens=spec.max_output_tokens,
|
|
164
247
|
api_key=api_key,
|
|
@@ -171,16 +254,17 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
|
|
|
171
254
|
f"Gemini API key not configured. Set via environment variable {GEMINI_API_KEY_ENV} or config."
|
|
172
255
|
)
|
|
173
256
|
|
|
174
|
-
#
|
|
175
|
-
model_name =
|
|
257
|
+
# Use requested model or default to gemini-2.5-pro
|
|
258
|
+
model_name = requested_model if requested_model else ModelName.GEMINI_2_5_PRO
|
|
176
259
|
if model_name not in MODEL_SPECS:
|
|
177
|
-
raise ValueError(f"Model '{model_name}' not found")
|
|
260
|
+
raise ValueError(f"Model '{model_name.value}' not found")
|
|
178
261
|
spec = MODEL_SPECS[model_name]
|
|
179
262
|
|
|
180
263
|
# Create fully configured ModelConfig
|
|
181
264
|
return ModelConfig(
|
|
182
265
|
name=spec.name,
|
|
183
266
|
provider=spec.provider,
|
|
267
|
+
key_provider=KeyProvider.BYOK,
|
|
184
268
|
max_input_tokens=spec.max_input_tokens,
|
|
185
269
|
max_output_tokens=spec.max_output_tokens,
|
|
186
270
|
api_key=api_key,
|
|
@@ -190,6 +274,25 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
|
|
|
190
274
|
raise ValueError(f"Unsupported provider: {provider_enum}")
|
|
191
275
|
|
|
192
276
|
|
|
277
|
+
def _has_provider_key(config: "ShotgunConfig", provider: ProviderType) -> bool:
|
|
278
|
+
"""Check if a provider has a configured API key.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
config: Shotgun configuration
|
|
282
|
+
provider: Provider to check
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
True if provider has a configured API key
|
|
286
|
+
"""
|
|
287
|
+
if provider == ProviderType.OPENAI:
|
|
288
|
+
return bool(_get_api_key(config.openai.api_key, OPENAI_API_KEY_ENV))
|
|
289
|
+
elif provider == ProviderType.ANTHROPIC:
|
|
290
|
+
return bool(_get_api_key(config.anthropic.api_key, ANTHROPIC_API_KEY_ENV))
|
|
291
|
+
elif provider == ProviderType.GOOGLE:
|
|
292
|
+
return bool(_get_api_key(config.google.api_key, GEMINI_API_KEY_ENV))
|
|
293
|
+
return False
|
|
294
|
+
|
|
295
|
+
|
|
193
296
|
def _get_api_key(config_key: SecretStr | None, env_var: str) -> str | None:
|
|
194
297
|
"""Get API key from config or environment variable.
|
|
195
298
|
|
|
@@ -31,7 +31,7 @@ async def apply_persistent_compaction(
|
|
|
31
31
|
|
|
32
32
|
try:
|
|
33
33
|
# Count actual token usage using shared utility
|
|
34
|
-
estimated_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
|
|
34
|
+
estimated_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
|
|
35
35
|
|
|
36
36
|
# Create minimal usage info for compaction check
|
|
37
37
|
usage = RequestUsage(
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Protocol
|
|
4
4
|
|
|
5
|
+
from pydantic_ai import ModelSettings
|
|
5
6
|
from pydantic_ai.messages import (
|
|
6
7
|
ModelMessage,
|
|
7
8
|
ModelRequest,
|
|
@@ -10,7 +11,7 @@ from pydantic_ai.messages import (
|
|
|
10
11
|
UserPromptPart,
|
|
11
12
|
)
|
|
12
13
|
|
|
13
|
-
from shotgun.agents.
|
|
14
|
+
from shotgun.agents.llm import shotgun_model_request
|
|
14
15
|
from shotgun.agents.messages import AgentSystemPrompt, SystemStatusPrompt
|
|
15
16
|
from shotgun.agents.models import AgentDeps
|
|
16
17
|
from shotgun.logging_config import get_logger
|
|
@@ -154,7 +155,7 @@ async def token_limit_compactor(
|
|
|
154
155
|
|
|
155
156
|
if last_summary_index is not None:
|
|
156
157
|
# Check if post-summary conversation exceeds threshold for incremental compaction
|
|
157
|
-
post_summary_tokens = estimate_post_summary_tokens(
|
|
158
|
+
post_summary_tokens = await estimate_post_summary_tokens(
|
|
158
159
|
messages, last_summary_index, deps.llm_model
|
|
159
160
|
)
|
|
160
161
|
post_summary_percentage = (
|
|
@@ -248,7 +249,7 @@ async def token_limit_compactor(
|
|
|
248
249
|
]
|
|
249
250
|
|
|
250
251
|
# Calculate optimal max_tokens for summarization
|
|
251
|
-
max_tokens = calculate_max_summarization_tokens(
|
|
252
|
+
max_tokens = await calculate_max_summarization_tokens(
|
|
252
253
|
deps.llm_model, request_messages
|
|
253
254
|
)
|
|
254
255
|
|
|
@@ -261,7 +262,9 @@ async def token_limit_compactor(
|
|
|
261
262
|
summary_response = await shotgun_model_request(
|
|
262
263
|
model_config=deps.llm_model,
|
|
263
264
|
messages=request_messages,
|
|
264
|
-
|
|
265
|
+
model_settings=ModelSettings(
|
|
266
|
+
max_tokens=max_tokens # Use calculated optimal tokens for summarization
|
|
267
|
+
),
|
|
265
268
|
)
|
|
266
269
|
|
|
267
270
|
log_summarization_response(summary_response, "INCREMENTAL")
|
|
@@ -328,7 +331,9 @@ async def token_limit_compactor(
|
|
|
328
331
|
|
|
329
332
|
# Track compaction completion
|
|
330
333
|
messages_after = len(compacted_messages)
|
|
331
|
-
tokens_after = estimate_tokens_from_messages(
|
|
334
|
+
tokens_after = await estimate_tokens_from_messages(
|
|
335
|
+
compacted_messages, deps.llm_model
|
|
336
|
+
)
|
|
332
337
|
reduction_percentage = (
|
|
333
338
|
((messages_before - messages_after) / messages_before * 100)
|
|
334
339
|
if messages_before > 0
|
|
@@ -354,7 +359,7 @@ async def token_limit_compactor(
|
|
|
354
359
|
|
|
355
360
|
else:
|
|
356
361
|
# Check if total conversation exceeds threshold for full compaction
|
|
357
|
-
total_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
|
|
362
|
+
total_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
|
|
358
363
|
total_percentage = (total_tokens / max_tokens) * 100 if max_tokens > 0 else 0
|
|
359
364
|
|
|
360
365
|
logger.debug(
|
|
@@ -392,7 +397,9 @@ async def _full_compaction(
|
|
|
392
397
|
]
|
|
393
398
|
|
|
394
399
|
# Calculate optimal max_tokens for summarization
|
|
395
|
-
max_tokens = calculate_max_summarization_tokens(
|
|
400
|
+
max_tokens = await calculate_max_summarization_tokens(
|
|
401
|
+
deps.llm_model, request_messages
|
|
402
|
+
)
|
|
396
403
|
|
|
397
404
|
# Debug logging using shared utilities
|
|
398
405
|
log_summarization_request(
|
|
@@ -403,11 +410,13 @@ async def _full_compaction(
|
|
|
403
410
|
summary_response = await shotgun_model_request(
|
|
404
411
|
model_config=deps.llm_model,
|
|
405
412
|
messages=request_messages,
|
|
406
|
-
|
|
413
|
+
model_settings=ModelSettings(
|
|
414
|
+
max_tokens=max_tokens # Use calculated optimal tokens for summarization
|
|
415
|
+
),
|
|
407
416
|
)
|
|
408
417
|
|
|
409
418
|
# Calculate token reduction
|
|
410
|
-
current_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
|
|
419
|
+
current_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
|
|
411
420
|
summary_usage = summary_response.usage
|
|
412
421
|
reduction_percentage = (
|
|
413
422
|
((current_tokens - summary_usage.output_tokens) / current_tokens) * 100
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Real token counting for all supported providers.
|
|
2
|
+
|
|
3
|
+
This module provides accurate token counting using each provider's official
|
|
4
|
+
APIs and libraries, eliminating the need for rough character-based estimation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .anthropic import AnthropicTokenCounter
|
|
8
|
+
from .base import TokenCounter, extract_text_from_messages
|
|
9
|
+
from .openai import OpenAITokenCounter
|
|
10
|
+
from .sentencepiece_counter import SentencePieceTokenCounter
|
|
11
|
+
from .utils import (
|
|
12
|
+
count_post_summary_tokens,
|
|
13
|
+
count_tokens_from_message_parts,
|
|
14
|
+
count_tokens_from_messages,
|
|
15
|
+
get_token_counter,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
# Base classes
|
|
20
|
+
"TokenCounter",
|
|
21
|
+
# Counter implementations
|
|
22
|
+
"OpenAITokenCounter",
|
|
23
|
+
"AnthropicTokenCounter",
|
|
24
|
+
"SentencePieceTokenCounter",
|
|
25
|
+
# Utility functions
|
|
26
|
+
"get_token_counter",
|
|
27
|
+
"count_tokens_from_messages",
|
|
28
|
+
"count_post_summary_tokens",
|
|
29
|
+
"count_tokens_from_message_parts",
|
|
30
|
+
"extract_text_from_messages",
|
|
31
|
+
]
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""Anthropic token counting using official client."""
|
|
2
|
+
|
|
3
|
+
from pydantic_ai.messages import ModelMessage
|
|
4
|
+
|
|
5
|
+
from shotgun.agents.config.models import KeyProvider
|
|
6
|
+
from shotgun.llm_proxy import create_anthropic_proxy_client
|
|
7
|
+
from shotgun.logging_config import get_logger
|
|
8
|
+
|
|
9
|
+
from .base import TokenCounter, extract_text_from_messages
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AnthropicTokenCounter(TokenCounter):
|
|
15
|
+
"""Token counter for Anthropic models using official client."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
model_name: str,
|
|
20
|
+
api_key: str,
|
|
21
|
+
key_provider: KeyProvider = KeyProvider.BYOK,
|
|
22
|
+
):
|
|
23
|
+
"""Initialize Anthropic token counter.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
model_name: Anthropic model name for token counting
|
|
27
|
+
api_key: API key (Anthropic for BYOK, Shotgun for proxy)
|
|
28
|
+
key_provider: Key provider type (BYOK or SHOTGUN)
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
RuntimeError: If client initialization fails
|
|
32
|
+
"""
|
|
33
|
+
self.model_name = model_name
|
|
34
|
+
import anthropic
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
if key_provider == KeyProvider.SHOTGUN:
|
|
38
|
+
# Use LiteLLM proxy for Shotgun Account
|
|
39
|
+
# Proxies to Anthropic's token counting API
|
|
40
|
+
self.client = create_anthropic_proxy_client(api_key)
|
|
41
|
+
logger.debug(
|
|
42
|
+
f"Initialized Anthropic token counter for {model_name} via LiteLLM proxy"
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
# Direct Anthropic API for BYOK
|
|
46
|
+
self.client = anthropic.Anthropic(api_key=api_key)
|
|
47
|
+
logger.debug(
|
|
48
|
+
f"Initialized Anthropic token counter for {model_name} via direct API"
|
|
49
|
+
)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
raise RuntimeError("Failed to initialize Anthropic client") from e
|
|
52
|
+
|
|
53
|
+
async def count_tokens(self, text: str) -> int:
|
|
54
|
+
"""Count tokens using Anthropic's official API (async).
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
text: Text to count tokens for
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Exact token count from Anthropic API
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
RuntimeError: If API call fails
|
|
64
|
+
"""
|
|
65
|
+
try:
|
|
66
|
+
# Anthropic API expects messages format and model parameter
|
|
67
|
+
result = self.client.messages.count_tokens(
|
|
68
|
+
messages=[{"role": "user", "content": text}], model=self.model_name
|
|
69
|
+
)
|
|
70
|
+
return result.input_tokens
|
|
71
|
+
except Exception as e:
|
|
72
|
+
raise RuntimeError(
|
|
73
|
+
f"Anthropic token counting API failed for {self.model_name}"
|
|
74
|
+
) from e
|
|
75
|
+
|
|
76
|
+
async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
|
|
77
|
+
"""Count tokens across all messages using Anthropic API (async).
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
messages: List of PydanticAI messages
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Total token count for all messages
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
RuntimeError: If token counting fails
|
|
87
|
+
"""
|
|
88
|
+
total_text = extract_text_from_messages(messages)
|
|
89
|
+
return await self.count_tokens(total_text)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Base classes and shared utilities for token counting."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from pydantic_ai.messages import ModelMessage
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TokenCounter(ABC):
|
|
9
|
+
"""Abstract base class for provider-specific token counting.
|
|
10
|
+
|
|
11
|
+
All methods are async to support non-blocking operations like
|
|
12
|
+
downloading tokenizer models or making API calls.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
async def count_tokens(self, text: str) -> int:
|
|
17
|
+
"""Count tokens in text using provider-specific method (async).
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
text: Text to count tokens for
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Exact token count as determined by the provider
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
RuntimeError: If token counting fails
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
|
|
31
|
+
"""Count tokens in PydanticAI message structures (async).
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
messages: List of messages to count tokens for
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Total token count across all messages
|
|
38
|
+
|
|
39
|
+
Raises:
|
|
40
|
+
RuntimeError: If token counting fails
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def extract_text_from_messages(messages: list[ModelMessage]) -> str:
|
|
45
|
+
"""Extract all text content from messages for token counting.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
messages: List of PydanticAI messages
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Combined text content from all messages
|
|
52
|
+
"""
|
|
53
|
+
text_parts = []
|
|
54
|
+
|
|
55
|
+
for message in messages:
|
|
56
|
+
if hasattr(message, "parts"):
|
|
57
|
+
for part in message.parts:
|
|
58
|
+
if hasattr(part, "content") and isinstance(part.content, str):
|
|
59
|
+
text_parts.append(part.content)
|
|
60
|
+
else:
|
|
61
|
+
# Handle non-text parts (tool calls, etc.)
|
|
62
|
+
text_parts.append(str(part))
|
|
63
|
+
else:
|
|
64
|
+
# Handle messages without parts
|
|
65
|
+
text_parts.append(str(message))
|
|
66
|
+
|
|
67
|
+
return "\n".join(text_parts)
|