shotgun-sh 0.1.0.dev13__py3-none-any.whl → 0.1.0.dev14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of shotgun-sh might be problematic. Click here for more details.

Files changed (40) hide show
  1. shotgun/agents/agent_manager.py +16 -3
  2. shotgun/agents/artifact_state.py +58 -0
  3. shotgun/agents/common.py +48 -14
  4. shotgun/agents/config/models.py +61 -0
  5. shotgun/agents/history/compaction.py +85 -0
  6. shotgun/agents/history/constants.py +19 -0
  7. shotgun/agents/history/context_extraction.py +108 -0
  8. shotgun/agents/history/history_building.py +104 -0
  9. shotgun/agents/history/history_processors.py +354 -157
  10. shotgun/agents/history/message_utils.py +46 -0
  11. shotgun/agents/history/token_counting.py +429 -0
  12. shotgun/agents/history/token_estimation.py +138 -0
  13. shotgun/agents/models.py +125 -1
  14. shotgun/agents/tools/artifact_management.py +56 -24
  15. shotgun/agents/tools/file_management.py +30 -11
  16. shotgun/agents/tools/web_search/anthropic.py +78 -17
  17. shotgun/agents/tools/web_search/gemini.py +1 -1
  18. shotgun/agents/tools/web_search/openai.py +16 -2
  19. shotgun/artifacts/manager.py +2 -1
  20. shotgun/artifacts/models.py +6 -4
  21. shotgun/codebase/core/nl_query.py +4 -4
  22. shotgun/prompts/agents/partials/artifact_system.j2 +4 -1
  23. shotgun/prompts/agents/partials/codebase_understanding.j2 +1 -2
  24. shotgun/prompts/agents/plan.j2 +9 -7
  25. shotgun/prompts/agents/research.j2 +7 -5
  26. shotgun/prompts/agents/specify.j2 +8 -7
  27. shotgun/prompts/agents/state/artifact_templates_available.j2 +18 -0
  28. shotgun/prompts/agents/state/codebase/codebase_graphs_available.j2 +3 -1
  29. shotgun/prompts/agents/state/existing_artifacts_available.j2 +23 -0
  30. shotgun/prompts/agents/state/system_state.j2 +9 -1
  31. shotgun/prompts/history/incremental_summarization.j2 +53 -0
  32. shotgun/sdk/services.py +14 -0
  33. shotgun/tui/app.py +1 -1
  34. shotgun/tui/screens/chat.py +4 -2
  35. shotgun/utils/file_system_utils.py +6 -1
  36. {shotgun_sh-0.1.0.dev13.dist-info → shotgun_sh-0.1.0.dev14.dist-info}/METADATA +2 -1
  37. {shotgun_sh-0.1.0.dev13.dist-info → shotgun_sh-0.1.0.dev14.dist-info}/RECORD +40 -29
  38. {shotgun_sh-0.1.0.dev13.dist-info → shotgun_sh-0.1.0.dev14.dist-info}/WHEEL +0 -0
  39. {shotgun_sh-0.1.0.dev13.dist-info → shotgun_sh-0.1.0.dev14.dist-info}/entry_points.txt +0 -0
  40. {shotgun_sh-0.1.0.dev13.dist-info → shotgun_sh-0.1.0.dev14.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,429 @@
1
+ """Real token counting for all supported providers.
2
+
3
+ This module provides accurate token counting using each provider's official
4
+ APIs and libraries, eliminating the need for rough character-based estimation.
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+ from typing import TYPE_CHECKING
9
+
10
+ from pydantic_ai.messages import ModelMessage
11
+
12
+ from shotgun.agents.config.models import ModelConfig, ProviderType
13
+ from shotgun.logging_config import get_logger
14
+
15
+ if TYPE_CHECKING:
16
+ pass
17
+
18
+ logger = get_logger(__name__)
19
+
20
+ # Global cache for token counter instances (singleton pattern)
21
+ _token_counter_cache: dict[tuple[str, str, str], "TokenCounter"] = {}
22
+
23
+
24
+ class TokenCounter(ABC):
25
+ """Abstract base class for provider-specific token counting."""
26
+
27
+ @abstractmethod
28
+ def count_tokens(self, text: str) -> int:
29
+ """Count tokens in text using provider-specific method.
30
+
31
+ Args:
32
+ text: Text to count tokens for
33
+
34
+ Returns:
35
+ Exact token count as determined by the provider
36
+
37
+ Raises:
38
+ RuntimeError: If token counting fails
39
+ """
40
+
41
+ @abstractmethod
42
+ def count_message_tokens(self, messages: list[ModelMessage]) -> int:
43
+ """Count tokens in PydanticAI message structures.
44
+
45
+ Args:
46
+ messages: List of messages to count tokens for
47
+
48
+ Returns:
49
+ Total token count across all messages
50
+
51
+ Raises:
52
+ RuntimeError: If token counting fails
53
+ """
54
+
55
+
56
+ class OpenAITokenCounter(TokenCounter):
57
+ """Token counter for OpenAI models using tiktoken."""
58
+
59
+ # Official encoding mappings for OpenAI models
60
+ ENCODING_MAP = {
61
+ "gpt-5": "o200k_base",
62
+ "gpt-4o": "o200k_base",
63
+ "gpt-4": "cl100k_base",
64
+ "gpt-3.5-turbo": "cl100k_base",
65
+ }
66
+
67
+ def __init__(self, model_name: str):
68
+ """Initialize OpenAI token counter.
69
+
70
+ Args:
71
+ model_name: OpenAI model name to get correct encoding for
72
+
73
+ Raises:
74
+ RuntimeError: If encoding initialization fails
75
+ """
76
+ self.model_name = model_name
77
+
78
+ import tiktoken
79
+
80
+ try:
81
+ # Get the appropriate encoding for this model
82
+ encoding_name = self.ENCODING_MAP.get(model_name, "o200k_base")
83
+ self.encoding = tiktoken.get_encoding(encoding_name)
84
+ logger.debug(
85
+ f"Initialized OpenAI token counter with {encoding_name} encoding"
86
+ )
87
+ except Exception as e:
88
+ raise RuntimeError(
89
+ f"Failed to initialize tiktoken encoding for {model_name}"
90
+ ) from e
91
+
92
+ def count_tokens(self, text: str) -> int:
93
+ """Count tokens using tiktoken.
94
+
95
+ Args:
96
+ text: Text to count tokens for
97
+
98
+ Returns:
99
+ Exact token count using tiktoken
100
+
101
+ Raises:
102
+ RuntimeError: If token counting fails
103
+ """
104
+ try:
105
+ return len(self.encoding.encode(text))
106
+ except Exception as e:
107
+ raise RuntimeError(
108
+ f"Failed to count tokens for OpenAI model {self.model_name}"
109
+ ) from e
110
+
111
+ def count_message_tokens(self, messages: list[ModelMessage]) -> int:
112
+ """Count tokens across all messages using tiktoken.
113
+
114
+ Args:
115
+ messages: List of PydanticAI messages
116
+
117
+ Returns:
118
+ Total token count for all messages
119
+
120
+ Raises:
121
+ RuntimeError: If token counting fails
122
+ """
123
+ total_text = self._extract_text_from_messages(messages)
124
+ return self.count_tokens(total_text)
125
+
126
+ def _extract_text_from_messages(self, messages: list[ModelMessage]) -> str:
127
+ """Extract all text content from messages for token counting."""
128
+ text_parts = []
129
+
130
+ for message in messages:
131
+ if hasattr(message, "parts"):
132
+ for part in message.parts:
133
+ if hasattr(part, "content") and isinstance(part.content, str):
134
+ text_parts.append(part.content)
135
+ else:
136
+ # Handle non-text parts (tool calls, etc.)
137
+ text_parts.append(str(part))
138
+ else:
139
+ # Handle messages without parts
140
+ text_parts.append(str(message))
141
+
142
+ return "\n".join(text_parts)
143
+
144
+
145
+ class AnthropicTokenCounter(TokenCounter):
146
+ """Token counter for Anthropic models using official client."""
147
+
148
+ def __init__(self, model_name: str, api_key: str):
149
+ """Initialize Anthropic token counter.
150
+
151
+ Args:
152
+ model_name: Anthropic model name for token counting
153
+ api_key: Anthropic API key
154
+
155
+ Raises:
156
+ RuntimeError: If client initialization fails
157
+ """
158
+ self.model_name = model_name
159
+ import anthropic
160
+
161
+ try:
162
+ self.client = anthropic.Anthropic(api_key=api_key)
163
+ logger.debug(f"Initialized Anthropic token counter for {model_name}")
164
+ except Exception as e:
165
+ raise RuntimeError("Failed to initialize Anthropic client") from e
166
+
167
+ def count_tokens(self, text: str) -> int:
168
+ """Count tokens using Anthropic's official API.
169
+
170
+ Args:
171
+ text: Text to count tokens for
172
+
173
+ Returns:
174
+ Exact token count from Anthropic API
175
+
176
+ Raises:
177
+ RuntimeError: If API call fails
178
+ """
179
+ try:
180
+ # Anthropic API expects messages format and model parameter
181
+ result = self.client.messages.count_tokens(
182
+ messages=[{"role": "user", "content": text}], model=self.model_name
183
+ )
184
+ return result.input_tokens
185
+ except Exception as e:
186
+ raise RuntimeError(
187
+ f"Anthropic token counting API failed for {self.model_name}"
188
+ ) from e
189
+
190
+ def count_message_tokens(self, messages: list[ModelMessage]) -> int:
191
+ """Count tokens across all messages using Anthropic API.
192
+
193
+ Args:
194
+ messages: List of PydanticAI messages
195
+
196
+ Returns:
197
+ Total token count for all messages
198
+
199
+ Raises:
200
+ RuntimeError: If token counting fails
201
+ """
202
+ total_text = self._extract_text_from_messages(messages)
203
+ return self.count_tokens(total_text)
204
+
205
+ def _extract_text_from_messages(self, messages: list[ModelMessage]) -> str:
206
+ """Extract all text content from messages for token counting."""
207
+ text_parts = []
208
+
209
+ for message in messages:
210
+ if hasattr(message, "parts"):
211
+ for part in message.parts:
212
+ if hasattr(part, "content") and isinstance(part.content, str):
213
+ text_parts.append(part.content)
214
+ else:
215
+ # Handle non-text parts (tool calls, etc.)
216
+ text_parts.append(str(part))
217
+ else:
218
+ # Handle messages without parts
219
+ text_parts.append(str(message))
220
+
221
+ return "\n".join(text_parts)
222
+
223
+
224
+ class GoogleTokenCounter(TokenCounter):
225
+ """Token counter for Google models using genai API."""
226
+
227
+ def __init__(self, model_name: str, api_key: str):
228
+ """Initialize Google token counter.
229
+
230
+ Args:
231
+ model_name: Google model name
232
+ api_key: Google API key
233
+
234
+ Raises:
235
+ RuntimeError: If configuration fails
236
+ """
237
+ self.model_name = model_name
238
+
239
+ import google.generativeai as genai
240
+
241
+ try:
242
+ genai.configure(api_key=api_key) # type: ignore[attr-defined]
243
+ self.model = genai.GenerativeModel(model_name) # type: ignore[attr-defined]
244
+ logger.debug(f"Initialized Google token counter for {model_name}")
245
+ except Exception as e:
246
+ raise RuntimeError(
247
+ f"Failed to configure Google genai client for {model_name}"
248
+ ) from e
249
+
250
+ def count_tokens(self, text: str) -> int:
251
+ """Count tokens using Google's genai API.
252
+
253
+ Args:
254
+ text: Text to count tokens for
255
+
256
+ Returns:
257
+ Exact token count from Google API
258
+
259
+ Raises:
260
+ RuntimeError: If API call fails
261
+ """
262
+ try:
263
+ result = self.model.count_tokens(text)
264
+ return result.total_tokens
265
+ except Exception as e:
266
+ raise RuntimeError(
267
+ f"Google token counting API failed for {self.model_name}"
268
+ ) from e
269
+
270
+ def count_message_tokens(self, messages: list[ModelMessage]) -> int:
271
+ """Count tokens across all messages using Google API.
272
+
273
+ Args:
274
+ messages: List of PydanticAI messages
275
+
276
+ Returns:
277
+ Total token count for all messages
278
+
279
+ Raises:
280
+ RuntimeError: If token counting fails
281
+ """
282
+ total_text = self._extract_text_from_messages(messages)
283
+ return self.count_tokens(total_text)
284
+
285
+ def _extract_text_from_messages(self, messages: list[ModelMessage]) -> str:
286
+ """Extract all text content from messages for token counting."""
287
+ text_parts = []
288
+
289
+ for message in messages:
290
+ if hasattr(message, "parts"):
291
+ for part in message.parts:
292
+ if hasattr(part, "content") and isinstance(part.content, str):
293
+ text_parts.append(part.content)
294
+ else:
295
+ # Handle non-text parts (tool calls, etc.)
296
+ text_parts.append(str(part))
297
+ else:
298
+ # Handle messages without parts
299
+ text_parts.append(str(message))
300
+
301
+ return "\n".join(text_parts)
302
+
303
+
304
+ def get_token_counter(model_config: ModelConfig) -> TokenCounter:
305
+ """Get appropriate token counter for the model provider (cached singleton).
306
+
307
+ This function ensures that every provider has a proper token counting
308
+ implementation without any fallbacks to estimation. Token counters are
309
+ cached to avoid repeated initialization overhead.
310
+
311
+ Args:
312
+ model_config: Model configuration with provider and credentials
313
+
314
+ Returns:
315
+ Cached provider-specific token counter
316
+
317
+ Raises:
318
+ ValueError: If provider is not supported for token counting
319
+ RuntimeError: If token counter initialization fails
320
+ """
321
+ # Create cache key from provider, model name, and API key
322
+ cache_key = (
323
+ model_config.provider.value,
324
+ model_config.name,
325
+ model_config.api_key[:10]
326
+ if model_config.api_key
327
+ else "no-key", # Partial key for cache
328
+ )
329
+
330
+ # Return cached instance if available
331
+ if cache_key in _token_counter_cache:
332
+ logger.debug(
333
+ f"Reusing cached token counter for {model_config.provider.value}:{model_config.name}"
334
+ )
335
+ return _token_counter_cache[cache_key]
336
+
337
+ # Create new instance and cache it
338
+ logger.debug(
339
+ f"Creating new token counter for {model_config.provider.value}:{model_config.name}"
340
+ )
341
+
342
+ counter: TokenCounter
343
+ if model_config.provider == ProviderType.OPENAI:
344
+ counter = OpenAITokenCounter(model_config.name)
345
+ elif model_config.provider == ProviderType.ANTHROPIC:
346
+ counter = AnthropicTokenCounter(model_config.name, model_config.api_key)
347
+ elif model_config.provider == ProviderType.GOOGLE:
348
+ counter = GoogleTokenCounter(model_config.name, model_config.api_key)
349
+ else:
350
+ raise ValueError(
351
+ f"Unsupported provider for token counting: {model_config.provider}. "
352
+ f"Supported providers: {[p.value for p in ProviderType]}"
353
+ )
354
+
355
+ # Cache the instance
356
+ _token_counter_cache[cache_key] = counter
357
+ logger.debug(
358
+ f"Cached token counter for {model_config.provider.value}:{model_config.name}"
359
+ )
360
+
361
+ return counter
362
+
363
+
364
+ def count_tokens_from_messages(
365
+ messages: list[ModelMessage], model_config: ModelConfig
366
+ ) -> int:
367
+ """Count actual tokens from messages using provider-specific methods.
368
+
369
+ This replaces the old estimation approach with accurate token counting
370
+ using each provider's official APIs and libraries.
371
+
372
+ Args:
373
+ messages: List of messages to count tokens for
374
+ model_config: Model configuration with provider info
375
+
376
+ Returns:
377
+ Exact token count for the messages
378
+
379
+ Raises:
380
+ ValueError: If provider is not supported
381
+ RuntimeError: If token counting fails
382
+ """
383
+ counter = get_token_counter(model_config)
384
+ return counter.count_message_tokens(messages)
385
+
386
+
387
+ def count_post_summary_tokens(
388
+ messages: list[ModelMessage], summary_index: int, model_config: ModelConfig
389
+ ) -> int:
390
+ """Count actual tokens from summary onwards for incremental compaction decisions.
391
+
392
+ Args:
393
+ messages: Full message history
394
+ summary_index: Index of the last summary message
395
+ model_config: Model configuration with provider info
396
+
397
+ Returns:
398
+ Exact token count from summary onwards
399
+
400
+ Raises:
401
+ ValueError: If provider is not supported
402
+ RuntimeError: If token counting fails
403
+ """
404
+ if summary_index >= len(messages):
405
+ return 0
406
+
407
+ post_summary_messages = messages[summary_index:]
408
+ return count_tokens_from_messages(post_summary_messages, model_config)
409
+
410
+
411
+ def count_tokens_from_message_parts(
412
+ messages: list[ModelMessage], model_config: ModelConfig
413
+ ) -> int:
414
+ """Count actual tokens from message parts for summarization requests.
415
+
416
+ Args:
417
+ messages: List of messages to count tokens for
418
+ model_config: Model configuration with provider info
419
+
420
+ Returns:
421
+ Exact token count from message parts
422
+
423
+ Raises:
424
+ ValueError: If provider is not supported
425
+ RuntimeError: If token counting fails
426
+ """
427
+ # For now, use the same logic as count_tokens_from_messages
428
+ # This can be optimized later if needed for different counting strategies
429
+ return count_tokens_from_messages(messages, model_config)
@@ -0,0 +1,138 @@
1
+ """Real token counting utilities for history processing.
2
+
3
+ This module provides accurate token counting using provider-specific APIs
4
+ and libraries, replacing the old character-based estimation approach.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Union
8
+
9
+ from pydantic_ai.messages import ModelMessage
10
+
11
+ from shotgun.agents.config.models import ModelConfig
12
+
13
+ if TYPE_CHECKING:
14
+ from pydantic_ai import RunContext
15
+
16
+ from shotgun.agents.models import AgentDeps
17
+
18
+ from .constants import INPUT_BUFFER_TOKENS, MIN_SUMMARY_TOKENS
19
+ from .token_counting import count_tokens_from_messages as _count_tokens_from_messages
20
+
21
+
22
+ def estimate_tokens_from_messages(
23
+ messages: list[ModelMessage], model_config: ModelConfig
24
+ ) -> int:
25
+ """Count actual tokens from current message list.
26
+
27
+ This provides accurate token counting for compaction decisions using
28
+ provider-specific token counting methods instead of rough estimation.
29
+
30
+ Args:
31
+ messages: List of messages to count tokens for
32
+ model_config: Model configuration with provider info
33
+
34
+ Returns:
35
+ Exact token count using provider-specific counting
36
+
37
+ Raises:
38
+ ValueError: If provider is not supported
39
+ RuntimeError: If token counting fails
40
+ """
41
+ return _count_tokens_from_messages(messages, model_config)
42
+
43
+
44
+ def estimate_post_summary_tokens(
45
+ messages: list[ModelMessage], summary_index: int, model_config: ModelConfig
46
+ ) -> int:
47
+ """Count actual tokens from summary onwards for incremental compaction decisions.
48
+
49
+ This treats the summary as a reset point and only counts tokens from the summary
50
+ message onwards. Used to determine if incremental compaction is needed.
51
+
52
+ Args:
53
+ messages: Full message history
54
+ summary_index: Index of the last summary message
55
+ model_config: Model configuration with provider info
56
+
57
+ Returns:
58
+ Exact token count from summary onwards
59
+
60
+ Raises:
61
+ ValueError: If provider is not supported
62
+ RuntimeError: If token counting fails
63
+ """
64
+ if summary_index >= len(messages):
65
+ return 0
66
+
67
+ post_summary_messages = messages[summary_index:]
68
+ return estimate_tokens_from_messages(post_summary_messages, model_config)
69
+
70
+
71
+ def estimate_tokens_from_message_parts(
72
+ messages: list[ModelMessage], model_config: ModelConfig
73
+ ) -> int:
74
+ """Count actual tokens from message parts for summarization requests.
75
+
76
+ This provides accurate token counting across the codebase using
77
+ provider-specific methods instead of character estimation.
78
+
79
+ Args:
80
+ messages: List of messages to count tokens for
81
+ model_config: Model configuration with provider info
82
+
83
+ Returns:
84
+ Exact token count from message parts
85
+
86
+ Raises:
87
+ ValueError: If provider is not supported
88
+ RuntimeError: If token counting fails
89
+ """
90
+ return _count_tokens_from_messages(messages, model_config)
91
+
92
+
93
+ def calculate_max_summarization_tokens(
94
+ ctx_or_model_config: Union["RunContext[AgentDeps]", ModelConfig],
95
+ request_messages: list[ModelMessage],
96
+ ) -> int:
97
+ """Calculate maximum tokens available for summarization output.
98
+
99
+ This ensures we use the model's full capacity while leaving room for input tokens.
100
+
101
+ Args:
102
+ ctx_or_model_config: RunContext or model configuration with token limits
103
+ request_messages: The messages that will be sent for summarization
104
+
105
+ Returns:
106
+ Maximum tokens available for the summarization response
107
+ """
108
+ # Support both RunContext and direct model config
109
+ if hasattr(ctx_or_model_config, "deps"):
110
+ model_config = ctx_or_model_config.deps.llm_model
111
+ else:
112
+ model_config = ctx_or_model_config
113
+
114
+ if not model_config:
115
+ return MIN_SUMMARY_TOKENS
116
+
117
+ # Count actual input tokens using shared utility
118
+ estimated_input_tokens = estimate_tokens_from_message_parts(
119
+ request_messages, model_config
120
+ )
121
+
122
+ # Add buffer for prompt overhead, system instructions, etc.
123
+ total_estimated_input = estimated_input_tokens + INPUT_BUFFER_TOKENS
124
+
125
+ # For models with combined token limits (like GPT), use total limit
126
+ # For models with separate limits (like Claude), use output limit directly
127
+ if hasattr(model_config, "max_total_tokens"):
128
+ # Combined limit model
129
+ available_for_output = (
130
+ int(model_config.max_total_tokens) - total_estimated_input
131
+ )
132
+ max_output = min(available_for_output, int(model_config.max_output_tokens))
133
+ else:
134
+ # Separate limits model - just use max_output_tokens
135
+ max_output = int(model_config.max_output_tokens)
136
+
137
+ # Ensure we don't go below a minimum useful amount
138
+ return max(MIN_SUMMARY_TOKENS, max_output)