shotgun-sh 0.1.15.dev2__py3-none-any.whl → 0.2.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of shotgun-sh might be problematic. Click here for more details.

Files changed (40) hide show
  1. shotgun/agents/common.py +4 -5
  2. shotgun/agents/config/constants.py +21 -5
  3. shotgun/agents/config/manager.py +147 -39
  4. shotgun/agents/config/models.py +59 -86
  5. shotgun/agents/config/provider.py +164 -61
  6. shotgun/agents/history/compaction.py +1 -1
  7. shotgun/agents/history/history_processors.py +18 -9
  8. shotgun/agents/history/token_counting/__init__.py +31 -0
  9. shotgun/agents/history/token_counting/anthropic.py +89 -0
  10. shotgun/agents/history/token_counting/base.py +67 -0
  11. shotgun/agents/history/token_counting/openai.py +80 -0
  12. shotgun/agents/history/token_counting/sentencepiece_counter.py +119 -0
  13. shotgun/agents/history/token_counting/tokenizer_cache.py +90 -0
  14. shotgun/agents/history/token_counting/utils.py +147 -0
  15. shotgun/agents/history/token_estimation.py +12 -12
  16. shotgun/agents/llm.py +62 -0
  17. shotgun/agents/models.py +2 -2
  18. shotgun/agents/tools/web_search/__init__.py +42 -15
  19. shotgun/agents/tools/web_search/anthropic.py +46 -40
  20. shotgun/agents/tools/web_search/gemini.py +31 -20
  21. shotgun/agents/tools/web_search/openai.py +4 -4
  22. shotgun/cli/config.py +14 -55
  23. shotgun/cli/feedback.py +1 -1
  24. shotgun/cli/models.py +2 -2
  25. shotgun/codebase/models.py +4 -4
  26. shotgun/llm_proxy/__init__.py +16 -0
  27. shotgun/llm_proxy/clients.py +39 -0
  28. shotgun/llm_proxy/constants.py +8 -0
  29. shotgun/main.py +6 -0
  30. shotgun/posthog_telemetry.py +5 -3
  31. shotgun/tui/app.py +3 -1
  32. shotgun/tui/screens/chat_screen/command_providers.py +20 -0
  33. shotgun/tui/screens/model_picker.py +214 -0
  34. shotgun/tui/screens/provider_config.py +39 -26
  35. {shotgun_sh-0.1.15.dev2.dist-info → shotgun_sh-0.2.1.dev1.dist-info}/METADATA +2 -2
  36. {shotgun_sh-0.1.15.dev2.dist-info → shotgun_sh-0.2.1.dev1.dist-info}/RECORD +39 -28
  37. shotgun/agents/history/token_counting.py +0 -429
  38. {shotgun_sh-0.1.15.dev2.dist-info → shotgun_sh-0.2.1.dev1.dist-info}/WHEEL +0 -0
  39. {shotgun_sh-0.1.15.dev2.dist-info → shotgun_sh-0.2.1.dev1.dist-info}/entry_points.txt +0 -0
  40. {shotgun_sh-0.1.15.dev2.dist-info → shotgun_sh-0.2.1.dev1.dist-info}/licenses/LICENSE +0 -0
@@ -12,27 +12,42 @@ from pydantic_ai.providers.google import GoogleProvider
12
12
  from pydantic_ai.providers.openai import OpenAIProvider
13
13
  from pydantic_ai.settings import ModelSettings
14
14
 
15
+ from shotgun.llm_proxy import create_litellm_provider
15
16
  from shotgun.logging_config import get_logger
16
17
 
17
18
  from .constants import (
18
19
  ANTHROPIC_API_KEY_ENV,
19
20
  GEMINI_API_KEY_ENV,
20
21
  OPENAI_API_KEY_ENV,
22
+ SHOTGUN_API_KEY_ENV,
21
23
  )
22
24
  from .manager import get_config_manager
23
- from .models import MODEL_SPECS, ModelConfig, ProviderType
25
+ from .models import (
26
+ MODEL_SPECS,
27
+ KeyProvider,
28
+ ModelConfig,
29
+ ModelName,
30
+ ProviderType,
31
+ ShotgunConfig,
32
+ )
24
33
 
25
34
  logger = get_logger(__name__)
26
35
 
27
36
  # Global cache for Model instances (singleton pattern)
28
- _model_cache: dict[tuple[ProviderType, str, str], Model] = {}
37
+ _model_cache: dict[tuple[ProviderType, KeyProvider, ModelName, str], Model] = {}
29
38
 
30
39
 
31
- def get_or_create_model(provider: ProviderType, model_name: str, api_key: str) -> Model:
40
+ def get_or_create_model(
41
+ provider: ProviderType,
42
+ key_provider: "KeyProvider",
43
+ model_name: ModelName,
44
+ api_key: str,
45
+ ) -> Model:
32
46
  """Get or create a singleton Model instance.
33
47
 
34
48
  Args:
35
- provider: Provider type
49
+ provider: Actual LLM provider (openai, anthropic, google)
50
+ key_provider: Authentication method (byok or shotgun)
36
51
  model_name: Name of the model
37
52
  api_key: API key for the provider
38
53
 
@@ -42,66 +57,88 @@ def get_or_create_model(provider: ProviderType, model_name: str, api_key: str) -
42
57
  Raises:
43
58
  ValueError: If provider is not supported
44
59
  """
45
- cache_key = (provider, model_name, api_key)
60
+ cache_key = (provider, key_provider, model_name, api_key)
46
61
 
47
62
  if cache_key not in _model_cache:
48
- logger.debug("Creating new %s model instance: %s", provider.value, model_name)
63
+ logger.debug(
64
+ "Creating new %s model instance via %s: %s",
65
+ provider.value,
66
+ key_provider.value,
67
+ model_name,
68
+ )
49
69
 
50
- if provider == ProviderType.OPENAI:
51
- # Get max_tokens from MODEL_SPECS to use full capacity
70
+ # Get max_tokens from MODEL_SPECS
71
+ if model_name in MODEL_SPECS:
72
+ max_tokens = MODEL_SPECS[model_name].max_output_tokens
73
+ else:
74
+ # Fallback defaults based on provider
75
+ max_tokens = {
76
+ ProviderType.OPENAI: 16_000,
77
+ ProviderType.ANTHROPIC: 32_000,
78
+ ProviderType.GOOGLE: 64_000,
79
+ }.get(provider, 16_000)
80
+
81
+ # Use LiteLLM proxy for Shotgun Account, native providers for BYOK
82
+ if key_provider == KeyProvider.SHOTGUN:
83
+ # Shotgun Account uses LiteLLM proxy for any model
52
84
  if model_name in MODEL_SPECS:
53
- max_tokens = MODEL_SPECS[model_name].max_output_tokens
85
+ litellm_model_name = MODEL_SPECS[model_name].litellm_proxy_model_name
54
86
  else:
55
- max_tokens = 16_000 # Default for GPT models
87
+ # Fallback for unmapped models
88
+ litellm_model_name = f"openai/{model_name.value}"
56
89
 
57
- openai_provider = OpenAIProvider(api_key=api_key)
90
+ litellm_provider = create_litellm_provider(api_key)
58
91
  _model_cache[cache_key] = OpenAIChatModel(
59
- model_name,
60
- provider=openai_provider,
92
+ litellm_model_name,
93
+ provider=litellm_provider,
61
94
  settings=ModelSettings(max_tokens=max_tokens),
62
95
  )
63
- elif provider == ProviderType.ANTHROPIC:
64
- # Get max_tokens from MODEL_SPECS to use full capacity
65
- if model_name in MODEL_SPECS:
66
- max_tokens = MODEL_SPECS[model_name].max_output_tokens
96
+ elif key_provider == KeyProvider.BYOK:
97
+ # Use native provider implementations with user's API keys
98
+ if provider == ProviderType.OPENAI:
99
+ openai_provider = OpenAIProvider(api_key=api_key)
100
+ _model_cache[cache_key] = OpenAIChatModel(
101
+ model_name,
102
+ provider=openai_provider,
103
+ settings=ModelSettings(max_tokens=max_tokens),
104
+ )
105
+ elif provider == ProviderType.ANTHROPIC:
106
+ anthropic_provider = AnthropicProvider(api_key=api_key)
107
+ _model_cache[cache_key] = AnthropicModel(
108
+ model_name,
109
+ provider=anthropic_provider,
110
+ settings=AnthropicModelSettings(
111
+ max_tokens=max_tokens,
112
+ timeout=600, # 10 minutes timeout for large responses
113
+ ),
114
+ )
115
+ elif provider == ProviderType.GOOGLE:
116
+ google_provider = GoogleProvider(api_key=api_key)
117
+ _model_cache[cache_key] = GoogleModel(
118
+ model_name,
119
+ provider=google_provider,
120
+ settings=ModelSettings(max_tokens=max_tokens),
121
+ )
67
122
  else:
68
- max_tokens = 32_000 # Default for Claude models
69
-
70
- anthropic_provider = AnthropicProvider(api_key=api_key)
71
- _model_cache[cache_key] = AnthropicModel(
72
- model_name,
73
- provider=anthropic_provider,
74
- settings=AnthropicModelSettings(
75
- max_tokens=max_tokens,
76
- timeout=600, # 10 minutes timeout for large responses
77
- ),
78
- )
79
- elif provider == ProviderType.GOOGLE:
80
- # Get max_tokens from MODEL_SPECS to use full capacity
81
- if model_name in MODEL_SPECS:
82
- max_tokens = MODEL_SPECS[model_name].max_output_tokens
83
- else:
84
- max_tokens = 64_000 # Default for Gemini models
85
-
86
- google_provider = GoogleProvider(api_key=api_key)
87
- _model_cache[cache_key] = GoogleModel(
88
- model_name,
89
- provider=google_provider,
90
- settings=ModelSettings(max_tokens=max_tokens),
91
- )
123
+ raise ValueError(f"Unsupported provider: {provider}")
92
124
  else:
93
- raise ValueError(f"Unsupported provider: {provider}")
125
+ raise ValueError(f"Unsupported key provider: {key_provider}")
94
126
  else:
95
127
  logger.debug("Reusing cached %s model instance: %s", provider.value, model_name)
96
128
 
97
129
  return _model_cache[cache_key]
98
130
 
99
131
 
100
- def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
132
+ def get_provider_model(
133
+ provider_or_model: ProviderType | ModelName | None = None,
134
+ ) -> ModelConfig:
101
135
  """Get a fully configured ModelConfig with API key and Model instance.
102
136
 
103
137
  Args:
104
- provider: Provider to get model for. If None, uses default provider
138
+ provider_or_model: Either a ProviderType, ModelName, or None.
139
+ - If ModelName: returns that specific model with appropriate API key
140
+ - If ProviderType: returns default model for that provider (backward compatible)
141
+ - If None: uses default provider with its default model
105
142
 
106
143
  Returns:
107
144
  ModelConfig with API key configured and lazy Model instance
@@ -111,14 +148,58 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
111
148
  """
112
149
  config_manager = get_config_manager()
113
150
  config = config_manager.load()
114
- # Convert string to ProviderType enum if needed
115
- provider_enum = (
116
- provider
117
- if isinstance(provider, ProviderType)
118
- else ProviderType(provider)
119
- if provider
120
- else config.default_provider
121
- )
151
+
152
+ # Priority 1: Check if Shotgun key exists - if so, use it for ANY model
153
+ shotgun_api_key = _get_api_key(config.shotgun.api_key, SHOTGUN_API_KEY_ENV)
154
+ if shotgun_api_key:
155
+ # Use selected model or default to claude-opus-4-1
156
+ model_name = config.selected_model or ModelName.CLAUDE_OPUS_4_1
157
+ if model_name not in MODEL_SPECS:
158
+ raise ValueError(f"Model '{model_name.value}' not found")
159
+ spec = MODEL_SPECS[model_name]
160
+
161
+ # Use Shotgun Account with selected model (provider = actual LLM provider)
162
+ return ModelConfig(
163
+ name=spec.name,
164
+ provider=spec.provider, # Actual LLM provider (OPENAI/ANTHROPIC/GOOGLE)
165
+ key_provider=KeyProvider.SHOTGUN, # Authenticated via Shotgun Account
166
+ max_input_tokens=spec.max_input_tokens,
167
+ max_output_tokens=spec.max_output_tokens,
168
+ api_key=shotgun_api_key,
169
+ )
170
+
171
+ # Priority 2: Fall back to individual provider keys
172
+
173
+ # Check if a specific model was requested
174
+ if isinstance(provider_or_model, ModelName):
175
+ # Look up the model spec
176
+ if provider_or_model not in MODEL_SPECS:
177
+ raise ValueError(f"Model '{provider_or_model.value}' not found")
178
+ spec = MODEL_SPECS[provider_or_model]
179
+ provider_enum = spec.provider
180
+ requested_model = provider_or_model
181
+ else:
182
+ # Convert string to ProviderType enum if needed (backward compatible)
183
+ if provider_or_model:
184
+ provider_enum = (
185
+ provider_or_model
186
+ if isinstance(provider_or_model, ProviderType)
187
+ else ProviderType(provider_or_model)
188
+ )
189
+ else:
190
+ # No provider specified - find first available provider with a key
191
+ provider_enum = None
192
+ for provider in ProviderType:
193
+ if _has_provider_key(config, provider):
194
+ provider_enum = provider
195
+ break
196
+
197
+ if provider_enum is None:
198
+ raise ValueError(
199
+ "No provider keys configured. Set via environment variables or config."
200
+ )
201
+
202
+ requested_model = None # Will use provider's default model
122
203
 
123
204
  if provider_enum == ProviderType.OPENAI:
124
205
  api_key = _get_api_key(config.openai.api_key, OPENAI_API_KEY_ENV)
@@ -127,16 +208,17 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
127
208
  f"OpenAI API key not configured. Set via environment variable {OPENAI_API_KEY_ENV} or config."
128
209
  )
129
210
 
130
- # Get model spec - hardcoded to gpt-5
131
- model_name = "gpt-5"
211
+ # Use requested model or default to gpt-5
212
+ model_name = requested_model if requested_model else ModelName.GPT_5
132
213
  if model_name not in MODEL_SPECS:
133
- raise ValueError(f"Model '{model_name}' not found")
214
+ raise ValueError(f"Model '{model_name.value}' not found")
134
215
  spec = MODEL_SPECS[model_name]
135
216
 
136
217
  # Create fully configured ModelConfig
137
218
  return ModelConfig(
138
219
  name=spec.name,
139
220
  provider=spec.provider,
221
+ key_provider=KeyProvider.BYOK,
140
222
  max_input_tokens=spec.max_input_tokens,
141
223
  max_output_tokens=spec.max_output_tokens,
142
224
  api_key=api_key,
@@ -149,16 +231,17 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
149
231
  f"Anthropic API key not configured. Set via environment variable {ANTHROPIC_API_KEY_ENV} or config."
150
232
  )
151
233
 
152
- # Get model spec - hardcoded to claude-opus-4-1
153
- model_name = "claude-opus-4-1"
234
+ # Use requested model or default to claude-opus-4-1
235
+ model_name = requested_model if requested_model else ModelName.CLAUDE_OPUS_4_1
154
236
  if model_name not in MODEL_SPECS:
155
- raise ValueError(f"Model '{model_name}' not found")
237
+ raise ValueError(f"Model '{model_name.value}' not found")
156
238
  spec = MODEL_SPECS[model_name]
157
239
 
158
240
  # Create fully configured ModelConfig
159
241
  return ModelConfig(
160
242
  name=spec.name,
161
243
  provider=spec.provider,
244
+ key_provider=KeyProvider.BYOK,
162
245
  max_input_tokens=spec.max_input_tokens,
163
246
  max_output_tokens=spec.max_output_tokens,
164
247
  api_key=api_key,
@@ -171,16 +254,17 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
171
254
  f"Gemini API key not configured. Set via environment variable {GEMINI_API_KEY_ENV} or config."
172
255
  )
173
256
 
174
- # Get model spec - hardcoded to gemini-2.5-pro
175
- model_name = "gemini-2.5-pro"
257
+ # Use requested model or default to gemini-2.5-pro
258
+ model_name = requested_model if requested_model else ModelName.GEMINI_2_5_PRO
176
259
  if model_name not in MODEL_SPECS:
177
- raise ValueError(f"Model '{model_name}' not found")
260
+ raise ValueError(f"Model '{model_name.value}' not found")
178
261
  spec = MODEL_SPECS[model_name]
179
262
 
180
263
  # Create fully configured ModelConfig
181
264
  return ModelConfig(
182
265
  name=spec.name,
183
266
  provider=spec.provider,
267
+ key_provider=KeyProvider.BYOK,
184
268
  max_input_tokens=spec.max_input_tokens,
185
269
  max_output_tokens=spec.max_output_tokens,
186
270
  api_key=api_key,
@@ -190,6 +274,25 @@ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
190
274
  raise ValueError(f"Unsupported provider: {provider_enum}")
191
275
 
192
276
 
277
+ def _has_provider_key(config: "ShotgunConfig", provider: ProviderType) -> bool:
278
+ """Check if a provider has a configured API key.
279
+
280
+ Args:
281
+ config: Shotgun configuration
282
+ provider: Provider to check
283
+
284
+ Returns:
285
+ True if provider has a configured API key
286
+ """
287
+ if provider == ProviderType.OPENAI:
288
+ return bool(_get_api_key(config.openai.api_key, OPENAI_API_KEY_ENV))
289
+ elif provider == ProviderType.ANTHROPIC:
290
+ return bool(_get_api_key(config.anthropic.api_key, ANTHROPIC_API_KEY_ENV))
291
+ elif provider == ProviderType.GOOGLE:
292
+ return bool(_get_api_key(config.google.api_key, GEMINI_API_KEY_ENV))
293
+ return False
294
+
295
+
193
296
  def _get_api_key(config_key: SecretStr | None, env_var: str) -> str | None:
194
297
  """Get API key from config or environment variable.
195
298
 
@@ -31,7 +31,7 @@ async def apply_persistent_compaction(
31
31
 
32
32
  try:
33
33
  # Count actual token usage using shared utility
34
- estimated_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
34
+ estimated_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
35
35
 
36
36
  # Create minimal usage info for compaction check
37
37
  usage = RequestUsage(
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import TYPE_CHECKING, Any, Protocol
4
4
 
5
+ from pydantic_ai import ModelSettings
5
6
  from pydantic_ai.messages import (
6
7
  ModelMessage,
7
8
  ModelRequest,
@@ -10,7 +11,7 @@ from pydantic_ai.messages import (
10
11
  UserPromptPart,
11
12
  )
12
13
 
13
- from shotgun.agents.config.models import shotgun_model_request
14
+ from shotgun.agents.llm import shotgun_model_request
14
15
  from shotgun.agents.messages import AgentSystemPrompt, SystemStatusPrompt
15
16
  from shotgun.agents.models import AgentDeps
16
17
  from shotgun.logging_config import get_logger
@@ -154,7 +155,7 @@ async def token_limit_compactor(
154
155
 
155
156
  if last_summary_index is not None:
156
157
  # Check if post-summary conversation exceeds threshold for incremental compaction
157
- post_summary_tokens = estimate_post_summary_tokens(
158
+ post_summary_tokens = await estimate_post_summary_tokens(
158
159
  messages, last_summary_index, deps.llm_model
159
160
  )
160
161
  post_summary_percentage = (
@@ -248,7 +249,7 @@ async def token_limit_compactor(
248
249
  ]
249
250
 
250
251
  # Calculate optimal max_tokens for summarization
251
- max_tokens = calculate_max_summarization_tokens(
252
+ max_tokens = await calculate_max_summarization_tokens(
252
253
  deps.llm_model, request_messages
253
254
  )
254
255
 
@@ -261,7 +262,9 @@ async def token_limit_compactor(
261
262
  summary_response = await shotgun_model_request(
262
263
  model_config=deps.llm_model,
263
264
  messages=request_messages,
264
- max_tokens=max_tokens, # Use calculated optimal tokens for summarization
265
+ model_settings=ModelSettings(
266
+ max_tokens=max_tokens # Use calculated optimal tokens for summarization
267
+ ),
265
268
  )
266
269
 
267
270
  log_summarization_response(summary_response, "INCREMENTAL")
@@ -328,7 +331,9 @@ async def token_limit_compactor(
328
331
 
329
332
  # Track compaction completion
330
333
  messages_after = len(compacted_messages)
331
- tokens_after = estimate_tokens_from_messages(compacted_messages, deps.llm_model)
334
+ tokens_after = await estimate_tokens_from_messages(
335
+ compacted_messages, deps.llm_model
336
+ )
332
337
  reduction_percentage = (
333
338
  ((messages_before - messages_after) / messages_before * 100)
334
339
  if messages_before > 0
@@ -354,7 +359,7 @@ async def token_limit_compactor(
354
359
 
355
360
  else:
356
361
  # Check if total conversation exceeds threshold for full compaction
357
- total_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
362
+ total_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
358
363
  total_percentage = (total_tokens / max_tokens) * 100 if max_tokens > 0 else 0
359
364
 
360
365
  logger.debug(
@@ -392,7 +397,9 @@ async def _full_compaction(
392
397
  ]
393
398
 
394
399
  # Calculate optimal max_tokens for summarization
395
- max_tokens = calculate_max_summarization_tokens(deps.llm_model, request_messages)
400
+ max_tokens = await calculate_max_summarization_tokens(
401
+ deps.llm_model, request_messages
402
+ )
396
403
 
397
404
  # Debug logging using shared utilities
398
405
  log_summarization_request(
@@ -403,11 +410,13 @@ async def _full_compaction(
403
410
  summary_response = await shotgun_model_request(
404
411
  model_config=deps.llm_model,
405
412
  messages=request_messages,
406
- max_tokens=max_tokens, # Use calculated optimal tokens for summarization
413
+ model_settings=ModelSettings(
414
+ max_tokens=max_tokens # Use calculated optimal tokens for summarization
415
+ ),
407
416
  )
408
417
 
409
418
  # Calculate token reduction
410
- current_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
419
+ current_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
411
420
  summary_usage = summary_response.usage
412
421
  reduction_percentage = (
413
422
  ((current_tokens - summary_usage.output_tokens) / current_tokens) * 100
@@ -0,0 +1,31 @@
1
+ """Real token counting for all supported providers.
2
+
3
+ This module provides accurate token counting using each provider's official
4
+ APIs and libraries, eliminating the need for rough character-based estimation.
5
+ """
6
+
7
+ from .anthropic import AnthropicTokenCounter
8
+ from .base import TokenCounter, extract_text_from_messages
9
+ from .openai import OpenAITokenCounter
10
+ from .sentencepiece_counter import SentencePieceTokenCounter
11
+ from .utils import (
12
+ count_post_summary_tokens,
13
+ count_tokens_from_message_parts,
14
+ count_tokens_from_messages,
15
+ get_token_counter,
16
+ )
17
+
18
+ __all__ = [
19
+ # Base classes
20
+ "TokenCounter",
21
+ # Counter implementations
22
+ "OpenAITokenCounter",
23
+ "AnthropicTokenCounter",
24
+ "SentencePieceTokenCounter",
25
+ # Utility functions
26
+ "get_token_counter",
27
+ "count_tokens_from_messages",
28
+ "count_post_summary_tokens",
29
+ "count_tokens_from_message_parts",
30
+ "extract_text_from_messages",
31
+ ]
@@ -0,0 +1,89 @@
1
+ """Anthropic token counting using official client."""
2
+
3
+ from pydantic_ai.messages import ModelMessage
4
+
5
+ from shotgun.agents.config.models import KeyProvider
6
+ from shotgun.llm_proxy import create_anthropic_proxy_client
7
+ from shotgun.logging_config import get_logger
8
+
9
+ from .base import TokenCounter, extract_text_from_messages
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class AnthropicTokenCounter(TokenCounter):
15
+ """Token counter for Anthropic models using official client."""
16
+
17
+ def __init__(
18
+ self,
19
+ model_name: str,
20
+ api_key: str,
21
+ key_provider: KeyProvider = KeyProvider.BYOK,
22
+ ):
23
+ """Initialize Anthropic token counter.
24
+
25
+ Args:
26
+ model_name: Anthropic model name for token counting
27
+ api_key: API key (Anthropic for BYOK, Shotgun for proxy)
28
+ key_provider: Key provider type (BYOK or SHOTGUN)
29
+
30
+ Raises:
31
+ RuntimeError: If client initialization fails
32
+ """
33
+ self.model_name = model_name
34
+ import anthropic
35
+
36
+ try:
37
+ if key_provider == KeyProvider.SHOTGUN:
38
+ # Use LiteLLM proxy for Shotgun Account
39
+ # Proxies to Anthropic's token counting API
40
+ self.client = create_anthropic_proxy_client(api_key)
41
+ logger.debug(
42
+ f"Initialized Anthropic token counter for {model_name} via LiteLLM proxy"
43
+ )
44
+ else:
45
+ # Direct Anthropic API for BYOK
46
+ self.client = anthropic.Anthropic(api_key=api_key)
47
+ logger.debug(
48
+ f"Initialized Anthropic token counter for {model_name} via direct API"
49
+ )
50
+ except Exception as e:
51
+ raise RuntimeError("Failed to initialize Anthropic client") from e
52
+
53
+ async def count_tokens(self, text: str) -> int:
54
+ """Count tokens using Anthropic's official API (async).
55
+
56
+ Args:
57
+ text: Text to count tokens for
58
+
59
+ Returns:
60
+ Exact token count from Anthropic API
61
+
62
+ Raises:
63
+ RuntimeError: If API call fails
64
+ """
65
+ try:
66
+ # Anthropic API expects messages format and model parameter
67
+ result = self.client.messages.count_tokens(
68
+ messages=[{"role": "user", "content": text}], model=self.model_name
69
+ )
70
+ return result.input_tokens
71
+ except Exception as e:
72
+ raise RuntimeError(
73
+ f"Anthropic token counting API failed for {self.model_name}"
74
+ ) from e
75
+
76
+ async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
77
+ """Count tokens across all messages using Anthropic API (async).
78
+
79
+ Args:
80
+ messages: List of PydanticAI messages
81
+
82
+ Returns:
83
+ Total token count for all messages
84
+
85
+ Raises:
86
+ RuntimeError: If token counting fails
87
+ """
88
+ total_text = extract_text_from_messages(messages)
89
+ return await self.count_tokens(total_text)
@@ -0,0 +1,67 @@
1
+ """Base classes and shared utilities for token counting."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ from pydantic_ai.messages import ModelMessage
6
+
7
+
8
+ class TokenCounter(ABC):
9
+ """Abstract base class for provider-specific token counting.
10
+
11
+ All methods are async to support non-blocking operations like
12
+ downloading tokenizer models or making API calls.
13
+ """
14
+
15
+ @abstractmethod
16
+ async def count_tokens(self, text: str) -> int:
17
+ """Count tokens in text using provider-specific method (async).
18
+
19
+ Args:
20
+ text: Text to count tokens for
21
+
22
+ Returns:
23
+ Exact token count as determined by the provider
24
+
25
+ Raises:
26
+ RuntimeError: If token counting fails
27
+ """
28
+
29
+ @abstractmethod
30
+ async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
31
+ """Count tokens in PydanticAI message structures (async).
32
+
33
+ Args:
34
+ messages: List of messages to count tokens for
35
+
36
+ Returns:
37
+ Total token count across all messages
38
+
39
+ Raises:
40
+ RuntimeError: If token counting fails
41
+ """
42
+
43
+
44
+ def extract_text_from_messages(messages: list[ModelMessage]) -> str:
45
+ """Extract all text content from messages for token counting.
46
+
47
+ Args:
48
+ messages: List of PydanticAI messages
49
+
50
+ Returns:
51
+ Combined text content from all messages
52
+ """
53
+ text_parts = []
54
+
55
+ for message in messages:
56
+ if hasattr(message, "parts"):
57
+ for part in message.parts:
58
+ if hasattr(part, "content") and isinstance(part.content, str):
59
+ text_parts.append(part.content)
60
+ else:
61
+ # Handle non-text parts (tool calls, etc.)
62
+ text_parts.append(str(part))
63
+ else:
64
+ # Handle messages without parts
65
+ text_parts.append(str(message))
66
+
67
+ return "\n".join(text_parts)