tokenator 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,163 @@
1
+ """OpenAI client wrapper with token usage tracking."""
2
+
3
+ from typing import Any, Optional, Union, overload, Iterator, AsyncIterator
4
+ import logging
5
+
6
+ from openai import AsyncOpenAI, OpenAI
7
+ from openai.types.chat import ChatCompletion, ChatCompletionChunk
8
+
9
+ from ..models import Usage, TokenUsageStats
10
+ from ..base_wrapper import BaseWrapper, ResponseType
11
+ from .stream_interceptors import OpenAIAsyncStreamInterceptor, OpenAISyncStreamInterceptor
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class BaseOpenAIWrapper(BaseWrapper):
17
+ provider = "openai"
18
+
19
+ def _process_response_usage(
20
+ self, response: ResponseType
21
+ ) -> Optional[TokenUsageStats]:
22
+ """Process and log usage statistics from a response."""
23
+ try:
24
+ if isinstance(response, ChatCompletion):
25
+ if response.usage is None:
26
+ return None
27
+ usage = Usage(
28
+ prompt_tokens=response.usage.prompt_tokens,
29
+ completion_tokens=response.usage.completion_tokens,
30
+ total_tokens=response.usage.total_tokens,
31
+ )
32
+ return TokenUsageStats(model=response.model, usage=usage)
33
+
34
+ elif isinstance(response, dict):
35
+ usage_dict = response.get("usage")
36
+ if not usage_dict:
37
+ return None
38
+ usage = Usage(
39
+ prompt_tokens=usage_dict.get("prompt_tokens", 0),
40
+ completion_tokens=usage_dict.get("completion_tokens", 0),
41
+ total_tokens=usage_dict.get("total_tokens", 0),
42
+ )
43
+ return TokenUsageStats(
44
+ model=response.get("model", "unknown"), usage=usage
45
+ )
46
+ except Exception as e:
47
+ logger.warning("Failed to process usage stats: %s", str(e))
48
+ return None
49
+ return None
50
+
51
+ @property
52
+ def chat(self):
53
+ return self
54
+
55
+ @property
56
+ def completions(self):
57
+ return self
58
+
59
+
60
+ def _create_usage_callback(execution_id, log_usage_fn):
61
+ """Creates a callback function for processing usage statistics from stream chunks."""
62
+
63
+ def usage_callback(chunks):
64
+ if not chunks:
65
+ return
66
+ # Build usage_data from the first chunk's model
67
+ usage_data = TokenUsageStats(
68
+ model=chunks[0].model,
69
+ usage=Usage(),
70
+ )
71
+ # Sum up usage from all chunks
72
+ has_usage = False
73
+ for ch in chunks:
74
+ if ch.usage:
75
+ has_usage = True
76
+ usage_data.usage.prompt_tokens += ch.usage.prompt_tokens
77
+ usage_data.usage.completion_tokens += ch.usage.completion_tokens
78
+ usage_data.usage.total_tokens += ch.usage.total_tokens
79
+
80
+ if has_usage:
81
+ log_usage_fn(usage_data, execution_id=execution_id)
82
+
83
+ return usage_callback
84
+
85
+
86
+ class OpenAIWrapper(BaseOpenAIWrapper):
87
+ def create(
88
+ self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
89
+ ) -> Union[ChatCompletion, Iterator[ChatCompletion]]:
90
+ """Create a chat completion and log token usage."""
91
+ logger.debug("Creating chat completion with args: %s, kwargs: %s", args, kwargs)
92
+
93
+ if kwargs.get("stream", False):
94
+ base_stream = self.client.chat.completions.create(*args, **kwargs)
95
+ return OpenAISyncStreamInterceptor(
96
+ base_stream=base_stream,
97
+ usage_callback=_create_usage_callback(execution_id, self._log_usage),
98
+ )
99
+
100
+ response = self.client.chat.completions.create(*args, **kwargs)
101
+ usage_data = self._process_response_usage(response)
102
+ if usage_data:
103
+ self._log_usage(usage_data, execution_id=execution_id)
104
+
105
+ return response
106
+
107
+
108
+ class AsyncOpenAIWrapper(BaseOpenAIWrapper):
109
+ async def create(
110
+ self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
111
+ ) -> Union[ChatCompletion, AsyncIterator[ChatCompletionChunk]]:
112
+ """
113
+ Create a chat completion and log token usage.
114
+ """
115
+ logger.debug("Creating chat completion with args: %s, kwargs: %s", args, kwargs)
116
+
117
+ # If user wants a stream, return an interceptor
118
+ if kwargs.get("stream", False):
119
+ base_stream = await self.client.chat.completions.create(*args, **kwargs)
120
+ return OpenAIAsyncStreamInterceptor(
121
+ base_stream=base_stream,
122
+ usage_callback=_create_usage_callback(execution_id, self._log_usage),
123
+ )
124
+
125
+ # Non-streaming path remains unchanged
126
+ response = await self.client.chat.completions.create(*args, **kwargs)
127
+ usage_data = self._process_response_usage(response)
128
+ if usage_data:
129
+ self._log_usage(usage_data, execution_id=execution_id)
130
+ return response
131
+
132
+
133
+ @overload
134
+ def tokenator_openai(
135
+ client: OpenAI,
136
+ db_path: Optional[str] = None,
137
+ ) -> OpenAIWrapper: ...
138
+
139
+
140
+ @overload
141
+ def tokenator_openai(
142
+ client: AsyncOpenAI,
143
+ db_path: Optional[str] = None,
144
+ ) -> AsyncOpenAIWrapper: ...
145
+
146
+
147
+ def tokenator_openai(
148
+ client: Union[OpenAI, AsyncOpenAI],
149
+ db_path: Optional[str] = None,
150
+ ) -> Union[OpenAIWrapper, AsyncOpenAIWrapper]:
151
+ """Create a token-tracking wrapper for an OpenAI client.
152
+
153
+ Args:
154
+ client: OpenAI or AsyncOpenAI client instance
155
+ db_path: Optional path to SQLite database for token tracking
156
+ """
157
+ if isinstance(client, OpenAI):
158
+ return OpenAIWrapper(client=client, db_path=db_path)
159
+
160
+ if isinstance(client, AsyncOpenAI):
161
+ return AsyncOpenAIWrapper(client=client, db_path=db_path)
162
+
163
+ raise ValueError("Client must be an instance of OpenAI or AsyncOpenAI")
@@ -0,0 +1,146 @@
1
+ import logging
2
+ from typing import AsyncIterator, Callable, List, Optional, TypeVar, Iterator
3
+
4
+ from openai import AsyncStream, Stream
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ _T = TypeVar("_T") # or you might specifically do _T = ChatCompletionChunk
9
+
10
+
11
+ class OpenAIAsyncStreamInterceptor(AsyncStream[_T]):
12
+ """
13
+ A wrapper around openai.AsyncStream that delegates all functionality
14
+ to the 'base_stream' but intercepts each chunk to handle usage or
15
+ logging logic. This preserves .response and other methods.
16
+
17
+ You can store aggregated usage in a local list and process it when
18
+ the stream ends (StopAsyncIteration).
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ base_stream: AsyncStream[_T],
24
+ usage_callback: Optional[Callable[[List[_T]], None]] = None,
25
+ ):
26
+ # We do NOT call super().__init__() because openai.AsyncStream
27
+ # expects constructor parameters we don't want to re-initialize.
28
+ # Instead, we just store the base_stream and delegate everything to it.
29
+ self._base_stream = base_stream
30
+ self._usage_callback = usage_callback
31
+ self._chunks: List[_T] = []
32
+
33
+ @property
34
+ def response(self):
35
+ """Expose the original stream's 'response' so user code can do stream.response, etc."""
36
+ return self._base_stream.response
37
+
38
+ def __aiter__(self) -> AsyncIterator[_T]:
39
+ """
40
+ Called when we do 'async for chunk in wrapped_stream:'
41
+ We simply return 'self'. Then __anext__ does the rest.
42
+ """
43
+ return self
44
+
45
+ async def __anext__(self) -> _T:
46
+ """
47
+ Intercept iteration. We pull the next chunk from the base_stream.
48
+ If it's the end, do any final usage logging, then raise StopAsyncIteration.
49
+ Otherwise, we can accumulate usage info or do whatever we need with the chunk.
50
+ """
51
+ try:
52
+ chunk = await self._base_stream.__anext__()
53
+ except StopAsyncIteration:
54
+ # Once the base stream is fully consumed, we can do final usage/logging.
55
+ if self._usage_callback and self._chunks:
56
+ self._usage_callback(self._chunks)
57
+ raise
58
+
59
+ # Intercept each chunk
60
+ self._chunks.append(chunk)
61
+ return chunk
62
+
63
+ async def __aenter__(self) -> "OpenAIAsyncStreamInterceptor[_T]":
64
+ """Support async with ... : usage."""
65
+ await self._base_stream.__aenter__()
66
+ return self
67
+
68
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
69
+ """
70
+ Ensure we propagate __aexit__ to the base stream,
71
+ so connections are properly closed.
72
+ """
73
+ return await self._base_stream.__aexit__(exc_type, exc_val, exc_tb)
74
+
75
+ async def close(self) -> None:
76
+ """Delegate close to the base_stream."""
77
+ await self._base_stream.close()
78
+
79
+
80
+ class OpenAISyncStreamInterceptor(Stream[_T]):
81
+ """
82
+ A wrapper around openai.Stream that delegates all functionality
83
+ to the 'base_stream' but intercepts each chunk to handle usage or
84
+ logging logic. This preserves .response and other methods.
85
+
86
+ You can store aggregated usage in a local list and process it when
87
+ the stream ends (StopIteration).
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ base_stream: Stream[_T],
93
+ usage_callback: Optional[Callable[[List[_T]], None]] = None,
94
+ ):
95
+ # We do NOT call super().__init__() because openai.Stream
96
+ # expects constructor parameters we don't want to re-initialize.
97
+ # Instead, we just store the base_stream and delegate everything to it.
98
+ self._base_stream = base_stream
99
+ self._usage_callback = usage_callback
100
+ self._chunks: List[_T] = []
101
+
102
+ @property
103
+ def response(self):
104
+ """Expose the original stream's 'response' so user code can do stream.response, etc."""
105
+ return self._base_stream.response
106
+
107
+ def __iter__(self) -> Iterator[_T]:
108
+ """
109
+ Called when we do 'for chunk in wrapped_stream:'
110
+ We simply return 'self'. Then __next__ does the rest.
111
+ """
112
+ return self
113
+
114
+ def __next__(self) -> _T:
115
+ """
116
+ Intercept iteration. We pull the next chunk from the base_stream.
117
+ If it's the end, do any final usage logging, then raise StopIteration.
118
+ Otherwise, we can accumulate usage info or do whatever we need with the chunk.
119
+ """
120
+ try:
121
+ chunk = self._base_stream.__next__()
122
+ except StopIteration:
123
+ # Once the base stream is fully consumed, we can do final usage/logging.
124
+ if self._usage_callback and self._chunks:
125
+ self._usage_callback(self._chunks)
126
+ raise
127
+
128
+ # Intercept each chunk
129
+ self._chunks.append(chunk)
130
+ return chunk
131
+
132
+ def __enter__(self) -> "OpenAISyncStreamInterceptor[_T]":
133
+ """Support with ... : usage."""
134
+ self._base_stream.__enter__()
135
+ return self
136
+
137
+ def __exit__(self, exc_type, exc_val, exc_tb):
138
+ """
139
+ Ensure we propagate __exit__ to the base stream,
140
+ so connections are properly closed.
141
+ """
142
+ return self._base_stream.__exit__(exc_type, exc_val, exc_tb)
143
+
144
+ def close(self) -> None:
145
+ """Delegate close to the base_stream."""
146
+ self._base_stream.close()
tokenator/schemas.py CHANGED
@@ -1,26 +1,22 @@
1
1
  """SQLAlchemy models for tokenator."""
2
2
 
3
- import uuid
4
3
  from datetime import datetime
5
- import os
6
4
 
7
- from sqlalchemy import create_engine, Column, Integer, String, DateTime, Float, Index
5
+ from sqlalchemy import create_engine, Column, Integer, String, DateTime, Index
8
6
  from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
9
7
 
10
8
  from .utils import get_default_db_path
11
9
 
12
10
  Base = declarative_base()
13
11
 
12
+
14
13
  def get_engine(db_path: str = None):
15
14
  """Create SQLAlchemy engine with the given database path."""
16
15
  if db_path is None:
17
- try:
18
- import google.colab # type: ignore
19
- db_path = '/content/tokenator.db'
20
- except ImportError:
21
- db_path = get_default_db_path()
16
+ db_path = get_default_db_path()
22
17
  return create_engine(f"sqlite:///{db_path}", echo=False)
23
18
 
19
+
24
20
  def get_session(db_path: str = None):
25
21
  """Create a thread-safe session factory."""
26
22
  engine = get_engine(db_path)
@@ -28,39 +24,42 @@ def get_session(db_path: str = None):
28
24
  session_factory = sessionmaker(bind=engine)
29
25
  return scoped_session(session_factory)
30
26
 
27
+
31
28
  class TokenUsage(Base):
32
29
  """Model for tracking token usage."""
33
-
30
+
34
31
  __tablename__ = "token_usage"
35
-
32
+
36
33
  id = Column(Integer, primary_key=True)
37
34
  execution_id = Column(String, nullable=False)
38
35
  provider = Column(String, nullable=False)
39
36
  model = Column(String, nullable=False)
40
37
  created_at = Column(DateTime, nullable=False, default=datetime.now)
41
- updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now)
38
+ updated_at = Column(
39
+ DateTime, nullable=False, default=datetime.now, onupdate=datetime.now
40
+ )
42
41
  prompt_tokens = Column(Integer, nullable=False)
43
42
  completion_tokens = Column(Integer, nullable=False)
44
43
  total_tokens = Column(Integer, nullable=False)
45
-
44
+
46
45
  # Create indexes
47
46
  __table_args__ = (
48
- Index('idx_created_at', 'created_at'),
49
- Index('idx_execution_id', 'execution_id'),
50
- Index('idx_provider', 'provider'),
51
- Index('idx_model', 'model'),
47
+ Index("idx_created_at", "created_at"),
48
+ Index("idx_execution_id", "execution_id"),
49
+ Index("idx_provider", "provider"),
50
+ Index("idx_model", "model"),
52
51
  )
53
-
52
+
54
53
  def to_dict(self):
55
54
  """Convert model instance to dictionary."""
56
55
  return {
57
- 'id': self.id,
58
- 'execution_id': self.execution_id,
59
- 'provider': self.provider,
60
- 'model': self.model,
61
- 'created_at': self.created_at,
62
- 'updated_at': self.updated_at,
63
- 'prompt_tokens': self.prompt_tokens,
64
- 'completion_tokens': self.completion_tokens,
65
- 'total_tokens': self.total_tokens
66
- }
56
+ "id": self.id,
57
+ "execution_id": self.execution_id,
58
+ "provider": self.provider,
59
+ "model": self.model,
60
+ "created_at": self.created_at,
61
+ "updated_at": self.updated_at,
62
+ "prompt_tokens": self.prompt_tokens,
63
+ "completion_tokens": self.completion_tokens,
64
+ "total_tokens": self.total_tokens,
65
+ }