shotgun-sh 0.1.14__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/agent_manager.py +715 -75
- shotgun/agents/common.py +80 -75
- shotgun/agents/config/constants.py +21 -10
- shotgun/agents/config/manager.py +322 -97
- shotgun/agents/config/models.py +114 -84
- shotgun/agents/config/provider.py +232 -88
- shotgun/agents/context_analyzer/__init__.py +28 -0
- shotgun/agents/context_analyzer/analyzer.py +471 -0
- shotgun/agents/context_analyzer/constants.py +9 -0
- shotgun/agents/context_analyzer/formatter.py +115 -0
- shotgun/agents/context_analyzer/models.py +212 -0
- shotgun/agents/conversation_history.py +125 -2
- shotgun/agents/conversation_manager.py +57 -19
- shotgun/agents/export.py +6 -7
- shotgun/agents/history/compaction.py +10 -5
- shotgun/agents/history/context_extraction.py +93 -6
- shotgun/agents/history/history_processors.py +129 -12
- shotgun/agents/history/token_counting/__init__.py +31 -0
- shotgun/agents/history/token_counting/anthropic.py +127 -0
- shotgun/agents/history/token_counting/base.py +78 -0
- shotgun/agents/history/token_counting/openai.py +90 -0
- shotgun/agents/history/token_counting/sentencepiece_counter.py +127 -0
- shotgun/agents/history/token_counting/tokenizer_cache.py +92 -0
- shotgun/agents/history/token_counting/utils.py +144 -0
- shotgun/agents/history/token_estimation.py +12 -12
- shotgun/agents/llm.py +62 -0
- shotgun/agents/models.py +59 -4
- shotgun/agents/plan.py +6 -7
- shotgun/agents/research.py +7 -8
- shotgun/agents/specify.py +6 -7
- shotgun/agents/tasks.py +6 -7
- shotgun/agents/tools/__init__.py +0 -2
- shotgun/agents/tools/codebase/codebase_shell.py +6 -0
- shotgun/agents/tools/codebase/directory_lister.py +6 -0
- shotgun/agents/tools/codebase/file_read.py +11 -2
- shotgun/agents/tools/codebase/query_graph.py +6 -0
- shotgun/agents/tools/codebase/retrieve_code.py +6 -0
- shotgun/agents/tools/file_management.py +82 -16
- shotgun/agents/tools/registry.py +217 -0
- shotgun/agents/tools/web_search/__init__.py +55 -16
- shotgun/agents/tools/web_search/anthropic.py +76 -51
- shotgun/agents/tools/web_search/gemini.py +50 -27
- shotgun/agents/tools/web_search/openai.py +26 -17
- shotgun/agents/tools/web_search/utils.py +2 -2
- shotgun/agents/usage_manager.py +164 -0
- shotgun/api_endpoints.py +15 -0
- shotgun/cli/clear.py +53 -0
- shotgun/cli/compact.py +186 -0
- shotgun/cli/config.py +41 -67
- shotgun/cli/context.py +111 -0
- shotgun/cli/export.py +1 -1
- shotgun/cli/feedback.py +50 -0
- shotgun/cli/models.py +3 -2
- shotgun/cli/plan.py +1 -1
- shotgun/cli/research.py +1 -1
- shotgun/cli/specify.py +1 -1
- shotgun/cli/tasks.py +1 -1
- shotgun/cli/update.py +16 -2
- shotgun/codebase/core/change_detector.py +5 -3
- shotgun/codebase/core/code_retrieval.py +4 -2
- shotgun/codebase/core/ingestor.py +57 -16
- shotgun/codebase/core/manager.py +20 -7
- shotgun/codebase/core/nl_query.py +1 -1
- shotgun/codebase/models.py +4 -4
- shotgun/exceptions.py +32 -0
- shotgun/llm_proxy/__init__.py +19 -0
- shotgun/llm_proxy/clients.py +44 -0
- shotgun/llm_proxy/constants.py +15 -0
- shotgun/logging_config.py +18 -27
- shotgun/main.py +91 -12
- shotgun/posthog_telemetry.py +81 -10
- shotgun/prompts/agents/export.j2 +18 -1
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +5 -1
- shotgun/prompts/agents/partials/interactive_mode.j2 +24 -7
- shotgun/prompts/agents/plan.j2 +1 -1
- shotgun/prompts/agents/research.j2 +1 -1
- shotgun/prompts/agents/specify.j2 +270 -3
- shotgun/prompts/agents/state/system_state.j2 +4 -0
- shotgun/prompts/agents/tasks.j2 +1 -1
- shotgun/prompts/loader.py +2 -2
- shotgun/prompts/tools/web_search.j2 +14 -0
- shotgun/sentry_telemetry.py +27 -18
- shotgun/settings.py +238 -0
- shotgun/shotgun_web/__init__.py +19 -0
- shotgun/shotgun_web/client.py +138 -0
- shotgun/shotgun_web/constants.py +21 -0
- shotgun/shotgun_web/models.py +47 -0
- shotgun/telemetry.py +24 -36
- shotgun/tui/app.py +251 -23
- shotgun/tui/commands/__init__.py +1 -1
- shotgun/tui/components/context_indicator.py +179 -0
- shotgun/tui/components/mode_indicator.py +70 -0
- shotgun/tui/components/status_bar.py +48 -0
- shotgun/tui/containers.py +91 -0
- shotgun/tui/dependencies.py +39 -0
- shotgun/tui/protocols.py +45 -0
- shotgun/tui/screens/chat/__init__.py +5 -0
- shotgun/tui/screens/chat/chat.tcss +54 -0
- shotgun/tui/screens/chat/chat_screen.py +1234 -0
- shotgun/tui/screens/chat/codebase_index_prompt_screen.py +64 -0
- shotgun/tui/screens/chat/codebase_index_selection.py +12 -0
- shotgun/tui/screens/chat/help_text.py +40 -0
- shotgun/tui/screens/chat/prompt_history.py +48 -0
- shotgun/tui/screens/chat.tcss +11 -0
- shotgun/tui/screens/chat_screen/command_providers.py +226 -11
- shotgun/tui/screens/chat_screen/history/__init__.py +22 -0
- shotgun/tui/screens/chat_screen/history/agent_response.py +66 -0
- shotgun/tui/screens/chat_screen/history/chat_history.py +116 -0
- shotgun/tui/screens/chat_screen/history/formatters.py +115 -0
- shotgun/tui/screens/chat_screen/history/partial_response.py +43 -0
- shotgun/tui/screens/chat_screen/history/user_question.py +42 -0
- shotgun/tui/screens/confirmation_dialog.py +151 -0
- shotgun/tui/screens/feedback.py +193 -0
- shotgun/tui/screens/github_issue.py +102 -0
- shotgun/tui/screens/model_picker.py +352 -0
- shotgun/tui/screens/onboarding.py +431 -0
- shotgun/tui/screens/pipx_migration.py +153 -0
- shotgun/tui/screens/provider_config.py +156 -39
- shotgun/tui/screens/shotgun_auth.py +295 -0
- shotgun/tui/screens/welcome.py +198 -0
- shotgun/tui/services/__init__.py +5 -0
- shotgun/tui/services/conversation_service.py +184 -0
- shotgun/tui/state/__init__.py +7 -0
- shotgun/tui/state/processing_state.py +185 -0
- shotgun/tui/utils/mode_progress.py +14 -7
- shotgun/tui/widgets/__init__.py +5 -0
- shotgun/tui/widgets/widget_coordinator.py +262 -0
- shotgun/utils/datetime_utils.py +77 -0
- shotgun/utils/env_utils.py +13 -0
- shotgun/utils/file_system_utils.py +22 -2
- shotgun/utils/marketing.py +110 -0
- shotgun/utils/update_checker.py +69 -14
- shotgun_sh-0.2.11.dist-info/METADATA +130 -0
- shotgun_sh-0.2.11.dist-info/RECORD +194 -0
- {shotgun_sh-0.1.14.dist-info → shotgun_sh-0.2.11.dist-info}/entry_points.txt +1 -0
- {shotgun_sh-0.1.14.dist-info → shotgun_sh-0.2.11.dist-info}/licenses/LICENSE +1 -1
- shotgun/agents/history/token_counting.py +0 -429
- shotgun/agents/tools/user_interaction.py +0 -37
- shotgun/tui/screens/chat.py +0 -797
- shotgun/tui/screens/chat_screen/history.py +0 -350
- shotgun_sh-0.1.14.dist-info/METADATA +0 -466
- shotgun_sh-0.1.14.dist-info/RECORD +0 -133
- {shotgun_sh-0.1.14.dist-info → shotgun_sh-0.2.11.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""OpenAI token counting using tiktoken."""
|
|
2
|
+
|
|
3
|
+
from pydantic_ai.messages import ModelMessage
|
|
4
|
+
|
|
5
|
+
from shotgun.logging_config import get_logger
|
|
6
|
+
|
|
7
|
+
from .base import TokenCounter, extract_text_from_messages
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OpenAITokenCounter(TokenCounter):
|
|
13
|
+
"""Token counter for OpenAI models using tiktoken."""
|
|
14
|
+
|
|
15
|
+
# Official encoding mappings for OpenAI models
|
|
16
|
+
ENCODING_MAP = {
|
|
17
|
+
"gpt-5": "o200k_base",
|
|
18
|
+
"gpt-4o": "o200k_base",
|
|
19
|
+
"gpt-4": "cl100k_base",
|
|
20
|
+
"gpt-3.5-turbo": "cl100k_base",
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
def __init__(self, model_name: str):
|
|
24
|
+
"""Initialize OpenAI token counter.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model_name: OpenAI model name to get correct encoding for
|
|
28
|
+
|
|
29
|
+
Raises:
|
|
30
|
+
RuntimeError: If encoding initialization fails
|
|
31
|
+
"""
|
|
32
|
+
self.model_name = model_name
|
|
33
|
+
|
|
34
|
+
import tiktoken
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
# Get the appropriate encoding for this model
|
|
38
|
+
encoding_name = self.ENCODING_MAP.get(model_name, "o200k_base")
|
|
39
|
+
self.encoding = tiktoken.get_encoding(encoding_name)
|
|
40
|
+
logger.debug(
|
|
41
|
+
f"Initialized OpenAI token counter with {encoding_name} encoding"
|
|
42
|
+
)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
raise RuntimeError(
|
|
45
|
+
f"Failed to initialize tiktoken encoding for {model_name}"
|
|
46
|
+
) from e
|
|
47
|
+
|
|
48
|
+
async def count_tokens(self, text: str) -> int:
|
|
49
|
+
"""Count tokens using tiktoken (async).
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
text: Text to count tokens for
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Exact token count using tiktoken
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
RuntimeError: If token counting fails
|
|
59
|
+
"""
|
|
60
|
+
# Handle empty text to avoid unnecessary encoding
|
|
61
|
+
if not text or not text.strip():
|
|
62
|
+
return 0
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
return len(self.encoding.encode(text))
|
|
66
|
+
except BaseException as e:
|
|
67
|
+
# Must catch BaseException to handle PanicException from tiktoken's Rust layer
|
|
68
|
+
# which can occur with extremely long texts. Regular Exception won't catch it.
|
|
69
|
+
raise RuntimeError(
|
|
70
|
+
f"Failed to count tokens for OpenAI model {self.model_name}"
|
|
71
|
+
) from e
|
|
72
|
+
|
|
73
|
+
async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
|
|
74
|
+
"""Count tokens across all messages using tiktoken (async).
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
messages: List of PydanticAI messages
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Total token count for all messages
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
RuntimeError: If token counting fails
|
|
84
|
+
"""
|
|
85
|
+
# Handle empty message list early
|
|
86
|
+
if not messages:
|
|
87
|
+
return 0
|
|
88
|
+
|
|
89
|
+
total_text = extract_text_from_messages(messages)
|
|
90
|
+
return await self.count_tokens(total_text)
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Gemini token counting using official SentencePiece tokenizer.
|
|
2
|
+
|
|
3
|
+
This implementation uses Google's official Gemini/Gemma tokenizer model
|
|
4
|
+
for 100% accurate local token counting without API calls.
|
|
5
|
+
|
|
6
|
+
Performance: 10-100x faster than API-based counting.
|
|
7
|
+
Accuracy: 100% match with actual Gemini API usage.
|
|
8
|
+
|
|
9
|
+
The tokenizer is downloaded on first use and cached locally for future use.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from pydantic_ai.messages import ModelMessage
|
|
15
|
+
|
|
16
|
+
from shotgun.logging_config import get_logger
|
|
17
|
+
|
|
18
|
+
from .base import TokenCounter, extract_text_from_messages
|
|
19
|
+
from .tokenizer_cache import download_gemini_tokenizer, get_gemini_tokenizer_path
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SentencePieceTokenCounter(TokenCounter):
|
|
25
|
+
"""Token counter for Gemini models using official SentencePiece tokenizer.
|
|
26
|
+
|
|
27
|
+
This counter provides 100% accurate token counting for Gemini models
|
|
28
|
+
using the official tokenizer model from Google's gemma_pytorch repository.
|
|
29
|
+
Token counting is performed locally without any API calls, resulting in
|
|
30
|
+
10-100x performance improvement over API-based methods.
|
|
31
|
+
|
|
32
|
+
The tokenizer is downloaded asynchronously on first use and cached locally.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, model_name: str):
|
|
36
|
+
"""Initialize Gemini SentencePiece token counter.
|
|
37
|
+
|
|
38
|
+
The tokenizer is not loaded immediately - it will be downloaded and
|
|
39
|
+
loaded lazily on first use.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
model_name: Gemini model name (used for logging)
|
|
43
|
+
"""
|
|
44
|
+
self.model_name = model_name
|
|
45
|
+
self.sp: Any | None = None # SentencePieceProcessor, loaded lazily
|
|
46
|
+
|
|
47
|
+
async def _ensure_tokenizer(self) -> None:
|
|
48
|
+
"""Ensure tokenizer is downloaded and loaded.
|
|
49
|
+
|
|
50
|
+
This method downloads the tokenizer on first call (if not cached)
|
|
51
|
+
and loads it into memory. Subsequent calls reuse the loaded tokenizer.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
RuntimeError: If tokenizer download or loading fails
|
|
55
|
+
"""
|
|
56
|
+
if self.sp is not None:
|
|
57
|
+
# Already loaded
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
import sentencepiece as spm # type: ignore[import-untyped]
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
# Check if already cached, otherwise download
|
|
64
|
+
tokenizer_path = get_gemini_tokenizer_path()
|
|
65
|
+
if not tokenizer_path.exists():
|
|
66
|
+
await download_gemini_tokenizer()
|
|
67
|
+
|
|
68
|
+
# Load the tokenizer
|
|
69
|
+
self.sp = spm.SentencePieceProcessor()
|
|
70
|
+
self.sp.load(str(tokenizer_path))
|
|
71
|
+
logger.debug(f"Loaded SentencePiece tokenizer for {self.model_name}")
|
|
72
|
+
except Exception as e:
|
|
73
|
+
raise RuntimeError(
|
|
74
|
+
f"Failed to load Gemini tokenizer for {self.model_name}"
|
|
75
|
+
) from e
|
|
76
|
+
|
|
77
|
+
async def count_tokens(self, text: str) -> int:
|
|
78
|
+
"""Count tokens using SentencePiece (async).
|
|
79
|
+
|
|
80
|
+
Downloads tokenizer on first call if not cached.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
text: Text to count tokens for
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Exact token count using Gemini's tokenizer
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
RuntimeError: If token counting fails
|
|
90
|
+
"""
|
|
91
|
+
# Handle empty text to avoid unnecessary tokenization
|
|
92
|
+
if not text or not text.strip():
|
|
93
|
+
return 0
|
|
94
|
+
|
|
95
|
+
await self._ensure_tokenizer()
|
|
96
|
+
|
|
97
|
+
if self.sp is None:
|
|
98
|
+
raise RuntimeError(f"Tokenizer not initialized for {self.model_name}")
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
tokens = self.sp.encode(text)
|
|
102
|
+
return len(tokens)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
raise RuntimeError(
|
|
105
|
+
f"Failed to count tokens for Gemini model {self.model_name}"
|
|
106
|
+
) from e
|
|
107
|
+
|
|
108
|
+
async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
|
|
109
|
+
"""Count tokens across all messages using SentencePiece (async).
|
|
110
|
+
|
|
111
|
+
Downloads tokenizer on first call if not cached.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
messages: List of PydanticAI messages
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Total token count for all messages
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
RuntimeError: If token counting fails
|
|
121
|
+
"""
|
|
122
|
+
# Handle empty message list early
|
|
123
|
+
if not messages:
|
|
124
|
+
return 0
|
|
125
|
+
|
|
126
|
+
total_text = extract_text_from_messages(messages)
|
|
127
|
+
return await self.count_tokens(total_text)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Async tokenizer download and caching utilities."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import aiofiles
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
from shotgun.logging_config import get_logger
|
|
10
|
+
from shotgun.utils.file_system_utils import get_shotgun_home
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
# Gemini tokenizer constants
|
|
15
|
+
GEMINI_TOKENIZER_URL = "https://raw.githubusercontent.com/google/gemma_pytorch/main/tokenizer/tokenizer.model"
|
|
16
|
+
GEMINI_TOKENIZER_SHA256 = (
|
|
17
|
+
"61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_tokenizer_cache_dir() -> Path:
|
|
22
|
+
"""Get the directory for cached tokenizer models.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Path to tokenizers cache directory
|
|
26
|
+
"""
|
|
27
|
+
cache_dir = get_shotgun_home() / "tokenizers"
|
|
28
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
29
|
+
return cache_dir
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_gemini_tokenizer_path() -> Path:
|
|
33
|
+
"""Get the path where the Gemini tokenizer should be cached.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Path to cached Gemini tokenizer
|
|
37
|
+
"""
|
|
38
|
+
return get_tokenizer_cache_dir() / "gemini_tokenizer.model"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
async def download_gemini_tokenizer() -> Path:
|
|
42
|
+
"""Download and cache the official Gemini tokenizer model.
|
|
43
|
+
|
|
44
|
+
This downloads Google's official Gemini/Gemma tokenizer from the
|
|
45
|
+
gemma_pytorch repository and caches it locally for future use.
|
|
46
|
+
|
|
47
|
+
The download is async and non-blocking, with SHA256 verification
|
|
48
|
+
for security.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Path to the cached tokenizer file
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
RuntimeError: If download fails or checksum verification fails
|
|
55
|
+
"""
|
|
56
|
+
cache_path = get_gemini_tokenizer_path()
|
|
57
|
+
|
|
58
|
+
# Check if already cached
|
|
59
|
+
if cache_path.exists():
|
|
60
|
+
logger.debug(f"Gemini tokenizer already cached at {cache_path}")
|
|
61
|
+
return cache_path
|
|
62
|
+
|
|
63
|
+
logger.info("Downloading Gemini tokenizer (4MB, first time only)...")
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
# Download with async httpx
|
|
67
|
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
68
|
+
response = await client.get(GEMINI_TOKENIZER_URL, follow_redirects=True)
|
|
69
|
+
response.raise_for_status()
|
|
70
|
+
content = response.content
|
|
71
|
+
|
|
72
|
+
# Verify SHA256 checksum
|
|
73
|
+
actual_hash = hashlib.sha256(content).hexdigest()
|
|
74
|
+
if actual_hash != GEMINI_TOKENIZER_SHA256:
|
|
75
|
+
raise RuntimeError(
|
|
76
|
+
f"Gemini tokenizer checksum mismatch. "
|
|
77
|
+
f"Expected: {GEMINI_TOKENIZER_SHA256}, got: {actual_hash}"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Atomic write: write to temp file first, then rename
|
|
81
|
+
temp_path = cache_path.with_suffix(".tmp")
|
|
82
|
+
async with aiofiles.open(temp_path, "wb") as f:
|
|
83
|
+
await f.write(content)
|
|
84
|
+
temp_path.rename(cache_path)
|
|
85
|
+
|
|
86
|
+
logger.info(f"Gemini tokenizer downloaded and cached at {cache_path}")
|
|
87
|
+
return cache_path
|
|
88
|
+
|
|
89
|
+
except httpx.HTTPError as e:
|
|
90
|
+
raise RuntimeError(f"Failed to download Gemini tokenizer: {e}") from e
|
|
91
|
+
except OSError as e:
|
|
92
|
+
raise RuntimeError(f"Failed to save Gemini tokenizer: {e}") from e
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""Utility functions and cache for token counting."""
|
|
2
|
+
|
|
3
|
+
from pydantic_ai.messages import ModelMessage
|
|
4
|
+
|
|
5
|
+
from shotgun.agents.config.models import ModelConfig, ProviderType
|
|
6
|
+
from shotgun.logging_config import get_logger
|
|
7
|
+
|
|
8
|
+
from .anthropic import AnthropicTokenCounter
|
|
9
|
+
from .base import TokenCounter
|
|
10
|
+
from .openai import OpenAITokenCounter
|
|
11
|
+
from .sentencepiece_counter import SentencePieceTokenCounter
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
# Global cache for token counter instances (singleton pattern)
|
|
16
|
+
_token_counter_cache: dict[tuple[str, str, str], TokenCounter] = {}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_token_counter(model_config: ModelConfig) -> TokenCounter:
|
|
20
|
+
"""Get appropriate token counter for the model provider (cached singleton).
|
|
21
|
+
|
|
22
|
+
This function ensures that every provider has a proper token counting
|
|
23
|
+
implementation without any fallbacks to estimation. Token counters are
|
|
24
|
+
cached to avoid repeated initialization overhead.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model_config: Model configuration with provider and credentials
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Cached provider-specific token counter
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If provider is not supported for token counting
|
|
34
|
+
RuntimeError: If token counter initialization fails
|
|
35
|
+
"""
|
|
36
|
+
# Create cache key from provider, model name, and API key
|
|
37
|
+
cache_key = (
|
|
38
|
+
model_config.provider.value,
|
|
39
|
+
model_config.name,
|
|
40
|
+
model_config.api_key[:10]
|
|
41
|
+
if model_config.api_key
|
|
42
|
+
else "no-key", # Partial key for cache
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Return cached instance if available
|
|
46
|
+
if cache_key in _token_counter_cache:
|
|
47
|
+
return _token_counter_cache[cache_key]
|
|
48
|
+
|
|
49
|
+
# Create new instance and cache it
|
|
50
|
+
logger.debug(
|
|
51
|
+
f"Creating new token counter for {model_config.provider.value}:{model_config.name}"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
counter: TokenCounter
|
|
55
|
+
if model_config.provider == ProviderType.OPENAI:
|
|
56
|
+
counter = OpenAITokenCounter(model_config.name)
|
|
57
|
+
elif model_config.provider == ProviderType.ANTHROPIC:
|
|
58
|
+
counter = AnthropicTokenCounter(
|
|
59
|
+
model_config.name, model_config.api_key, model_config.key_provider
|
|
60
|
+
)
|
|
61
|
+
elif model_config.provider == ProviderType.GOOGLE:
|
|
62
|
+
# Use local SentencePiece tokenizer (100% accurate, 10-100x faster than API)
|
|
63
|
+
counter = SentencePieceTokenCounter(model_config.name)
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Unsupported provider for token counting: {model_config.provider}. "
|
|
67
|
+
f"Supported providers: {[p.value for p in ProviderType]}"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Cache the instance
|
|
71
|
+
_token_counter_cache[cache_key] = counter
|
|
72
|
+
logger.debug(
|
|
73
|
+
f"Cached token counter for {model_config.provider.value}:{model_config.name}"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return counter
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
async def count_tokens_from_messages(
|
|
80
|
+
messages: list[ModelMessage], model_config: ModelConfig
|
|
81
|
+
) -> int:
|
|
82
|
+
"""Count actual tokens from messages using provider-specific methods (async).
|
|
83
|
+
|
|
84
|
+
This replaces the old estimation approach with accurate token counting
|
|
85
|
+
using each provider's official APIs and libraries.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
messages: List of messages to count tokens for
|
|
89
|
+
model_config: Model configuration with provider info
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Exact token count for the messages
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ValueError: If provider is not supported
|
|
96
|
+
RuntimeError: If token counting fails
|
|
97
|
+
"""
|
|
98
|
+
counter = get_token_counter(model_config)
|
|
99
|
+
return await counter.count_message_tokens(messages)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
async def count_post_summary_tokens(
|
|
103
|
+
messages: list[ModelMessage], summary_index: int, model_config: ModelConfig
|
|
104
|
+
) -> int:
|
|
105
|
+
"""Count actual tokens from summary onwards for incremental compaction decisions (async).
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
messages: Full message history
|
|
109
|
+
summary_index: Index of the last summary message
|
|
110
|
+
model_config: Model configuration with provider info
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Exact token count from summary onwards
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
ValueError: If provider is not supported
|
|
117
|
+
RuntimeError: If token counting fails
|
|
118
|
+
"""
|
|
119
|
+
if summary_index >= len(messages):
|
|
120
|
+
return 0
|
|
121
|
+
|
|
122
|
+
post_summary_messages = messages[summary_index:]
|
|
123
|
+
return await count_tokens_from_messages(post_summary_messages, model_config)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def count_tokens_from_message_parts(
|
|
127
|
+
messages: list[ModelMessage], model_config: ModelConfig
|
|
128
|
+
) -> int:
|
|
129
|
+
"""Count actual tokens from message parts for summarization requests (async).
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
messages: List of messages to count tokens for
|
|
133
|
+
model_config: Model configuration with provider info
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Exact token count from message parts
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
ValueError: If provider is not supported
|
|
140
|
+
RuntimeError: If token counting fails
|
|
141
|
+
"""
|
|
142
|
+
# For now, use the same logic as count_tokens_from_messages
|
|
143
|
+
# This can be optimized later if needed for different counting strategies
|
|
144
|
+
return await count_tokens_from_messages(messages, model_config)
|
|
@@ -19,10 +19,10 @@ from .constants import INPUT_BUFFER_TOKENS, MIN_SUMMARY_TOKENS
|
|
|
19
19
|
from .token_counting import count_tokens_from_messages as _count_tokens_from_messages
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def estimate_tokens_from_messages(
|
|
22
|
+
async def estimate_tokens_from_messages(
|
|
23
23
|
messages: list[ModelMessage], model_config: ModelConfig
|
|
24
24
|
) -> int:
|
|
25
|
-
"""Count actual tokens from current message list.
|
|
25
|
+
"""Count actual tokens from current message list (async).
|
|
26
26
|
|
|
27
27
|
This provides accurate token counting for compaction decisions using
|
|
28
28
|
provider-specific token counting methods instead of rough estimation.
|
|
@@ -38,13 +38,13 @@ def estimate_tokens_from_messages(
|
|
|
38
38
|
ValueError: If provider is not supported
|
|
39
39
|
RuntimeError: If token counting fails
|
|
40
40
|
"""
|
|
41
|
-
return _count_tokens_from_messages(messages, model_config)
|
|
41
|
+
return await _count_tokens_from_messages(messages, model_config)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def estimate_post_summary_tokens(
|
|
44
|
+
async def estimate_post_summary_tokens(
|
|
45
45
|
messages: list[ModelMessage], summary_index: int, model_config: ModelConfig
|
|
46
46
|
) -> int:
|
|
47
|
-
"""Count actual tokens from summary onwards for incremental compaction decisions.
|
|
47
|
+
"""Count actual tokens from summary onwards for incremental compaction decisions (async).
|
|
48
48
|
|
|
49
49
|
This treats the summary as a reset point and only counts tokens from the summary
|
|
50
50
|
message onwards. Used to determine if incremental compaction is needed.
|
|
@@ -65,13 +65,13 @@ def estimate_post_summary_tokens(
|
|
|
65
65
|
return 0
|
|
66
66
|
|
|
67
67
|
post_summary_messages = messages[summary_index:]
|
|
68
|
-
return estimate_tokens_from_messages(post_summary_messages, model_config)
|
|
68
|
+
return await estimate_tokens_from_messages(post_summary_messages, model_config)
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
def estimate_tokens_from_message_parts(
|
|
71
|
+
async def estimate_tokens_from_message_parts(
|
|
72
72
|
messages: list[ModelMessage], model_config: ModelConfig
|
|
73
73
|
) -> int:
|
|
74
|
-
"""Count actual tokens from message parts for summarization requests.
|
|
74
|
+
"""Count actual tokens from message parts for summarization requests (async).
|
|
75
75
|
|
|
76
76
|
This provides accurate token counting across the codebase using
|
|
77
77
|
provider-specific methods instead of character estimation.
|
|
@@ -87,14 +87,14 @@ def estimate_tokens_from_message_parts(
|
|
|
87
87
|
ValueError: If provider is not supported
|
|
88
88
|
RuntimeError: If token counting fails
|
|
89
89
|
"""
|
|
90
|
-
return _count_tokens_from_messages(messages, model_config)
|
|
90
|
+
return await _count_tokens_from_messages(messages, model_config)
|
|
91
91
|
|
|
92
92
|
|
|
93
|
-
def calculate_max_summarization_tokens(
|
|
93
|
+
async def calculate_max_summarization_tokens(
|
|
94
94
|
ctx_or_model_config: Union["RunContext[AgentDeps]", ModelConfig],
|
|
95
95
|
request_messages: list[ModelMessage],
|
|
96
96
|
) -> int:
|
|
97
|
-
"""Calculate maximum tokens available for summarization output.
|
|
97
|
+
"""Calculate maximum tokens available for summarization output (async).
|
|
98
98
|
|
|
99
99
|
This ensures we use the model's full capacity while leaving room for input tokens.
|
|
100
100
|
|
|
@@ -115,7 +115,7 @@ def calculate_max_summarization_tokens(
|
|
|
115
115
|
return MIN_SUMMARY_TOKENS
|
|
116
116
|
|
|
117
117
|
# Count actual input tokens using shared utility
|
|
118
|
-
estimated_input_tokens = estimate_tokens_from_message_parts(
|
|
118
|
+
estimated_input_tokens = await estimate_tokens_from_message_parts(
|
|
119
119
|
request_messages, model_config
|
|
120
120
|
)
|
|
121
121
|
|
shotgun/agents/llm.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""LLM request utilities for Shotgun agents."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic_ai.direct import model_request
|
|
6
|
+
from pydantic_ai.messages import ModelMessage, ModelResponse
|
|
7
|
+
from pydantic_ai.settings import ModelSettings
|
|
8
|
+
|
|
9
|
+
from shotgun.agents.config.models import ModelConfig
|
|
10
|
+
from shotgun.logging_config import get_logger
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
async def shotgun_model_request(
|
|
16
|
+
model_config: ModelConfig,
|
|
17
|
+
messages: list[ModelMessage],
|
|
18
|
+
model_settings: ModelSettings | None = None,
|
|
19
|
+
**kwargs: Any,
|
|
20
|
+
) -> ModelResponse:
|
|
21
|
+
"""Model request wrapper that uses full token capacity by default.
|
|
22
|
+
|
|
23
|
+
This wrapper ensures all LLM calls in Shotgun use the maximum available
|
|
24
|
+
token capacity of each model, improving response quality and completeness.
|
|
25
|
+
The most common issue this fixes is truncated summaries that were cut off
|
|
26
|
+
at default token limits (e.g., 4096 for Claude models).
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
model_config: ModelConfig instance with model settings and API key
|
|
30
|
+
messages: Messages to send to the model
|
|
31
|
+
model_settings: Optional ModelSettings. If None, creates default with max tokens
|
|
32
|
+
**kwargs: Additional arguments passed to model_request
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
ModelResponse from the model
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
# Uses full token capacity (e.g., 4096 for Claude, 128k for GPT-5)
|
|
39
|
+
response = await shotgun_model_request(model_config, messages)
|
|
40
|
+
|
|
41
|
+
# With custom settings
|
|
42
|
+
response = await shotgun_model_request(model_config, messages, model_settings=ModelSettings(max_tokens=1000, temperature=0.7))
|
|
43
|
+
"""
|
|
44
|
+
if kwargs.get("max_tokens") is not None:
|
|
45
|
+
logger.warning(
|
|
46
|
+
"⚠️ 'max_tokens' argument is ignored in shotgun_model_request. "
|
|
47
|
+
"Set 'model_settings.max_tokens' instead."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if not model_settings:
|
|
51
|
+
model_settings = ModelSettings()
|
|
52
|
+
|
|
53
|
+
if model_settings.get("max_tokens") is None:
|
|
54
|
+
model_settings["max_tokens"] = model_config.max_output_tokens
|
|
55
|
+
|
|
56
|
+
# Make the model request with full token utilization
|
|
57
|
+
return await model_request(
|
|
58
|
+
model=model_config.model_instance,
|
|
59
|
+
messages=messages,
|
|
60
|
+
model_settings=model_settings,
|
|
61
|
+
**kwargs,
|
|
62
|
+
)
|
shotgun/agents/models.py
CHANGED
|
@@ -4,19 +4,45 @@ import os
|
|
|
4
4
|
from asyncio import Future, Queue
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from datetime import datetime
|
|
7
|
-
from enum import
|
|
7
|
+
from enum import StrEnum
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
from typing import TYPE_CHECKING
|
|
10
10
|
|
|
11
11
|
from pydantic import BaseModel, ConfigDict, Field
|
|
12
12
|
from pydantic_ai import RunContext
|
|
13
13
|
|
|
14
|
+
from shotgun.agents.usage_manager import SessionUsageManager, get_session_usage_manager
|
|
15
|
+
|
|
14
16
|
from .config.models import ModelConfig
|
|
15
17
|
|
|
16
18
|
if TYPE_CHECKING:
|
|
17
19
|
from shotgun.codebase.service import CodebaseService
|
|
18
20
|
|
|
19
21
|
|
|
22
|
+
class AgentResponse(BaseModel):
|
|
23
|
+
"""Structured response from an agent with optional clarifying questions.
|
|
24
|
+
|
|
25
|
+
This model provides a consistent response format for all agents:
|
|
26
|
+
- response: The main response text (can be empty if only asking questions)
|
|
27
|
+
- clarifying_questions: Optional list of questions to ask the user
|
|
28
|
+
|
|
29
|
+
When clarifying_questions is provided, the agent expects to receive
|
|
30
|
+
answers before continuing its work. This replaces the ask_questions tool.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
response: str = Field(
|
|
34
|
+
description="The agent's response text. Always respond with some text summarizing what happened, whats next, etc.",
|
|
35
|
+
)
|
|
36
|
+
clarifying_questions: list[str] | None = Field(
|
|
37
|
+
default=None,
|
|
38
|
+
description="""
|
|
39
|
+
Optional list of clarifying questions to ask the user.
|
|
40
|
+
- Single question: Shown as a non-blocking suggestion (user can answer or continue with other prompts)
|
|
41
|
+
- Multiple questions (2+): Asked sequentially in Q&A mode (blocks input until all answered or cancelled)
|
|
42
|
+
""",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
20
46
|
class AgentType(StrEnum):
|
|
21
47
|
"""Enumeration for available agent types."""
|
|
22
48
|
|
|
@@ -71,6 +97,30 @@ class UserQuestion(BaseModel):
|
|
|
71
97
|
)
|
|
72
98
|
|
|
73
99
|
|
|
100
|
+
class MultipleUserQuestions(BaseModel):
|
|
101
|
+
"""Multiple questions to ask the user sequentially."""
|
|
102
|
+
|
|
103
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
104
|
+
|
|
105
|
+
questions: list[str] = Field(
|
|
106
|
+
description="List of questions to ask the user",
|
|
107
|
+
)
|
|
108
|
+
current_index: int = Field(
|
|
109
|
+
default=0,
|
|
110
|
+
description="Current question index being asked",
|
|
111
|
+
)
|
|
112
|
+
answers: list[str] = Field(
|
|
113
|
+
default_factory=list,
|
|
114
|
+
description="Accumulated answers from the user",
|
|
115
|
+
)
|
|
116
|
+
tool_call_id: str = Field(
|
|
117
|
+
description="Tool call id",
|
|
118
|
+
)
|
|
119
|
+
result: Future[UserAnswer] = Field(
|
|
120
|
+
description="Future that will contain all answers formatted as Q&A pairs"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
74
124
|
class AgentRuntimeOptions(BaseModel):
|
|
75
125
|
"""User interface options for agents."""
|
|
76
126
|
|
|
@@ -98,9 +148,9 @@ class AgentRuntimeOptions(BaseModel):
|
|
|
98
148
|
description="Maximum number of iterations for agent loops",
|
|
99
149
|
)
|
|
100
150
|
|
|
101
|
-
queue: Queue[UserQuestion] = Field(
|
|
151
|
+
queue: Queue[UserQuestion | MultipleUserQuestions] = Field(
|
|
102
152
|
default_factory=Queue,
|
|
103
|
-
description="Queue for storing user
|
|
153
|
+
description="Queue for storing user questions (single or multiple)",
|
|
104
154
|
)
|
|
105
155
|
|
|
106
156
|
tasks: list[Future[UserAnswer]] = Field(
|
|
@@ -108,8 +158,13 @@ class AgentRuntimeOptions(BaseModel):
|
|
|
108
158
|
description="Tasks for storing deferred tool results",
|
|
109
159
|
)
|
|
110
160
|
|
|
161
|
+
usage_manager: SessionUsageManager = Field(
|
|
162
|
+
default_factory=get_session_usage_manager,
|
|
163
|
+
description="Usage manager for tracking usage",
|
|
164
|
+
)
|
|
165
|
+
|
|
111
166
|
|
|
112
|
-
class FileOperationType(
|
|
167
|
+
class FileOperationType(StrEnum):
|
|
113
168
|
"""Types of file operations that can be tracked."""
|
|
114
169
|
|
|
115
170
|
CREATED = "created"
|