tokenator 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- tokenator/__init__.py +2 -2
- tokenator/anthropic/client_anthropic.py +154 -0
- tokenator/anthropic/stream_interceptors.py +146 -0
- tokenator/base_wrapper.py +26 -13
- tokenator/create_migrations.py +6 -5
- tokenator/migrations/env.py +5 -4
- tokenator/migrations/versions/f6f1f2437513_initial_migration.py +25 -23
- tokenator/migrations.py +9 -6
- tokenator/models.py +15 -4
- tokenator/openai/client_openai.py +66 -70
- tokenator/openai/stream_interceptors.py +146 -0
- tokenator/schemas.py +26 -27
- tokenator/usage.py +114 -47
- tokenator/utils.py +14 -9
- {tokenator-0.1.9.dist-info → tokenator-0.1.11.dist-info}/METADATA +72 -17
- tokenator-0.1.11.dist-info/RECORD +19 -0
- {tokenator-0.1.9.dist-info → tokenator-0.1.11.dist-info}/WHEEL +1 -1
- tokenator/client_anthropic.py +0 -148
- tokenator/openai/AsyncStreamInterceptor.py +0 -78
- tokenator-0.1.9.dist-info/RECORD +0 -18
- {tokenator-0.1.9.dist-info → tokenator-0.1.11.dist-info}/LICENSE +0 -0
tokenator/__init__.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from .openai.client_openai import tokenator_openai
|
5
|
-
from .client_anthropic import tokenator_anthropic
|
5
|
+
from .anthropic.client_anthropic import tokenator_anthropic
|
6
6
|
from . import usage
|
7
7
|
from .utils import get_default_db_path
|
8
8
|
from .migrations import check_and_run_migrations
|
@@ -15,4 +15,4 @@ logger = logging.getLogger(__name__)
|
|
15
15
|
try:
|
16
16
|
check_and_run_migrations()
|
17
17
|
except Exception as e:
|
18
|
-
logger.warning(f"Failed to run migrations, but continuing anyway: {e}")
|
18
|
+
logger.warning(f"Failed to run migrations, but continuing anyway: {e}")
|
@@ -0,0 +1,154 @@
|
|
1
|
+
"""Anthropic client wrapper with token usage tracking."""
|
2
|
+
|
3
|
+
from typing import Any, Optional, Union, overload, Iterator, AsyncIterator
|
4
|
+
import logging
|
5
|
+
|
6
|
+
from anthropic import Anthropic, AsyncAnthropic
|
7
|
+
from anthropic.types import Message, RawMessageStartEvent, RawMessageDeltaEvent
|
8
|
+
|
9
|
+
from ..models import Usage, TokenUsageStats
|
10
|
+
from ..base_wrapper import BaseWrapper, ResponseType
|
11
|
+
from .stream_interceptors import AnthropicAsyncStreamInterceptor, AnthropicSyncStreamInterceptor
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class BaseAnthropicWrapper(BaseWrapper):
|
17
|
+
provider = "anthropic"
|
18
|
+
|
19
|
+
def _process_response_usage(
|
20
|
+
self, response: ResponseType
|
21
|
+
) -> Optional[TokenUsageStats]:
|
22
|
+
"""Process and log usage statistics from a response."""
|
23
|
+
try:
|
24
|
+
if isinstance(response, Message):
|
25
|
+
if not hasattr(response, "usage"):
|
26
|
+
return None
|
27
|
+
usage = Usage(
|
28
|
+
prompt_tokens=response.usage.input_tokens,
|
29
|
+
completion_tokens=response.usage.output_tokens,
|
30
|
+
total_tokens=response.usage.input_tokens
|
31
|
+
+ response.usage.output_tokens,
|
32
|
+
)
|
33
|
+
return TokenUsageStats(model=response.model, usage=usage)
|
34
|
+
elif isinstance(response, dict):
|
35
|
+
usage_dict = response.get("usage")
|
36
|
+
if not usage_dict:
|
37
|
+
return None
|
38
|
+
usage = Usage(
|
39
|
+
prompt_tokens=usage_dict.get("input_tokens", 0),
|
40
|
+
completion_tokens=usage_dict.get("output_tokens", 0),
|
41
|
+
total_tokens=usage_dict.get("input_tokens", 0)
|
42
|
+
+ usage_dict.get("output_tokens", 0),
|
43
|
+
)
|
44
|
+
return TokenUsageStats(
|
45
|
+
model=response.get("model", "unknown"), usage=usage
|
46
|
+
)
|
47
|
+
except Exception as e:
|
48
|
+
logger.warning("Failed to process usage stats: %s", str(e))
|
49
|
+
return None
|
50
|
+
return None
|
51
|
+
|
52
|
+
@property
|
53
|
+
def messages(self):
|
54
|
+
return self
|
55
|
+
|
56
|
+
|
57
|
+
def _create_usage_callback(execution_id, log_usage_fn):
|
58
|
+
"""Creates a callback function for processing usage statistics from stream chunks."""
|
59
|
+
def usage_callback(chunks):
|
60
|
+
if not chunks:
|
61
|
+
return
|
62
|
+
|
63
|
+
usage_data = TokenUsageStats(
|
64
|
+
model=chunks[0].message.model if isinstance(chunks[0], RawMessageStartEvent) else "",
|
65
|
+
usage=Usage(),
|
66
|
+
)
|
67
|
+
|
68
|
+
for chunk in chunks:
|
69
|
+
if isinstance(chunk, RawMessageStartEvent):
|
70
|
+
usage_data.model = chunk.message.model
|
71
|
+
usage_data.usage.prompt_tokens += chunk.message.usage.input_tokens
|
72
|
+
usage_data.usage.completion_tokens += chunk.message.usage.output_tokens
|
73
|
+
elif isinstance(chunk, RawMessageDeltaEvent):
|
74
|
+
usage_data.usage.completion_tokens += chunk.usage.output_tokens
|
75
|
+
|
76
|
+
usage_data.usage.total_tokens = usage_data.usage.prompt_tokens + usage_data.usage.completion_tokens
|
77
|
+
log_usage_fn(usage_data, execution_id=execution_id)
|
78
|
+
|
79
|
+
return usage_callback
|
80
|
+
|
81
|
+
|
82
|
+
class AnthropicWrapper(BaseAnthropicWrapper):
|
83
|
+
def create(
|
84
|
+
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
85
|
+
) -> Union[Message, Iterator[Message]]:
|
86
|
+
"""Create a message completion and log token usage."""
|
87
|
+
logger.debug("Creating message completion with args: %s, kwargs: %s", args, kwargs)
|
88
|
+
|
89
|
+
if kwargs.get("stream", False):
|
90
|
+
base_stream = self.client.messages.create(*args, **kwargs)
|
91
|
+
return AnthropicSyncStreamInterceptor(
|
92
|
+
base_stream=base_stream,
|
93
|
+
usage_callback=_create_usage_callback(execution_id, self._log_usage),
|
94
|
+
)
|
95
|
+
|
96
|
+
response = self.client.messages.create(*args, **kwargs)
|
97
|
+
usage_data = self._process_response_usage(response)
|
98
|
+
if usage_data:
|
99
|
+
self._log_usage(usage_data, execution_id=execution_id)
|
100
|
+
return response
|
101
|
+
|
102
|
+
|
103
|
+
class AsyncAnthropicWrapper(BaseAnthropicWrapper):
|
104
|
+
async def create(
|
105
|
+
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
106
|
+
) -> Union[Message, AsyncIterator[Message]]:
|
107
|
+
"""Create a message completion and log token usage."""
|
108
|
+
logger.debug("Creating message completion with args: %s, kwargs: %s", args, kwargs)
|
109
|
+
|
110
|
+
if kwargs.get("stream", False):
|
111
|
+
base_stream = await self.client.messages.create(*args, **kwargs)
|
112
|
+
return AnthropicAsyncStreamInterceptor(
|
113
|
+
base_stream=base_stream,
|
114
|
+
usage_callback=_create_usage_callback(execution_id, self._log_usage),
|
115
|
+
)
|
116
|
+
|
117
|
+
response = await self.client.messages.create(*args, **kwargs)
|
118
|
+
usage_data = self._process_response_usage(response)
|
119
|
+
if usage_data:
|
120
|
+
self._log_usage(usage_data, execution_id=execution_id)
|
121
|
+
return response
|
122
|
+
|
123
|
+
|
124
|
+
@overload
|
125
|
+
def tokenator_anthropic(
|
126
|
+
client: Anthropic,
|
127
|
+
db_path: Optional[str] = None,
|
128
|
+
) -> AnthropicWrapper: ...
|
129
|
+
|
130
|
+
|
131
|
+
@overload
|
132
|
+
def tokenator_anthropic(
|
133
|
+
client: AsyncAnthropic,
|
134
|
+
db_path: Optional[str] = None,
|
135
|
+
) -> AsyncAnthropicWrapper: ...
|
136
|
+
|
137
|
+
|
138
|
+
def tokenator_anthropic(
|
139
|
+
client: Union[Anthropic, AsyncAnthropic],
|
140
|
+
db_path: Optional[str] = None,
|
141
|
+
) -> Union[AnthropicWrapper, AsyncAnthropicWrapper]:
|
142
|
+
"""Create a token-tracking wrapper for an Anthropic client.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
client: Anthropic or AsyncAnthropic client instance
|
146
|
+
db_path: Optional path to SQLite database for token tracking
|
147
|
+
"""
|
148
|
+
if isinstance(client, Anthropic):
|
149
|
+
return AnthropicWrapper(client=client, db_path=db_path)
|
150
|
+
|
151
|
+
if isinstance(client, AsyncAnthropic):
|
152
|
+
return AsyncAnthropicWrapper(client=client, db_path=db_path)
|
153
|
+
|
154
|
+
raise ValueError("Client must be an instance of Anthropic or AsyncAnthropic")
|
@@ -0,0 +1,146 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import AsyncIterator, Callable, List, Optional, TypeVar, Iterator
|
3
|
+
|
4
|
+
from anthropic import AsyncStream, Stream
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
_T = TypeVar("_T")
|
9
|
+
|
10
|
+
|
11
|
+
class AnthropicAsyncStreamInterceptor(AsyncStream[_T]):
|
12
|
+
"""
|
13
|
+
A wrapper around anthropic.AsyncStream that delegates all functionality
|
14
|
+
to the 'base_stream' but intercepts each chunk to handle usage or
|
15
|
+
logging logic. This preserves .response and other methods.
|
16
|
+
|
17
|
+
You can store aggregated usage in a local list and process it when
|
18
|
+
the stream ends (StopAsyncIteration).
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
base_stream: AsyncStream[_T],
|
24
|
+
usage_callback: Optional[Callable[[List[_T]], None]] = None,
|
25
|
+
):
|
26
|
+
# We do NOT call super().__init__() because anthropic.AsyncStream
|
27
|
+
# expects constructor parameters we don't want to re-initialize.
|
28
|
+
# Instead, we just store the base_stream and delegate everything to it.
|
29
|
+
self._base_stream = base_stream
|
30
|
+
self._usage_callback = usage_callback
|
31
|
+
self._chunks: List[_T] = []
|
32
|
+
|
33
|
+
@property
|
34
|
+
def response(self):
|
35
|
+
"""Expose the original stream's 'response' so user code can do stream.response, etc."""
|
36
|
+
return self._base_stream.response
|
37
|
+
|
38
|
+
def __aiter__(self) -> AsyncIterator[_T]:
|
39
|
+
"""
|
40
|
+
Called when we do 'async for chunk in wrapped_stream:'
|
41
|
+
We simply return 'self'. Then __anext__ does the rest.
|
42
|
+
"""
|
43
|
+
return self
|
44
|
+
|
45
|
+
async def __anext__(self) -> _T:
|
46
|
+
"""
|
47
|
+
Intercept iteration. We pull the next chunk from the base_stream.
|
48
|
+
If it's the end, do any final usage logging, then raise StopAsyncIteration.
|
49
|
+
Otherwise, we can accumulate usage info or do whatever we need with the chunk.
|
50
|
+
"""
|
51
|
+
try:
|
52
|
+
chunk = await self._base_stream.__anext__()
|
53
|
+
except StopAsyncIteration:
|
54
|
+
# Once the base stream is fully consumed, we can do final usage/logging.
|
55
|
+
if self._usage_callback and self._chunks:
|
56
|
+
self._usage_callback(self._chunks)
|
57
|
+
raise
|
58
|
+
|
59
|
+
# Intercept each chunk
|
60
|
+
self._chunks.append(chunk)
|
61
|
+
return chunk
|
62
|
+
|
63
|
+
async def __aenter__(self) -> "AnthropicAsyncStreamInterceptor[_T]":
|
64
|
+
"""Support async with ... : usage."""
|
65
|
+
await self._base_stream.__aenter__()
|
66
|
+
return self
|
67
|
+
|
68
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
69
|
+
"""
|
70
|
+
Ensure we propagate __aexit__ to the base stream,
|
71
|
+
so connections are properly closed.
|
72
|
+
"""
|
73
|
+
return await self._base_stream.__aexit__(exc_type, exc_val, exc_tb)
|
74
|
+
|
75
|
+
async def close(self) -> None:
|
76
|
+
"""Delegate close to the base_stream."""
|
77
|
+
await self._base_stream.close()
|
78
|
+
|
79
|
+
|
80
|
+
class AnthropicSyncStreamInterceptor(Stream[_T]):
|
81
|
+
"""
|
82
|
+
A wrapper around anthropic.Stream that delegates all functionality
|
83
|
+
to the 'base_stream' but intercepts each chunk to handle usage or
|
84
|
+
logging logic. This preserves .response and other methods.
|
85
|
+
|
86
|
+
You can store aggregated usage in a local list and process it when
|
87
|
+
the stream ends (StopIteration).
|
88
|
+
"""
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
base_stream: Stream[_T],
|
93
|
+
usage_callback: Optional[Callable[[List[_T]], None]] = None,
|
94
|
+
):
|
95
|
+
# We do NOT call super().__init__() because openai.SyncStream
|
96
|
+
# expects constructor parameters we don't want to re-initialize.
|
97
|
+
# Instead, we just store the base_stream and delegate everything to it.
|
98
|
+
self._base_stream = base_stream
|
99
|
+
self._usage_callback = usage_callback
|
100
|
+
self._chunks: List[_T] = []
|
101
|
+
|
102
|
+
@property
|
103
|
+
def response(self):
|
104
|
+
"""Expose the original stream's 'response' so user code can do stream.response, etc."""
|
105
|
+
return self._base_stream.response
|
106
|
+
|
107
|
+
def __iter__(self) -> Iterator[_T]:
|
108
|
+
"""
|
109
|
+
Called when we do 'for chunk in wrapped_stream:'
|
110
|
+
We simply return 'self'. Then __next__ does the rest.
|
111
|
+
"""
|
112
|
+
return self
|
113
|
+
|
114
|
+
def __next__(self) -> _T:
|
115
|
+
"""
|
116
|
+
Intercept iteration. We pull the next chunk from the base_stream.
|
117
|
+
If it's the end, do any final usage logging, then raise StopIteration.
|
118
|
+
Otherwise, we can accumulate usage info or do whatever we need with the chunk.
|
119
|
+
"""
|
120
|
+
try:
|
121
|
+
chunk = self._base_stream.__next__()
|
122
|
+
except StopIteration:
|
123
|
+
# Once the base stream is fully consumed, we can do final usage/logging.
|
124
|
+
if self._usage_callback and self._chunks:
|
125
|
+
self._usage_callback(self._chunks)
|
126
|
+
raise
|
127
|
+
|
128
|
+
# Intercept each chunk
|
129
|
+
self._chunks.append(chunk)
|
130
|
+
return chunk
|
131
|
+
|
132
|
+
def __enter__(self) -> "AnthropicSyncStreamInterceptor[_T]":
|
133
|
+
"""Support with ... : usage."""
|
134
|
+
self._base_stream.__enter__()
|
135
|
+
return self
|
136
|
+
|
137
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
138
|
+
"""
|
139
|
+
Ensure we propagate __aexit__ to the base stream,
|
140
|
+
so connections are properly closed.
|
141
|
+
"""
|
142
|
+
return self._base_stream.__exit__(exc_type, exc_val, exc_tb)
|
143
|
+
|
144
|
+
async def close(self) -> None:
|
145
|
+
"""Delegate close to the base_stream."""
|
146
|
+
self._base_stream.close()
|
tokenator/base_wrapper.py
CHANGED
@@ -1,16 +1,17 @@
|
|
1
1
|
"""Base wrapper class for token usage tracking."""
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional, TypeVar
|
5
5
|
import logging
|
6
6
|
import uuid
|
7
7
|
|
8
|
-
from .models import
|
8
|
+
from .models import TokenUsageStats
|
9
9
|
from .schemas import get_session, TokenUsage
|
10
10
|
|
11
11
|
logger = logging.getLogger(__name__)
|
12
12
|
|
13
|
-
ResponseType = TypeVar(
|
13
|
+
ResponseType = TypeVar("ResponseType")
|
14
|
+
|
14
15
|
|
15
16
|
class BaseWrapper:
|
16
17
|
def __init__(self, client: Any, db_path: Optional[str] = None):
|
@@ -22,13 +23,20 @@ class BaseWrapper:
|
|
22
23
|
logger.info("Created database directory at: %s", Path(db_path).parent)
|
23
24
|
|
24
25
|
self.Session = get_session(db_path)
|
25
|
-
|
26
|
-
logger.debug("Initializing %s with db_path: %s",
|
27
|
-
self.__class__.__name__, db_path)
|
28
26
|
|
29
|
-
|
27
|
+
logger.debug(
|
28
|
+
"Initializing %s with db_path: %s", self.__class__.__name__, db_path
|
29
|
+
)
|
30
|
+
|
31
|
+
def _log_usage_impl(
|
32
|
+
self, token_usage_stats: TokenUsageStats, session, execution_id: str
|
33
|
+
) -> None:
|
30
34
|
"""Implementation of token usage logging."""
|
31
|
-
logger.debug(
|
35
|
+
logger.debug(
|
36
|
+
"Logging usage for model %s: %s",
|
37
|
+
token_usage_stats.model,
|
38
|
+
token_usage_stats.usage.model_dump(),
|
39
|
+
)
|
32
40
|
try:
|
33
41
|
token_usage = TokenUsage(
|
34
42
|
execution_id=execution_id,
|
@@ -36,15 +44,20 @@ class BaseWrapper:
|
|
36
44
|
model=token_usage_stats.model,
|
37
45
|
prompt_tokens=token_usage_stats.usage.prompt_tokens,
|
38
46
|
completion_tokens=token_usage_stats.usage.completion_tokens,
|
39
|
-
total_tokens=token_usage_stats.usage.total_tokens
|
47
|
+
total_tokens=token_usage_stats.usage.total_tokens,
|
40
48
|
)
|
41
49
|
session.add(token_usage)
|
42
|
-
logger.
|
43
|
-
|
50
|
+
logger.debug(
|
51
|
+
"Logged token usage: model=%s, total_tokens=%d",
|
52
|
+
token_usage_stats.model,
|
53
|
+
token_usage_stats.usage.total_tokens,
|
54
|
+
)
|
44
55
|
except Exception as e:
|
45
56
|
logger.error("Failed to log token usage: %s", str(e))
|
46
57
|
|
47
|
-
def _log_usage(
|
58
|
+
def _log_usage(
|
59
|
+
self, token_usage_stats: TokenUsageStats, execution_id: Optional[str] = None
|
60
|
+
):
|
48
61
|
"""Log token usage to database."""
|
49
62
|
if not execution_id:
|
50
63
|
execution_id = str(uuid.uuid4())
|
@@ -58,4 +71,4 @@ class BaseWrapper:
|
|
58
71
|
logger.error("Failed to log token usage: %s", str(e))
|
59
72
|
session.rollback()
|
60
73
|
finally:
|
61
|
-
session.close()
|
74
|
+
session.close()
|
tokenator/create_migrations.py
CHANGED
@@ -1,25 +1,26 @@
|
|
1
1
|
"""Development utilities for tokenator."""
|
2
2
|
|
3
|
-
import os
|
4
3
|
import sys
|
5
4
|
from pathlib import Path
|
6
5
|
from alembic import command
|
7
6
|
from tokenator.migrations import get_alembic_config
|
8
7
|
|
8
|
+
|
9
9
|
def create_migration(message: str):
|
10
10
|
"""Create a new migration based on model changes."""
|
11
11
|
config = get_alembic_config()
|
12
|
-
|
12
|
+
|
13
13
|
# Get the migrations directory
|
14
14
|
migrations_dir = Path(__file__).parent / "migrations" / "versions"
|
15
15
|
migrations_dir.mkdir(parents=True, exist_ok=True)
|
16
|
-
|
16
|
+
|
17
17
|
# Generate migration with custom message
|
18
|
-
command.revision(config, autogenerate=True, message=message)
|
18
|
+
command.revision(config, autogenerate=True, message=message)
|
19
|
+
|
19
20
|
|
20
21
|
if __name__ == "__main__":
|
21
22
|
if len(sys.argv) > 1:
|
22
23
|
msg = " ".join(sys.argv[1:])
|
23
24
|
else:
|
24
25
|
msg = "auto generated migration"
|
25
|
-
create_migration(msg)
|
26
|
+
create_migration(msg)
|
tokenator/migrations/env.py
CHANGED
@@ -18,6 +18,7 @@ if config.config_file_name is not None:
|
|
18
18
|
# add your model's MetaData object here
|
19
19
|
target_metadata = Base.metadata
|
20
20
|
|
21
|
+
|
21
22
|
def run_migrations_offline() -> None:
|
22
23
|
"""Run migrations in 'offline' mode."""
|
23
24
|
url = config.get_main_option("sqlalchemy.url")
|
@@ -31,6 +32,7 @@ def run_migrations_offline() -> None:
|
|
31
32
|
with context.begin_transaction():
|
32
33
|
context.run_migrations()
|
33
34
|
|
35
|
+
|
34
36
|
def run_migrations_online() -> None:
|
35
37
|
"""Run migrations in 'online' mode."""
|
36
38
|
connectable = engine_from_config(
|
@@ -40,14 +42,13 @@ def run_migrations_online() -> None:
|
|
40
42
|
)
|
41
43
|
|
42
44
|
with connectable.connect() as connection:
|
43
|
-
context.configure(
|
44
|
-
connection=connection, target_metadata=target_metadata
|
45
|
-
)
|
45
|
+
context.configure(connection=connection, target_metadata=target_metadata)
|
46
46
|
|
47
47
|
with context.begin_transaction():
|
48
48
|
context.run_migrations()
|
49
49
|
|
50
|
+
|
50
51
|
if context.is_offline_mode():
|
51
52
|
run_migrations_offline()
|
52
53
|
else:
|
53
|
-
run_migrations_online()
|
54
|
+
run_migrations_online()
|
@@ -1,10 +1,11 @@
|
|
1
1
|
"""Initial migration
|
2
2
|
|
3
3
|
Revision ID: f6f1f2437513
|
4
|
-
Revises:
|
4
|
+
Revises:
|
5
5
|
Create Date: 2024-12-21 17:33:27.187221
|
6
6
|
|
7
7
|
"""
|
8
|
+
|
8
9
|
from typing import Sequence, Union
|
9
10
|
|
10
11
|
from alembic import op
|
@@ -12,7 +13,7 @@ import sqlalchemy as sa
|
|
12
13
|
|
13
14
|
|
14
15
|
# revision identifiers, used by Alembic.
|
15
|
-
revision: str =
|
16
|
+
revision: str = "f6f1f2437513"
|
16
17
|
down_revision: Union[str, None] = None
|
17
18
|
branch_labels: Union[str, Sequence[str], None] = None
|
18
19
|
depends_on: Union[str, Sequence[str], None] = None
|
@@ -20,30 +21,31 @@ depends_on: Union[str, Sequence[str], None] = None
|
|
20
21
|
|
21
22
|
def upgrade() -> None:
|
22
23
|
# ### commands auto generated by Alembic - please adjust! ###
|
23
|
-
op.create_table(
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
24
|
+
op.create_table(
|
25
|
+
"token_usage",
|
26
|
+
sa.Column("id", sa.Integer(), nullable=False),
|
27
|
+
sa.Column("execution_id", sa.String(), nullable=False),
|
28
|
+
sa.Column("provider", sa.String(), nullable=False),
|
29
|
+
sa.Column("model", sa.String(), nullable=False),
|
30
|
+
sa.Column("created_at", sa.DateTime(), nullable=False),
|
31
|
+
sa.Column("updated_at", sa.DateTime(), nullable=False),
|
32
|
+
sa.Column("prompt_tokens", sa.Integer(), nullable=False),
|
33
|
+
sa.Column("completion_tokens", sa.Integer(), nullable=False),
|
34
|
+
sa.Column("total_tokens", sa.Integer(), nullable=False),
|
35
|
+
sa.PrimaryKeyConstraint("id"),
|
34
36
|
)
|
35
|
-
op.create_index(
|
36
|
-
op.create_index(
|
37
|
-
op.create_index(
|
38
|
-
op.create_index(
|
37
|
+
op.create_index("idx_created_at", "token_usage", ["created_at"], unique=False)
|
38
|
+
op.create_index("idx_execution_id", "token_usage", ["execution_id"], unique=False)
|
39
|
+
op.create_index("idx_model", "token_usage", ["model"], unique=False)
|
40
|
+
op.create_index("idx_provider", "token_usage", ["provider"], unique=False)
|
39
41
|
# ### end Alembic commands ###
|
40
42
|
|
41
43
|
|
42
44
|
def downgrade() -> None:
|
43
45
|
# ### commands auto generated by Alembic - please adjust! ###
|
44
|
-
op.drop_index(
|
45
|
-
op.drop_index(
|
46
|
-
op.drop_index(
|
47
|
-
op.drop_index(
|
48
|
-
op.drop_table(
|
49
|
-
# ### end Alembic commands ###
|
46
|
+
op.drop_index("idx_provider", table_name="token_usage")
|
47
|
+
op.drop_index("idx_model", table_name="token_usage")
|
48
|
+
op.drop_index("idx_execution_id", table_name="token_usage")
|
49
|
+
op.drop_index("idx_created_at", table_name="token_usage")
|
50
|
+
op.drop_table("token_usage")
|
51
|
+
# ### end Alembic commands ###
|
tokenator/migrations.py
CHANGED
@@ -6,6 +6,7 @@ from alembic.config import Config
|
|
6
6
|
from alembic import command
|
7
7
|
from .utils import get_default_db_path
|
8
8
|
|
9
|
+
|
9
10
|
def get_alembic_config(db_path: str = None) -> Config:
|
10
11
|
"""Get Alembic config for migrations."""
|
11
12
|
if db_path is None:
|
@@ -13,27 +14,29 @@ def get_alembic_config(db_path: str = None) -> Config:
|
|
13
14
|
|
14
15
|
# Get the directory containing this file
|
15
16
|
migrations_dir = Path(__file__).parent / "migrations"
|
16
|
-
|
17
|
+
|
17
18
|
# Create Config object
|
18
19
|
config = Config()
|
19
20
|
config.set_main_option("script_location", str(migrations_dir))
|
20
21
|
config.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
|
21
|
-
|
22
|
+
|
22
23
|
return config
|
23
24
|
|
25
|
+
|
24
26
|
def check_and_run_migrations(db_path: str = None):
|
25
27
|
"""Check and run any pending database migrations."""
|
26
28
|
if db_path is None:
|
27
29
|
db_path = get_default_db_path()
|
28
|
-
|
30
|
+
|
29
31
|
dirname = os.path.dirname(db_path)
|
30
32
|
if dirname:
|
31
33
|
os.makedirs(dirname, exist_ok=True)
|
32
|
-
|
34
|
+
|
33
35
|
# Initialize database
|
34
36
|
import sqlite3
|
37
|
+
|
35
38
|
conn = sqlite3.connect(db_path)
|
36
39
|
conn.close()
|
37
|
-
|
40
|
+
|
38
41
|
config = get_alembic_config(db_path)
|
39
|
-
command.upgrade(config, "head")
|
42
|
+
command.upgrade(config, "head")
|
tokenator/models.py
CHANGED
@@ -1,31 +1,42 @@
|
|
1
1
|
from pydantic import BaseModel, Field
|
2
|
-
from typing import
|
2
|
+
from typing import List
|
3
|
+
|
3
4
|
|
4
5
|
class TokenRate(BaseModel):
|
5
6
|
prompt: float = Field(..., description="Cost per prompt token")
|
6
7
|
completion: float = Field(..., description="Cost per completion token")
|
7
8
|
|
9
|
+
|
8
10
|
class TokenMetrics(BaseModel):
|
9
11
|
total_cost: float = Field(..., description="Total cost in USD")
|
10
12
|
total_tokens: int = Field(..., description="Total tokens used")
|
11
13
|
prompt_tokens: int = Field(..., description="Number of prompt tokens")
|
12
14
|
completion_tokens: int = Field(..., description="Number of completion tokens")
|
13
15
|
|
16
|
+
|
14
17
|
class ModelUsage(TokenMetrics):
|
15
18
|
model: str = Field(..., description="Model name")
|
16
19
|
|
20
|
+
|
17
21
|
class ProviderUsage(TokenMetrics):
|
18
22
|
provider: str = Field(..., description="Provider name")
|
19
|
-
models: List[ModelUsage] = Field(
|
23
|
+
models: List[ModelUsage] = Field(
|
24
|
+
default_factory=list, description="Usage breakdown by model"
|
25
|
+
)
|
26
|
+
|
20
27
|
|
21
28
|
class TokenUsageReport(TokenMetrics):
|
22
|
-
providers: List[ProviderUsage] = Field(
|
29
|
+
providers: List[ProviderUsage] = Field(
|
30
|
+
default_factory=list, description="Usage breakdown by provider"
|
31
|
+
)
|
32
|
+
|
23
33
|
|
24
34
|
class Usage(BaseModel):
|
25
35
|
prompt_tokens: int = 0
|
26
36
|
completion_tokens: int = 0
|
27
37
|
total_tokens: int = 0
|
28
38
|
|
39
|
+
|
29
40
|
class TokenUsageStats(BaseModel):
|
30
41
|
model: str
|
31
|
-
usage: Usage
|
42
|
+
usage: Usage
|