tokenator 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- tokenator/__init__.py +3 -3
- tokenator/anthropic/client_anthropic.py +155 -0
- tokenator/anthropic/stream_interceptors.py +146 -0
- tokenator/base_wrapper.py +26 -13
- tokenator/create_migrations.py +6 -5
- tokenator/migrations/env.py +5 -4
- tokenator/migrations/versions/f6f1f2437513_initial_migration.py +25 -23
- tokenator/migrations.py +9 -6
- tokenator/models.py +15 -4
- tokenator/openai/client_openai.py +163 -0
- tokenator/openai/stream_interceptors.py +146 -0
- tokenator/schemas.py +26 -27
- tokenator/usage.py +114 -47
- tokenator/utils.py +14 -9
- {tokenator-0.1.8.dist-info → tokenator-0.1.10.dist-info}/METADATA +40 -13
- tokenator-0.1.10.dist-info/RECORD +19 -0
- tokenator/client_anthropic.py +0 -148
- tokenator/client_openai.py +0 -151
- tokenator-0.1.8.dist-info/RECORD +0 -17
- {tokenator-0.1.8.dist-info → tokenator-0.1.10.dist-info}/LICENSE +0 -0
- {tokenator-0.1.8.dist-info → tokenator-0.1.10.dist-info}/WHEEL +0 -0
tokenator/usage.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
1
|
"""Cost analysis functions for token usage."""
|
2
2
|
|
3
|
-
from datetime import datetime, timedelta
|
3
|
+
from datetime import datetime, timedelta
|
4
4
|
from typing import Dict, Optional, Union
|
5
5
|
|
6
|
-
from sqlalchemy import and_
|
7
6
|
|
8
7
|
from .schemas import get_session, TokenUsage
|
9
8
|
from .models import TokenRate, TokenUsageReport, ModelUsage, ProviderUsage
|
@@ -13,48 +12,63 @@ import logging
|
|
13
12
|
|
14
13
|
logger = logging.getLogger(__name__)
|
15
14
|
|
15
|
+
|
16
16
|
def _get_model_costs() -> Dict[str, TokenRate]:
|
17
17
|
url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
18
18
|
response = requests.get(url)
|
19
19
|
data = response.json()
|
20
|
-
|
20
|
+
|
21
21
|
return {
|
22
22
|
model: TokenRate(
|
23
23
|
prompt=info["input_cost_per_token"],
|
24
|
-
completion=info["output_cost_per_token"]
|
24
|
+
completion=info["output_cost_per_token"],
|
25
25
|
)
|
26
26
|
for model, info in data.items()
|
27
27
|
if "input_cost_per_token" in info and "output_cost_per_token" in info
|
28
28
|
}
|
29
29
|
|
30
|
+
|
30
31
|
MODEL_COSTS = _get_model_costs()
|
31
32
|
|
32
|
-
|
33
|
+
|
34
|
+
def _calculate_cost(
|
35
|
+
usages: list[TokenUsage], provider: Optional[str] = None
|
36
|
+
) -> TokenUsageReport:
|
33
37
|
"""Calculate cost from token usage records."""
|
34
38
|
# Group usages by provider and model
|
35
39
|
provider_model_usages: Dict[str, Dict[str, list[TokenUsage]]] = {}
|
36
40
|
|
37
41
|
print(f"usages: {len(usages)}")
|
38
|
-
|
42
|
+
|
39
43
|
for usage in usages:
|
40
44
|
if usage.model not in MODEL_COSTS:
|
41
45
|
continue
|
42
|
-
|
46
|
+
|
43
47
|
provider = usage.provider
|
44
48
|
if provider not in provider_model_usages:
|
45
49
|
provider_model_usages[provider] = {}
|
46
|
-
|
50
|
+
|
47
51
|
if usage.model not in provider_model_usages[provider]:
|
48
52
|
provider_model_usages[provider][usage.model] = []
|
49
|
-
|
53
|
+
|
50
54
|
provider_model_usages[provider][usage.model].append(usage)
|
51
55
|
|
52
56
|
# Calculate totals for each level
|
53
57
|
providers_list = []
|
54
|
-
total_metrics = {
|
58
|
+
total_metrics = {
|
59
|
+
"total_cost": 0.0,
|
60
|
+
"total_tokens": 0,
|
61
|
+
"prompt_tokens": 0,
|
62
|
+
"completion_tokens": 0,
|
63
|
+
}
|
55
64
|
|
56
65
|
for provider, model_usages in provider_model_usages.items():
|
57
|
-
provider_metrics = {
|
66
|
+
provider_metrics = {
|
67
|
+
"total_cost": 0.0,
|
68
|
+
"total_tokens": 0,
|
69
|
+
"prompt_tokens": 0,
|
70
|
+
"completion_tokens": 0,
|
71
|
+
}
|
58
72
|
models_list = []
|
59
73
|
|
60
74
|
for model, usages in model_usages.items():
|
@@ -67,17 +81,21 @@ def _calculate_cost(usages: list[TokenUsage], provider: Optional[str] = None) ->
|
|
67
81
|
model_prompt += usage.prompt_tokens
|
68
82
|
model_completion += usage.completion_tokens
|
69
83
|
model_total += usage.total_tokens
|
70
|
-
|
71
|
-
model_cost +=
|
72
|
-
model_cost += (
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
84
|
+
|
85
|
+
model_cost += usage.prompt_tokens * MODEL_COSTS[usage.model].prompt
|
86
|
+
model_cost += (
|
87
|
+
usage.completion_tokens * MODEL_COSTS[usage.model].completion
|
88
|
+
)
|
89
|
+
|
90
|
+
models_list.append(
|
91
|
+
ModelUsage(
|
92
|
+
model=model,
|
93
|
+
total_cost=round(model_cost, 6),
|
94
|
+
total_tokens=model_total,
|
95
|
+
prompt_tokens=model_prompt,
|
96
|
+
completion_tokens=model_completion,
|
97
|
+
)
|
98
|
+
)
|
81
99
|
|
82
100
|
# Add to provider totals
|
83
101
|
provider_metrics["total_cost"] += model_cost
|
@@ -85,11 +103,16 @@ def _calculate_cost(usages: list[TokenUsage], provider: Optional[str] = None) ->
|
|
85
103
|
provider_metrics["prompt_tokens"] += model_prompt
|
86
104
|
provider_metrics["completion_tokens"] += model_completion
|
87
105
|
|
88
|
-
providers_list.append(
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
106
|
+
providers_list.append(
|
107
|
+
ProviderUsage(
|
108
|
+
provider=provider,
|
109
|
+
models=models_list,
|
110
|
+
**{
|
111
|
+
k: (round(v, 6) if k == "total_cost" else v)
|
112
|
+
for k, v in provider_metrics.items()
|
113
|
+
},
|
114
|
+
)
|
115
|
+
)
|
93
116
|
|
94
117
|
# Add to grand totals
|
95
118
|
for key in total_metrics:
|
@@ -97,76 +120,110 @@ def _calculate_cost(usages: list[TokenUsage], provider: Optional[str] = None) ->
|
|
97
120
|
|
98
121
|
return TokenUsageReport(
|
99
122
|
providers=providers_list,
|
100
|
-
**{
|
123
|
+
**{
|
124
|
+
k: (round(v, 6) if k == "total_cost" else v)
|
125
|
+
for k, v in total_metrics.items()
|
126
|
+
},
|
101
127
|
)
|
102
128
|
|
103
|
-
|
104
|
-
|
105
|
-
|
129
|
+
|
130
|
+
def _query_usage(
|
131
|
+
start_date: datetime,
|
132
|
+
end_date: datetime,
|
133
|
+
provider: Optional[str] = None,
|
134
|
+
model: Optional[str] = None,
|
135
|
+
) -> TokenUsageReport:
|
106
136
|
"""Query token usage for a specific time period."""
|
107
137
|
session = get_session()()
|
108
138
|
try:
|
109
139
|
query = session.query(TokenUsage).filter(
|
110
140
|
TokenUsage.created_at.between(start_date, end_date)
|
111
141
|
)
|
112
|
-
|
142
|
+
|
113
143
|
if provider:
|
114
144
|
query = query.filter(TokenUsage.provider == provider)
|
115
145
|
if model:
|
116
146
|
query = query.filter(TokenUsage.model == model)
|
117
|
-
|
147
|
+
|
118
148
|
usages = query.all()
|
119
149
|
return _calculate_cost(usages, provider or "all")
|
120
150
|
finally:
|
121
151
|
session.close()
|
122
152
|
|
123
|
-
|
153
|
+
|
154
|
+
def last_hour(
|
155
|
+
provider: Optional[str] = None, model: Optional[str] = None
|
156
|
+
) -> TokenUsageReport:
|
124
157
|
"""Get cost analysis for the last hour."""
|
125
|
-
logger.debug(
|
158
|
+
logger.debug(
|
159
|
+
f"Getting cost analysis for last hour (provider={provider}, model={model})"
|
160
|
+
)
|
126
161
|
end = datetime.now()
|
127
162
|
start = end - timedelta(hours=1)
|
128
163
|
return _query_usage(start, end, provider, model)
|
129
164
|
|
130
|
-
|
165
|
+
|
166
|
+
def last_day(
|
167
|
+
provider: Optional[str] = None, model: Optional[str] = None
|
168
|
+
) -> TokenUsageReport:
|
131
169
|
"""Get cost analysis for the last 24 hours."""
|
132
|
-
logger.debug(
|
170
|
+
logger.debug(
|
171
|
+
f"Getting cost analysis for last 24 hours (provider={provider}, model={model})"
|
172
|
+
)
|
133
173
|
end = datetime.now()
|
134
174
|
start = end - timedelta(days=1)
|
135
175
|
return _query_usage(start, end, provider, model)
|
136
176
|
|
137
|
-
|
177
|
+
|
178
|
+
def last_week(
|
179
|
+
provider: Optional[str] = None, model: Optional[str] = None
|
180
|
+
) -> TokenUsageReport:
|
138
181
|
"""Get cost analysis for the last 7 days."""
|
139
|
-
logger.debug(
|
182
|
+
logger.debug(
|
183
|
+
f"Getting cost analysis for last 7 days (provider={provider}, model={model})"
|
184
|
+
)
|
140
185
|
end = datetime.now()
|
141
186
|
start = end - timedelta(weeks=1)
|
142
187
|
return _query_usage(start, end, provider, model)
|
143
188
|
|
144
|
-
|
189
|
+
|
190
|
+
def last_month(
|
191
|
+
provider: Optional[str] = None, model: Optional[str] = None
|
192
|
+
) -> TokenUsageReport:
|
145
193
|
"""Get cost analysis for the last 30 days."""
|
146
|
-
logger.debug(
|
194
|
+
logger.debug(
|
195
|
+
f"Getting cost analysis for last 30 days (provider={provider}, model={model})"
|
196
|
+
)
|
147
197
|
end = datetime.now()
|
148
198
|
start = end - timedelta(days=30)
|
149
199
|
return _query_usage(start, end, provider, model)
|
150
200
|
|
201
|
+
|
151
202
|
def between(
|
152
203
|
start_date: Union[datetime, str],
|
153
204
|
end_date: Union[datetime, str],
|
154
205
|
provider: Optional[str] = None,
|
155
|
-
model: Optional[str] = None
|
206
|
+
model: Optional[str] = None,
|
156
207
|
) -> TokenUsageReport:
|
157
208
|
"""Get cost analysis between two dates.
|
158
|
-
|
209
|
+
|
159
210
|
Args:
|
160
211
|
start_date: datetime object or string (format: YYYY-MM-DD or YYYY-MM-DD HH:MM:SS)
|
161
212
|
end_date: datetime object or string (format: YYYY-MM-DD or YYYY-MM-DD HH:MM:SS)
|
162
213
|
"""
|
163
|
-
logger.debug(
|
164
|
-
|
214
|
+
logger.debug(
|
215
|
+
f"Getting cost analysis between {start_date} and {end_date} (provider={provider}, model={model})"
|
216
|
+
)
|
217
|
+
|
165
218
|
if isinstance(start_date, str):
|
166
219
|
try:
|
167
220
|
start = datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S")
|
168
221
|
except ValueError:
|
222
|
+
logger.warning(
|
223
|
+
f"Date-only string provided for start_date: {start_date}. Setting time to 00:00:00"
|
224
|
+
)
|
169
225
|
start = datetime.strptime(start_date, "%Y-%m-%d")
|
226
|
+
|
170
227
|
else:
|
171
228
|
start = start_date
|
172
229
|
|
@@ -174,12 +231,20 @@ def between(
|
|
174
231
|
try:
|
175
232
|
end = datetime.strptime(end_date, "%Y-%m-%d %H:%M:%S")
|
176
233
|
except ValueError:
|
177
|
-
|
234
|
+
logger.warning(
|
235
|
+
f"Date-only string provided for end_date: {end_date}. Setting time to 23:59:59"
|
236
|
+
)
|
237
|
+
end = (
|
238
|
+
datetime.strptime(end_date, "%Y-%m-%d")
|
239
|
+
+ timedelta(days=1)
|
240
|
+
- timedelta(seconds=1)
|
241
|
+
)
|
178
242
|
else:
|
179
243
|
end = end_date
|
180
244
|
|
181
245
|
return _query_usage(start, end, provider, model)
|
182
246
|
|
247
|
+
|
183
248
|
def for_execution(execution_id: str) -> TokenUsageReport:
|
184
249
|
"""Get cost analysis for a specific execution."""
|
185
250
|
logger.debug(f"Getting cost analysis for execution_id={execution_id}")
|
@@ -187,6 +252,7 @@ def for_execution(execution_id: str) -> TokenUsageReport:
|
|
187
252
|
query = session.query(TokenUsage).filter(TokenUsage.execution_id == execution_id)
|
188
253
|
return _calculate_cost(query.all())
|
189
254
|
|
255
|
+
|
190
256
|
def last_execution() -> TokenUsageReport:
|
191
257
|
"""Get cost analysis for the last execution_id."""
|
192
258
|
logger.debug("Getting cost analysis for last execution")
|
@@ -194,9 +260,10 @@ def last_execution() -> TokenUsageReport:
|
|
194
260
|
query = session.query(TokenUsage).order_by(TokenUsage.created_at.desc()).first()
|
195
261
|
return for_execution(query.execution_id)
|
196
262
|
|
263
|
+
|
197
264
|
def all_time() -> TokenUsageReport:
|
198
265
|
"""Get cost analysis for all time."""
|
199
266
|
logger.warning("Getting cost analysis for all time. This may take a while...")
|
200
267
|
session = get_session()()
|
201
268
|
query = session.query(TokenUsage).all()
|
202
|
-
return for_execution(query.execution_id)
|
269
|
+
return for_execution(query.execution_id)
|
tokenator/utils.py
CHANGED
@@ -4,27 +4,29 @@ import os
|
|
4
4
|
import platform
|
5
5
|
import logging
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Optional
|
8
7
|
|
9
8
|
logger = logging.getLogger(__name__)
|
10
9
|
|
10
|
+
|
11
11
|
def is_colab() -> bool:
|
12
12
|
"""Check if running in Google Colab."""
|
13
13
|
try:
|
14
|
-
|
15
|
-
|
14
|
+
from importlib.util import find_spec
|
15
|
+
|
16
|
+
return find_spec("google.colab") is not None
|
16
17
|
except ImportError:
|
17
18
|
return False
|
18
19
|
|
20
|
+
|
19
21
|
def get_default_db_path() -> str:
|
20
22
|
"""Get the platform-specific default database path."""
|
21
23
|
try:
|
22
24
|
if is_colab():
|
23
25
|
# Use in-memory database for Colab
|
24
26
|
return "usage.db"
|
25
|
-
|
27
|
+
|
26
28
|
system = platform.system().lower()
|
27
|
-
|
29
|
+
|
28
30
|
if system == "linux" or system == "darwin":
|
29
31
|
# Follow XDG Base Directory Specification
|
30
32
|
xdg_data_home = os.environ.get("XDG_DATA_HOME", "")
|
@@ -39,18 +41,21 @@ def get_default_db_path() -> str:
|
|
39
41
|
db_path = os.path.join(local_app_data, "tokenator", "usage.db")
|
40
42
|
else:
|
41
43
|
db_path = os.path.join(str(Path.home()), ".tokenator", "usage.db")
|
42
|
-
|
44
|
+
|
43
45
|
# Create directory if it doesn't exist
|
44
46
|
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
45
47
|
return db_path
|
46
48
|
except (OSError, IOError) as e:
|
47
49
|
# Fallback to current directory if we can't create the default path
|
48
50
|
fallback_path = os.path.join(os.getcwd(), "tokenator_usage.db")
|
49
|
-
logger.warning(
|
50
|
-
|
51
|
+
logger.warning(
|
52
|
+
f"Could not create default db path, falling back to {fallback_path}. Error: {e}"
|
53
|
+
)
|
54
|
+
return fallback_path
|
55
|
+
|
51
56
|
|
52
57
|
__all__ = [
|
53
58
|
"get_default_db_path",
|
54
59
|
"is_colab",
|
55
60
|
# ... other exports ...
|
56
|
-
]
|
61
|
+
]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: tokenator
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.10
|
4
4
|
Summary: Token usage tracking wrapper for LLMs
|
5
5
|
License: MIT
|
6
6
|
Author: Ujjwal Maheshwari
|
@@ -27,7 +27,7 @@ Have you ever wondered about :
|
|
27
27
|
- How much does it cost to do run a complex AI workflow with multiple LLM providers?
|
28
28
|
- How much money did I spent today on development?
|
29
29
|
|
30
|
-
Afraid not, tokenator is here! With tokenator's easy to use API, you can start tracking LLM usage in a matter of minutes
|
30
|
+
Afraid not, tokenator is here! With tokenator's easy to use API, you can start tracking LLM usage in a matter of minutes.
|
31
31
|
|
32
32
|
Get started with just 3 lines of code!
|
33
33
|
|
@@ -60,27 +60,54 @@ response = client.chat.completions.create(
|
|
60
60
|
### Cost Analysis
|
61
61
|
|
62
62
|
```python
|
63
|
-
from tokenator import
|
63
|
+
from tokenator import usage
|
64
64
|
|
65
65
|
# Get usage for different time periods
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
66
|
+
usage.last_hour()
|
67
|
+
usage.last_day()
|
68
|
+
usage.last_week()
|
69
|
+
usage.last_month()
|
70
70
|
|
71
71
|
# Custom date range
|
72
|
-
|
72
|
+
usage.between("2024-03-01", "2024-03-15")
|
73
73
|
|
74
74
|
# Get usage for different LLM providers
|
75
|
-
|
76
|
-
|
77
|
-
|
75
|
+
usage.last_day("openai")
|
76
|
+
usage.last_day("anthropic")
|
77
|
+
usage.last_day("google")
|
78
78
|
```
|
79
79
|
|
80
|
-
### Example `
|
80
|
+
### Example `usage` object
|
81
81
|
|
82
|
-
```
|
82
|
+
```python
|
83
|
+
print(cost.last_hour().model_dump_json(indent=4))
|
84
|
+
```
|
83
85
|
|
86
|
+
```json
|
87
|
+
{
|
88
|
+
"total_cost": 0.0004,
|
89
|
+
"total_tokens": 79,
|
90
|
+
"prompt_tokens": 52,
|
91
|
+
"completion_tokens": 27,
|
92
|
+
"providers": [
|
93
|
+
{
|
94
|
+
"total_cost": 0.0004,
|
95
|
+
"total_tokens": 79,
|
96
|
+
"prompt_tokens": 52,
|
97
|
+
"completion_tokens": 27,
|
98
|
+
"provider": "openai",
|
99
|
+
"models": [
|
100
|
+
{
|
101
|
+
"total_cost": 0.0004,
|
102
|
+
"total_tokens": 79,
|
103
|
+
"prompt_tokens": 52,
|
104
|
+
"completion_tokens": 27,
|
105
|
+
"model": "gpt-4o-2024-08-06"
|
106
|
+
}
|
107
|
+
]
|
108
|
+
}
|
109
|
+
]
|
110
|
+
}
|
84
111
|
```
|
85
112
|
|
86
113
|
## Features
|
@@ -0,0 +1,19 @@
|
|
1
|
+
tokenator/__init__.py,sha256=bIAPyGAvWreS2i_5tzxJEyX9JlZgAUNxzVk1iHNUhvU,593
|
2
|
+
tokenator/anthropic/client_anthropic.py,sha256=fcKxGsLex99II-WD9SVNI5QVzH0IEWRmVLjyvZd9wKs,5936
|
3
|
+
tokenator/anthropic/stream_interceptors.py,sha256=4VHC_-WkG3Pa10YizmFLrHcbz0Tm2MR_YB5-uohKp5A,5221
|
4
|
+
tokenator/base_wrapper.py,sha256=VYSkQB1MEudgzBX60T-VAMsNg4fFx7IRzpadzjm4klE,2466
|
5
|
+
tokenator/create_migrations.py,sha256=k9IHiGK21dLTA8MYNsuhO0-kUVIcMSViMFYtY4WU2Rw,730
|
6
|
+
tokenator/migrations/env.py,sha256=JoF5MJ4ae0wJW5kdBHuFlG3ZqeCCDvbMcU8fNA_a6hM,1396
|
7
|
+
tokenator/migrations/script.py.mako,sha256=nJL-tbLQE0Qy4P9S4r4ntNAcikPtoFUlvXe6xvm9ot8,635
|
8
|
+
tokenator/migrations/versions/f6f1f2437513_initial_migration.py,sha256=4cveHkwSxs-hxOPCm81YfvGZTkJJ2ClAFmyL98-1VCo,1910
|
9
|
+
tokenator/migrations.py,sha256=YAf9gZmDzAq36PWWXPtdUQoJFYPXtIDzflC79H6gcJg,1114
|
10
|
+
tokenator/models.py,sha256=MhYwCvmqposUNDRxFZNAVnzCqBTHxNL3Hp0MNFXM5ck,1201
|
11
|
+
tokenator/openai/client_openai.py,sha256=Umfxha3BhBFU_JebPjyuaUZEZuPqJWQo1xTCuAy3R24,5691
|
12
|
+
tokenator/openai/stream_interceptors.py,sha256=ez1MnjRZW_rEalv2SIPAvrU9oMD6OJoD9vht-057fDM,5243
|
13
|
+
tokenator/schemas.py,sha256=Ye8hqZlrm3Gh2FyvOVX-hWCpKynWxS58QQRQMfDtIAQ,2114
|
14
|
+
tokenator/usage.py,sha256=eTWfcRrTLop-30FmwHpi7_GwCJxU6Qfji374hG1Qptw,8476
|
15
|
+
tokenator/utils.py,sha256=xg9l2GV1yJL1BlxKL1r8CboABWDslf3G5rGQEJSjFrE,1973
|
16
|
+
tokenator-0.1.10.dist-info/LICENSE,sha256=wdG-B6-ODk8RQ4jq5uXSn0w1UWTzCH_MMyvh7AwtGns,1074
|
17
|
+
tokenator-0.1.10.dist-info/METADATA,sha256=ryILkOYlq8V8219sVmK0xUeEEw51msw_FCoF_3VJ_k8,3108
|
18
|
+
tokenator-0.1.10.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
19
|
+
tokenator-0.1.10.dist-info/RECORD,,
|
tokenator/client_anthropic.py
DELETED
@@ -1,148 +0,0 @@
|
|
1
|
-
"""Anthropic client wrapper with token usage tracking."""
|
2
|
-
|
3
|
-
from typing import Any, Dict, Optional, TypeVar, Union, overload, Iterator, AsyncIterator
|
4
|
-
import logging
|
5
|
-
|
6
|
-
from anthropic import Anthropic, AsyncAnthropic
|
7
|
-
from anthropic.types import Message, RawMessageStartEvent, RawMessageDeltaEvent
|
8
|
-
|
9
|
-
from .models import Usage, TokenUsageStats
|
10
|
-
from .base_wrapper import BaseWrapper, ResponseType
|
11
|
-
|
12
|
-
logger = logging.getLogger(__name__)
|
13
|
-
|
14
|
-
class BaseAnthropicWrapper(BaseWrapper):
|
15
|
-
provider = "anthropic"
|
16
|
-
|
17
|
-
def _process_response_usage(self, response: ResponseType) -> Optional[TokenUsageStats]:
|
18
|
-
"""Process and log usage statistics from a response."""
|
19
|
-
try:
|
20
|
-
if isinstance(response, Message):
|
21
|
-
if not hasattr(response, 'usage'):
|
22
|
-
return None
|
23
|
-
usage = Usage(
|
24
|
-
prompt_tokens=response.usage.input_tokens,
|
25
|
-
completion_tokens=response.usage.output_tokens,
|
26
|
-
total_tokens=response.usage.input_tokens + response.usage.output_tokens
|
27
|
-
)
|
28
|
-
return TokenUsageStats(model=response.model, usage=usage)
|
29
|
-
elif isinstance(response, dict):
|
30
|
-
usage_dict = response.get('usage')
|
31
|
-
if not usage_dict:
|
32
|
-
return None
|
33
|
-
usage = Usage(
|
34
|
-
prompt_tokens=usage_dict.get('input_tokens', 0),
|
35
|
-
completion_tokens=usage_dict.get('output_tokens', 0),
|
36
|
-
total_tokens=usage_dict.get('input_tokens', 0) + usage_dict.get('output_tokens', 0)
|
37
|
-
)
|
38
|
-
return TokenUsageStats(
|
39
|
-
model=response.get('model', 'unknown'),
|
40
|
-
usage=usage
|
41
|
-
)
|
42
|
-
except Exception as e:
|
43
|
-
logger.warning("Failed to process usage stats: %s", str(e))
|
44
|
-
return None
|
45
|
-
return None
|
46
|
-
|
47
|
-
@property
|
48
|
-
def messages(self):
|
49
|
-
return self
|
50
|
-
|
51
|
-
class AnthropicWrapper(BaseAnthropicWrapper):
|
52
|
-
def create(self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any) -> Union[Message, Iterator[Message]]:
|
53
|
-
"""Create a message completion and log token usage."""
|
54
|
-
logger.debug("Creating message completion with args: %s, kwargs: %s", args, kwargs)
|
55
|
-
|
56
|
-
response = self.client.messages.create(*args, **kwargs)
|
57
|
-
|
58
|
-
if not kwargs.get('stream', False):
|
59
|
-
usage_data = self._process_response_usage(response)
|
60
|
-
if usage_data:
|
61
|
-
self._log_usage(usage_data, execution_id=execution_id)
|
62
|
-
return response
|
63
|
-
|
64
|
-
return self._wrap_streaming_response(response, execution_id)
|
65
|
-
|
66
|
-
def _wrap_streaming_response(self, response_iter: Iterator[Message], execution_id: Optional[str]) -> Iterator[Message]:
|
67
|
-
"""Wrap streaming response to capture final usage stats"""
|
68
|
-
usage_data: TokenUsageStats = TokenUsageStats(model="", usage=Usage())
|
69
|
-
for chunk in response_iter:
|
70
|
-
if isinstance(chunk, RawMessageStartEvent):
|
71
|
-
usage_data.model = chunk.message.model
|
72
|
-
usage_data.usage.prompt_tokens = chunk.message.usage.input_tokens
|
73
|
-
usage_data.usage.completion_tokens = chunk.message.usage.output_tokens
|
74
|
-
usage_data.usage.total_tokens = chunk.message.usage.input_tokens + chunk.message.usage.output_tokens
|
75
|
-
|
76
|
-
elif isinstance(chunk, RawMessageDeltaEvent):
|
77
|
-
usage_data.usage.prompt_tokens += chunk.usage.input_tokens
|
78
|
-
usage_data.usage.completion_tokens += chunk.usage.output_tokens
|
79
|
-
usage_data.usage.total_tokens += chunk.usage.input_tokens + chunk.usage.output_tokens
|
80
|
-
|
81
|
-
yield chunk
|
82
|
-
|
83
|
-
self._log_usage(usage_data, execution_id=execution_id)
|
84
|
-
|
85
|
-
class AsyncAnthropicWrapper(BaseAnthropicWrapper):
|
86
|
-
async def create(self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any) -> Union[Message, AsyncIterator[Message]]:
|
87
|
-
"""Create a message completion and log token usage."""
|
88
|
-
logger.debug("Creating message completion with args: %s, kwargs: %s", args, kwargs)
|
89
|
-
|
90
|
-
if kwargs.get('stream', False):
|
91
|
-
response = await self.client.messages.create(*args, **kwargs)
|
92
|
-
return self._wrap_streaming_response(response, execution_id)
|
93
|
-
|
94
|
-
response = await self.client.messages.create(*args, **kwargs)
|
95
|
-
usage_data = self._process_response_usage(response)
|
96
|
-
if usage_data:
|
97
|
-
self._log_usage(usage_data, execution_id=execution_id)
|
98
|
-
return response
|
99
|
-
|
100
|
-
async def _wrap_streaming_response(self, response_iter: AsyncIterator[Message], execution_id: Optional[str]) -> AsyncIterator[Message]:
|
101
|
-
"""Wrap streaming response to capture final usage stats"""
|
102
|
-
usage_data: TokenUsageStats = TokenUsageStats(model="", usage=Usage())
|
103
|
-
async for chunk in response_iter:
|
104
|
-
if isinstance(chunk, RawMessageStartEvent):
|
105
|
-
usage_data.model = chunk.message.model
|
106
|
-
usage_data.usage.prompt_tokens = chunk.message.usage.input_tokens
|
107
|
-
usage_data.usage.completion_tokens = chunk.message.usage.output_tokens
|
108
|
-
usage_data.usage.total_tokens = chunk.message.usage.input_tokens + chunk.message.usage.output_tokens
|
109
|
-
|
110
|
-
elif isinstance(chunk, RawMessageDeltaEvent):
|
111
|
-
usage_data.usage.prompt_tokens += chunk.usage.input_tokens
|
112
|
-
usage_data.usage.completion_tokens += chunk.usage.output_tokens
|
113
|
-
usage_data.usage.total_tokens += chunk.usage.input_tokens + chunk.usage.output_tokens
|
114
|
-
|
115
|
-
yield chunk
|
116
|
-
|
117
|
-
|
118
|
-
self._log_usage(usage_data, execution_id=execution_id)
|
119
|
-
|
120
|
-
@overload
|
121
|
-
def tokenator_anthropic(
|
122
|
-
client: Anthropic,
|
123
|
-
db_path: Optional[str] = None,
|
124
|
-
) -> AnthropicWrapper: ...
|
125
|
-
|
126
|
-
@overload
|
127
|
-
def tokenator_anthropic(
|
128
|
-
client: AsyncAnthropic,
|
129
|
-
db_path: Optional[str] = None,
|
130
|
-
) -> AsyncAnthropicWrapper: ...
|
131
|
-
|
132
|
-
def tokenator_anthropic(
|
133
|
-
client: Union[Anthropic, AsyncAnthropic],
|
134
|
-
db_path: Optional[str] = None,
|
135
|
-
) -> Union[AnthropicWrapper, AsyncAnthropicWrapper]:
|
136
|
-
"""Create a token-tracking wrapper for an Anthropic client.
|
137
|
-
|
138
|
-
Args:
|
139
|
-
client: Anthropic or AsyncAnthropic client instance
|
140
|
-
db_path: Optional path to SQLite database for token tracking
|
141
|
-
"""
|
142
|
-
if isinstance(client, Anthropic):
|
143
|
-
return AnthropicWrapper(client=client, db_path=db_path)
|
144
|
-
|
145
|
-
if isinstance(client, AsyncAnthropic):
|
146
|
-
return AsyncAnthropicWrapper(client=client, db_path=db_path)
|
147
|
-
|
148
|
-
raise ValueError("Client must be an instance of Anthropic or AsyncAnthropic")
|