llm-cost-guard 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. llm_cost_guard/__init__.py +39 -0
  2. llm_cost_guard/backends/__init__.py +52 -0
  3. llm_cost_guard/backends/base.py +121 -0
  4. llm_cost_guard/backends/memory.py +265 -0
  5. llm_cost_guard/backends/sqlite.py +425 -0
  6. llm_cost_guard/budget.py +306 -0
  7. llm_cost_guard/cli.py +464 -0
  8. llm_cost_guard/clients/__init__.py +11 -0
  9. llm_cost_guard/clients/anthropic.py +231 -0
  10. llm_cost_guard/clients/openai.py +262 -0
  11. llm_cost_guard/exceptions.py +71 -0
  12. llm_cost_guard/integrations/__init__.py +12 -0
  13. llm_cost_guard/integrations/cache.py +189 -0
  14. llm_cost_guard/integrations/langchain.py +257 -0
  15. llm_cost_guard/models.py +123 -0
  16. llm_cost_guard/pricing/__init__.py +7 -0
  17. llm_cost_guard/pricing/anthropic.yaml +88 -0
  18. llm_cost_guard/pricing/bedrock.yaml +215 -0
  19. llm_cost_guard/pricing/loader.py +221 -0
  20. llm_cost_guard/pricing/openai.yaml +148 -0
  21. llm_cost_guard/pricing/vertex.yaml +133 -0
  22. llm_cost_guard/providers/__init__.py +69 -0
  23. llm_cost_guard/providers/anthropic.py +115 -0
  24. llm_cost_guard/providers/base.py +72 -0
  25. llm_cost_guard/providers/bedrock.py +135 -0
  26. llm_cost_guard/providers/openai.py +110 -0
  27. llm_cost_guard/rate_limit.py +233 -0
  28. llm_cost_guard/span.py +143 -0
  29. llm_cost_guard/tokenizers/__init__.py +7 -0
  30. llm_cost_guard/tokenizers/base.py +207 -0
  31. llm_cost_guard/tracker.py +718 -0
  32. llm_cost_guard-0.1.0.dist-info/METADATA +357 -0
  33. llm_cost_guard-0.1.0.dist-info/RECORD +36 -0
  34. llm_cost_guard-0.1.0.dist-info/WHEEL +4 -0
  35. llm_cost_guard-0.1.0.dist-info/entry_points.txt +2 -0
  36. llm_cost_guard-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,72 @@
1
+ """
2
+ Base provider interface for LLM Cost Guard.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Dict, Optional
7
+
8
+ from llm_cost_guard.models import UsageData
9
+
10
+
11
+ class Provider(ABC):
12
+ """Abstract base class for LLM providers."""
13
+
14
+ @property
15
+ @abstractmethod
16
+ def name(self) -> str:
17
+ """Get the provider name."""
18
+ pass
19
+
20
+ @abstractmethod
21
+ def extract_usage(self, response: Any) -> UsageData:
22
+ """
23
+ Extract token usage from an API response.
24
+
25
+ Args:
26
+ response: The raw API response object
27
+
28
+ Returns:
29
+ UsageData with token counts
30
+ """
31
+ pass
32
+
33
+ @abstractmethod
34
+ def extract_model(self, response: Any) -> str:
35
+ """
36
+ Extract the model name from an API response.
37
+
38
+ Args:
39
+ response: The raw API response object
40
+
41
+ Returns:
42
+ Model name string
43
+ """
44
+ pass
45
+
46
+ def extract_cached_tokens(self, response: Any) -> int:
47
+ """
48
+ Extract cached token count from an API response.
49
+
50
+ Args:
51
+ response: The raw API response object
52
+
53
+ Returns:
54
+ Number of cached tokens (0 if not applicable)
55
+ """
56
+ return 0
57
+
58
+ def supports_streaming(self) -> bool:
59
+ """Check if this provider supports streaming."""
60
+ return True
61
+
62
+ def normalize_model_name(self, model: str) -> str:
63
+ """
64
+ Normalize model name for consistent pricing lookup.
65
+
66
+ Args:
67
+ model: Raw model name from API
68
+
69
+ Returns:
70
+ Normalized model name
71
+ """
72
+ return model
@@ -0,0 +1,135 @@
1
+ """
2
+ AWS Bedrock provider for LLM Cost Guard.
3
+ """
4
+
5
+ import json
6
+ from typing import Any
7
+
8
+ from llm_cost_guard.models import UsageData
9
+ from llm_cost_guard.providers.base import Provider
10
+
11
+
12
+ class BedrockProvider(Provider):
13
+ """AWS Bedrock API provider."""
14
+
15
+ @property
16
+ def name(self) -> str:
17
+ return "bedrock"
18
+
19
+ def extract_usage(self, response: Any) -> UsageData:
20
+ """Extract token usage from a Bedrock API response."""
21
+ usage = UsageData()
22
+
23
+ # Handle dictionary response (from boto3)
24
+ if isinstance(response, dict):
25
+ # Check for standard usage field
26
+ if "usage" in response:
27
+ usage_data = response["usage"]
28
+ usage.input_tokens = usage_data.get("inputTokens", 0)
29
+ usage.output_tokens = usage_data.get("outputTokens", 0)
30
+ usage.total_tokens = usage_data.get("totalTokens", 0)
31
+ return usage
32
+
33
+ # Check for Bedrock response metadata
34
+ metadata = response.get("ResponseMetadata", {})
35
+ headers = metadata.get("HTTPHeaders", {})
36
+
37
+ # Bedrock includes token counts in headers for some models
38
+ if "x-amzn-bedrock-input-token-count" in headers:
39
+ usage.input_tokens = int(headers["x-amzn-bedrock-input-token-count"])
40
+ if "x-amzn-bedrock-output-token-count" in headers:
41
+ usage.output_tokens = int(headers["x-amzn-bedrock-output-token-count"])
42
+
43
+ # Also check the body for Claude-style responses
44
+ body = response.get("body", {})
45
+ if isinstance(body, bytes):
46
+ try:
47
+ body = json.loads(body.decode("utf-8"))
48
+ except (json.JSONDecodeError, UnicodeDecodeError):
49
+ body = {}
50
+
51
+ if isinstance(body, dict) and "usage" in body:
52
+ body_usage = body["usage"]
53
+ usage.input_tokens = body_usage.get(
54
+ "input_tokens", body_usage.get("inputTokens", usage.input_tokens)
55
+ )
56
+ usage.output_tokens = body_usage.get(
57
+ "output_tokens", body_usage.get("outputTokens", usage.output_tokens)
58
+ )
59
+
60
+ usage.total_tokens = usage.input_tokens + usage.output_tokens
61
+
62
+ return usage
63
+
64
+ def extract_model(self, response: Any) -> str:
65
+ """Extract the model name from a Bedrock API response."""
66
+ if isinstance(response, dict):
67
+ # Check for model ID in various locations
68
+ if "modelId" in response:
69
+ return response["modelId"]
70
+
71
+ # Check response metadata
72
+ metadata = response.get("ResponseMetadata", {})
73
+ if "modelId" in metadata:
74
+ return metadata["modelId"]
75
+
76
+ return "unknown"
77
+
78
+ def normalize_model_name(self, model: str) -> str:
79
+ """Normalize Bedrock model name for pricing lookup."""
80
+ # Bedrock model IDs include version info that should be preserved
81
+ # e.g., "anthropic.claude-3-sonnet-20240229-v1:0"
82
+ return model
83
+
84
+
85
+ class BedrockStreamingHandler:
86
+ """Handler for streaming Bedrock responses."""
87
+
88
+ def __init__(self, model_id: str = "unknown"):
89
+ self.input_tokens = 0
90
+ self.output_tokens = 0
91
+ self.model = model_id
92
+ self._chunks: list = []
93
+
94
+ def handle_chunk(self, chunk: Any) -> None:
95
+ """Process a streaming chunk."""
96
+ if isinstance(chunk, dict):
97
+ # Handle Bedrock's streaming chunk format
98
+ if "chunk" in chunk:
99
+ chunk_data = chunk["chunk"]
100
+ if "bytes" in chunk_data:
101
+ try:
102
+ parsed = json.loads(chunk_data["bytes"].decode("utf-8"))
103
+ self._process_parsed_chunk(parsed)
104
+ except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
105
+ pass
106
+
107
+ # Some responses have direct usage info
108
+ if "usage" in chunk:
109
+ usage = chunk["usage"]
110
+ self.input_tokens = usage.get("inputTokens", self.input_tokens)
111
+ self.output_tokens = usage.get("outputTokens", self.output_tokens)
112
+
113
+ self._chunks.append(chunk)
114
+
115
+ def _process_parsed_chunk(self, parsed: dict) -> None:
116
+ """Process a parsed JSON chunk."""
117
+ # Claude on Bedrock format
118
+ if "usage" in parsed:
119
+ usage = parsed["usage"]
120
+ self.input_tokens = usage.get("input_tokens", self.input_tokens)
121
+ self.output_tokens = usage.get("output_tokens", self.output_tokens)
122
+
123
+ # Llama/Titan format
124
+ if "amazon-bedrock-invocationMetrics" in parsed:
125
+ metrics = parsed["amazon-bedrock-invocationMetrics"]
126
+ self.input_tokens = metrics.get("inputTokenCount", self.input_tokens)
127
+ self.output_tokens = metrics.get("outputTokenCount", self.output_tokens)
128
+
129
+ def get_usage(self) -> UsageData:
130
+ """Get final usage data."""
131
+ return UsageData(
132
+ input_tokens=self.input_tokens,
133
+ output_tokens=self.output_tokens,
134
+ total_tokens=self.input_tokens + self.output_tokens,
135
+ )
@@ -0,0 +1,110 @@
1
+ """
2
+ OpenAI provider for LLM Cost Guard.
3
+ """
4
+
5
+ from typing import Any, Optional
6
+
7
+ from llm_cost_guard.models import UsageData
8
+ from llm_cost_guard.providers.base import Provider
9
+
10
+
11
+ class OpenAIProvider(Provider):
12
+ """OpenAI API provider."""
13
+
14
+ @property
15
+ def name(self) -> str:
16
+ return "openai"
17
+
18
+ def extract_usage(self, response: Any) -> UsageData:
19
+ """Extract token usage from an OpenAI API response."""
20
+ usage = UsageData()
21
+
22
+ # Handle dictionary response
23
+ if isinstance(response, dict):
24
+ usage_data = response.get("usage", {})
25
+ usage.input_tokens = usage_data.get("prompt_tokens", 0)
26
+ usage.output_tokens = usage_data.get("completion_tokens", 0)
27
+ usage.total_tokens = usage_data.get("total_tokens", 0)
28
+
29
+ # Check for cached tokens (prompt caching)
30
+ prompt_tokens_details = usage_data.get("prompt_tokens_details", {})
31
+ if prompt_tokens_details:
32
+ usage.cached_tokens = prompt_tokens_details.get("cached_tokens", 0)
33
+
34
+ return usage
35
+
36
+ # Handle OpenAI client response object
37
+ if hasattr(response, "usage") and response.usage is not None:
38
+ usage.input_tokens = getattr(response.usage, "prompt_tokens", 0) or 0
39
+ usage.output_tokens = getattr(response.usage, "completion_tokens", 0) or 0
40
+ usage.total_tokens = getattr(response.usage, "total_tokens", 0) or 0
41
+
42
+ # Check for prompt caching details
43
+ if hasattr(response.usage, "prompt_tokens_details"):
44
+ details = response.usage.prompt_tokens_details
45
+ if details and hasattr(details, "cached_tokens"):
46
+ usage.cached_tokens = details.cached_tokens or 0
47
+
48
+ return usage
49
+
50
+ def extract_model(self, response: Any) -> str:
51
+ """Extract the model name from an OpenAI API response."""
52
+ if isinstance(response, dict):
53
+ return response.get("model", "unknown")
54
+
55
+ if hasattr(response, "model"):
56
+ return response.model or "unknown"
57
+
58
+ return "unknown"
59
+
60
+ def extract_cached_tokens(self, response: Any) -> int:
61
+ """Extract cached token count from an OpenAI API response."""
62
+ usage = self.extract_usage(response)
63
+ return usage.cached_tokens
64
+
65
+ def normalize_model_name(self, model: str) -> str:
66
+ """Normalize OpenAI model name."""
67
+ # Remove date suffixes for pricing lookup
68
+ # e.g., "gpt-4-0613" -> "gpt-4"
69
+ parts = model.split("-")
70
+ if len(parts) >= 3 and parts[-1].isdigit() and len(parts[-1]) >= 4:
71
+ return "-".join(parts[:-1])
72
+ return model
73
+
74
+
75
+ class OpenAIStreamingHandler:
76
+ """Handler for streaming OpenAI responses."""
77
+
78
+ def __init__(self):
79
+ self.input_tokens = 0
80
+ self.output_tokens = 0
81
+ self.model = "unknown"
82
+ self._chunk_count = 0
83
+
84
+ def handle_chunk(self, chunk: Any) -> None:
85
+ """Process a streaming chunk."""
86
+ self._chunk_count += 1
87
+
88
+ # Extract model from first chunk
89
+ if self._chunk_count == 1:
90
+ if isinstance(chunk, dict):
91
+ self.model = chunk.get("model", self.model)
92
+ elif hasattr(chunk, "model"):
93
+ self.model = chunk.model or self.model
94
+
95
+ # Some providers include usage in the final chunk
96
+ if isinstance(chunk, dict):
97
+ if "usage" in chunk and chunk["usage"]:
98
+ self.input_tokens = chunk["usage"].get("prompt_tokens", 0)
99
+ self.output_tokens = chunk["usage"].get("completion_tokens", 0)
100
+ elif hasattr(chunk, "usage") and chunk.usage:
101
+ self.input_tokens = getattr(chunk.usage, "prompt_tokens", 0) or 0
102
+ self.output_tokens = getattr(chunk.usage, "completion_tokens", 0) or 0
103
+
104
+ def get_usage(self) -> UsageData:
105
+ """Get final usage data."""
106
+ return UsageData(
107
+ input_tokens=self.input_tokens,
108
+ output_tokens=self.output_tokens,
109
+ total_tokens=self.input_tokens + self.output_tokens,
110
+ )
@@ -0,0 +1,233 @@
1
+ """
2
+ Rate limiting for LLM Cost Guard.
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from datetime import datetime, timedelta
7
+ from typing import Dict, List, Literal, Optional, Tuple
8
+ import threading
9
+ import time
10
+
11
+ RateLimitPeriod = Literal["second", "minute", "hour"]
12
+ RateLimitScope = Literal["global", "model", "provider"]
13
+
14
+
15
+ @dataclass
16
+ class RateLimit:
17
+ """Rate limit configuration."""
18
+
19
+ name: str
20
+ limit: int
21
+ period: RateLimitPeriod = "minute"
22
+ scope: str = "global" # "global", "model", "provider", or "tag:key_name"
23
+
24
+
25
+ class SlidingWindowCounter:
26
+ """Sliding window rate limiter implementation."""
27
+
28
+ def __init__(self, window_size_seconds: float, limit: int):
29
+ self._window_size = window_size_seconds
30
+ self._limit = limit
31
+ self._requests: List[float] = []
32
+ self._lock = threading.Lock()
33
+
34
+ def _cleanup(self, now: float) -> None:
35
+ """Remove expired entries."""
36
+ cutoff = now - self._window_size
37
+ self._requests = [t for t in self._requests if t > cutoff]
38
+
39
+ def check(self) -> Tuple[bool, int, Optional[float]]:
40
+ """
41
+ Check if a request would be allowed.
42
+ Returns (allowed, current_count, retry_after_seconds).
43
+ """
44
+ now = time.time()
45
+ with self._lock:
46
+ self._cleanup(now)
47
+ current = len(self._requests)
48
+
49
+ if current >= self._limit:
50
+ # Calculate when the oldest request will expire
51
+ if self._requests:
52
+ retry_after = self._requests[0] + self._window_size - now
53
+ return False, current, max(0.0, retry_after)
54
+ return False, current, self._window_size
55
+
56
+ return True, current, None
57
+
58
+ def record(self) -> bool:
59
+ """
60
+ Record a request. Returns True if allowed, False if rate limited.
61
+ """
62
+ now = time.time()
63
+ with self._lock:
64
+ self._cleanup(now)
65
+
66
+ if len(self._requests) >= self._limit:
67
+ return False
68
+
69
+ self._requests.append(now)
70
+ return True
71
+
72
+ def get_count(self) -> int:
73
+ """Get current request count in the window."""
74
+ now = time.time()
75
+ with self._lock:
76
+ self._cleanup(now)
77
+ return len(self._requests)
78
+
79
+ def reset(self) -> None:
80
+ """Reset the counter."""
81
+ with self._lock:
82
+ self._requests = []
83
+
84
+
85
+ class RateLimiter:
86
+ """Manages rate limiting across multiple limits and scopes."""
87
+
88
+ def __init__(self, rate_limits: Optional[List[RateLimit]] = None):
89
+ self._rate_limits = rate_limits or []
90
+ # Key: (limit_name, scope_value) -> SlidingWindowCounter
91
+ self._counters: Dict[Tuple[str, str], SlidingWindowCounter] = {}
92
+ self._lock = threading.Lock()
93
+
94
+ def _get_period_seconds(self, period: RateLimitPeriod) -> float:
95
+ """Get period in seconds."""
96
+ periods = {
97
+ "second": 1.0,
98
+ "minute": 60.0,
99
+ "hour": 3600.0,
100
+ }
101
+ return periods.get(period, 60.0)
102
+
103
+ def _get_scope_key(
104
+ self,
105
+ rate_limit: RateLimit,
106
+ model: Optional[str] = None,
107
+ provider: Optional[str] = None,
108
+ tags: Optional[Dict[str, str]] = None,
109
+ ) -> str:
110
+ """Get the scope key for a rate limit."""
111
+ scope = rate_limit.scope
112
+ tags = tags or {}
113
+
114
+ if scope == "global":
115
+ return "global"
116
+ elif scope == "model":
117
+ return f"model:{model or 'unknown'}"
118
+ elif scope == "provider":
119
+ return f"provider:{provider or 'unknown'}"
120
+ elif scope.startswith("tag:"):
121
+ tag_key = scope[4:]
122
+ tag_value = tags.get(tag_key, "unknown")
123
+ return f"tag:{tag_key}:{tag_value}"
124
+ return "global"
125
+
126
+ def _get_counter(
127
+ self,
128
+ rate_limit: RateLimit,
129
+ scope_key: str,
130
+ ) -> SlidingWindowCounter:
131
+ """Get or create a counter for a rate limit and scope."""
132
+ key = (rate_limit.name, scope_key)
133
+
134
+ with self._lock:
135
+ if key not in self._counters:
136
+ window_size = self._get_period_seconds(rate_limit.period)
137
+ self._counters[key] = SlidingWindowCounter(window_size, rate_limit.limit)
138
+ return self._counters[key]
139
+
140
+ def add_rate_limit(self, rate_limit: RateLimit) -> None:
141
+ """Add a new rate limit."""
142
+ self._rate_limits.append(rate_limit)
143
+
144
+ def remove_rate_limit(self, name: str) -> bool:
145
+ """Remove a rate limit by name."""
146
+ for i, rl in enumerate(self._rate_limits):
147
+ if rl.name == name:
148
+ del self._rate_limits[i]
149
+ # Remove associated counters
150
+ with self._lock:
151
+ keys_to_remove = [k for k in self._counters if k[0] == name]
152
+ for k in keys_to_remove:
153
+ del self._counters[k]
154
+ return True
155
+ return False
156
+
157
+ def get_rate_limit(self, name: str) -> Optional[RateLimit]:
158
+ """Get a rate limit by name."""
159
+ for rl in self._rate_limits:
160
+ if rl.name == name:
161
+ return rl
162
+ return None
163
+
164
+ def check(
165
+ self,
166
+ model: Optional[str] = None,
167
+ provider: Optional[str] = None,
168
+ tags: Optional[Dict[str, str]] = None,
169
+ ) -> List[Tuple[RateLimit, int, Optional[float]]]:
170
+ """
171
+ Check all rate limits.
172
+ Returns list of (rate_limit, current_count, retry_after) for exceeded limits.
173
+ """
174
+ exceeded = []
175
+
176
+ for rate_limit in self._rate_limits:
177
+ scope_key = self._get_scope_key(rate_limit, model, provider, tags)
178
+ counter = self._get_counter(rate_limit, scope_key)
179
+ allowed, current, retry_after = counter.check()
180
+
181
+ if not allowed:
182
+ exceeded.append((rate_limit, current, retry_after))
183
+
184
+ return exceeded
185
+
186
+ def record(
187
+ self,
188
+ model: Optional[str] = None,
189
+ provider: Optional[str] = None,
190
+ tags: Optional[Dict[str, str]] = None,
191
+ ) -> List[Tuple[RateLimit, int, Optional[float]]]:
192
+ """
193
+ Record a request against all applicable rate limits.
194
+ Returns list of (rate_limit, current_count, retry_after) for limits that are now exceeded.
195
+ """
196
+ exceeded = []
197
+
198
+ for rate_limit in self._rate_limits:
199
+ scope_key = self._get_scope_key(rate_limit, model, provider, tags)
200
+ counter = self._get_counter(rate_limit, scope_key)
201
+
202
+ if not counter.record():
203
+ _, current, retry_after = counter.check()
204
+ exceeded.append((rate_limit, current, retry_after))
205
+
206
+ return exceeded
207
+
208
+ def get_remaining(
209
+ self,
210
+ name: str,
211
+ model: Optional[str] = None,
212
+ provider: Optional[str] = None,
213
+ tags: Optional[Dict[str, str]] = None,
214
+ ) -> int:
215
+ """Get remaining requests for a rate limit."""
216
+ rate_limit = self.get_rate_limit(name)
217
+ if rate_limit is None:
218
+ return 0
219
+
220
+ scope_key = self._get_scope_key(rate_limit, model, provider, tags)
221
+ counter = self._get_counter(rate_limit, scope_key)
222
+ return max(0, rate_limit.limit - counter.get_count())
223
+
224
+ def reset(self, name: Optional[str] = None) -> None:
225
+ """Reset counters for a specific rate limit or all rate limits."""
226
+ with self._lock:
227
+ if name:
228
+ keys_to_reset = [k for k in self._counters if k[0] == name]
229
+ for k in keys_to_reset:
230
+ self._counters[k].reset()
231
+ else:
232
+ for counter in self._counters.values():
233
+ counter.reset()
llm_cost_guard/span.py ADDED
@@ -0,0 +1,143 @@
1
+ """
2
+ Hierarchical tracking spans for LLM Cost Guard.
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from datetime import datetime
7
+ from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
8
+ import contextvars
9
+ import threading
10
+ import uuid
11
+
12
+ if TYPE_CHECKING:
13
+ from llm_cost_guard.models import CostRecord
14
+
15
+
16
+ # Context variable to track the current span
17
+ _current_span: contextvars.ContextVar[Optional["Span"]] = contextvars.ContextVar(
18
+ "current_span", default=None
19
+ )
20
+
21
+
22
+ @dataclass
23
+ class Span:
24
+ """Hierarchical tracking span for grouping multiple LLM calls."""
25
+
26
+ name: str
27
+ span_id: str = field(default_factory=lambda: str(uuid.uuid4()))
28
+ parent_id: Optional[str] = None
29
+ start_time: Optional[datetime] = None
30
+ end_time: Optional[datetime] = None
31
+ tags: Dict[str, str] = field(default_factory=dict)
32
+ metadata: Dict[str, Any] = field(default_factory=dict)
33
+
34
+ # Aggregated metrics
35
+ total_cost: float = 0.0
36
+ total_input_tokens: int = 0
37
+ total_output_tokens: int = 0
38
+ call_count: int = 0
39
+ models_used: Set[str] = field(default_factory=set)
40
+
41
+ # Hierarchy
42
+ children: List["Span"] = field(default_factory=list)
43
+ _records: List["CostRecord"] = field(default_factory=list)
44
+ _lock: threading.Lock = field(default_factory=threading.Lock)
45
+
46
+ # For context management
47
+ _token: Optional[contextvars.Token] = field(default=None, repr=False)
48
+ _previous_span: Optional["Span"] = field(default=None, repr=False)
49
+
50
+ def __enter__(self) -> "Span":
51
+ """Enter the span context."""
52
+ self.start_time = datetime.now()
53
+
54
+ # Save previous span and set this as current
55
+ self._previous_span = _current_span.get()
56
+ self._token = _current_span.set(self)
57
+
58
+ # If there's a parent span, register as child
59
+ if self._previous_span is not None:
60
+ self.parent_id = self._previous_span.span_id
61
+ self._previous_span._add_child(self)
62
+
63
+ return self
64
+
65
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
66
+ """Exit the span context."""
67
+ self.end_time = datetime.now()
68
+
69
+ # Restore previous span
70
+ if self._token is not None:
71
+ _current_span.reset(self._token)
72
+
73
+ # Propagate costs to parent
74
+ if self._previous_span is not None:
75
+ self._previous_span._propagate_child_costs(self)
76
+
77
+ def _add_child(self, child: "Span") -> None:
78
+ """Add a child span."""
79
+ with self._lock:
80
+ self.children.append(child)
81
+
82
+ def _propagate_child_costs(self, child: "Span") -> None:
83
+ """Propagate child span costs to this span."""
84
+ with self._lock:
85
+ self.total_cost += child.total_cost
86
+ self.total_input_tokens += child.total_input_tokens
87
+ self.total_output_tokens += child.total_output_tokens
88
+ self.call_count += child.call_count
89
+ self.models_used.update(child.models_used)
90
+
91
+ def record_call(
92
+ self,
93
+ cost: float,
94
+ input_tokens: int,
95
+ output_tokens: int,
96
+ model: str,
97
+ record: Optional["CostRecord"] = None,
98
+ ) -> None:
99
+ """Record an LLM call in this span."""
100
+ with self._lock:
101
+ self.total_cost += cost
102
+ self.total_input_tokens += input_tokens
103
+ self.total_output_tokens += output_tokens
104
+ self.call_count += 1
105
+ self.models_used.add(model)
106
+ if record is not None:
107
+ self._records.append(record)
108
+
109
+ @property
110
+ def duration_ms(self) -> Optional[int]:
111
+ """Get the span duration in milliseconds."""
112
+ if self.start_time is None or self.end_time is None:
113
+ return None
114
+ delta = self.end_time - self.start_time
115
+ return int(delta.total_seconds() * 1000)
116
+
117
+ @property
118
+ def records(self) -> List["CostRecord"]:
119
+ """Get all records in this span."""
120
+ return list(self._records)
121
+
122
+ def to_dict(self) -> Dict[str, Any]:
123
+ """Convert span to dictionary."""
124
+ return {
125
+ "name": self.name,
126
+ "span_id": self.span_id,
127
+ "parent_id": self.parent_id,
128
+ "start_time": self.start_time.isoformat() if self.start_time else None,
129
+ "end_time": self.end_time.isoformat() if self.end_time else None,
130
+ "duration_ms": self.duration_ms,
131
+ "tags": self.tags,
132
+ "total_cost": self.total_cost,
133
+ "total_input_tokens": self.total_input_tokens,
134
+ "total_output_tokens": self.total_output_tokens,
135
+ "call_count": self.call_count,
136
+ "models_used": list(self.models_used),
137
+ "children": [child.to_dict() for child in self.children],
138
+ }
139
+
140
+
141
+ def get_current_span() -> Optional[Span]:
142
+ """Get the current active span, if any."""
143
+ return _current_span.get()