tokenator 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tokenator/__init__.py CHANGED
@@ -3,11 +3,18 @@
3
3
  import logging
4
4
  from .openai.client_openai import tokenator_openai
5
5
  from .anthropic.client_anthropic import tokenator_anthropic
6
+ from .gemini.client_gemini import tokenator_gemini
6
7
  from . import usage
7
8
  from .utils import get_default_db_path
8
9
  from .usage import TokenUsageService
9
10
 
10
11
  usage = TokenUsageService() # noqa: F811
11
- __all__ = ["tokenator_openai", "tokenator_anthropic", "usage", "get_default_db_path"]
12
+ __all__ = [
13
+ "tokenator_openai",
14
+ "tokenator_anthropic",
15
+ "tokenator_gemini",
16
+ "usage",
17
+ "get_default_db_path",
18
+ ]
12
19
 
13
20
  logger = logging.getLogger(__name__)
tokenator/base_wrapper.py CHANGED
@@ -112,7 +112,10 @@ class BaseWrapper:
112
112
  try:
113
113
  self._log_usage_impl(token_usage_stats, session, execution_id)
114
114
  session.commit()
115
- logger.debug("Successfully committed token usage for execution_id: %s", execution_id)
115
+ logger.debug(
116
+ "Successfully committed token usage for execution_id: %s",
117
+ execution_id,
118
+ )
116
119
  except Exception as e:
117
120
  logger.error("Failed to log token usage: %s", str(e))
118
121
  session.rollback()
@@ -0,0 +1,5 @@
1
+ """Gemini client wrapper with token usage tracking."""
2
+
3
+ from .client_gemini import tokenator_gemini
4
+
5
+ __all__ = ["tokenator_gemini"]
@@ -0,0 +1,230 @@
1
+ """Gemini client wrapper with token usage tracking."""
2
+
3
+ from typing import Any, Optional, Iterator, AsyncIterator
4
+ import logging
5
+
6
+ from google import genai
7
+ from google.genai.types import GenerateContentResponse
8
+
9
+ from ..models import (
10
+ TokenMetrics,
11
+ TokenUsageStats,
12
+ )
13
+ from ..base_wrapper import BaseWrapper, ResponseType
14
+ from .stream_interceptors import (
15
+ GeminiAsyncStreamInterceptor,
16
+ GeminiSyncStreamInterceptor,
17
+ )
18
+ from ..state import is_tokenator_enabled
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def _create_usage_callback(execution_id, log_usage_fn):
24
+ """Creates a callback function for processing usage statistics from stream chunks."""
25
+
26
+ def usage_callback(chunks):
27
+ if not chunks:
28
+ return
29
+
30
+ # Skip if tokenator is disabled
31
+ if not is_tokenator_enabled:
32
+ logger.debug("Tokenator is disabled - skipping stream usage logging")
33
+ return
34
+
35
+ logger.debug("Processing stream usage for execution_id: %s", execution_id)
36
+
37
+ # Build usage_data from the first chunk's model
38
+ usage_data = TokenUsageStats(
39
+ model=chunks[0].model_version,
40
+ usage=TokenMetrics(),
41
+ )
42
+
43
+ # Only take usage from the last chunk as it contains complete usage info
44
+ last_chunk = chunks[-1]
45
+ if last_chunk.usage_metadata:
46
+ usage_data.usage.prompt_tokens = (
47
+ last_chunk.usage_metadata.prompt_token_count
48
+ )
49
+ usage_data.usage.completion_tokens = (
50
+ last_chunk.usage_metadata.candidates_token_count or 0
51
+ )
52
+ usage_data.usage.total_tokens = last_chunk.usage_metadata.total_token_count
53
+ log_usage_fn(usage_data, execution_id=execution_id)
54
+
55
+ return usage_callback
56
+
57
+
58
+ class BaseGeminiWrapper(BaseWrapper):
59
+ def __init__(self, client, db_path=None, provider: str = "gemini"):
60
+ super().__init__(client, db_path)
61
+ self.provider = provider
62
+ self._async_wrapper = None
63
+
64
+ def _process_response_usage(
65
+ self, response: ResponseType
66
+ ) -> Optional[TokenUsageStats]:
67
+ """Process and log usage statistics from a response."""
68
+ try:
69
+ if isinstance(response, GenerateContentResponse):
70
+ if response.usage_metadata is None:
71
+ return None
72
+ usage = TokenMetrics(
73
+ prompt_tokens=response.usage_metadata.prompt_token_count,
74
+ completion_tokens=response.usage_metadata.candidates_token_count,
75
+ total_tokens=response.usage_metadata.total_token_count,
76
+ )
77
+ return TokenUsageStats(model=response.model_version, usage=usage)
78
+
79
+ elif isinstance(response, dict):
80
+ usage_dict = response.get("usage_metadata")
81
+ if not usage_dict:
82
+ return None
83
+ usage = TokenMetrics(
84
+ prompt_tokens=usage_dict.get("prompt_token_count", 0),
85
+ completion_tokens=usage_dict.get("candidates_token_count", 0),
86
+ total_tokens=usage_dict.get("total_token_count", 0),
87
+ )
88
+ return TokenUsageStats(
89
+ model=response.get("model", "unknown"), usage=usage
90
+ )
91
+ except Exception as e:
92
+ logger.warning("Failed to process usage stats: %s", str(e))
93
+ return None
94
+ return None
95
+
96
+ @property
97
+ def chat(self):
98
+ return self
99
+
100
+ @property
101
+ def chats(self):
102
+ return self
103
+
104
+ @property
105
+ def models(self):
106
+ return self
107
+
108
+ @property
109
+ def aio(self):
110
+ if self._async_wrapper is None:
111
+ self._async_wrapper = AsyncGeminiWrapper(self)
112
+ return self._async_wrapper
113
+
114
+ def count_tokens(self, *args: Any, **kwargs: Any):
115
+ return self.client.models.count_tokens(*args, **kwargs)
116
+
117
+
118
+ class AsyncGeminiWrapper:
119
+ """Async wrapper for Gemini client to match the official SDK structure."""
120
+
121
+ def __init__(self, wrapper: BaseGeminiWrapper):
122
+ self.wrapper = wrapper
123
+ self._models = None
124
+
125
+ @property
126
+ def models(self):
127
+ if self._models is None:
128
+ self._models = AsyncModelsWrapper(self.wrapper)
129
+ return self._models
130
+
131
+
132
+ class AsyncModelsWrapper:
133
+ """Async wrapper for models to match the official SDK structure."""
134
+
135
+ def __init__(self, wrapper: BaseGeminiWrapper):
136
+ self.wrapper = wrapper
137
+
138
+ async def generate_content(
139
+ self, *args: Any, **kwargs: Any
140
+ ) -> GenerateContentResponse:
141
+ """Async method for generate_content."""
142
+ execution_id = kwargs.pop("execution_id", None)
143
+ return await self.wrapper.generate_content_async(
144
+ *args, execution_id=execution_id, **kwargs
145
+ )
146
+
147
+ async def generate_content_stream(
148
+ self, *args: Any, **kwargs: Any
149
+ ) -> AsyncIterator[GenerateContentResponse]:
150
+ """Async method for generate_content_stream."""
151
+ execution_id = kwargs.pop("execution_id", None)
152
+ return await self.wrapper.generate_content_stream_async(
153
+ *args, execution_id=execution_id, **kwargs
154
+ )
155
+
156
+
157
+ class GeminiWrapper(BaseGeminiWrapper):
158
+ def generate_content(
159
+ self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
160
+ ) -> GenerateContentResponse:
161
+ """Generate content and log token usage."""
162
+ logger.debug("Generating content with args: %s, kwargs: %s", args, kwargs)
163
+
164
+ response = self.client.models.generate_content(*args, **kwargs)
165
+ usage_data = self._process_response_usage(response)
166
+ if usage_data:
167
+ self._log_usage(usage_data, execution_id=execution_id)
168
+
169
+ return response
170
+
171
+ def generate_content_stream(
172
+ self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
173
+ ) -> Iterator[GenerateContentResponse]:
174
+ """Generate content with streaming and log token usage."""
175
+ logger.debug(
176
+ "Generating content stream with args: %s, kwargs: %s", args, kwargs
177
+ )
178
+
179
+ base_stream = self.client.models.generate_content_stream(*args, **kwargs)
180
+ return GeminiSyncStreamInterceptor(
181
+ base_stream=base_stream,
182
+ usage_callback=_create_usage_callback(execution_id, self._log_usage),
183
+ )
184
+
185
+ async def generate_content_async(
186
+ self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
187
+ ) -> GenerateContentResponse:
188
+ """Generate content asynchronously and log token usage."""
189
+ logger.debug("Generating content async with args: %s, kwargs: %s", args, kwargs)
190
+
191
+ response = await self.client.aio.models.generate_content(*args, **kwargs)
192
+ usage_data = self._process_response_usage(response)
193
+ if usage_data:
194
+ self._log_usage(usage_data, execution_id=execution_id)
195
+
196
+ return response
197
+
198
+ async def generate_content_stream_async(
199
+ self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
200
+ ) -> AsyncIterator[GenerateContentResponse]:
201
+ """Generate content with async streaming and log token usage."""
202
+ logger.debug(
203
+ "Generating content stream async with args: %s, kwargs: %s", args, kwargs
204
+ )
205
+
206
+ base_stream = await self.client.aio.models.generate_content_stream(
207
+ *args, **kwargs
208
+ )
209
+ return GeminiAsyncStreamInterceptor(
210
+ base_stream=base_stream,
211
+ usage_callback=_create_usage_callback(execution_id, self._log_usage),
212
+ )
213
+
214
+
215
+ def tokenator_gemini(
216
+ client: genai.Client,
217
+ db_path: Optional[str] = None,
218
+ provider: str = "gemini",
219
+ ) -> GeminiWrapper:
220
+ """Create a token-tracking wrapper for a Gemini client.
221
+
222
+ Args:
223
+ client: Gemini client instance
224
+ db_path: Optional path to SQLite database for token tracking
225
+ provider: Provider name, defaults to "gemini"
226
+ """
227
+ if not isinstance(client, genai.Client):
228
+ raise ValueError("Client must be an instance of genai.Client")
229
+
230
+ return GeminiWrapper(client=client, db_path=db_path, provider=provider)
@@ -0,0 +1,77 @@
1
+ """Stream interceptors for Gemini responses."""
2
+
3
+ import logging
4
+ from typing import AsyncIterator, Callable, List, Optional, TypeVar, Iterator
5
+
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ _T = TypeVar("_T") # GenerateContentResponse
10
+
11
+
12
+ class GeminiAsyncStreamInterceptor(AsyncIterator[_T]):
13
+ """
14
+ A wrapper around Gemini async stream that intercepts each chunk to handle usage or
15
+ logging logic.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ base_stream: AsyncIterator[_T],
21
+ usage_callback: Optional[Callable[[List[_T]], None]] = None,
22
+ ):
23
+ self._base_stream = base_stream
24
+ self._usage_callback = usage_callback
25
+ self._chunks: List[_T] = []
26
+
27
+ def __aiter__(self) -> AsyncIterator[_T]:
28
+ """Return self as async iterator."""
29
+ return self
30
+
31
+ async def __anext__(self) -> _T:
32
+ """Get next chunk and track it."""
33
+ try:
34
+ chunk = await self._base_stream.__anext__()
35
+ except StopAsyncIteration:
36
+ # Once the base stream is fully consumed, we can do final usage/logging.
37
+ if self._usage_callback and self._chunks:
38
+ self._usage_callback(self._chunks)
39
+ raise
40
+
41
+ # Intercept each chunk
42
+ self._chunks.append(chunk)
43
+ return chunk
44
+
45
+
46
+ class GeminiSyncStreamInterceptor(Iterator[_T]):
47
+ """
48
+ A wrapper around Gemini sync stream that intercepts each chunk to handle usage or
49
+ logging logic.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ base_stream: Iterator[_T],
55
+ usage_callback: Optional[Callable[[List[_T]], None]] = None,
56
+ ):
57
+ self._base_stream = base_stream
58
+ self._usage_callback = usage_callback
59
+ self._chunks: List[_T] = []
60
+
61
+ def __iter__(self) -> Iterator[_T]:
62
+ """Return self as iterator."""
63
+ return self
64
+
65
+ def __next__(self) -> _T:
66
+ """Get next chunk and track it."""
67
+ try:
68
+ chunk = next(self._base_stream)
69
+ except StopIteration:
70
+ # Once the base stream is fully consumed, we can do final usage/logging.
71
+ if self._usage_callback and self._chunks:
72
+ self._usage_callback(self._chunks)
73
+ raise
74
+
75
+ # Intercept each chunk
76
+ self._chunks.append(chunk)
77
+ return chunk