shotgun-sh 0.1.16.dev2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of shotgun-sh might be problematic. Click here for more details.

Files changed (43) hide show
  1. shotgun/agents/common.py +4 -5
  2. shotgun/agents/config/constants.py +21 -5
  3. shotgun/agents/config/manager.py +171 -63
  4. shotgun/agents/config/models.py +65 -84
  5. shotgun/agents/config/provider.py +174 -85
  6. shotgun/agents/history/compaction.py +1 -1
  7. shotgun/agents/history/history_processors.py +18 -9
  8. shotgun/agents/history/token_counting/__init__.py +31 -0
  9. shotgun/agents/history/token_counting/anthropic.py +89 -0
  10. shotgun/agents/history/token_counting/base.py +67 -0
  11. shotgun/agents/history/token_counting/openai.py +80 -0
  12. shotgun/agents/history/token_counting/sentencepiece_counter.py +119 -0
  13. shotgun/agents/history/token_counting/tokenizer_cache.py +90 -0
  14. shotgun/agents/history/token_counting/utils.py +147 -0
  15. shotgun/agents/history/token_estimation.py +12 -12
  16. shotgun/agents/llm.py +62 -0
  17. shotgun/agents/models.py +2 -2
  18. shotgun/agents/tools/web_search/__init__.py +42 -15
  19. shotgun/agents/tools/web_search/anthropic.py +54 -50
  20. shotgun/agents/tools/web_search/gemini.py +31 -20
  21. shotgun/agents/tools/web_search/openai.py +4 -4
  22. shotgun/build_constants.py +2 -2
  23. shotgun/cli/config.py +28 -57
  24. shotgun/cli/models.py +2 -2
  25. shotgun/codebase/models.py +4 -4
  26. shotgun/llm_proxy/__init__.py +16 -0
  27. shotgun/llm_proxy/clients.py +39 -0
  28. shotgun/llm_proxy/constants.py +8 -0
  29. shotgun/main.py +6 -0
  30. shotgun/posthog_telemetry.py +5 -3
  31. shotgun/tui/app.py +7 -3
  32. shotgun/tui/screens/chat.py +2 -8
  33. shotgun/tui/screens/chat_screen/command_providers.py +118 -11
  34. shotgun/tui/screens/chat_screen/history.py +3 -1
  35. shotgun/tui/screens/model_picker.py +327 -0
  36. shotgun/tui/screens/provider_config.py +57 -26
  37. shotgun/utils/env_utils.py +12 -0
  38. {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.0.dist-info}/METADATA +2 -2
  39. {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.0.dist-info}/RECORD +42 -31
  40. shotgun/agents/history/token_counting.py +0 -429
  41. {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.0.dist-info}/WHEEL +0 -0
  42. {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.0.dist-info}/entry_points.txt +0 -0
  43. {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,5 @@
1
1
  """Provider management for LLM configuration."""
2
2
 
3
- import os
4
-
5
3
  from pydantic import SecretStr
6
4
  from pydantic_ai.models import Model
7
5
  from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
@@ -12,27 +10,36 @@ from pydantic_ai.providers.google import GoogleProvider
12
10
  from pydantic_ai.providers.openai import OpenAIProvider
13
11
  from pydantic_ai.settings import ModelSettings
14
12
 
13
+ from shotgun.llm_proxy import create_litellm_provider
15
14
  from shotgun.logging_config import get_logger
16
15
 
17
- from .constants import (
18
- ANTHROPIC_API_KEY_ENV,
19
- GEMINI_API_KEY_ENV,
20
- OPENAI_API_KEY_ENV,
21
- )
22
16
  from .manager import get_config_manager
23
- from .models import MODEL_SPECS, ModelConfig, ProviderType
17
+ from .models import (
18
+ MODEL_SPECS,
19
+ KeyProvider,
20
+ ModelConfig,
21
+ ModelName,
22
+ ProviderType,
23
+ ShotgunConfig,
24
+ )
24
25
 
25
26
  logger = get_logger(__name__)
26
27
 
27
28
  # Global cache for Model instances (singleton pattern)
28
- _model_cache: dict[tuple[ProviderType, str, str], Model] = {}
29
+ _model_cache: dict[tuple[ProviderType, KeyProvider, ModelName, str], Model] = {}
29
30
 
30
31
 
31
- def get_or_create_model(provider: ProviderType, model_name: str, api_key: str) -> Model:
32
+ def get_or_create_model(
33
+ provider: ProviderType,
34
+ key_provider: "KeyProvider",
35
+ model_name: ModelName,
36
+ api_key: str,
37
+ ) -> Model:
32
38
  """Get or create a singleton Model instance.
33
39
 
34
40
  Args:
35
- provider: Provider type
41
+ provider: Actual LLM provider (openai, anthropic, google)
42
+ key_provider: Authentication method (byok or shotgun)
36
43
  model_name: Name of the model
37
44
  api_key: API key for the provider
38
45
 
@@ -42,66 +49,88 @@ def get_or_create_model(provider: ProviderType, model_name: str, api_key: str) -
42
49
  Raises:
43
50
  ValueError: If provider is not supported
44
51
  """
45
- cache_key = (provider, model_name, api_key)
52
+ cache_key = (provider, key_provider, model_name, api_key)
46
53
 
47
54
  if cache_key not in _model_cache:
48
- logger.debug("Creating new %s model instance: %s", provider.value, model_name)
55
+ logger.debug(
56
+ "Creating new %s model instance via %s: %s",
57
+ provider.value,
58
+ key_provider.value,
59
+ model_name,
60
+ )
49
61
 
50
- if provider == ProviderType.OPENAI:
51
- # Get max_tokens from MODEL_SPECS to use full capacity
62
+ # Get max_tokens from MODEL_SPECS
63
+ if model_name in MODEL_SPECS:
64
+ max_tokens = MODEL_SPECS[model_name].max_output_tokens
65
+ else:
66
+ # Fallback defaults based on provider
67
+ max_tokens = {
68
+ ProviderType.OPENAI: 16_000,
69
+ ProviderType.ANTHROPIC: 32_000,
70
+ ProviderType.GOOGLE: 64_000,
71
+ }.get(provider, 16_000)
72
+
73
+ # Use LiteLLM proxy for Shotgun Account, native providers for BYOK
74
+ if key_provider == KeyProvider.SHOTGUN:
75
+ # Shotgun Account uses LiteLLM proxy for any model
52
76
  if model_name in MODEL_SPECS:
53
- max_tokens = MODEL_SPECS[model_name].max_output_tokens
77
+ litellm_model_name = MODEL_SPECS[model_name].litellm_proxy_model_name
54
78
  else:
55
- max_tokens = 16_000 # Default for GPT models
79
+ # Fallback for unmapped models
80
+ litellm_model_name = f"openai/{model_name.value}"
56
81
 
57
- openai_provider = OpenAIProvider(api_key=api_key)
82
+ litellm_provider = create_litellm_provider(api_key)
58
83
  _model_cache[cache_key] = OpenAIChatModel(
59
- model_name,
60
- provider=openai_provider,
84
+ litellm_model_name,
85
+ provider=litellm_provider,
61
86
  settings=ModelSettings(max_tokens=max_tokens),
62
87
  )
63
- elif provider == ProviderType.ANTHROPIC:
64
- # Get max_tokens from MODEL_SPECS to use full capacity
65
- if model_name in MODEL_SPECS:
66
- max_tokens = MODEL_SPECS[model_name].max_output_tokens
88
+ elif key_provider == KeyProvider.BYOK:
89
+ # Use native provider implementations with user's API keys
90
+ if provider == ProviderType.OPENAI:
91
+ openai_provider = OpenAIProvider(api_key=api_key)
92
+ _model_cache[cache_key] = OpenAIChatModel(
93
+ model_name,
94
+ provider=openai_provider,
95
+ settings=ModelSettings(max_tokens=max_tokens),
96
+ )
97
+ elif provider == ProviderType.ANTHROPIC:
98
+ anthropic_provider = AnthropicProvider(api_key=api_key)
99
+ _model_cache[cache_key] = AnthropicModel(
100
+ model_name,
101
+ provider=anthropic_provider,
102
+ settings=AnthropicModelSettings(
103
+ max_tokens=max_tokens,
104
+ timeout=600, # 10 minutes timeout for large responses
105
+ ),
106
+ )
107
+ elif provider == ProviderType.GOOGLE:
108
+ google_provider = GoogleProvider(api_key=api_key)
109
+ _model_cache[cache_key] = GoogleModel(
110
+ model_name,
111
+ provider=google_provider,
112
+ settings=ModelSettings(max_tokens=max_tokens),
113
+ )
67
114
  else:
68
- max_tokens = 32_000 # Default for Claude models
69
-
70
- anthropic_provider = AnthropicProvider(api_key=api_key)
71
- _model_cache[cache_key] = AnthropicModel(
72
- model_name,
73
- provider=anthropic_provider,
74
- settings=AnthropicModelSettings(
75
- max_tokens=max_tokens,
76
- timeout=600, # 10 minutes timeout for large responses
77
- ),
78
- )
79
- elif provider == ProviderType.GOOGLE:
80
- # Get max_tokens from MODEL_SPECS to use full capacity
81
- if model_name in MODEL_SPECS:
82
- max_tokens = MODEL_SPECS[model_name].max_output_tokens
83
- else:
84
- max_tokens = 64_000 # Default for Gemini models
85
-
86
- google_provider = GoogleProvider(api_key=api_key)
87
- _model_cache[cache_key] = GoogleModel(
88
- model_name,
89
- provider=google_provider,
90
- settings=ModelSettings(max_tokens=max_tokens),
91
- )
115
+ raise ValueError(f"Unsupported provider: {provider}")
92
116
  else:
93
- raise ValueError(f"Unsupported provider: {provider}")
117
+ raise ValueError(f"Unsupported key provider: {key_provider}")
94
118
  else:
95
119
  logger.debug("Reusing cached %s model instance: %s", provider.value, model_name)
96
120
 
97
121
  return _model_cache[cache_key]
98
122
 
99
123
 
100
- def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
124
+ def get_provider_model(
125
+ provider_or_model: ProviderType | ModelName | None = None,
126
+ ) -> ModelConfig:
101
127
  """Get a fully configured ModelConfig with API key and Model instance.
102
128
 
103
129
  Args:
104
- provider: Provider to get model for. If None, uses default provider
130
+ provider_or_model: Either a ProviderType, ModelName, or None.
131
+ - If ModelName: returns that specific model with appropriate API key
132
+ - If ProviderType: returns default model for that provider (backward compatible)
133
+ - If None: uses default provider with its default model
105
134
 
106
135
  Returns:
107
136
  ModelConfig with API key configured and lazy Model instance
@@ -110,77 +139,119 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
110
139
  ValueError: If provider is not configured properly or model not found
111
140
  """
112
141
  config_manager = get_config_manager()
113
- config = config_manager.load()
114
- # Convert string to ProviderType enum if needed
115
- provider_enum = (
116
- provider
117
- if isinstance(provider, ProviderType)
118
- else ProviderType(provider)
119
- if provider
120
- else config.default_provider
121
- )
142
+ # Use cached config for read-only access (performance)
143
+ config = config_manager.load(force_reload=False)
144
+
145
+ # Priority 1: Check if Shotgun key exists - if so, use it for ANY model
146
+ shotgun_api_key = _get_api_key(config.shotgun.api_key)
147
+ if shotgun_api_key:
148
+ # Use selected model or default to claude-sonnet-4-5
149
+ model_name = config.selected_model or ModelName.CLAUDE_SONNET_4_5
150
+ if model_name not in MODEL_SPECS:
151
+ raise ValueError(f"Model '{model_name.value}' not found")
152
+ spec = MODEL_SPECS[model_name]
153
+
154
+ # Use Shotgun Account with selected model (provider = actual LLM provider)
155
+ return ModelConfig(
156
+ name=spec.name,
157
+ provider=spec.provider, # Actual LLM provider (OPENAI/ANTHROPIC/GOOGLE)
158
+ key_provider=KeyProvider.SHOTGUN, # Authenticated via Shotgun Account
159
+ max_input_tokens=spec.max_input_tokens,
160
+ max_output_tokens=spec.max_output_tokens,
161
+ api_key=shotgun_api_key,
162
+ )
163
+
164
+ # Priority 2: Fall back to individual provider keys
165
+
166
+ # Check if a specific model was requested
167
+ if isinstance(provider_or_model, ModelName):
168
+ # Look up the model spec
169
+ if provider_or_model not in MODEL_SPECS:
170
+ raise ValueError(f"Model '{provider_or_model.value}' not found")
171
+ spec = MODEL_SPECS[provider_or_model]
172
+ provider_enum = spec.provider
173
+ requested_model = provider_or_model
174
+ else:
175
+ # Convert string to ProviderType enum if needed (backward compatible)
176
+ if provider_or_model:
177
+ provider_enum = (
178
+ provider_or_model
179
+ if isinstance(provider_or_model, ProviderType)
180
+ else ProviderType(provider_or_model)
181
+ )
182
+ else:
183
+ # No provider specified - find first available provider with a key
184
+ provider_enum = None
185
+ for provider in ProviderType:
186
+ if _has_provider_key(config, provider):
187
+ provider_enum = provider
188
+ break
189
+
190
+ if provider_enum is None:
191
+ raise ValueError(
192
+ "No provider keys configured. Set via environment variables or config."
193
+ )
194
+
195
+ requested_model = None # Will use provider's default model
122
196
 
123
197
  if provider_enum == ProviderType.OPENAI:
124
- api_key = _get_api_key(config.openai.api_key, OPENAI_API_KEY_ENV)
198
+ api_key = _get_api_key(config.openai.api_key)
125
199
  if not api_key:
126
- raise ValueError(
127
- f"OpenAI API key not configured. Set via environment variable {OPENAI_API_KEY_ENV} or config."
128
- )
200
+ raise ValueError("OpenAI API key not configured. Set via config.")
129
201
 
130
- # Get model spec - hardcoded to gpt-5
131
- model_name = "gpt-5"
202
+ # Use requested model or default to gpt-5
203
+ model_name = requested_model if requested_model else ModelName.GPT_5
132
204
  if model_name not in MODEL_SPECS:
133
- raise ValueError(f"Model '{model_name}' not found")
205
+ raise ValueError(f"Model '{model_name.value}' not found")
134
206
  spec = MODEL_SPECS[model_name]
135
207
 
136
208
  # Create fully configured ModelConfig
137
209
  return ModelConfig(
138
210
  name=spec.name,
139
211
  provider=spec.provider,
212
+ key_provider=KeyProvider.BYOK,
140
213
  max_input_tokens=spec.max_input_tokens,
141
214
  max_output_tokens=spec.max_output_tokens,
142
215
  api_key=api_key,
143
216
  )
144
217
 
145
218
  elif provider_enum == ProviderType.ANTHROPIC:
146
- api_key = _get_api_key(config.anthropic.api_key, ANTHROPIC_API_KEY_ENV)
219
+ api_key = _get_api_key(config.anthropic.api_key)
147
220
  if not api_key:
148
- raise ValueError(
149
- f"Anthropic API key not configured. Set via environment variable {ANTHROPIC_API_KEY_ENV} or config."
150
- )
221
+ raise ValueError("Anthropic API key not configured. Set via config.")
151
222
 
152
- # Get model spec - hardcoded to claude-opus-4-1
153
- model_name = "claude-opus-4-1"
223
+ # Use requested model or default to claude-sonnet-4-5
224
+ model_name = requested_model if requested_model else ModelName.CLAUDE_SONNET_4_5
154
225
  if model_name not in MODEL_SPECS:
155
- raise ValueError(f"Model '{model_name}' not found")
226
+ raise ValueError(f"Model '{model_name.value}' not found")
156
227
  spec = MODEL_SPECS[model_name]
157
228
 
158
229
  # Create fully configured ModelConfig
159
230
  return ModelConfig(
160
231
  name=spec.name,
161
232
  provider=spec.provider,
233
+ key_provider=KeyProvider.BYOK,
162
234
  max_input_tokens=spec.max_input_tokens,
163
235
  max_output_tokens=spec.max_output_tokens,
164
236
  api_key=api_key,
165
237
  )
166
238
 
167
239
  elif provider_enum == ProviderType.GOOGLE:
168
- api_key = _get_api_key(config.google.api_key, GEMINI_API_KEY_ENV)
240
+ api_key = _get_api_key(config.google.api_key)
169
241
  if not api_key:
170
- raise ValueError(
171
- f"Gemini API key not configured. Set via environment variable {GEMINI_API_KEY_ENV} or config."
172
- )
242
+ raise ValueError("Gemini API key not configured. Set via config.")
173
243
 
174
- # Get model spec - hardcoded to gemini-2.5-pro
175
- model_name = "gemini-2.5-pro"
244
+ # Use requested model or default to gemini-2.5-pro
245
+ model_name = requested_model if requested_model else ModelName.GEMINI_2_5_PRO
176
246
  if model_name not in MODEL_SPECS:
177
- raise ValueError(f"Model '{model_name}' not found")
247
+ raise ValueError(f"Model '{model_name.value}' not found")
178
248
  spec = MODEL_SPECS[model_name]
179
249
 
180
250
  # Create fully configured ModelConfig
181
251
  return ModelConfig(
182
252
  name=spec.name,
183
253
  provider=spec.provider,
254
+ key_provider=KeyProvider.BYOK,
184
255
  max_input_tokens=spec.max_input_tokens,
185
256
  max_output_tokens=spec.max_output_tokens,
186
257
  api_key=api_key,
@@ -190,12 +261,30 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
190
261
  raise ValueError(f"Unsupported provider: {provider_enum}")
191
262
 
192
263
 
193
- def _get_api_key(config_key: SecretStr | None, env_var: str) -> str | None:
194
- """Get API key from config or environment variable.
264
+ def _has_provider_key(config: "ShotgunConfig", provider: ProviderType) -> bool:
265
+ """Check if a provider has a configured API key.
266
+
267
+ Args:
268
+ config: Shotgun configuration
269
+ provider: Provider to check
270
+
271
+ Returns:
272
+ True if provider has a configured API key
273
+ """
274
+ if provider == ProviderType.OPENAI:
275
+ return bool(_get_api_key(config.openai.api_key))
276
+ elif provider == ProviderType.ANTHROPIC:
277
+ return bool(_get_api_key(config.anthropic.api_key))
278
+ elif provider == ProviderType.GOOGLE:
279
+ return bool(_get_api_key(config.google.api_key))
280
+ return False
281
+
282
+
283
+ def _get_api_key(config_key: SecretStr | None) -> str | None:
284
+ """Get API key from config.
195
285
 
196
286
  Args:
197
287
  config_key: API key from configuration
198
- env_var: Environment variable name to check
199
288
 
200
289
  Returns:
201
290
  API key string or None
@@ -203,4 +292,4 @@ def _get_api_key(config_key: SecretStr | None, env_var: str) -> str | None:
203
292
  if config_key is not None:
204
293
  return config_key.get_secret_value()
205
294
 
206
- return os.getenv(env_var)
295
+ return None
@@ -31,7 +31,7 @@ async def apply_persistent_compaction(
31
31
 
32
32
  try:
33
33
  # Count actual token usage using shared utility
34
- estimated_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
34
+ estimated_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
35
35
 
36
36
  # Create minimal usage info for compaction check
37
37
  usage = RequestUsage(
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import TYPE_CHECKING, Any, Protocol
4
4
 
5
+ from pydantic_ai import ModelSettings
5
6
  from pydantic_ai.messages import (
6
7
  ModelMessage,
7
8
  ModelRequest,
@@ -10,7 +11,7 @@ from pydantic_ai.messages import (
10
11
  UserPromptPart,
11
12
  )
12
13
 
13
- from shotgun.agents.config.models import shotgun_model_request
14
+ from shotgun.agents.llm import shotgun_model_request
14
15
  from shotgun.agents.messages import AgentSystemPrompt, SystemStatusPrompt
15
16
  from shotgun.agents.models import AgentDeps
16
17
  from shotgun.logging_config import get_logger
@@ -154,7 +155,7 @@ async def token_limit_compactor(
154
155
 
155
156
  if last_summary_index is not None:
156
157
  # Check if post-summary conversation exceeds threshold for incremental compaction
157
- post_summary_tokens = estimate_post_summary_tokens(
158
+ post_summary_tokens = await estimate_post_summary_tokens(
158
159
  messages, last_summary_index, deps.llm_model
159
160
  )
160
161
  post_summary_percentage = (
@@ -248,7 +249,7 @@ async def token_limit_compactor(
248
249
  ]
249
250
 
250
251
  # Calculate optimal max_tokens for summarization
251
- max_tokens = calculate_max_summarization_tokens(
252
+ max_tokens = await calculate_max_summarization_tokens(
252
253
  deps.llm_model, request_messages
253
254
  )
254
255
 
@@ -261,7 +262,9 @@ async def token_limit_compactor(
261
262
  summary_response = await shotgun_model_request(
262
263
  model_config=deps.llm_model,
263
264
  messages=request_messages,
264
- max_tokens=max_tokens, # Use calculated optimal tokens for summarization
265
+ model_settings=ModelSettings(
266
+ max_tokens=max_tokens # Use calculated optimal tokens for summarization
267
+ ),
265
268
  )
266
269
 
267
270
  log_summarization_response(summary_response, "INCREMENTAL")
@@ -328,7 +331,9 @@ async def token_limit_compactor(
328
331
 
329
332
  # Track compaction completion
330
333
  messages_after = len(compacted_messages)
331
- tokens_after = estimate_tokens_from_messages(compacted_messages, deps.llm_model)
334
+ tokens_after = await estimate_tokens_from_messages(
335
+ compacted_messages, deps.llm_model
336
+ )
332
337
  reduction_percentage = (
333
338
  ((messages_before - messages_after) / messages_before * 100)
334
339
  if messages_before > 0
@@ -354,7 +359,7 @@ async def token_limit_compactor(
354
359
 
355
360
  else:
356
361
  # Check if total conversation exceeds threshold for full compaction
357
- total_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
362
+ total_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
358
363
  total_percentage = (total_tokens / max_tokens) * 100 if max_tokens > 0 else 0
359
364
 
360
365
  logger.debug(
@@ -392,7 +397,9 @@ async def _full_compaction(
392
397
  ]
393
398
 
394
399
  # Calculate optimal max_tokens for summarization
395
- max_tokens = calculate_max_summarization_tokens(deps.llm_model, request_messages)
400
+ max_tokens = await calculate_max_summarization_tokens(
401
+ deps.llm_model, request_messages
402
+ )
396
403
 
397
404
  # Debug logging using shared utilities
398
405
  log_summarization_request(
@@ -403,11 +410,13 @@ async def _full_compaction(
403
410
  summary_response = await shotgun_model_request(
404
411
  model_config=deps.llm_model,
405
412
  messages=request_messages,
406
- max_tokens=max_tokens, # Use calculated optimal tokens for summarization
413
+ model_settings=ModelSettings(
414
+ max_tokens=max_tokens # Use calculated optimal tokens for summarization
415
+ ),
407
416
  )
408
417
 
409
418
  # Calculate token reduction
410
- current_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
419
+ current_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
411
420
  summary_usage = summary_response.usage
412
421
  reduction_percentage = (
413
422
  ((current_tokens - summary_usage.output_tokens) / current_tokens) * 100
@@ -0,0 +1,31 @@
1
+ """Real token counting for all supported providers.
2
+
3
+ This module provides accurate token counting using each provider's official
4
+ APIs and libraries, eliminating the need for rough character-based estimation.
5
+ """
6
+
7
+ from .anthropic import AnthropicTokenCounter
8
+ from .base import TokenCounter, extract_text_from_messages
9
+ from .openai import OpenAITokenCounter
10
+ from .sentencepiece_counter import SentencePieceTokenCounter
11
+ from .utils import (
12
+ count_post_summary_tokens,
13
+ count_tokens_from_message_parts,
14
+ count_tokens_from_messages,
15
+ get_token_counter,
16
+ )
17
+
18
+ __all__ = [
19
+ # Base classes
20
+ "TokenCounter",
21
+ # Counter implementations
22
+ "OpenAITokenCounter",
23
+ "AnthropicTokenCounter",
24
+ "SentencePieceTokenCounter",
25
+ # Utility functions
26
+ "get_token_counter",
27
+ "count_tokens_from_messages",
28
+ "count_post_summary_tokens",
29
+ "count_tokens_from_message_parts",
30
+ "extract_text_from_messages",
31
+ ]
@@ -0,0 +1,89 @@
1
+ """Anthropic token counting using official client."""
2
+
3
+ from pydantic_ai.messages import ModelMessage
4
+
5
+ from shotgun.agents.config.models import KeyProvider
6
+ from shotgun.llm_proxy import create_anthropic_proxy_client
7
+ from shotgun.logging_config import get_logger
8
+
9
+ from .base import TokenCounter, extract_text_from_messages
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class AnthropicTokenCounter(TokenCounter):
15
+ """Token counter for Anthropic models using official client."""
16
+
17
+ def __init__(
18
+ self,
19
+ model_name: str,
20
+ api_key: str,
21
+ key_provider: KeyProvider = KeyProvider.BYOK,
22
+ ):
23
+ """Initialize Anthropic token counter.
24
+
25
+ Args:
26
+ model_name: Anthropic model name for token counting
27
+ api_key: API key (Anthropic for BYOK, Shotgun for proxy)
28
+ key_provider: Key provider type (BYOK or SHOTGUN)
29
+
30
+ Raises:
31
+ RuntimeError: If client initialization fails
32
+ """
33
+ self.model_name = model_name
34
+ import anthropic
35
+
36
+ try:
37
+ if key_provider == KeyProvider.SHOTGUN:
38
+ # Use LiteLLM proxy for Shotgun Account
39
+ # Proxies to Anthropic's token counting API
40
+ self.client = create_anthropic_proxy_client(api_key)
41
+ logger.debug(
42
+ f"Initialized Anthropic token counter for {model_name} via LiteLLM proxy"
43
+ )
44
+ else:
45
+ # Direct Anthropic API for BYOK
46
+ self.client = anthropic.Anthropic(api_key=api_key)
47
+ logger.debug(
48
+ f"Initialized Anthropic token counter for {model_name} via direct API"
49
+ )
50
+ except Exception as e:
51
+ raise RuntimeError("Failed to initialize Anthropic client") from e
52
+
53
+ async def count_tokens(self, text: str) -> int:
54
+ """Count tokens using Anthropic's official API (async).
55
+
56
+ Args:
57
+ text: Text to count tokens for
58
+
59
+ Returns:
60
+ Exact token count from Anthropic API
61
+
62
+ Raises:
63
+ RuntimeError: If API call fails
64
+ """
65
+ try:
66
+ # Anthropic API expects messages format and model parameter
67
+ result = self.client.messages.count_tokens(
68
+ messages=[{"role": "user", "content": text}], model=self.model_name
69
+ )
70
+ return result.input_tokens
71
+ except Exception as e:
72
+ raise RuntimeError(
73
+ f"Anthropic token counting API failed for {self.model_name}"
74
+ ) from e
75
+
76
+ async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
77
+ """Count tokens across all messages using Anthropic API (async).
78
+
79
+ Args:
80
+ messages: List of PydanticAI messages
81
+
82
+ Returns:
83
+ Total token count for all messages
84
+
85
+ Raises:
86
+ RuntimeError: If token counting fails
87
+ """
88
+ total_text = extract_text_from_messages(messages)
89
+ return await self.count_tokens(total_text)
@@ -0,0 +1,67 @@
1
+ """Base classes and shared utilities for token counting."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ from pydantic_ai.messages import ModelMessage
6
+
7
+
8
+ class TokenCounter(ABC):
9
+ """Abstract base class for provider-specific token counting.
10
+
11
+ All methods are async to support non-blocking operations like
12
+ downloading tokenizer models or making API calls.
13
+ """
14
+
15
+ @abstractmethod
16
+ async def count_tokens(self, text: str) -> int:
17
+ """Count tokens in text using provider-specific method (async).
18
+
19
+ Args:
20
+ text: Text to count tokens for
21
+
22
+ Returns:
23
+ Exact token count as determined by the provider
24
+
25
+ Raises:
26
+ RuntimeError: If token counting fails
27
+ """
28
+
29
+ @abstractmethod
30
+ async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
31
+ """Count tokens in PydanticAI message structures (async).
32
+
33
+ Args:
34
+ messages: List of messages to count tokens for
35
+
36
+ Returns:
37
+ Total token count across all messages
38
+
39
+ Raises:
40
+ RuntimeError: If token counting fails
41
+ """
42
+
43
+
44
+ def extract_text_from_messages(messages: list[ModelMessage]) -> str:
45
+ """Extract all text content from messages for token counting.
46
+
47
+ Args:
48
+ messages: List of PydanticAI messages
49
+
50
+ Returns:
51
+ Combined text content from all messages
52
+ """
53
+ text_parts = []
54
+
55
+ for message in messages:
56
+ if hasattr(message, "parts"):
57
+ for part in message.parts:
58
+ if hasattr(part, "content") and isinstance(part.content, str):
59
+ text_parts.append(part.content)
60
+ else:
61
+ # Handle non-text parts (tool calls, etc.)
62
+ text_parts.append(str(part))
63
+ else:
64
+ # Handle messages without parts
65
+ text_parts.append(str(message))
66
+
67
+ return "\n".join(text_parts)