tokenator 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- tokenator/__init__.py +2 -2
- tokenator/anthropic/client_anthropic.py +154 -0
- tokenator/anthropic/stream_interceptors.py +146 -0
- tokenator/base_wrapper.py +26 -13
- tokenator/create_migrations.py +6 -5
- tokenator/migrations/env.py +5 -4
- tokenator/migrations/versions/f6f1f2437513_initial_migration.py +25 -23
- tokenator/migrations.py +9 -6
- tokenator/models.py +15 -4
- tokenator/openai/client_openai.py +66 -70
- tokenator/openai/stream_interceptors.py +146 -0
- tokenator/schemas.py +26 -27
- tokenator/usage.py +114 -47
- tokenator/utils.py +14 -9
- {tokenator-0.1.9.dist-info → tokenator-0.1.11.dist-info}/METADATA +72 -17
- tokenator-0.1.11.dist-info/RECORD +19 -0
- {tokenator-0.1.9.dist-info → tokenator-0.1.11.dist-info}/WHEEL +1 -1
- tokenator/client_anthropic.py +0 -148
- tokenator/openai/AsyncStreamInterceptor.py +0 -78
- tokenator-0.1.9.dist-info/RECORD +0 -18
- {tokenator-0.1.9.dist-info → tokenator-0.1.11.dist-info}/LICENSE +0 -0
@@ -1,21 +1,24 @@
|
|
1
1
|
"""OpenAI client wrapper with token usage tracking."""
|
2
2
|
|
3
|
-
from typing import Any,
|
3
|
+
from typing import Any, Optional, Union, overload, Iterator, AsyncIterator
|
4
4
|
import logging
|
5
5
|
|
6
|
-
from openai import AsyncOpenAI,
|
6
|
+
from openai import AsyncOpenAI, OpenAI
|
7
7
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
8
8
|
|
9
9
|
from ..models import Usage, TokenUsageStats
|
10
10
|
from ..base_wrapper import BaseWrapper, ResponseType
|
11
|
-
from .
|
11
|
+
from .stream_interceptors import OpenAIAsyncStreamInterceptor, OpenAISyncStreamInterceptor
|
12
12
|
|
13
13
|
logger = logging.getLogger(__name__)
|
14
14
|
|
15
|
+
|
15
16
|
class BaseOpenAIWrapper(BaseWrapper):
|
16
17
|
provider = "openai"
|
17
18
|
|
18
|
-
def _process_response_usage(
|
19
|
+
def _process_response_usage(
|
20
|
+
self, response: ResponseType
|
21
|
+
) -> Optional[TokenUsageStats]:
|
19
22
|
"""Process and log usage statistics from a response."""
|
20
23
|
try:
|
21
24
|
if isinstance(response, ChatCompletion):
|
@@ -27,19 +30,18 @@ class BaseOpenAIWrapper(BaseWrapper):
|
|
27
30
|
total_tokens=response.usage.total_tokens,
|
28
31
|
)
|
29
32
|
return TokenUsageStats(model=response.model, usage=usage)
|
30
|
-
|
33
|
+
|
31
34
|
elif isinstance(response, dict):
|
32
|
-
usage_dict = response.get(
|
35
|
+
usage_dict = response.get("usage")
|
33
36
|
if not usage_dict:
|
34
37
|
return None
|
35
38
|
usage = Usage(
|
36
|
-
prompt_tokens=usage_dict.get(
|
37
|
-
completion_tokens=usage_dict.get(
|
38
|
-
total_tokens=usage_dict.get(
|
39
|
+
prompt_tokens=usage_dict.get("prompt_tokens", 0),
|
40
|
+
completion_tokens=usage_dict.get("completion_tokens", 0),
|
41
|
+
total_tokens=usage_dict.get("total_tokens", 0),
|
39
42
|
)
|
40
43
|
return TokenUsageStats(
|
41
|
-
model=response.get(
|
42
|
-
usage=usage
|
44
|
+
model=response.get("model", "unknown"), usage=usage
|
43
45
|
)
|
44
46
|
except Exception as e:
|
45
47
|
logger.warning("Failed to process usage stats: %s", str(e))
|
@@ -54,45 +56,58 @@ class BaseOpenAIWrapper(BaseWrapper):
|
|
54
56
|
def completions(self):
|
55
57
|
return self
|
56
58
|
|
59
|
+
|
60
|
+
def _create_usage_callback(execution_id, log_usage_fn):
|
61
|
+
"""Creates a callback function for processing usage statistics from stream chunks."""
|
62
|
+
|
63
|
+
def usage_callback(chunks):
|
64
|
+
if not chunks:
|
65
|
+
return
|
66
|
+
# Build usage_data from the first chunk's model
|
67
|
+
usage_data = TokenUsageStats(
|
68
|
+
model=chunks[0].model,
|
69
|
+
usage=Usage(),
|
70
|
+
)
|
71
|
+
# Sum up usage from all chunks
|
72
|
+
has_usage = False
|
73
|
+
for ch in chunks:
|
74
|
+
if ch.usage:
|
75
|
+
has_usage = True
|
76
|
+
usage_data.usage.prompt_tokens += ch.usage.prompt_tokens
|
77
|
+
usage_data.usage.completion_tokens += ch.usage.completion_tokens
|
78
|
+
usage_data.usage.total_tokens += ch.usage.total_tokens
|
79
|
+
|
80
|
+
if has_usage:
|
81
|
+
log_usage_fn(usage_data, execution_id=execution_id)
|
82
|
+
|
83
|
+
return usage_callback
|
84
|
+
|
85
|
+
|
57
86
|
class OpenAIWrapper(BaseOpenAIWrapper):
|
58
|
-
def create(
|
87
|
+
def create(
|
88
|
+
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
89
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletion]]:
|
59
90
|
"""Create a chat completion and log token usage."""
|
60
91
|
logger.debug("Creating chat completion with args: %s, kwargs: %s", args, kwargs)
|
61
|
-
|
92
|
+
|
93
|
+
if kwargs.get("stream", False):
|
94
|
+
base_stream = self.client.chat.completions.create(*args, **kwargs)
|
95
|
+
return OpenAISyncStreamInterceptor(
|
96
|
+
base_stream=base_stream,
|
97
|
+
usage_callback=_create_usage_callback(execution_id, self._log_usage),
|
98
|
+
)
|
99
|
+
|
62
100
|
response = self.client.chat.completions.create(*args, **kwargs)
|
63
|
-
|
64
|
-
if
|
65
|
-
usage_data = self._process_response_usage(response)
|
66
|
-
if usage_data:
|
67
|
-
self._log_usage(usage_data, execution_id=execution_id)
|
68
|
-
return response
|
69
|
-
|
70
|
-
return self._wrap_streaming_response(response, execution_id)
|
71
|
-
|
72
|
-
def _wrap_streaming_response(self, response_iter: Stream[ChatCompletionChunk], execution_id: Optional[str]) -> Iterator[ChatCompletionChunk]:
|
73
|
-
"""Wrap streaming response to capture final usage stats"""
|
74
|
-
chunks_with_usage = []
|
75
|
-
for chunk in response_iter:
|
76
|
-
if isinstance(chunk, ChatCompletionChunk) and chunk.usage is not None:
|
77
|
-
chunks_with_usage.append(chunk)
|
78
|
-
yield chunk
|
79
|
-
|
80
|
-
if len(chunks_with_usage) > 0:
|
81
|
-
usage_data: TokenUsageStats = TokenUsageStats(model=chunks_with_usage[0].model, usage=Usage())
|
82
|
-
for chunk in chunks_with_usage:
|
83
|
-
usage_data.usage.prompt_tokens += chunk.usage.prompt_tokens
|
84
|
-
usage_data.usage.completion_tokens += chunk.usage.completion_tokens
|
85
|
-
usage_data.usage.total_tokens += chunk.usage.total_tokens
|
86
|
-
|
101
|
+
usage_data = self._process_response_usage(response)
|
102
|
+
if usage_data:
|
87
103
|
self._log_usage(usage_data, execution_id=execution_id)
|
88
|
-
|
104
|
+
|
105
|
+
return response
|
106
|
+
|
89
107
|
|
90
108
|
class AsyncOpenAIWrapper(BaseOpenAIWrapper):
|
91
109
|
async def create(
|
92
|
-
self,
|
93
|
-
*args: Any,
|
94
|
-
execution_id: Optional[str] = None,
|
95
|
-
**kwargs: Any
|
110
|
+
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
96
111
|
) -> Union[ChatCompletion, AsyncIterator[ChatCompletionChunk]]:
|
97
112
|
"""
|
98
113
|
Create a chat completion and log token usage.
|
@@ -102,32 +117,9 @@ class AsyncOpenAIWrapper(BaseOpenAIWrapper):
|
|
102
117
|
# If user wants a stream, return an interceptor
|
103
118
|
if kwargs.get("stream", False):
|
104
119
|
base_stream = await self.client.chat.completions.create(*args, **kwargs)
|
105
|
-
|
106
|
-
# Define a callback that will get called once the stream ends
|
107
|
-
def usage_callback(chunks):
|
108
|
-
# Mimic your old logic to gather usage from chunk.usage
|
109
|
-
# e.g. ChatCompletionChunk.usage
|
110
|
-
# Then call self._log_usage(...)
|
111
|
-
if not chunks:
|
112
|
-
return
|
113
|
-
# Build usage_data from the first chunk's model
|
114
|
-
usage_data = TokenUsageStats(
|
115
|
-
model=chunks[0].model,
|
116
|
-
usage=Usage(),
|
117
|
-
)
|
118
|
-
# Sum up usage from all chunks
|
119
|
-
for ch in chunks:
|
120
|
-
if ch.usage:
|
121
|
-
usage_data.usage.prompt_tokens += ch.usage.prompt_tokens
|
122
|
-
usage_data.usage.completion_tokens += ch.usage.completion_tokens
|
123
|
-
usage_data.usage.total_tokens += ch.usage.total_tokens
|
124
|
-
|
125
|
-
self._log_usage(usage_data, execution_id=execution_id)
|
126
|
-
|
127
|
-
# Return the interceptor that wraps the real AsyncStream
|
128
|
-
return AsyncStreamInterceptor(
|
120
|
+
return OpenAIAsyncStreamInterceptor(
|
129
121
|
base_stream=base_stream,
|
130
|
-
usage_callback=
|
122
|
+
usage_callback=_create_usage_callback(execution_id, self._log_usage),
|
131
123
|
)
|
132
124
|
|
133
125
|
# Non-streaming path remains unchanged
|
@@ -136,32 +128,36 @@ class AsyncOpenAIWrapper(BaseOpenAIWrapper):
|
|
136
128
|
if usage_data:
|
137
129
|
self._log_usage(usage_data, execution_id=execution_id)
|
138
130
|
return response
|
131
|
+
|
132
|
+
|
139
133
|
@overload
|
140
134
|
def tokenator_openai(
|
141
135
|
client: OpenAI,
|
142
136
|
db_path: Optional[str] = None,
|
143
137
|
) -> OpenAIWrapper: ...
|
144
138
|
|
139
|
+
|
145
140
|
@overload
|
146
141
|
def tokenator_openai(
|
147
142
|
client: AsyncOpenAI,
|
148
143
|
db_path: Optional[str] = None,
|
149
144
|
) -> AsyncOpenAIWrapper: ...
|
150
145
|
|
146
|
+
|
151
147
|
def tokenator_openai(
|
152
148
|
client: Union[OpenAI, AsyncOpenAI],
|
153
149
|
db_path: Optional[str] = None,
|
154
150
|
) -> Union[OpenAIWrapper, AsyncOpenAIWrapper]:
|
155
151
|
"""Create a token-tracking wrapper for an OpenAI client.
|
156
|
-
|
152
|
+
|
157
153
|
Args:
|
158
154
|
client: OpenAI or AsyncOpenAI client instance
|
159
155
|
db_path: Optional path to SQLite database for token tracking
|
160
156
|
"""
|
161
157
|
if isinstance(client, OpenAI):
|
162
158
|
return OpenAIWrapper(client=client, db_path=db_path)
|
163
|
-
|
159
|
+
|
164
160
|
if isinstance(client, AsyncOpenAI):
|
165
161
|
return AsyncOpenAIWrapper(client=client, db_path=db_path)
|
166
|
-
|
162
|
+
|
167
163
|
raise ValueError("Client must be an instance of OpenAI or AsyncOpenAI")
|
@@ -0,0 +1,146 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import AsyncIterator, Callable, List, Optional, TypeVar, Iterator
|
3
|
+
|
4
|
+
from openai import AsyncStream, Stream
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
_T = TypeVar("_T") # or you might specifically do _T = ChatCompletionChunk
|
9
|
+
|
10
|
+
|
11
|
+
class OpenAIAsyncStreamInterceptor(AsyncStream[_T]):
|
12
|
+
"""
|
13
|
+
A wrapper around openai.AsyncStream that delegates all functionality
|
14
|
+
to the 'base_stream' but intercepts each chunk to handle usage or
|
15
|
+
logging logic. This preserves .response and other methods.
|
16
|
+
|
17
|
+
You can store aggregated usage in a local list and process it when
|
18
|
+
the stream ends (StopAsyncIteration).
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
base_stream: AsyncStream[_T],
|
24
|
+
usage_callback: Optional[Callable[[List[_T]], None]] = None,
|
25
|
+
):
|
26
|
+
# We do NOT call super().__init__() because openai.AsyncStream
|
27
|
+
# expects constructor parameters we don't want to re-initialize.
|
28
|
+
# Instead, we just store the base_stream and delegate everything to it.
|
29
|
+
self._base_stream = base_stream
|
30
|
+
self._usage_callback = usage_callback
|
31
|
+
self._chunks: List[_T] = []
|
32
|
+
|
33
|
+
@property
|
34
|
+
def response(self):
|
35
|
+
"""Expose the original stream's 'response' so user code can do stream.response, etc."""
|
36
|
+
return self._base_stream.response
|
37
|
+
|
38
|
+
def __aiter__(self) -> AsyncIterator[_T]:
|
39
|
+
"""
|
40
|
+
Called when we do 'async for chunk in wrapped_stream:'
|
41
|
+
We simply return 'self'. Then __anext__ does the rest.
|
42
|
+
"""
|
43
|
+
return self
|
44
|
+
|
45
|
+
async def __anext__(self) -> _T:
|
46
|
+
"""
|
47
|
+
Intercept iteration. We pull the next chunk from the base_stream.
|
48
|
+
If it's the end, do any final usage logging, then raise StopAsyncIteration.
|
49
|
+
Otherwise, we can accumulate usage info or do whatever we need with the chunk.
|
50
|
+
"""
|
51
|
+
try:
|
52
|
+
chunk = await self._base_stream.__anext__()
|
53
|
+
except StopAsyncIteration:
|
54
|
+
# Once the base stream is fully consumed, we can do final usage/logging.
|
55
|
+
if self._usage_callback and self._chunks:
|
56
|
+
self._usage_callback(self._chunks)
|
57
|
+
raise
|
58
|
+
|
59
|
+
# Intercept each chunk
|
60
|
+
self._chunks.append(chunk)
|
61
|
+
return chunk
|
62
|
+
|
63
|
+
async def __aenter__(self) -> "OpenAIAsyncStreamInterceptor[_T]":
|
64
|
+
"""Support async with ... : usage."""
|
65
|
+
await self._base_stream.__aenter__()
|
66
|
+
return self
|
67
|
+
|
68
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
69
|
+
"""
|
70
|
+
Ensure we propagate __aexit__ to the base stream,
|
71
|
+
so connections are properly closed.
|
72
|
+
"""
|
73
|
+
return await self._base_stream.__aexit__(exc_type, exc_val, exc_tb)
|
74
|
+
|
75
|
+
async def close(self) -> None:
|
76
|
+
"""Delegate close to the base_stream."""
|
77
|
+
await self._base_stream.close()
|
78
|
+
|
79
|
+
|
80
|
+
class OpenAISyncStreamInterceptor(Stream[_T]):
|
81
|
+
"""
|
82
|
+
A wrapper around openai.Stream that delegates all functionality
|
83
|
+
to the 'base_stream' but intercepts each chunk to handle usage or
|
84
|
+
logging logic. This preserves .response and other methods.
|
85
|
+
|
86
|
+
You can store aggregated usage in a local list and process it when
|
87
|
+
the stream ends (StopIteration).
|
88
|
+
"""
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
base_stream: Stream[_T],
|
93
|
+
usage_callback: Optional[Callable[[List[_T]], None]] = None,
|
94
|
+
):
|
95
|
+
# We do NOT call super().__init__() because openai.Stream
|
96
|
+
# expects constructor parameters we don't want to re-initialize.
|
97
|
+
# Instead, we just store the base_stream and delegate everything to it.
|
98
|
+
self._base_stream = base_stream
|
99
|
+
self._usage_callback = usage_callback
|
100
|
+
self._chunks: List[_T] = []
|
101
|
+
|
102
|
+
@property
|
103
|
+
def response(self):
|
104
|
+
"""Expose the original stream's 'response' so user code can do stream.response, etc."""
|
105
|
+
return self._base_stream.response
|
106
|
+
|
107
|
+
def __iter__(self) -> Iterator[_T]:
|
108
|
+
"""
|
109
|
+
Called when we do 'for chunk in wrapped_stream:'
|
110
|
+
We simply return 'self'. Then __next__ does the rest.
|
111
|
+
"""
|
112
|
+
return self
|
113
|
+
|
114
|
+
def __next__(self) -> _T:
|
115
|
+
"""
|
116
|
+
Intercept iteration. We pull the next chunk from the base_stream.
|
117
|
+
If it's the end, do any final usage logging, then raise StopIteration.
|
118
|
+
Otherwise, we can accumulate usage info or do whatever we need with the chunk.
|
119
|
+
"""
|
120
|
+
try:
|
121
|
+
chunk = self._base_stream.__next__()
|
122
|
+
except StopIteration:
|
123
|
+
# Once the base stream is fully consumed, we can do final usage/logging.
|
124
|
+
if self._usage_callback and self._chunks:
|
125
|
+
self._usage_callback(self._chunks)
|
126
|
+
raise
|
127
|
+
|
128
|
+
# Intercept each chunk
|
129
|
+
self._chunks.append(chunk)
|
130
|
+
return chunk
|
131
|
+
|
132
|
+
def __enter__(self) -> "OpenAISyncStreamInterceptor[_T]":
|
133
|
+
"""Support with ... : usage."""
|
134
|
+
self._base_stream.__enter__()
|
135
|
+
return self
|
136
|
+
|
137
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
138
|
+
"""
|
139
|
+
Ensure we propagate __exit__ to the base stream,
|
140
|
+
so connections are properly closed.
|
141
|
+
"""
|
142
|
+
return self._base_stream.__exit__(exc_type, exc_val, exc_tb)
|
143
|
+
|
144
|
+
def close(self) -> None:
|
145
|
+
"""Delegate close to the base_stream."""
|
146
|
+
self._base_stream.close()
|
tokenator/schemas.py
CHANGED
@@ -1,26 +1,22 @@
|
|
1
1
|
"""SQLAlchemy models for tokenator."""
|
2
2
|
|
3
|
-
import uuid
|
4
3
|
from datetime import datetime
|
5
|
-
import os
|
6
4
|
|
7
|
-
from sqlalchemy import create_engine, Column, Integer, String, DateTime,
|
5
|
+
from sqlalchemy import create_engine, Column, Integer, String, DateTime, Index
|
8
6
|
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
9
7
|
|
10
8
|
from .utils import get_default_db_path
|
11
9
|
|
12
10
|
Base = declarative_base()
|
13
11
|
|
12
|
+
|
14
13
|
def get_engine(db_path: str = None):
|
15
14
|
"""Create SQLAlchemy engine with the given database path."""
|
16
15
|
if db_path is None:
|
17
|
-
|
18
|
-
import google.colab # type: ignore
|
19
|
-
db_path = '/content/tokenator.db'
|
20
|
-
except ImportError:
|
21
|
-
db_path = get_default_db_path()
|
16
|
+
db_path = get_default_db_path()
|
22
17
|
return create_engine(f"sqlite:///{db_path}", echo=False)
|
23
18
|
|
19
|
+
|
24
20
|
def get_session(db_path: str = None):
|
25
21
|
"""Create a thread-safe session factory."""
|
26
22
|
engine = get_engine(db_path)
|
@@ -28,39 +24,42 @@ def get_session(db_path: str = None):
|
|
28
24
|
session_factory = sessionmaker(bind=engine)
|
29
25
|
return scoped_session(session_factory)
|
30
26
|
|
27
|
+
|
31
28
|
class TokenUsage(Base):
|
32
29
|
"""Model for tracking token usage."""
|
33
|
-
|
30
|
+
|
34
31
|
__tablename__ = "token_usage"
|
35
|
-
|
32
|
+
|
36
33
|
id = Column(Integer, primary_key=True)
|
37
34
|
execution_id = Column(String, nullable=False)
|
38
35
|
provider = Column(String, nullable=False)
|
39
36
|
model = Column(String, nullable=False)
|
40
37
|
created_at = Column(DateTime, nullable=False, default=datetime.now)
|
41
|
-
updated_at = Column(
|
38
|
+
updated_at = Column(
|
39
|
+
DateTime, nullable=False, default=datetime.now, onupdate=datetime.now
|
40
|
+
)
|
42
41
|
prompt_tokens = Column(Integer, nullable=False)
|
43
42
|
completion_tokens = Column(Integer, nullable=False)
|
44
43
|
total_tokens = Column(Integer, nullable=False)
|
45
|
-
|
44
|
+
|
46
45
|
# Create indexes
|
47
46
|
__table_args__ = (
|
48
|
-
Index(
|
49
|
-
Index(
|
50
|
-
Index(
|
51
|
-
Index(
|
47
|
+
Index("idx_created_at", "created_at"),
|
48
|
+
Index("idx_execution_id", "execution_id"),
|
49
|
+
Index("idx_provider", "provider"),
|
50
|
+
Index("idx_model", "model"),
|
52
51
|
)
|
53
|
-
|
52
|
+
|
54
53
|
def to_dict(self):
|
55
54
|
"""Convert model instance to dictionary."""
|
56
55
|
return {
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
}
|
56
|
+
"id": self.id,
|
57
|
+
"execution_id": self.execution_id,
|
58
|
+
"provider": self.provider,
|
59
|
+
"model": self.model,
|
60
|
+
"created_at": self.created_at,
|
61
|
+
"updated_at": self.updated_at,
|
62
|
+
"prompt_tokens": self.prompt_tokens,
|
63
|
+
"completion_tokens": self.completion_tokens,
|
64
|
+
"total_tokens": self.total_tokens,
|
65
|
+
}
|