llm-cost-guard 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_cost_guard/__init__.py +39 -0
- llm_cost_guard/backends/__init__.py +52 -0
- llm_cost_guard/backends/base.py +121 -0
- llm_cost_guard/backends/memory.py +265 -0
- llm_cost_guard/backends/sqlite.py +425 -0
- llm_cost_guard/budget.py +306 -0
- llm_cost_guard/cli.py +464 -0
- llm_cost_guard/clients/__init__.py +11 -0
- llm_cost_guard/clients/anthropic.py +231 -0
- llm_cost_guard/clients/openai.py +262 -0
- llm_cost_guard/exceptions.py +71 -0
- llm_cost_guard/integrations/__init__.py +12 -0
- llm_cost_guard/integrations/cache.py +189 -0
- llm_cost_guard/integrations/langchain.py +257 -0
- llm_cost_guard/models.py +123 -0
- llm_cost_guard/pricing/__init__.py +7 -0
- llm_cost_guard/pricing/anthropic.yaml +88 -0
- llm_cost_guard/pricing/bedrock.yaml +215 -0
- llm_cost_guard/pricing/loader.py +221 -0
- llm_cost_guard/pricing/openai.yaml +148 -0
- llm_cost_guard/pricing/vertex.yaml +133 -0
- llm_cost_guard/providers/__init__.py +69 -0
- llm_cost_guard/providers/anthropic.py +115 -0
- llm_cost_guard/providers/base.py +72 -0
- llm_cost_guard/providers/bedrock.py +135 -0
- llm_cost_guard/providers/openai.py +110 -0
- llm_cost_guard/rate_limit.py +233 -0
- llm_cost_guard/span.py +143 -0
- llm_cost_guard/tokenizers/__init__.py +7 -0
- llm_cost_guard/tokenizers/base.py +207 -0
- llm_cost_guard/tracker.py +718 -0
- llm_cost_guard-0.1.0.dist-info/METADATA +357 -0
- llm_cost_guard-0.1.0.dist-info/RECORD +36 -0
- llm_cost_guard-0.1.0.dist-info/WHEEL +4 -0
- llm_cost_guard-0.1.0.dist-info/entry_points.txt +2 -0
- llm_cost_guard-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base provider interface for LLM Cost Guard.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any, Dict, Optional
|
|
7
|
+
|
|
8
|
+
from llm_cost_guard.models import UsageData
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Provider(ABC):
|
|
12
|
+
"""Abstract base class for LLM providers."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def name(self) -> str:
|
|
17
|
+
"""Get the provider name."""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def extract_usage(self, response: Any) -> UsageData:
|
|
22
|
+
"""
|
|
23
|
+
Extract token usage from an API response.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
response: The raw API response object
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
UsageData with token counts
|
|
30
|
+
"""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def extract_model(self, response: Any) -> str:
|
|
35
|
+
"""
|
|
36
|
+
Extract the model name from an API response.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
response: The raw API response object
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Model name string
|
|
43
|
+
"""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def extract_cached_tokens(self, response: Any) -> int:
|
|
47
|
+
"""
|
|
48
|
+
Extract cached token count from an API response.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
response: The raw API response object
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Number of cached tokens (0 if not applicable)
|
|
55
|
+
"""
|
|
56
|
+
return 0
|
|
57
|
+
|
|
58
|
+
def supports_streaming(self) -> bool:
|
|
59
|
+
"""Check if this provider supports streaming."""
|
|
60
|
+
return True
|
|
61
|
+
|
|
62
|
+
def normalize_model_name(self, model: str) -> str:
|
|
63
|
+
"""
|
|
64
|
+
Normalize model name for consistent pricing lookup.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
model: Raw model name from API
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Normalized model name
|
|
71
|
+
"""
|
|
72
|
+
return model
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AWS Bedrock provider for LLM Cost Guard.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from llm_cost_guard.models import UsageData
|
|
9
|
+
from llm_cost_guard.providers.base import Provider
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BedrockProvider(Provider):
|
|
13
|
+
"""AWS Bedrock API provider."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def name(self) -> str:
|
|
17
|
+
return "bedrock"
|
|
18
|
+
|
|
19
|
+
def extract_usage(self, response: Any) -> UsageData:
|
|
20
|
+
"""Extract token usage from a Bedrock API response."""
|
|
21
|
+
usage = UsageData()
|
|
22
|
+
|
|
23
|
+
# Handle dictionary response (from boto3)
|
|
24
|
+
if isinstance(response, dict):
|
|
25
|
+
# Check for standard usage field
|
|
26
|
+
if "usage" in response:
|
|
27
|
+
usage_data = response["usage"]
|
|
28
|
+
usage.input_tokens = usage_data.get("inputTokens", 0)
|
|
29
|
+
usage.output_tokens = usage_data.get("outputTokens", 0)
|
|
30
|
+
usage.total_tokens = usage_data.get("totalTokens", 0)
|
|
31
|
+
return usage
|
|
32
|
+
|
|
33
|
+
# Check for Bedrock response metadata
|
|
34
|
+
metadata = response.get("ResponseMetadata", {})
|
|
35
|
+
headers = metadata.get("HTTPHeaders", {})
|
|
36
|
+
|
|
37
|
+
# Bedrock includes token counts in headers for some models
|
|
38
|
+
if "x-amzn-bedrock-input-token-count" in headers:
|
|
39
|
+
usage.input_tokens = int(headers["x-amzn-bedrock-input-token-count"])
|
|
40
|
+
if "x-amzn-bedrock-output-token-count" in headers:
|
|
41
|
+
usage.output_tokens = int(headers["x-amzn-bedrock-output-token-count"])
|
|
42
|
+
|
|
43
|
+
# Also check the body for Claude-style responses
|
|
44
|
+
body = response.get("body", {})
|
|
45
|
+
if isinstance(body, bytes):
|
|
46
|
+
try:
|
|
47
|
+
body = json.loads(body.decode("utf-8"))
|
|
48
|
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
49
|
+
body = {}
|
|
50
|
+
|
|
51
|
+
if isinstance(body, dict) and "usage" in body:
|
|
52
|
+
body_usage = body["usage"]
|
|
53
|
+
usage.input_tokens = body_usage.get(
|
|
54
|
+
"input_tokens", body_usage.get("inputTokens", usage.input_tokens)
|
|
55
|
+
)
|
|
56
|
+
usage.output_tokens = body_usage.get(
|
|
57
|
+
"output_tokens", body_usage.get("outputTokens", usage.output_tokens)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
usage.total_tokens = usage.input_tokens + usage.output_tokens
|
|
61
|
+
|
|
62
|
+
return usage
|
|
63
|
+
|
|
64
|
+
def extract_model(self, response: Any) -> str:
|
|
65
|
+
"""Extract the model name from a Bedrock API response."""
|
|
66
|
+
if isinstance(response, dict):
|
|
67
|
+
# Check for model ID in various locations
|
|
68
|
+
if "modelId" in response:
|
|
69
|
+
return response["modelId"]
|
|
70
|
+
|
|
71
|
+
# Check response metadata
|
|
72
|
+
metadata = response.get("ResponseMetadata", {})
|
|
73
|
+
if "modelId" in metadata:
|
|
74
|
+
return metadata["modelId"]
|
|
75
|
+
|
|
76
|
+
return "unknown"
|
|
77
|
+
|
|
78
|
+
def normalize_model_name(self, model: str) -> str:
|
|
79
|
+
"""Normalize Bedrock model name for pricing lookup."""
|
|
80
|
+
# Bedrock model IDs include version info that should be preserved
|
|
81
|
+
# e.g., "anthropic.claude-3-sonnet-20240229-v1:0"
|
|
82
|
+
return model
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class BedrockStreamingHandler:
|
|
86
|
+
"""Handler for streaming Bedrock responses."""
|
|
87
|
+
|
|
88
|
+
def __init__(self, model_id: str = "unknown"):
|
|
89
|
+
self.input_tokens = 0
|
|
90
|
+
self.output_tokens = 0
|
|
91
|
+
self.model = model_id
|
|
92
|
+
self._chunks: list = []
|
|
93
|
+
|
|
94
|
+
def handle_chunk(self, chunk: Any) -> None:
|
|
95
|
+
"""Process a streaming chunk."""
|
|
96
|
+
if isinstance(chunk, dict):
|
|
97
|
+
# Handle Bedrock's streaming chunk format
|
|
98
|
+
if "chunk" in chunk:
|
|
99
|
+
chunk_data = chunk["chunk"]
|
|
100
|
+
if "bytes" in chunk_data:
|
|
101
|
+
try:
|
|
102
|
+
parsed = json.loads(chunk_data["bytes"].decode("utf-8"))
|
|
103
|
+
self._process_parsed_chunk(parsed)
|
|
104
|
+
except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
# Some responses have direct usage info
|
|
108
|
+
if "usage" in chunk:
|
|
109
|
+
usage = chunk["usage"]
|
|
110
|
+
self.input_tokens = usage.get("inputTokens", self.input_tokens)
|
|
111
|
+
self.output_tokens = usage.get("outputTokens", self.output_tokens)
|
|
112
|
+
|
|
113
|
+
self._chunks.append(chunk)
|
|
114
|
+
|
|
115
|
+
def _process_parsed_chunk(self, parsed: dict) -> None:
|
|
116
|
+
"""Process a parsed JSON chunk."""
|
|
117
|
+
# Claude on Bedrock format
|
|
118
|
+
if "usage" in parsed:
|
|
119
|
+
usage = parsed["usage"]
|
|
120
|
+
self.input_tokens = usage.get("input_tokens", self.input_tokens)
|
|
121
|
+
self.output_tokens = usage.get("output_tokens", self.output_tokens)
|
|
122
|
+
|
|
123
|
+
# Llama/Titan format
|
|
124
|
+
if "amazon-bedrock-invocationMetrics" in parsed:
|
|
125
|
+
metrics = parsed["amazon-bedrock-invocationMetrics"]
|
|
126
|
+
self.input_tokens = metrics.get("inputTokenCount", self.input_tokens)
|
|
127
|
+
self.output_tokens = metrics.get("outputTokenCount", self.output_tokens)
|
|
128
|
+
|
|
129
|
+
def get_usage(self) -> UsageData:
|
|
130
|
+
"""Get final usage data."""
|
|
131
|
+
return UsageData(
|
|
132
|
+
input_tokens=self.input_tokens,
|
|
133
|
+
output_tokens=self.output_tokens,
|
|
134
|
+
total_tokens=self.input_tokens + self.output_tokens,
|
|
135
|
+
)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenAI provider for LLM Cost Guard.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Any, Optional
|
|
6
|
+
|
|
7
|
+
from llm_cost_guard.models import UsageData
|
|
8
|
+
from llm_cost_guard.providers.base import Provider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenAIProvider(Provider):
|
|
12
|
+
"""OpenAI API provider."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def name(self) -> str:
|
|
16
|
+
return "openai"
|
|
17
|
+
|
|
18
|
+
def extract_usage(self, response: Any) -> UsageData:
|
|
19
|
+
"""Extract token usage from an OpenAI API response."""
|
|
20
|
+
usage = UsageData()
|
|
21
|
+
|
|
22
|
+
# Handle dictionary response
|
|
23
|
+
if isinstance(response, dict):
|
|
24
|
+
usage_data = response.get("usage", {})
|
|
25
|
+
usage.input_tokens = usage_data.get("prompt_tokens", 0)
|
|
26
|
+
usage.output_tokens = usage_data.get("completion_tokens", 0)
|
|
27
|
+
usage.total_tokens = usage_data.get("total_tokens", 0)
|
|
28
|
+
|
|
29
|
+
# Check for cached tokens (prompt caching)
|
|
30
|
+
prompt_tokens_details = usage_data.get("prompt_tokens_details", {})
|
|
31
|
+
if prompt_tokens_details:
|
|
32
|
+
usage.cached_tokens = prompt_tokens_details.get("cached_tokens", 0)
|
|
33
|
+
|
|
34
|
+
return usage
|
|
35
|
+
|
|
36
|
+
# Handle OpenAI client response object
|
|
37
|
+
if hasattr(response, "usage") and response.usage is not None:
|
|
38
|
+
usage.input_tokens = getattr(response.usage, "prompt_tokens", 0) or 0
|
|
39
|
+
usage.output_tokens = getattr(response.usage, "completion_tokens", 0) or 0
|
|
40
|
+
usage.total_tokens = getattr(response.usage, "total_tokens", 0) or 0
|
|
41
|
+
|
|
42
|
+
# Check for prompt caching details
|
|
43
|
+
if hasattr(response.usage, "prompt_tokens_details"):
|
|
44
|
+
details = response.usage.prompt_tokens_details
|
|
45
|
+
if details and hasattr(details, "cached_tokens"):
|
|
46
|
+
usage.cached_tokens = details.cached_tokens or 0
|
|
47
|
+
|
|
48
|
+
return usage
|
|
49
|
+
|
|
50
|
+
def extract_model(self, response: Any) -> str:
|
|
51
|
+
"""Extract the model name from an OpenAI API response."""
|
|
52
|
+
if isinstance(response, dict):
|
|
53
|
+
return response.get("model", "unknown")
|
|
54
|
+
|
|
55
|
+
if hasattr(response, "model"):
|
|
56
|
+
return response.model or "unknown"
|
|
57
|
+
|
|
58
|
+
return "unknown"
|
|
59
|
+
|
|
60
|
+
def extract_cached_tokens(self, response: Any) -> int:
|
|
61
|
+
"""Extract cached token count from an OpenAI API response."""
|
|
62
|
+
usage = self.extract_usage(response)
|
|
63
|
+
return usage.cached_tokens
|
|
64
|
+
|
|
65
|
+
def normalize_model_name(self, model: str) -> str:
|
|
66
|
+
"""Normalize OpenAI model name."""
|
|
67
|
+
# Remove date suffixes for pricing lookup
|
|
68
|
+
# e.g., "gpt-4-0613" -> "gpt-4"
|
|
69
|
+
parts = model.split("-")
|
|
70
|
+
if len(parts) >= 3 and parts[-1].isdigit() and len(parts[-1]) >= 4:
|
|
71
|
+
return "-".join(parts[:-1])
|
|
72
|
+
return model
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class OpenAIStreamingHandler:
|
|
76
|
+
"""Handler for streaming OpenAI responses."""
|
|
77
|
+
|
|
78
|
+
def __init__(self):
|
|
79
|
+
self.input_tokens = 0
|
|
80
|
+
self.output_tokens = 0
|
|
81
|
+
self.model = "unknown"
|
|
82
|
+
self._chunk_count = 0
|
|
83
|
+
|
|
84
|
+
def handle_chunk(self, chunk: Any) -> None:
|
|
85
|
+
"""Process a streaming chunk."""
|
|
86
|
+
self._chunk_count += 1
|
|
87
|
+
|
|
88
|
+
# Extract model from first chunk
|
|
89
|
+
if self._chunk_count == 1:
|
|
90
|
+
if isinstance(chunk, dict):
|
|
91
|
+
self.model = chunk.get("model", self.model)
|
|
92
|
+
elif hasattr(chunk, "model"):
|
|
93
|
+
self.model = chunk.model or self.model
|
|
94
|
+
|
|
95
|
+
# Some providers include usage in the final chunk
|
|
96
|
+
if isinstance(chunk, dict):
|
|
97
|
+
if "usage" in chunk and chunk["usage"]:
|
|
98
|
+
self.input_tokens = chunk["usage"].get("prompt_tokens", 0)
|
|
99
|
+
self.output_tokens = chunk["usage"].get("completion_tokens", 0)
|
|
100
|
+
elif hasattr(chunk, "usage") and chunk.usage:
|
|
101
|
+
self.input_tokens = getattr(chunk.usage, "prompt_tokens", 0) or 0
|
|
102
|
+
self.output_tokens = getattr(chunk.usage, "completion_tokens", 0) or 0
|
|
103
|
+
|
|
104
|
+
def get_usage(self) -> UsageData:
|
|
105
|
+
"""Get final usage data."""
|
|
106
|
+
return UsageData(
|
|
107
|
+
input_tokens=self.input_tokens,
|
|
108
|
+
output_tokens=self.output_tokens,
|
|
109
|
+
total_tokens=self.input_tokens + self.output_tokens,
|
|
110
|
+
)
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Rate limiting for LLM Cost Guard.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from datetime import datetime, timedelta
|
|
7
|
+
from typing import Dict, List, Literal, Optional, Tuple
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
|
|
11
|
+
RateLimitPeriod = Literal["second", "minute", "hour"]
|
|
12
|
+
RateLimitScope = Literal["global", "model", "provider"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class RateLimit:
|
|
17
|
+
"""Rate limit configuration."""
|
|
18
|
+
|
|
19
|
+
name: str
|
|
20
|
+
limit: int
|
|
21
|
+
period: RateLimitPeriod = "minute"
|
|
22
|
+
scope: str = "global" # "global", "model", "provider", or "tag:key_name"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SlidingWindowCounter:
|
|
26
|
+
"""Sliding window rate limiter implementation."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, window_size_seconds: float, limit: int):
|
|
29
|
+
self._window_size = window_size_seconds
|
|
30
|
+
self._limit = limit
|
|
31
|
+
self._requests: List[float] = []
|
|
32
|
+
self._lock = threading.Lock()
|
|
33
|
+
|
|
34
|
+
def _cleanup(self, now: float) -> None:
|
|
35
|
+
"""Remove expired entries."""
|
|
36
|
+
cutoff = now - self._window_size
|
|
37
|
+
self._requests = [t for t in self._requests if t > cutoff]
|
|
38
|
+
|
|
39
|
+
def check(self) -> Tuple[bool, int, Optional[float]]:
|
|
40
|
+
"""
|
|
41
|
+
Check if a request would be allowed.
|
|
42
|
+
Returns (allowed, current_count, retry_after_seconds).
|
|
43
|
+
"""
|
|
44
|
+
now = time.time()
|
|
45
|
+
with self._lock:
|
|
46
|
+
self._cleanup(now)
|
|
47
|
+
current = len(self._requests)
|
|
48
|
+
|
|
49
|
+
if current >= self._limit:
|
|
50
|
+
# Calculate when the oldest request will expire
|
|
51
|
+
if self._requests:
|
|
52
|
+
retry_after = self._requests[0] + self._window_size - now
|
|
53
|
+
return False, current, max(0.0, retry_after)
|
|
54
|
+
return False, current, self._window_size
|
|
55
|
+
|
|
56
|
+
return True, current, None
|
|
57
|
+
|
|
58
|
+
def record(self) -> bool:
|
|
59
|
+
"""
|
|
60
|
+
Record a request. Returns True if allowed, False if rate limited.
|
|
61
|
+
"""
|
|
62
|
+
now = time.time()
|
|
63
|
+
with self._lock:
|
|
64
|
+
self._cleanup(now)
|
|
65
|
+
|
|
66
|
+
if len(self._requests) >= self._limit:
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
self._requests.append(now)
|
|
70
|
+
return True
|
|
71
|
+
|
|
72
|
+
def get_count(self) -> int:
|
|
73
|
+
"""Get current request count in the window."""
|
|
74
|
+
now = time.time()
|
|
75
|
+
with self._lock:
|
|
76
|
+
self._cleanup(now)
|
|
77
|
+
return len(self._requests)
|
|
78
|
+
|
|
79
|
+
def reset(self) -> None:
|
|
80
|
+
"""Reset the counter."""
|
|
81
|
+
with self._lock:
|
|
82
|
+
self._requests = []
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class RateLimiter:
|
|
86
|
+
"""Manages rate limiting across multiple limits and scopes."""
|
|
87
|
+
|
|
88
|
+
def __init__(self, rate_limits: Optional[List[RateLimit]] = None):
|
|
89
|
+
self._rate_limits = rate_limits or []
|
|
90
|
+
# Key: (limit_name, scope_value) -> SlidingWindowCounter
|
|
91
|
+
self._counters: Dict[Tuple[str, str], SlidingWindowCounter] = {}
|
|
92
|
+
self._lock = threading.Lock()
|
|
93
|
+
|
|
94
|
+
def _get_period_seconds(self, period: RateLimitPeriod) -> float:
|
|
95
|
+
"""Get period in seconds."""
|
|
96
|
+
periods = {
|
|
97
|
+
"second": 1.0,
|
|
98
|
+
"minute": 60.0,
|
|
99
|
+
"hour": 3600.0,
|
|
100
|
+
}
|
|
101
|
+
return periods.get(period, 60.0)
|
|
102
|
+
|
|
103
|
+
def _get_scope_key(
|
|
104
|
+
self,
|
|
105
|
+
rate_limit: RateLimit,
|
|
106
|
+
model: Optional[str] = None,
|
|
107
|
+
provider: Optional[str] = None,
|
|
108
|
+
tags: Optional[Dict[str, str]] = None,
|
|
109
|
+
) -> str:
|
|
110
|
+
"""Get the scope key for a rate limit."""
|
|
111
|
+
scope = rate_limit.scope
|
|
112
|
+
tags = tags or {}
|
|
113
|
+
|
|
114
|
+
if scope == "global":
|
|
115
|
+
return "global"
|
|
116
|
+
elif scope == "model":
|
|
117
|
+
return f"model:{model or 'unknown'}"
|
|
118
|
+
elif scope == "provider":
|
|
119
|
+
return f"provider:{provider or 'unknown'}"
|
|
120
|
+
elif scope.startswith("tag:"):
|
|
121
|
+
tag_key = scope[4:]
|
|
122
|
+
tag_value = tags.get(tag_key, "unknown")
|
|
123
|
+
return f"tag:{tag_key}:{tag_value}"
|
|
124
|
+
return "global"
|
|
125
|
+
|
|
126
|
+
def _get_counter(
|
|
127
|
+
self,
|
|
128
|
+
rate_limit: RateLimit,
|
|
129
|
+
scope_key: str,
|
|
130
|
+
) -> SlidingWindowCounter:
|
|
131
|
+
"""Get or create a counter for a rate limit and scope."""
|
|
132
|
+
key = (rate_limit.name, scope_key)
|
|
133
|
+
|
|
134
|
+
with self._lock:
|
|
135
|
+
if key not in self._counters:
|
|
136
|
+
window_size = self._get_period_seconds(rate_limit.period)
|
|
137
|
+
self._counters[key] = SlidingWindowCounter(window_size, rate_limit.limit)
|
|
138
|
+
return self._counters[key]
|
|
139
|
+
|
|
140
|
+
def add_rate_limit(self, rate_limit: RateLimit) -> None:
|
|
141
|
+
"""Add a new rate limit."""
|
|
142
|
+
self._rate_limits.append(rate_limit)
|
|
143
|
+
|
|
144
|
+
def remove_rate_limit(self, name: str) -> bool:
|
|
145
|
+
"""Remove a rate limit by name."""
|
|
146
|
+
for i, rl in enumerate(self._rate_limits):
|
|
147
|
+
if rl.name == name:
|
|
148
|
+
del self._rate_limits[i]
|
|
149
|
+
# Remove associated counters
|
|
150
|
+
with self._lock:
|
|
151
|
+
keys_to_remove = [k for k in self._counters if k[0] == name]
|
|
152
|
+
for k in keys_to_remove:
|
|
153
|
+
del self._counters[k]
|
|
154
|
+
return True
|
|
155
|
+
return False
|
|
156
|
+
|
|
157
|
+
def get_rate_limit(self, name: str) -> Optional[RateLimit]:
|
|
158
|
+
"""Get a rate limit by name."""
|
|
159
|
+
for rl in self._rate_limits:
|
|
160
|
+
if rl.name == name:
|
|
161
|
+
return rl
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
def check(
|
|
165
|
+
self,
|
|
166
|
+
model: Optional[str] = None,
|
|
167
|
+
provider: Optional[str] = None,
|
|
168
|
+
tags: Optional[Dict[str, str]] = None,
|
|
169
|
+
) -> List[Tuple[RateLimit, int, Optional[float]]]:
|
|
170
|
+
"""
|
|
171
|
+
Check all rate limits.
|
|
172
|
+
Returns list of (rate_limit, current_count, retry_after) for exceeded limits.
|
|
173
|
+
"""
|
|
174
|
+
exceeded = []
|
|
175
|
+
|
|
176
|
+
for rate_limit in self._rate_limits:
|
|
177
|
+
scope_key = self._get_scope_key(rate_limit, model, provider, tags)
|
|
178
|
+
counter = self._get_counter(rate_limit, scope_key)
|
|
179
|
+
allowed, current, retry_after = counter.check()
|
|
180
|
+
|
|
181
|
+
if not allowed:
|
|
182
|
+
exceeded.append((rate_limit, current, retry_after))
|
|
183
|
+
|
|
184
|
+
return exceeded
|
|
185
|
+
|
|
186
|
+
def record(
|
|
187
|
+
self,
|
|
188
|
+
model: Optional[str] = None,
|
|
189
|
+
provider: Optional[str] = None,
|
|
190
|
+
tags: Optional[Dict[str, str]] = None,
|
|
191
|
+
) -> List[Tuple[RateLimit, int, Optional[float]]]:
|
|
192
|
+
"""
|
|
193
|
+
Record a request against all applicable rate limits.
|
|
194
|
+
Returns list of (rate_limit, current_count, retry_after) for limits that are now exceeded.
|
|
195
|
+
"""
|
|
196
|
+
exceeded = []
|
|
197
|
+
|
|
198
|
+
for rate_limit in self._rate_limits:
|
|
199
|
+
scope_key = self._get_scope_key(rate_limit, model, provider, tags)
|
|
200
|
+
counter = self._get_counter(rate_limit, scope_key)
|
|
201
|
+
|
|
202
|
+
if not counter.record():
|
|
203
|
+
_, current, retry_after = counter.check()
|
|
204
|
+
exceeded.append((rate_limit, current, retry_after))
|
|
205
|
+
|
|
206
|
+
return exceeded
|
|
207
|
+
|
|
208
|
+
def get_remaining(
|
|
209
|
+
self,
|
|
210
|
+
name: str,
|
|
211
|
+
model: Optional[str] = None,
|
|
212
|
+
provider: Optional[str] = None,
|
|
213
|
+
tags: Optional[Dict[str, str]] = None,
|
|
214
|
+
) -> int:
|
|
215
|
+
"""Get remaining requests for a rate limit."""
|
|
216
|
+
rate_limit = self.get_rate_limit(name)
|
|
217
|
+
if rate_limit is None:
|
|
218
|
+
return 0
|
|
219
|
+
|
|
220
|
+
scope_key = self._get_scope_key(rate_limit, model, provider, tags)
|
|
221
|
+
counter = self._get_counter(rate_limit, scope_key)
|
|
222
|
+
return max(0, rate_limit.limit - counter.get_count())
|
|
223
|
+
|
|
224
|
+
def reset(self, name: Optional[str] = None) -> None:
|
|
225
|
+
"""Reset counters for a specific rate limit or all rate limits."""
|
|
226
|
+
with self._lock:
|
|
227
|
+
if name:
|
|
228
|
+
keys_to_reset = [k for k in self._counters if k[0] == name]
|
|
229
|
+
for k in keys_to_reset:
|
|
230
|
+
self._counters[k].reset()
|
|
231
|
+
else:
|
|
232
|
+
for counter in self._counters.values():
|
|
233
|
+
counter.reset()
|
llm_cost_guard/span.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hierarchical tracking spans for LLM Cost Guard.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
|
|
8
|
+
import contextvars
|
|
9
|
+
import threading
|
|
10
|
+
import uuid
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from llm_cost_guard.models import CostRecord
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Context variable to track the current span
|
|
17
|
+
_current_span: contextvars.ContextVar[Optional["Span"]] = contextvars.ContextVar(
|
|
18
|
+
"current_span", default=None
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class Span:
|
|
24
|
+
"""Hierarchical tracking span for grouping multiple LLM calls."""
|
|
25
|
+
|
|
26
|
+
name: str
|
|
27
|
+
span_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
28
|
+
parent_id: Optional[str] = None
|
|
29
|
+
start_time: Optional[datetime] = None
|
|
30
|
+
end_time: Optional[datetime] = None
|
|
31
|
+
tags: Dict[str, str] = field(default_factory=dict)
|
|
32
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
33
|
+
|
|
34
|
+
# Aggregated metrics
|
|
35
|
+
total_cost: float = 0.0
|
|
36
|
+
total_input_tokens: int = 0
|
|
37
|
+
total_output_tokens: int = 0
|
|
38
|
+
call_count: int = 0
|
|
39
|
+
models_used: Set[str] = field(default_factory=set)
|
|
40
|
+
|
|
41
|
+
# Hierarchy
|
|
42
|
+
children: List["Span"] = field(default_factory=list)
|
|
43
|
+
_records: List["CostRecord"] = field(default_factory=list)
|
|
44
|
+
_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
45
|
+
|
|
46
|
+
# For context management
|
|
47
|
+
_token: Optional[contextvars.Token] = field(default=None, repr=False)
|
|
48
|
+
_previous_span: Optional["Span"] = field(default=None, repr=False)
|
|
49
|
+
|
|
50
|
+
def __enter__(self) -> "Span":
|
|
51
|
+
"""Enter the span context."""
|
|
52
|
+
self.start_time = datetime.now()
|
|
53
|
+
|
|
54
|
+
# Save previous span and set this as current
|
|
55
|
+
self._previous_span = _current_span.get()
|
|
56
|
+
self._token = _current_span.set(self)
|
|
57
|
+
|
|
58
|
+
# If there's a parent span, register as child
|
|
59
|
+
if self._previous_span is not None:
|
|
60
|
+
self.parent_id = self._previous_span.span_id
|
|
61
|
+
self._previous_span._add_child(self)
|
|
62
|
+
|
|
63
|
+
return self
|
|
64
|
+
|
|
65
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
66
|
+
"""Exit the span context."""
|
|
67
|
+
self.end_time = datetime.now()
|
|
68
|
+
|
|
69
|
+
# Restore previous span
|
|
70
|
+
if self._token is not None:
|
|
71
|
+
_current_span.reset(self._token)
|
|
72
|
+
|
|
73
|
+
# Propagate costs to parent
|
|
74
|
+
if self._previous_span is not None:
|
|
75
|
+
self._previous_span._propagate_child_costs(self)
|
|
76
|
+
|
|
77
|
+
def _add_child(self, child: "Span") -> None:
|
|
78
|
+
"""Add a child span."""
|
|
79
|
+
with self._lock:
|
|
80
|
+
self.children.append(child)
|
|
81
|
+
|
|
82
|
+
def _propagate_child_costs(self, child: "Span") -> None:
|
|
83
|
+
"""Propagate child span costs to this span."""
|
|
84
|
+
with self._lock:
|
|
85
|
+
self.total_cost += child.total_cost
|
|
86
|
+
self.total_input_tokens += child.total_input_tokens
|
|
87
|
+
self.total_output_tokens += child.total_output_tokens
|
|
88
|
+
self.call_count += child.call_count
|
|
89
|
+
self.models_used.update(child.models_used)
|
|
90
|
+
|
|
91
|
+
def record_call(
|
|
92
|
+
self,
|
|
93
|
+
cost: float,
|
|
94
|
+
input_tokens: int,
|
|
95
|
+
output_tokens: int,
|
|
96
|
+
model: str,
|
|
97
|
+
record: Optional["CostRecord"] = None,
|
|
98
|
+
) -> None:
|
|
99
|
+
"""Record an LLM call in this span."""
|
|
100
|
+
with self._lock:
|
|
101
|
+
self.total_cost += cost
|
|
102
|
+
self.total_input_tokens += input_tokens
|
|
103
|
+
self.total_output_tokens += output_tokens
|
|
104
|
+
self.call_count += 1
|
|
105
|
+
self.models_used.add(model)
|
|
106
|
+
if record is not None:
|
|
107
|
+
self._records.append(record)
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def duration_ms(self) -> Optional[int]:
|
|
111
|
+
"""Get the span duration in milliseconds."""
|
|
112
|
+
if self.start_time is None or self.end_time is None:
|
|
113
|
+
return None
|
|
114
|
+
delta = self.end_time - self.start_time
|
|
115
|
+
return int(delta.total_seconds() * 1000)
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def records(self) -> List["CostRecord"]:
|
|
119
|
+
"""Get all records in this span."""
|
|
120
|
+
return list(self._records)
|
|
121
|
+
|
|
122
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
123
|
+
"""Convert span to dictionary."""
|
|
124
|
+
return {
|
|
125
|
+
"name": self.name,
|
|
126
|
+
"span_id": self.span_id,
|
|
127
|
+
"parent_id": self.parent_id,
|
|
128
|
+
"start_time": self.start_time.isoformat() if self.start_time else None,
|
|
129
|
+
"end_time": self.end_time.isoformat() if self.end_time else None,
|
|
130
|
+
"duration_ms": self.duration_ms,
|
|
131
|
+
"tags": self.tags,
|
|
132
|
+
"total_cost": self.total_cost,
|
|
133
|
+
"total_input_tokens": self.total_input_tokens,
|
|
134
|
+
"total_output_tokens": self.total_output_tokens,
|
|
135
|
+
"call_count": self.call_count,
|
|
136
|
+
"models_used": list(self.models_used),
|
|
137
|
+
"children": [child.to_dict() for child in self.children],
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def get_current_span() -> Optional[Span]:
|
|
142
|
+
"""Get the current active span, if any."""
|
|
143
|
+
return _current_span.get()
|