shotgun-sh 0.1.16.dev2__py3-none-any.whl → 0.2.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/common.py +4 -5
- shotgun/agents/config/constants.py +21 -5
- shotgun/agents/config/manager.py +149 -57
- shotgun/agents/config/models.py +65 -84
- shotgun/agents/config/provider.py +172 -84
- shotgun/agents/history/compaction.py +1 -1
- shotgun/agents/history/history_processors.py +18 -9
- shotgun/agents/history/token_counting/__init__.py +31 -0
- shotgun/agents/history/token_counting/anthropic.py +89 -0
- shotgun/agents/history/token_counting/base.py +67 -0
- shotgun/agents/history/token_counting/openai.py +80 -0
- shotgun/agents/history/token_counting/sentencepiece_counter.py +119 -0
- shotgun/agents/history/token_counting/tokenizer_cache.py +90 -0
- shotgun/agents/history/token_counting/utils.py +147 -0
- shotgun/agents/history/token_estimation.py +12 -12
- shotgun/agents/llm.py +62 -0
- shotgun/agents/models.py +2 -2
- shotgun/agents/tools/web_search/__init__.py +42 -15
- shotgun/agents/tools/web_search/anthropic.py +54 -50
- shotgun/agents/tools/web_search/gemini.py +31 -20
- shotgun/agents/tools/web_search/openai.py +4 -4
- shotgun/cli/config.py +14 -55
- shotgun/cli/models.py +2 -2
- shotgun/codebase/models.py +4 -4
- shotgun/llm_proxy/__init__.py +16 -0
- shotgun/llm_proxy/clients.py +39 -0
- shotgun/llm_proxy/constants.py +8 -0
- shotgun/main.py +6 -0
- shotgun/posthog_telemetry.py +5 -3
- shotgun/tui/app.py +2 -0
- shotgun/tui/screens/chat_screen/command_providers.py +20 -0
- shotgun/tui/screens/model_picker.py +215 -0
- shotgun/tui/screens/provider_config.py +39 -26
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dev2.dist-info}/METADATA +2 -2
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dev2.dist-info}/RECORD +38 -27
- shotgun/agents/history/token_counting.py +0 -429
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dev2.dist-info}/WHEEL +0 -0
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dev2.dist-info}/entry_points.txt +0 -0
- {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dev2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Provider management for LLM configuration."""
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
|
-
|
|
5
3
|
from pydantic import SecretStr
|
|
6
4
|
from pydantic_ai.models import Model
|
|
7
5
|
from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
|
|
@@ -12,27 +10,36 @@ from pydantic_ai.providers.google import GoogleProvider
|
|
|
12
10
|
from pydantic_ai.providers.openai import OpenAIProvider
|
|
13
11
|
from pydantic_ai.settings import ModelSettings
|
|
14
12
|
|
|
13
|
+
from shotgun.llm_proxy import create_litellm_provider
|
|
15
14
|
from shotgun.logging_config import get_logger
|
|
16
15
|
|
|
17
|
-
from .constants import (
|
|
18
|
-
ANTHROPIC_API_KEY_ENV,
|
|
19
|
-
GEMINI_API_KEY_ENV,
|
|
20
|
-
OPENAI_API_KEY_ENV,
|
|
21
|
-
)
|
|
22
16
|
from .manager import get_config_manager
|
|
23
|
-
from .models import
|
|
17
|
+
from .models import (
|
|
18
|
+
MODEL_SPECS,
|
|
19
|
+
KeyProvider,
|
|
20
|
+
ModelConfig,
|
|
21
|
+
ModelName,
|
|
22
|
+
ProviderType,
|
|
23
|
+
ShotgunConfig,
|
|
24
|
+
)
|
|
24
25
|
|
|
25
26
|
logger = get_logger(__name__)
|
|
26
27
|
|
|
27
28
|
# Global cache for Model instances (singleton pattern)
|
|
28
|
-
_model_cache: dict[tuple[ProviderType,
|
|
29
|
+
_model_cache: dict[tuple[ProviderType, KeyProvider, ModelName, str], Model] = {}
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
def get_or_create_model(
|
|
32
|
+
def get_or_create_model(
|
|
33
|
+
provider: ProviderType,
|
|
34
|
+
key_provider: "KeyProvider",
|
|
35
|
+
model_name: ModelName,
|
|
36
|
+
api_key: str,
|
|
37
|
+
) -> Model:
|
|
32
38
|
"""Get or create a singleton Model instance.
|
|
33
39
|
|
|
34
40
|
Args:
|
|
35
|
-
provider:
|
|
41
|
+
provider: Actual LLM provider (openai, anthropic, google)
|
|
42
|
+
key_provider: Authentication method (byok or shotgun)
|
|
36
43
|
model_name: Name of the model
|
|
37
44
|
api_key: API key for the provider
|
|
38
45
|
|
|
@@ -42,66 +49,88 @@ def get_or_create_model(provider: ProviderType, model_name: str, api_key: str) -
|
|
|
42
49
|
Raises:
|
|
43
50
|
ValueError: If provider is not supported
|
|
44
51
|
"""
|
|
45
|
-
cache_key = (provider, model_name, api_key)
|
|
52
|
+
cache_key = (provider, key_provider, model_name, api_key)
|
|
46
53
|
|
|
47
54
|
if cache_key not in _model_cache:
|
|
48
|
-
logger.debug(
|
|
55
|
+
logger.debug(
|
|
56
|
+
"Creating new %s model instance via %s: %s",
|
|
57
|
+
provider.value,
|
|
58
|
+
key_provider.value,
|
|
59
|
+
model_name,
|
|
60
|
+
)
|
|
49
61
|
|
|
50
|
-
|
|
51
|
-
|
|
62
|
+
# Get max_tokens from MODEL_SPECS
|
|
63
|
+
if model_name in MODEL_SPECS:
|
|
64
|
+
max_tokens = MODEL_SPECS[model_name].max_output_tokens
|
|
65
|
+
else:
|
|
66
|
+
# Fallback defaults based on provider
|
|
67
|
+
max_tokens = {
|
|
68
|
+
ProviderType.OPENAI: 16_000,
|
|
69
|
+
ProviderType.ANTHROPIC: 32_000,
|
|
70
|
+
ProviderType.GOOGLE: 64_000,
|
|
71
|
+
}.get(provider, 16_000)
|
|
72
|
+
|
|
73
|
+
# Use LiteLLM proxy for Shotgun Account, native providers for BYOK
|
|
74
|
+
if key_provider == KeyProvider.SHOTGUN:
|
|
75
|
+
# Shotgun Account uses LiteLLM proxy for any model
|
|
52
76
|
if model_name in MODEL_SPECS:
|
|
53
|
-
|
|
77
|
+
litellm_model_name = MODEL_SPECS[model_name].litellm_proxy_model_name
|
|
54
78
|
else:
|
|
55
|
-
|
|
79
|
+
# Fallback for unmapped models
|
|
80
|
+
litellm_model_name = f"openai/{model_name.value}"
|
|
56
81
|
|
|
57
|
-
|
|
82
|
+
litellm_provider = create_litellm_provider(api_key)
|
|
58
83
|
_model_cache[cache_key] = OpenAIChatModel(
|
|
59
|
-
|
|
60
|
-
provider=
|
|
84
|
+
litellm_model_name,
|
|
85
|
+
provider=litellm_provider,
|
|
61
86
|
settings=ModelSettings(max_tokens=max_tokens),
|
|
62
87
|
)
|
|
63
|
-
elif
|
|
64
|
-
#
|
|
65
|
-
if
|
|
66
|
-
|
|
88
|
+
elif key_provider == KeyProvider.BYOK:
|
|
89
|
+
# Use native provider implementations with user's API keys
|
|
90
|
+
if provider == ProviderType.OPENAI:
|
|
91
|
+
openai_provider = OpenAIProvider(api_key=api_key)
|
|
92
|
+
_model_cache[cache_key] = OpenAIChatModel(
|
|
93
|
+
model_name,
|
|
94
|
+
provider=openai_provider,
|
|
95
|
+
settings=ModelSettings(max_tokens=max_tokens),
|
|
96
|
+
)
|
|
97
|
+
elif provider == ProviderType.ANTHROPIC:
|
|
98
|
+
anthropic_provider = AnthropicProvider(api_key=api_key)
|
|
99
|
+
_model_cache[cache_key] = AnthropicModel(
|
|
100
|
+
model_name,
|
|
101
|
+
provider=anthropic_provider,
|
|
102
|
+
settings=AnthropicModelSettings(
|
|
103
|
+
max_tokens=max_tokens,
|
|
104
|
+
timeout=600, # 10 minutes timeout for large responses
|
|
105
|
+
),
|
|
106
|
+
)
|
|
107
|
+
elif provider == ProviderType.GOOGLE:
|
|
108
|
+
google_provider = GoogleProvider(api_key=api_key)
|
|
109
|
+
_model_cache[cache_key] = GoogleModel(
|
|
110
|
+
model_name,
|
|
111
|
+
provider=google_provider,
|
|
112
|
+
settings=ModelSettings(max_tokens=max_tokens),
|
|
113
|
+
)
|
|
67
114
|
else:
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
anthropic_provider = AnthropicProvider(api_key=api_key)
|
|
71
|
-
_model_cache[cache_key] = AnthropicModel(
|
|
72
|
-
model_name,
|
|
73
|
-
provider=anthropic_provider,
|
|
74
|
-
settings=AnthropicModelSettings(
|
|
75
|
-
max_tokens=max_tokens,
|
|
76
|
-
timeout=600, # 10 minutes timeout for large responses
|
|
77
|
-
),
|
|
78
|
-
)
|
|
79
|
-
elif provider == ProviderType.GOOGLE:
|
|
80
|
-
# Get max_tokens from MODEL_SPECS to use full capacity
|
|
81
|
-
if model_name in MODEL_SPECS:
|
|
82
|
-
max_tokens = MODEL_SPECS[model_name].max_output_tokens
|
|
83
|
-
else:
|
|
84
|
-
max_tokens = 64_000 # Default for Gemini models
|
|
85
|
-
|
|
86
|
-
google_provider = GoogleProvider(api_key=api_key)
|
|
87
|
-
_model_cache[cache_key] = GoogleModel(
|
|
88
|
-
model_name,
|
|
89
|
-
provider=google_provider,
|
|
90
|
-
settings=ModelSettings(max_tokens=max_tokens),
|
|
91
|
-
)
|
|
115
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
92
116
|
else:
|
|
93
|
-
raise ValueError(f"Unsupported provider: {
|
|
117
|
+
raise ValueError(f"Unsupported key provider: {key_provider}")
|
|
94
118
|
else:
|
|
95
119
|
logger.debug("Reusing cached %s model instance: %s", provider.value, model_name)
|
|
96
120
|
|
|
97
121
|
return _model_cache[cache_key]
|
|
98
122
|
|
|
99
123
|
|
|
100
|
-
def get_provider_model(
|
|
124
|
+
def get_provider_model(
|
|
125
|
+
provider_or_model: ProviderType | ModelName | None = None,
|
|
126
|
+
) -> ModelConfig:
|
|
101
127
|
"""Get a fully configured ModelConfig with API key and Model instance.
|
|
102
128
|
|
|
103
129
|
Args:
|
|
104
|
-
|
|
130
|
+
provider_or_model: Either a ProviderType, ModelName, or None.
|
|
131
|
+
- If ModelName: returns that specific model with appropriate API key
|
|
132
|
+
- If ProviderType: returns default model for that provider (backward compatible)
|
|
133
|
+
- If None: uses default provider with its default model
|
|
105
134
|
|
|
106
135
|
Returns:
|
|
107
136
|
ModelConfig with API key configured and lazy Model instance
|
|
@@ -111,76 +140,117 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
|
|
|
111
140
|
"""
|
|
112
141
|
config_manager = get_config_manager()
|
|
113
142
|
config = config_manager.load()
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
143
|
+
|
|
144
|
+
# Priority 1: Check if Shotgun key exists - if so, use it for ANY model
|
|
145
|
+
shotgun_api_key = _get_api_key(config.shotgun.api_key)
|
|
146
|
+
if shotgun_api_key:
|
|
147
|
+
# Use selected model or default to claude-sonnet-4-5
|
|
148
|
+
model_name = config.selected_model or ModelName.CLAUDE_SONNET_4_5
|
|
149
|
+
if model_name not in MODEL_SPECS:
|
|
150
|
+
raise ValueError(f"Model '{model_name.value}' not found")
|
|
151
|
+
spec = MODEL_SPECS[model_name]
|
|
152
|
+
|
|
153
|
+
# Use Shotgun Account with selected model (provider = actual LLM provider)
|
|
154
|
+
return ModelConfig(
|
|
155
|
+
name=spec.name,
|
|
156
|
+
provider=spec.provider, # Actual LLM provider (OPENAI/ANTHROPIC/GOOGLE)
|
|
157
|
+
key_provider=KeyProvider.SHOTGUN, # Authenticated via Shotgun Account
|
|
158
|
+
max_input_tokens=spec.max_input_tokens,
|
|
159
|
+
max_output_tokens=spec.max_output_tokens,
|
|
160
|
+
api_key=shotgun_api_key,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Priority 2: Fall back to individual provider keys
|
|
164
|
+
|
|
165
|
+
# Check if a specific model was requested
|
|
166
|
+
if isinstance(provider_or_model, ModelName):
|
|
167
|
+
# Look up the model spec
|
|
168
|
+
if provider_or_model not in MODEL_SPECS:
|
|
169
|
+
raise ValueError(f"Model '{provider_or_model.value}' not found")
|
|
170
|
+
spec = MODEL_SPECS[provider_or_model]
|
|
171
|
+
provider_enum = spec.provider
|
|
172
|
+
requested_model = provider_or_model
|
|
173
|
+
else:
|
|
174
|
+
# Convert string to ProviderType enum if needed (backward compatible)
|
|
175
|
+
if provider_or_model:
|
|
176
|
+
provider_enum = (
|
|
177
|
+
provider_or_model
|
|
178
|
+
if isinstance(provider_or_model, ProviderType)
|
|
179
|
+
else ProviderType(provider_or_model)
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
# No provider specified - find first available provider with a key
|
|
183
|
+
provider_enum = None
|
|
184
|
+
for provider in ProviderType:
|
|
185
|
+
if _has_provider_key(config, provider):
|
|
186
|
+
provider_enum = provider
|
|
187
|
+
break
|
|
188
|
+
|
|
189
|
+
if provider_enum is None:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"No provider keys configured. Set via environment variables or config."
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
requested_model = None # Will use provider's default model
|
|
122
195
|
|
|
123
196
|
if provider_enum == ProviderType.OPENAI:
|
|
124
|
-
api_key = _get_api_key(config.openai.api_key
|
|
197
|
+
api_key = _get_api_key(config.openai.api_key)
|
|
125
198
|
if not api_key:
|
|
126
|
-
raise ValueError(
|
|
127
|
-
f"OpenAI API key not configured. Set via environment variable {OPENAI_API_KEY_ENV} or config."
|
|
128
|
-
)
|
|
199
|
+
raise ValueError("OpenAI API key not configured. Set via config.")
|
|
129
200
|
|
|
130
|
-
#
|
|
131
|
-
model_name =
|
|
201
|
+
# Use requested model or default to gpt-5
|
|
202
|
+
model_name = requested_model if requested_model else ModelName.GPT_5
|
|
132
203
|
if model_name not in MODEL_SPECS:
|
|
133
|
-
raise ValueError(f"Model '{model_name}' not found")
|
|
204
|
+
raise ValueError(f"Model '{model_name.value}' not found")
|
|
134
205
|
spec = MODEL_SPECS[model_name]
|
|
135
206
|
|
|
136
207
|
# Create fully configured ModelConfig
|
|
137
208
|
return ModelConfig(
|
|
138
209
|
name=spec.name,
|
|
139
210
|
provider=spec.provider,
|
|
211
|
+
key_provider=KeyProvider.BYOK,
|
|
140
212
|
max_input_tokens=spec.max_input_tokens,
|
|
141
213
|
max_output_tokens=spec.max_output_tokens,
|
|
142
214
|
api_key=api_key,
|
|
143
215
|
)
|
|
144
216
|
|
|
145
217
|
elif provider_enum == ProviderType.ANTHROPIC:
|
|
146
|
-
api_key = _get_api_key(config.anthropic.api_key
|
|
218
|
+
api_key = _get_api_key(config.anthropic.api_key)
|
|
147
219
|
if not api_key:
|
|
148
|
-
raise ValueError(
|
|
149
|
-
f"Anthropic API key not configured. Set via environment variable {ANTHROPIC_API_KEY_ENV} or config."
|
|
150
|
-
)
|
|
220
|
+
raise ValueError("Anthropic API key not configured. Set via config.")
|
|
151
221
|
|
|
152
|
-
#
|
|
153
|
-
model_name =
|
|
222
|
+
# Use requested model or default to claude-opus-4-1
|
|
223
|
+
model_name = requested_model if requested_model else ModelName.CLAUDE_OPUS_4_1
|
|
154
224
|
if model_name not in MODEL_SPECS:
|
|
155
|
-
raise ValueError(f"Model '{model_name}' not found")
|
|
225
|
+
raise ValueError(f"Model '{model_name.value}' not found")
|
|
156
226
|
spec = MODEL_SPECS[model_name]
|
|
157
227
|
|
|
158
228
|
# Create fully configured ModelConfig
|
|
159
229
|
return ModelConfig(
|
|
160
230
|
name=spec.name,
|
|
161
231
|
provider=spec.provider,
|
|
232
|
+
key_provider=KeyProvider.BYOK,
|
|
162
233
|
max_input_tokens=spec.max_input_tokens,
|
|
163
234
|
max_output_tokens=spec.max_output_tokens,
|
|
164
235
|
api_key=api_key,
|
|
165
236
|
)
|
|
166
237
|
|
|
167
238
|
elif provider_enum == ProviderType.GOOGLE:
|
|
168
|
-
api_key = _get_api_key(config.google.api_key
|
|
239
|
+
api_key = _get_api_key(config.google.api_key)
|
|
169
240
|
if not api_key:
|
|
170
|
-
raise ValueError(
|
|
171
|
-
f"Gemini API key not configured. Set via environment variable {GEMINI_API_KEY_ENV} or config."
|
|
172
|
-
)
|
|
241
|
+
raise ValueError("Gemini API key not configured. Set via config.")
|
|
173
242
|
|
|
174
|
-
#
|
|
175
|
-
model_name =
|
|
243
|
+
# Use requested model or default to gemini-2.5-pro
|
|
244
|
+
model_name = requested_model if requested_model else ModelName.GEMINI_2_5_PRO
|
|
176
245
|
if model_name not in MODEL_SPECS:
|
|
177
|
-
raise ValueError(f"Model '{model_name}' not found")
|
|
246
|
+
raise ValueError(f"Model '{model_name.value}' not found")
|
|
178
247
|
spec = MODEL_SPECS[model_name]
|
|
179
248
|
|
|
180
249
|
# Create fully configured ModelConfig
|
|
181
250
|
return ModelConfig(
|
|
182
251
|
name=spec.name,
|
|
183
252
|
provider=spec.provider,
|
|
253
|
+
key_provider=KeyProvider.BYOK,
|
|
184
254
|
max_input_tokens=spec.max_input_tokens,
|
|
185
255
|
max_output_tokens=spec.max_output_tokens,
|
|
186
256
|
api_key=api_key,
|
|
@@ -190,12 +260,30 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
|
|
|
190
260
|
raise ValueError(f"Unsupported provider: {provider_enum}")
|
|
191
261
|
|
|
192
262
|
|
|
193
|
-
def
|
|
194
|
-
"""
|
|
263
|
+
def _has_provider_key(config: "ShotgunConfig", provider: ProviderType) -> bool:
|
|
264
|
+
"""Check if a provider has a configured API key.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
config: Shotgun configuration
|
|
268
|
+
provider: Provider to check
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
True if provider has a configured API key
|
|
272
|
+
"""
|
|
273
|
+
if provider == ProviderType.OPENAI:
|
|
274
|
+
return bool(_get_api_key(config.openai.api_key))
|
|
275
|
+
elif provider == ProviderType.ANTHROPIC:
|
|
276
|
+
return bool(_get_api_key(config.anthropic.api_key))
|
|
277
|
+
elif provider == ProviderType.GOOGLE:
|
|
278
|
+
return bool(_get_api_key(config.google.api_key))
|
|
279
|
+
return False
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _get_api_key(config_key: SecretStr | None) -> str | None:
|
|
283
|
+
"""Get API key from config.
|
|
195
284
|
|
|
196
285
|
Args:
|
|
197
286
|
config_key: API key from configuration
|
|
198
|
-
env_var: Environment variable name to check
|
|
199
287
|
|
|
200
288
|
Returns:
|
|
201
289
|
API key string or None
|
|
@@ -203,4 +291,4 @@ def _get_api_key(config_key: SecretStr | None, env_var: str) -> str | None:
|
|
|
203
291
|
if config_key is not None:
|
|
204
292
|
return config_key.get_secret_value()
|
|
205
293
|
|
|
206
|
-
return
|
|
294
|
+
return None
|
|
@@ -31,7 +31,7 @@ async def apply_persistent_compaction(
|
|
|
31
31
|
|
|
32
32
|
try:
|
|
33
33
|
# Count actual token usage using shared utility
|
|
34
|
-
estimated_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
|
|
34
|
+
estimated_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
|
|
35
35
|
|
|
36
36
|
# Create minimal usage info for compaction check
|
|
37
37
|
usage = RequestUsage(
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Protocol
|
|
4
4
|
|
|
5
|
+
from pydantic_ai import ModelSettings
|
|
5
6
|
from pydantic_ai.messages import (
|
|
6
7
|
ModelMessage,
|
|
7
8
|
ModelRequest,
|
|
@@ -10,7 +11,7 @@ from pydantic_ai.messages import (
|
|
|
10
11
|
UserPromptPart,
|
|
11
12
|
)
|
|
12
13
|
|
|
13
|
-
from shotgun.agents.
|
|
14
|
+
from shotgun.agents.llm import shotgun_model_request
|
|
14
15
|
from shotgun.agents.messages import AgentSystemPrompt, SystemStatusPrompt
|
|
15
16
|
from shotgun.agents.models import AgentDeps
|
|
16
17
|
from shotgun.logging_config import get_logger
|
|
@@ -154,7 +155,7 @@ async def token_limit_compactor(
|
|
|
154
155
|
|
|
155
156
|
if last_summary_index is not None:
|
|
156
157
|
# Check if post-summary conversation exceeds threshold for incremental compaction
|
|
157
|
-
post_summary_tokens = estimate_post_summary_tokens(
|
|
158
|
+
post_summary_tokens = await estimate_post_summary_tokens(
|
|
158
159
|
messages, last_summary_index, deps.llm_model
|
|
159
160
|
)
|
|
160
161
|
post_summary_percentage = (
|
|
@@ -248,7 +249,7 @@ async def token_limit_compactor(
|
|
|
248
249
|
]
|
|
249
250
|
|
|
250
251
|
# Calculate optimal max_tokens for summarization
|
|
251
|
-
max_tokens = calculate_max_summarization_tokens(
|
|
252
|
+
max_tokens = await calculate_max_summarization_tokens(
|
|
252
253
|
deps.llm_model, request_messages
|
|
253
254
|
)
|
|
254
255
|
|
|
@@ -261,7 +262,9 @@ async def token_limit_compactor(
|
|
|
261
262
|
summary_response = await shotgun_model_request(
|
|
262
263
|
model_config=deps.llm_model,
|
|
263
264
|
messages=request_messages,
|
|
264
|
-
|
|
265
|
+
model_settings=ModelSettings(
|
|
266
|
+
max_tokens=max_tokens # Use calculated optimal tokens for summarization
|
|
267
|
+
),
|
|
265
268
|
)
|
|
266
269
|
|
|
267
270
|
log_summarization_response(summary_response, "INCREMENTAL")
|
|
@@ -328,7 +331,9 @@ async def token_limit_compactor(
|
|
|
328
331
|
|
|
329
332
|
# Track compaction completion
|
|
330
333
|
messages_after = len(compacted_messages)
|
|
331
|
-
tokens_after = estimate_tokens_from_messages(
|
|
334
|
+
tokens_after = await estimate_tokens_from_messages(
|
|
335
|
+
compacted_messages, deps.llm_model
|
|
336
|
+
)
|
|
332
337
|
reduction_percentage = (
|
|
333
338
|
((messages_before - messages_after) / messages_before * 100)
|
|
334
339
|
if messages_before > 0
|
|
@@ -354,7 +359,7 @@ async def token_limit_compactor(
|
|
|
354
359
|
|
|
355
360
|
else:
|
|
356
361
|
# Check if total conversation exceeds threshold for full compaction
|
|
357
|
-
total_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
|
|
362
|
+
total_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
|
|
358
363
|
total_percentage = (total_tokens / max_tokens) * 100 if max_tokens > 0 else 0
|
|
359
364
|
|
|
360
365
|
logger.debug(
|
|
@@ -392,7 +397,9 @@ async def _full_compaction(
|
|
|
392
397
|
]
|
|
393
398
|
|
|
394
399
|
# Calculate optimal max_tokens for summarization
|
|
395
|
-
max_tokens = calculate_max_summarization_tokens(
|
|
400
|
+
max_tokens = await calculate_max_summarization_tokens(
|
|
401
|
+
deps.llm_model, request_messages
|
|
402
|
+
)
|
|
396
403
|
|
|
397
404
|
# Debug logging using shared utilities
|
|
398
405
|
log_summarization_request(
|
|
@@ -403,11 +410,13 @@ async def _full_compaction(
|
|
|
403
410
|
summary_response = await shotgun_model_request(
|
|
404
411
|
model_config=deps.llm_model,
|
|
405
412
|
messages=request_messages,
|
|
406
|
-
|
|
413
|
+
model_settings=ModelSettings(
|
|
414
|
+
max_tokens=max_tokens # Use calculated optimal tokens for summarization
|
|
415
|
+
),
|
|
407
416
|
)
|
|
408
417
|
|
|
409
418
|
# Calculate token reduction
|
|
410
|
-
current_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
|
|
419
|
+
current_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
|
|
411
420
|
summary_usage = summary_response.usage
|
|
412
421
|
reduction_percentage = (
|
|
413
422
|
((current_tokens - summary_usage.output_tokens) / current_tokens) * 100
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Real token counting for all supported providers.
|
|
2
|
+
|
|
3
|
+
This module provides accurate token counting using each provider's official
|
|
4
|
+
APIs and libraries, eliminating the need for rough character-based estimation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .anthropic import AnthropicTokenCounter
|
|
8
|
+
from .base import TokenCounter, extract_text_from_messages
|
|
9
|
+
from .openai import OpenAITokenCounter
|
|
10
|
+
from .sentencepiece_counter import SentencePieceTokenCounter
|
|
11
|
+
from .utils import (
|
|
12
|
+
count_post_summary_tokens,
|
|
13
|
+
count_tokens_from_message_parts,
|
|
14
|
+
count_tokens_from_messages,
|
|
15
|
+
get_token_counter,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
# Base classes
|
|
20
|
+
"TokenCounter",
|
|
21
|
+
# Counter implementations
|
|
22
|
+
"OpenAITokenCounter",
|
|
23
|
+
"AnthropicTokenCounter",
|
|
24
|
+
"SentencePieceTokenCounter",
|
|
25
|
+
# Utility functions
|
|
26
|
+
"get_token_counter",
|
|
27
|
+
"count_tokens_from_messages",
|
|
28
|
+
"count_post_summary_tokens",
|
|
29
|
+
"count_tokens_from_message_parts",
|
|
30
|
+
"extract_text_from_messages",
|
|
31
|
+
]
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""Anthropic token counting using official client."""
|
|
2
|
+
|
|
3
|
+
from pydantic_ai.messages import ModelMessage
|
|
4
|
+
|
|
5
|
+
from shotgun.agents.config.models import KeyProvider
|
|
6
|
+
from shotgun.llm_proxy import create_anthropic_proxy_client
|
|
7
|
+
from shotgun.logging_config import get_logger
|
|
8
|
+
|
|
9
|
+
from .base import TokenCounter, extract_text_from_messages
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AnthropicTokenCounter(TokenCounter):
|
|
15
|
+
"""Token counter for Anthropic models using official client."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
model_name: str,
|
|
20
|
+
api_key: str,
|
|
21
|
+
key_provider: KeyProvider = KeyProvider.BYOK,
|
|
22
|
+
):
|
|
23
|
+
"""Initialize Anthropic token counter.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
model_name: Anthropic model name for token counting
|
|
27
|
+
api_key: API key (Anthropic for BYOK, Shotgun for proxy)
|
|
28
|
+
key_provider: Key provider type (BYOK or SHOTGUN)
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
RuntimeError: If client initialization fails
|
|
32
|
+
"""
|
|
33
|
+
self.model_name = model_name
|
|
34
|
+
import anthropic
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
if key_provider == KeyProvider.SHOTGUN:
|
|
38
|
+
# Use LiteLLM proxy for Shotgun Account
|
|
39
|
+
# Proxies to Anthropic's token counting API
|
|
40
|
+
self.client = create_anthropic_proxy_client(api_key)
|
|
41
|
+
logger.debug(
|
|
42
|
+
f"Initialized Anthropic token counter for {model_name} via LiteLLM proxy"
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
# Direct Anthropic API for BYOK
|
|
46
|
+
self.client = anthropic.Anthropic(api_key=api_key)
|
|
47
|
+
logger.debug(
|
|
48
|
+
f"Initialized Anthropic token counter for {model_name} via direct API"
|
|
49
|
+
)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
raise RuntimeError("Failed to initialize Anthropic client") from e
|
|
52
|
+
|
|
53
|
+
async def count_tokens(self, text: str) -> int:
|
|
54
|
+
"""Count tokens using Anthropic's official API (async).
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
text: Text to count tokens for
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Exact token count from Anthropic API
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
RuntimeError: If API call fails
|
|
64
|
+
"""
|
|
65
|
+
try:
|
|
66
|
+
# Anthropic API expects messages format and model parameter
|
|
67
|
+
result = self.client.messages.count_tokens(
|
|
68
|
+
messages=[{"role": "user", "content": text}], model=self.model_name
|
|
69
|
+
)
|
|
70
|
+
return result.input_tokens
|
|
71
|
+
except Exception as e:
|
|
72
|
+
raise RuntimeError(
|
|
73
|
+
f"Anthropic token counting API failed for {self.model_name}"
|
|
74
|
+
) from e
|
|
75
|
+
|
|
76
|
+
async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
|
|
77
|
+
"""Count tokens across all messages using Anthropic API (async).
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
messages: List of PydanticAI messages
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Total token count for all messages
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
RuntimeError: If token counting fails
|
|
87
|
+
"""
|
|
88
|
+
total_text = extract_text_from_messages(messages)
|
|
89
|
+
return await self.count_tokens(total_text)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Base classes and shared utilities for token counting."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from pydantic_ai.messages import ModelMessage
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TokenCounter(ABC):
|
|
9
|
+
"""Abstract base class for provider-specific token counting.
|
|
10
|
+
|
|
11
|
+
All methods are async to support non-blocking operations like
|
|
12
|
+
downloading tokenizer models or making API calls.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
async def count_tokens(self, text: str) -> int:
|
|
17
|
+
"""Count tokens in text using provider-specific method (async).
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
text: Text to count tokens for
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Exact token count as determined by the provider
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
RuntimeError: If token counting fails
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
|
|
31
|
+
"""Count tokens in PydanticAI message structures (async).
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
messages: List of messages to count tokens for
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Total token count across all messages
|
|
38
|
+
|
|
39
|
+
Raises:
|
|
40
|
+
RuntimeError: If token counting fails
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def extract_text_from_messages(messages: list[ModelMessage]) -> str:
|
|
45
|
+
"""Extract all text content from messages for token counting.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
messages: List of PydanticAI messages
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Combined text content from all messages
|
|
52
|
+
"""
|
|
53
|
+
text_parts = []
|
|
54
|
+
|
|
55
|
+
for message in messages:
|
|
56
|
+
if hasattr(message, "parts"):
|
|
57
|
+
for part in message.parts:
|
|
58
|
+
if hasattr(part, "content") and isinstance(part.content, str):
|
|
59
|
+
text_parts.append(part.content)
|
|
60
|
+
else:
|
|
61
|
+
# Handle non-text parts (tool calls, etc.)
|
|
62
|
+
text_parts.append(str(part))
|
|
63
|
+
else:
|
|
64
|
+
# Handle messages without parts
|
|
65
|
+
text_parts.append(str(message))
|
|
66
|
+
|
|
67
|
+
return "\n".join(text_parts)
|