tokenator 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tokenator/__init__.py +8 -1
- tokenator/base_wrapper.py +4 -1
- tokenator/gemini/__init__.py +5 -0
- tokenator/gemini/client_gemini.py +230 -0
- tokenator/gemini/stream_interceptors.py +77 -0
- tokenator/usage.py +464 -377
- tokenator/utils.py +7 -4
- {tokenator-0.1.15.dist-info → tokenator-0.2.0.dist-info}/METADATA +63 -6
- {tokenator-0.1.15.dist-info → tokenator-0.2.0.dist-info}/RECORD +11 -8
- {tokenator-0.1.15.dist-info → tokenator-0.2.0.dist-info}/WHEEL +1 -1
- {tokenator-0.1.15.dist-info → tokenator-0.2.0.dist-info}/LICENSE +0 -0
tokenator/__init__.py
CHANGED
@@ -3,11 +3,18 @@
|
|
3
3
|
import logging
|
4
4
|
from .openai.client_openai import tokenator_openai
|
5
5
|
from .anthropic.client_anthropic import tokenator_anthropic
|
6
|
+
from .gemini.client_gemini import tokenator_gemini
|
6
7
|
from . import usage
|
7
8
|
from .utils import get_default_db_path
|
8
9
|
from .usage import TokenUsageService
|
9
10
|
|
10
11
|
usage = TokenUsageService() # noqa: F811
|
11
|
-
__all__ = [
|
12
|
+
__all__ = [
|
13
|
+
"tokenator_openai",
|
14
|
+
"tokenator_anthropic",
|
15
|
+
"tokenator_gemini",
|
16
|
+
"usage",
|
17
|
+
"get_default_db_path",
|
18
|
+
]
|
12
19
|
|
13
20
|
logger = logging.getLogger(__name__)
|
tokenator/base_wrapper.py
CHANGED
@@ -112,7 +112,10 @@ class BaseWrapper:
|
|
112
112
|
try:
|
113
113
|
self._log_usage_impl(token_usage_stats, session, execution_id)
|
114
114
|
session.commit()
|
115
|
-
logger.debug(
|
115
|
+
logger.debug(
|
116
|
+
"Successfully committed token usage for execution_id: %s",
|
117
|
+
execution_id,
|
118
|
+
)
|
116
119
|
except Exception as e:
|
117
120
|
logger.error("Failed to log token usage: %s", str(e))
|
118
121
|
session.rollback()
|
@@ -0,0 +1,230 @@
|
|
1
|
+
"""Gemini client wrapper with token usage tracking."""
|
2
|
+
|
3
|
+
from typing import Any, Optional, Iterator, AsyncIterator
|
4
|
+
import logging
|
5
|
+
|
6
|
+
from google import genai
|
7
|
+
from google.genai.types import GenerateContentResponse
|
8
|
+
|
9
|
+
from ..models import (
|
10
|
+
TokenMetrics,
|
11
|
+
TokenUsageStats,
|
12
|
+
)
|
13
|
+
from ..base_wrapper import BaseWrapper, ResponseType
|
14
|
+
from .stream_interceptors import (
|
15
|
+
GeminiAsyncStreamInterceptor,
|
16
|
+
GeminiSyncStreamInterceptor,
|
17
|
+
)
|
18
|
+
from ..state import is_tokenator_enabled
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
def _create_usage_callback(execution_id, log_usage_fn):
|
24
|
+
"""Creates a callback function for processing usage statistics from stream chunks."""
|
25
|
+
|
26
|
+
def usage_callback(chunks):
|
27
|
+
if not chunks:
|
28
|
+
return
|
29
|
+
|
30
|
+
# Skip if tokenator is disabled
|
31
|
+
if not is_tokenator_enabled:
|
32
|
+
logger.debug("Tokenator is disabled - skipping stream usage logging")
|
33
|
+
return
|
34
|
+
|
35
|
+
logger.debug("Processing stream usage for execution_id: %s", execution_id)
|
36
|
+
|
37
|
+
# Build usage_data from the first chunk's model
|
38
|
+
usage_data = TokenUsageStats(
|
39
|
+
model=chunks[0].model_version,
|
40
|
+
usage=TokenMetrics(),
|
41
|
+
)
|
42
|
+
|
43
|
+
# Only take usage from the last chunk as it contains complete usage info
|
44
|
+
last_chunk = chunks[-1]
|
45
|
+
if last_chunk.usage_metadata:
|
46
|
+
usage_data.usage.prompt_tokens = (
|
47
|
+
last_chunk.usage_metadata.prompt_token_count
|
48
|
+
)
|
49
|
+
usage_data.usage.completion_tokens = (
|
50
|
+
last_chunk.usage_metadata.candidates_token_count or 0
|
51
|
+
)
|
52
|
+
usage_data.usage.total_tokens = last_chunk.usage_metadata.total_token_count
|
53
|
+
log_usage_fn(usage_data, execution_id=execution_id)
|
54
|
+
|
55
|
+
return usage_callback
|
56
|
+
|
57
|
+
|
58
|
+
class BaseGeminiWrapper(BaseWrapper):
|
59
|
+
def __init__(self, client, db_path=None, provider: str = "gemini"):
|
60
|
+
super().__init__(client, db_path)
|
61
|
+
self.provider = provider
|
62
|
+
self._async_wrapper = None
|
63
|
+
|
64
|
+
def _process_response_usage(
|
65
|
+
self, response: ResponseType
|
66
|
+
) -> Optional[TokenUsageStats]:
|
67
|
+
"""Process and log usage statistics from a response."""
|
68
|
+
try:
|
69
|
+
if isinstance(response, GenerateContentResponse):
|
70
|
+
if response.usage_metadata is None:
|
71
|
+
return None
|
72
|
+
usage = TokenMetrics(
|
73
|
+
prompt_tokens=response.usage_metadata.prompt_token_count,
|
74
|
+
completion_tokens=response.usage_metadata.candidates_token_count,
|
75
|
+
total_tokens=response.usage_metadata.total_token_count,
|
76
|
+
)
|
77
|
+
return TokenUsageStats(model=response.model_version, usage=usage)
|
78
|
+
|
79
|
+
elif isinstance(response, dict):
|
80
|
+
usage_dict = response.get("usage_metadata")
|
81
|
+
if not usage_dict:
|
82
|
+
return None
|
83
|
+
usage = TokenMetrics(
|
84
|
+
prompt_tokens=usage_dict.get("prompt_token_count", 0),
|
85
|
+
completion_tokens=usage_dict.get("candidates_token_count", 0),
|
86
|
+
total_tokens=usage_dict.get("total_token_count", 0),
|
87
|
+
)
|
88
|
+
return TokenUsageStats(
|
89
|
+
model=response.get("model", "unknown"), usage=usage
|
90
|
+
)
|
91
|
+
except Exception as e:
|
92
|
+
logger.warning("Failed to process usage stats: %s", str(e))
|
93
|
+
return None
|
94
|
+
return None
|
95
|
+
|
96
|
+
@property
|
97
|
+
def chat(self):
|
98
|
+
return self
|
99
|
+
|
100
|
+
@property
|
101
|
+
def chats(self):
|
102
|
+
return self
|
103
|
+
|
104
|
+
@property
|
105
|
+
def models(self):
|
106
|
+
return self
|
107
|
+
|
108
|
+
@property
|
109
|
+
def aio(self):
|
110
|
+
if self._async_wrapper is None:
|
111
|
+
self._async_wrapper = AsyncGeminiWrapper(self)
|
112
|
+
return self._async_wrapper
|
113
|
+
|
114
|
+
def count_tokens(self, *args: Any, **kwargs: Any):
|
115
|
+
return self.client.models.count_tokens(*args, **kwargs)
|
116
|
+
|
117
|
+
|
118
|
+
class AsyncGeminiWrapper:
|
119
|
+
"""Async wrapper for Gemini client to match the official SDK structure."""
|
120
|
+
|
121
|
+
def __init__(self, wrapper: BaseGeminiWrapper):
|
122
|
+
self.wrapper = wrapper
|
123
|
+
self._models = None
|
124
|
+
|
125
|
+
@property
|
126
|
+
def models(self):
|
127
|
+
if self._models is None:
|
128
|
+
self._models = AsyncModelsWrapper(self.wrapper)
|
129
|
+
return self._models
|
130
|
+
|
131
|
+
|
132
|
+
class AsyncModelsWrapper:
|
133
|
+
"""Async wrapper for models to match the official SDK structure."""
|
134
|
+
|
135
|
+
def __init__(self, wrapper: BaseGeminiWrapper):
|
136
|
+
self.wrapper = wrapper
|
137
|
+
|
138
|
+
async def generate_content(
|
139
|
+
self, *args: Any, **kwargs: Any
|
140
|
+
) -> GenerateContentResponse:
|
141
|
+
"""Async method for generate_content."""
|
142
|
+
execution_id = kwargs.pop("execution_id", None)
|
143
|
+
return await self.wrapper.generate_content_async(
|
144
|
+
*args, execution_id=execution_id, **kwargs
|
145
|
+
)
|
146
|
+
|
147
|
+
async def generate_content_stream(
|
148
|
+
self, *args: Any, **kwargs: Any
|
149
|
+
) -> AsyncIterator[GenerateContentResponse]:
|
150
|
+
"""Async method for generate_content_stream."""
|
151
|
+
execution_id = kwargs.pop("execution_id", None)
|
152
|
+
return await self.wrapper.generate_content_stream_async(
|
153
|
+
*args, execution_id=execution_id, **kwargs
|
154
|
+
)
|
155
|
+
|
156
|
+
|
157
|
+
class GeminiWrapper(BaseGeminiWrapper):
|
158
|
+
def generate_content(
|
159
|
+
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
160
|
+
) -> GenerateContentResponse:
|
161
|
+
"""Generate content and log token usage."""
|
162
|
+
logger.debug("Generating content with args: %s, kwargs: %s", args, kwargs)
|
163
|
+
|
164
|
+
response = self.client.models.generate_content(*args, **kwargs)
|
165
|
+
usage_data = self._process_response_usage(response)
|
166
|
+
if usage_data:
|
167
|
+
self._log_usage(usage_data, execution_id=execution_id)
|
168
|
+
|
169
|
+
return response
|
170
|
+
|
171
|
+
def generate_content_stream(
|
172
|
+
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
173
|
+
) -> Iterator[GenerateContentResponse]:
|
174
|
+
"""Generate content with streaming and log token usage."""
|
175
|
+
logger.debug(
|
176
|
+
"Generating content stream with args: %s, kwargs: %s", args, kwargs
|
177
|
+
)
|
178
|
+
|
179
|
+
base_stream = self.client.models.generate_content_stream(*args, **kwargs)
|
180
|
+
return GeminiSyncStreamInterceptor(
|
181
|
+
base_stream=base_stream,
|
182
|
+
usage_callback=_create_usage_callback(execution_id, self._log_usage),
|
183
|
+
)
|
184
|
+
|
185
|
+
async def generate_content_async(
|
186
|
+
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
187
|
+
) -> GenerateContentResponse:
|
188
|
+
"""Generate content asynchronously and log token usage."""
|
189
|
+
logger.debug("Generating content async with args: %s, kwargs: %s", args, kwargs)
|
190
|
+
|
191
|
+
response = await self.client.aio.models.generate_content(*args, **kwargs)
|
192
|
+
usage_data = self._process_response_usage(response)
|
193
|
+
if usage_data:
|
194
|
+
self._log_usage(usage_data, execution_id=execution_id)
|
195
|
+
|
196
|
+
return response
|
197
|
+
|
198
|
+
async def generate_content_stream_async(
|
199
|
+
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
200
|
+
) -> AsyncIterator[GenerateContentResponse]:
|
201
|
+
"""Generate content with async streaming and log token usage."""
|
202
|
+
logger.debug(
|
203
|
+
"Generating content stream async with args: %s, kwargs: %s", args, kwargs
|
204
|
+
)
|
205
|
+
|
206
|
+
base_stream = await self.client.aio.models.generate_content_stream(
|
207
|
+
*args, **kwargs
|
208
|
+
)
|
209
|
+
return GeminiAsyncStreamInterceptor(
|
210
|
+
base_stream=base_stream,
|
211
|
+
usage_callback=_create_usage_callback(execution_id, self._log_usage),
|
212
|
+
)
|
213
|
+
|
214
|
+
|
215
|
+
def tokenator_gemini(
|
216
|
+
client: genai.Client,
|
217
|
+
db_path: Optional[str] = None,
|
218
|
+
provider: str = "gemini",
|
219
|
+
) -> GeminiWrapper:
|
220
|
+
"""Create a token-tracking wrapper for a Gemini client.
|
221
|
+
|
222
|
+
Args:
|
223
|
+
client: Gemini client instance
|
224
|
+
db_path: Optional path to SQLite database for token tracking
|
225
|
+
provider: Provider name, defaults to "gemini"
|
226
|
+
"""
|
227
|
+
if not isinstance(client, genai.Client):
|
228
|
+
raise ValueError("Client must be an instance of genai.Client")
|
229
|
+
|
230
|
+
return GeminiWrapper(client=client, db_path=db_path, provider=provider)
|
@@ -0,0 +1,77 @@
|
|
1
|
+
"""Stream interceptors for Gemini responses."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import AsyncIterator, Callable, List, Optional, TypeVar, Iterator
|
5
|
+
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
_T = TypeVar("_T") # GenerateContentResponse
|
10
|
+
|
11
|
+
|
12
|
+
class GeminiAsyncStreamInterceptor(AsyncIterator[_T]):
|
13
|
+
"""
|
14
|
+
A wrapper around Gemini async stream that intercepts each chunk to handle usage or
|
15
|
+
logging logic.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
base_stream: AsyncIterator[_T],
|
21
|
+
usage_callback: Optional[Callable[[List[_T]], None]] = None,
|
22
|
+
):
|
23
|
+
self._base_stream = base_stream
|
24
|
+
self._usage_callback = usage_callback
|
25
|
+
self._chunks: List[_T] = []
|
26
|
+
|
27
|
+
def __aiter__(self) -> AsyncIterator[_T]:
|
28
|
+
"""Return self as async iterator."""
|
29
|
+
return self
|
30
|
+
|
31
|
+
async def __anext__(self) -> _T:
|
32
|
+
"""Get next chunk and track it."""
|
33
|
+
try:
|
34
|
+
chunk = await self._base_stream.__anext__()
|
35
|
+
except StopAsyncIteration:
|
36
|
+
# Once the base stream is fully consumed, we can do final usage/logging.
|
37
|
+
if self._usage_callback and self._chunks:
|
38
|
+
self._usage_callback(self._chunks)
|
39
|
+
raise
|
40
|
+
|
41
|
+
# Intercept each chunk
|
42
|
+
self._chunks.append(chunk)
|
43
|
+
return chunk
|
44
|
+
|
45
|
+
|
46
|
+
class GeminiSyncStreamInterceptor(Iterator[_T]):
|
47
|
+
"""
|
48
|
+
A wrapper around Gemini sync stream that intercepts each chunk to handle usage or
|
49
|
+
logging logic.
|
50
|
+
"""
|
51
|
+
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
base_stream: Iterator[_T],
|
55
|
+
usage_callback: Optional[Callable[[List[_T]], None]] = None,
|
56
|
+
):
|
57
|
+
self._base_stream = base_stream
|
58
|
+
self._usage_callback = usage_callback
|
59
|
+
self._chunks: List[_T] = []
|
60
|
+
|
61
|
+
def __iter__(self) -> Iterator[_T]:
|
62
|
+
"""Return self as iterator."""
|
63
|
+
return self
|
64
|
+
|
65
|
+
def __next__(self) -> _T:
|
66
|
+
"""Get next chunk and track it."""
|
67
|
+
try:
|
68
|
+
chunk = next(self._base_stream)
|
69
|
+
except StopIteration:
|
70
|
+
# Once the base stream is fully consumed, we can do final usage/logging.
|
71
|
+
if self._usage_callback and self._chunks:
|
72
|
+
self._usage_callback(self._chunks)
|
73
|
+
raise
|
74
|
+
|
75
|
+
# Intercept each chunk
|
76
|
+
self._chunks.append(chunk)
|
77
|
+
return chunk
|