rollgate 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rollgate/client.py ADDED
@@ -0,0 +1,562 @@
1
+ """
2
+ Rollgate client for feature flag evaluation.
3
+ """
4
+
5
+ import asyncio
6
+ import logging
7
+ import time
8
+ from dataclasses import dataclass, field
9
+ from typing import Optional, Dict, Callable, Any, List
10
+
11
+ import httpx
12
+ from httpx_sse import aconnect_sse
13
+
14
+ from rollgate.cache import FlagCache, CacheConfig, DEFAULT_CACHE_CONFIG
15
+ from rollgate.circuit_breaker import (
16
+ CircuitBreaker,
17
+ CircuitBreakerConfig,
18
+ CircuitOpenError,
19
+ CircuitState,
20
+ DEFAULT_CIRCUIT_BREAKER_CONFIG,
21
+ )
22
+ from rollgate.retry import (
23
+ RetryConfig,
24
+ DEFAULT_RETRY_CONFIG,
25
+ fetch_with_retry,
26
+ )
27
+ from rollgate.errors import (
28
+ RollgateError,
29
+ AuthenticationError,
30
+ NetworkError,
31
+ RateLimitError,
32
+ classify_error,
33
+ )
34
+ from rollgate.reasons import (
35
+ EvaluationReason,
36
+ EvaluationDetail,
37
+ EvaluationReasonKind,
38
+ EvaluationErrorKind,
39
+ fallthrough_reason,
40
+ error_reason,
41
+ unknown_reason,
42
+ )
43
+
44
+ logger = logging.getLogger("rollgate")
45
+
46
+
47
+ @dataclass
48
+ class UserContext:
49
+ """User context for targeting."""
50
+
51
+ id: str
52
+ email: Optional[str] = None
53
+ attributes: Optional[Dict[str, Any]] = None
54
+
55
+
56
+ @dataclass
57
+ class RollgateConfig:
58
+ """Configuration for Rollgate client."""
59
+
60
+ api_key: str
61
+ """API key for authentication."""
62
+
63
+ base_url: str = "https://api.rollgate.io"
64
+ """Base URL for the API."""
65
+
66
+ refresh_interval_ms: int = 30000
67
+ """Polling interval in milliseconds (default: 30s). Set to 0 to disable."""
68
+
69
+ enable_streaming: bool = False
70
+ """Use SSE for real-time updates (default: False)."""
71
+
72
+ timeout_ms: int = 5000
73
+ """Request timeout in milliseconds."""
74
+
75
+ retry: RetryConfig = field(default_factory=lambda: DEFAULT_RETRY_CONFIG)
76
+ """Retry configuration."""
77
+
78
+ circuit_breaker: CircuitBreakerConfig = field(
79
+ default_factory=lambda: DEFAULT_CIRCUIT_BREAKER_CONFIG
80
+ )
81
+ """Circuit breaker configuration."""
82
+
83
+ cache: CacheConfig = field(default_factory=lambda: DEFAULT_CACHE_CONFIG)
84
+ """Cache configuration."""
85
+
86
+
87
+ class RollgateClient:
88
+ """
89
+ Rollgate feature flag client.
90
+
91
+ Example:
92
+ ```python
93
+ client = RollgateClient(RollgateConfig(api_key="your-api-key"))
94
+ await client.init()
95
+
96
+ if client.is_enabled("my-feature"):
97
+ # Feature is enabled
98
+ pass
99
+
100
+ await client.close()
101
+ ```
102
+ """
103
+
104
+ def __init__(self, config: RollgateConfig, http_client: Optional[httpx.AsyncClient] = None):
105
+ """
106
+ Initialize the Rollgate client.
107
+
108
+ Args:
109
+ config: Client configuration
110
+ http_client: Optional pre-configured httpx.AsyncClient for connection reuse
111
+ """
112
+ self._config = config
113
+ self._flags: Dict[str, bool] = {}
114
+ self._flag_reasons: Dict[str, EvaluationReason] = {}
115
+ self._initialized = False
116
+ self._user_context: Optional[UserContext] = None
117
+ self._circuit_breaker = CircuitBreaker(config.circuit_breaker)
118
+ self._cache = FlagCache(config.cache)
119
+ self._last_etag: Optional[str] = None
120
+ self._poll_task: Optional[asyncio.Task] = None
121
+ self._sse_task: Optional[asyncio.Task] = None
122
+ self._closing = False
123
+
124
+ # Use provided HTTP client or create one
125
+ if http_client:
126
+ self._http_client = http_client
127
+ self._owns_http_client = False
128
+ else:
129
+ self._http_client = httpx.AsyncClient(
130
+ timeout=config.timeout_ms / 1000,
131
+ limits=httpx.Limits(
132
+ max_connections=10,
133
+ max_keepalive_connections=5,
134
+ keepalive_expiry=5,
135
+ ),
136
+ )
137
+ self._owns_http_client = True
138
+
139
+ # Event callbacks
140
+ self._callbacks: Dict[str, List[Callable]] = {
141
+ "ready": [],
142
+ "flags_updated": [],
143
+ "flags_stale": [],
144
+ "flag_changed": [],
145
+ "error": [],
146
+ "circuit_open": [],
147
+ "circuit_closed": [],
148
+ "circuit_half_open": [],
149
+ }
150
+
151
+ # Forward circuit breaker events
152
+ self._circuit_breaker.on("circuit_open", lambda *args: self._emit("circuit_open", *args))
153
+ self._circuit_breaker.on("circuit_closed", lambda: self._emit("circuit_closed"))
154
+ self._circuit_breaker.on("circuit_half_open", lambda: self._emit("circuit_half_open"))
155
+
156
+ def on(self, event: str, callback: Callable) -> "RollgateClient":
157
+ """
158
+ Register an event callback.
159
+
160
+ Args:
161
+ event: Event name
162
+ callback: Callback function
163
+
164
+ Returns:
165
+ Self for chaining
166
+ """
167
+ if event in self._callbacks:
168
+ self._callbacks[event].append(callback)
169
+ return self
170
+
171
+ def off(self, event: str, callback: Callable) -> "RollgateClient":
172
+ """
173
+ Remove an event callback.
174
+
175
+ Args:
176
+ event: Event name
177
+ callback: Callback function
178
+
179
+ Returns:
180
+ Self for chaining
181
+ """
182
+ if event in self._callbacks and callback in self._callbacks[event]:
183
+ self._callbacks[event].remove(callback)
184
+ return self
185
+
186
+ def _emit(self, event: str, *args) -> None:
187
+ """Emit an event to all registered callbacks."""
188
+ for callback in self._callbacks.get(event, []):
189
+ try:
190
+ callback(*args)
191
+ except Exception as e:
192
+ logger.warning(f"Error in event callback: {e}")
193
+
194
+ async def init(self, user: Optional[UserContext] = None) -> None:
195
+ """
196
+ Initialize the client and fetch initial flags.
197
+
198
+ Args:
199
+ user: Optional user context for targeting
200
+ """
201
+ self._user_context = user
202
+
203
+ # Try to load cached flags first
204
+ self._cache.load()
205
+ cached = self._cache.get()
206
+ if cached:
207
+ self._flags = cached.flags.copy()
208
+ if cached.stale:
209
+ self._emit("flags_stale", self.get_all_flags())
210
+
211
+ # Fetch fresh flags
212
+ await self._fetch_flags()
213
+ self._initialized = True
214
+
215
+ # Start background refresh
216
+ if self._config.enable_streaming:
217
+ self._sse_task = asyncio.create_task(self._start_streaming())
218
+ elif self._config.refresh_interval_ms > 0:
219
+ self._poll_task = asyncio.create_task(self._start_polling())
220
+
221
+ self._emit("ready")
222
+
223
+ async def _start_polling(self) -> None:
224
+ """Start background polling for flag updates."""
225
+ while not self._closing:
226
+ await asyncio.sleep(self._config.refresh_interval_ms / 1000)
227
+ if self._closing:
228
+ break
229
+ try:
230
+ await self._fetch_flags()
231
+ except Exception as e:
232
+ logger.warning(f"Polling error: {e}")
233
+
234
+ async def _start_streaming(self) -> None:
235
+ """Start SSE streaming for real-time updates."""
236
+ url = f"{self._config.base_url}/api/v1/sdk/stream"
237
+ params = {}
238
+ if self._user_context:
239
+ params["user_id"] = self._user_context.id
240
+
241
+ headers = {
242
+ "Authorization": f"Bearer {self._config.api_key}",
243
+ }
244
+
245
+ backoff = 1.0 # Start at 1 second (matches Go SDK)
246
+ max_backoff = 30.0
247
+
248
+ while not self._closing:
249
+ try:
250
+ async with aconnect_sse(
251
+ self._http_client,
252
+ "GET",
253
+ url,
254
+ params=params,
255
+ headers=headers,
256
+ ) as event_source:
257
+ # Reset backoff on successful connection
258
+ backoff = 1.0
259
+ async for event in event_source.aiter_sse():
260
+ if self._closing:
261
+ break
262
+ if event.data:
263
+ try:
264
+ import json
265
+
266
+ data = json.loads(event.data)
267
+ new_flags = data.get("flags", {})
268
+ self._update_flags(new_flags)
269
+ except Exception as e:
270
+ logger.warning(f"Failed to parse SSE message: {e}")
271
+ except Exception as e:
272
+ if not self._closing:
273
+ logger.warning(f"SSE connection error, reconnecting in {backoff}s: {e}")
274
+ await asyncio.sleep(backoff)
275
+ # Exponential backoff with cap
276
+ backoff = min(backoff * 2, max_backoff)
277
+
278
+ async def _fetch_flags(self) -> None:
279
+ """Fetch all flags from the API."""
280
+ url = f"{self._config.base_url}/api/v1/sdk/flags"
281
+ params = {"withReasons": "true"}
282
+ if self._user_context:
283
+ params["user_id"] = self._user_context.id
284
+
285
+ # Check if circuit breaker allows the request
286
+ if not self._circuit_breaker.is_allowing_requests():
287
+ logger.warning("Circuit breaker is open, using cached flags")
288
+ self._use_cached_fallback()
289
+ return
290
+
291
+ try:
292
+ # Execute through circuit breaker with retry
293
+ result = await self._circuit_breaker.execute(
294
+ lambda: self._do_fetch_flags(url, params)
295
+ )
296
+
297
+ if result is None:
298
+ # 304 Not Modified
299
+ return
300
+
301
+ # Update cache and flags
302
+ self._cache.set("flags", result)
303
+ self._update_flags(result)
304
+
305
+ except CircuitOpenError:
306
+ logger.warning("Circuit breaker is open")
307
+ self._use_cached_fallback()
308
+ except Exception as e:
309
+ classified = classify_error(e)
310
+ logger.error(f"Error fetching flags: {classified.message}")
311
+ self._emit("error", classified)
312
+ self._use_cached_fallback()
313
+
314
+ async def _do_fetch_flags(self, url: str, params: Dict) -> Optional[Dict[str, bool]]:
315
+ """Execute the actual fetch with retry."""
316
+ result = await fetch_with_retry(
317
+ lambda: self._single_fetch(url, params),
318
+ self._config.retry,
319
+ )
320
+
321
+ if not result.success:
322
+ raise result.error or Exception("Fetch failed")
323
+
324
+ return result.data
325
+
326
+ async def _single_fetch(self, url: str, params: Dict) -> Optional[Dict[str, bool]]:
327
+ """Single fetch attempt."""
328
+ headers = {
329
+ "Authorization": f"Bearer {self._config.api_key}",
330
+ "Content-Type": "application/json",
331
+ }
332
+ if self._last_etag:
333
+ headers["If-None-Match"] = self._last_etag
334
+
335
+ response = await self._http_client.get(url, params=params, headers=headers)
336
+
337
+ # Handle 304 Not Modified
338
+ if response.status_code == 304:
339
+ return None
340
+
341
+ # Handle errors
342
+ if response.status_code == 401 or response.status_code == 403:
343
+ raise AuthenticationError(
344
+ f"Authentication failed: {response.status_code}",
345
+ response.status_code,
346
+ )
347
+
348
+ if response.status_code == 429:
349
+ retry_after = response.headers.get("Retry-After")
350
+ raise RateLimitError(
351
+ "Rate limit exceeded",
352
+ retry_after=int(retry_after) if retry_after else None,
353
+ )
354
+
355
+ if response.status_code >= 500:
356
+ raise RollgateError(
357
+ f"Server error: {response.status_code}",
358
+ status_code=response.status_code,
359
+ retryable=True,
360
+ )
361
+
362
+ if not response.is_success:
363
+ raise RollgateError(
364
+ f"Request failed: {response.status_code}",
365
+ status_code=response.status_code,
366
+ )
367
+
368
+ # Store ETag for conditional requests
369
+ etag = response.headers.get("ETag")
370
+ if etag:
371
+ self._last_etag = etag
372
+
373
+ data = response.json()
374
+ # Store reasons if present
375
+ if "reasons" in data:
376
+ self._flag_reasons = {
377
+ k: EvaluationReason(
378
+ kind=EvaluationReasonKind(v.get("kind", "UNKNOWN")),
379
+ rule_id=v.get("ruleId"),
380
+ rule_index=v.get("ruleIndex"),
381
+ in_rollout=v.get("inRollout"),
382
+ error_kind=EvaluationErrorKind(v["errorKind"]) if v.get("errorKind") else None,
383
+ )
384
+ for k, v in data["reasons"].items()
385
+ }
386
+ return data.get("flags", {})
387
+
388
+ def _update_flags(self, new_flags: Dict[str, bool]) -> None:
389
+ """Update flags and emit change events."""
390
+ old_flags = self._flags.copy()
391
+ self._flags = new_flags.copy()
392
+
393
+ # Emit change events for changed flags
394
+ for key, value in self._flags.items():
395
+ old_value = old_flags.get(key)
396
+ if old_value != value:
397
+ self._emit("flag_changed", key, value, old_value)
398
+
399
+ self._emit("flags_updated", self.get_all_flags())
400
+
401
+ def _use_cached_fallback(self) -> None:
402
+ """Use cached flags as fallback."""
403
+ cached = self._cache.get()
404
+ if cached:
405
+ self._update_flags(cached.flags)
406
+ if cached.stale:
407
+ self._emit("flags_stale", self.get_all_flags())
408
+
409
+ def is_enabled(self, flag_key: str, default_value: bool = False) -> bool:
410
+ """
411
+ Check if a flag is enabled.
412
+
413
+ Args:
414
+ flag_key: The flag key to check
415
+ default_value: Default value if flag not found
416
+
417
+ Returns:
418
+ True if the flag is enabled
419
+ """
420
+ return self.is_enabled_detail(flag_key, default_value).value
421
+
422
+ def is_enabled_detail(
423
+ self, flag_key: str, default_value: bool = False
424
+ ) -> EvaluationDetail[bool]:
425
+ """
426
+ Check if a flag is enabled with evaluation reason.
427
+
428
+ Args:
429
+ flag_key: The flag key to check
430
+ default_value: Default value if flag not found
431
+
432
+ Returns:
433
+ EvaluationDetail containing the value and reason
434
+ """
435
+ if not self._initialized:
436
+ logger.warning("Client not initialized. Call init() first.")
437
+ return EvaluationDetail(
438
+ value=default_value,
439
+ reason=error_reason(EvaluationErrorKind.CLIENT_NOT_READY),
440
+ )
441
+
442
+ if flag_key not in self._flags:
443
+ return EvaluationDetail(
444
+ value=default_value,
445
+ reason=unknown_reason(),
446
+ )
447
+
448
+ value = self._flags[flag_key]
449
+ # Use stored reason from server, or FALLTHROUGH as default
450
+ stored_reason = self._flag_reasons.get(flag_key)
451
+ return EvaluationDetail(
452
+ value=value,
453
+ reason=stored_reason if stored_reason else fallthrough_reason(in_rollout=value),
454
+ )
455
+
456
+ def bool_variation_detail(
457
+ self, flag_key: str, default_value: bool = False
458
+ ) -> EvaluationDetail[bool]:
459
+ """
460
+ Alias for is_enabled_detail for LaunchDarkly compatibility.
461
+
462
+ Args:
463
+ flag_key: The flag key to check
464
+ default_value: Default value if flag not found
465
+
466
+ Returns:
467
+ EvaluationDetail containing the value and reason
468
+ """
469
+ return self.is_enabled_detail(flag_key, default_value)
470
+
471
+ def get_all_flags(self) -> Dict[str, bool]:
472
+ """
473
+ Get all flags as a dictionary.
474
+
475
+ Returns:
476
+ Dictionary of flag keys to boolean values
477
+ """
478
+ return self._flags.copy()
479
+
480
+ async def identify(self, user: UserContext) -> None:
481
+ """
482
+ Update user context and re-fetch flags.
483
+
484
+ Args:
485
+ user: New user context
486
+ """
487
+ self._user_context = user
488
+ await self._fetch_flags()
489
+
490
+ async def reset(self) -> None:
491
+ """Clear user context and re-fetch flags."""
492
+ self._user_context = None
493
+ await self._fetch_flags()
494
+
495
+ async def refresh(self) -> None:
496
+ """Force refresh flags."""
497
+ await self._fetch_flags()
498
+
499
+ @property
500
+ def circuit_state(self) -> CircuitState:
501
+ """Get current circuit breaker state."""
502
+ return self._circuit_breaker.state
503
+
504
+ def get_circuit_stats(self):
505
+ """Get circuit breaker statistics."""
506
+ return self._circuit_breaker.get_stats()
507
+
508
+ def reset_circuit(self) -> None:
509
+ """Force reset the circuit breaker."""
510
+ self._circuit_breaker.force_reset()
511
+
512
+ def get_cache_stats(self):
513
+ """Get cache statistics."""
514
+ return self._cache.get_stats()
515
+
516
+ def get_cache_hit_rate(self) -> float:
517
+ """Get cache hit rate."""
518
+ return self._cache.get_hit_rate()
519
+
520
+ def clear_cache(self) -> None:
521
+ """Clear the cache."""
522
+ self._cache.clear()
523
+
524
+ async def close(self) -> None:
525
+ """Close the client and cleanup resources."""
526
+ self._closing = True
527
+
528
+ # Cancel background tasks
529
+ if self._poll_task:
530
+ self._poll_task.cancel()
531
+ try:
532
+ await self._poll_task
533
+ except asyncio.CancelledError:
534
+ pass
535
+
536
+ if self._sse_task:
537
+ self._sse_task.cancel()
538
+ try:
539
+ await self._sse_task
540
+ except asyncio.CancelledError:
541
+ pass
542
+
543
+ # Close HTTP client only if we own it
544
+ if self._owns_http_client and self._http_client:
545
+ await self._http_client.aclose()
546
+
547
+ # Close cache (also clears its callbacks)
548
+ self._cache.close()
549
+
550
+ # Clear all callbacks to prevent memory leaks
551
+ self._circuit_breaker.clear_callbacks()
552
+ for event in self._callbacks:
553
+ self._callbacks[event].clear()
554
+
555
+ async def __aenter__(self) -> "RollgateClient":
556
+ """Async context manager entry."""
557
+ await self.init()
558
+ return self
559
+
560
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
561
+ """Async context manager exit."""
562
+ await self.close()