rollgate 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rollgate/dedup.py ADDED
@@ -0,0 +1,172 @@
1
+ """
2
+ Request deduplication to prevent duplicate inflight requests.
3
+ """
4
+
5
+ import asyncio
6
+ import time
7
+ from dataclasses import dataclass, field
8
+ from typing import Dict, Optional, Any, Callable, Awaitable, TypeVar, Generic
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ @dataclass
14
+ class DedupConfig:
15
+ """Configuration for request deduplication."""
16
+
17
+ enabled: bool = True
18
+ """Enable request deduplication."""
19
+
20
+ ttl_ms: int = 5000
21
+ """Time-to-live for inflight request tracking (default: 5s)."""
22
+
23
+
24
+ DEFAULT_DEDUP_CONFIG = DedupConfig()
25
+
26
+
27
+ @dataclass
28
+ class InflightRequest(Generic[T]):
29
+ """Represents an inflight request."""
30
+
31
+ future: asyncio.Future
32
+ timestamp: float
33
+ key: str
34
+
35
+
36
+ class RequestDeduplicator:
37
+ """
38
+ Deduplicates concurrent identical requests.
39
+
40
+ When multiple callers request the same resource simultaneously,
41
+ only one actual request is made and the result is shared.
42
+
43
+ Example:
44
+ ```python
45
+ dedup = RequestDeduplicator()
46
+
47
+ # These concurrent calls will result in only one actual fetch
48
+ async with asyncio.TaskGroup() as tg:
49
+ tg.create_task(dedup.dedupe("flags", fetch_flags))
50
+ tg.create_task(dedup.dedupe("flags", fetch_flags))
51
+ tg.create_task(dedup.dedupe("flags", fetch_flags))
52
+ ```
53
+ """
54
+
55
+ def __init__(self, config: DedupConfig = DEFAULT_DEDUP_CONFIG):
56
+ """
57
+ Initialize the deduplicator.
58
+
59
+ Args:
60
+ config: Deduplication configuration
61
+ """
62
+ self._config = config
63
+ self._inflight: Dict[str, InflightRequest] = {}
64
+ self._lock = asyncio.Lock()
65
+
66
+ # Statistics
67
+ self._total_requests = 0
68
+ self._deduplicated_requests = 0
69
+
70
+ async def dedupe(
71
+ self,
72
+ key: str,
73
+ request_fn: Callable[[], Awaitable[T]],
74
+ ) -> T:
75
+ """
76
+ Execute a request with deduplication.
77
+
78
+ If an identical request (by key) is already inflight,
79
+ wait for its result instead of making a new request.
80
+
81
+ Args:
82
+ key: Unique key for this request type
83
+ request_fn: Async function to execute if no inflight request exists
84
+
85
+ Returns:
86
+ Result of the request
87
+ """
88
+ if not self._config.enabled:
89
+ return await request_fn()
90
+
91
+ async with self._lock:
92
+ self._total_requests += 1
93
+
94
+ # Clean up expired inflight requests
95
+ self._cleanup_expired()
96
+
97
+ # Check for existing inflight request
98
+ if key in self._inflight:
99
+ inflight = self._inflight[key]
100
+ self._deduplicated_requests += 1
101
+ # Wait for existing request to complete
102
+ return await inflight.future
103
+
104
+ # Create new inflight request
105
+ loop = asyncio.get_running_loop()
106
+ future: asyncio.Future = loop.create_future()
107
+ self._inflight[key] = InflightRequest(
108
+ future=future,
109
+ timestamp=time.time(),
110
+ key=key,
111
+ )
112
+
113
+ # Execute request outside lock
114
+ try:
115
+ result = await request_fn()
116
+ # Set result for all waiters
117
+ if not future.done():
118
+ future.set_result(result)
119
+ return result
120
+ except Exception as e:
121
+ # Propagate error to all waiters
122
+ if not future.done():
123
+ future.set_exception(e)
124
+ raise
125
+ finally:
126
+ # Remove from inflight
127
+ async with self._lock:
128
+ if key in self._inflight and self._inflight[key].future is future:
129
+ del self._inflight[key]
130
+
131
+ def _cleanup_expired(self) -> None:
132
+ """Remove expired inflight requests."""
133
+ now = time.time()
134
+ ttl_seconds = self._config.ttl_ms / 1000
135
+ expired_keys = [
136
+ key
137
+ for key, req in self._inflight.items()
138
+ if now - req.timestamp > ttl_seconds
139
+ ]
140
+ for key in expired_keys:
141
+ del self._inflight[key]
142
+
143
+ @property
144
+ def inflight_count(self) -> int:
145
+ """Get number of currently inflight requests."""
146
+ return len(self._inflight)
147
+
148
+ def get_stats(self) -> Dict[str, Any]:
149
+ """
150
+ Get deduplication statistics.
151
+
152
+ Returns:
153
+ Dictionary with total_requests, deduplicated_requests, dedup_rate
154
+ """
155
+ total = self._total_requests
156
+ deduped = self._deduplicated_requests
157
+ return {
158
+ "total_requests": total,
159
+ "deduplicated_requests": deduped,
160
+ "dedup_rate": deduped / total if total > 0 else 0,
161
+ "inflight_count": len(self._inflight),
162
+ }
163
+
164
+ def reset_stats(self) -> None:
165
+ """Reset statistics counters."""
166
+ self._total_requests = 0
167
+ self._deduplicated_requests = 0
168
+
169
+ async def clear(self) -> None:
170
+ """Clear all inflight requests."""
171
+ async with self._lock:
172
+ self._inflight.clear()
rollgate/errors.py ADDED
@@ -0,0 +1,162 @@
1
+ """
2
+ Error types for Rollgate SDK.
3
+
4
+ Provides structured error handling with categories for better error management.
5
+ """
6
+
7
+ from enum import Enum
8
+ from typing import Optional
9
+
10
+
11
+ class ErrorCategory(str, Enum):
12
+ """Categories of errors for classification."""
13
+
14
+ AUTH = "auth"
15
+ NETWORK = "network"
16
+ RATE_LIMIT = "rate_limit"
17
+ VALIDATION = "validation"
18
+ NOT_FOUND = "not_found"
19
+ INTERNAL = "internal"
20
+ UNKNOWN = "unknown"
21
+
22
+
23
+ class RollgateError(Exception):
24
+ """Base exception for all Rollgate SDK errors."""
25
+
26
+ def __init__(
27
+ self,
28
+ message: str,
29
+ category: ErrorCategory = ErrorCategory.UNKNOWN,
30
+ status_code: Optional[int] = None,
31
+ retryable: bool = False,
32
+ ):
33
+ super().__init__(message)
34
+ self.message = message
35
+ self.category = category
36
+ self.status_code = status_code
37
+ self.retryable = retryable
38
+
39
+ def __repr__(self) -> str:
40
+ return f"{self.__class__.__name__}(message={self.message!r}, category={self.category})"
41
+
42
+
43
+ class AuthenticationError(RollgateError):
44
+ """Raised when authentication fails (401/403)."""
45
+
46
+ def __init__(self, message: str = "Authentication failed", status_code: int = 401):
47
+ super().__init__(
48
+ message,
49
+ category=ErrorCategory.AUTH,
50
+ status_code=status_code,
51
+ retryable=False,
52
+ )
53
+
54
+
55
+ class NetworkError(RollgateError):
56
+ """Raised when a network error occurs."""
57
+
58
+ def __init__(self, message: str = "Network error"):
59
+ super().__init__(
60
+ message,
61
+ category=ErrorCategory.NETWORK,
62
+ status_code=None,
63
+ retryable=True,
64
+ )
65
+
66
+
67
+ class RateLimitError(RollgateError):
68
+ """Raised when rate limited (429)."""
69
+
70
+ def __init__(
71
+ self,
72
+ message: str = "Rate limit exceeded",
73
+ retry_after: Optional[int] = None,
74
+ ):
75
+ super().__init__(
76
+ message,
77
+ category=ErrorCategory.RATE_LIMIT,
78
+ status_code=429,
79
+ retryable=True,
80
+ )
81
+ self.retry_after = retry_after
82
+
83
+
84
+ class ValidationError(RollgateError):
85
+ """Raised when validation fails (400)."""
86
+
87
+ def __init__(self, message: str = "Validation error"):
88
+ super().__init__(
89
+ message,
90
+ category=ErrorCategory.VALIDATION,
91
+ status_code=400,
92
+ retryable=False,
93
+ )
94
+
95
+
96
+ class NotFoundError(RollgateError):
97
+ """Raised when resource not found (404)."""
98
+
99
+ def __init__(self, message: str = "Resource not found"):
100
+ super().__init__(
101
+ message,
102
+ category=ErrorCategory.NOT_FOUND,
103
+ status_code=404,
104
+ retryable=False,
105
+ )
106
+
107
+
108
+ class InternalError(RollgateError):
109
+ """Raised when server error occurs (5xx)."""
110
+
111
+ def __init__(self, message: str = "Internal server error", status_code: int = 500):
112
+ super().__init__(
113
+ message,
114
+ category=ErrorCategory.INTERNAL,
115
+ status_code=status_code,
116
+ retryable=True,
117
+ )
118
+
119
+
120
+ def classify_error(error: Exception, status_code: Optional[int] = None) -> RollgateError:
121
+ """
122
+ Classify an exception into a RollgateError.
123
+
124
+ Args:
125
+ error: The original exception
126
+ status_code: Optional HTTP status code
127
+
128
+ Returns:
129
+ A classified RollgateError
130
+ """
131
+ if isinstance(error, RollgateError):
132
+ return error
133
+
134
+ message = str(error)
135
+
136
+ # Network errors
137
+ network_indicators = [
138
+ "connection",
139
+ "timeout",
140
+ "econnrefused",
141
+ "etimedout",
142
+ "enotfound",
143
+ "network",
144
+ "dns",
145
+ ]
146
+ if any(indicator in message.lower() for indicator in network_indicators):
147
+ return NetworkError(message)
148
+
149
+ # HTTP status code based classification
150
+ if status_code:
151
+ if status_code == 401 or status_code == 403:
152
+ return AuthenticationError(message, status_code)
153
+ if status_code == 404:
154
+ return NotFoundError(message)
155
+ if status_code == 429:
156
+ return RateLimitError(message)
157
+ if status_code == 400:
158
+ return ValidationError(message)
159
+ if 500 <= status_code < 600:
160
+ return InternalError(message, status_code)
161
+
162
+ return RollgateError(message)
rollgate/evaluate.py ADDED
@@ -0,0 +1,345 @@
1
+ """
2
+ Client-side flag evaluation logic.
3
+ Mirrors the server-side evaluation for consistency.
4
+ """
5
+
6
+ import hashlib
7
+ import re
8
+ from dataclasses import dataclass, field
9
+ from typing import Dict, List, Optional, Any, Union
10
+
11
+
12
+ @dataclass
13
+ class Condition:
14
+ """Represents a targeting condition."""
15
+ attribute: str
16
+ operator: str
17
+ value: str
18
+
19
+
20
+ @dataclass
21
+ class TargetingRule:
22
+ """Represents a targeting rule with conditions."""
23
+ id: str
24
+ enabled: bool
25
+ rollout: int
26
+ conditions: List[Condition] = field(default_factory=list)
27
+ name: Optional[str] = None
28
+
29
+
30
+ @dataclass
31
+ class FlagRule:
32
+ """Represents a feature flag with targeting rules."""
33
+ key: str
34
+ enabled: bool
35
+ rollout: int
36
+ target_users: List[str] = field(default_factory=list)
37
+ rules: List[TargetingRule] = field(default_factory=list)
38
+
39
+
40
+ @dataclass
41
+ class RulesPayload:
42
+ """Represents the rules response from the API."""
43
+ version: str
44
+ flags: Dict[str, FlagRule] = field(default_factory=dict)
45
+
46
+
47
+ @dataclass
48
+ class EvaluationResult:
49
+ """Represents the result of a flag evaluation."""
50
+ enabled: bool
51
+ value: Any
52
+ variation_id: Optional[str] = None
53
+
54
+
55
+ @dataclass
56
+ class UserContext:
57
+ """User context for targeting."""
58
+ id: str
59
+ email: Optional[str] = None
60
+ attributes: Optional[Dict[str, Any]] = None
61
+
62
+
63
+ def evaluate_flag(rule: FlagRule, user: Optional[UserContext]) -> bool:
64
+ """
65
+ Evaluate a flag for a given user context using client-side rules.
66
+
67
+ Evaluation priority:
68
+ 1. If flag is disabled, return false
69
+ 2. If user is in targetUsers list, return true
70
+ 3. If user matches any enabled targeting rule, use rule's rollout
71
+ 4. Otherwise, use flag's default rollout percentage
72
+ """
73
+ # 1. If flag is disabled, always return false
74
+ if not rule.enabled:
75
+ return False
76
+
77
+ # 2. Check if user is in target list
78
+ if user and user.id and rule.target_users:
79
+ if user.id in rule.target_users:
80
+ return True
81
+
82
+ # 3. Check targeting rules
83
+ if user and rule.rules:
84
+ for targeting_rule in rule.rules:
85
+ if targeting_rule.enabled and _matches_rule(targeting_rule, user):
86
+ if targeting_rule.rollout >= 100:
87
+ return True
88
+ if targeting_rule.rollout <= 0:
89
+ return False
90
+ return _is_in_rollout(rule.key, user.id, targeting_rule.rollout)
91
+
92
+ # 4. Default rollout percentage
93
+ if rule.rollout >= 100:
94
+ return True
95
+ if rule.rollout <= 0:
96
+ return False
97
+
98
+ # Use consistent hashing for rollout (requires user ID)
99
+ if not user or not user.id:
100
+ return False
101
+ return _is_in_rollout(rule.key, user.id, rule.rollout)
102
+
103
+
104
+ def _matches_rule(rule: TargetingRule, user: UserContext) -> bool:
105
+ """
106
+ Check if a user matches a targeting rule.
107
+ All conditions within a rule must match (AND logic).
108
+ """
109
+ if not rule.conditions:
110
+ return False
111
+
112
+ for condition in rule.conditions:
113
+ if not _matches_condition(condition, user):
114
+ return False
115
+ return True
116
+
117
+
118
+ def _matches_condition(condition: Condition, user: UserContext) -> bool:
119
+ """Check if a user matches a single condition."""
120
+ attr_value = _get_attribute_value(condition.attribute, user)
121
+ exists = attr_value is not None and str(attr_value) != ""
122
+
123
+ # Handle is_set / is_not_set operators first
124
+ if condition.operator == "is_set":
125
+ return exists
126
+ if condition.operator == "is_not_set":
127
+ return not exists
128
+
129
+ # For other operators, if attribute doesn't exist, condition fails
130
+ if not exists:
131
+ return False
132
+
133
+ value = str(attr_value).lower()
134
+ cond_value = condition.value.lower()
135
+
136
+ if condition.operator == "equals":
137
+ return value == cond_value
138
+ elif condition.operator == "not_equals":
139
+ return value != cond_value
140
+ elif condition.operator == "contains":
141
+ return cond_value in value
142
+ elif condition.operator == "not_contains":
143
+ return cond_value not in value
144
+ elif condition.operator == "starts_with":
145
+ return value.startswith(cond_value)
146
+ elif condition.operator == "ends_with":
147
+ return value.endswith(cond_value)
148
+ elif condition.operator == "in":
149
+ values = [v.strip().lower() for v in condition.value.split(",")]
150
+ return value in values
151
+ elif condition.operator == "not_in":
152
+ values = [v.strip().lower() for v in condition.value.split(",")]
153
+ return value not in values
154
+ elif condition.operator == "greater_than":
155
+ return _compare_numeric(attr_value, condition.value, ">")
156
+ elif condition.operator == "greater_equal":
157
+ return _compare_numeric(attr_value, condition.value, ">=")
158
+ elif condition.operator == "less_than":
159
+ return _compare_numeric(attr_value, condition.value, "<")
160
+ elif condition.operator == "less_equal":
161
+ return _compare_numeric(attr_value, condition.value, "<=")
162
+ elif condition.operator == "regex":
163
+ try:
164
+ return bool(re.match(condition.value, str(attr_value)))
165
+ except re.error:
166
+ return False
167
+ elif condition.operator == "semver_gt":
168
+ return _compare_semver(str(attr_value), condition.value, ">")
169
+ elif condition.operator == "semver_lt":
170
+ return _compare_semver(str(attr_value), condition.value, "<")
171
+ elif condition.operator == "semver_eq":
172
+ return _compare_semver(str(attr_value), condition.value, "=")
173
+ else:
174
+ return False
175
+
176
+
177
+ def _get_attribute_value(attribute: str, user: UserContext) -> Any:
178
+ """Get an attribute value from user context."""
179
+ if user is None:
180
+ return None
181
+ if attribute == "id":
182
+ return user.id
183
+ elif attribute == "email":
184
+ return user.email
185
+ elif user.attributes:
186
+ return user.attributes.get(attribute)
187
+ return None
188
+
189
+
190
+ def _compare_numeric(attr_val: Any, cond_val: str, op: str) -> bool:
191
+ """Compare two numeric values."""
192
+ try:
193
+ a = float(str(attr_val))
194
+ b = float(cond_val)
195
+
196
+ if op == ">":
197
+ return a > b
198
+ elif op == ">=":
199
+ return a >= b
200
+ elif op == "<":
201
+ return a < b
202
+ elif op == "<=":
203
+ return a <= b
204
+ else:
205
+ return False
206
+ except (ValueError, TypeError):
207
+ return False
208
+
209
+
210
+ def _compare_semver(attr_val: str, cond_val: str, op: str) -> bool:
211
+ """Compare two semantic versions."""
212
+ a = _parse_version(attr_val)
213
+ b = _parse_version(cond_val)
214
+ if a is None or b is None:
215
+ return False
216
+
217
+ # Pad lists to same length
218
+ while len(a) < len(b):
219
+ a.append(0)
220
+ while len(b) < len(a):
221
+ b.append(0)
222
+
223
+ # Compare each part
224
+ for i in range(len(a)):
225
+ if a[i] > b[i]:
226
+ return op in (">", ">=")
227
+ if a[i] < b[i]:
228
+ return op in ("<", "<=")
229
+
230
+ # Equal
231
+ return op in ("=", ">=", "<=")
232
+
233
+
234
+ def _parse_version(v: str) -> Optional[List[int]]:
235
+ """Parse a semantic version string."""
236
+ clean = v.lstrip("v")
237
+ parts = clean.split(".")
238
+ try:
239
+ return [int(p) for p in parts]
240
+ except ValueError:
241
+ return None
242
+
243
+
244
+ def _is_in_rollout(flag_key: str, user_id: str, percentage: int) -> bool:
245
+ """
246
+ Consistent hashing for rollout percentage.
247
+ Uses SHA-256 hash of flagKey:userId to ensure:
248
+ - Same user always gets same result for a given flag
249
+ - Distribution is statistically uniform
250
+ """
251
+ hash_input = f"{flag_key}:{user_id}".encode("utf-8")
252
+ hash_bytes = hashlib.sha256(hash_input).digest()
253
+ # Use first 4 bytes as uint32 and mod 100 to get a value 0-99
254
+ value = int.from_bytes(hash_bytes[:4], byteorder="big") % 100
255
+ return value < percentage
256
+
257
+
258
+ def evaluate_all_flags(
259
+ rules: Dict[str, FlagRule],
260
+ user: Optional[UserContext]
261
+ ) -> Dict[str, bool]:
262
+ """Evaluate all flags for a user context."""
263
+ return {key: evaluate_flag(rule, user) for key, rule in rules.items()}
264
+
265
+
266
+ class LocalEvaluator:
267
+ """
268
+ Local evaluator for client-side flag evaluation.
269
+
270
+ Example:
271
+ ```python
272
+ evaluator = LocalEvaluator()
273
+ evaluator.set_rules(rules_payload)
274
+
275
+ user = UserContext(id="user-123", email="user@example.com")
276
+ enabled = evaluator.evaluate("my-feature", user, default_value=False)
277
+ ```
278
+ """
279
+
280
+ def __init__(self):
281
+ """Initialize the local evaluator."""
282
+ self._rules: Dict[str, FlagRule] = {}
283
+ self._version: str = ""
284
+
285
+ def set_rules(self, payload: RulesPayload) -> None:
286
+ """Set the rules for local evaluation."""
287
+ self._rules = payload.flags
288
+ self._version = payload.version
289
+
290
+ def set_rules_from_dict(self, data: Dict[str, Any]) -> None:
291
+ """Set rules from a dictionary (e.g., from JSON)."""
292
+ self._version = data.get("version", "")
293
+ self._rules = {}
294
+
295
+ for key, flag_data in data.get("flags", {}).items():
296
+ rules = []
297
+ for rule_data in flag_data.get("rules", []):
298
+ conditions = [
299
+ Condition(
300
+ attribute=c.get("attribute", ""),
301
+ operator=c.get("operator", ""),
302
+ value=c.get("value", ""),
303
+ )
304
+ for c in rule_data.get("conditions", [])
305
+ ]
306
+ rules.append(TargetingRule(
307
+ id=rule_data.get("id", ""),
308
+ name=rule_data.get("name"),
309
+ enabled=rule_data.get("enabled", False),
310
+ rollout=rule_data.get("rollout", 0),
311
+ conditions=conditions,
312
+ ))
313
+
314
+ self._rules[key] = FlagRule(
315
+ key=key,
316
+ enabled=flag_data.get("enabled", False),
317
+ rollout=flag_data.get("rollout", 0),
318
+ target_users=flag_data.get("targetUsers", []),
319
+ rules=rules,
320
+ )
321
+
322
+ @property
323
+ def version(self) -> str:
324
+ """Get the current rules version."""
325
+ return self._version
326
+
327
+ def evaluate(
328
+ self,
329
+ flag_key: str,
330
+ user: Optional[UserContext],
331
+ default_value: bool = False
332
+ ) -> bool:
333
+ """Evaluate a single flag."""
334
+ rule = self._rules.get(flag_key)
335
+ if rule is None:
336
+ return default_value
337
+ return evaluate_flag(rule, user)
338
+
339
+ def evaluate_all(self, user: Optional[UserContext]) -> Dict[str, bool]:
340
+ """Evaluate all flags."""
341
+ return evaluate_all_flags(self._rules, user)
342
+
343
+ def has_flag(self, flag_key: str) -> bool:
344
+ """Check if a flag exists."""
345
+ return flag_key in self._rules