tokenator 0.1.9__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tokenator/__init__.py +2 -2
- tokenator/anthropic/client_anthropic.py +155 -0
- tokenator/anthropic/stream_interceptors.py +146 -0
- tokenator/base_wrapper.py +26 -13
- tokenator/create_migrations.py +6 -5
- tokenator/migrations/env.py +5 -4
- tokenator/migrations/versions/f6f1f2437513_initial_migration.py +25 -23
- tokenator/migrations.py +9 -6
- tokenator/models.py +15 -4
- tokenator/openai/client_openai.py +66 -70
- tokenator/openai/stream_interceptors.py +146 -0
- tokenator/schemas.py +26 -27
- tokenator/usage.py +114 -47
- tokenator/utils.py +14 -9
- {tokenator-0.1.9.dist-info → tokenator-0.1.10.dist-info}/METADATA +16 -14
- tokenator-0.1.10.dist-info/RECORD +19 -0
- tokenator/client_anthropic.py +0 -148
- tokenator/openai/AsyncStreamInterceptor.py +0 -78
- tokenator-0.1.9.dist-info/RECORD +0 -18
- {tokenator-0.1.9.dist-info → tokenator-0.1.10.dist-info}/LICENSE +0 -0
- {tokenator-0.1.9.dist-info → tokenator-0.1.10.dist-info}/WHEEL +0 -0
tokenator/usage.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
1
|
"""Cost analysis functions for token usage."""
|
2
2
|
|
3
|
-
from datetime import datetime, timedelta
|
3
|
+
from datetime import datetime, timedelta
|
4
4
|
from typing import Dict, Optional, Union
|
5
5
|
|
6
|
-
from sqlalchemy import and_
|
7
6
|
|
8
7
|
from .schemas import get_session, TokenUsage
|
9
8
|
from .models import TokenRate, TokenUsageReport, ModelUsage, ProviderUsage
|
@@ -13,48 +12,63 @@ import logging
|
|
13
12
|
|
14
13
|
logger = logging.getLogger(__name__)
|
15
14
|
|
15
|
+
|
16
16
|
def _get_model_costs() -> Dict[str, TokenRate]:
|
17
17
|
url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
18
18
|
response = requests.get(url)
|
19
19
|
data = response.json()
|
20
|
-
|
20
|
+
|
21
21
|
return {
|
22
22
|
model: TokenRate(
|
23
23
|
prompt=info["input_cost_per_token"],
|
24
|
-
completion=info["output_cost_per_token"]
|
24
|
+
completion=info["output_cost_per_token"],
|
25
25
|
)
|
26
26
|
for model, info in data.items()
|
27
27
|
if "input_cost_per_token" in info and "output_cost_per_token" in info
|
28
28
|
}
|
29
29
|
|
30
|
+
|
30
31
|
MODEL_COSTS = _get_model_costs()
|
31
32
|
|
32
|
-
|
33
|
+
|
34
|
+
def _calculate_cost(
|
35
|
+
usages: list[TokenUsage], provider: Optional[str] = None
|
36
|
+
) -> TokenUsageReport:
|
33
37
|
"""Calculate cost from token usage records."""
|
34
38
|
# Group usages by provider and model
|
35
39
|
provider_model_usages: Dict[str, Dict[str, list[TokenUsage]]] = {}
|
36
40
|
|
37
41
|
print(f"usages: {len(usages)}")
|
38
|
-
|
42
|
+
|
39
43
|
for usage in usages:
|
40
44
|
if usage.model not in MODEL_COSTS:
|
41
45
|
continue
|
42
|
-
|
46
|
+
|
43
47
|
provider = usage.provider
|
44
48
|
if provider not in provider_model_usages:
|
45
49
|
provider_model_usages[provider] = {}
|
46
|
-
|
50
|
+
|
47
51
|
if usage.model not in provider_model_usages[provider]:
|
48
52
|
provider_model_usages[provider][usage.model] = []
|
49
|
-
|
53
|
+
|
50
54
|
provider_model_usages[provider][usage.model].append(usage)
|
51
55
|
|
52
56
|
# Calculate totals for each level
|
53
57
|
providers_list = []
|
54
|
-
total_metrics = {
|
58
|
+
total_metrics = {
|
59
|
+
"total_cost": 0.0,
|
60
|
+
"total_tokens": 0,
|
61
|
+
"prompt_tokens": 0,
|
62
|
+
"completion_tokens": 0,
|
63
|
+
}
|
55
64
|
|
56
65
|
for provider, model_usages in provider_model_usages.items():
|
57
|
-
provider_metrics = {
|
66
|
+
provider_metrics = {
|
67
|
+
"total_cost": 0.0,
|
68
|
+
"total_tokens": 0,
|
69
|
+
"prompt_tokens": 0,
|
70
|
+
"completion_tokens": 0,
|
71
|
+
}
|
58
72
|
models_list = []
|
59
73
|
|
60
74
|
for model, usages in model_usages.items():
|
@@ -67,17 +81,21 @@ def _calculate_cost(usages: list[TokenUsage], provider: Optional[str] = None) ->
|
|
67
81
|
model_prompt += usage.prompt_tokens
|
68
82
|
model_completion += usage.completion_tokens
|
69
83
|
model_total += usage.total_tokens
|
70
|
-
|
71
|
-
model_cost +=
|
72
|
-
model_cost += (
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
84
|
+
|
85
|
+
model_cost += usage.prompt_tokens * MODEL_COSTS[usage.model].prompt
|
86
|
+
model_cost += (
|
87
|
+
usage.completion_tokens * MODEL_COSTS[usage.model].completion
|
88
|
+
)
|
89
|
+
|
90
|
+
models_list.append(
|
91
|
+
ModelUsage(
|
92
|
+
model=model,
|
93
|
+
total_cost=round(model_cost, 6),
|
94
|
+
total_tokens=model_total,
|
95
|
+
prompt_tokens=model_prompt,
|
96
|
+
completion_tokens=model_completion,
|
97
|
+
)
|
98
|
+
)
|
81
99
|
|
82
100
|
# Add to provider totals
|
83
101
|
provider_metrics["total_cost"] += model_cost
|
@@ -85,11 +103,16 @@ def _calculate_cost(usages: list[TokenUsage], provider: Optional[str] = None) ->
|
|
85
103
|
provider_metrics["prompt_tokens"] += model_prompt
|
86
104
|
provider_metrics["completion_tokens"] += model_completion
|
87
105
|
|
88
|
-
providers_list.append(
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
106
|
+
providers_list.append(
|
107
|
+
ProviderUsage(
|
108
|
+
provider=provider,
|
109
|
+
models=models_list,
|
110
|
+
**{
|
111
|
+
k: (round(v, 6) if k == "total_cost" else v)
|
112
|
+
for k, v in provider_metrics.items()
|
113
|
+
},
|
114
|
+
)
|
115
|
+
)
|
93
116
|
|
94
117
|
# Add to grand totals
|
95
118
|
for key in total_metrics:
|
@@ -97,76 +120,110 @@ def _calculate_cost(usages: list[TokenUsage], provider: Optional[str] = None) ->
|
|
97
120
|
|
98
121
|
return TokenUsageReport(
|
99
122
|
providers=providers_list,
|
100
|
-
**{
|
123
|
+
**{
|
124
|
+
k: (round(v, 6) if k == "total_cost" else v)
|
125
|
+
for k, v in total_metrics.items()
|
126
|
+
},
|
101
127
|
)
|
102
128
|
|
103
|
-
|
104
|
-
|
105
|
-
|
129
|
+
|
130
|
+
def _query_usage(
|
131
|
+
start_date: datetime,
|
132
|
+
end_date: datetime,
|
133
|
+
provider: Optional[str] = None,
|
134
|
+
model: Optional[str] = None,
|
135
|
+
) -> TokenUsageReport:
|
106
136
|
"""Query token usage for a specific time period."""
|
107
137
|
session = get_session()()
|
108
138
|
try:
|
109
139
|
query = session.query(TokenUsage).filter(
|
110
140
|
TokenUsage.created_at.between(start_date, end_date)
|
111
141
|
)
|
112
|
-
|
142
|
+
|
113
143
|
if provider:
|
114
144
|
query = query.filter(TokenUsage.provider == provider)
|
115
145
|
if model:
|
116
146
|
query = query.filter(TokenUsage.model == model)
|
117
|
-
|
147
|
+
|
118
148
|
usages = query.all()
|
119
149
|
return _calculate_cost(usages, provider or "all")
|
120
150
|
finally:
|
121
151
|
session.close()
|
122
152
|
|
123
|
-
|
153
|
+
|
154
|
+
def last_hour(
|
155
|
+
provider: Optional[str] = None, model: Optional[str] = None
|
156
|
+
) -> TokenUsageReport:
|
124
157
|
"""Get cost analysis for the last hour."""
|
125
|
-
logger.debug(
|
158
|
+
logger.debug(
|
159
|
+
f"Getting cost analysis for last hour (provider={provider}, model={model})"
|
160
|
+
)
|
126
161
|
end = datetime.now()
|
127
162
|
start = end - timedelta(hours=1)
|
128
163
|
return _query_usage(start, end, provider, model)
|
129
164
|
|
130
|
-
|
165
|
+
|
166
|
+
def last_day(
|
167
|
+
provider: Optional[str] = None, model: Optional[str] = None
|
168
|
+
) -> TokenUsageReport:
|
131
169
|
"""Get cost analysis for the last 24 hours."""
|
132
|
-
logger.debug(
|
170
|
+
logger.debug(
|
171
|
+
f"Getting cost analysis for last 24 hours (provider={provider}, model={model})"
|
172
|
+
)
|
133
173
|
end = datetime.now()
|
134
174
|
start = end - timedelta(days=1)
|
135
175
|
return _query_usage(start, end, provider, model)
|
136
176
|
|
137
|
-
|
177
|
+
|
178
|
+
def last_week(
|
179
|
+
provider: Optional[str] = None, model: Optional[str] = None
|
180
|
+
) -> TokenUsageReport:
|
138
181
|
"""Get cost analysis for the last 7 days."""
|
139
|
-
logger.debug(
|
182
|
+
logger.debug(
|
183
|
+
f"Getting cost analysis for last 7 days (provider={provider}, model={model})"
|
184
|
+
)
|
140
185
|
end = datetime.now()
|
141
186
|
start = end - timedelta(weeks=1)
|
142
187
|
return _query_usage(start, end, provider, model)
|
143
188
|
|
144
|
-
|
189
|
+
|
190
|
+
def last_month(
|
191
|
+
provider: Optional[str] = None, model: Optional[str] = None
|
192
|
+
) -> TokenUsageReport:
|
145
193
|
"""Get cost analysis for the last 30 days."""
|
146
|
-
logger.debug(
|
194
|
+
logger.debug(
|
195
|
+
f"Getting cost analysis for last 30 days (provider={provider}, model={model})"
|
196
|
+
)
|
147
197
|
end = datetime.now()
|
148
198
|
start = end - timedelta(days=30)
|
149
199
|
return _query_usage(start, end, provider, model)
|
150
200
|
|
201
|
+
|
151
202
|
def between(
|
152
203
|
start_date: Union[datetime, str],
|
153
204
|
end_date: Union[datetime, str],
|
154
205
|
provider: Optional[str] = None,
|
155
|
-
model: Optional[str] = None
|
206
|
+
model: Optional[str] = None,
|
156
207
|
) -> TokenUsageReport:
|
157
208
|
"""Get cost analysis between two dates.
|
158
|
-
|
209
|
+
|
159
210
|
Args:
|
160
211
|
start_date: datetime object or string (format: YYYY-MM-DD or YYYY-MM-DD HH:MM:SS)
|
161
212
|
end_date: datetime object or string (format: YYYY-MM-DD or YYYY-MM-DD HH:MM:SS)
|
162
213
|
"""
|
163
|
-
logger.debug(
|
164
|
-
|
214
|
+
logger.debug(
|
215
|
+
f"Getting cost analysis between {start_date} and {end_date} (provider={provider}, model={model})"
|
216
|
+
)
|
217
|
+
|
165
218
|
if isinstance(start_date, str):
|
166
219
|
try:
|
167
220
|
start = datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S")
|
168
221
|
except ValueError:
|
222
|
+
logger.warning(
|
223
|
+
f"Date-only string provided for start_date: {start_date}. Setting time to 00:00:00"
|
224
|
+
)
|
169
225
|
start = datetime.strptime(start_date, "%Y-%m-%d")
|
226
|
+
|
170
227
|
else:
|
171
228
|
start = start_date
|
172
229
|
|
@@ -174,12 +231,20 @@ def between(
|
|
174
231
|
try:
|
175
232
|
end = datetime.strptime(end_date, "%Y-%m-%d %H:%M:%S")
|
176
233
|
except ValueError:
|
177
|
-
|
234
|
+
logger.warning(
|
235
|
+
f"Date-only string provided for end_date: {end_date}. Setting time to 23:59:59"
|
236
|
+
)
|
237
|
+
end = (
|
238
|
+
datetime.strptime(end_date, "%Y-%m-%d")
|
239
|
+
+ timedelta(days=1)
|
240
|
+
- timedelta(seconds=1)
|
241
|
+
)
|
178
242
|
else:
|
179
243
|
end = end_date
|
180
244
|
|
181
245
|
return _query_usage(start, end, provider, model)
|
182
246
|
|
247
|
+
|
183
248
|
def for_execution(execution_id: str) -> TokenUsageReport:
|
184
249
|
"""Get cost analysis for a specific execution."""
|
185
250
|
logger.debug(f"Getting cost analysis for execution_id={execution_id}")
|
@@ -187,6 +252,7 @@ def for_execution(execution_id: str) -> TokenUsageReport:
|
|
187
252
|
query = session.query(TokenUsage).filter(TokenUsage.execution_id == execution_id)
|
188
253
|
return _calculate_cost(query.all())
|
189
254
|
|
255
|
+
|
190
256
|
def last_execution() -> TokenUsageReport:
|
191
257
|
"""Get cost analysis for the last execution_id."""
|
192
258
|
logger.debug("Getting cost analysis for last execution")
|
@@ -194,9 +260,10 @@ def last_execution() -> TokenUsageReport:
|
|
194
260
|
query = session.query(TokenUsage).order_by(TokenUsage.created_at.desc()).first()
|
195
261
|
return for_execution(query.execution_id)
|
196
262
|
|
263
|
+
|
197
264
|
def all_time() -> TokenUsageReport:
|
198
265
|
"""Get cost analysis for all time."""
|
199
266
|
logger.warning("Getting cost analysis for all time. This may take a while...")
|
200
267
|
session = get_session()()
|
201
268
|
query = session.query(TokenUsage).all()
|
202
|
-
return for_execution(query.execution_id)
|
269
|
+
return for_execution(query.execution_id)
|
tokenator/utils.py
CHANGED
@@ -4,27 +4,29 @@ import os
|
|
4
4
|
import platform
|
5
5
|
import logging
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Optional
|
8
7
|
|
9
8
|
logger = logging.getLogger(__name__)
|
10
9
|
|
10
|
+
|
11
11
|
def is_colab() -> bool:
|
12
12
|
"""Check if running in Google Colab."""
|
13
13
|
try:
|
14
|
-
|
15
|
-
|
14
|
+
from importlib.util import find_spec
|
15
|
+
|
16
|
+
return find_spec("google.colab") is not None
|
16
17
|
except ImportError:
|
17
18
|
return False
|
18
19
|
|
20
|
+
|
19
21
|
def get_default_db_path() -> str:
|
20
22
|
"""Get the platform-specific default database path."""
|
21
23
|
try:
|
22
24
|
if is_colab():
|
23
25
|
# Use in-memory database for Colab
|
24
26
|
return "usage.db"
|
25
|
-
|
27
|
+
|
26
28
|
system = platform.system().lower()
|
27
|
-
|
29
|
+
|
28
30
|
if system == "linux" or system == "darwin":
|
29
31
|
# Follow XDG Base Directory Specification
|
30
32
|
xdg_data_home = os.environ.get("XDG_DATA_HOME", "")
|
@@ -39,18 +41,21 @@ def get_default_db_path() -> str:
|
|
39
41
|
db_path = os.path.join(local_app_data, "tokenator", "usage.db")
|
40
42
|
else:
|
41
43
|
db_path = os.path.join(str(Path.home()), ".tokenator", "usage.db")
|
42
|
-
|
44
|
+
|
43
45
|
# Create directory if it doesn't exist
|
44
46
|
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
45
47
|
return db_path
|
46
48
|
except (OSError, IOError) as e:
|
47
49
|
# Fallback to current directory if we can't create the default path
|
48
50
|
fallback_path = os.path.join(os.getcwd(), "tokenator_usage.db")
|
49
|
-
logger.warning(
|
50
|
-
|
51
|
+
logger.warning(
|
52
|
+
f"Could not create default db path, falling back to {fallback_path}. Error: {e}"
|
53
|
+
)
|
54
|
+
return fallback_path
|
55
|
+
|
51
56
|
|
52
57
|
__all__ = [
|
53
58
|
"get_default_db_path",
|
54
59
|
"is_colab",
|
55
60
|
# ... other exports ...
|
56
|
-
]
|
61
|
+
]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: tokenator
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.10
|
4
4
|
Summary: Token usage tracking wrapper for LLMs
|
5
5
|
License: MIT
|
6
6
|
Author: Ujjwal Maheshwari
|
@@ -60,29 +60,31 @@ response = client.chat.completions.create(
|
|
60
60
|
### Cost Analysis
|
61
61
|
|
62
62
|
```python
|
63
|
-
from tokenator import
|
63
|
+
from tokenator import usage
|
64
64
|
|
65
65
|
# Get usage for different time periods
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
66
|
+
usage.last_hour()
|
67
|
+
usage.last_day()
|
68
|
+
usage.last_week()
|
69
|
+
usage.last_month()
|
70
70
|
|
71
71
|
# Custom date range
|
72
|
-
|
72
|
+
usage.between("2024-03-01", "2024-03-15")
|
73
73
|
|
74
74
|
# Get usage for different LLM providers
|
75
|
-
|
76
|
-
|
77
|
-
|
75
|
+
usage.last_day("openai")
|
76
|
+
usage.last_day("anthropic")
|
77
|
+
usage.last_day("google")
|
78
78
|
```
|
79
79
|
|
80
|
-
### Example `
|
80
|
+
### Example `usage` object
|
81
81
|
|
82
|
-
```
|
83
|
-
|
82
|
+
```python
|
83
|
+
print(cost.last_hour().model_dump_json(indent=4))
|
84
|
+
```
|
84
85
|
|
85
|
-
|
86
|
+
```json
|
87
|
+
{
|
86
88
|
"total_cost": 0.0004,
|
87
89
|
"total_tokens": 79,
|
88
90
|
"prompt_tokens": 52,
|
@@ -0,0 +1,19 @@
|
|
1
|
+
tokenator/__init__.py,sha256=bIAPyGAvWreS2i_5tzxJEyX9JlZgAUNxzVk1iHNUhvU,593
|
2
|
+
tokenator/anthropic/client_anthropic.py,sha256=fcKxGsLex99II-WD9SVNI5QVzH0IEWRmVLjyvZd9wKs,5936
|
3
|
+
tokenator/anthropic/stream_interceptors.py,sha256=4VHC_-WkG3Pa10YizmFLrHcbz0Tm2MR_YB5-uohKp5A,5221
|
4
|
+
tokenator/base_wrapper.py,sha256=VYSkQB1MEudgzBX60T-VAMsNg4fFx7IRzpadzjm4klE,2466
|
5
|
+
tokenator/create_migrations.py,sha256=k9IHiGK21dLTA8MYNsuhO0-kUVIcMSViMFYtY4WU2Rw,730
|
6
|
+
tokenator/migrations/env.py,sha256=JoF5MJ4ae0wJW5kdBHuFlG3ZqeCCDvbMcU8fNA_a6hM,1396
|
7
|
+
tokenator/migrations/script.py.mako,sha256=nJL-tbLQE0Qy4P9S4r4ntNAcikPtoFUlvXe6xvm9ot8,635
|
8
|
+
tokenator/migrations/versions/f6f1f2437513_initial_migration.py,sha256=4cveHkwSxs-hxOPCm81YfvGZTkJJ2ClAFmyL98-1VCo,1910
|
9
|
+
tokenator/migrations.py,sha256=YAf9gZmDzAq36PWWXPtdUQoJFYPXtIDzflC79H6gcJg,1114
|
10
|
+
tokenator/models.py,sha256=MhYwCvmqposUNDRxFZNAVnzCqBTHxNL3Hp0MNFXM5ck,1201
|
11
|
+
tokenator/openai/client_openai.py,sha256=Umfxha3BhBFU_JebPjyuaUZEZuPqJWQo1xTCuAy3R24,5691
|
12
|
+
tokenator/openai/stream_interceptors.py,sha256=ez1MnjRZW_rEalv2SIPAvrU9oMD6OJoD9vht-057fDM,5243
|
13
|
+
tokenator/schemas.py,sha256=Ye8hqZlrm3Gh2FyvOVX-hWCpKynWxS58QQRQMfDtIAQ,2114
|
14
|
+
tokenator/usage.py,sha256=eTWfcRrTLop-30FmwHpi7_GwCJxU6Qfji374hG1Qptw,8476
|
15
|
+
tokenator/utils.py,sha256=xg9l2GV1yJL1BlxKL1r8CboABWDslf3G5rGQEJSjFrE,1973
|
16
|
+
tokenator-0.1.10.dist-info/LICENSE,sha256=wdG-B6-ODk8RQ4jq5uXSn0w1UWTzCH_MMyvh7AwtGns,1074
|
17
|
+
tokenator-0.1.10.dist-info/METADATA,sha256=ryILkOYlq8V8219sVmK0xUeEEw51msw_FCoF_3VJ_k8,3108
|
18
|
+
tokenator-0.1.10.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
19
|
+
tokenator-0.1.10.dist-info/RECORD,,
|
tokenator/client_anthropic.py
DELETED
@@ -1,148 +0,0 @@
|
|
1
|
-
"""Anthropic client wrapper with token usage tracking."""
|
2
|
-
|
3
|
-
from typing import Any, Dict, Optional, TypeVar, Union, overload, Iterator, AsyncIterator
|
4
|
-
import logging
|
5
|
-
|
6
|
-
from anthropic import Anthropic, AsyncAnthropic
|
7
|
-
from anthropic.types import Message, RawMessageStartEvent, RawMessageDeltaEvent
|
8
|
-
|
9
|
-
from .models import Usage, TokenUsageStats
|
10
|
-
from .base_wrapper import BaseWrapper, ResponseType
|
11
|
-
|
12
|
-
logger = logging.getLogger(__name__)
|
13
|
-
|
14
|
-
class BaseAnthropicWrapper(BaseWrapper):
|
15
|
-
provider = "anthropic"
|
16
|
-
|
17
|
-
def _process_response_usage(self, response: ResponseType) -> Optional[TokenUsageStats]:
|
18
|
-
"""Process and log usage statistics from a response."""
|
19
|
-
try:
|
20
|
-
if isinstance(response, Message):
|
21
|
-
if not hasattr(response, 'usage'):
|
22
|
-
return None
|
23
|
-
usage = Usage(
|
24
|
-
prompt_tokens=response.usage.input_tokens,
|
25
|
-
completion_tokens=response.usage.output_tokens,
|
26
|
-
total_tokens=response.usage.input_tokens + response.usage.output_tokens
|
27
|
-
)
|
28
|
-
return TokenUsageStats(model=response.model, usage=usage)
|
29
|
-
elif isinstance(response, dict):
|
30
|
-
usage_dict = response.get('usage')
|
31
|
-
if not usage_dict:
|
32
|
-
return None
|
33
|
-
usage = Usage(
|
34
|
-
prompt_tokens=usage_dict.get('input_tokens', 0),
|
35
|
-
completion_tokens=usage_dict.get('output_tokens', 0),
|
36
|
-
total_tokens=usage_dict.get('input_tokens', 0) + usage_dict.get('output_tokens', 0)
|
37
|
-
)
|
38
|
-
return TokenUsageStats(
|
39
|
-
model=response.get('model', 'unknown'),
|
40
|
-
usage=usage
|
41
|
-
)
|
42
|
-
except Exception as e:
|
43
|
-
logger.warning("Failed to process usage stats: %s", str(e))
|
44
|
-
return None
|
45
|
-
return None
|
46
|
-
|
47
|
-
@property
|
48
|
-
def messages(self):
|
49
|
-
return self
|
50
|
-
|
51
|
-
class AnthropicWrapper(BaseAnthropicWrapper):
|
52
|
-
def create(self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any) -> Union[Message, Iterator[Message]]:
|
53
|
-
"""Create a message completion and log token usage."""
|
54
|
-
logger.debug("Creating message completion with args: %s, kwargs: %s", args, kwargs)
|
55
|
-
|
56
|
-
response = self.client.messages.create(*args, **kwargs)
|
57
|
-
|
58
|
-
if not kwargs.get('stream', False):
|
59
|
-
usage_data = self._process_response_usage(response)
|
60
|
-
if usage_data:
|
61
|
-
self._log_usage(usage_data, execution_id=execution_id)
|
62
|
-
return response
|
63
|
-
|
64
|
-
return self._wrap_streaming_response(response, execution_id)
|
65
|
-
|
66
|
-
def _wrap_streaming_response(self, response_iter: Iterator[Message], execution_id: Optional[str]) -> Iterator[Message]:
|
67
|
-
"""Wrap streaming response to capture final usage stats"""
|
68
|
-
usage_data: TokenUsageStats = TokenUsageStats(model="", usage=Usage())
|
69
|
-
for chunk in response_iter:
|
70
|
-
if isinstance(chunk, RawMessageStartEvent):
|
71
|
-
usage_data.model = chunk.message.model
|
72
|
-
usage_data.usage.prompt_tokens = chunk.message.usage.input_tokens
|
73
|
-
usage_data.usage.completion_tokens = chunk.message.usage.output_tokens
|
74
|
-
usage_data.usage.total_tokens = chunk.message.usage.input_tokens + chunk.message.usage.output_tokens
|
75
|
-
|
76
|
-
elif isinstance(chunk, RawMessageDeltaEvent):
|
77
|
-
usage_data.usage.prompt_tokens += chunk.usage.input_tokens
|
78
|
-
usage_data.usage.completion_tokens += chunk.usage.output_tokens
|
79
|
-
usage_data.usage.total_tokens += chunk.usage.input_tokens + chunk.usage.output_tokens
|
80
|
-
|
81
|
-
yield chunk
|
82
|
-
|
83
|
-
self._log_usage(usage_data, execution_id=execution_id)
|
84
|
-
|
85
|
-
class AsyncAnthropicWrapper(BaseAnthropicWrapper):
|
86
|
-
async def create(self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any) -> Union[Message, AsyncIterator[Message]]:
|
87
|
-
"""Create a message completion and log token usage."""
|
88
|
-
logger.debug("Creating message completion with args: %s, kwargs: %s", args, kwargs)
|
89
|
-
|
90
|
-
if kwargs.get('stream', False):
|
91
|
-
response = await self.client.messages.create(*args, **kwargs)
|
92
|
-
return self._wrap_streaming_response(response, execution_id)
|
93
|
-
|
94
|
-
response = await self.client.messages.create(*args, **kwargs)
|
95
|
-
usage_data = self._process_response_usage(response)
|
96
|
-
if usage_data:
|
97
|
-
self._log_usage(usage_data, execution_id=execution_id)
|
98
|
-
return response
|
99
|
-
|
100
|
-
async def _wrap_streaming_response(self, response_iter: AsyncIterator[Message], execution_id: Optional[str]) -> AsyncIterator[Message]:
|
101
|
-
"""Wrap streaming response to capture final usage stats"""
|
102
|
-
usage_data: TokenUsageStats = TokenUsageStats(model="", usage=Usage())
|
103
|
-
async for chunk in response_iter:
|
104
|
-
if isinstance(chunk, RawMessageStartEvent):
|
105
|
-
usage_data.model = chunk.message.model
|
106
|
-
usage_data.usage.prompt_tokens = chunk.message.usage.input_tokens
|
107
|
-
usage_data.usage.completion_tokens = chunk.message.usage.output_tokens
|
108
|
-
usage_data.usage.total_tokens = chunk.message.usage.input_tokens + chunk.message.usage.output_tokens
|
109
|
-
|
110
|
-
elif isinstance(chunk, RawMessageDeltaEvent):
|
111
|
-
usage_data.usage.prompt_tokens += chunk.usage.input_tokens
|
112
|
-
usage_data.usage.completion_tokens += chunk.usage.output_tokens
|
113
|
-
usage_data.usage.total_tokens += chunk.usage.input_tokens + chunk.usage.output_tokens
|
114
|
-
|
115
|
-
yield chunk
|
116
|
-
|
117
|
-
|
118
|
-
self._log_usage(usage_data, execution_id=execution_id)
|
119
|
-
|
120
|
-
@overload
|
121
|
-
def tokenator_anthropic(
|
122
|
-
client: Anthropic,
|
123
|
-
db_path: Optional[str] = None,
|
124
|
-
) -> AnthropicWrapper: ...
|
125
|
-
|
126
|
-
@overload
|
127
|
-
def tokenator_anthropic(
|
128
|
-
client: AsyncAnthropic,
|
129
|
-
db_path: Optional[str] = None,
|
130
|
-
) -> AsyncAnthropicWrapper: ...
|
131
|
-
|
132
|
-
def tokenator_anthropic(
|
133
|
-
client: Union[Anthropic, AsyncAnthropic],
|
134
|
-
db_path: Optional[str] = None,
|
135
|
-
) -> Union[AnthropicWrapper, AsyncAnthropicWrapper]:
|
136
|
-
"""Create a token-tracking wrapper for an Anthropic client.
|
137
|
-
|
138
|
-
Args:
|
139
|
-
client: Anthropic or AsyncAnthropic client instance
|
140
|
-
db_path: Optional path to SQLite database for token tracking
|
141
|
-
"""
|
142
|
-
if isinstance(client, Anthropic):
|
143
|
-
return AnthropicWrapper(client=client, db_path=db_path)
|
144
|
-
|
145
|
-
if isinstance(client, AsyncAnthropic):
|
146
|
-
return AsyncAnthropicWrapper(client=client, db_path=db_path)
|
147
|
-
|
148
|
-
raise ValueError("Client must be an instance of Anthropic or AsyncAnthropic")
|
@@ -1,78 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
from typing import AsyncIterator, Callable, Generic, List, Optional, TypeVar
|
3
|
-
|
4
|
-
from openai import AsyncStream, AsyncOpenAI
|
5
|
-
from openai.types.chat import ChatCompletionChunk
|
6
|
-
|
7
|
-
logger = logging.getLogger(__name__)
|
8
|
-
|
9
|
-
_T = TypeVar("_T") # or you might specifically do _T = ChatCompletionChunk
|
10
|
-
|
11
|
-
|
12
|
-
class AsyncStreamInterceptor(AsyncStream[_T]):
|
13
|
-
"""
|
14
|
-
A wrapper around openai.AsyncStream that delegates all functionality
|
15
|
-
to the 'base_stream' but intercepts each chunk to handle usage or
|
16
|
-
logging logic. This preserves .response and other methods.
|
17
|
-
|
18
|
-
You can store aggregated usage in a local list and process it when
|
19
|
-
the stream ends (StopAsyncIteration).
|
20
|
-
"""
|
21
|
-
|
22
|
-
def __init__(
|
23
|
-
self,
|
24
|
-
base_stream: AsyncStream[_T],
|
25
|
-
usage_callback: Optional[Callable[[List[_T]], None]] = None,
|
26
|
-
):
|
27
|
-
# We do NOT call super().__init__() because openai.AsyncStream
|
28
|
-
# expects constructor parameters we don't want to re-initialize.
|
29
|
-
# Instead, we just store the base_stream and delegate everything to it.
|
30
|
-
self._base_stream = base_stream
|
31
|
-
self._usage_callback = usage_callback
|
32
|
-
self._chunks: List[_T] = []
|
33
|
-
|
34
|
-
@property
|
35
|
-
def response(self):
|
36
|
-
"""Expose the original stream's 'response' so user code can do stream.response, etc."""
|
37
|
-
return self._base_stream.response
|
38
|
-
|
39
|
-
def __aiter__(self) -> AsyncIterator[_T]:
|
40
|
-
"""
|
41
|
-
Called when we do 'async for chunk in wrapped_stream:'
|
42
|
-
We simply return 'self'. Then __anext__ does the rest.
|
43
|
-
"""
|
44
|
-
return self
|
45
|
-
|
46
|
-
async def __anext__(self) -> _T:
|
47
|
-
"""
|
48
|
-
Intercept iteration. We pull the next chunk from the base_stream.
|
49
|
-
If it's the end, do any final usage logging, then raise StopAsyncIteration.
|
50
|
-
Otherwise, we can accumulate usage info or do whatever we need with the chunk.
|
51
|
-
"""
|
52
|
-
try:
|
53
|
-
chunk = await self._base_stream.__anext__()
|
54
|
-
except StopAsyncIteration:
|
55
|
-
# Once the base stream is fully consumed, we can do final usage/logging.
|
56
|
-
if self._usage_callback and self._chunks:
|
57
|
-
self._usage_callback(self._chunks)
|
58
|
-
raise
|
59
|
-
|
60
|
-
# Intercept each chunk
|
61
|
-
self._chunks.append(chunk)
|
62
|
-
return chunk
|
63
|
-
|
64
|
-
async def __aenter__(self) -> "AsyncStreamInterceptor[_T]":
|
65
|
-
"""Support async with ... : usage."""
|
66
|
-
await self._base_stream.__aenter__()
|
67
|
-
return self
|
68
|
-
|
69
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
70
|
-
"""
|
71
|
-
Ensure we propagate __aexit__ to the base stream,
|
72
|
-
so connections are properly closed.
|
73
|
-
"""
|
74
|
-
return await self._base_stream.__aexit__(exc_type, exc_val, exc_tb)
|
75
|
-
|
76
|
-
async def close(self) -> None:
|
77
|
-
"""Delegate close to the base_stream."""
|
78
|
-
await self._base_stream.close()
|