shotgun-sh 0.1.16.dev2__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of shotgun-sh might be problematic. Click here for more details.

Files changed (55) hide show
  1. shotgun/agents/common.py +4 -5
  2. shotgun/agents/config/constants.py +23 -6
  3. shotgun/agents/config/manager.py +239 -76
  4. shotgun/agents/config/models.py +74 -84
  5. shotgun/agents/config/provider.py +174 -85
  6. shotgun/agents/history/compaction.py +1 -1
  7. shotgun/agents/history/history_processors.py +18 -9
  8. shotgun/agents/history/token_counting/__init__.py +31 -0
  9. shotgun/agents/history/token_counting/anthropic.py +89 -0
  10. shotgun/agents/history/token_counting/base.py +67 -0
  11. shotgun/agents/history/token_counting/openai.py +80 -0
  12. shotgun/agents/history/token_counting/sentencepiece_counter.py +119 -0
  13. shotgun/agents/history/token_counting/tokenizer_cache.py +90 -0
  14. shotgun/agents/history/token_counting/utils.py +147 -0
  15. shotgun/agents/history/token_estimation.py +12 -12
  16. shotgun/agents/llm.py +62 -0
  17. shotgun/agents/models.py +2 -2
  18. shotgun/agents/tools/web_search/__init__.py +42 -15
  19. shotgun/agents/tools/web_search/anthropic.py +54 -50
  20. shotgun/agents/tools/web_search/gemini.py +31 -20
  21. shotgun/agents/tools/web_search/openai.py +4 -4
  22. shotgun/build_constants.py +2 -2
  23. shotgun/cli/config.py +34 -63
  24. shotgun/cli/feedback.py +4 -2
  25. shotgun/cli/models.py +2 -2
  26. shotgun/codebase/core/ingestor.py +47 -8
  27. shotgun/codebase/core/manager.py +7 -3
  28. shotgun/codebase/models.py +4 -4
  29. shotgun/llm_proxy/__init__.py +16 -0
  30. shotgun/llm_proxy/clients.py +39 -0
  31. shotgun/llm_proxy/constants.py +8 -0
  32. shotgun/main.py +6 -0
  33. shotgun/posthog_telemetry.py +15 -11
  34. shotgun/sentry_telemetry.py +3 -3
  35. shotgun/shotgun_web/__init__.py +19 -0
  36. shotgun/shotgun_web/client.py +138 -0
  37. shotgun/shotgun_web/constants.py +17 -0
  38. shotgun/shotgun_web/models.py +47 -0
  39. shotgun/telemetry.py +7 -4
  40. shotgun/tui/app.py +26 -8
  41. shotgun/tui/screens/chat.py +2 -8
  42. shotgun/tui/screens/chat_screen/command_providers.py +118 -11
  43. shotgun/tui/screens/chat_screen/history.py +3 -1
  44. shotgun/tui/screens/feedback.py +2 -2
  45. shotgun/tui/screens/model_picker.py +327 -0
  46. shotgun/tui/screens/provider_config.py +118 -28
  47. shotgun/tui/screens/shotgun_auth.py +295 -0
  48. shotgun/tui/screens/welcome.py +176 -0
  49. shotgun/utils/env_utils.py +12 -0
  50. {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dist-info}/METADATA +2 -2
  51. {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dist-info}/RECORD +54 -37
  52. shotgun/agents/history/token_counting.py +0 -429
  53. {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dist-info}/WHEEL +0 -0
  54. {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dist-info}/entry_points.txt +0 -0
  55. {shotgun_sh-0.1.16.dev2.dist-info → shotgun_sh-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import TYPE_CHECKING, Any, Protocol
4
4
 
5
+ from pydantic_ai import ModelSettings
5
6
  from pydantic_ai.messages import (
6
7
  ModelMessage,
7
8
  ModelRequest,
@@ -10,7 +11,7 @@ from pydantic_ai.messages import (
10
11
  UserPromptPart,
11
12
  )
12
13
 
13
- from shotgun.agents.config.models import shotgun_model_request
14
+ from shotgun.agents.llm import shotgun_model_request
14
15
  from shotgun.agents.messages import AgentSystemPrompt, SystemStatusPrompt
15
16
  from shotgun.agents.models import AgentDeps
16
17
  from shotgun.logging_config import get_logger
@@ -154,7 +155,7 @@ async def token_limit_compactor(
154
155
 
155
156
  if last_summary_index is not None:
156
157
  # Check if post-summary conversation exceeds threshold for incremental compaction
157
- post_summary_tokens = estimate_post_summary_tokens(
158
+ post_summary_tokens = await estimate_post_summary_tokens(
158
159
  messages, last_summary_index, deps.llm_model
159
160
  )
160
161
  post_summary_percentage = (
@@ -248,7 +249,7 @@ async def token_limit_compactor(
248
249
  ]
249
250
 
250
251
  # Calculate optimal max_tokens for summarization
251
- max_tokens = calculate_max_summarization_tokens(
252
+ max_tokens = await calculate_max_summarization_tokens(
252
253
  deps.llm_model, request_messages
253
254
  )
254
255
 
@@ -261,7 +262,9 @@ async def token_limit_compactor(
261
262
  summary_response = await shotgun_model_request(
262
263
  model_config=deps.llm_model,
263
264
  messages=request_messages,
264
- max_tokens=max_tokens, # Use calculated optimal tokens for summarization
265
+ model_settings=ModelSettings(
266
+ max_tokens=max_tokens # Use calculated optimal tokens for summarization
267
+ ),
265
268
  )
266
269
 
267
270
  log_summarization_response(summary_response, "INCREMENTAL")
@@ -328,7 +331,9 @@ async def token_limit_compactor(
328
331
 
329
332
  # Track compaction completion
330
333
  messages_after = len(compacted_messages)
331
- tokens_after = estimate_tokens_from_messages(compacted_messages, deps.llm_model)
334
+ tokens_after = await estimate_tokens_from_messages(
335
+ compacted_messages, deps.llm_model
336
+ )
332
337
  reduction_percentage = (
333
338
  ((messages_before - messages_after) / messages_before * 100)
334
339
  if messages_before > 0
@@ -354,7 +359,7 @@ async def token_limit_compactor(
354
359
 
355
360
  else:
356
361
  # Check if total conversation exceeds threshold for full compaction
357
- total_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
362
+ total_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
358
363
  total_percentage = (total_tokens / max_tokens) * 100 if max_tokens > 0 else 0
359
364
 
360
365
  logger.debug(
@@ -392,7 +397,9 @@ async def _full_compaction(
392
397
  ]
393
398
 
394
399
  # Calculate optimal max_tokens for summarization
395
- max_tokens = calculate_max_summarization_tokens(deps.llm_model, request_messages)
400
+ max_tokens = await calculate_max_summarization_tokens(
401
+ deps.llm_model, request_messages
402
+ )
396
403
 
397
404
  # Debug logging using shared utilities
398
405
  log_summarization_request(
@@ -403,11 +410,13 @@ async def _full_compaction(
403
410
  summary_response = await shotgun_model_request(
404
411
  model_config=deps.llm_model,
405
412
  messages=request_messages,
406
- max_tokens=max_tokens, # Use calculated optimal tokens for summarization
413
+ model_settings=ModelSettings(
414
+ max_tokens=max_tokens # Use calculated optimal tokens for summarization
415
+ ),
407
416
  )
408
417
 
409
418
  # Calculate token reduction
410
- current_tokens = estimate_tokens_from_messages(messages, deps.llm_model)
419
+ current_tokens = await estimate_tokens_from_messages(messages, deps.llm_model)
411
420
  summary_usage = summary_response.usage
412
421
  reduction_percentage = (
413
422
  ((current_tokens - summary_usage.output_tokens) / current_tokens) * 100
@@ -0,0 +1,31 @@
1
+ """Real token counting for all supported providers.
2
+
3
+ This module provides accurate token counting using each provider's official
4
+ APIs and libraries, eliminating the need for rough character-based estimation.
5
+ """
6
+
7
+ from .anthropic import AnthropicTokenCounter
8
+ from .base import TokenCounter, extract_text_from_messages
9
+ from .openai import OpenAITokenCounter
10
+ from .sentencepiece_counter import SentencePieceTokenCounter
11
+ from .utils import (
12
+ count_post_summary_tokens,
13
+ count_tokens_from_message_parts,
14
+ count_tokens_from_messages,
15
+ get_token_counter,
16
+ )
17
+
18
+ __all__ = [
19
+ # Base classes
20
+ "TokenCounter",
21
+ # Counter implementations
22
+ "OpenAITokenCounter",
23
+ "AnthropicTokenCounter",
24
+ "SentencePieceTokenCounter",
25
+ # Utility functions
26
+ "get_token_counter",
27
+ "count_tokens_from_messages",
28
+ "count_post_summary_tokens",
29
+ "count_tokens_from_message_parts",
30
+ "extract_text_from_messages",
31
+ ]
@@ -0,0 +1,89 @@
1
+ """Anthropic token counting using official client."""
2
+
3
+ from pydantic_ai.messages import ModelMessage
4
+
5
+ from shotgun.agents.config.models import KeyProvider
6
+ from shotgun.llm_proxy import create_anthropic_proxy_client
7
+ from shotgun.logging_config import get_logger
8
+
9
+ from .base import TokenCounter, extract_text_from_messages
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class AnthropicTokenCounter(TokenCounter):
15
+ """Token counter for Anthropic models using official client."""
16
+
17
+ def __init__(
18
+ self,
19
+ model_name: str,
20
+ api_key: str,
21
+ key_provider: KeyProvider = KeyProvider.BYOK,
22
+ ):
23
+ """Initialize Anthropic token counter.
24
+
25
+ Args:
26
+ model_name: Anthropic model name for token counting
27
+ api_key: API key (Anthropic for BYOK, Shotgun for proxy)
28
+ key_provider: Key provider type (BYOK or SHOTGUN)
29
+
30
+ Raises:
31
+ RuntimeError: If client initialization fails
32
+ """
33
+ self.model_name = model_name
34
+ import anthropic
35
+
36
+ try:
37
+ if key_provider == KeyProvider.SHOTGUN:
38
+ # Use LiteLLM proxy for Shotgun Account
39
+ # Proxies to Anthropic's token counting API
40
+ self.client = create_anthropic_proxy_client(api_key)
41
+ logger.debug(
42
+ f"Initialized Anthropic token counter for {model_name} via LiteLLM proxy"
43
+ )
44
+ else:
45
+ # Direct Anthropic API for BYOK
46
+ self.client = anthropic.Anthropic(api_key=api_key)
47
+ logger.debug(
48
+ f"Initialized Anthropic token counter for {model_name} via direct API"
49
+ )
50
+ except Exception as e:
51
+ raise RuntimeError("Failed to initialize Anthropic client") from e
52
+
53
+ async def count_tokens(self, text: str) -> int:
54
+ """Count tokens using Anthropic's official API (async).
55
+
56
+ Args:
57
+ text: Text to count tokens for
58
+
59
+ Returns:
60
+ Exact token count from Anthropic API
61
+
62
+ Raises:
63
+ RuntimeError: If API call fails
64
+ """
65
+ try:
66
+ # Anthropic API expects messages format and model parameter
67
+ result = self.client.messages.count_tokens(
68
+ messages=[{"role": "user", "content": text}], model=self.model_name
69
+ )
70
+ return result.input_tokens
71
+ except Exception as e:
72
+ raise RuntimeError(
73
+ f"Anthropic token counting API failed for {self.model_name}"
74
+ ) from e
75
+
76
+ async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
77
+ """Count tokens across all messages using Anthropic API (async).
78
+
79
+ Args:
80
+ messages: List of PydanticAI messages
81
+
82
+ Returns:
83
+ Total token count for all messages
84
+
85
+ Raises:
86
+ RuntimeError: If token counting fails
87
+ """
88
+ total_text = extract_text_from_messages(messages)
89
+ return await self.count_tokens(total_text)
@@ -0,0 +1,67 @@
1
+ """Base classes and shared utilities for token counting."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ from pydantic_ai.messages import ModelMessage
6
+
7
+
8
+ class TokenCounter(ABC):
9
+ """Abstract base class for provider-specific token counting.
10
+
11
+ All methods are async to support non-blocking operations like
12
+ downloading tokenizer models or making API calls.
13
+ """
14
+
15
+ @abstractmethod
16
+ async def count_tokens(self, text: str) -> int:
17
+ """Count tokens in text using provider-specific method (async).
18
+
19
+ Args:
20
+ text: Text to count tokens for
21
+
22
+ Returns:
23
+ Exact token count as determined by the provider
24
+
25
+ Raises:
26
+ RuntimeError: If token counting fails
27
+ """
28
+
29
+ @abstractmethod
30
+ async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
31
+ """Count tokens in PydanticAI message structures (async).
32
+
33
+ Args:
34
+ messages: List of messages to count tokens for
35
+
36
+ Returns:
37
+ Total token count across all messages
38
+
39
+ Raises:
40
+ RuntimeError: If token counting fails
41
+ """
42
+
43
+
44
+ def extract_text_from_messages(messages: list[ModelMessage]) -> str:
45
+ """Extract all text content from messages for token counting.
46
+
47
+ Args:
48
+ messages: List of PydanticAI messages
49
+
50
+ Returns:
51
+ Combined text content from all messages
52
+ """
53
+ text_parts = []
54
+
55
+ for message in messages:
56
+ if hasattr(message, "parts"):
57
+ for part in message.parts:
58
+ if hasattr(part, "content") and isinstance(part.content, str):
59
+ text_parts.append(part.content)
60
+ else:
61
+ # Handle non-text parts (tool calls, etc.)
62
+ text_parts.append(str(part))
63
+ else:
64
+ # Handle messages without parts
65
+ text_parts.append(str(message))
66
+
67
+ return "\n".join(text_parts)
@@ -0,0 +1,80 @@
1
+ """OpenAI token counting using tiktoken."""
2
+
3
+ from pydantic_ai.messages import ModelMessage
4
+
5
+ from shotgun.logging_config import get_logger
6
+
7
+ from .base import TokenCounter, extract_text_from_messages
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ class OpenAITokenCounter(TokenCounter):
13
+ """Token counter for OpenAI models using tiktoken."""
14
+
15
+ # Official encoding mappings for OpenAI models
16
+ ENCODING_MAP = {
17
+ "gpt-5": "o200k_base",
18
+ "gpt-4o": "o200k_base",
19
+ "gpt-4": "cl100k_base",
20
+ "gpt-3.5-turbo": "cl100k_base",
21
+ }
22
+
23
+ def __init__(self, model_name: str):
24
+ """Initialize OpenAI token counter.
25
+
26
+ Args:
27
+ model_name: OpenAI model name to get correct encoding for
28
+
29
+ Raises:
30
+ RuntimeError: If encoding initialization fails
31
+ """
32
+ self.model_name = model_name
33
+
34
+ import tiktoken
35
+
36
+ try:
37
+ # Get the appropriate encoding for this model
38
+ encoding_name = self.ENCODING_MAP.get(model_name, "o200k_base")
39
+ self.encoding = tiktoken.get_encoding(encoding_name)
40
+ logger.debug(
41
+ f"Initialized OpenAI token counter with {encoding_name} encoding"
42
+ )
43
+ except Exception as e:
44
+ raise RuntimeError(
45
+ f"Failed to initialize tiktoken encoding for {model_name}"
46
+ ) from e
47
+
48
+ async def count_tokens(self, text: str) -> int:
49
+ """Count tokens using tiktoken (async).
50
+
51
+ Args:
52
+ text: Text to count tokens for
53
+
54
+ Returns:
55
+ Exact token count using tiktoken
56
+
57
+ Raises:
58
+ RuntimeError: If token counting fails
59
+ """
60
+ try:
61
+ return len(self.encoding.encode(text))
62
+ except Exception as e:
63
+ raise RuntimeError(
64
+ f"Failed to count tokens for OpenAI model {self.model_name}"
65
+ ) from e
66
+
67
+ async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
68
+ """Count tokens across all messages using tiktoken (async).
69
+
70
+ Args:
71
+ messages: List of PydanticAI messages
72
+
73
+ Returns:
74
+ Total token count for all messages
75
+
76
+ Raises:
77
+ RuntimeError: If token counting fails
78
+ """
79
+ total_text = extract_text_from_messages(messages)
80
+ return await self.count_tokens(total_text)
@@ -0,0 +1,119 @@
1
+ """Gemini token counting using official SentencePiece tokenizer.
2
+
3
+ This implementation uses Google's official Gemini/Gemma tokenizer model
4
+ for 100% accurate local token counting without API calls.
5
+
6
+ Performance: 10-100x faster than API-based counting.
7
+ Accuracy: 100% match with actual Gemini API usage.
8
+
9
+ The tokenizer is downloaded on first use and cached locally for future use.
10
+ """
11
+
12
+ from typing import Any
13
+
14
+ from pydantic_ai.messages import ModelMessage
15
+
16
+ from shotgun.logging_config import get_logger
17
+
18
+ from .base import TokenCounter, extract_text_from_messages
19
+ from .tokenizer_cache import download_gemini_tokenizer, get_gemini_tokenizer_path
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ class SentencePieceTokenCounter(TokenCounter):
25
+ """Token counter for Gemini models using official SentencePiece tokenizer.
26
+
27
+ This counter provides 100% accurate token counting for Gemini models
28
+ using the official tokenizer model from Google's gemma_pytorch repository.
29
+ Token counting is performed locally without any API calls, resulting in
30
+ 10-100x performance improvement over API-based methods.
31
+
32
+ The tokenizer is downloaded asynchronously on first use and cached locally.
33
+ """
34
+
35
+ def __init__(self, model_name: str):
36
+ """Initialize Gemini SentencePiece token counter.
37
+
38
+ The tokenizer is not loaded immediately - it will be downloaded and
39
+ loaded lazily on first use.
40
+
41
+ Args:
42
+ model_name: Gemini model name (used for logging)
43
+ """
44
+ self.model_name = model_name
45
+ self.sp: Any | None = None # SentencePieceProcessor, loaded lazily
46
+
47
+ async def _ensure_tokenizer(self) -> None:
48
+ """Ensure tokenizer is downloaded and loaded.
49
+
50
+ This method downloads the tokenizer on first call (if not cached)
51
+ and loads it into memory. Subsequent calls reuse the loaded tokenizer.
52
+
53
+ Raises:
54
+ RuntimeError: If tokenizer download or loading fails
55
+ """
56
+ if self.sp is not None:
57
+ # Already loaded
58
+ return
59
+
60
+ import sentencepiece as spm # type: ignore[import-untyped]
61
+
62
+ try:
63
+ # Check if already cached, otherwise download
64
+ tokenizer_path = get_gemini_tokenizer_path()
65
+ if not tokenizer_path.exists():
66
+ await download_gemini_tokenizer()
67
+
68
+ # Load the tokenizer
69
+ self.sp = spm.SentencePieceProcessor()
70
+ self.sp.load(str(tokenizer_path))
71
+ logger.debug(f"Loaded SentencePiece tokenizer for {self.model_name}")
72
+ except Exception as e:
73
+ raise RuntimeError(
74
+ f"Failed to load Gemini tokenizer for {self.model_name}"
75
+ ) from e
76
+
77
+ async def count_tokens(self, text: str) -> int:
78
+ """Count tokens using SentencePiece (async).
79
+
80
+ Downloads tokenizer on first call if not cached.
81
+
82
+ Args:
83
+ text: Text to count tokens for
84
+
85
+ Returns:
86
+ Exact token count using Gemini's tokenizer
87
+
88
+ Raises:
89
+ RuntimeError: If token counting fails
90
+ """
91
+ await self._ensure_tokenizer()
92
+
93
+ if self.sp is None:
94
+ raise RuntimeError(f"Tokenizer not initialized for {self.model_name}")
95
+
96
+ try:
97
+ tokens = self.sp.encode(text)
98
+ return len(tokens)
99
+ except Exception as e:
100
+ raise RuntimeError(
101
+ f"Failed to count tokens for Gemini model {self.model_name}"
102
+ ) from e
103
+
104
+ async def count_message_tokens(self, messages: list[ModelMessage]) -> int:
105
+ """Count tokens across all messages using SentencePiece (async).
106
+
107
+ Downloads tokenizer on first call if not cached.
108
+
109
+ Args:
110
+ messages: List of PydanticAI messages
111
+
112
+ Returns:
113
+ Total token count for all messages
114
+
115
+ Raises:
116
+ RuntimeError: If token counting fails
117
+ """
118
+ total_text = extract_text_from_messages(messages)
119
+ return await self.count_tokens(total_text)
@@ -0,0 +1,90 @@
1
+ """Async tokenizer download and caching utilities."""
2
+
3
+ import hashlib
4
+ from pathlib import Path
5
+
6
+ import httpx
7
+
8
+ from shotgun.logging_config import get_logger
9
+ from shotgun.utils.file_system_utils import get_shotgun_home
10
+
11
+ logger = get_logger(__name__)
12
+
13
+ # Gemini tokenizer constants
14
+ GEMINI_TOKENIZER_URL = "https://raw.githubusercontent.com/google/gemma_pytorch/main/tokenizer/tokenizer.model"
15
+ GEMINI_TOKENIZER_SHA256 = (
16
+ "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"
17
+ )
18
+
19
+
20
+ def get_tokenizer_cache_dir() -> Path:
21
+ """Get the directory for cached tokenizer models.
22
+
23
+ Returns:
24
+ Path to tokenizers cache directory
25
+ """
26
+ cache_dir = get_shotgun_home() / "tokenizers"
27
+ cache_dir.mkdir(parents=True, exist_ok=True)
28
+ return cache_dir
29
+
30
+
31
+ def get_gemini_tokenizer_path() -> Path:
32
+ """Get the path where the Gemini tokenizer should be cached.
33
+
34
+ Returns:
35
+ Path to cached Gemini tokenizer
36
+ """
37
+ return get_tokenizer_cache_dir() / "gemini_tokenizer.model"
38
+
39
+
40
+ async def download_gemini_tokenizer() -> Path:
41
+ """Download and cache the official Gemini tokenizer model.
42
+
43
+ This downloads Google's official Gemini/Gemma tokenizer from the
44
+ gemma_pytorch repository and caches it locally for future use.
45
+
46
+ The download is async and non-blocking, with SHA256 verification
47
+ for security.
48
+
49
+ Returns:
50
+ Path to the cached tokenizer file
51
+
52
+ Raises:
53
+ RuntimeError: If download fails or checksum verification fails
54
+ """
55
+ cache_path = get_gemini_tokenizer_path()
56
+
57
+ # Check if already cached
58
+ if cache_path.exists():
59
+ logger.debug(f"Gemini tokenizer already cached at {cache_path}")
60
+ return cache_path
61
+
62
+ logger.info("Downloading Gemini tokenizer (4MB, first time only)...")
63
+
64
+ try:
65
+ # Download with async httpx
66
+ async with httpx.AsyncClient(timeout=30.0) as client:
67
+ response = await client.get(GEMINI_TOKENIZER_URL, follow_redirects=True)
68
+ response.raise_for_status()
69
+ content = response.content
70
+
71
+ # Verify SHA256 checksum
72
+ actual_hash = hashlib.sha256(content).hexdigest()
73
+ if actual_hash != GEMINI_TOKENIZER_SHA256:
74
+ raise RuntimeError(
75
+ f"Gemini tokenizer checksum mismatch. "
76
+ f"Expected: {GEMINI_TOKENIZER_SHA256}, got: {actual_hash}"
77
+ )
78
+
79
+ # Atomic write: write to temp file first, then rename
80
+ temp_path = cache_path.with_suffix(".tmp")
81
+ temp_path.write_bytes(content)
82
+ temp_path.rename(cache_path)
83
+
84
+ logger.info(f"Gemini tokenizer downloaded and cached at {cache_path}")
85
+ return cache_path
86
+
87
+ except httpx.HTTPError as e:
88
+ raise RuntimeError(f"Failed to download Gemini tokenizer: {e}") from e
89
+ except OSError as e:
90
+ raise RuntimeError(f"Failed to save Gemini tokenizer: {e}") from e