tokenator 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,21 +1,24 @@
1
1
  """OpenAI client wrapper with token usage tracking."""
2
2
 
3
- from typing import Any, Dict, Optional, TypeVar, Union, overload, Iterator, AsyncIterator
3
+ from typing import Any, Optional, Union, overload, Iterator, AsyncIterator
4
4
  import logging
5
5
 
6
- from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
6
+ from openai import AsyncOpenAI, OpenAI
7
7
  from openai.types.chat import ChatCompletion, ChatCompletionChunk
8
8
 
9
9
  from ..models import Usage, TokenUsageStats
10
10
  from ..base_wrapper import BaseWrapper, ResponseType
11
- from .AsyncStreamInterceptor import AsyncStreamInterceptor
11
+ from .stream_interceptors import OpenAIAsyncStreamInterceptor, OpenAISyncStreamInterceptor
12
12
 
13
13
  logger = logging.getLogger(__name__)
14
14
 
15
+
15
16
  class BaseOpenAIWrapper(BaseWrapper):
16
17
  provider = "openai"
17
18
 
18
- def _process_response_usage(self, response: ResponseType) -> Optional[TokenUsageStats]:
19
+ def _process_response_usage(
20
+ self, response: ResponseType
21
+ ) -> Optional[TokenUsageStats]:
19
22
  """Process and log usage statistics from a response."""
20
23
  try:
21
24
  if isinstance(response, ChatCompletion):
@@ -27,19 +30,18 @@ class BaseOpenAIWrapper(BaseWrapper):
27
30
  total_tokens=response.usage.total_tokens,
28
31
  )
29
32
  return TokenUsageStats(model=response.model, usage=usage)
30
-
33
+
31
34
  elif isinstance(response, dict):
32
- usage_dict = response.get('usage')
35
+ usage_dict = response.get("usage")
33
36
  if not usage_dict:
34
37
  return None
35
38
  usage = Usage(
36
- prompt_tokens=usage_dict.get('prompt_tokens', 0),
37
- completion_tokens=usage_dict.get('completion_tokens', 0),
38
- total_tokens=usage_dict.get('total_tokens', 0)
39
+ prompt_tokens=usage_dict.get("prompt_tokens", 0),
40
+ completion_tokens=usage_dict.get("completion_tokens", 0),
41
+ total_tokens=usage_dict.get("total_tokens", 0),
39
42
  )
40
43
  return TokenUsageStats(
41
- model=response.get('model', 'unknown'),
42
- usage=usage
44
+ model=response.get("model", "unknown"), usage=usage
43
45
  )
44
46
  except Exception as e:
45
47
  logger.warning("Failed to process usage stats: %s", str(e))
@@ -54,45 +56,58 @@ class BaseOpenAIWrapper(BaseWrapper):
54
56
  def completions(self):
55
57
  return self
56
58
 
59
+
60
+ def _create_usage_callback(execution_id, log_usage_fn):
61
+ """Creates a callback function for processing usage statistics from stream chunks."""
62
+
63
+ def usage_callback(chunks):
64
+ if not chunks:
65
+ return
66
+ # Build usage_data from the first chunk's model
67
+ usage_data = TokenUsageStats(
68
+ model=chunks[0].model,
69
+ usage=Usage(),
70
+ )
71
+ # Sum up usage from all chunks
72
+ has_usage = False
73
+ for ch in chunks:
74
+ if ch.usage:
75
+ has_usage = True
76
+ usage_data.usage.prompt_tokens += ch.usage.prompt_tokens
77
+ usage_data.usage.completion_tokens += ch.usage.completion_tokens
78
+ usage_data.usage.total_tokens += ch.usage.total_tokens
79
+
80
+ if has_usage:
81
+ log_usage_fn(usage_data, execution_id=execution_id)
82
+
83
+ return usage_callback
84
+
85
+
57
86
  class OpenAIWrapper(BaseOpenAIWrapper):
58
- def create(self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any) -> Union[ChatCompletion, Iterator[ChatCompletion]]:
87
+ def create(
88
+ self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
89
+ ) -> Union[ChatCompletion, Iterator[ChatCompletion]]:
59
90
  """Create a chat completion and log token usage."""
60
91
  logger.debug("Creating chat completion with args: %s, kwargs: %s", args, kwargs)
61
-
92
+
93
+ if kwargs.get("stream", False):
94
+ base_stream = self.client.chat.completions.create(*args, **kwargs)
95
+ return OpenAISyncStreamInterceptor(
96
+ base_stream=base_stream,
97
+ usage_callback=_create_usage_callback(execution_id, self._log_usage),
98
+ )
99
+
62
100
  response = self.client.chat.completions.create(*args, **kwargs)
63
-
64
- if not kwargs.get('stream', False):
65
- usage_data = self._process_response_usage(response)
66
- if usage_data:
67
- self._log_usage(usage_data, execution_id=execution_id)
68
- return response
69
-
70
- return self._wrap_streaming_response(response, execution_id)
71
-
72
- def _wrap_streaming_response(self, response_iter: Stream[ChatCompletionChunk], execution_id: Optional[str]) -> Iterator[ChatCompletionChunk]:
73
- """Wrap streaming response to capture final usage stats"""
74
- chunks_with_usage = []
75
- for chunk in response_iter:
76
- if isinstance(chunk, ChatCompletionChunk) and chunk.usage is not None:
77
- chunks_with_usage.append(chunk)
78
- yield chunk
79
-
80
- if len(chunks_with_usage) > 0:
81
- usage_data: TokenUsageStats = TokenUsageStats(model=chunks_with_usage[0].model, usage=Usage())
82
- for chunk in chunks_with_usage:
83
- usage_data.usage.prompt_tokens += chunk.usage.prompt_tokens
84
- usage_data.usage.completion_tokens += chunk.usage.completion_tokens
85
- usage_data.usage.total_tokens += chunk.usage.total_tokens
86
-
101
+ usage_data = self._process_response_usage(response)
102
+ if usage_data:
87
103
  self._log_usage(usage_data, execution_id=execution_id)
88
-
104
+
105
+ return response
106
+
89
107
 
90
108
  class AsyncOpenAIWrapper(BaseOpenAIWrapper):
91
109
  async def create(
92
- self,
93
- *args: Any,
94
- execution_id: Optional[str] = None,
95
- **kwargs: Any
110
+ self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
96
111
  ) -> Union[ChatCompletion, AsyncIterator[ChatCompletionChunk]]:
97
112
  """
98
113
  Create a chat completion and log token usage.
@@ -102,32 +117,9 @@ class AsyncOpenAIWrapper(BaseOpenAIWrapper):
102
117
  # If user wants a stream, return an interceptor
103
118
  if kwargs.get("stream", False):
104
119
  base_stream = await self.client.chat.completions.create(*args, **kwargs)
105
-
106
- # Define a callback that will get called once the stream ends
107
- def usage_callback(chunks):
108
- # Mimic your old logic to gather usage from chunk.usage
109
- # e.g. ChatCompletionChunk.usage
110
- # Then call self._log_usage(...)
111
- if not chunks:
112
- return
113
- # Build usage_data from the first chunk's model
114
- usage_data = TokenUsageStats(
115
- model=chunks[0].model,
116
- usage=Usage(),
117
- )
118
- # Sum up usage from all chunks
119
- for ch in chunks:
120
- if ch.usage:
121
- usage_data.usage.prompt_tokens += ch.usage.prompt_tokens
122
- usage_data.usage.completion_tokens += ch.usage.completion_tokens
123
- usage_data.usage.total_tokens += ch.usage.total_tokens
124
-
125
- self._log_usage(usage_data, execution_id=execution_id)
126
-
127
- # Return the interceptor that wraps the real AsyncStream
128
- return AsyncStreamInterceptor(
120
+ return OpenAIAsyncStreamInterceptor(
129
121
  base_stream=base_stream,
130
- usage_callback=usage_callback,
122
+ usage_callback=_create_usage_callback(execution_id, self._log_usage),
131
123
  )
132
124
 
133
125
  # Non-streaming path remains unchanged
@@ -136,32 +128,36 @@ class AsyncOpenAIWrapper(BaseOpenAIWrapper):
136
128
  if usage_data:
137
129
  self._log_usage(usage_data, execution_id=execution_id)
138
130
  return response
131
+
132
+
139
133
  @overload
140
134
  def tokenator_openai(
141
135
  client: OpenAI,
142
136
  db_path: Optional[str] = None,
143
137
  ) -> OpenAIWrapper: ...
144
138
 
139
+
145
140
  @overload
146
141
  def tokenator_openai(
147
142
  client: AsyncOpenAI,
148
143
  db_path: Optional[str] = None,
149
144
  ) -> AsyncOpenAIWrapper: ...
150
145
 
146
+
151
147
  def tokenator_openai(
152
148
  client: Union[OpenAI, AsyncOpenAI],
153
149
  db_path: Optional[str] = None,
154
150
  ) -> Union[OpenAIWrapper, AsyncOpenAIWrapper]:
155
151
  """Create a token-tracking wrapper for an OpenAI client.
156
-
152
+
157
153
  Args:
158
154
  client: OpenAI or AsyncOpenAI client instance
159
155
  db_path: Optional path to SQLite database for token tracking
160
156
  """
161
157
  if isinstance(client, OpenAI):
162
158
  return OpenAIWrapper(client=client, db_path=db_path)
163
-
159
+
164
160
  if isinstance(client, AsyncOpenAI):
165
161
  return AsyncOpenAIWrapper(client=client, db_path=db_path)
166
-
162
+
167
163
  raise ValueError("Client must be an instance of OpenAI or AsyncOpenAI")
@@ -0,0 +1,146 @@
1
+ import logging
2
+ from typing import AsyncIterator, Callable, List, Optional, TypeVar, Iterator
3
+
4
+ from openai import AsyncStream, Stream
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ _T = TypeVar("_T") # or you might specifically do _T = ChatCompletionChunk
9
+
10
+
11
+ class OpenAIAsyncStreamInterceptor(AsyncStream[_T]):
12
+ """
13
+ A wrapper around openai.AsyncStream that delegates all functionality
14
+ to the 'base_stream' but intercepts each chunk to handle usage or
15
+ logging logic. This preserves .response and other methods.
16
+
17
+ You can store aggregated usage in a local list and process it when
18
+ the stream ends (StopAsyncIteration).
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ base_stream: AsyncStream[_T],
24
+ usage_callback: Optional[Callable[[List[_T]], None]] = None,
25
+ ):
26
+ # We do NOT call super().__init__() because openai.AsyncStream
27
+ # expects constructor parameters we don't want to re-initialize.
28
+ # Instead, we just store the base_stream and delegate everything to it.
29
+ self._base_stream = base_stream
30
+ self._usage_callback = usage_callback
31
+ self._chunks: List[_T] = []
32
+
33
+ @property
34
+ def response(self):
35
+ """Expose the original stream's 'response' so user code can do stream.response, etc."""
36
+ return self._base_stream.response
37
+
38
+ def __aiter__(self) -> AsyncIterator[_T]:
39
+ """
40
+ Called when we do 'async for chunk in wrapped_stream:'
41
+ We simply return 'self'. Then __anext__ does the rest.
42
+ """
43
+ return self
44
+
45
+ async def __anext__(self) -> _T:
46
+ """
47
+ Intercept iteration. We pull the next chunk from the base_stream.
48
+ If it's the end, do any final usage logging, then raise StopAsyncIteration.
49
+ Otherwise, we can accumulate usage info or do whatever we need with the chunk.
50
+ """
51
+ try:
52
+ chunk = await self._base_stream.__anext__()
53
+ except StopAsyncIteration:
54
+ # Once the base stream is fully consumed, we can do final usage/logging.
55
+ if self._usage_callback and self._chunks:
56
+ self._usage_callback(self._chunks)
57
+ raise
58
+
59
+ # Intercept each chunk
60
+ self._chunks.append(chunk)
61
+ return chunk
62
+
63
+ async def __aenter__(self) -> "OpenAIAsyncStreamInterceptor[_T]":
64
+ """Support async with ... : usage."""
65
+ await self._base_stream.__aenter__()
66
+ return self
67
+
68
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
69
+ """
70
+ Ensure we propagate __aexit__ to the base stream,
71
+ so connections are properly closed.
72
+ """
73
+ return await self._base_stream.__aexit__(exc_type, exc_val, exc_tb)
74
+
75
+ async def close(self) -> None:
76
+ """Delegate close to the base_stream."""
77
+ await self._base_stream.close()
78
+
79
+
80
+ class OpenAISyncStreamInterceptor(Stream[_T]):
81
+ """
82
+ A wrapper around openai.Stream that delegates all functionality
83
+ to the 'base_stream' but intercepts each chunk to handle usage or
84
+ logging logic. This preserves .response and other methods.
85
+
86
+ You can store aggregated usage in a local list and process it when
87
+ the stream ends (StopIteration).
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ base_stream: Stream[_T],
93
+ usage_callback: Optional[Callable[[List[_T]], None]] = None,
94
+ ):
95
+ # We do NOT call super().__init__() because openai.Stream
96
+ # expects constructor parameters we don't want to re-initialize.
97
+ # Instead, we just store the base_stream and delegate everything to it.
98
+ self._base_stream = base_stream
99
+ self._usage_callback = usage_callback
100
+ self._chunks: List[_T] = []
101
+
102
+ @property
103
+ def response(self):
104
+ """Expose the original stream's 'response' so user code can do stream.response, etc."""
105
+ return self._base_stream.response
106
+
107
+ def __iter__(self) -> Iterator[_T]:
108
+ """
109
+ Called when we do 'for chunk in wrapped_stream:'
110
+ We simply return 'self'. Then __next__ does the rest.
111
+ """
112
+ return self
113
+
114
+ def __next__(self) -> _T:
115
+ """
116
+ Intercept iteration. We pull the next chunk from the base_stream.
117
+ If it's the end, do any final usage logging, then raise StopIteration.
118
+ Otherwise, we can accumulate usage info or do whatever we need with the chunk.
119
+ """
120
+ try:
121
+ chunk = self._base_stream.__next__()
122
+ except StopIteration:
123
+ # Once the base stream is fully consumed, we can do final usage/logging.
124
+ if self._usage_callback and self._chunks:
125
+ self._usage_callback(self._chunks)
126
+ raise
127
+
128
+ # Intercept each chunk
129
+ self._chunks.append(chunk)
130
+ return chunk
131
+
132
+ def __enter__(self) -> "OpenAISyncStreamInterceptor[_T]":
133
+ """Support with ... : usage."""
134
+ self._base_stream.__enter__()
135
+ return self
136
+
137
+ def __exit__(self, exc_type, exc_val, exc_tb):
138
+ """
139
+ Ensure we propagate __exit__ to the base stream,
140
+ so connections are properly closed.
141
+ """
142
+ return self._base_stream.__exit__(exc_type, exc_val, exc_tb)
143
+
144
+ def close(self) -> None:
145
+ """Delegate close to the base_stream."""
146
+ self._base_stream.close()
tokenator/schemas.py CHANGED
@@ -1,26 +1,22 @@
1
1
  """SQLAlchemy models for tokenator."""
2
2
 
3
- import uuid
4
3
  from datetime import datetime
5
- import os
6
4
 
7
- from sqlalchemy import create_engine, Column, Integer, String, DateTime, Float, Index
5
+ from sqlalchemy import create_engine, Column, Integer, String, DateTime, Index
8
6
  from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
9
7
 
10
8
  from .utils import get_default_db_path
11
9
 
12
10
  Base = declarative_base()
13
11
 
12
+
14
13
  def get_engine(db_path: str = None):
15
14
  """Create SQLAlchemy engine with the given database path."""
16
15
  if db_path is None:
17
- try:
18
- import google.colab # type: ignore
19
- db_path = '/content/tokenator.db'
20
- except ImportError:
21
- db_path = get_default_db_path()
16
+ db_path = get_default_db_path()
22
17
  return create_engine(f"sqlite:///{db_path}", echo=False)
23
18
 
19
+
24
20
  def get_session(db_path: str = None):
25
21
  """Create a thread-safe session factory."""
26
22
  engine = get_engine(db_path)
@@ -28,39 +24,42 @@ def get_session(db_path: str = None):
28
24
  session_factory = sessionmaker(bind=engine)
29
25
  return scoped_session(session_factory)
30
26
 
27
+
31
28
  class TokenUsage(Base):
32
29
  """Model for tracking token usage."""
33
-
30
+
34
31
  __tablename__ = "token_usage"
35
-
32
+
36
33
  id = Column(Integer, primary_key=True)
37
34
  execution_id = Column(String, nullable=False)
38
35
  provider = Column(String, nullable=False)
39
36
  model = Column(String, nullable=False)
40
37
  created_at = Column(DateTime, nullable=False, default=datetime.now)
41
- updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now)
38
+ updated_at = Column(
39
+ DateTime, nullable=False, default=datetime.now, onupdate=datetime.now
40
+ )
42
41
  prompt_tokens = Column(Integer, nullable=False)
43
42
  completion_tokens = Column(Integer, nullable=False)
44
43
  total_tokens = Column(Integer, nullable=False)
45
-
44
+
46
45
  # Create indexes
47
46
  __table_args__ = (
48
- Index('idx_created_at', 'created_at'),
49
- Index('idx_execution_id', 'execution_id'),
50
- Index('idx_provider', 'provider'),
51
- Index('idx_model', 'model'),
47
+ Index("idx_created_at", "created_at"),
48
+ Index("idx_execution_id", "execution_id"),
49
+ Index("idx_provider", "provider"),
50
+ Index("idx_model", "model"),
52
51
  )
53
-
52
+
54
53
  def to_dict(self):
55
54
  """Convert model instance to dictionary."""
56
55
  return {
57
- 'id': self.id,
58
- 'execution_id': self.execution_id,
59
- 'provider': self.provider,
60
- 'model': self.model,
61
- 'created_at': self.created_at,
62
- 'updated_at': self.updated_at,
63
- 'prompt_tokens': self.prompt_tokens,
64
- 'completion_tokens': self.completion_tokens,
65
- 'total_tokens': self.total_tokens
66
- }
56
+ "id": self.id,
57
+ "execution_id": self.execution_id,
58
+ "provider": self.provider,
59
+ "model": self.model,
60
+ "created_at": self.created_at,
61
+ "updated_at": self.updated_at,
62
+ "prompt_tokens": self.prompt_tokens,
63
+ "completion_tokens": self.completion_tokens,
64
+ "total_tokens": self.total_tokens,
65
+ }