llm-cost-guard 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. llm_cost_guard/__init__.py +39 -0
  2. llm_cost_guard/backends/__init__.py +52 -0
  3. llm_cost_guard/backends/base.py +121 -0
  4. llm_cost_guard/backends/memory.py +265 -0
  5. llm_cost_guard/backends/sqlite.py +425 -0
  6. llm_cost_guard/budget.py +306 -0
  7. llm_cost_guard/cli.py +464 -0
  8. llm_cost_guard/clients/__init__.py +11 -0
  9. llm_cost_guard/clients/anthropic.py +231 -0
  10. llm_cost_guard/clients/openai.py +262 -0
  11. llm_cost_guard/exceptions.py +71 -0
  12. llm_cost_guard/integrations/__init__.py +12 -0
  13. llm_cost_guard/integrations/cache.py +189 -0
  14. llm_cost_guard/integrations/langchain.py +257 -0
  15. llm_cost_guard/models.py +123 -0
  16. llm_cost_guard/pricing/__init__.py +7 -0
  17. llm_cost_guard/pricing/anthropic.yaml +88 -0
  18. llm_cost_guard/pricing/bedrock.yaml +215 -0
  19. llm_cost_guard/pricing/loader.py +221 -0
  20. llm_cost_guard/pricing/openai.yaml +148 -0
  21. llm_cost_guard/pricing/vertex.yaml +133 -0
  22. llm_cost_guard/providers/__init__.py +69 -0
  23. llm_cost_guard/providers/anthropic.py +115 -0
  24. llm_cost_guard/providers/base.py +72 -0
  25. llm_cost_guard/providers/bedrock.py +135 -0
  26. llm_cost_guard/providers/openai.py +110 -0
  27. llm_cost_guard/rate_limit.py +233 -0
  28. llm_cost_guard/span.py +143 -0
  29. llm_cost_guard/tokenizers/__init__.py +7 -0
  30. llm_cost_guard/tokenizers/base.py +207 -0
  31. llm_cost_guard/tracker.py +718 -0
  32. llm_cost_guard-0.1.0.dist-info/METADATA +357 -0
  33. llm_cost_guard-0.1.0.dist-info/RECORD +36 -0
  34. llm_cost_guard-0.1.0.dist-info/WHEEL +4 -0
  35. llm_cost_guard-0.1.0.dist-info/entry_points.txt +2 -0
  36. llm_cost_guard-0.1.0.dist-info/licenses/LICENSE +21 -0
llm_cost_guard/cli.py ADDED
@@ -0,0 +1,464 @@
1
+ """
2
+ Command-line interface for LLM Cost Guard.
3
+ """
4
+
5
+ import argparse
6
+ import json
7
+ import sys
8
+ from datetime import datetime, timedelta
9
+ from typing import Any, Dict, Optional
10
+
11
+ from llm_cost_guard import CostTracker
12
+ from llm_cost_guard.pricing.loader import PricingLoader
13
+
14
+
15
+ def create_parser() -> argparse.ArgumentParser:
16
+ """Create the argument parser."""
17
+ parser = argparse.ArgumentParser(
18
+ prog="llm-cost-guard",
19
+ description="LLM Cost Guard - Real-time cost tracking and budget enforcement for LLM applications",
20
+ )
21
+
22
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
23
+
24
+ # status command
25
+ status_parser = subparsers.add_parser("status", help="View current costs and budget status")
26
+ status_parser.add_argument(
27
+ "--backend",
28
+ default="memory",
29
+ help="Backend URL (default: memory)",
30
+ )
31
+
32
+ # health command
33
+ health_parser = subparsers.add_parser("health", help="Check tracker health")
34
+ health_parser.add_argument(
35
+ "--backend",
36
+ default="memory",
37
+ help="Backend URL (default: memory)",
38
+ )
39
+
40
+ # report command
41
+ report_parser = subparsers.add_parser("report", help="Generate cost report")
42
+ report_parser.add_argument(
43
+ "--period",
44
+ choices=["day", "week", "month"],
45
+ default="day",
46
+ help="Report period (default: day)",
47
+ )
48
+ report_parser.add_argument(
49
+ "--group-by",
50
+ nargs="+",
51
+ help="Group by fields (e.g., --group-by model provider)",
52
+ )
53
+ report_parser.add_argument(
54
+ "--format",
55
+ choices=["text", "json", "csv"],
56
+ default="text",
57
+ help="Output format (default: text)",
58
+ )
59
+ report_parser.add_argument(
60
+ "--backend",
61
+ default="memory",
62
+ help="Backend URL (default: memory)",
63
+ )
64
+
65
+ # pricing-status command
66
+ pricing_parser = subparsers.add_parser("pricing-status", help="Check pricing data status")
67
+
68
+ # update-pricing command
69
+ update_parser = subparsers.add_parser("update-pricing", help="Update pricing data")
70
+
71
+ # export command
72
+ export_parser = subparsers.add_parser("export", help="Export cost data")
73
+ export_parser.add_argument(
74
+ "--format",
75
+ choices=["json", "csv"],
76
+ default="json",
77
+ help="Export format (default: json)",
78
+ )
79
+ export_parser.add_argument(
80
+ "--output",
81
+ "-o",
82
+ help="Output file path (default: stdout)",
83
+ )
84
+ export_parser.add_argument(
85
+ "--start-date",
86
+ help="Start date (ISO format)",
87
+ )
88
+ export_parser.add_argument(
89
+ "--end-date",
90
+ help="End date (ISO format)",
91
+ )
92
+ export_parser.add_argument(
93
+ "--backend",
94
+ default="memory",
95
+ help="Backend URL (default: memory)",
96
+ )
97
+
98
+ # validate-config command
99
+ validate_parser = subparsers.add_parser("validate-config", help="Validate configuration")
100
+ validate_parser.add_argument(
101
+ "--config",
102
+ "-c",
103
+ help="Configuration file path",
104
+ )
105
+
106
+ # models command
107
+ models_parser = subparsers.add_parser("models", help="List supported models and pricing")
108
+ models_parser.add_argument(
109
+ "--provider",
110
+ help="Filter by provider (e.g., openai, anthropic, bedrock)",
111
+ )
112
+
113
+ return parser
114
+
115
+
116
+ def cmd_status(args: argparse.Namespace) -> int:
117
+ """Handle the status command."""
118
+ tracker = CostTracker(backend=args.backend)
119
+
120
+ try:
121
+ # Get today's report
122
+ report = tracker.daily_report()
123
+
124
+ print("=" * 50)
125
+ print("LLM Cost Guard - Status")
126
+ print("=" * 50)
127
+ print()
128
+ print(f"Today's Summary ({datetime.now().strftime('%Y-%m-%d')}):")
129
+ print(f" Total Cost: ${report.total_cost:.4f}")
130
+ print(f" Total Calls: {report.total_calls}")
131
+ print(f" Input Tokens: {report.total_input_tokens:,}")
132
+ print(f" Output Tokens: {report.total_output_tokens:,}")
133
+ print(f" Success Rate: {report.successful_calls / max(1, report.total_calls) * 100:.1f}%")
134
+
135
+ if report.cache_hits > 0:
136
+ print(f" Cache Hits: {report.cache_hits}")
137
+ print(f" Cache Savings: ${report.cache_savings:.4f}")
138
+
139
+ print()
140
+
141
+ # Check health
142
+ health = tracker.health_check()
143
+ print("Health Status:")
144
+ print(f" Backend: {'✓ Connected' if health.backend_connected else '✗ Disconnected'}")
145
+ print(f" Pricing: {'✓ Fresh' if health.pricing_fresh else '⚠ Stale'}")
146
+
147
+ if health.errors:
148
+ print(" Errors:")
149
+ for error in health.errors:
150
+ print(f" - {error}")
151
+
152
+ return 0
153
+
154
+ finally:
155
+ tracker.close()
156
+
157
+
158
+ def cmd_health(args: argparse.Namespace) -> int:
159
+ """Handle the health command."""
160
+ tracker = CostTracker(backend=args.backend)
161
+
162
+ try:
163
+ health = tracker.health_check()
164
+
165
+ print("LLM Cost Guard - Health Check")
166
+ print("-" * 40)
167
+ print(f"Overall: {'✓ Healthy' if health.healthy else '✗ Unhealthy'}")
168
+ print(f"Backend: {'✓ Connected' if health.backend_connected else '✗ Disconnected'}")
169
+ print(f"Pricing: {'✓ Fresh' if health.pricing_fresh else '⚠ Stale'}")
170
+
171
+ if health.last_record_time:
172
+ print(f"Last Record: {health.last_record_time.isoformat()}")
173
+
174
+ if health.pricing_last_updated:
175
+ print(f"Pricing Updated: {health.pricing_last_updated.isoformat()}")
176
+
177
+ if health.errors:
178
+ print("\nErrors:")
179
+ for error in health.errors:
180
+ print(f" - {error}")
181
+
182
+ return 0 if health.healthy else 1
183
+
184
+ finally:
185
+ tracker.close()
186
+
187
+
188
+ def cmd_report(args: argparse.Namespace) -> int:
189
+ """Handle the report command."""
190
+ tracker = CostTracker(backend=args.backend)
191
+
192
+ try:
193
+ # Get period start
194
+ now = datetime.now()
195
+ if args.period == "day":
196
+ start = now.replace(hour=0, minute=0, second=0, microsecond=0)
197
+ elif args.period == "week":
198
+ start = now.replace(hour=0, minute=0, second=0, microsecond=0)
199
+ start = start - timedelta(days=now.weekday())
200
+ else: # month
201
+ start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
202
+
203
+ report = tracker.get_costs(
204
+ start_date=start.isoformat(),
205
+ group_by=args.group_by,
206
+ )
207
+
208
+ if args.format == "json":
209
+ output = {
210
+ "period": args.period,
211
+ "start_date": start.isoformat(),
212
+ "end_date": now.isoformat(),
213
+ "total_cost": report.total_cost,
214
+ "total_calls": report.total_calls,
215
+ "total_input_tokens": report.total_input_tokens,
216
+ "total_output_tokens": report.total_output_tokens,
217
+ "successful_calls": report.successful_calls,
218
+ "failed_calls": report.failed_calls,
219
+ }
220
+ if report.grouped_data:
221
+ output["groups"] = report.grouped_data.get("groups", [])
222
+ print(json.dumps(output, indent=2))
223
+
224
+ elif args.format == "csv":
225
+ if report.grouped_data and report.grouped_data.get("groups"):
226
+ groups = report.grouped_data["groups"]
227
+ if groups:
228
+ # Print header
229
+ headers = list(groups[0].keys())
230
+ print(",".join(headers))
231
+ # Print rows
232
+ for group in groups:
233
+ print(",".join(str(group.get(h, "")) for h in headers))
234
+ else:
235
+ print("date,cost,calls,input_tokens,output_tokens")
236
+ print(
237
+ f"{start.strftime('%Y-%m-%d')},{report.total_cost:.4f},"
238
+ f"{report.total_calls},{report.total_input_tokens},{report.total_output_tokens}"
239
+ )
240
+
241
+ else: # text
242
+ print(f"\nCost Report - {args.period.title()}")
243
+ print(f"Period: {start.strftime('%Y-%m-%d')} to {now.strftime('%Y-%m-%d')}")
244
+ print("=" * 60)
245
+ print(f"Total Cost: ${report.total_cost:.4f}")
246
+ print(f"Total Calls: {report.total_calls}")
247
+ print(f"Input Tokens: {report.total_input_tokens:,}")
248
+ print(f"Output Tokens: {report.total_output_tokens:,}")
249
+ print(f"Successful: {report.successful_calls}")
250
+ print(f"Failed: {report.failed_calls}")
251
+
252
+ if report.grouped_data and report.grouped_data.get("groups"):
253
+ print(f"\nBy {', '.join(args.group_by or [])}:")
254
+ print("-" * 60)
255
+ for group in report.grouped_data["groups"]:
256
+ # Build label from group keys
257
+ label_parts = []
258
+ for key in args.group_by or []:
259
+ if key in group:
260
+ label_parts.append(f"{key}={group[key]}")
261
+ label = ", ".join(label_parts)
262
+
263
+ cost = group.get("cost", 0)
264
+ calls = group.get("calls", 0)
265
+ print(f" {label}: ${cost:.4f} ({calls} calls)")
266
+
267
+ return 0
268
+
269
+ finally:
270
+ tracker.close()
271
+
272
+
273
+ def cmd_pricing_status(args: argparse.Namespace) -> int:
274
+ """Handle the pricing-status command."""
275
+ loader = PricingLoader()
276
+
277
+ print("LLM Cost Guard - Pricing Status")
278
+ print("-" * 40)
279
+
280
+ if loader.last_updated:
281
+ print(f"Last Updated: {loader.last_updated.isoformat()}")
282
+ else:
283
+ print("Last Updated: Never")
284
+
285
+ print(f"Stale: {'Yes ⚠' if loader.is_stale else 'No ✓'}")
286
+ print(f"Very Stale: {'Yes ✗' if loader.is_very_stale else 'No ✓'}")
287
+
288
+ print("\nProvider Versions:")
289
+ for provider, version in loader.pricing_version.items():
290
+ print(f" {provider}: {version}")
291
+
292
+ return 0
293
+
294
+
295
+ def cmd_update_pricing(args: argparse.Namespace) -> int:
296
+ """Handle the update-pricing command."""
297
+ print("Refreshing pricing data from local files...")
298
+
299
+ loader = PricingLoader()
300
+ loader.refresh()
301
+
302
+ print("✓ Pricing data refreshed")
303
+ print(f" Last updated: {loader.last_updated.isoformat() if loader.last_updated else 'Never'}")
304
+
305
+ return 0
306
+
307
+
308
+ def cmd_export(args: argparse.Namespace) -> int:
309
+ """Handle the export command."""
310
+ tracker = CostTracker(backend=args.backend)
311
+
312
+ try:
313
+ start = datetime.fromisoformat(args.start_date) if args.start_date else None
314
+ end = datetime.fromisoformat(args.end_date) if args.end_date else None
315
+
316
+ records = tracker._backend.get_records(start_date=start, end_date=end)
317
+
318
+ if args.format == "json":
319
+ data = []
320
+ for r in records:
321
+ data.append(
322
+ {
323
+ "timestamp": r.timestamp.isoformat(),
324
+ "provider": r.provider,
325
+ "model": r.model,
326
+ "input_tokens": r.input_tokens,
327
+ "output_tokens": r.output_tokens,
328
+ "input_cost": r.input_cost,
329
+ "output_cost": r.output_cost,
330
+ "total_cost": r.total_cost,
331
+ "latency_ms": r.latency_ms,
332
+ "success": r.success,
333
+ "error_type": r.error_type,
334
+ "cached": r.cached,
335
+ "tags": r.tags,
336
+ }
337
+ )
338
+ output = json.dumps(data, indent=2)
339
+
340
+ else: # csv
341
+ lines = [
342
+ "timestamp,provider,model,input_tokens,output_tokens,total_cost,latency_ms,success"
343
+ ]
344
+ for r in records:
345
+ lines.append(
346
+ f"{r.timestamp.isoformat()},{r.provider},{r.model},"
347
+ f"{r.input_tokens},{r.output_tokens},{r.total_cost:.6f},"
348
+ f"{r.latency_ms},{r.success}"
349
+ )
350
+ output = "\n".join(lines)
351
+
352
+ if args.output:
353
+ with open(args.output, "w") as f:
354
+ f.write(output)
355
+ print(f"Exported {len(records)} records to {args.output}")
356
+ else:
357
+ print(output)
358
+
359
+ return 0
360
+
361
+ finally:
362
+ tracker.close()
363
+
364
+
365
+ def cmd_validate_config(args: argparse.Namespace) -> int:
366
+ """Handle the validate-config command."""
367
+ if not args.config:
368
+ print("No configuration file specified.")
369
+ print("Use --config to specify a configuration file.")
370
+ return 1
371
+
372
+ try:
373
+ import yaml
374
+
375
+ with open(args.config, "r") as f:
376
+ config = yaml.safe_load(f)
377
+
378
+ print(f"Configuration file: {args.config}")
379
+ print("-" * 40)
380
+
381
+ # Validate budgets
382
+ if "budgets" in config:
383
+ print(f"✓ Found {len(config['budgets'])} budget(s)")
384
+ for budget in config["budgets"]:
385
+ if "name" not in budget:
386
+ print(" ✗ Budget missing 'name' field")
387
+ if "limit" not in budget:
388
+ print(f" ✗ Budget '{budget.get('name', 'unknown')}' missing 'limit' field")
389
+
390
+ # Validate rate limits
391
+ if "rate_limits" in config:
392
+ print(f"✓ Found {len(config['rate_limits'])} rate limit(s)")
393
+
394
+ print("\n✓ Configuration is valid")
395
+ return 0
396
+
397
+ except FileNotFoundError:
398
+ print(f"✗ Configuration file not found: {args.config}")
399
+ return 1
400
+ except yaml.YAMLError as e:
401
+ print(f"✗ Invalid YAML: {e}")
402
+ return 1
403
+ except Exception as e:
404
+ print(f"✗ Error validating configuration: {e}")
405
+ return 1
406
+
407
+
408
+ def cmd_models(args: argparse.Namespace) -> int:
409
+ """Handle the models command."""
410
+ loader = PricingLoader()
411
+
412
+ models = loader.get_all_models(args.provider)
413
+
414
+ print("Supported Models and Pricing")
415
+ print("=" * 70)
416
+
417
+ for provider, model_list in sorted(models.items()):
418
+ print(f"\n{provider.upper()}")
419
+ print("-" * 70)
420
+
421
+ for model in sorted(model_list):
422
+ try:
423
+ pricing = loader.get_pricing(provider, model)
424
+ print(
425
+ f" {model:40} "
426
+ f"Input: ${pricing.input_cost_per_1k:.6f}/1K "
427
+ f"Output: ${pricing.output_cost_per_1k:.6f}/1K"
428
+ )
429
+ except Exception:
430
+ print(f" {model:40} (pricing unavailable)")
431
+
432
+ return 0
433
+
434
+
435
+ def main() -> int:
436
+ """Main entry point for the CLI."""
437
+ parser = create_parser()
438
+ args = parser.parse_args()
439
+
440
+ if not args.command:
441
+ parser.print_help()
442
+ return 0
443
+
444
+ commands = {
445
+ "status": cmd_status,
446
+ "health": cmd_health,
447
+ "report": cmd_report,
448
+ "pricing-status": cmd_pricing_status,
449
+ "update-pricing": cmd_update_pricing,
450
+ "export": cmd_export,
451
+ "validate-config": cmd_validate_config,
452
+ "models": cmd_models,
453
+ }
454
+
455
+ handler = commands.get(args.command)
456
+ if handler:
457
+ return handler(args)
458
+ else:
459
+ print(f"Unknown command: {args.command}")
460
+ return 1
461
+
462
+
463
+ if __name__ == "__main__":
464
+ sys.exit(main())
@@ -0,0 +1,11 @@
1
+ """
2
+ Wrapped LLM clients with automatic cost tracking.
3
+ """
4
+
5
+ from llm_cost_guard.clients.openai import TrackedOpenAI
6
+ from llm_cost_guard.clients.anthropic import TrackedAnthropic
7
+
8
+ __all__ = [
9
+ "TrackedOpenAI",
10
+ "TrackedAnthropic",
11
+ ]
@@ -0,0 +1,231 @@
1
+ """
2
+ Wrapped Anthropic client with automatic cost tracking.
3
+ """
4
+
5
+ import time
6
+ from typing import Any, Dict, Optional, TYPE_CHECKING
7
+
8
+ if TYPE_CHECKING:
9
+ from llm_cost_guard import CostTracker
10
+
11
+
12
+ class TrackedAnthropic:
13
+ """
14
+ Anthropic client wrapper with automatic cost tracking.
15
+
16
+ Usage:
17
+ from llm_cost_guard import CostTracker
18
+ from llm_cost_guard.clients import TrackedAnthropic
19
+
20
+ tracker = CostTracker()
21
+ client = TrackedAnthropic(tracker=tracker)
22
+
23
+ response = client.messages.create(
24
+ model="claude-3-5-sonnet-20241022",
25
+ messages=[{"role": "user", "content": "Hello!"}]
26
+ )
27
+ # Cost is automatically tracked
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ tracker: "CostTracker",
33
+ client: Optional[Any] = None,
34
+ tags: Optional[Dict[str, str]] = None,
35
+ **anthropic_kwargs: Any,
36
+ ):
37
+ """
38
+ Initialize the tracked Anthropic client.
39
+
40
+ Args:
41
+ tracker: CostTracker instance
42
+ client: Optional existing Anthropic client to wrap
43
+ tags: Default tags for all calls
44
+ **anthropic_kwargs: Arguments to pass to Anthropic client
45
+ """
46
+ try:
47
+ from anthropic import Anthropic
48
+ except ImportError:
49
+ raise ImportError(
50
+ "Anthropic is required for this client. "
51
+ "Install with: pip install llm-cost-guard[anthropic]"
52
+ )
53
+
54
+ self._tracker = tracker
55
+ self._default_tags = tags or {}
56
+ self._client = client or Anthropic(**anthropic_kwargs)
57
+
58
+ # Create wrapped interface
59
+ self.messages = _TrackedMessages(self._client.messages, self._tracker, self._default_tags)
60
+
61
+ def close(self) -> None:
62
+ """Close the client."""
63
+ self._client.close()
64
+
65
+ def __enter__(self):
66
+ return self
67
+
68
+ def __exit__(self, *args):
69
+ self.close()
70
+
71
+
72
+ class _TrackedMessages:
73
+ """Wrapped messages API."""
74
+
75
+ def __init__(self, messages, tracker: "CostTracker", default_tags: Dict[str, str]):
76
+ self._messages = messages
77
+ self._tracker = tracker
78
+ self._default_tags = default_tags
79
+
80
+ def create(
81
+ self,
82
+ *,
83
+ tags: Optional[Dict[str, str]] = None,
84
+ **kwargs: Any,
85
+ ) -> Any:
86
+ """Create a message with tracking."""
87
+ start_time = time.time()
88
+ success = True
89
+ error_type = None
90
+ response = None
91
+
92
+ try:
93
+ response = self._messages.create(**kwargs)
94
+ return response
95
+ except Exception as e:
96
+ success = False
97
+ error_type = type(e).__name__
98
+ raise
99
+ finally:
100
+ latency_ms = int((time.time() - start_time) * 1000)
101
+
102
+ if response is not None:
103
+ self._record_response(
104
+ response, kwargs.get("model"), tags, success, error_type, latency_ms
105
+ )
106
+
107
+ def _record_response(
108
+ self,
109
+ response: Any,
110
+ model_hint: Optional[str],
111
+ tags: Optional[Dict[str, str]],
112
+ success: bool,
113
+ error_type: Optional[str],
114
+ latency_ms: int,
115
+ ) -> None:
116
+ """Record the response with the tracker."""
117
+ from llm_cost_guard.providers.anthropic import AnthropicProvider
118
+
119
+ provider = AnthropicProvider()
120
+ usage = provider.extract_usage(response)
121
+ model = provider.extract_model(response)
122
+
123
+ if model == "unknown" and model_hint:
124
+ model = model_hint
125
+
126
+ all_tags = dict(self._default_tags)
127
+ if tags:
128
+ all_tags.update(tags)
129
+
130
+ self._tracker.record(
131
+ provider="anthropic",
132
+ model=model,
133
+ input_tokens=usage.input_tokens,
134
+ output_tokens=usage.output_tokens,
135
+ tags=all_tags,
136
+ success=success,
137
+ error_type=error_type,
138
+ latency_ms=latency_ms,
139
+ cached_tokens=usage.cached_tokens,
140
+ )
141
+
142
+ def stream(
143
+ self,
144
+ *,
145
+ tags: Optional[Dict[str, str]] = None,
146
+ **kwargs: Any,
147
+ ) -> Any:
148
+ """Create a streaming message with tracking."""
149
+ # For streaming, we wrap the stream to track after completion
150
+ start_time = time.time()
151
+
152
+ stream = self._messages.stream(**kwargs)
153
+
154
+ # Wrap in a tracking context
155
+ return _TrackedStream(
156
+ stream,
157
+ self._tracker,
158
+ self._default_tags,
159
+ tags,
160
+ kwargs.get("model"),
161
+ start_time,
162
+ )
163
+
164
+
165
+ class _TrackedStream:
166
+ """Wrapper for streaming responses."""
167
+
168
+ def __init__(
169
+ self,
170
+ stream: Any,
171
+ tracker: "CostTracker",
172
+ default_tags: Dict[str, str],
173
+ tags: Optional[Dict[str, str]],
174
+ model_hint: Optional[str],
175
+ start_time: float,
176
+ ):
177
+ self._stream = stream
178
+ self._tracker = tracker
179
+ self._default_tags = default_tags
180
+ self._tags = tags
181
+ self._model_hint = model_hint
182
+ self._start_time = start_time
183
+ self._input_tokens = 0
184
+ self._output_tokens = 0
185
+ self._model = "unknown"
186
+
187
+ def __enter__(self):
188
+ self._stream.__enter__()
189
+ return self
190
+
191
+ def __exit__(self, *args):
192
+ self._stream.__exit__(*args)
193
+
194
+ # Record the call
195
+ latency_ms = int((time.time() - self._start_time) * 1000)
196
+
197
+ all_tags = dict(self._default_tags)
198
+ if self._tags:
199
+ all_tags.update(self._tags)
200
+
201
+ model = self._model if self._model != "unknown" else (self._model_hint or "unknown")
202
+
203
+ self._tracker.record(
204
+ provider="anthropic",
205
+ model=model,
206
+ input_tokens=self._input_tokens,
207
+ output_tokens=self._output_tokens,
208
+ tags=all_tags,
209
+ success=True,
210
+ latency_ms=latency_ms,
211
+ )
212
+
213
+ def __iter__(self):
214
+ for event in self._stream:
215
+ # Track usage from events
216
+ self._handle_event(event)
217
+ yield event
218
+
219
+ def _handle_event(self, event: Any) -> None:
220
+ """Extract usage info from streaming event."""
221
+ event_type = getattr(event, "type", "")
222
+
223
+ if event_type == "message_start":
224
+ if hasattr(event, "message"):
225
+ self._model = getattr(event.message, "model", self._model)
226
+ if hasattr(event.message, "usage"):
227
+ self._input_tokens = getattr(event.message.usage, "input_tokens", 0)
228
+
229
+ elif event_type == "message_delta":
230
+ if hasattr(event, "usage"):
231
+ self._output_tokens = getattr(event.usage, "output_tokens", 0)