tokenator 0.1.9__py3-none-any.whl → 0.1.10__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
tokenator/usage.py CHANGED
@@ -1,9 +1,8 @@
1
1
  """Cost analysis functions for token usage."""
2
2
 
3
- from datetime import datetime, timedelta, timezone
3
+ from datetime import datetime, timedelta
4
4
  from typing import Dict, Optional, Union
5
5
 
6
- from sqlalchemy import and_
7
6
 
8
7
  from .schemas import get_session, TokenUsage
9
8
  from .models import TokenRate, TokenUsageReport, ModelUsage, ProviderUsage
@@ -13,48 +12,63 @@ import logging
13
12
 
14
13
  logger = logging.getLogger(__name__)
15
14
 
15
+
16
16
  def _get_model_costs() -> Dict[str, TokenRate]:
17
17
  url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
18
18
  response = requests.get(url)
19
19
  data = response.json()
20
-
20
+
21
21
  return {
22
22
  model: TokenRate(
23
23
  prompt=info["input_cost_per_token"],
24
- completion=info["output_cost_per_token"]
24
+ completion=info["output_cost_per_token"],
25
25
  )
26
26
  for model, info in data.items()
27
27
  if "input_cost_per_token" in info and "output_cost_per_token" in info
28
28
  }
29
29
 
30
+
30
31
  MODEL_COSTS = _get_model_costs()
31
32
 
32
- def _calculate_cost(usages: list[TokenUsage], provider: Optional[str] = None) -> TokenUsageReport:
33
+
34
+ def _calculate_cost(
35
+ usages: list[TokenUsage], provider: Optional[str] = None
36
+ ) -> TokenUsageReport:
33
37
  """Calculate cost from token usage records."""
34
38
  # Group usages by provider and model
35
39
  provider_model_usages: Dict[str, Dict[str, list[TokenUsage]]] = {}
36
40
 
37
41
  print(f"usages: {len(usages)}")
38
-
42
+
39
43
  for usage in usages:
40
44
  if usage.model not in MODEL_COSTS:
41
45
  continue
42
-
46
+
43
47
  provider = usage.provider
44
48
  if provider not in provider_model_usages:
45
49
  provider_model_usages[provider] = {}
46
-
50
+
47
51
  if usage.model not in provider_model_usages[provider]:
48
52
  provider_model_usages[provider][usage.model] = []
49
-
53
+
50
54
  provider_model_usages[provider][usage.model].append(usage)
51
55
 
52
56
  # Calculate totals for each level
53
57
  providers_list = []
54
- total_metrics = {"total_cost": 0.0, "total_tokens": 0, "prompt_tokens": 0, "completion_tokens": 0}
58
+ total_metrics = {
59
+ "total_cost": 0.0,
60
+ "total_tokens": 0,
61
+ "prompt_tokens": 0,
62
+ "completion_tokens": 0,
63
+ }
55
64
 
56
65
  for provider, model_usages in provider_model_usages.items():
57
- provider_metrics = {"total_cost": 0.0, "total_tokens": 0, "prompt_tokens": 0, "completion_tokens": 0}
66
+ provider_metrics = {
67
+ "total_cost": 0.0,
68
+ "total_tokens": 0,
69
+ "prompt_tokens": 0,
70
+ "completion_tokens": 0,
71
+ }
58
72
  models_list = []
59
73
 
60
74
  for model, usages in model_usages.items():
@@ -67,17 +81,21 @@ def _calculate_cost(usages: list[TokenUsage], provider: Optional[str] = None) ->
67
81
  model_prompt += usage.prompt_tokens
68
82
  model_completion += usage.completion_tokens
69
83
  model_total += usage.total_tokens
70
-
71
- model_cost += (usage.prompt_tokens * MODEL_COSTS[usage.model].prompt)
72
- model_cost += (usage.completion_tokens * MODEL_COSTS[usage.model].completion)
73
-
74
- models_list.append(ModelUsage(
75
- model=model,
76
- total_cost=round(model_cost, 6),
77
- total_tokens=model_total,
78
- prompt_tokens=model_prompt,
79
- completion_tokens=model_completion
80
- ))
84
+
85
+ model_cost += usage.prompt_tokens * MODEL_COSTS[usage.model].prompt
86
+ model_cost += (
87
+ usage.completion_tokens * MODEL_COSTS[usage.model].completion
88
+ )
89
+
90
+ models_list.append(
91
+ ModelUsage(
92
+ model=model,
93
+ total_cost=round(model_cost, 6),
94
+ total_tokens=model_total,
95
+ prompt_tokens=model_prompt,
96
+ completion_tokens=model_completion,
97
+ )
98
+ )
81
99
 
82
100
  # Add to provider totals
83
101
  provider_metrics["total_cost"] += model_cost
@@ -85,11 +103,16 @@ def _calculate_cost(usages: list[TokenUsage], provider: Optional[str] = None) ->
85
103
  provider_metrics["prompt_tokens"] += model_prompt
86
104
  provider_metrics["completion_tokens"] += model_completion
87
105
 
88
- providers_list.append(ProviderUsage(
89
- provider=provider,
90
- models=models_list,
91
- **{k: (round(v, 6) if k == "total_cost" else v) for k, v in provider_metrics.items()}
92
- ))
106
+ providers_list.append(
107
+ ProviderUsage(
108
+ provider=provider,
109
+ models=models_list,
110
+ **{
111
+ k: (round(v, 6) if k == "total_cost" else v)
112
+ for k, v in provider_metrics.items()
113
+ },
114
+ )
115
+ )
93
116
 
94
117
  # Add to grand totals
95
118
  for key in total_metrics:
@@ -97,76 +120,110 @@ def _calculate_cost(usages: list[TokenUsage], provider: Optional[str] = None) ->
97
120
 
98
121
  return TokenUsageReport(
99
122
  providers=providers_list,
100
- **{k: (round(v, 6) if k == "total_cost" else v) for k, v in total_metrics.items()}
123
+ **{
124
+ k: (round(v, 6) if k == "total_cost" else v)
125
+ for k, v in total_metrics.items()
126
+ },
101
127
  )
102
128
 
103
- def _query_usage(start_date: datetime, end_date: datetime,
104
- provider: Optional[str] = None,
105
- model: Optional[str] = None) -> TokenUsageReport:
129
+
130
+ def _query_usage(
131
+ start_date: datetime,
132
+ end_date: datetime,
133
+ provider: Optional[str] = None,
134
+ model: Optional[str] = None,
135
+ ) -> TokenUsageReport:
106
136
  """Query token usage for a specific time period."""
107
137
  session = get_session()()
108
138
  try:
109
139
  query = session.query(TokenUsage).filter(
110
140
  TokenUsage.created_at.between(start_date, end_date)
111
141
  )
112
-
142
+
113
143
  if provider:
114
144
  query = query.filter(TokenUsage.provider == provider)
115
145
  if model:
116
146
  query = query.filter(TokenUsage.model == model)
117
-
147
+
118
148
  usages = query.all()
119
149
  return _calculate_cost(usages, provider or "all")
120
150
  finally:
121
151
  session.close()
122
152
 
123
- def last_hour(provider: Optional[str] = None, model: Optional[str] = None) -> TokenUsageReport:
153
+
154
+ def last_hour(
155
+ provider: Optional[str] = None, model: Optional[str] = None
156
+ ) -> TokenUsageReport:
124
157
  """Get cost analysis for the last hour."""
125
- logger.debug(f"Getting cost analysis for last hour (provider={provider}, model={model})")
158
+ logger.debug(
159
+ f"Getting cost analysis for last hour (provider={provider}, model={model})"
160
+ )
126
161
  end = datetime.now()
127
162
  start = end - timedelta(hours=1)
128
163
  return _query_usage(start, end, provider, model)
129
164
 
130
- def last_day(provider: Optional[str] = None, model: Optional[str] = None) -> TokenUsageReport:
165
+
166
+ def last_day(
167
+ provider: Optional[str] = None, model: Optional[str] = None
168
+ ) -> TokenUsageReport:
131
169
  """Get cost analysis for the last 24 hours."""
132
- logger.debug(f"Getting cost analysis for last 24 hours (provider={provider}, model={model})")
170
+ logger.debug(
171
+ f"Getting cost analysis for last 24 hours (provider={provider}, model={model})"
172
+ )
133
173
  end = datetime.now()
134
174
  start = end - timedelta(days=1)
135
175
  return _query_usage(start, end, provider, model)
136
176
 
137
- def last_week(provider: Optional[str] = None, model: Optional[str] = None) -> TokenUsageReport:
177
+
178
+ def last_week(
179
+ provider: Optional[str] = None, model: Optional[str] = None
180
+ ) -> TokenUsageReport:
138
181
  """Get cost analysis for the last 7 days."""
139
- logger.debug(f"Getting cost analysis for last 7 days (provider={provider}, model={model})")
182
+ logger.debug(
183
+ f"Getting cost analysis for last 7 days (provider={provider}, model={model})"
184
+ )
140
185
  end = datetime.now()
141
186
  start = end - timedelta(weeks=1)
142
187
  return _query_usage(start, end, provider, model)
143
188
 
144
- def last_month(provider: Optional[str] = None, model: Optional[str] = None) -> TokenUsageReport:
189
+
190
+ def last_month(
191
+ provider: Optional[str] = None, model: Optional[str] = None
192
+ ) -> TokenUsageReport:
145
193
  """Get cost analysis for the last 30 days."""
146
- logger.debug(f"Getting cost analysis for last 30 days (provider={provider}, model={model})")
194
+ logger.debug(
195
+ f"Getting cost analysis for last 30 days (provider={provider}, model={model})"
196
+ )
147
197
  end = datetime.now()
148
198
  start = end - timedelta(days=30)
149
199
  return _query_usage(start, end, provider, model)
150
200
 
201
+
151
202
  def between(
152
203
  start_date: Union[datetime, str],
153
204
  end_date: Union[datetime, str],
154
205
  provider: Optional[str] = None,
155
- model: Optional[str] = None
206
+ model: Optional[str] = None,
156
207
  ) -> TokenUsageReport:
157
208
  """Get cost analysis between two dates.
158
-
209
+
159
210
  Args:
160
211
  start_date: datetime object or string (format: YYYY-MM-DD or YYYY-MM-DD HH:MM:SS)
161
212
  end_date: datetime object or string (format: YYYY-MM-DD or YYYY-MM-DD HH:MM:SS)
162
213
  """
163
- logger.debug(f"Getting cost analysis between {start_date} and {end_date} (provider={provider}, model={model})")
164
-
214
+ logger.debug(
215
+ f"Getting cost analysis between {start_date} and {end_date} (provider={provider}, model={model})"
216
+ )
217
+
165
218
  if isinstance(start_date, str):
166
219
  try:
167
220
  start = datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S")
168
221
  except ValueError:
222
+ logger.warning(
223
+ f"Date-only string provided for start_date: {start_date}. Setting time to 00:00:00"
224
+ )
169
225
  start = datetime.strptime(start_date, "%Y-%m-%d")
226
+
170
227
  else:
171
228
  start = start_date
172
229
 
@@ -174,12 +231,20 @@ def between(
174
231
  try:
175
232
  end = datetime.strptime(end_date, "%Y-%m-%d %H:%M:%S")
176
233
  except ValueError:
177
- end = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1) # Include the end date
234
+ logger.warning(
235
+ f"Date-only string provided for end_date: {end_date}. Setting time to 23:59:59"
236
+ )
237
+ end = (
238
+ datetime.strptime(end_date, "%Y-%m-%d")
239
+ + timedelta(days=1)
240
+ - timedelta(seconds=1)
241
+ )
178
242
  else:
179
243
  end = end_date
180
244
 
181
245
  return _query_usage(start, end, provider, model)
182
246
 
247
+
183
248
  def for_execution(execution_id: str) -> TokenUsageReport:
184
249
  """Get cost analysis for a specific execution."""
185
250
  logger.debug(f"Getting cost analysis for execution_id={execution_id}")
@@ -187,6 +252,7 @@ def for_execution(execution_id: str) -> TokenUsageReport:
187
252
  query = session.query(TokenUsage).filter(TokenUsage.execution_id == execution_id)
188
253
  return _calculate_cost(query.all())
189
254
 
255
+
190
256
  def last_execution() -> TokenUsageReport:
191
257
  """Get cost analysis for the last execution_id."""
192
258
  logger.debug("Getting cost analysis for last execution")
@@ -194,9 +260,10 @@ def last_execution() -> TokenUsageReport:
194
260
  query = session.query(TokenUsage).order_by(TokenUsage.created_at.desc()).first()
195
261
  return for_execution(query.execution_id)
196
262
 
263
+
197
264
  def all_time() -> TokenUsageReport:
198
265
  """Get cost analysis for all time."""
199
266
  logger.warning("Getting cost analysis for all time. This may take a while...")
200
267
  session = get_session()()
201
268
  query = session.query(TokenUsage).all()
202
- return for_execution(query.execution_id)
269
+ return for_execution(query.execution_id)
tokenator/utils.py CHANGED
@@ -4,27 +4,29 @@ import os
4
4
  import platform
5
5
  import logging
6
6
  from pathlib import Path
7
- from typing import Optional
8
7
 
9
8
  logger = logging.getLogger(__name__)
10
9
 
10
+
11
11
  def is_colab() -> bool:
12
12
  """Check if running in Google Colab."""
13
13
  try:
14
- import google.colab # type: ignore
15
- return True
14
+ from importlib.util import find_spec
15
+
16
+ return find_spec("google.colab") is not None
16
17
  except ImportError:
17
18
  return False
18
19
 
20
+
19
21
  def get_default_db_path() -> str:
20
22
  """Get the platform-specific default database path."""
21
23
  try:
22
24
  if is_colab():
23
25
  # Use in-memory database for Colab
24
26
  return "usage.db"
25
-
27
+
26
28
  system = platform.system().lower()
27
-
29
+
28
30
  if system == "linux" or system == "darwin":
29
31
  # Follow XDG Base Directory Specification
30
32
  xdg_data_home = os.environ.get("XDG_DATA_HOME", "")
@@ -39,18 +41,21 @@ def get_default_db_path() -> str:
39
41
  db_path = os.path.join(local_app_data, "tokenator", "usage.db")
40
42
  else:
41
43
  db_path = os.path.join(str(Path.home()), ".tokenator", "usage.db")
42
-
44
+
43
45
  # Create directory if it doesn't exist
44
46
  os.makedirs(os.path.dirname(db_path), exist_ok=True)
45
47
  return db_path
46
48
  except (OSError, IOError) as e:
47
49
  # Fallback to current directory if we can't create the default path
48
50
  fallback_path = os.path.join(os.getcwd(), "tokenator_usage.db")
49
- logger.warning(f"Could not create default db path, falling back to {fallback_path}. Error: {e}")
50
- return fallback_path
51
+ logger.warning(
52
+ f"Could not create default db path, falling back to {fallback_path}. Error: {e}"
53
+ )
54
+ return fallback_path
55
+
51
56
 
52
57
  __all__ = [
53
58
  "get_default_db_path",
54
59
  "is_colab",
55
60
  # ... other exports ...
56
- ]
61
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tokenator
3
- Version: 0.1.9
3
+ Version: 0.1.10
4
4
  Summary: Token usage tracking wrapper for LLMs
5
5
  License: MIT
6
6
  Author: Ujjwal Maheshwari
@@ -60,29 +60,31 @@ response = client.chat.completions.create(
60
60
  ### Cost Analysis
61
61
 
62
62
  ```python
63
- from tokenator import cost
63
+ from tokenator import usage
64
64
 
65
65
  # Get usage for different time periods
66
- cost.last_hour()
67
- cost.last_day()
68
- cost.last_week()
69
- cost.last_month()
66
+ usage.last_hour()
67
+ usage.last_day()
68
+ usage.last_week()
69
+ usage.last_month()
70
70
 
71
71
  # Custom date range
72
- cost.between("2024-03-01", "2024-03-15")
72
+ usage.between("2024-03-01", "2024-03-15")
73
73
 
74
74
  # Get usage for different LLM providers
75
- cost.last_day("openai")
76
- cost.last_day("anthropic")
77
- cost.last_day("google")
75
+ usage.last_day("openai")
76
+ usage.last_day("anthropic")
77
+ usage.last_day("google")
78
78
  ```
79
79
 
80
- ### Example `cost` object
80
+ ### Example `usage` object
81
81
 
82
- ```json
83
- # print(cost.last_hour().model_dump_json(indent=4))
82
+ ```python
83
+ print(cost.last_hour().model_dump_json(indent=4))
84
+ ```
84
85
 
85
- usage : {
86
+ ```json
87
+ {
86
88
  "total_cost": 0.0004,
87
89
  "total_tokens": 79,
88
90
  "prompt_tokens": 52,
@@ -0,0 +1,19 @@
1
+ tokenator/__init__.py,sha256=bIAPyGAvWreS2i_5tzxJEyX9JlZgAUNxzVk1iHNUhvU,593
2
+ tokenator/anthropic/client_anthropic.py,sha256=fcKxGsLex99II-WD9SVNI5QVzH0IEWRmVLjyvZd9wKs,5936
3
+ tokenator/anthropic/stream_interceptors.py,sha256=4VHC_-WkG3Pa10YizmFLrHcbz0Tm2MR_YB5-uohKp5A,5221
4
+ tokenator/base_wrapper.py,sha256=VYSkQB1MEudgzBX60T-VAMsNg4fFx7IRzpadzjm4klE,2466
5
+ tokenator/create_migrations.py,sha256=k9IHiGK21dLTA8MYNsuhO0-kUVIcMSViMFYtY4WU2Rw,730
6
+ tokenator/migrations/env.py,sha256=JoF5MJ4ae0wJW5kdBHuFlG3ZqeCCDvbMcU8fNA_a6hM,1396
7
+ tokenator/migrations/script.py.mako,sha256=nJL-tbLQE0Qy4P9S4r4ntNAcikPtoFUlvXe6xvm9ot8,635
8
+ tokenator/migrations/versions/f6f1f2437513_initial_migration.py,sha256=4cveHkwSxs-hxOPCm81YfvGZTkJJ2ClAFmyL98-1VCo,1910
9
+ tokenator/migrations.py,sha256=YAf9gZmDzAq36PWWXPtdUQoJFYPXtIDzflC79H6gcJg,1114
10
+ tokenator/models.py,sha256=MhYwCvmqposUNDRxFZNAVnzCqBTHxNL3Hp0MNFXM5ck,1201
11
+ tokenator/openai/client_openai.py,sha256=Umfxha3BhBFU_JebPjyuaUZEZuPqJWQo1xTCuAy3R24,5691
12
+ tokenator/openai/stream_interceptors.py,sha256=ez1MnjRZW_rEalv2SIPAvrU9oMD6OJoD9vht-057fDM,5243
13
+ tokenator/schemas.py,sha256=Ye8hqZlrm3Gh2FyvOVX-hWCpKynWxS58QQRQMfDtIAQ,2114
14
+ tokenator/usage.py,sha256=eTWfcRrTLop-30FmwHpi7_GwCJxU6Qfji374hG1Qptw,8476
15
+ tokenator/utils.py,sha256=xg9l2GV1yJL1BlxKL1r8CboABWDslf3G5rGQEJSjFrE,1973
16
+ tokenator-0.1.10.dist-info/LICENSE,sha256=wdG-B6-ODk8RQ4jq5uXSn0w1UWTzCH_MMyvh7AwtGns,1074
17
+ tokenator-0.1.10.dist-info/METADATA,sha256=ryILkOYlq8V8219sVmK0xUeEEw51msw_FCoF_3VJ_k8,3108
18
+ tokenator-0.1.10.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
19
+ tokenator-0.1.10.dist-info/RECORD,,
@@ -1,148 +0,0 @@
1
- """Anthropic client wrapper with token usage tracking."""
2
-
3
- from typing import Any, Dict, Optional, TypeVar, Union, overload, Iterator, AsyncIterator
4
- import logging
5
-
6
- from anthropic import Anthropic, AsyncAnthropic
7
- from anthropic.types import Message, RawMessageStartEvent, RawMessageDeltaEvent
8
-
9
- from .models import Usage, TokenUsageStats
10
- from .base_wrapper import BaseWrapper, ResponseType
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
- class BaseAnthropicWrapper(BaseWrapper):
15
- provider = "anthropic"
16
-
17
- def _process_response_usage(self, response: ResponseType) -> Optional[TokenUsageStats]:
18
- """Process and log usage statistics from a response."""
19
- try:
20
- if isinstance(response, Message):
21
- if not hasattr(response, 'usage'):
22
- return None
23
- usage = Usage(
24
- prompt_tokens=response.usage.input_tokens,
25
- completion_tokens=response.usage.output_tokens,
26
- total_tokens=response.usage.input_tokens + response.usage.output_tokens
27
- )
28
- return TokenUsageStats(model=response.model, usage=usage)
29
- elif isinstance(response, dict):
30
- usage_dict = response.get('usage')
31
- if not usage_dict:
32
- return None
33
- usage = Usage(
34
- prompt_tokens=usage_dict.get('input_tokens', 0),
35
- completion_tokens=usage_dict.get('output_tokens', 0),
36
- total_tokens=usage_dict.get('input_tokens', 0) + usage_dict.get('output_tokens', 0)
37
- )
38
- return TokenUsageStats(
39
- model=response.get('model', 'unknown'),
40
- usage=usage
41
- )
42
- except Exception as e:
43
- logger.warning("Failed to process usage stats: %s", str(e))
44
- return None
45
- return None
46
-
47
- @property
48
- def messages(self):
49
- return self
50
-
51
- class AnthropicWrapper(BaseAnthropicWrapper):
52
- def create(self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any) -> Union[Message, Iterator[Message]]:
53
- """Create a message completion and log token usage."""
54
- logger.debug("Creating message completion with args: %s, kwargs: %s", args, kwargs)
55
-
56
- response = self.client.messages.create(*args, **kwargs)
57
-
58
- if not kwargs.get('stream', False):
59
- usage_data = self._process_response_usage(response)
60
- if usage_data:
61
- self._log_usage(usage_data, execution_id=execution_id)
62
- return response
63
-
64
- return self._wrap_streaming_response(response, execution_id)
65
-
66
- def _wrap_streaming_response(self, response_iter: Iterator[Message], execution_id: Optional[str]) -> Iterator[Message]:
67
- """Wrap streaming response to capture final usage stats"""
68
- usage_data: TokenUsageStats = TokenUsageStats(model="", usage=Usage())
69
- for chunk in response_iter:
70
- if isinstance(chunk, RawMessageStartEvent):
71
- usage_data.model = chunk.message.model
72
- usage_data.usage.prompt_tokens = chunk.message.usage.input_tokens
73
- usage_data.usage.completion_tokens = chunk.message.usage.output_tokens
74
- usage_data.usage.total_tokens = chunk.message.usage.input_tokens + chunk.message.usage.output_tokens
75
-
76
- elif isinstance(chunk, RawMessageDeltaEvent):
77
- usage_data.usage.prompt_tokens += chunk.usage.input_tokens
78
- usage_data.usage.completion_tokens += chunk.usage.output_tokens
79
- usage_data.usage.total_tokens += chunk.usage.input_tokens + chunk.usage.output_tokens
80
-
81
- yield chunk
82
-
83
- self._log_usage(usage_data, execution_id=execution_id)
84
-
85
- class AsyncAnthropicWrapper(BaseAnthropicWrapper):
86
- async def create(self, *args: Any, execution_id: Optional[str] = None, **kwargs: Any) -> Union[Message, AsyncIterator[Message]]:
87
- """Create a message completion and log token usage."""
88
- logger.debug("Creating message completion with args: %s, kwargs: %s", args, kwargs)
89
-
90
- if kwargs.get('stream', False):
91
- response = await self.client.messages.create(*args, **kwargs)
92
- return self._wrap_streaming_response(response, execution_id)
93
-
94
- response = await self.client.messages.create(*args, **kwargs)
95
- usage_data = self._process_response_usage(response)
96
- if usage_data:
97
- self._log_usage(usage_data, execution_id=execution_id)
98
- return response
99
-
100
- async def _wrap_streaming_response(self, response_iter: AsyncIterator[Message], execution_id: Optional[str]) -> AsyncIterator[Message]:
101
- """Wrap streaming response to capture final usage stats"""
102
- usage_data: TokenUsageStats = TokenUsageStats(model="", usage=Usage())
103
- async for chunk in response_iter:
104
- if isinstance(chunk, RawMessageStartEvent):
105
- usage_data.model = chunk.message.model
106
- usage_data.usage.prompt_tokens = chunk.message.usage.input_tokens
107
- usage_data.usage.completion_tokens = chunk.message.usage.output_tokens
108
- usage_data.usage.total_tokens = chunk.message.usage.input_tokens + chunk.message.usage.output_tokens
109
-
110
- elif isinstance(chunk, RawMessageDeltaEvent):
111
- usage_data.usage.prompt_tokens += chunk.usage.input_tokens
112
- usage_data.usage.completion_tokens += chunk.usage.output_tokens
113
- usage_data.usage.total_tokens += chunk.usage.input_tokens + chunk.usage.output_tokens
114
-
115
- yield chunk
116
-
117
-
118
- self._log_usage(usage_data, execution_id=execution_id)
119
-
120
- @overload
121
- def tokenator_anthropic(
122
- client: Anthropic,
123
- db_path: Optional[str] = None,
124
- ) -> AnthropicWrapper: ...
125
-
126
- @overload
127
- def tokenator_anthropic(
128
- client: AsyncAnthropic,
129
- db_path: Optional[str] = None,
130
- ) -> AsyncAnthropicWrapper: ...
131
-
132
- def tokenator_anthropic(
133
- client: Union[Anthropic, AsyncAnthropic],
134
- db_path: Optional[str] = None,
135
- ) -> Union[AnthropicWrapper, AsyncAnthropicWrapper]:
136
- """Create a token-tracking wrapper for an Anthropic client.
137
-
138
- Args:
139
- client: Anthropic or AsyncAnthropic client instance
140
- db_path: Optional path to SQLite database for token tracking
141
- """
142
- if isinstance(client, Anthropic):
143
- return AnthropicWrapper(client=client, db_path=db_path)
144
-
145
- if isinstance(client, AsyncAnthropic):
146
- return AsyncAnthropicWrapper(client=client, db_path=db_path)
147
-
148
- raise ValueError("Client must be an instance of Anthropic or AsyncAnthropic")
@@ -1,78 +0,0 @@
1
- import logging
2
- from typing import AsyncIterator, Callable, Generic, List, Optional, TypeVar
3
-
4
- from openai import AsyncStream, AsyncOpenAI
5
- from openai.types.chat import ChatCompletionChunk
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
- _T = TypeVar("_T") # or you might specifically do _T = ChatCompletionChunk
10
-
11
-
12
- class AsyncStreamInterceptor(AsyncStream[_T]):
13
- """
14
- A wrapper around openai.AsyncStream that delegates all functionality
15
- to the 'base_stream' but intercepts each chunk to handle usage or
16
- logging logic. This preserves .response and other methods.
17
-
18
- You can store aggregated usage in a local list and process it when
19
- the stream ends (StopAsyncIteration).
20
- """
21
-
22
- def __init__(
23
- self,
24
- base_stream: AsyncStream[_T],
25
- usage_callback: Optional[Callable[[List[_T]], None]] = None,
26
- ):
27
- # We do NOT call super().__init__() because openai.AsyncStream
28
- # expects constructor parameters we don't want to re-initialize.
29
- # Instead, we just store the base_stream and delegate everything to it.
30
- self._base_stream = base_stream
31
- self._usage_callback = usage_callback
32
- self._chunks: List[_T] = []
33
-
34
- @property
35
- def response(self):
36
- """Expose the original stream's 'response' so user code can do stream.response, etc."""
37
- return self._base_stream.response
38
-
39
- def __aiter__(self) -> AsyncIterator[_T]:
40
- """
41
- Called when we do 'async for chunk in wrapped_stream:'
42
- We simply return 'self'. Then __anext__ does the rest.
43
- """
44
- return self
45
-
46
- async def __anext__(self) -> _T:
47
- """
48
- Intercept iteration. We pull the next chunk from the base_stream.
49
- If it's the end, do any final usage logging, then raise StopAsyncIteration.
50
- Otherwise, we can accumulate usage info or do whatever we need with the chunk.
51
- """
52
- try:
53
- chunk = await self._base_stream.__anext__()
54
- except StopAsyncIteration:
55
- # Once the base stream is fully consumed, we can do final usage/logging.
56
- if self._usage_callback and self._chunks:
57
- self._usage_callback(self._chunks)
58
- raise
59
-
60
- # Intercept each chunk
61
- self._chunks.append(chunk)
62
- return chunk
63
-
64
- async def __aenter__(self) -> "AsyncStreamInterceptor[_T]":
65
- """Support async with ... : usage."""
66
- await self._base_stream.__aenter__()
67
- return self
68
-
69
- async def __aexit__(self, exc_type, exc_val, exc_tb):
70
- """
71
- Ensure we propagate __aexit__ to the base stream,
72
- so connections are properly closed.
73
- """
74
- return await self._base_stream.__aexit__(exc_type, exc_val, exc_tb)
75
-
76
- async def close(self) -> None:
77
- """Delegate close to the base_stream."""
78
- await self._base_stream.close()