rollgate 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rollgate/retry.py ADDED
@@ -0,0 +1,177 @@
1
+ """
2
+ Retry utility with exponential backoff and jitter.
3
+ """
4
+
5
+ import asyncio
6
+ import random
7
+ from dataclasses import dataclass, field
8
+ from typing import Callable, TypeVar, Optional, Awaitable, Generic
9
+
10
+ from rollgate.errors import RollgateError
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ @dataclass
16
+ class RetryConfig:
17
+ """Configuration for retry behavior."""
18
+
19
+ max_retries: int = 3
20
+ """Maximum number of retry attempts."""
21
+
22
+ base_delay_ms: int = 100
23
+ """Base delay in milliseconds."""
24
+
25
+ max_delay_ms: int = 10000
26
+ """Maximum delay in milliseconds."""
27
+
28
+ jitter_factor: float = 0.1
29
+ """Jitter factor 0-1 to randomize delays."""
30
+
31
+
32
+ @dataclass
33
+ class RetryResult(Generic[T]):
34
+ """Result of a retry operation."""
35
+
36
+ success: bool
37
+ data: Optional[T] = None
38
+ error: Optional[Exception] = None
39
+ attempts: int = 1
40
+
41
+
42
+ DEFAULT_RETRY_CONFIG = RetryConfig()
43
+
44
+
45
+ def calculate_backoff(attempt: int, config: RetryConfig) -> float:
46
+ """
47
+ Calculate backoff delay with exponential increase and jitter.
48
+
49
+ Args:
50
+ attempt: Current attempt number (0-indexed)
51
+ config: Retry configuration
52
+
53
+ Returns:
54
+ Delay in seconds
55
+ """
56
+ # Exponential: base_delay * 2^attempt
57
+ exponential_delay = config.base_delay_ms * (2**attempt)
58
+
59
+ # Cap at max_delay
60
+ capped_delay = min(exponential_delay, config.max_delay_ms)
61
+
62
+ # Add jitter: random value between -jitter and +jitter
63
+ jitter = capped_delay * config.jitter_factor * (random.random() * 2 - 1)
64
+
65
+ delay_ms = max(0, capped_delay + jitter)
66
+ return delay_ms / 1000.0 # Convert to seconds
67
+
68
+
69
+ def is_retryable_error(error: Exception) -> bool:
70
+ """
71
+ Check if an error is retryable.
72
+
73
+ Args:
74
+ error: The exception to check
75
+
76
+ Returns:
77
+ True if the error should be retried
78
+ """
79
+ if isinstance(error, RollgateError):
80
+ return error.retryable
81
+
82
+ message = str(error).lower()
83
+
84
+ # Network errors (always retry)
85
+ network_indicators = [
86
+ "econnrefused",
87
+ "etimedout",
88
+ "enotfound",
89
+ "econnreset",
90
+ "network",
91
+ "connection",
92
+ "timeout",
93
+ "dns",
94
+ ]
95
+ if any(indicator in message for indicator in network_indicators):
96
+ return True
97
+
98
+ # HTTP 5xx errors (server issues, retry)
99
+ server_errors = ["500", "502", "503", "504"]
100
+ if any(code in message for code in server_errors):
101
+ return True
102
+
103
+ # Rate limiting (retry with backoff)
104
+ if "429" in message or "too many requests" in message:
105
+ return True
106
+
107
+ # HTTP 4xx errors (client errors, don't retry)
108
+ client_errors = ["400", "401", "403", "404"]
109
+ if any(code in message for code in client_errors):
110
+ return False
111
+
112
+ return False
113
+
114
+
115
+ async def fetch_with_retry(
116
+ fn: Callable[[], Awaitable[T]],
117
+ config: Optional[RetryConfig] = None,
118
+ ) -> RetryResult[T]:
119
+ """
120
+ Execute an async function with retry logic and exponential backoff.
121
+
122
+ Args:
123
+ fn: Async function to execute
124
+ config: Retry configuration
125
+
126
+ Returns:
127
+ RetryResult with success status and data/error
128
+ """
129
+ cfg = config or DEFAULT_RETRY_CONFIG
130
+ last_error: Optional[Exception] = None
131
+
132
+ for attempt in range(cfg.max_retries + 1):
133
+ try:
134
+ data = await fn()
135
+ return RetryResult(success=True, data=data, attempts=attempt + 1)
136
+ except Exception as error:
137
+ last_error = error
138
+
139
+ # Don't retry non-retryable errors
140
+ if not is_retryable_error(error):
141
+ return RetryResult(success=False, error=error, attempts=attempt + 1)
142
+
143
+ # Don't sleep after the last attempt
144
+ if attempt < cfg.max_retries:
145
+ delay = calculate_backoff(attempt, cfg)
146
+ await asyncio.sleep(delay)
147
+
148
+ return RetryResult(
149
+ success=False,
150
+ error=last_error or Exception("Retry exhausted"),
151
+ attempts=cfg.max_retries + 1,
152
+ )
153
+
154
+
155
+ async def retry_async(
156
+ fn: Callable[[], Awaitable[T]],
157
+ config: Optional[RetryConfig] = None,
158
+ ) -> T:
159
+ """
160
+ Execute an async function with retry, raising on failure.
161
+
162
+ Args:
163
+ fn: Async function to execute
164
+ config: Retry configuration
165
+
166
+ Returns:
167
+ Result of the function
168
+
169
+ Raises:
170
+ Exception: The last error if all retries fail
171
+ """
172
+ result = await fetch_with_retry(fn, config)
173
+
174
+ if not result.success:
175
+ raise result.error # type: ignore
176
+
177
+ return result.data # type: ignore
rollgate/tracing.py ADDED
@@ -0,0 +1,434 @@
1
+ """
2
+ W3C Trace Context support for distributed tracing.
3
+ Implements traceparent header format for request correlation.
4
+ """
5
+
6
+ import re
7
+ import time
8
+ import secrets
9
+ from dataclasses import dataclass, field
10
+ from typing import Dict, Optional, List
11
+ from contextlib import contextmanager
12
+
13
+
14
+ # W3C Trace Context header names
15
+ HEADER_TRACEPARENT = "traceparent"
16
+ HEADER_TRACESTATE = "tracestate"
17
+ HEADER_TRACE_ID = "x-trace-id"
18
+ HEADER_SPAN_ID = "x-span-id"
19
+ HEADER_REQUEST_ID = "x-request-id"
20
+
21
+ # W3C traceparent format: version-trace_id-parent_id-flags
22
+ # Example: 00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01
23
+ TRACEPARENT_REGEX = re.compile(
24
+ r"^([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})$"
25
+ )
26
+
27
+
28
+ def generate_trace_id() -> str:
29
+ """Generate a 32-character hex trace ID."""
30
+ return secrets.token_hex(16)
31
+
32
+
33
+ def generate_span_id() -> str:
34
+ """Generate a 16-character hex span ID."""
35
+ return secrets.token_hex(8)
36
+
37
+
38
+ def generate_request_id() -> str:
39
+ """Generate a unique request ID."""
40
+ return f"req_{secrets.token_hex(12)}"
41
+
42
+
43
+ @dataclass
44
+ class TraceContext:
45
+ """
46
+ Represents W3C Trace Context for distributed tracing.
47
+
48
+ The traceparent header format is:
49
+ {version}-{trace-id}-{parent-id}-{flags}
50
+
51
+ Example:
52
+ 00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01
53
+
54
+ Where:
55
+ - version: 2 hex chars (always "00" for current version)
56
+ - trace-id: 32 hex chars
57
+ - parent-id: 16 hex chars
58
+ - flags: 2 hex chars (01 = sampled)
59
+ """
60
+
61
+ trace_id: str = field(default_factory=generate_trace_id)
62
+ """32-character hex trace ID."""
63
+
64
+ span_id: str = field(default_factory=generate_span_id)
65
+ """16-character hex span ID."""
66
+
67
+ parent_id: Optional[str] = None
68
+ """Parent span ID for nested spans."""
69
+
70
+ request_id: str = field(default_factory=generate_request_id)
71
+ """Human-readable request ID."""
72
+
73
+ sampled: bool = True
74
+ """Whether this trace should be sampled."""
75
+
76
+ def get_headers(self) -> Dict[str, str]:
77
+ """
78
+ Get headers to propagate trace context.
79
+
80
+ Returns:
81
+ Dictionary of headers to add to outgoing requests
82
+ """
83
+ flags = "01" if self.sampled else "00"
84
+ traceparent = f"00-{self.trace_id}-{self.span_id}-{flags}"
85
+
86
+ return {
87
+ HEADER_TRACEPARENT: traceparent,
88
+ HEADER_TRACE_ID: self.trace_id,
89
+ HEADER_SPAN_ID: self.span_id,
90
+ HEADER_REQUEST_ID: self.request_id,
91
+ }
92
+
93
+ def create_child(self) -> "TraceContext":
94
+ """
95
+ Create a child span context.
96
+
97
+ Returns:
98
+ New TraceContext with same trace_id but new span_id
99
+ """
100
+ return TraceContext(
101
+ trace_id=self.trace_id,
102
+ span_id=generate_span_id(),
103
+ parent_id=self.span_id,
104
+ request_id=self.request_id,
105
+ sampled=self.sampled,
106
+ )
107
+
108
+ @classmethod
109
+ def from_traceparent(cls, traceparent: str) -> Optional["TraceContext"]:
110
+ """
111
+ Parse a traceparent header.
112
+
113
+ Args:
114
+ traceparent: W3C traceparent header value
115
+
116
+ Returns:
117
+ TraceContext if valid, None if invalid
118
+ """
119
+ match = TRACEPARENT_REGEX.match(traceparent.lower())
120
+ if not match:
121
+ return None
122
+
123
+ version, trace_id, parent_id, flags = match.groups()
124
+
125
+ # Only support version 00
126
+ if version != "00":
127
+ return None
128
+
129
+ return cls(
130
+ trace_id=trace_id,
131
+ span_id=generate_span_id(),
132
+ parent_id=parent_id,
133
+ sampled=flags == "01",
134
+ )
135
+
136
+ @classmethod
137
+ def from_headers(cls, headers: Dict[str, str]) -> Optional["TraceContext"]:
138
+ """
139
+ Extract trace context from request headers.
140
+
141
+ Args:
142
+ headers: Request headers (case-insensitive)
143
+
144
+ Returns:
145
+ TraceContext if found, None otherwise
146
+ """
147
+ # Normalize header names to lowercase
148
+ normalized = {k.lower(): v for k, v in headers.items()}
149
+
150
+ # Try W3C traceparent first
151
+ traceparent = normalized.get(HEADER_TRACEPARENT)
152
+ if traceparent:
153
+ ctx = cls.from_traceparent(traceparent)
154
+ if ctx:
155
+ # Preserve request ID if present
156
+ request_id = normalized.get(HEADER_REQUEST_ID)
157
+ if request_id:
158
+ ctx.request_id = request_id
159
+ return ctx
160
+
161
+ # Fall back to custom headers
162
+ trace_id = normalized.get(HEADER_TRACE_ID)
163
+ span_id = normalized.get(HEADER_SPAN_ID)
164
+ request_id = normalized.get(HEADER_REQUEST_ID)
165
+
166
+ if trace_id:
167
+ return cls(
168
+ trace_id=trace_id,
169
+ span_id=span_id or generate_span_id(),
170
+ request_id=request_id or generate_request_id(),
171
+ )
172
+
173
+ return None
174
+
175
+
176
+ @dataclass
177
+ class RequestTrace:
178
+ """
179
+ Tracks timing for a single request.
180
+
181
+ Example:
182
+ ```python
183
+ trace = RequestTrace(
184
+ context=TraceContext(),
185
+ endpoint="/api/v1/flags",
186
+ )
187
+ trace.start()
188
+ # ... make request ...
189
+ trace.finish(200)
190
+ print(f"Latency: {trace.latency_ms}ms")
191
+ ```
192
+ """
193
+
194
+ context: TraceContext
195
+ """Trace context for this request."""
196
+
197
+ endpoint: str
198
+ """API endpoint being called."""
199
+
200
+ start_time: float = 0
201
+ """Start timestamp (seconds since epoch)."""
202
+
203
+ end_time: float = 0
204
+ """End timestamp (seconds since epoch)."""
205
+
206
+ status_code: int = 0
207
+ """HTTP status code."""
208
+
209
+ error: Optional[str] = None
210
+ """Error message if request failed."""
211
+
212
+ def start(self) -> None:
213
+ """Mark request start time."""
214
+ self.start_time = time.time()
215
+
216
+ def finish(self, status_code: int, error: Optional[str] = None) -> None:
217
+ """
218
+ Mark request end time.
219
+
220
+ Args:
221
+ status_code: HTTP status code
222
+ error: Error message if failed
223
+ """
224
+ self.end_time = time.time()
225
+ self.status_code = status_code
226
+ self.error = error
227
+
228
+ @property
229
+ def latency_ms(self) -> float:
230
+ """Get request latency in milliseconds."""
231
+ if self.end_time == 0 or self.start_time == 0:
232
+ return 0
233
+ return (self.end_time - self.start_time) * 1000
234
+
235
+ @property
236
+ def success(self) -> bool:
237
+ """Check if request was successful."""
238
+ return 200 <= self.status_code < 400 and self.error is None
239
+
240
+
241
+ class TracingManager:
242
+ """
243
+ Manages trace contexts for the SDK.
244
+
245
+ Thread-safe management of trace context propagation.
246
+
247
+ Example:
248
+ ```python
249
+ tracer = TracingManager()
250
+
251
+ # Create a new trace
252
+ ctx = tracer.create_context()
253
+
254
+ # Track a request
255
+ with tracer.trace_request("/api/v1/flags") as trace:
256
+ response = await client.get(url, headers=trace.context.get_headers())
257
+ trace.finish(response.status_code)
258
+ ```
259
+ """
260
+
261
+ def __init__(self, enabled: bool = True, sample_rate: float = 1.0):
262
+ """
263
+ Initialize tracing manager.
264
+
265
+ Args:
266
+ enabled: Whether tracing is enabled
267
+ sample_rate: Fraction of requests to sample (0.0 to 1.0)
268
+ """
269
+ self._enabled = enabled
270
+ self._sample_rate = sample_rate
271
+ self._traces: List[RequestTrace] = []
272
+ self._max_traces = 1000
273
+
274
+ @property
275
+ def enabled(self) -> bool:
276
+ """Check if tracing is enabled."""
277
+ return self._enabled
278
+
279
+ @enabled.setter
280
+ def enabled(self, value: bool) -> None:
281
+ """Enable or disable tracing."""
282
+ self._enabled = value
283
+
284
+ def create_context(
285
+ self,
286
+ parent: Optional[TraceContext] = None,
287
+ ) -> TraceContext:
288
+ """
289
+ Create a new trace context.
290
+
291
+ Args:
292
+ parent: Optional parent context for nested spans
293
+
294
+ Returns:
295
+ New TraceContext
296
+ """
297
+ if parent:
298
+ return parent.create_child()
299
+
300
+ # Determine if this trace should be sampled
301
+ sampled = self._enabled and (secrets.randbelow(100) / 100 < self._sample_rate)
302
+
303
+ return TraceContext(sampled=sampled)
304
+
305
+ def extract_context(self, headers: Dict[str, str]) -> Optional[TraceContext]:
306
+ """
307
+ Extract trace context from headers.
308
+
309
+ Args:
310
+ headers: Request headers
311
+
312
+ Returns:
313
+ TraceContext if found
314
+ """
315
+ if not self._enabled:
316
+ return None
317
+ return TraceContext.from_headers(headers)
318
+
319
+ def inject_headers(
320
+ self,
321
+ headers: Dict[str, str],
322
+ context: Optional[TraceContext] = None,
323
+ ) -> Dict[str, str]:
324
+ """
325
+ Inject trace context into headers.
326
+
327
+ Args:
328
+ headers: Existing headers
329
+ context: Trace context (creates new if None)
330
+
331
+ Returns:
332
+ Headers with trace context added
333
+ """
334
+ if not self._enabled:
335
+ return headers
336
+
337
+ ctx = context or self.create_context()
338
+ result = dict(headers)
339
+ result.update(ctx.get_headers())
340
+ return result
341
+
342
+ @contextmanager
343
+ def trace_request(
344
+ self,
345
+ endpoint: str,
346
+ parent: Optional[TraceContext] = None,
347
+ ):
348
+ """
349
+ Context manager for tracing a request.
350
+
351
+ Args:
352
+ endpoint: API endpoint
353
+ parent: Optional parent context
354
+
355
+ Yields:
356
+ RequestTrace to record timing
357
+ """
358
+ ctx = self.create_context(parent)
359
+ trace = RequestTrace(context=ctx, endpoint=endpoint)
360
+ trace.start()
361
+
362
+ try:
363
+ yield trace
364
+ finally:
365
+ # Store trace if enabled
366
+ if self._enabled and trace.end_time > 0:
367
+ self._traces.append(trace)
368
+ if len(self._traces) > self._max_traces:
369
+ self._traces.pop(0)
370
+
371
+ def get_recent_traces(self, limit: int = 100) -> List[RequestTrace]:
372
+ """
373
+ Get recent request traces.
374
+
375
+ Args:
376
+ limit: Maximum number of traces to return
377
+
378
+ Returns:
379
+ List of recent RequestTrace objects
380
+ """
381
+ return self._traces[-limit:]
382
+
383
+ def clear_traces(self) -> None:
384
+ """Clear all stored traces."""
385
+ self._traces = []
386
+
387
+ def get_stats(self) -> Dict[str, float]:
388
+ """
389
+ Get tracing statistics.
390
+
391
+ Returns:
392
+ Dictionary with trace_count, avg_latency_ms, error_rate
393
+ """
394
+ traces = self._traces
395
+ if not traces:
396
+ return {
397
+ "trace_count": 0,
398
+ "avg_latency_ms": 0,
399
+ "error_rate": 0,
400
+ }
401
+
402
+ completed = [t for t in traces if t.end_time > 0]
403
+ if not completed:
404
+ return {
405
+ "trace_count": len(traces),
406
+ "avg_latency_ms": 0,
407
+ "error_rate": 0,
408
+ }
409
+
410
+ total_latency = sum(t.latency_ms for t in completed)
411
+ errors = sum(1 for t in completed if not t.success)
412
+
413
+ return {
414
+ "trace_count": len(traces),
415
+ "avg_latency_ms": total_latency / len(completed),
416
+ "error_rate": errors / len(completed),
417
+ }
418
+
419
+
420
+ # Global tracer instance
421
+ _global_tracer: Optional[TracingManager] = None
422
+
423
+
424
+ def get_tracer() -> TracingManager:
425
+ """Get or create the global tracer instance."""
426
+ global _global_tracer
427
+ if _global_tracer is None:
428
+ _global_tracer = TracingManager()
429
+ return _global_tracer
430
+
431
+
432
+ def create_tracer(enabled: bool = True, sample_rate: float = 1.0) -> TracingManager:
433
+ """Create a new tracer instance."""
434
+ return TracingManager(enabled=enabled, sample_rate=sample_rate)