shotgun-sh 0.1.14__py3-none-any.whl ā 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/agent_manager.py +715 -75
- shotgun/agents/common.py +80 -75
- shotgun/agents/config/constants.py +21 -10
- shotgun/agents/config/manager.py +322 -97
- shotgun/agents/config/models.py +114 -84
- shotgun/agents/config/provider.py +232 -88
- shotgun/agents/context_analyzer/__init__.py +28 -0
- shotgun/agents/context_analyzer/analyzer.py +471 -0
- shotgun/agents/context_analyzer/constants.py +9 -0
- shotgun/agents/context_analyzer/formatter.py +115 -0
- shotgun/agents/context_analyzer/models.py +212 -0
- shotgun/agents/conversation_history.py +125 -2
- shotgun/agents/conversation_manager.py +57 -19
- shotgun/agents/export.py +6 -7
- shotgun/agents/history/compaction.py +10 -5
- shotgun/agents/history/context_extraction.py +93 -6
- shotgun/agents/history/history_processors.py +129 -12
- shotgun/agents/history/token_counting/__init__.py +31 -0
- shotgun/agents/history/token_counting/anthropic.py +127 -0
- shotgun/agents/history/token_counting/base.py +78 -0
- shotgun/agents/history/token_counting/openai.py +90 -0
- shotgun/agents/history/token_counting/sentencepiece_counter.py +127 -0
- shotgun/agents/history/token_counting/tokenizer_cache.py +92 -0
- shotgun/agents/history/token_counting/utils.py +144 -0
- shotgun/agents/history/token_estimation.py +12 -12
- shotgun/agents/llm.py +62 -0
- shotgun/agents/models.py +59 -4
- shotgun/agents/plan.py +6 -7
- shotgun/agents/research.py +7 -8
- shotgun/agents/specify.py +6 -7
- shotgun/agents/tasks.py +6 -7
- shotgun/agents/tools/__init__.py +0 -2
- shotgun/agents/tools/codebase/codebase_shell.py +6 -0
- shotgun/agents/tools/codebase/directory_lister.py +6 -0
- shotgun/agents/tools/codebase/file_read.py +11 -2
- shotgun/agents/tools/codebase/query_graph.py +6 -0
- shotgun/agents/tools/codebase/retrieve_code.py +6 -0
- shotgun/agents/tools/file_management.py +82 -16
- shotgun/agents/tools/registry.py +217 -0
- shotgun/agents/tools/web_search/__init__.py +55 -16
- shotgun/agents/tools/web_search/anthropic.py +76 -51
- shotgun/agents/tools/web_search/gemini.py +50 -27
- shotgun/agents/tools/web_search/openai.py +26 -17
- shotgun/agents/tools/web_search/utils.py +2 -2
- shotgun/agents/usage_manager.py +164 -0
- shotgun/api_endpoints.py +15 -0
- shotgun/cli/clear.py +53 -0
- shotgun/cli/compact.py +186 -0
- shotgun/cli/config.py +41 -67
- shotgun/cli/context.py +111 -0
- shotgun/cli/export.py +1 -1
- shotgun/cli/feedback.py +50 -0
- shotgun/cli/models.py +3 -2
- shotgun/cli/plan.py +1 -1
- shotgun/cli/research.py +1 -1
- shotgun/cli/specify.py +1 -1
- shotgun/cli/tasks.py +1 -1
- shotgun/cli/update.py +16 -2
- shotgun/codebase/core/change_detector.py +5 -3
- shotgun/codebase/core/code_retrieval.py +4 -2
- shotgun/codebase/core/ingestor.py +57 -16
- shotgun/codebase/core/manager.py +20 -7
- shotgun/codebase/core/nl_query.py +1 -1
- shotgun/codebase/models.py +4 -4
- shotgun/exceptions.py +32 -0
- shotgun/llm_proxy/__init__.py +19 -0
- shotgun/llm_proxy/clients.py +44 -0
- shotgun/llm_proxy/constants.py +15 -0
- shotgun/logging_config.py +18 -27
- shotgun/main.py +91 -12
- shotgun/posthog_telemetry.py +81 -10
- shotgun/prompts/agents/export.j2 +18 -1
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +5 -1
- shotgun/prompts/agents/partials/interactive_mode.j2 +24 -7
- shotgun/prompts/agents/plan.j2 +1 -1
- shotgun/prompts/agents/research.j2 +1 -1
- shotgun/prompts/agents/specify.j2 +270 -3
- shotgun/prompts/agents/state/system_state.j2 +4 -0
- shotgun/prompts/agents/tasks.j2 +1 -1
- shotgun/prompts/loader.py +2 -2
- shotgun/prompts/tools/web_search.j2 +14 -0
- shotgun/sentry_telemetry.py +27 -18
- shotgun/settings.py +238 -0
- shotgun/shotgun_web/__init__.py +19 -0
- shotgun/shotgun_web/client.py +138 -0
- shotgun/shotgun_web/constants.py +21 -0
- shotgun/shotgun_web/models.py +47 -0
- shotgun/telemetry.py +24 -36
- shotgun/tui/app.py +251 -23
- shotgun/tui/commands/__init__.py +1 -1
- shotgun/tui/components/context_indicator.py +179 -0
- shotgun/tui/components/mode_indicator.py +70 -0
- shotgun/tui/components/status_bar.py +48 -0
- shotgun/tui/containers.py +91 -0
- shotgun/tui/dependencies.py +39 -0
- shotgun/tui/protocols.py +45 -0
- shotgun/tui/screens/chat/__init__.py +5 -0
- shotgun/tui/screens/chat/chat.tcss +54 -0
- shotgun/tui/screens/chat/chat_screen.py +1234 -0
- shotgun/tui/screens/chat/codebase_index_prompt_screen.py +64 -0
- shotgun/tui/screens/chat/codebase_index_selection.py +12 -0
- shotgun/tui/screens/chat/help_text.py +40 -0
- shotgun/tui/screens/chat/prompt_history.py +48 -0
- shotgun/tui/screens/chat.tcss +11 -0
- shotgun/tui/screens/chat_screen/command_providers.py +226 -11
- shotgun/tui/screens/chat_screen/history/__init__.py +22 -0
- shotgun/tui/screens/chat_screen/history/agent_response.py +66 -0
- shotgun/tui/screens/chat_screen/history/chat_history.py +116 -0
- shotgun/tui/screens/chat_screen/history/formatters.py +115 -0
- shotgun/tui/screens/chat_screen/history/partial_response.py +43 -0
- shotgun/tui/screens/chat_screen/history/user_question.py +42 -0
- shotgun/tui/screens/confirmation_dialog.py +151 -0
- shotgun/tui/screens/feedback.py +193 -0
- shotgun/tui/screens/github_issue.py +102 -0
- shotgun/tui/screens/model_picker.py +352 -0
- shotgun/tui/screens/onboarding.py +431 -0
- shotgun/tui/screens/pipx_migration.py +153 -0
- shotgun/tui/screens/provider_config.py +156 -39
- shotgun/tui/screens/shotgun_auth.py +295 -0
- shotgun/tui/screens/welcome.py +198 -0
- shotgun/tui/services/__init__.py +5 -0
- shotgun/tui/services/conversation_service.py +184 -0
- shotgun/tui/state/__init__.py +7 -0
- shotgun/tui/state/processing_state.py +185 -0
- shotgun/tui/utils/mode_progress.py +14 -7
- shotgun/tui/widgets/__init__.py +5 -0
- shotgun/tui/widgets/widget_coordinator.py +262 -0
- shotgun/utils/datetime_utils.py +77 -0
- shotgun/utils/env_utils.py +13 -0
- shotgun/utils/file_system_utils.py +22 -2
- shotgun/utils/marketing.py +110 -0
- shotgun/utils/update_checker.py +69 -14
- shotgun_sh-0.2.11.dist-info/METADATA +130 -0
- shotgun_sh-0.2.11.dist-info/RECORD +194 -0
- {shotgun_sh-0.1.14.dist-info ā shotgun_sh-0.2.11.dist-info}/entry_points.txt +1 -0
- {shotgun_sh-0.1.14.dist-info ā shotgun_sh-0.2.11.dist-info}/licenses/LICENSE +1 -1
- shotgun/agents/history/token_counting.py +0 -429
- shotgun/agents/tools/user_interaction.py +0 -37
- shotgun/tui/screens/chat.py +0 -797
- shotgun/tui/screens/chat_screen/history.py +0 -350
- shotgun_sh-0.1.14.dist-info/METADATA +0 -466
- shotgun_sh-0.1.14.dist-info/RECORD +0 -133
- {shotgun_sh-0.1.14.dist-info ā shotgun_sh-0.2.11.dist-info}/WHEEL +0 -0
|
@@ -1,429 +0,0 @@
|
|
|
1
|
-
"""Real token counting for all supported providers.
|
|
2
|
-
|
|
3
|
-
This module provides accurate token counting using each provider's official
|
|
4
|
-
APIs and libraries, eliminating the need for rough character-based estimation.
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
from abc import ABC, abstractmethod
|
|
8
|
-
from typing import TYPE_CHECKING
|
|
9
|
-
|
|
10
|
-
from pydantic_ai.messages import ModelMessage
|
|
11
|
-
|
|
12
|
-
from shotgun.agents.config.models import ModelConfig, ProviderType
|
|
13
|
-
from shotgun.logging_config import get_logger
|
|
14
|
-
|
|
15
|
-
if TYPE_CHECKING:
|
|
16
|
-
pass
|
|
17
|
-
|
|
18
|
-
logger = get_logger(__name__)
|
|
19
|
-
|
|
20
|
-
# Global cache for token counter instances (singleton pattern)
|
|
21
|
-
_token_counter_cache: dict[tuple[str, str, str], "TokenCounter"] = {}
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class TokenCounter(ABC):
|
|
25
|
-
"""Abstract base class for provider-specific token counting."""
|
|
26
|
-
|
|
27
|
-
@abstractmethod
|
|
28
|
-
def count_tokens(self, text: str) -> int:
|
|
29
|
-
"""Count tokens in text using provider-specific method.
|
|
30
|
-
|
|
31
|
-
Args:
|
|
32
|
-
text: Text to count tokens for
|
|
33
|
-
|
|
34
|
-
Returns:
|
|
35
|
-
Exact token count as determined by the provider
|
|
36
|
-
|
|
37
|
-
Raises:
|
|
38
|
-
RuntimeError: If token counting fails
|
|
39
|
-
"""
|
|
40
|
-
|
|
41
|
-
@abstractmethod
|
|
42
|
-
def count_message_tokens(self, messages: list[ModelMessage]) -> int:
|
|
43
|
-
"""Count tokens in PydanticAI message structures.
|
|
44
|
-
|
|
45
|
-
Args:
|
|
46
|
-
messages: List of messages to count tokens for
|
|
47
|
-
|
|
48
|
-
Returns:
|
|
49
|
-
Total token count across all messages
|
|
50
|
-
|
|
51
|
-
Raises:
|
|
52
|
-
RuntimeError: If token counting fails
|
|
53
|
-
"""
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class OpenAITokenCounter(TokenCounter):
|
|
57
|
-
"""Token counter for OpenAI models using tiktoken."""
|
|
58
|
-
|
|
59
|
-
# Official encoding mappings for OpenAI models
|
|
60
|
-
ENCODING_MAP = {
|
|
61
|
-
"gpt-5": "o200k_base",
|
|
62
|
-
"gpt-4o": "o200k_base",
|
|
63
|
-
"gpt-4": "cl100k_base",
|
|
64
|
-
"gpt-3.5-turbo": "cl100k_base",
|
|
65
|
-
}
|
|
66
|
-
|
|
67
|
-
def __init__(self, model_name: str):
|
|
68
|
-
"""Initialize OpenAI token counter.
|
|
69
|
-
|
|
70
|
-
Args:
|
|
71
|
-
model_name: OpenAI model name to get correct encoding for
|
|
72
|
-
|
|
73
|
-
Raises:
|
|
74
|
-
RuntimeError: If encoding initialization fails
|
|
75
|
-
"""
|
|
76
|
-
self.model_name = model_name
|
|
77
|
-
|
|
78
|
-
import tiktoken
|
|
79
|
-
|
|
80
|
-
try:
|
|
81
|
-
# Get the appropriate encoding for this model
|
|
82
|
-
encoding_name = self.ENCODING_MAP.get(model_name, "o200k_base")
|
|
83
|
-
self.encoding = tiktoken.get_encoding(encoding_name)
|
|
84
|
-
logger.debug(
|
|
85
|
-
f"Initialized OpenAI token counter with {encoding_name} encoding"
|
|
86
|
-
)
|
|
87
|
-
except Exception as e:
|
|
88
|
-
raise RuntimeError(
|
|
89
|
-
f"Failed to initialize tiktoken encoding for {model_name}"
|
|
90
|
-
) from e
|
|
91
|
-
|
|
92
|
-
def count_tokens(self, text: str) -> int:
|
|
93
|
-
"""Count tokens using tiktoken.
|
|
94
|
-
|
|
95
|
-
Args:
|
|
96
|
-
text: Text to count tokens for
|
|
97
|
-
|
|
98
|
-
Returns:
|
|
99
|
-
Exact token count using tiktoken
|
|
100
|
-
|
|
101
|
-
Raises:
|
|
102
|
-
RuntimeError: If token counting fails
|
|
103
|
-
"""
|
|
104
|
-
try:
|
|
105
|
-
return len(self.encoding.encode(text))
|
|
106
|
-
except Exception as e:
|
|
107
|
-
raise RuntimeError(
|
|
108
|
-
f"Failed to count tokens for OpenAI model {self.model_name}"
|
|
109
|
-
) from e
|
|
110
|
-
|
|
111
|
-
def count_message_tokens(self, messages: list[ModelMessage]) -> int:
|
|
112
|
-
"""Count tokens across all messages using tiktoken.
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
messages: List of PydanticAI messages
|
|
116
|
-
|
|
117
|
-
Returns:
|
|
118
|
-
Total token count for all messages
|
|
119
|
-
|
|
120
|
-
Raises:
|
|
121
|
-
RuntimeError: If token counting fails
|
|
122
|
-
"""
|
|
123
|
-
total_text = self._extract_text_from_messages(messages)
|
|
124
|
-
return self.count_tokens(total_text)
|
|
125
|
-
|
|
126
|
-
def _extract_text_from_messages(self, messages: list[ModelMessage]) -> str:
|
|
127
|
-
"""Extract all text content from messages for token counting."""
|
|
128
|
-
text_parts = []
|
|
129
|
-
|
|
130
|
-
for message in messages:
|
|
131
|
-
if hasattr(message, "parts"):
|
|
132
|
-
for part in message.parts:
|
|
133
|
-
if hasattr(part, "content") and isinstance(part.content, str):
|
|
134
|
-
text_parts.append(part.content)
|
|
135
|
-
else:
|
|
136
|
-
# Handle non-text parts (tool calls, etc.)
|
|
137
|
-
text_parts.append(str(part))
|
|
138
|
-
else:
|
|
139
|
-
# Handle messages without parts
|
|
140
|
-
text_parts.append(str(message))
|
|
141
|
-
|
|
142
|
-
return "\n".join(text_parts)
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
class AnthropicTokenCounter(TokenCounter):
|
|
146
|
-
"""Token counter for Anthropic models using official client."""
|
|
147
|
-
|
|
148
|
-
def __init__(self, model_name: str, api_key: str):
|
|
149
|
-
"""Initialize Anthropic token counter.
|
|
150
|
-
|
|
151
|
-
Args:
|
|
152
|
-
model_name: Anthropic model name for token counting
|
|
153
|
-
api_key: Anthropic API key
|
|
154
|
-
|
|
155
|
-
Raises:
|
|
156
|
-
RuntimeError: If client initialization fails
|
|
157
|
-
"""
|
|
158
|
-
self.model_name = model_name
|
|
159
|
-
import anthropic
|
|
160
|
-
|
|
161
|
-
try:
|
|
162
|
-
self.client = anthropic.Anthropic(api_key=api_key)
|
|
163
|
-
logger.debug(f"Initialized Anthropic token counter for {model_name}")
|
|
164
|
-
except Exception as e:
|
|
165
|
-
raise RuntimeError("Failed to initialize Anthropic client") from e
|
|
166
|
-
|
|
167
|
-
def count_tokens(self, text: str) -> int:
|
|
168
|
-
"""Count tokens using Anthropic's official API.
|
|
169
|
-
|
|
170
|
-
Args:
|
|
171
|
-
text: Text to count tokens for
|
|
172
|
-
|
|
173
|
-
Returns:
|
|
174
|
-
Exact token count from Anthropic API
|
|
175
|
-
|
|
176
|
-
Raises:
|
|
177
|
-
RuntimeError: If API call fails
|
|
178
|
-
"""
|
|
179
|
-
try:
|
|
180
|
-
# Anthropic API expects messages format and model parameter
|
|
181
|
-
result = self.client.messages.count_tokens(
|
|
182
|
-
messages=[{"role": "user", "content": text}], model=self.model_name
|
|
183
|
-
)
|
|
184
|
-
return result.input_tokens
|
|
185
|
-
except Exception as e:
|
|
186
|
-
raise RuntimeError(
|
|
187
|
-
f"Anthropic token counting API failed for {self.model_name}"
|
|
188
|
-
) from e
|
|
189
|
-
|
|
190
|
-
def count_message_tokens(self, messages: list[ModelMessage]) -> int:
|
|
191
|
-
"""Count tokens across all messages using Anthropic API.
|
|
192
|
-
|
|
193
|
-
Args:
|
|
194
|
-
messages: List of PydanticAI messages
|
|
195
|
-
|
|
196
|
-
Returns:
|
|
197
|
-
Total token count for all messages
|
|
198
|
-
|
|
199
|
-
Raises:
|
|
200
|
-
RuntimeError: If token counting fails
|
|
201
|
-
"""
|
|
202
|
-
total_text = self._extract_text_from_messages(messages)
|
|
203
|
-
return self.count_tokens(total_text)
|
|
204
|
-
|
|
205
|
-
def _extract_text_from_messages(self, messages: list[ModelMessage]) -> str:
|
|
206
|
-
"""Extract all text content from messages for token counting."""
|
|
207
|
-
text_parts = []
|
|
208
|
-
|
|
209
|
-
for message in messages:
|
|
210
|
-
if hasattr(message, "parts"):
|
|
211
|
-
for part in message.parts:
|
|
212
|
-
if hasattr(part, "content") and isinstance(part.content, str):
|
|
213
|
-
text_parts.append(part.content)
|
|
214
|
-
else:
|
|
215
|
-
# Handle non-text parts (tool calls, etc.)
|
|
216
|
-
text_parts.append(str(part))
|
|
217
|
-
else:
|
|
218
|
-
# Handle messages without parts
|
|
219
|
-
text_parts.append(str(message))
|
|
220
|
-
|
|
221
|
-
return "\n".join(text_parts)
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
class GoogleTokenCounter(TokenCounter):
|
|
225
|
-
"""Token counter for Google models using genai API."""
|
|
226
|
-
|
|
227
|
-
def __init__(self, model_name: str, api_key: str):
|
|
228
|
-
"""Initialize Google token counter.
|
|
229
|
-
|
|
230
|
-
Args:
|
|
231
|
-
model_name: Google model name
|
|
232
|
-
api_key: Google API key
|
|
233
|
-
|
|
234
|
-
Raises:
|
|
235
|
-
RuntimeError: If configuration fails
|
|
236
|
-
"""
|
|
237
|
-
self.model_name = model_name
|
|
238
|
-
|
|
239
|
-
import google.generativeai as genai
|
|
240
|
-
|
|
241
|
-
try:
|
|
242
|
-
genai.configure(api_key=api_key) # type: ignore[attr-defined]
|
|
243
|
-
self.model = genai.GenerativeModel(model_name) # type: ignore[attr-defined]
|
|
244
|
-
logger.debug(f"Initialized Google token counter for {model_name}")
|
|
245
|
-
except Exception as e:
|
|
246
|
-
raise RuntimeError(
|
|
247
|
-
f"Failed to configure Google genai client for {model_name}"
|
|
248
|
-
) from e
|
|
249
|
-
|
|
250
|
-
def count_tokens(self, text: str) -> int:
|
|
251
|
-
"""Count tokens using Google's genai API.
|
|
252
|
-
|
|
253
|
-
Args:
|
|
254
|
-
text: Text to count tokens for
|
|
255
|
-
|
|
256
|
-
Returns:
|
|
257
|
-
Exact token count from Google API
|
|
258
|
-
|
|
259
|
-
Raises:
|
|
260
|
-
RuntimeError: If API call fails
|
|
261
|
-
"""
|
|
262
|
-
try:
|
|
263
|
-
result = self.model.count_tokens(text)
|
|
264
|
-
return result.total_tokens
|
|
265
|
-
except Exception as e:
|
|
266
|
-
raise RuntimeError(
|
|
267
|
-
f"Google token counting API failed for {self.model_name}"
|
|
268
|
-
) from e
|
|
269
|
-
|
|
270
|
-
def count_message_tokens(self, messages: list[ModelMessage]) -> int:
|
|
271
|
-
"""Count tokens across all messages using Google API.
|
|
272
|
-
|
|
273
|
-
Args:
|
|
274
|
-
messages: List of PydanticAI messages
|
|
275
|
-
|
|
276
|
-
Returns:
|
|
277
|
-
Total token count for all messages
|
|
278
|
-
|
|
279
|
-
Raises:
|
|
280
|
-
RuntimeError: If token counting fails
|
|
281
|
-
"""
|
|
282
|
-
total_text = self._extract_text_from_messages(messages)
|
|
283
|
-
return self.count_tokens(total_text)
|
|
284
|
-
|
|
285
|
-
def _extract_text_from_messages(self, messages: list[ModelMessage]) -> str:
|
|
286
|
-
"""Extract all text content from messages for token counting."""
|
|
287
|
-
text_parts = []
|
|
288
|
-
|
|
289
|
-
for message in messages:
|
|
290
|
-
if hasattr(message, "parts"):
|
|
291
|
-
for part in message.parts:
|
|
292
|
-
if hasattr(part, "content") and isinstance(part.content, str):
|
|
293
|
-
text_parts.append(part.content)
|
|
294
|
-
else:
|
|
295
|
-
# Handle non-text parts (tool calls, etc.)
|
|
296
|
-
text_parts.append(str(part))
|
|
297
|
-
else:
|
|
298
|
-
# Handle messages without parts
|
|
299
|
-
text_parts.append(str(message))
|
|
300
|
-
|
|
301
|
-
return "\n".join(text_parts)
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
def get_token_counter(model_config: ModelConfig) -> TokenCounter:
|
|
305
|
-
"""Get appropriate token counter for the model provider (cached singleton).
|
|
306
|
-
|
|
307
|
-
This function ensures that every provider has a proper token counting
|
|
308
|
-
implementation without any fallbacks to estimation. Token counters are
|
|
309
|
-
cached to avoid repeated initialization overhead.
|
|
310
|
-
|
|
311
|
-
Args:
|
|
312
|
-
model_config: Model configuration with provider and credentials
|
|
313
|
-
|
|
314
|
-
Returns:
|
|
315
|
-
Cached provider-specific token counter
|
|
316
|
-
|
|
317
|
-
Raises:
|
|
318
|
-
ValueError: If provider is not supported for token counting
|
|
319
|
-
RuntimeError: If token counter initialization fails
|
|
320
|
-
"""
|
|
321
|
-
# Create cache key from provider, model name, and API key
|
|
322
|
-
cache_key = (
|
|
323
|
-
model_config.provider.value,
|
|
324
|
-
model_config.name,
|
|
325
|
-
model_config.api_key[:10]
|
|
326
|
-
if model_config.api_key
|
|
327
|
-
else "no-key", # Partial key for cache
|
|
328
|
-
)
|
|
329
|
-
|
|
330
|
-
# Return cached instance if available
|
|
331
|
-
if cache_key in _token_counter_cache:
|
|
332
|
-
logger.debug(
|
|
333
|
-
f"Reusing cached token counter for {model_config.provider.value}:{model_config.name}"
|
|
334
|
-
)
|
|
335
|
-
return _token_counter_cache[cache_key]
|
|
336
|
-
|
|
337
|
-
# Create new instance and cache it
|
|
338
|
-
logger.debug(
|
|
339
|
-
f"Creating new token counter for {model_config.provider.value}:{model_config.name}"
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
counter: TokenCounter
|
|
343
|
-
if model_config.provider == ProviderType.OPENAI:
|
|
344
|
-
counter = OpenAITokenCounter(model_config.name)
|
|
345
|
-
elif model_config.provider == ProviderType.ANTHROPIC:
|
|
346
|
-
counter = AnthropicTokenCounter(model_config.name, model_config.api_key)
|
|
347
|
-
elif model_config.provider == ProviderType.GOOGLE:
|
|
348
|
-
counter = GoogleTokenCounter(model_config.name, model_config.api_key)
|
|
349
|
-
else:
|
|
350
|
-
raise ValueError(
|
|
351
|
-
f"Unsupported provider for token counting: {model_config.provider}. "
|
|
352
|
-
f"Supported providers: {[p.value for p in ProviderType]}"
|
|
353
|
-
)
|
|
354
|
-
|
|
355
|
-
# Cache the instance
|
|
356
|
-
_token_counter_cache[cache_key] = counter
|
|
357
|
-
logger.debug(
|
|
358
|
-
f"Cached token counter for {model_config.provider.value}:{model_config.name}"
|
|
359
|
-
)
|
|
360
|
-
|
|
361
|
-
return counter
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
def count_tokens_from_messages(
|
|
365
|
-
messages: list[ModelMessage], model_config: ModelConfig
|
|
366
|
-
) -> int:
|
|
367
|
-
"""Count actual tokens from messages using provider-specific methods.
|
|
368
|
-
|
|
369
|
-
This replaces the old estimation approach with accurate token counting
|
|
370
|
-
using each provider's official APIs and libraries.
|
|
371
|
-
|
|
372
|
-
Args:
|
|
373
|
-
messages: List of messages to count tokens for
|
|
374
|
-
model_config: Model configuration with provider info
|
|
375
|
-
|
|
376
|
-
Returns:
|
|
377
|
-
Exact token count for the messages
|
|
378
|
-
|
|
379
|
-
Raises:
|
|
380
|
-
ValueError: If provider is not supported
|
|
381
|
-
RuntimeError: If token counting fails
|
|
382
|
-
"""
|
|
383
|
-
counter = get_token_counter(model_config)
|
|
384
|
-
return counter.count_message_tokens(messages)
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
def count_post_summary_tokens(
|
|
388
|
-
messages: list[ModelMessage], summary_index: int, model_config: ModelConfig
|
|
389
|
-
) -> int:
|
|
390
|
-
"""Count actual tokens from summary onwards for incremental compaction decisions.
|
|
391
|
-
|
|
392
|
-
Args:
|
|
393
|
-
messages: Full message history
|
|
394
|
-
summary_index: Index of the last summary message
|
|
395
|
-
model_config: Model configuration with provider info
|
|
396
|
-
|
|
397
|
-
Returns:
|
|
398
|
-
Exact token count from summary onwards
|
|
399
|
-
|
|
400
|
-
Raises:
|
|
401
|
-
ValueError: If provider is not supported
|
|
402
|
-
RuntimeError: If token counting fails
|
|
403
|
-
"""
|
|
404
|
-
if summary_index >= len(messages):
|
|
405
|
-
return 0
|
|
406
|
-
|
|
407
|
-
post_summary_messages = messages[summary_index:]
|
|
408
|
-
return count_tokens_from_messages(post_summary_messages, model_config)
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
def count_tokens_from_message_parts(
|
|
412
|
-
messages: list[ModelMessage], model_config: ModelConfig
|
|
413
|
-
) -> int:
|
|
414
|
-
"""Count actual tokens from message parts for summarization requests.
|
|
415
|
-
|
|
416
|
-
Args:
|
|
417
|
-
messages: List of messages to count tokens for
|
|
418
|
-
model_config: Model configuration with provider info
|
|
419
|
-
|
|
420
|
-
Returns:
|
|
421
|
-
Exact token count from message parts
|
|
422
|
-
|
|
423
|
-
Raises:
|
|
424
|
-
ValueError: If provider is not supported
|
|
425
|
-
RuntimeError: If token counting fails
|
|
426
|
-
"""
|
|
427
|
-
# For now, use the same logic as count_tokens_from_messages
|
|
428
|
-
# This can be optimized later if needed for different counting strategies
|
|
429
|
-
return count_tokens_from_messages(messages, model_config)
|
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
"""User interaction tools for Pydantic AI agents."""
|
|
2
|
-
|
|
3
|
-
from asyncio import get_running_loop
|
|
4
|
-
|
|
5
|
-
from pydantic_ai import CallDeferred, RunContext
|
|
6
|
-
|
|
7
|
-
from shotgun.agents.models import AgentDeps, UserQuestion
|
|
8
|
-
from shotgun.logging_config import get_logger
|
|
9
|
-
|
|
10
|
-
logger = get_logger(__name__)
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
async def ask_user(ctx: RunContext[AgentDeps], question: str) -> str:
|
|
14
|
-
"""Ask the human a question and return the answer.
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
question: The question to ask the user with a clear CTA at the end. Needs to be is readable, clear, and easy to understand. Use Markdown formatting. Make key phrases and words stand out.
|
|
19
|
-
|
|
20
|
-
Returns:
|
|
21
|
-
The user's response as a string
|
|
22
|
-
"""
|
|
23
|
-
tool_call_id = ctx.tool_call_id
|
|
24
|
-
assert tool_call_id is not None # noqa: S101
|
|
25
|
-
|
|
26
|
-
try:
|
|
27
|
-
logger.debug("\nš %s\n", question)
|
|
28
|
-
future = get_running_loop().create_future()
|
|
29
|
-
await ctx.deps.queue.put(
|
|
30
|
-
UserQuestion(question=question, tool_call_id=tool_call_id, result=future)
|
|
31
|
-
)
|
|
32
|
-
ctx.deps.tasks.append(future)
|
|
33
|
-
raise CallDeferred(question)
|
|
34
|
-
|
|
35
|
-
except (EOFError, KeyboardInterrupt):
|
|
36
|
-
logger.warning("User input interrupted or unavailable")
|
|
37
|
-
return "User input not available or interrupted"
|