tokenator 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tokenator/__init__.py CHANGED
@@ -5,14 +5,9 @@ from .openai.client_openai import tokenator_openai
5
5
  from .anthropic.client_anthropic import tokenator_anthropic
6
6
  from . import usage
7
7
  from .utils import get_default_db_path
8
- from .migrations import check_and_run_migrations
8
+ from .usage import TokenUsageService
9
9
 
10
- __version__ = "0.1.0"
10
+ usage = TokenUsageService() # noqa: F811
11
11
  __all__ = ["tokenator_openai", "tokenator_anthropic", "usage", "get_default_db_path"]
12
12
 
13
13
  logger = logging.getLogger(__name__)
14
-
15
- try:
16
- check_and_run_migrations()
17
- except Exception as e:
18
- logger.warning(f"Failed to run migrations, but continuing anyway: {e}")
@@ -6,9 +6,13 @@ import logging
6
6
  from anthropic import Anthropic, AsyncAnthropic
7
7
  from anthropic.types import Message, RawMessageStartEvent, RawMessageDeltaEvent
8
8
 
9
- from ..models import Usage, TokenUsageStats
9
+ from ..models import PromptTokenDetails, TokenMetrics, TokenUsageStats
10
10
  from ..base_wrapper import BaseWrapper, ResponseType
11
- from .stream_interceptors import AnthropicAsyncStreamInterceptor, AnthropicSyncStreamInterceptor
11
+ from .stream_interceptors import (
12
+ AnthropicAsyncStreamInterceptor,
13
+ AnthropicSyncStreamInterceptor,
14
+ )
15
+ from ..state import is_tokenator_enabled
12
16
 
13
17
  logger = logging.getLogger(__name__)
14
18
 
@@ -24,28 +28,46 @@ class BaseAnthropicWrapper(BaseWrapper):
24
28
  if isinstance(response, Message):
25
29
  if not hasattr(response, "usage"):
26
30
  return None
27
- usage = Usage(
28
- prompt_tokens=response.usage.input_tokens,
31
+ usage = TokenMetrics(
32
+ prompt_tokens=response.usage.input_tokens
33
+ + (getattr(response.usage, "cache_creation_input_tokens", 0) or 0),
29
34
  completion_tokens=response.usage.output_tokens,
30
35
  total_tokens=response.usage.input_tokens
31
36
  + response.usage.output_tokens,
37
+ prompt_tokens_details=PromptTokenDetails(
38
+ cached_input_tokens=getattr(
39
+ response.usage, "cache_read_input_tokens", None
40
+ ),
41
+ cached_creation_tokens=getattr(
42
+ response.usage, "cache_creation_input_tokens", None
43
+ ),
44
+ ),
32
45
  )
33
46
  return TokenUsageStats(model=response.model, usage=usage)
34
47
  elif isinstance(response, dict):
35
48
  usage_dict = response.get("usage")
36
49
  if not usage_dict:
37
50
  return None
38
- usage = Usage(
39
- prompt_tokens=usage_dict.get("input_tokens", 0),
51
+ usage = TokenMetrics(
52
+ prompt_tokens=usage_dict.get("input_tokens", 0)
53
+ + (getattr(usage_dict, "cache_creation_input_tokens", 0) or 0),
40
54
  completion_tokens=usage_dict.get("output_tokens", 0),
41
55
  total_tokens=usage_dict.get("input_tokens", 0)
42
56
  + usage_dict.get("output_tokens", 0),
57
+ prompt_tokens_details=PromptTokenDetails(
58
+ cached_input_tokens=getattr(
59
+ usage_dict, "cache_read_input_tokens", None
60
+ ),
61
+ cached_creation_tokens=getattr(
62
+ usage_dict, "cache_creation_input_tokens", None
63
+ ),
64
+ ),
43
65
  )
44
66
  return TokenUsageStats(
45
67
  model=response.get("model", "unknown"), usage=usage
46
68
  )
47
69
  except Exception as e:
48
- logger.warning("Failed to process usage stats: %s", str(e))
70
+ logger.warning("Failed to process usage stats: %s", str(e), exc_info=True)
49
71
  return None
50
72
  return None
51
73
 
@@ -56,15 +78,23 @@ class BaseAnthropicWrapper(BaseWrapper):
56
78
 
57
79
  def _create_usage_callback(execution_id, log_usage_fn):
58
80
  """Creates a callback function for processing usage statistics from stream chunks."""
81
+
59
82
  def usage_callback(chunks):
60
83
  if not chunks:
61
84
  return
62
-
85
+
86
+ # Skip if tokenator is disabled
87
+ if not is_tokenator_enabled:
88
+ logger.debug("Tokenator is disabled - skipping stream usage logging")
89
+ return
90
+
63
91
  usage_data = TokenUsageStats(
64
- model=chunks[0].message.model if isinstance(chunks[0], RawMessageStartEvent) else "",
65
- usage=Usage(),
92
+ model=chunks[0].message.model
93
+ if isinstance(chunks[0], RawMessageStartEvent)
94
+ else "",
95
+ usage=TokenMetrics(),
66
96
  )
67
-
97
+
68
98
  for chunk in chunks:
69
99
  if isinstance(chunk, RawMessageStartEvent):
70
100
  usage_data.model = chunk.message.model
@@ -72,8 +102,10 @@ def _create_usage_callback(execution_id, log_usage_fn):
72
102
  usage_data.usage.completion_tokens += chunk.message.usage.output_tokens
73
103
  elif isinstance(chunk, RawMessageDeltaEvent):
74
104
  usage_data.usage.completion_tokens += chunk.usage.output_tokens
75
-
76
- usage_data.usage.total_tokens = usage_data.usage.prompt_tokens + usage_data.usage.completion_tokens
105
+
106
+ usage_data.usage.total_tokens = (
107
+ usage_data.usage.prompt_tokens + usage_data.usage.completion_tokens
108
+ )
77
109
  log_usage_fn(usage_data, execution_id=execution_id)
78
110
 
79
111
  return usage_callback
@@ -84,7 +116,9 @@ class AnthropicWrapper(BaseAnthropicWrapper):
84
116
  self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
85
117
  ) -> Union[Message, Iterator[Message]]:
86
118
  """Create a message completion and log token usage."""
87
- logger.debug("Creating message completion with args: %s, kwargs: %s", args, kwargs)
119
+ logger.debug(
120
+ "Creating message completion with args: %s, kwargs: %s", args, kwargs
121
+ )
88
122
 
89
123
  if kwargs.get("stream", False):
90
124
  base_stream = self.client.messages.create(*args, **kwargs)
@@ -105,7 +139,9 @@ class AsyncAnthropicWrapper(BaseAnthropicWrapper):
105
139
  self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
106
140
  ) -> Union[Message, AsyncIterator[Message]]:
107
141
  """Create a message completion and log token usage."""
108
- logger.debug("Creating message completion with args: %s, kwargs: %s", args, kwargs)
142
+ logger.debug(
143
+ "Creating message completion with args: %s, kwargs: %s", args, kwargs
144
+ )
109
145
 
110
146
  if kwargs.get("stream", False):
111
147
  base_stream = await self.client.messages.create(*args, **kwargs)
tokenator/base_wrapper.py CHANGED
@@ -7,6 +7,9 @@ import uuid
7
7
 
8
8
  from .models import TokenUsageStats
9
9
  from .schemas import get_session, TokenUsage
10
+ from . import state
11
+
12
+ from .migrations import check_and_run_migrations
10
13
 
11
14
  logger = logging.getLogger(__name__)
12
15
 
@@ -16,17 +19,30 @@ ResponseType = TypeVar("ResponseType")
16
19
  class BaseWrapper:
17
20
  def __init__(self, client: Any, db_path: Optional[str] = None):
18
21
  """Initialize the base wrapper."""
19
- self.client = client
22
+ state.is_tokenator_enabled = True
23
+ try:
24
+ self.client = client
20
25
 
21
- if db_path:
22
- Path(db_path).parent.mkdir(parents=True, exist_ok=True)
23
- logger.info("Created database directory at: %s", Path(db_path).parent)
26
+ if db_path:
27
+ Path(db_path).parent.mkdir(parents=True, exist_ok=True)
28
+ logger.info("Created database directory at: %s", Path(db_path).parent)
29
+ state.db_path = db_path # Store db_path in state
24
30
 
25
- self.Session = get_session(db_path)
31
+ else:
32
+ state.db_path = None # Use default path
26
33
 
27
- logger.debug(
28
- "Initializing %s with db_path: %s", self.__class__.__name__, db_path
29
- )
34
+ self.Session = get_session()
35
+
36
+ logger.debug(
37
+ "Initializing %s with db_path: %s", self.__class__.__name__, db_path
38
+ )
39
+
40
+ check_and_run_migrations(db_path)
41
+ except Exception as e:
42
+ state.is_tokenator_enabled = False
43
+ logger.warning(
44
+ f"Tokenator initialization failed. Usage tracking will be disabled. Error: {e}"
45
+ )
30
46
 
31
47
  def _log_usage_impl(
32
48
  self, token_usage_stats: TokenUsageStats, session, execution_id: str
@@ -42,9 +58,33 @@ class BaseWrapper:
42
58
  execution_id=execution_id,
43
59
  provider=self.provider,
44
60
  model=token_usage_stats.model,
61
+ total_cost=0, # This needs to be calculated based on your rates
45
62
  prompt_tokens=token_usage_stats.usage.prompt_tokens,
46
63
  completion_tokens=token_usage_stats.usage.completion_tokens,
47
64
  total_tokens=token_usage_stats.usage.total_tokens,
65
+ # Prompt details
66
+ prompt_cached_input_tokens=token_usage_stats.usage.prompt_tokens_details.cached_input_tokens
67
+ if token_usage_stats.usage.prompt_tokens_details
68
+ else None,
69
+ prompt_cached_creation_tokens=token_usage_stats.usage.prompt_tokens_details.cached_creation_tokens
70
+ if token_usage_stats.usage.prompt_tokens_details
71
+ else None,
72
+ prompt_audio_tokens=token_usage_stats.usage.prompt_tokens_details.audio_tokens
73
+ if token_usage_stats.usage.prompt_tokens_details
74
+ else None,
75
+ # Completion details
76
+ completion_audio_tokens=token_usage_stats.usage.completion_tokens_details.audio_tokens
77
+ if token_usage_stats.usage.completion_tokens_details
78
+ else None,
79
+ completion_reasoning_tokens=token_usage_stats.usage.completion_tokens_details.reasoning_tokens
80
+ if token_usage_stats.usage.completion_tokens_details
81
+ else None,
82
+ completion_accepted_prediction_tokens=token_usage_stats.usage.completion_tokens_details.accepted_prediction_tokens
83
+ if token_usage_stats.usage.completion_tokens_details
84
+ else None,
85
+ completion_rejected_prediction_tokens=token_usage_stats.usage.completion_tokens_details.rejected_prediction_tokens
86
+ if token_usage_stats.usage.completion_tokens_details
87
+ else None,
48
88
  )
49
89
  session.add(token_usage)
50
90
  logger.debug(
@@ -59,14 +99,20 @@ class BaseWrapper:
59
99
  self, token_usage_stats: TokenUsageStats, execution_id: Optional[str] = None
60
100
  ):
61
101
  """Log token usage to database."""
102
+ if not state.is_tokenator_enabled:
103
+ logger.debug("Tokenator is disabled - skipping usage logging")
104
+ return
105
+
62
106
  if not execution_id:
63
107
  execution_id = str(uuid.uuid4())
64
108
 
109
+ logger.debug("Starting token usage logging for execution_id: %s", execution_id)
65
110
  session = self.Session()
66
111
  try:
67
112
  try:
68
113
  self._log_usage_impl(token_usage_stats, session, execution_id)
69
114
  session.commit()
115
+ logger.debug("Successfully committed token usage for execution_id: %s", execution_id)
70
116
  except Exception as e:
71
117
  logger.error("Failed to log token usage: %s", str(e))
72
118
  session.rollback()
@@ -0,0 +1,64 @@
1
+ """Adding detailed input and output token schema
2
+
3
+ Revision ID: f028b8155fed
4
+ Revises: f6f1f2437513
5
+ Create Date: 2025-01-19 15:41:12.715623
6
+
7
+ """
8
+
9
+ from typing import Sequence, Union
10
+
11
+ from alembic import op
12
+ import sqlalchemy as sa
13
+
14
+
15
+ # revision identifiers, used by Alembic.
16
+ revision: str = "f028b8155fed"
17
+ down_revision: Union[str, None] = "f6f1f2437513"
18
+ branch_labels: Union[str, Sequence[str], None] = None
19
+ depends_on: Union[str, Sequence[str], None] = None
20
+
21
+
22
+ def upgrade() -> None:
23
+ # ### commands auto generated by Alembic - please adjust! ###
24
+ op.add_column("token_usage", sa.Column("total_cost", sa.Integer(), nullable=False))
25
+ op.add_column(
26
+ "token_usage",
27
+ sa.Column("prompt_cached_input_tokens", sa.Integer(), nullable=True),
28
+ )
29
+ op.add_column(
30
+ "token_usage",
31
+ sa.Column("prompt_cached_creation_tokens", sa.Integer(), nullable=True),
32
+ )
33
+ op.add_column(
34
+ "token_usage", sa.Column("prompt_audio_tokens", sa.Integer(), nullable=True)
35
+ )
36
+ op.add_column(
37
+ "token_usage", sa.Column("completion_audio_tokens", sa.Integer(), nullable=True)
38
+ )
39
+ op.add_column(
40
+ "token_usage",
41
+ sa.Column("completion_reasoning_tokens", sa.Integer(), nullable=True),
42
+ )
43
+ op.add_column(
44
+ "token_usage",
45
+ sa.Column("completion_accepted_prediction_tokens", sa.Integer(), nullable=True),
46
+ )
47
+ op.add_column(
48
+ "token_usage",
49
+ sa.Column("completion_rejected_prediction_tokens", sa.Integer(), nullable=True),
50
+ )
51
+ # ### end Alembic commands ###
52
+
53
+
54
+ def downgrade() -> None:
55
+ # ### commands auto generated by Alembic - please adjust! ###
56
+ op.drop_column("token_usage", "completion_rejected_prediction_tokens")
57
+ op.drop_column("token_usage", "completion_accepted_prediction_tokens")
58
+ op.drop_column("token_usage", "completion_reasoning_tokens")
59
+ op.drop_column("token_usage", "completion_audio_tokens")
60
+ op.drop_column("token_usage", "prompt_audio_tokens")
61
+ op.drop_column("token_usage", "prompt_cached_creation_tokens")
62
+ op.drop_column("token_usage", "prompt_cached_input_tokens")
63
+ op.drop_column("token_usage", "total_cost")
64
+ # ### end Alembic commands ###
tokenator/models.py CHANGED
@@ -1,17 +1,44 @@
1
1
  from pydantic import BaseModel, Field
2
- from typing import List
2
+ from typing import List, Optional
3
3
 
4
4
 
5
5
  class TokenRate(BaseModel):
6
6
  prompt: float = Field(..., description="Cost per prompt token")
7
7
  completion: float = Field(..., description="Cost per completion token")
8
+ prompt_audio: Optional[float] = Field(
9
+ None, description="Cost per audio prompt token"
10
+ )
11
+ completion_audio: Optional[float] = Field(
12
+ None, description="Cost per audio completion token"
13
+ )
14
+ prompt_cached_input: Optional[float] = Field(
15
+ None, description="Cost per cached prompt input token"
16
+ )
17
+ prompt_cached_creation: Optional[float] = Field(
18
+ None, description="Cost per cached prompt creation token"
19
+ )
20
+
21
+
22
+ class PromptTokenDetails(BaseModel):
23
+ cached_input_tokens: Optional[int] = None
24
+ cached_creation_tokens: Optional[int] = None
25
+ audio_tokens: Optional[int] = None
26
+
27
+
28
+ class CompletionTokenDetails(BaseModel):
29
+ reasoning_tokens: Optional[int] = None
30
+ audio_tokens: Optional[int] = None
31
+ accepted_prediction_tokens: Optional[int] = None
32
+ rejected_prediction_tokens: Optional[int] = None
8
33
 
9
34
 
10
35
  class TokenMetrics(BaseModel):
11
- total_cost: float = Field(..., description="Total cost in USD")
12
- total_tokens: int = Field(..., description="Total tokens used")
13
- prompt_tokens: int = Field(..., description="Number of prompt tokens")
14
- completion_tokens: int = Field(..., description="Number of completion tokens")
36
+ total_cost: float = Field(default=0, description="Total cost in USD")
37
+ total_tokens: int = Field(default=0, description="Total tokens used")
38
+ prompt_tokens: int = Field(default=0, description="Number of prompt tokens")
39
+ completion_tokens: int = Field(default=0, description="Number of completion tokens")
40
+ prompt_tokens_details: Optional[PromptTokenDetails] = None
41
+ completion_tokens_details: Optional[CompletionTokenDetails] = None
15
42
 
16
43
 
17
44
  class ModelUsage(TokenMetrics):
@@ -31,12 +58,6 @@ class TokenUsageReport(TokenMetrics):
31
58
  )
32
59
 
33
60
 
34
- class Usage(BaseModel):
35
- prompt_tokens: int = 0
36
- completion_tokens: int = 0
37
- total_tokens: int = 0
38
-
39
-
40
61
  class TokenUsageStats(BaseModel):
41
62
  model: str
42
- usage: Usage
63
+ usage: TokenMetrics
@@ -6,9 +6,18 @@ import logging
6
6
  from openai import AsyncOpenAI, OpenAI
7
7
  from openai.types.chat import ChatCompletion, ChatCompletionChunk
8
8
 
9
- from ..models import Usage, TokenUsageStats
9
+ from ..models import (
10
+ TokenMetrics,
11
+ TokenUsageStats,
12
+ PromptTokenDetails,
13
+ CompletionTokenDetails,
14
+ )
10
15
  from ..base_wrapper import BaseWrapper, ResponseType
11
- from .stream_interceptors import OpenAIAsyncStreamInterceptor, OpenAISyncStreamInterceptor
16
+ from .stream_interceptors import (
17
+ OpenAIAsyncStreamInterceptor,
18
+ OpenAISyncStreamInterceptor,
19
+ )
20
+ from ..state import is_tokenator_enabled
12
21
 
13
22
  logger = logging.getLogger(__name__)
14
23
 
@@ -26,18 +35,49 @@ class BaseOpenAIWrapper(BaseWrapper):
26
35
  if isinstance(response, ChatCompletion):
27
36
  if response.usage is None:
28
37
  return None
29
- usage = Usage(
38
+ usage = TokenMetrics(
30
39
  prompt_tokens=response.usage.prompt_tokens,
31
40
  completion_tokens=response.usage.completion_tokens,
32
41
  total_tokens=response.usage.total_tokens,
42
+ prompt_tokens_details=PromptTokenDetails(
43
+ cached_input_tokens=getattr(
44
+ response.usage.prompt_tokens_details, "cached_tokens", None
45
+ ),
46
+ audio_tokens=getattr(
47
+ response.usage.prompt_tokens_details, "audio_tokens", None
48
+ ),
49
+ ),
50
+ completion_tokens_details=CompletionTokenDetails(
51
+ reasoning_tokens=getattr(
52
+ response.usage.completion_tokens_details,
53
+ "reasoning_tokens",
54
+ None,
55
+ ),
56
+ audio_tokens=getattr(
57
+ response.usage.completion_tokens_details,
58
+ "audio_tokens",
59
+ None,
60
+ ),
61
+ accepted_prediction_tokens=getattr(
62
+ response.usage.completion_tokens_details,
63
+ "accepted_prediction_tokens",
64
+ None,
65
+ ),
66
+ rejected_prediction_tokens=getattr(
67
+ response.usage.completion_tokens_details,
68
+ "rejected_prediction_tokens",
69
+ None,
70
+ ),
71
+ ),
33
72
  )
73
+
34
74
  return TokenUsageStats(model=response.model, usage=usage)
35
75
 
36
76
  elif isinstance(response, dict):
37
77
  usage_dict = response.get("usage")
38
78
  if not usage_dict:
39
79
  return None
40
- usage = Usage(
80
+ usage = TokenMetrics(
41
81
  prompt_tokens=usage_dict.get("prompt_tokens", 0),
42
82
  completion_tokens=usage_dict.get("completion_tokens", 0),
43
83
  total_tokens=usage_dict.get("total_tokens", 0),
@@ -58,6 +98,10 @@ class BaseOpenAIWrapper(BaseWrapper):
58
98
  def completions(self):
59
99
  return self
60
100
 
101
+ @property
102
+ def beta(self):
103
+ return self
104
+
61
105
 
62
106
  def _create_usage_callback(execution_id, log_usage_fn):
63
107
  """Creates a callback function for processing usage statistics from stream chunks."""
@@ -65,10 +109,18 @@ def _create_usage_callback(execution_id, log_usage_fn):
65
109
  def usage_callback(chunks):
66
110
  if not chunks:
67
111
  return
112
+
113
+ # Skip if tokenator is disabled
114
+ if not is_tokenator_enabled:
115
+ logger.debug("Tokenator is disabled - skipping stream usage logging")
116
+ return
117
+
118
+ logger.debug("Processing stream usage for execution_id: %s", execution_id)
119
+
68
120
  # Build usage_data from the first chunk's model
69
121
  usage_data = TokenUsageStats(
70
122
  model=chunks[0].model,
71
- usage=Usage(),
123
+ usage=TokenMetrics(),
72
124
  )
73
125
  # Sum up usage from all chunks
74
126
  has_usage = False
@@ -106,6 +158,26 @@ class OpenAIWrapper(BaseOpenAIWrapper):
106
158
 
107
159
  return response
108
160
 
161
+ def parse(
162
+ self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
163
+ ) -> Union[ChatCompletion, Iterator[ChatCompletion]]:
164
+ """Create a chat completion parse and log token usage."""
165
+ logger.debug("Creating chat completion with args: %s, kwargs: %s", args, kwargs)
166
+
167
+ if kwargs.get("stream", False):
168
+ base_stream = self.client.beta.chat.completions.parse(*args, **kwargs)
169
+ return OpenAISyncStreamInterceptor(
170
+ base_stream=base_stream,
171
+ usage_callback=_create_usage_callback(execution_id, self._log_usage),
172
+ )
173
+
174
+ response = self.client.beta.chat.completions.parse(*args, **kwargs)
175
+ usage_data = self._process_response_usage(response)
176
+ if usage_data:
177
+ self._log_usage(usage_data, execution_id=execution_id)
178
+
179
+ return response
180
+
109
181
 
110
182
  class AsyncOpenAIWrapper(BaseOpenAIWrapper):
111
183
  async def create(
@@ -131,6 +203,26 @@ class AsyncOpenAIWrapper(BaseOpenAIWrapper):
131
203
  self._log_usage(usage_data, execution_id=execution_id)
132
204
  return response
133
205
 
206
+ async def parse(
207
+ self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
208
+ ) -> Union[ChatCompletion, AsyncIterator[ChatCompletionChunk]]:
209
+ """Create a chat completion parse and log token usage."""
210
+ logger.debug("Creating chat completion with args: %s, kwargs: %s", args, kwargs)
211
+
212
+ if kwargs.get("stream", False):
213
+ base_stream = await self.client.beta.chat.completions.parse(*args, **kwargs)
214
+ return OpenAIAsyncStreamInterceptor(
215
+ base_stream=base_stream,
216
+ usage_callback=_create_usage_callback(execution_id, self._log_usage),
217
+ )
218
+
219
+ response = await self.client.beta.chat.completions.parse(*args, **kwargs)
220
+ usage_data = self._process_response_usage(response)
221
+ if usage_data:
222
+ self._log_usage(usage_data, execution_id=execution_id)
223
+
224
+ return response
225
+
134
226
 
135
227
  @overload
136
228
  def tokenator_openai(
tokenator/schemas.py CHANGED
@@ -1,25 +1,27 @@
1
1
  """SQLAlchemy models for tokenator."""
2
2
 
3
3
  from datetime import datetime
4
+ from typing import Optional
4
5
 
5
6
  from sqlalchemy import create_engine, Column, Integer, String, DateTime, Index
6
7
  from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
7
8
 
8
9
  from .utils import get_default_db_path
10
+ from . import state # Import state to access db_path
9
11
 
10
12
  Base = declarative_base()
11
13
 
12
14
 
13
- def get_engine(db_path: str = None):
15
+ def get_engine(db_path: Optional[str] = None):
14
16
  """Create SQLAlchemy engine with the given database path."""
15
17
  if db_path is None:
16
- db_path = get_default_db_path()
18
+ db_path = state.db_path or get_default_db_path() # Use state.db_path if set
17
19
  return create_engine(f"sqlite:///{db_path}", echo=False)
18
20
 
19
21
 
20
- def get_session(db_path: str = None):
22
+ def get_session():
21
23
  """Create a thread-safe session factory."""
22
- engine = get_engine(db_path)
24
+ engine = get_engine()
23
25
  # Base.metadata.create_all(engine)
24
26
  session_factory = sessionmaker(bind=engine)
25
27
  return scoped_session(session_factory)
@@ -38,28 +40,28 @@ class TokenUsage(Base):
38
40
  updated_at = Column(
39
41
  DateTime, nullable=False, default=datetime.now, onupdate=datetime.now
40
42
  )
43
+
44
+ # Core metrics (mandatory)
45
+ total_cost = Column(Integer, nullable=False)
41
46
  prompt_tokens = Column(Integer, nullable=False)
42
47
  completion_tokens = Column(Integer, nullable=False)
43
48
  total_tokens = Column(Integer, nullable=False)
44
49
 
45
- # Create indexes
50
+ # Prompt token details (optional)
51
+ prompt_cached_input_tokens = Column(Integer, nullable=True)
52
+ prompt_cached_creation_tokens = Column(Integer, nullable=True)
53
+ prompt_audio_tokens = Column(Integer, nullable=True)
54
+
55
+ # Completion token details (optional)
56
+ completion_audio_tokens = Column(Integer, nullable=True)
57
+ completion_reasoning_tokens = Column(Integer, nullable=True)
58
+ completion_accepted_prediction_tokens = Column(Integer, nullable=True)
59
+ completion_rejected_prediction_tokens = Column(Integer, nullable=True)
60
+
61
+ # Keep existing indexes
46
62
  __table_args__ = (
47
63
  Index("idx_created_at", "created_at"),
48
64
  Index("idx_execution_id", "execution_id"),
49
65
  Index("idx_provider", "provider"),
50
66
  Index("idx_model", "model"),
51
67
  )
52
-
53
- def to_dict(self):
54
- """Convert model instance to dictionary."""
55
- return {
56
- "id": self.id,
57
- "execution_id": self.execution_id,
58
- "provider": self.provider,
59
- "model": self.model,
60
- "created_at": self.created_at,
61
- "updated_at": self.updated_at,
62
- "prompt_tokens": self.prompt_tokens,
63
- "completion_tokens": self.completion_tokens,
64
- "total_tokens": self.total_tokens,
65
- }
tokenator/state.py ADDED
@@ -0,0 +1,12 @@
1
+ """Global state for tokenator."""
2
+
3
+ import logging
4
+ from typing import Optional
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ # Global flag to track if tokenator is properly initialized
9
+ is_tokenator_enabled = True
10
+
11
+ # Store the database path
12
+ db_path: Optional[str] = None