zae-limiter 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zae_limiter/lease.py ADDED
@@ -0,0 +1,196 @@
1
+ """Lease management for rate limit acquisitions."""
2
+
3
+ import asyncio
4
+ import time
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ from .bucket import calculate_available, force_consume, try_consume
9
+ from .exceptions import RateLimitExceeded
10
+ from .models import BucketState, Limit, LimitStatus
11
+
12
+ if TYPE_CHECKING:
13
+ from .repository import Repository
14
+
15
+
16
+ @dataclass
17
+ class LeaseEntry:
18
+ """Tracks a single bucket within a lease."""
19
+
20
+ entity_id: str
21
+ resource: str
22
+ limit: Limit
23
+ state: BucketState
24
+ consumed: int = 0 # total consumed during this lease
25
+
26
+
27
+ @dataclass
28
+ class Lease:
29
+ """
30
+ Manages an active rate limit acquisition.
31
+
32
+ Tracks consumption across multiple entities/limits and handles
33
+ rollback on exception.
34
+ """
35
+
36
+ repository: "Repository"
37
+ entries: list[LeaseEntry] = field(default_factory=list)
38
+ _committed: bool = False
39
+ _rolled_back: bool = False
40
+
41
+ @property
42
+ def consumed(self) -> dict[str, int]:
43
+ """Total consumed amounts by limit name."""
44
+ result: dict[str, int] = {}
45
+ for entry in self.entries:
46
+ name = entry.limit.name
47
+ result[name] = result.get(name, 0) + entry.consumed
48
+ return result
49
+
50
+ async def consume(self, **amounts: int) -> None:
51
+ """
52
+ Consume additional capacity from the buckets.
53
+
54
+ Raises RateLimitExceeded if any bucket has insufficient capacity.
55
+
56
+ Args:
57
+ **amounts: Mapping of limit_name -> amount to consume
58
+ """
59
+ if self._committed or self._rolled_back:
60
+ raise RuntimeError("Lease is no longer active")
61
+
62
+ now_ms = int(time.time() * 1000)
63
+ statuses: list[LimitStatus] = []
64
+ updates: list[tuple[LeaseEntry, int, int]] = [] # (entry, new_tokens, new_refill)
65
+
66
+ # Check all limits first
67
+ for entry in self.entries:
68
+ amount = amounts.get(entry.limit.name, 0)
69
+ if amount <= 0:
70
+ continue
71
+
72
+ result = try_consume(entry.state, amount, now_ms)
73
+
74
+ status = LimitStatus(
75
+ entity_id=entry.entity_id,
76
+ resource=entry.resource,
77
+ limit_name=entry.limit.name,
78
+ limit=entry.limit,
79
+ available=result.available,
80
+ requested=amount,
81
+ exceeded=not result.success,
82
+ retry_after_seconds=result.retry_after_seconds,
83
+ )
84
+ statuses.append(status)
85
+
86
+ if result.success:
87
+ updates.append((entry, result.new_tokens_milli, result.new_last_refill_ms))
88
+
89
+ # Also include statuses for limits not being consumed (for full visibility)
90
+ consumed_names = set(amounts.keys())
91
+ for entry in self.entries:
92
+ if entry.limit.name not in consumed_names:
93
+ available = calculate_available(entry.state, now_ms)
94
+ statuses.append(
95
+ LimitStatus(
96
+ entity_id=entry.entity_id,
97
+ resource=entry.resource,
98
+ limit_name=entry.limit.name,
99
+ limit=entry.limit,
100
+ available=available,
101
+ requested=0,
102
+ exceeded=False,
103
+ retry_after_seconds=0.0,
104
+ )
105
+ )
106
+
107
+ # Check for violations
108
+ violations = [s for s in statuses if s.exceeded]
109
+ if violations:
110
+ raise RateLimitExceeded(statuses)
111
+
112
+ # Apply updates to local state (will be persisted on commit)
113
+ for entry, new_tokens, new_refill in updates:
114
+ entry.state.tokens_milli = new_tokens
115
+ entry.state.last_refill_ms = new_refill
116
+ entry.consumed += amounts.get(entry.limit.name, 0)
117
+
118
+ async def adjust(self, **amounts: int) -> None:
119
+ """
120
+ Adjust consumption by delta (positive or negative).
121
+
122
+ Never raises - allows bucket to go negative.
123
+ Use for post-hoc reconciliation (e.g., LLM token counts).
124
+
125
+ Args:
126
+ **amounts: Mapping of limit_name -> delta (positive = consume more)
127
+ """
128
+ if self._committed or self._rolled_back:
129
+ raise RuntimeError("Lease is no longer active")
130
+
131
+ now_ms = int(time.time() * 1000)
132
+
133
+ for entry in self.entries:
134
+ amount = amounts.get(entry.limit.name, 0)
135
+ if amount == 0:
136
+ continue
137
+
138
+ new_tokens, new_refill = force_consume(entry.state, amount, now_ms)
139
+ entry.state.tokens_milli = new_tokens
140
+ entry.state.last_refill_ms = new_refill
141
+ entry.consumed += amount
142
+
143
+ async def release(self, **amounts: int) -> None:
144
+ """
145
+ Return unused capacity to bucket.
146
+
147
+ Convenience wrapper for adjust() with negated values.
148
+
149
+ Args:
150
+ **amounts: Mapping of limit_name -> amount to return
151
+ """
152
+ negated = {k: -v for k, v in amounts.items()}
153
+ await self.adjust(**negated)
154
+
155
+ async def _commit(self) -> None:
156
+ """Persist the final bucket states to DynamoDB."""
157
+ if self._committed or self._rolled_back:
158
+ return
159
+
160
+ self._committed = True
161
+
162
+ # Build transaction items
163
+ items: list[dict[str, Any]] = []
164
+ for entry in self.entries:
165
+ items.append(self.repository.build_bucket_put_item(entry.state))
166
+
167
+ await self.repository.transact_write(items)
168
+
169
+ async def _rollback(self) -> None:
170
+ """Rollback is implicit - we just don't commit."""
171
+ self._rolled_back = True
172
+
173
+
174
+ class SyncLease:
175
+ """Synchronous wrapper for Lease."""
176
+
177
+ def __init__(self, lease: Lease, loop: asyncio.AbstractEventLoop) -> None:
178
+ self._lease = lease
179
+ self._loop = loop
180
+
181
+ @property
182
+ def consumed(self) -> dict[str, int]:
183
+ """Total consumed amounts by limit name."""
184
+ return self._lease.consumed
185
+
186
+ def consume(self, **amounts: int) -> None:
187
+ """Consume additional capacity from the buckets."""
188
+ self._loop.run_until_complete(self._lease.consume(**amounts))
189
+
190
+ def adjust(self, **amounts: int) -> None:
191
+ """Adjust consumption by delta."""
192
+ self._loop.run_until_complete(self._lease.adjust(**amounts))
193
+
194
+ def release(self, **amounts: int) -> None:
195
+ """Return unused capacity to bucket."""
196
+ self._loop.run_until_complete(self._lease.release(**amounts))