tokenator 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tokenator/__init__.py +2 -7
- tokenator/anthropic/client_anthropic.py +51 -15
- tokenator/base_wrapper.py +54 -8
- tokenator/migrations/versions/f028b8155fed_adding_detailed_input_and_output_token_.py +64 -0
- tokenator/models.py +33 -12
- tokenator/openai/client_openai.py +97 -5
- tokenator/schemas.py +21 -19
- tokenator/state.py +12 -0
- tokenator/usage.py +466 -232
- tokenator/utils.py +14 -1
- {tokenator-0.1.13.dist-info → tokenator-0.1.15.dist-info}/METADATA +5 -4
- tokenator-0.1.15.dist-info/RECORD +21 -0
- tokenator-0.1.13.dist-info/RECORD +0 -19
- {tokenator-0.1.13.dist-info → tokenator-0.1.15.dist-info}/LICENSE +0 -0
- {tokenator-0.1.13.dist-info → tokenator-0.1.15.dist-info}/WHEEL +0 -0
tokenator/__init__.py
CHANGED
@@ -5,14 +5,9 @@ from .openai.client_openai import tokenator_openai
|
|
5
5
|
from .anthropic.client_anthropic import tokenator_anthropic
|
6
6
|
from . import usage
|
7
7
|
from .utils import get_default_db_path
|
8
|
-
from .
|
8
|
+
from .usage import TokenUsageService
|
9
9
|
|
10
|
-
|
10
|
+
usage = TokenUsageService() # noqa: F811
|
11
11
|
__all__ = ["tokenator_openai", "tokenator_anthropic", "usage", "get_default_db_path"]
|
12
12
|
|
13
13
|
logger = logging.getLogger(__name__)
|
14
|
-
|
15
|
-
try:
|
16
|
-
check_and_run_migrations()
|
17
|
-
except Exception as e:
|
18
|
-
logger.warning(f"Failed to run migrations, but continuing anyway: {e}")
|
@@ -6,9 +6,13 @@ import logging
|
|
6
6
|
from anthropic import Anthropic, AsyncAnthropic
|
7
7
|
from anthropic.types import Message, RawMessageStartEvent, RawMessageDeltaEvent
|
8
8
|
|
9
|
-
from ..models import
|
9
|
+
from ..models import PromptTokenDetails, TokenMetrics, TokenUsageStats
|
10
10
|
from ..base_wrapper import BaseWrapper, ResponseType
|
11
|
-
from .stream_interceptors import
|
11
|
+
from .stream_interceptors import (
|
12
|
+
AnthropicAsyncStreamInterceptor,
|
13
|
+
AnthropicSyncStreamInterceptor,
|
14
|
+
)
|
15
|
+
from ..state import is_tokenator_enabled
|
12
16
|
|
13
17
|
logger = logging.getLogger(__name__)
|
14
18
|
|
@@ -24,28 +28,46 @@ class BaseAnthropicWrapper(BaseWrapper):
|
|
24
28
|
if isinstance(response, Message):
|
25
29
|
if not hasattr(response, "usage"):
|
26
30
|
return None
|
27
|
-
usage =
|
28
|
-
prompt_tokens=response.usage.input_tokens
|
31
|
+
usage = TokenMetrics(
|
32
|
+
prompt_tokens=response.usage.input_tokens
|
33
|
+
+ (getattr(response.usage, "cache_creation_input_tokens", 0) or 0),
|
29
34
|
completion_tokens=response.usage.output_tokens,
|
30
35
|
total_tokens=response.usage.input_tokens
|
31
36
|
+ response.usage.output_tokens,
|
37
|
+
prompt_tokens_details=PromptTokenDetails(
|
38
|
+
cached_input_tokens=getattr(
|
39
|
+
response.usage, "cache_read_input_tokens", None
|
40
|
+
),
|
41
|
+
cached_creation_tokens=getattr(
|
42
|
+
response.usage, "cache_creation_input_tokens", None
|
43
|
+
),
|
44
|
+
),
|
32
45
|
)
|
33
46
|
return TokenUsageStats(model=response.model, usage=usage)
|
34
47
|
elif isinstance(response, dict):
|
35
48
|
usage_dict = response.get("usage")
|
36
49
|
if not usage_dict:
|
37
50
|
return None
|
38
|
-
usage =
|
39
|
-
prompt_tokens=usage_dict.get("input_tokens", 0)
|
51
|
+
usage = TokenMetrics(
|
52
|
+
prompt_tokens=usage_dict.get("input_tokens", 0)
|
53
|
+
+ (getattr(usage_dict, "cache_creation_input_tokens", 0) or 0),
|
40
54
|
completion_tokens=usage_dict.get("output_tokens", 0),
|
41
55
|
total_tokens=usage_dict.get("input_tokens", 0)
|
42
56
|
+ usage_dict.get("output_tokens", 0),
|
57
|
+
prompt_tokens_details=PromptTokenDetails(
|
58
|
+
cached_input_tokens=getattr(
|
59
|
+
usage_dict, "cache_read_input_tokens", None
|
60
|
+
),
|
61
|
+
cached_creation_tokens=getattr(
|
62
|
+
usage_dict, "cache_creation_input_tokens", None
|
63
|
+
),
|
64
|
+
),
|
43
65
|
)
|
44
66
|
return TokenUsageStats(
|
45
67
|
model=response.get("model", "unknown"), usage=usage
|
46
68
|
)
|
47
69
|
except Exception as e:
|
48
|
-
logger.warning("Failed to process usage stats: %s", str(e))
|
70
|
+
logger.warning("Failed to process usage stats: %s", str(e), exc_info=True)
|
49
71
|
return None
|
50
72
|
return None
|
51
73
|
|
@@ -56,15 +78,23 @@ class BaseAnthropicWrapper(BaseWrapper):
|
|
56
78
|
|
57
79
|
def _create_usage_callback(execution_id, log_usage_fn):
|
58
80
|
"""Creates a callback function for processing usage statistics from stream chunks."""
|
81
|
+
|
59
82
|
def usage_callback(chunks):
|
60
83
|
if not chunks:
|
61
84
|
return
|
62
|
-
|
85
|
+
|
86
|
+
# Skip if tokenator is disabled
|
87
|
+
if not is_tokenator_enabled:
|
88
|
+
logger.debug("Tokenator is disabled - skipping stream usage logging")
|
89
|
+
return
|
90
|
+
|
63
91
|
usage_data = TokenUsageStats(
|
64
|
-
model=chunks[0].message.model
|
65
|
-
|
92
|
+
model=chunks[0].message.model
|
93
|
+
if isinstance(chunks[0], RawMessageStartEvent)
|
94
|
+
else "",
|
95
|
+
usage=TokenMetrics(),
|
66
96
|
)
|
67
|
-
|
97
|
+
|
68
98
|
for chunk in chunks:
|
69
99
|
if isinstance(chunk, RawMessageStartEvent):
|
70
100
|
usage_data.model = chunk.message.model
|
@@ -72,8 +102,10 @@ def _create_usage_callback(execution_id, log_usage_fn):
|
|
72
102
|
usage_data.usage.completion_tokens += chunk.message.usage.output_tokens
|
73
103
|
elif isinstance(chunk, RawMessageDeltaEvent):
|
74
104
|
usage_data.usage.completion_tokens += chunk.usage.output_tokens
|
75
|
-
|
76
|
-
usage_data.usage.total_tokens =
|
105
|
+
|
106
|
+
usage_data.usage.total_tokens = (
|
107
|
+
usage_data.usage.prompt_tokens + usage_data.usage.completion_tokens
|
108
|
+
)
|
77
109
|
log_usage_fn(usage_data, execution_id=execution_id)
|
78
110
|
|
79
111
|
return usage_callback
|
@@ -84,7 +116,9 @@ class AnthropicWrapper(BaseAnthropicWrapper):
|
|
84
116
|
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
85
117
|
) -> Union[Message, Iterator[Message]]:
|
86
118
|
"""Create a message completion and log token usage."""
|
87
|
-
logger.debug(
|
119
|
+
logger.debug(
|
120
|
+
"Creating message completion with args: %s, kwargs: %s", args, kwargs
|
121
|
+
)
|
88
122
|
|
89
123
|
if kwargs.get("stream", False):
|
90
124
|
base_stream = self.client.messages.create(*args, **kwargs)
|
@@ -105,7 +139,9 @@ class AsyncAnthropicWrapper(BaseAnthropicWrapper):
|
|
105
139
|
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
106
140
|
) -> Union[Message, AsyncIterator[Message]]:
|
107
141
|
"""Create a message completion and log token usage."""
|
108
|
-
logger.debug(
|
142
|
+
logger.debug(
|
143
|
+
"Creating message completion with args: %s, kwargs: %s", args, kwargs
|
144
|
+
)
|
109
145
|
|
110
146
|
if kwargs.get("stream", False):
|
111
147
|
base_stream = await self.client.messages.create(*args, **kwargs)
|
tokenator/base_wrapper.py
CHANGED
@@ -7,6 +7,9 @@ import uuid
|
|
7
7
|
|
8
8
|
from .models import TokenUsageStats
|
9
9
|
from .schemas import get_session, TokenUsage
|
10
|
+
from . import state
|
11
|
+
|
12
|
+
from .migrations import check_and_run_migrations
|
10
13
|
|
11
14
|
logger = logging.getLogger(__name__)
|
12
15
|
|
@@ -16,17 +19,30 @@ ResponseType = TypeVar("ResponseType")
|
|
16
19
|
class BaseWrapper:
|
17
20
|
def __init__(self, client: Any, db_path: Optional[str] = None):
|
18
21
|
"""Initialize the base wrapper."""
|
19
|
-
|
22
|
+
state.is_tokenator_enabled = True
|
23
|
+
try:
|
24
|
+
self.client = client
|
20
25
|
|
21
|
-
|
22
|
-
|
23
|
-
|
26
|
+
if db_path:
|
27
|
+
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
28
|
+
logger.info("Created database directory at: %s", Path(db_path).parent)
|
29
|
+
state.db_path = db_path # Store db_path in state
|
24
30
|
|
25
|
-
|
31
|
+
else:
|
32
|
+
state.db_path = None # Use default path
|
26
33
|
|
27
|
-
|
28
|
-
|
29
|
-
|
34
|
+
self.Session = get_session()
|
35
|
+
|
36
|
+
logger.debug(
|
37
|
+
"Initializing %s with db_path: %s", self.__class__.__name__, db_path
|
38
|
+
)
|
39
|
+
|
40
|
+
check_and_run_migrations(db_path)
|
41
|
+
except Exception as e:
|
42
|
+
state.is_tokenator_enabled = False
|
43
|
+
logger.warning(
|
44
|
+
f"Tokenator initialization failed. Usage tracking will be disabled. Error: {e}"
|
45
|
+
)
|
30
46
|
|
31
47
|
def _log_usage_impl(
|
32
48
|
self, token_usage_stats: TokenUsageStats, session, execution_id: str
|
@@ -42,9 +58,33 @@ class BaseWrapper:
|
|
42
58
|
execution_id=execution_id,
|
43
59
|
provider=self.provider,
|
44
60
|
model=token_usage_stats.model,
|
61
|
+
total_cost=0, # This needs to be calculated based on your rates
|
45
62
|
prompt_tokens=token_usage_stats.usage.prompt_tokens,
|
46
63
|
completion_tokens=token_usage_stats.usage.completion_tokens,
|
47
64
|
total_tokens=token_usage_stats.usage.total_tokens,
|
65
|
+
# Prompt details
|
66
|
+
prompt_cached_input_tokens=token_usage_stats.usage.prompt_tokens_details.cached_input_tokens
|
67
|
+
if token_usage_stats.usage.prompt_tokens_details
|
68
|
+
else None,
|
69
|
+
prompt_cached_creation_tokens=token_usage_stats.usage.prompt_tokens_details.cached_creation_tokens
|
70
|
+
if token_usage_stats.usage.prompt_tokens_details
|
71
|
+
else None,
|
72
|
+
prompt_audio_tokens=token_usage_stats.usage.prompt_tokens_details.audio_tokens
|
73
|
+
if token_usage_stats.usage.prompt_tokens_details
|
74
|
+
else None,
|
75
|
+
# Completion details
|
76
|
+
completion_audio_tokens=token_usage_stats.usage.completion_tokens_details.audio_tokens
|
77
|
+
if token_usage_stats.usage.completion_tokens_details
|
78
|
+
else None,
|
79
|
+
completion_reasoning_tokens=token_usage_stats.usage.completion_tokens_details.reasoning_tokens
|
80
|
+
if token_usage_stats.usage.completion_tokens_details
|
81
|
+
else None,
|
82
|
+
completion_accepted_prediction_tokens=token_usage_stats.usage.completion_tokens_details.accepted_prediction_tokens
|
83
|
+
if token_usage_stats.usage.completion_tokens_details
|
84
|
+
else None,
|
85
|
+
completion_rejected_prediction_tokens=token_usage_stats.usage.completion_tokens_details.rejected_prediction_tokens
|
86
|
+
if token_usage_stats.usage.completion_tokens_details
|
87
|
+
else None,
|
48
88
|
)
|
49
89
|
session.add(token_usage)
|
50
90
|
logger.debug(
|
@@ -59,14 +99,20 @@ class BaseWrapper:
|
|
59
99
|
self, token_usage_stats: TokenUsageStats, execution_id: Optional[str] = None
|
60
100
|
):
|
61
101
|
"""Log token usage to database."""
|
102
|
+
if not state.is_tokenator_enabled:
|
103
|
+
logger.debug("Tokenator is disabled - skipping usage logging")
|
104
|
+
return
|
105
|
+
|
62
106
|
if not execution_id:
|
63
107
|
execution_id = str(uuid.uuid4())
|
64
108
|
|
109
|
+
logger.debug("Starting token usage logging for execution_id: %s", execution_id)
|
65
110
|
session = self.Session()
|
66
111
|
try:
|
67
112
|
try:
|
68
113
|
self._log_usage_impl(token_usage_stats, session, execution_id)
|
69
114
|
session.commit()
|
115
|
+
logger.debug("Successfully committed token usage for execution_id: %s", execution_id)
|
70
116
|
except Exception as e:
|
71
117
|
logger.error("Failed to log token usage: %s", str(e))
|
72
118
|
session.rollback()
|
@@ -0,0 +1,64 @@
|
|
1
|
+
"""Adding detailed input and output token schema
|
2
|
+
|
3
|
+
Revision ID: f028b8155fed
|
4
|
+
Revises: f6f1f2437513
|
5
|
+
Create Date: 2025-01-19 15:41:12.715623
|
6
|
+
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Sequence, Union
|
10
|
+
|
11
|
+
from alembic import op
|
12
|
+
import sqlalchemy as sa
|
13
|
+
|
14
|
+
|
15
|
+
# revision identifiers, used by Alembic.
|
16
|
+
revision: str = "f028b8155fed"
|
17
|
+
down_revision: Union[str, None] = "f6f1f2437513"
|
18
|
+
branch_labels: Union[str, Sequence[str], None] = None
|
19
|
+
depends_on: Union[str, Sequence[str], None] = None
|
20
|
+
|
21
|
+
|
22
|
+
def upgrade() -> None:
|
23
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
24
|
+
op.add_column("token_usage", sa.Column("total_cost", sa.Integer(), nullable=False))
|
25
|
+
op.add_column(
|
26
|
+
"token_usage",
|
27
|
+
sa.Column("prompt_cached_input_tokens", sa.Integer(), nullable=True),
|
28
|
+
)
|
29
|
+
op.add_column(
|
30
|
+
"token_usage",
|
31
|
+
sa.Column("prompt_cached_creation_tokens", sa.Integer(), nullable=True),
|
32
|
+
)
|
33
|
+
op.add_column(
|
34
|
+
"token_usage", sa.Column("prompt_audio_tokens", sa.Integer(), nullable=True)
|
35
|
+
)
|
36
|
+
op.add_column(
|
37
|
+
"token_usage", sa.Column("completion_audio_tokens", sa.Integer(), nullable=True)
|
38
|
+
)
|
39
|
+
op.add_column(
|
40
|
+
"token_usage",
|
41
|
+
sa.Column("completion_reasoning_tokens", sa.Integer(), nullable=True),
|
42
|
+
)
|
43
|
+
op.add_column(
|
44
|
+
"token_usage",
|
45
|
+
sa.Column("completion_accepted_prediction_tokens", sa.Integer(), nullable=True),
|
46
|
+
)
|
47
|
+
op.add_column(
|
48
|
+
"token_usage",
|
49
|
+
sa.Column("completion_rejected_prediction_tokens", sa.Integer(), nullable=True),
|
50
|
+
)
|
51
|
+
# ### end Alembic commands ###
|
52
|
+
|
53
|
+
|
54
|
+
def downgrade() -> None:
|
55
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
56
|
+
op.drop_column("token_usage", "completion_rejected_prediction_tokens")
|
57
|
+
op.drop_column("token_usage", "completion_accepted_prediction_tokens")
|
58
|
+
op.drop_column("token_usage", "completion_reasoning_tokens")
|
59
|
+
op.drop_column("token_usage", "completion_audio_tokens")
|
60
|
+
op.drop_column("token_usage", "prompt_audio_tokens")
|
61
|
+
op.drop_column("token_usage", "prompt_cached_creation_tokens")
|
62
|
+
op.drop_column("token_usage", "prompt_cached_input_tokens")
|
63
|
+
op.drop_column("token_usage", "total_cost")
|
64
|
+
# ### end Alembic commands ###
|
tokenator/models.py
CHANGED
@@ -1,17 +1,44 @@
|
|
1
1
|
from pydantic import BaseModel, Field
|
2
|
-
from typing import List
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
|
5
5
|
class TokenRate(BaseModel):
|
6
6
|
prompt: float = Field(..., description="Cost per prompt token")
|
7
7
|
completion: float = Field(..., description="Cost per completion token")
|
8
|
+
prompt_audio: Optional[float] = Field(
|
9
|
+
None, description="Cost per audio prompt token"
|
10
|
+
)
|
11
|
+
completion_audio: Optional[float] = Field(
|
12
|
+
None, description="Cost per audio completion token"
|
13
|
+
)
|
14
|
+
prompt_cached_input: Optional[float] = Field(
|
15
|
+
None, description="Cost per cached prompt input token"
|
16
|
+
)
|
17
|
+
prompt_cached_creation: Optional[float] = Field(
|
18
|
+
None, description="Cost per cached prompt creation token"
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
class PromptTokenDetails(BaseModel):
|
23
|
+
cached_input_tokens: Optional[int] = None
|
24
|
+
cached_creation_tokens: Optional[int] = None
|
25
|
+
audio_tokens: Optional[int] = None
|
26
|
+
|
27
|
+
|
28
|
+
class CompletionTokenDetails(BaseModel):
|
29
|
+
reasoning_tokens: Optional[int] = None
|
30
|
+
audio_tokens: Optional[int] = None
|
31
|
+
accepted_prediction_tokens: Optional[int] = None
|
32
|
+
rejected_prediction_tokens: Optional[int] = None
|
8
33
|
|
9
34
|
|
10
35
|
class TokenMetrics(BaseModel):
|
11
|
-
total_cost: float = Field(
|
12
|
-
total_tokens: int = Field(
|
13
|
-
prompt_tokens: int = Field(
|
14
|
-
completion_tokens: int = Field(
|
36
|
+
total_cost: float = Field(default=0, description="Total cost in USD")
|
37
|
+
total_tokens: int = Field(default=0, description="Total tokens used")
|
38
|
+
prompt_tokens: int = Field(default=0, description="Number of prompt tokens")
|
39
|
+
completion_tokens: int = Field(default=0, description="Number of completion tokens")
|
40
|
+
prompt_tokens_details: Optional[PromptTokenDetails] = None
|
41
|
+
completion_tokens_details: Optional[CompletionTokenDetails] = None
|
15
42
|
|
16
43
|
|
17
44
|
class ModelUsage(TokenMetrics):
|
@@ -31,12 +58,6 @@ class TokenUsageReport(TokenMetrics):
|
|
31
58
|
)
|
32
59
|
|
33
60
|
|
34
|
-
class Usage(BaseModel):
|
35
|
-
prompt_tokens: int = 0
|
36
|
-
completion_tokens: int = 0
|
37
|
-
total_tokens: int = 0
|
38
|
-
|
39
|
-
|
40
61
|
class TokenUsageStats(BaseModel):
|
41
62
|
model: str
|
42
|
-
usage:
|
63
|
+
usage: TokenMetrics
|
@@ -6,9 +6,18 @@ import logging
|
|
6
6
|
from openai import AsyncOpenAI, OpenAI
|
7
7
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
8
8
|
|
9
|
-
from ..models import
|
9
|
+
from ..models import (
|
10
|
+
TokenMetrics,
|
11
|
+
TokenUsageStats,
|
12
|
+
PromptTokenDetails,
|
13
|
+
CompletionTokenDetails,
|
14
|
+
)
|
10
15
|
from ..base_wrapper import BaseWrapper, ResponseType
|
11
|
-
from .stream_interceptors import
|
16
|
+
from .stream_interceptors import (
|
17
|
+
OpenAIAsyncStreamInterceptor,
|
18
|
+
OpenAISyncStreamInterceptor,
|
19
|
+
)
|
20
|
+
from ..state import is_tokenator_enabled
|
12
21
|
|
13
22
|
logger = logging.getLogger(__name__)
|
14
23
|
|
@@ -26,18 +35,49 @@ class BaseOpenAIWrapper(BaseWrapper):
|
|
26
35
|
if isinstance(response, ChatCompletion):
|
27
36
|
if response.usage is None:
|
28
37
|
return None
|
29
|
-
usage =
|
38
|
+
usage = TokenMetrics(
|
30
39
|
prompt_tokens=response.usage.prompt_tokens,
|
31
40
|
completion_tokens=response.usage.completion_tokens,
|
32
41
|
total_tokens=response.usage.total_tokens,
|
42
|
+
prompt_tokens_details=PromptTokenDetails(
|
43
|
+
cached_input_tokens=getattr(
|
44
|
+
response.usage.prompt_tokens_details, "cached_tokens", None
|
45
|
+
),
|
46
|
+
audio_tokens=getattr(
|
47
|
+
response.usage.prompt_tokens_details, "audio_tokens", None
|
48
|
+
),
|
49
|
+
),
|
50
|
+
completion_tokens_details=CompletionTokenDetails(
|
51
|
+
reasoning_tokens=getattr(
|
52
|
+
response.usage.completion_tokens_details,
|
53
|
+
"reasoning_tokens",
|
54
|
+
None,
|
55
|
+
),
|
56
|
+
audio_tokens=getattr(
|
57
|
+
response.usage.completion_tokens_details,
|
58
|
+
"audio_tokens",
|
59
|
+
None,
|
60
|
+
),
|
61
|
+
accepted_prediction_tokens=getattr(
|
62
|
+
response.usage.completion_tokens_details,
|
63
|
+
"accepted_prediction_tokens",
|
64
|
+
None,
|
65
|
+
),
|
66
|
+
rejected_prediction_tokens=getattr(
|
67
|
+
response.usage.completion_tokens_details,
|
68
|
+
"rejected_prediction_tokens",
|
69
|
+
None,
|
70
|
+
),
|
71
|
+
),
|
33
72
|
)
|
73
|
+
|
34
74
|
return TokenUsageStats(model=response.model, usage=usage)
|
35
75
|
|
36
76
|
elif isinstance(response, dict):
|
37
77
|
usage_dict = response.get("usage")
|
38
78
|
if not usage_dict:
|
39
79
|
return None
|
40
|
-
usage =
|
80
|
+
usage = TokenMetrics(
|
41
81
|
prompt_tokens=usage_dict.get("prompt_tokens", 0),
|
42
82
|
completion_tokens=usage_dict.get("completion_tokens", 0),
|
43
83
|
total_tokens=usage_dict.get("total_tokens", 0),
|
@@ -58,6 +98,10 @@ class BaseOpenAIWrapper(BaseWrapper):
|
|
58
98
|
def completions(self):
|
59
99
|
return self
|
60
100
|
|
101
|
+
@property
|
102
|
+
def beta(self):
|
103
|
+
return self
|
104
|
+
|
61
105
|
|
62
106
|
def _create_usage_callback(execution_id, log_usage_fn):
|
63
107
|
"""Creates a callback function for processing usage statistics from stream chunks."""
|
@@ -65,10 +109,18 @@ def _create_usage_callback(execution_id, log_usage_fn):
|
|
65
109
|
def usage_callback(chunks):
|
66
110
|
if not chunks:
|
67
111
|
return
|
112
|
+
|
113
|
+
# Skip if tokenator is disabled
|
114
|
+
if not is_tokenator_enabled:
|
115
|
+
logger.debug("Tokenator is disabled - skipping stream usage logging")
|
116
|
+
return
|
117
|
+
|
118
|
+
logger.debug("Processing stream usage for execution_id: %s", execution_id)
|
119
|
+
|
68
120
|
# Build usage_data from the first chunk's model
|
69
121
|
usage_data = TokenUsageStats(
|
70
122
|
model=chunks[0].model,
|
71
|
-
usage=
|
123
|
+
usage=TokenMetrics(),
|
72
124
|
)
|
73
125
|
# Sum up usage from all chunks
|
74
126
|
has_usage = False
|
@@ -106,6 +158,26 @@ class OpenAIWrapper(BaseOpenAIWrapper):
|
|
106
158
|
|
107
159
|
return response
|
108
160
|
|
161
|
+
def parse(
|
162
|
+
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
163
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletion]]:
|
164
|
+
"""Create a chat completion parse and log token usage."""
|
165
|
+
logger.debug("Creating chat completion with args: %s, kwargs: %s", args, kwargs)
|
166
|
+
|
167
|
+
if kwargs.get("stream", False):
|
168
|
+
base_stream = self.client.beta.chat.completions.parse(*args, **kwargs)
|
169
|
+
return OpenAISyncStreamInterceptor(
|
170
|
+
base_stream=base_stream,
|
171
|
+
usage_callback=_create_usage_callback(execution_id, self._log_usage),
|
172
|
+
)
|
173
|
+
|
174
|
+
response = self.client.beta.chat.completions.parse(*args, **kwargs)
|
175
|
+
usage_data = self._process_response_usage(response)
|
176
|
+
if usage_data:
|
177
|
+
self._log_usage(usage_data, execution_id=execution_id)
|
178
|
+
|
179
|
+
return response
|
180
|
+
|
109
181
|
|
110
182
|
class AsyncOpenAIWrapper(BaseOpenAIWrapper):
|
111
183
|
async def create(
|
@@ -131,6 +203,26 @@ class AsyncOpenAIWrapper(BaseOpenAIWrapper):
|
|
131
203
|
self._log_usage(usage_data, execution_id=execution_id)
|
132
204
|
return response
|
133
205
|
|
206
|
+
async def parse(
|
207
|
+
self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any
|
208
|
+
) -> Union[ChatCompletion, AsyncIterator[ChatCompletionChunk]]:
|
209
|
+
"""Create a chat completion parse and log token usage."""
|
210
|
+
logger.debug("Creating chat completion with args: %s, kwargs: %s", args, kwargs)
|
211
|
+
|
212
|
+
if kwargs.get("stream", False):
|
213
|
+
base_stream = await self.client.beta.chat.completions.parse(*args, **kwargs)
|
214
|
+
return OpenAIAsyncStreamInterceptor(
|
215
|
+
base_stream=base_stream,
|
216
|
+
usage_callback=_create_usage_callback(execution_id, self._log_usage),
|
217
|
+
)
|
218
|
+
|
219
|
+
response = await self.client.beta.chat.completions.parse(*args, **kwargs)
|
220
|
+
usage_data = self._process_response_usage(response)
|
221
|
+
if usage_data:
|
222
|
+
self._log_usage(usage_data, execution_id=execution_id)
|
223
|
+
|
224
|
+
return response
|
225
|
+
|
134
226
|
|
135
227
|
@overload
|
136
228
|
def tokenator_openai(
|
tokenator/schemas.py
CHANGED
@@ -1,25 +1,27 @@
|
|
1
1
|
"""SQLAlchemy models for tokenator."""
|
2
2
|
|
3
3
|
from datetime import datetime
|
4
|
+
from typing import Optional
|
4
5
|
|
5
6
|
from sqlalchemy import create_engine, Column, Integer, String, DateTime, Index
|
6
7
|
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
7
8
|
|
8
9
|
from .utils import get_default_db_path
|
10
|
+
from . import state # Import state to access db_path
|
9
11
|
|
10
12
|
Base = declarative_base()
|
11
13
|
|
12
14
|
|
13
|
-
def get_engine(db_path: str = None):
|
15
|
+
def get_engine(db_path: Optional[str] = None):
|
14
16
|
"""Create SQLAlchemy engine with the given database path."""
|
15
17
|
if db_path is None:
|
16
|
-
db_path = get_default_db_path()
|
18
|
+
db_path = state.db_path or get_default_db_path() # Use state.db_path if set
|
17
19
|
return create_engine(f"sqlite:///{db_path}", echo=False)
|
18
20
|
|
19
21
|
|
20
|
-
def get_session(
|
22
|
+
def get_session():
|
21
23
|
"""Create a thread-safe session factory."""
|
22
|
-
engine = get_engine(
|
24
|
+
engine = get_engine()
|
23
25
|
# Base.metadata.create_all(engine)
|
24
26
|
session_factory = sessionmaker(bind=engine)
|
25
27
|
return scoped_session(session_factory)
|
@@ -38,28 +40,28 @@ class TokenUsage(Base):
|
|
38
40
|
updated_at = Column(
|
39
41
|
DateTime, nullable=False, default=datetime.now, onupdate=datetime.now
|
40
42
|
)
|
43
|
+
|
44
|
+
# Core metrics (mandatory)
|
45
|
+
total_cost = Column(Integer, nullable=False)
|
41
46
|
prompt_tokens = Column(Integer, nullable=False)
|
42
47
|
completion_tokens = Column(Integer, nullable=False)
|
43
48
|
total_tokens = Column(Integer, nullable=False)
|
44
49
|
|
45
|
-
#
|
50
|
+
# Prompt token details (optional)
|
51
|
+
prompt_cached_input_tokens = Column(Integer, nullable=True)
|
52
|
+
prompt_cached_creation_tokens = Column(Integer, nullable=True)
|
53
|
+
prompt_audio_tokens = Column(Integer, nullable=True)
|
54
|
+
|
55
|
+
# Completion token details (optional)
|
56
|
+
completion_audio_tokens = Column(Integer, nullable=True)
|
57
|
+
completion_reasoning_tokens = Column(Integer, nullable=True)
|
58
|
+
completion_accepted_prediction_tokens = Column(Integer, nullable=True)
|
59
|
+
completion_rejected_prediction_tokens = Column(Integer, nullable=True)
|
60
|
+
|
61
|
+
# Keep existing indexes
|
46
62
|
__table_args__ = (
|
47
63
|
Index("idx_created_at", "created_at"),
|
48
64
|
Index("idx_execution_id", "execution_id"),
|
49
65
|
Index("idx_provider", "provider"),
|
50
66
|
Index("idx_model", "model"),
|
51
67
|
)
|
52
|
-
|
53
|
-
def to_dict(self):
|
54
|
-
"""Convert model instance to dictionary."""
|
55
|
-
return {
|
56
|
-
"id": self.id,
|
57
|
-
"execution_id": self.execution_id,
|
58
|
-
"provider": self.provider,
|
59
|
-
"model": self.model,
|
60
|
-
"created_at": self.created_at,
|
61
|
-
"updated_at": self.updated_at,
|
62
|
-
"prompt_tokens": self.prompt_tokens,
|
63
|
-
"completion_tokens": self.completion_tokens,
|
64
|
-
"total_tokens": self.total_tokens,
|
65
|
-
}
|
tokenator/state.py
ADDED
@@ -0,0 +1,12 @@
|
|
1
|
+
"""Global state for tokenator."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
# Global flag to track if tokenator is properly initialized
|
9
|
+
is_tokenator_enabled = True
|
10
|
+
|
11
|
+
# Store the database path
|
12
|
+
db_path: Optional[str] = None
|