zae-limiter 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,130 @@
1
+ """
2
+ zae-limiter: Rate limiting library backed by DynamoDB.
3
+
4
+ This library provides a token bucket rate limiter with:
5
+ - Multiple limits per entity/resource
6
+ - Two-level hierarchy (parent/child entities)
7
+ - Cascade mode (consume from entity + parent)
8
+ - Stored limit configs
9
+ - Usage analytics via Lambda aggregator
10
+
11
+ Example:
12
+ from zae_limiter import RateLimiter, Limit, FailureMode
13
+
14
+ limiter = RateLimiter(
15
+ table_name="rate_limits",
16
+ region="us-east-1",
17
+ create_table=True,
18
+ )
19
+
20
+ async with limiter.acquire(
21
+ entity_id="key-abc",
22
+ resource="gpt-4",
23
+ limits=[
24
+ Limit.per_minute("rpm", 100),
25
+ Limit.per_minute("tpm", 10_000),
26
+ ],
27
+ consume={"rpm": 1, "tpm": 500},
28
+ ) as lease:
29
+ response = await llm_call()
30
+ await lease.adjust(tpm=response.usage.total_tokens - 500)
31
+ """
32
+
33
+ from .exceptions import (
34
+ EntityExistsError,
35
+ EntityNotFoundError,
36
+ IncompatibleSchemaError,
37
+ InfrastructureNotFoundError,
38
+ RateLimitError,
39
+ RateLimiterUnavailable,
40
+ RateLimitExceeded,
41
+ StackAlreadyExistsError,
42
+ StackCreationError,
43
+ VersionError,
44
+ VersionMismatchError,
45
+ )
46
+ from .lease import Lease, SyncLease
47
+ from .models import (
48
+ BucketState,
49
+ Entity,
50
+ EntityCapacity,
51
+ Limit,
52
+ LimitName,
53
+ LimitStatus,
54
+ ResourceCapacity,
55
+ UsageSnapshot,
56
+ )
57
+
58
+ # RateLimiter, SyncRateLimiter, FailureMode, and StackManager are imported
59
+ # lazily via __getattr__ to avoid loading aioboto3 for Lambda functions
60
+ # that only need boto3
61
+
62
+ try:
63
+ from ._version import __version__ # type: ignore[import-untyped]
64
+ except ImportError:
65
+ __version__ = "0.0.0+unknown"
66
+
67
+ __all__ = [
68
+ # Version
69
+ "__version__",
70
+ # Main classes
71
+ "RateLimiter",
72
+ "SyncRateLimiter",
73
+ "Lease",
74
+ "SyncLease",
75
+ "StackManager",
76
+ # Models
77
+ "Limit",
78
+ "LimitName",
79
+ "Entity",
80
+ "LimitStatus",
81
+ "BucketState",
82
+ "UsageSnapshot",
83
+ "ResourceCapacity",
84
+ "EntityCapacity",
85
+ # Enums
86
+ "FailureMode",
87
+ # Exceptions
88
+ "RateLimitError",
89
+ "RateLimitExceeded",
90
+ "RateLimiterUnavailable",
91
+ "EntityNotFoundError",
92
+ "EntityExistsError",
93
+ "StackCreationError",
94
+ "StackAlreadyExistsError",
95
+ "VersionError",
96
+ "VersionMismatchError",
97
+ "IncompatibleSchemaError",
98
+ "InfrastructureNotFoundError",
99
+ ]
100
+
101
+
102
+ def __getattr__(name: str) -> type:
103
+ """
104
+ Lazy import for modules with heavy dependencies.
105
+
106
+ This allows the package to be imported without loading aioboto3,
107
+ which is critical for Lambda functions that only need boto3.
108
+
109
+ The aggregator Lambda function imports the handler which would normally
110
+ trigger loading of the entire package. By making RateLimiter and
111
+ StackManager lazy imports, we avoid loading aioboto3 (not available in
112
+ Lambda runtime) while maintaining backward compatibility for regular usage.
113
+ """
114
+ if name == "RateLimiter":
115
+ from .limiter import RateLimiter
116
+
117
+ return RateLimiter
118
+ if name == "SyncRateLimiter":
119
+ from .limiter import SyncRateLimiter
120
+
121
+ return SyncRateLimiter
122
+ if name == "FailureMode":
123
+ from .limiter import FailureMode
124
+
125
+ return FailureMode
126
+ if name == "StackManager":
127
+ from .infra.stack_manager import StackManager
128
+
129
+ return StackManager
130
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,11 @@
1
+ """Lambda aggregator for usage snapshots."""
2
+
3
+ from .handler import handler
4
+ from .processor import ConsumptionDelta, ProcessResult, process_stream_records
5
+
6
+ __all__ = [
7
+ "handler",
8
+ "process_stream_records",
9
+ "ProcessResult",
10
+ "ConsumptionDelta",
11
+ ]
@@ -0,0 +1,54 @@
1
+ """Lambda handler for DynamoDB Stream events."""
2
+
3
+ import os
4
+ from typing import Any
5
+
6
+ from .processor import process_stream_records
7
+
8
+ # Configuration from environment
9
+ TABLE_NAME = os.environ.get("TABLE_NAME", "rate_limits")
10
+ SNAPSHOT_WINDOWS = os.environ.get("SNAPSHOT_WINDOWS", "hourly,daily").split(",")
11
+ SNAPSHOT_TTL_DAYS = int(os.environ.get("SNAPSHOT_TTL_DAYS", "90"))
12
+
13
+
14
+ def handler(event: dict[str, Any], context: Any) -> dict[str, Any]:
15
+ """
16
+ Lambda handler for DynamoDB Stream events.
17
+
18
+ Processes bucket changes and updates usage snapshots.
19
+
20
+ Environment variables:
21
+ TABLE_NAME: DynamoDB table name (default: rate_limits)
22
+ SNAPSHOT_WINDOWS: Comma-separated windows (default: hourly,daily)
23
+ SNAPSHOT_TTL_DAYS: TTL for snapshots in days (default: 90)
24
+
25
+ Args:
26
+ event: DynamoDB Stream event
27
+ context: Lambda context
28
+
29
+ Returns:
30
+ Processing result summary
31
+ """
32
+ records = event.get("Records", [])
33
+
34
+ if not records:
35
+ return {
36
+ "statusCode": 200,
37
+ "body": {"processed": 0, "snapshots_updated": 0, "errors": []},
38
+ }
39
+
40
+ result = process_stream_records(
41
+ records=records,
42
+ table_name=TABLE_NAME,
43
+ windows=SNAPSHOT_WINDOWS,
44
+ ttl_days=SNAPSHOT_TTL_DAYS,
45
+ )
46
+
47
+ return {
48
+ "statusCode": 200,
49
+ "body": {
50
+ "processed": result.processed_count,
51
+ "snapshots_updated": result.snapshots_updated,
52
+ "errors": result.errors,
53
+ },
54
+ }
@@ -0,0 +1,270 @@
1
+ """DynamoDB Stream processor for usage aggregation."""
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from datetime import UTC, datetime, timedelta
6
+ from typing import Any
7
+
8
+ import boto3 # type: ignore[import-untyped]
9
+
10
+ from ..schema import SK_BUCKET, gsi2_pk_resource, gsi2_sk_usage, pk_entity, sk_usage
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class ProcessResult:
17
+ """Result of processing stream records."""
18
+
19
+ processed_count: int
20
+ snapshots_updated: int
21
+ errors: list[str]
22
+
23
+
24
+ @dataclass
25
+ class ConsumptionDelta:
26
+ """Consumption delta extracted from stream record."""
27
+
28
+ entity_id: str
29
+ resource: str
30
+ limit_name: str
31
+ tokens_delta: int # positive = consumed, negative = refilled/returned
32
+ timestamp_ms: int
33
+
34
+
35
+ def process_stream_records(
36
+ records: list[dict[str, Any]],
37
+ table_name: str,
38
+ windows: list[str],
39
+ ttl_days: int = 90,
40
+ ) -> ProcessResult:
41
+ """
42
+ Process DynamoDB stream records and update usage snapshots.
43
+
44
+ 1. Filter for BUCKET records (MODIFY events)
45
+ 2. Extract consumption deltas from old/new images
46
+ 3. Aggregate into hourly/daily snapshot records
47
+ 4. Write updates using atomic ADD operations
48
+
49
+ Args:
50
+ records: DynamoDB stream records
51
+ table_name: Target table name
52
+ windows: List of window types ("hourly", "daily")
53
+ ttl_days: TTL for snapshot records
54
+
55
+ Returns:
56
+ ProcessResult with counts and errors
57
+ """
58
+ dynamodb = boto3.resource("dynamodb")
59
+ table = dynamodb.Table(table_name)
60
+
61
+ deltas: list[ConsumptionDelta] = []
62
+ errors: list[str] = []
63
+
64
+ # Extract deltas from records
65
+ for record in records:
66
+ if record.get("eventName") != "MODIFY":
67
+ continue
68
+
69
+ try:
70
+ delta = extract_delta(record)
71
+ if delta:
72
+ deltas.append(delta)
73
+ except Exception as e:
74
+ error_msg = f"Error processing record: {e}"
75
+ logger.warning(error_msg)
76
+ errors.append(error_msg)
77
+
78
+ if not deltas:
79
+ return ProcessResult(len(records), 0, errors)
80
+
81
+ # Update snapshots
82
+ snapshots_updated = 0
83
+ for delta in deltas:
84
+ for window in windows:
85
+ try:
86
+ update_snapshot(table, delta, window, ttl_days)
87
+ snapshots_updated += 1
88
+ except Exception as e:
89
+ error_msg = f"Error updating snapshot: {e}"
90
+ logger.warning(error_msg)
91
+ errors.append(error_msg)
92
+
93
+ return ProcessResult(len(records), snapshots_updated, errors)
94
+
95
+
96
+ def extract_delta(record: dict[str, Any]) -> ConsumptionDelta | None:
97
+ """
98
+ Extract consumption delta from a stream record.
99
+
100
+ Only processes BUCKET records where tokens decreased (consumption).
101
+
102
+ Args:
103
+ record: DynamoDB stream record
104
+
105
+ Returns:
106
+ ConsumptionDelta if this was a consumption event, None otherwise
107
+ """
108
+ dynamodb_data = record.get("dynamodb", {})
109
+ new_image = dynamodb_data.get("NewImage", {})
110
+ old_image = dynamodb_data.get("OldImage", {})
111
+
112
+ # Only process bucket records
113
+ sk = new_image.get("SK", {}).get("S", "")
114
+ if not sk.startswith(SK_BUCKET):
115
+ return None
116
+
117
+ # Parse key: #BUCKET#{resource}#{limit_name}
118
+ parts = sk[len(SK_BUCKET) :].split("#", 1)
119
+ if len(parts) != 2:
120
+ return None
121
+
122
+ resource, limit_name = parts
123
+ entity_id = new_image.get("entity_id", {}).get("S", "")
124
+
125
+ if not entity_id:
126
+ return None
127
+
128
+ # Extract token values from data map
129
+ new_data = new_image.get("data", {}).get("M", {})
130
+ old_data = old_image.get("data", {}).get("M", {})
131
+
132
+ new_tokens = int(new_data.get("tokens_milli", {}).get("N", "0"))
133
+ old_tokens = int(old_data.get("tokens_milli", {}).get("N", "0"))
134
+ new_refill_ms = int(new_data.get("last_refill_ms", {}).get("N", "0"))
135
+
136
+ # Calculate delta: old - new = amount consumed
137
+ # (tokens decrease when consumed)
138
+ tokens_delta = old_tokens - new_tokens
139
+
140
+ # We track all changes (consumption and refunds)
141
+ # but skip pure refill events (no net consumption)
142
+ if tokens_delta == 0:
143
+ return None
144
+
145
+ return ConsumptionDelta(
146
+ entity_id=entity_id,
147
+ resource=resource,
148
+ limit_name=limit_name,
149
+ tokens_delta=tokens_delta, # positive = consumed, negative = returned
150
+ timestamp_ms=new_refill_ms,
151
+ )
152
+
153
+
154
+ def get_window_key(timestamp_ms: int, window: str) -> str:
155
+ """
156
+ Get the window key (ISO timestamp) for a given timestamp.
157
+
158
+ Args:
159
+ timestamp_ms: Epoch milliseconds
160
+ window: Window type ("hourly", "daily", "monthly")
161
+
162
+ Returns:
163
+ ISO timestamp string for the window start
164
+ """
165
+ dt = datetime.fromtimestamp(timestamp_ms / 1000, tz=UTC)
166
+
167
+ if window == "hourly":
168
+ return dt.strftime("%Y-%m-%dT%H:00:00Z")
169
+ elif window == "daily":
170
+ return dt.strftime("%Y-%m-%dT00:00:00Z")
171
+ elif window == "monthly":
172
+ return dt.strftime("%Y-%m-01T00:00:00Z")
173
+ else:
174
+ raise ValueError(f"Unknown window type: {window}")
175
+
176
+
177
+ def get_window_end(window_key: str, window: str) -> str:
178
+ """
179
+ Get the window end timestamp.
180
+
181
+ Args:
182
+ window_key: Window start (ISO timestamp)
183
+ window: Window type
184
+
185
+ Returns:
186
+ ISO timestamp string for the window end
187
+ """
188
+ dt = datetime.fromisoformat(window_key.replace("Z", "+00:00"))
189
+
190
+ if window == "hourly":
191
+ end_dt = dt.replace(minute=59, second=59)
192
+ elif window == "daily":
193
+ end_dt = dt.replace(hour=23, minute=59, second=59)
194
+ elif window == "monthly":
195
+ # Last day of month
196
+ if dt.month == 12:
197
+ end_dt = dt.replace(year=dt.year + 1, month=1, day=1) - timedelta(seconds=1)
198
+ else:
199
+ end_dt = dt.replace(month=dt.month + 1, day=1) - timedelta(seconds=1)
200
+ else:
201
+ end_dt = dt
202
+
203
+ return end_dt.strftime("%Y-%m-%dT%H:%M:%SZ")
204
+
205
+
206
+ def calculate_snapshot_ttl(ttl_days: int) -> int:
207
+ """Calculate TTL epoch seconds."""
208
+ return int(datetime.now(UTC).timestamp()) + (ttl_days * 86400)
209
+
210
+
211
+ def update_snapshot(
212
+ table: Any,
213
+ delta: ConsumptionDelta,
214
+ window: str,
215
+ ttl_days: int,
216
+ ) -> None:
217
+ """
218
+ Update a usage snapshot record atomically.
219
+
220
+ Uses DynamoDB ADD operation to increment counters, creating
221
+ the record if it doesn't exist.
222
+
223
+ Args:
224
+ table: boto3 Table resource
225
+ delta: Consumption delta to record
226
+ window: Window type
227
+ ttl_days: TTL in days
228
+ """
229
+ window_key = get_window_key(delta.timestamp_ms, window)
230
+
231
+ # Convert millitokens to tokens for storage
232
+ tokens_delta = delta.tokens_delta // 1000
233
+
234
+ # Build update expression
235
+ # We use ADD for atomic increments and SET for metadata
236
+ table.update_item(
237
+ Key={
238
+ "PK": pk_entity(delta.entity_id),
239
+ "SK": sk_usage(delta.resource, window_key),
240
+ },
241
+ UpdateExpression="""
242
+ SET entity_id = :entity_id,
243
+ #data.#resource = :resource,
244
+ #data.#window = :window,
245
+ #data.window_start = :window_start,
246
+ GSI2PK = :gsi2pk,
247
+ GSI2SK = :gsi2sk,
248
+ #ttl = :ttl
249
+ ADD #data.#limit_name :delta,
250
+ #data.total_events :one
251
+ """,
252
+ ExpressionAttributeNames={
253
+ "#data": "data",
254
+ "#resource": "resource",
255
+ "#window": "window",
256
+ "#limit_name": delta.limit_name,
257
+ "#ttl": "ttl",
258
+ },
259
+ ExpressionAttributeValues={
260
+ ":entity_id": delta.entity_id,
261
+ ":resource": delta.resource,
262
+ ":window": window,
263
+ ":window_start": window_key,
264
+ ":gsi2pk": gsi2_pk_resource(delta.resource),
265
+ ":gsi2sk": gsi2_sk_usage(window_key, delta.entity_id),
266
+ ":ttl": calculate_snapshot_ttl(ttl_days),
267
+ ":delta": tokens_delta,
268
+ ":one": 1,
269
+ },
270
+ )