tokenator 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- tokenator/__init__.py +3 -3
- tokenator/anthropic/client_anthropic.py +155 -0
- tokenator/anthropic/stream_interceptors.py +146 -0
- tokenator/base_wrapper.py +26 -13
- tokenator/create_migrations.py +6 -5
- tokenator/migrations/env.py +5 -4
- tokenator/migrations/versions/f6f1f2437513_initial_migration.py +25 -23
- tokenator/migrations.py +9 -6
- tokenator/models.py +15 -4
- tokenator/openai/client_openai.py +163 -0
- tokenator/openai/stream_interceptors.py +146 -0
- tokenator/schemas.py +26 -27
- tokenator/usage.py +114 -47
- tokenator/utils.py +14 -9
- {tokenator-0.1.8.dist-info → tokenator-0.1.10.dist-info}/METADATA +40 -13
- tokenator-0.1.10.dist-info/RECORD +19 -0
- tokenator/client_anthropic.py +0 -148
- tokenator/client_openai.py +0 -151
- tokenator-0.1.8.dist-info/RECORD +0 -17
- {tokenator-0.1.8.dist-info → tokenator-0.1.10.dist-info}/LICENSE +0 -0
- {tokenator-0.1.8.dist-info → tokenator-0.1.10.dist-info}/WHEEL +0 -0
@@ -0,0 +1,163 @@
|
|
1
|
+
"""OpenAI client wrapper with token usage tracking."""
|
2
|
+
|
3
|
+
from typing import Any, Optional, Union, overload, Iterator, AsyncIterator
|
4
|
+
import logging
|
5
|
+
|
6
|
+
from openai import AsyncOpenAI, OpenAI
|
7
|
+
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
8
|
+
|
9
|
+
from ..models import Usage, TokenUsageStats
|
10
|
+
from ..base_wrapper import BaseWrapper, ResponseType
|
11
|
+
from .stream_interceptors import OpenAIAsyncStreamInterceptor, OpenAISyncStreamInterceptor
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class BaseOpenAIWrapper(BaseWrapper):
|
17
|
+
provider = "openai"
|
18
|
+
|
19
|
+
def _process_response_usage(
|
20
|
+
self, response: ResponseType
|
21
|
+
) -> Optional[TokenUsageStats]:
|
22
|
+
"""Process and log usage statistics from a response."""
|
23
|
+
try:
|
24
|
+
if isinstance(response, ChatCompletion):
|
25
|
+
if response.usage is None:
|
26
|
+
return None
|
27
|
+
usage = Usage(
|
28
|
+
prompt_tokens=response.usage.prompt_tokens,
|
29
|
+
completion_tokens=response.usage.completion_tokens,
|
30
|
+
total_tokens=response.usage.total_tokens,
|
31
|
+
)
|
32
|
+
return TokenUsageStats(model=response.model, usage=usage)
|
33
|
+
|
34
|
+
elif isinstance(response, dict):
|
35
|
+
usage_dict = response.get("usage")
|
36
|
+
if not usage_dict:
|
37
|
+
return None
|
38
|
+
usage = Usage(
|
39
|
+
prompt_tokens=usage_dict.get("prompt_tokens", 0),
|
40
|
+
completion_tokens=usage_dict.get("completion_tokens", 0),
|
41
|
+
total_tokens=usage_dict.get("total_tokens", 0),
|
42
|
+
)
|
43
|
+
return TokenUsageStats(
|
44
|
+
model=response.get("model", "unknown"), usage=usage
|
45
|
+
)
|
46
|
+
except Exception as e:
|
47
|
+
logger.warning("Failed to process usage stats: %s", str(e))
|
48
|
+
return None
|
49
|
+
return None
|
50
|
+
|
51
|
+
@property
|
52
|
+
def chat(self):
|
53
|
+
return self
|
54
|
+
|
55
|
+
@property
|
56
|
+
def completions(self):
|
57
|
+
return self
|
58
|
+
|
59
|
+
|
60
|
+
def _create_usage_callback(execution_id, log_usage_fn):
|
61
|
+
"""Creates a callback function for processing usage statistics from stream chunks."""
|
62
|
+
|
63
|
+
def usage_callback(chunks):
|
64
|
+
if not chunks:
|
65
|
+
return
|
66
|
+
# Build usage_data from the first chunk's model
|
67
|
+
usage_data = TokenUsageStats(
|
68
|
+
model=chunks[0].model,
|
69
|
+
usage=Usage(),
|
70
|
+
)
|
71
|
+
# Sum up usage from all chunks
|
72
|
+
has_usage = False
|
73
|
+
for ch in chunks:
|
74
|
+
if ch.usage:
|
75
|
+
has_usage = True
|
76
|
+
usage_data.usage.prompt_tokens += ch.usage.prompt_tokens
|
77
|
+
usage_data.usage.completion_tokens += ch.usage.completion_tokens
|
78
|
+
usage_data.usage.total_tokens += ch.usage.total_tokens
|
79
|
+
|
80
|
+
if has_usage:
|
81
|
+
log_usage_fn(usage_data, execution_id=execution_id)
|
82
|
+
|
83
|
+
return usage_callback
|
84
|
+
|
85
|
+
|
86
|
+
class OpenAIWrapper(BaseOpenAIWrapper):
|
87
|
+
def create(
|
88
|
+
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
89
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletion]]:
|
90
|
+
"""Create a chat completion and log token usage."""
|
91
|
+
logger.debug("Creating chat completion with args: %s, kwargs: %s", args, kwargs)
|
92
|
+
|
93
|
+
if kwargs.get("stream", False):
|
94
|
+
base_stream = self.client.chat.completions.create(*args, **kwargs)
|
95
|
+
return OpenAISyncStreamInterceptor(
|
96
|
+
base_stream=base_stream,
|
97
|
+
usage_callback=_create_usage_callback(execution_id, self._log_usage),
|
98
|
+
)
|
99
|
+
|
100
|
+
response = self.client.chat.completions.create(*args, **kwargs)
|
101
|
+
usage_data = self._process_response_usage(response)
|
102
|
+
if usage_data:
|
103
|
+
self._log_usage(usage_data, execution_id=execution_id)
|
104
|
+
|
105
|
+
return response
|
106
|
+
|
107
|
+
|
108
|
+
class AsyncOpenAIWrapper(BaseOpenAIWrapper):
|
109
|
+
async def create(
|
110
|
+
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
111
|
+
) -> Union[ChatCompletion, AsyncIterator[ChatCompletionChunk]]:
|
112
|
+
"""
|
113
|
+
Create a chat completion and log token usage.
|
114
|
+
"""
|
115
|
+
logger.debug("Creating chat completion with args: %s, kwargs: %s", args, kwargs)
|
116
|
+
|
117
|
+
# If user wants a stream, return an interceptor
|
118
|
+
if kwargs.get("stream", False):
|
119
|
+
base_stream = await self.client.chat.completions.create(*args, **kwargs)
|
120
|
+
return OpenAIAsyncStreamInterceptor(
|
121
|
+
base_stream=base_stream,
|
122
|
+
usage_callback=_create_usage_callback(execution_id, self._log_usage),
|
123
|
+
)
|
124
|
+
|
125
|
+
# Non-streaming path remains unchanged
|
126
|
+
response = await self.client.chat.completions.create(*args, **kwargs)
|
127
|
+
usage_data = self._process_response_usage(response)
|
128
|
+
if usage_data:
|
129
|
+
self._log_usage(usage_data, execution_id=execution_id)
|
130
|
+
return response
|
131
|
+
|
132
|
+
|
133
|
+
@overload
|
134
|
+
def tokenator_openai(
|
135
|
+
client: OpenAI,
|
136
|
+
db_path: Optional[str] = None,
|
137
|
+
) -> OpenAIWrapper: ...
|
138
|
+
|
139
|
+
|
140
|
+
@overload
|
141
|
+
def tokenator_openai(
|
142
|
+
client: AsyncOpenAI,
|
143
|
+
db_path: Optional[str] = None,
|
144
|
+
) -> AsyncOpenAIWrapper: ...
|
145
|
+
|
146
|
+
|
147
|
+
def tokenator_openai(
|
148
|
+
client: Union[OpenAI, AsyncOpenAI],
|
149
|
+
db_path: Optional[str] = None,
|
150
|
+
) -> Union[OpenAIWrapper, AsyncOpenAIWrapper]:
|
151
|
+
"""Create a token-tracking wrapper for an OpenAI client.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
client: OpenAI or AsyncOpenAI client instance
|
155
|
+
db_path: Optional path to SQLite database for token tracking
|
156
|
+
"""
|
157
|
+
if isinstance(client, OpenAI):
|
158
|
+
return OpenAIWrapper(client=client, db_path=db_path)
|
159
|
+
|
160
|
+
if isinstance(client, AsyncOpenAI):
|
161
|
+
return AsyncOpenAIWrapper(client=client, db_path=db_path)
|
162
|
+
|
163
|
+
raise ValueError("Client must be an instance of OpenAI or AsyncOpenAI")
|
@@ -0,0 +1,146 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import AsyncIterator, Callable, List, Optional, TypeVar, Iterator
|
3
|
+
|
4
|
+
from openai import AsyncStream, Stream
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
_T = TypeVar("_T") # or you might specifically do _T = ChatCompletionChunk
|
9
|
+
|
10
|
+
|
11
|
+
class OpenAIAsyncStreamInterceptor(AsyncStream[_T]):
|
12
|
+
"""
|
13
|
+
A wrapper around openai.AsyncStream that delegates all functionality
|
14
|
+
to the 'base_stream' but intercepts each chunk to handle usage or
|
15
|
+
logging logic. This preserves .response and other methods.
|
16
|
+
|
17
|
+
You can store aggregated usage in a local list and process it when
|
18
|
+
the stream ends (StopAsyncIteration).
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
base_stream: AsyncStream[_T],
|
24
|
+
usage_callback: Optional[Callable[[List[_T]], None]] = None,
|
25
|
+
):
|
26
|
+
# We do NOT call super().__init__() because openai.AsyncStream
|
27
|
+
# expects constructor parameters we don't want to re-initialize.
|
28
|
+
# Instead, we just store the base_stream and delegate everything to it.
|
29
|
+
self._base_stream = base_stream
|
30
|
+
self._usage_callback = usage_callback
|
31
|
+
self._chunks: List[_T] = []
|
32
|
+
|
33
|
+
@property
|
34
|
+
def response(self):
|
35
|
+
"""Expose the original stream's 'response' so user code can do stream.response, etc."""
|
36
|
+
return self._base_stream.response
|
37
|
+
|
38
|
+
def __aiter__(self) -> AsyncIterator[_T]:
|
39
|
+
"""
|
40
|
+
Called when we do 'async for chunk in wrapped_stream:'
|
41
|
+
We simply return 'self'. Then __anext__ does the rest.
|
42
|
+
"""
|
43
|
+
return self
|
44
|
+
|
45
|
+
async def __anext__(self) -> _T:
|
46
|
+
"""
|
47
|
+
Intercept iteration. We pull the next chunk from the base_stream.
|
48
|
+
If it's the end, do any final usage logging, then raise StopAsyncIteration.
|
49
|
+
Otherwise, we can accumulate usage info or do whatever we need with the chunk.
|
50
|
+
"""
|
51
|
+
try:
|
52
|
+
chunk = await self._base_stream.__anext__()
|
53
|
+
except StopAsyncIteration:
|
54
|
+
# Once the base stream is fully consumed, we can do final usage/logging.
|
55
|
+
if self._usage_callback and self._chunks:
|
56
|
+
self._usage_callback(self._chunks)
|
57
|
+
raise
|
58
|
+
|
59
|
+
# Intercept each chunk
|
60
|
+
self._chunks.append(chunk)
|
61
|
+
return chunk
|
62
|
+
|
63
|
+
async def __aenter__(self) -> "OpenAIAsyncStreamInterceptor[_T]":
|
64
|
+
"""Support async with ... : usage."""
|
65
|
+
await self._base_stream.__aenter__()
|
66
|
+
return self
|
67
|
+
|
68
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
69
|
+
"""
|
70
|
+
Ensure we propagate __aexit__ to the base stream,
|
71
|
+
so connections are properly closed.
|
72
|
+
"""
|
73
|
+
return await self._base_stream.__aexit__(exc_type, exc_val, exc_tb)
|
74
|
+
|
75
|
+
async def close(self) -> None:
|
76
|
+
"""Delegate close to the base_stream."""
|
77
|
+
await self._base_stream.close()
|
78
|
+
|
79
|
+
|
80
|
+
class OpenAISyncStreamInterceptor(Stream[_T]):
|
81
|
+
"""
|
82
|
+
A wrapper around openai.Stream that delegates all functionality
|
83
|
+
to the 'base_stream' but intercepts each chunk to handle usage or
|
84
|
+
logging logic. This preserves .response and other methods.
|
85
|
+
|
86
|
+
You can store aggregated usage in a local list and process it when
|
87
|
+
the stream ends (StopIteration).
|
88
|
+
"""
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
base_stream: Stream[_T],
|
93
|
+
usage_callback: Optional[Callable[[List[_T]], None]] = None,
|
94
|
+
):
|
95
|
+
# We do NOT call super().__init__() because openai.Stream
|
96
|
+
# expects constructor parameters we don't want to re-initialize.
|
97
|
+
# Instead, we just store the base_stream and delegate everything to it.
|
98
|
+
self._base_stream = base_stream
|
99
|
+
self._usage_callback = usage_callback
|
100
|
+
self._chunks: List[_T] = []
|
101
|
+
|
102
|
+
@property
|
103
|
+
def response(self):
|
104
|
+
"""Expose the original stream's 'response' so user code can do stream.response, etc."""
|
105
|
+
return self._base_stream.response
|
106
|
+
|
107
|
+
def __iter__(self) -> Iterator[_T]:
|
108
|
+
"""
|
109
|
+
Called when we do 'for chunk in wrapped_stream:'
|
110
|
+
We simply return 'self'. Then __next__ does the rest.
|
111
|
+
"""
|
112
|
+
return self
|
113
|
+
|
114
|
+
def __next__(self) -> _T:
|
115
|
+
"""
|
116
|
+
Intercept iteration. We pull the next chunk from the base_stream.
|
117
|
+
If it's the end, do any final usage logging, then raise StopIteration.
|
118
|
+
Otherwise, we can accumulate usage info or do whatever we need with the chunk.
|
119
|
+
"""
|
120
|
+
try:
|
121
|
+
chunk = self._base_stream.__next__()
|
122
|
+
except StopIteration:
|
123
|
+
# Once the base stream is fully consumed, we can do final usage/logging.
|
124
|
+
if self._usage_callback and self._chunks:
|
125
|
+
self._usage_callback(self._chunks)
|
126
|
+
raise
|
127
|
+
|
128
|
+
# Intercept each chunk
|
129
|
+
self._chunks.append(chunk)
|
130
|
+
return chunk
|
131
|
+
|
132
|
+
def __enter__(self) -> "OpenAISyncStreamInterceptor[_T]":
|
133
|
+
"""Support with ... : usage."""
|
134
|
+
self._base_stream.__enter__()
|
135
|
+
return self
|
136
|
+
|
137
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
138
|
+
"""
|
139
|
+
Ensure we propagate __exit__ to the base stream,
|
140
|
+
so connections are properly closed.
|
141
|
+
"""
|
142
|
+
return self._base_stream.__exit__(exc_type, exc_val, exc_tb)
|
143
|
+
|
144
|
+
def close(self) -> None:
|
145
|
+
"""Delegate close to the base_stream."""
|
146
|
+
self._base_stream.close()
|
tokenator/schemas.py
CHANGED
@@ -1,26 +1,22 @@
|
|
1
1
|
"""SQLAlchemy models for tokenator."""
|
2
2
|
|
3
|
-
import uuid
|
4
3
|
from datetime import datetime
|
5
|
-
import os
|
6
4
|
|
7
|
-
from sqlalchemy import create_engine, Column, Integer, String, DateTime,
|
5
|
+
from sqlalchemy import create_engine, Column, Integer, String, DateTime, Index
|
8
6
|
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
9
7
|
|
10
8
|
from .utils import get_default_db_path
|
11
9
|
|
12
10
|
Base = declarative_base()
|
13
11
|
|
12
|
+
|
14
13
|
def get_engine(db_path: str = None):
|
15
14
|
"""Create SQLAlchemy engine with the given database path."""
|
16
15
|
if db_path is None:
|
17
|
-
|
18
|
-
import google.colab # type: ignore
|
19
|
-
db_path = '/content/tokenator.db'
|
20
|
-
except ImportError:
|
21
|
-
db_path = get_default_db_path()
|
16
|
+
db_path = get_default_db_path()
|
22
17
|
return create_engine(f"sqlite:///{db_path}", echo=False)
|
23
18
|
|
19
|
+
|
24
20
|
def get_session(db_path: str = None):
|
25
21
|
"""Create a thread-safe session factory."""
|
26
22
|
engine = get_engine(db_path)
|
@@ -28,39 +24,42 @@ def get_session(db_path: str = None):
|
|
28
24
|
session_factory = sessionmaker(bind=engine)
|
29
25
|
return scoped_session(session_factory)
|
30
26
|
|
27
|
+
|
31
28
|
class TokenUsage(Base):
|
32
29
|
"""Model for tracking token usage."""
|
33
|
-
|
30
|
+
|
34
31
|
__tablename__ = "token_usage"
|
35
|
-
|
32
|
+
|
36
33
|
id = Column(Integer, primary_key=True)
|
37
34
|
execution_id = Column(String, nullable=False)
|
38
35
|
provider = Column(String, nullable=False)
|
39
36
|
model = Column(String, nullable=False)
|
40
37
|
created_at = Column(DateTime, nullable=False, default=datetime.now)
|
41
|
-
updated_at = Column(
|
38
|
+
updated_at = Column(
|
39
|
+
DateTime, nullable=False, default=datetime.now, onupdate=datetime.now
|
40
|
+
)
|
42
41
|
prompt_tokens = Column(Integer, nullable=False)
|
43
42
|
completion_tokens = Column(Integer, nullable=False)
|
44
43
|
total_tokens = Column(Integer, nullable=False)
|
45
|
-
|
44
|
+
|
46
45
|
# Create indexes
|
47
46
|
__table_args__ = (
|
48
|
-
Index(
|
49
|
-
Index(
|
50
|
-
Index(
|
51
|
-
Index(
|
47
|
+
Index("idx_created_at", "created_at"),
|
48
|
+
Index("idx_execution_id", "execution_id"),
|
49
|
+
Index("idx_provider", "provider"),
|
50
|
+
Index("idx_model", "model"),
|
52
51
|
)
|
53
|
-
|
52
|
+
|
54
53
|
def to_dict(self):
|
55
54
|
"""Convert model instance to dictionary."""
|
56
55
|
return {
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
}
|
56
|
+
"id": self.id,
|
57
|
+
"execution_id": self.execution_id,
|
58
|
+
"provider": self.provider,
|
59
|
+
"model": self.model,
|
60
|
+
"created_at": self.created_at,
|
61
|
+
"updated_at": self.updated_at,
|
62
|
+
"prompt_tokens": self.prompt_tokens,
|
63
|
+
"completion_tokens": self.completion_tokens,
|
64
|
+
"total_tokens": self.total_tokens,
|
65
|
+
}
|