amazon-ads-mcp 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. amazon_ads_mcp/__init__.py +11 -0
  2. amazon_ads_mcp/auth/__init__.py +33 -0
  3. amazon_ads_mcp/auth/base.py +211 -0
  4. amazon_ads_mcp/auth/hooks.py +172 -0
  5. amazon_ads_mcp/auth/manager.py +791 -0
  6. amazon_ads_mcp/auth/oauth_state_store.py +277 -0
  7. amazon_ads_mcp/auth/providers/__init__.py +14 -0
  8. amazon_ads_mcp/auth/providers/direct.py +393 -0
  9. amazon_ads_mcp/auth/providers/example_auth0.py.example +216 -0
  10. amazon_ads_mcp/auth/providers/openbridge.py +512 -0
  11. amazon_ads_mcp/auth/registry.py +146 -0
  12. amazon_ads_mcp/auth/secure_token_store.py +297 -0
  13. amazon_ads_mcp/auth/token_store.py +723 -0
  14. amazon_ads_mcp/config/__init__.py +5 -0
  15. amazon_ads_mcp/config/sampling.py +111 -0
  16. amazon_ads_mcp/config/settings.py +366 -0
  17. amazon_ads_mcp/exceptions.py +314 -0
  18. amazon_ads_mcp/middleware/__init__.py +11 -0
  19. amazon_ads_mcp/middleware/authentication.py +1474 -0
  20. amazon_ads_mcp/middleware/caching.py +177 -0
  21. amazon_ads_mcp/middleware/oauth.py +175 -0
  22. amazon_ads_mcp/middleware/sampling.py +112 -0
  23. amazon_ads_mcp/models/__init__.py +320 -0
  24. amazon_ads_mcp/models/amc_models.py +837 -0
  25. amazon_ads_mcp/models/api_responses.py +847 -0
  26. amazon_ads_mcp/models/base_models.py +215 -0
  27. amazon_ads_mcp/models/builtin_responses.py +496 -0
  28. amazon_ads_mcp/models/dsp_models.py +556 -0
  29. amazon_ads_mcp/models/stores_brands.py +610 -0
  30. amazon_ads_mcp/server/__init__.py +6 -0
  31. amazon_ads_mcp/server/__main__.py +6 -0
  32. amazon_ads_mcp/server/builtin_prompts.py +269 -0
  33. amazon_ads_mcp/server/builtin_tools.py +962 -0
  34. amazon_ads_mcp/server/file_routes.py +547 -0
  35. amazon_ads_mcp/server/html_templates.py +149 -0
  36. amazon_ads_mcp/server/mcp_server.py +327 -0
  37. amazon_ads_mcp/server/openapi_utils.py +158 -0
  38. amazon_ads_mcp/server/sampling_handler.py +251 -0
  39. amazon_ads_mcp/server/server_builder.py +751 -0
  40. amazon_ads_mcp/server/sidecar_loader.py +178 -0
  41. amazon_ads_mcp/server/transform_executor.py +827 -0
  42. amazon_ads_mcp/tools/__init__.py +22 -0
  43. amazon_ads_mcp/tools/cache_management.py +105 -0
  44. amazon_ads_mcp/tools/download_tools.py +267 -0
  45. amazon_ads_mcp/tools/identity.py +236 -0
  46. amazon_ads_mcp/tools/oauth.py +598 -0
  47. amazon_ads_mcp/tools/profile.py +150 -0
  48. amazon_ads_mcp/tools/profile_listing.py +285 -0
  49. amazon_ads_mcp/tools/region.py +320 -0
  50. amazon_ads_mcp/tools/region_identity.py +175 -0
  51. amazon_ads_mcp/utils/__init__.py +6 -0
  52. amazon_ads_mcp/utils/async_compat.py +215 -0
  53. amazon_ads_mcp/utils/errors.py +452 -0
  54. amazon_ads_mcp/utils/export_content_type_resolver.py +249 -0
  55. amazon_ads_mcp/utils/export_download_handler.py +579 -0
  56. amazon_ads_mcp/utils/header_resolver.py +81 -0
  57. amazon_ads_mcp/utils/http/__init__.py +56 -0
  58. amazon_ads_mcp/utils/http/circuit_breaker.py +127 -0
  59. amazon_ads_mcp/utils/http/client_manager.py +329 -0
  60. amazon_ads_mcp/utils/http/request.py +207 -0
  61. amazon_ads_mcp/utils/http/resilience.py +512 -0
  62. amazon_ads_mcp/utils/http/resilient_client.py +195 -0
  63. amazon_ads_mcp/utils/http/retry.py +76 -0
  64. amazon_ads_mcp/utils/http_client.py +873 -0
  65. amazon_ads_mcp/utils/media/__init__.py +21 -0
  66. amazon_ads_mcp/utils/media/negotiator.py +243 -0
  67. amazon_ads_mcp/utils/media/types.py +199 -0
  68. amazon_ads_mcp/utils/openapi/__init__.py +16 -0
  69. amazon_ads_mcp/utils/openapi/json.py +55 -0
  70. amazon_ads_mcp/utils/openapi/loader.py +263 -0
  71. amazon_ads_mcp/utils/openapi/refs.py +46 -0
  72. amazon_ads_mcp/utils/region_config.py +200 -0
  73. amazon_ads_mcp/utils/response_wrapper.py +171 -0
  74. amazon_ads_mcp/utils/sampling_helpers.py +156 -0
  75. amazon_ads_mcp/utils/sampling_wrapper.py +173 -0
  76. amazon_ads_mcp/utils/security.py +630 -0
  77. amazon_ads_mcp/utils/tool_naming.py +137 -0
  78. amazon_ads_mcp-0.2.7.dist-info/METADATA +664 -0
  79. amazon_ads_mcp-0.2.7.dist-info/RECORD +82 -0
  80. amazon_ads_mcp-0.2.7.dist-info/WHEEL +4 -0
  81. amazon_ads_mcp-0.2.7.dist-info/entry_points.txt +3 -0
  82. amazon_ads_mcp-0.2.7.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,512 @@
1
+ """Enhanced resilience patterns for Amazon Ads API HTTP operations.
2
+
3
+ This module provides production-ready implementations of resilience patterns
4
+ following Amazon's guidance for API interactions:
5
+ - Exponential backoff with full jitter
6
+ - Retry-After header support
7
+ - Token bucket rate limiting
8
+ - Circuit breaker integration
9
+ - Per-endpoint/region awareness
10
+ - Comprehensive metrics
11
+
12
+ All implementations follow Amazon's recommended defaults and best practices.
13
+ """
14
+
15
+ import asyncio
16
+ import logging
17
+ import random
18
+ import time
19
+ from collections import defaultdict
20
+ from dataclasses import dataclass, field
21
+ from datetime import datetime
22
+ from enum import Enum
23
+ from functools import wraps
24
+ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
25
+ from urllib.parse import urlparse
26
+
27
+ import httpx
28
+
29
+ from ..region_config import RegionConfig
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ T = TypeVar("T")
34
+
35
+
36
+ class MetricsCollector:
37
+ """Collects metrics for monitoring and alerting."""
38
+
39
+ def __init__(self):
40
+ self._metrics: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(int))
41
+ self._start_time = time.time()
42
+
43
+ def record_throttle(self, endpoint: str, region: str) -> None:
44
+ """Record a 429 throttle response."""
45
+ key = f"ads_api_throttles_total.{endpoint}.{region}"
46
+ self._metrics["counters"][key] += 1
47
+ logger.debug(f"Throttle recorded: {endpoint} in {region}")
48
+
49
+ def record_retry(self, endpoint: str, attempt: int, delay: float) -> None:
50
+ """Record a retry attempt."""
51
+ self._metrics["counters"][f"retry_attempts_total.{endpoint}"] += 1
52
+ self._metrics["histograms"][f"retry_delay_seconds.{endpoint}"] = delay
53
+ logger.debug(f"Retry {attempt} for {endpoint} with {delay:.2f}s delay")
54
+
55
+ def record_retry_after(self, endpoint: str, delay: float) -> None:
56
+ """Record observed Retry-After header value."""
57
+ self._metrics["gauges"][f"retry_after_seconds.{endpoint}"] = delay
58
+ logger.info(f"Retry-After header: {delay}s for {endpoint}")
59
+
60
+ def record_circuit_state(self, endpoint: str, state: str) -> None:
61
+ """Record circuit breaker state change."""
62
+ self._metrics["gauges"][f"circuit_breaker_state.{endpoint}"] = state
63
+ logger.info(f"Circuit breaker {state} for {endpoint}")
64
+
65
+ def record_queue_wait(self, endpoint: str, wait_time: float) -> None:
66
+ """Record token bucket queue wait time."""
67
+ self._metrics["histograms"][f"queue_wait_seconds.{endpoint}"] = wait_time
68
+ if wait_time > 5.0:
69
+ logger.warning(f"Long queue wait: {wait_time:.2f}s for {endpoint}")
70
+
71
+ def record_success_after_retry(self, endpoint: str, attempts: int) -> None:
72
+ """Record successful completion after retries."""
73
+ self._metrics["counters"][f"success_after_retry.{endpoint}"] += 1
74
+ logger.info(f"Success after {attempts} attempts for {endpoint}")
75
+
76
+ def get_metrics(self) -> Dict[str, Dict[str, Any]]:
77
+ """Get all collected metrics."""
78
+ return dict(self._metrics)
79
+
80
+
81
+ # Global metrics instance
82
+ metrics = MetricsCollector()
83
+
84
+
85
+ class CircuitState(Enum):
86
+ """Circuit breaker states."""
87
+
88
+ CLOSED = "closed"
89
+ OPEN = "open"
90
+ HALF_OPEN = "half_open"
91
+
92
+
93
+ @dataclass
94
+ class CircuitBreaker:
95
+ """Enhanced circuit breaker with metrics and per-endpoint tracking."""
96
+
97
+ failure_threshold: int = 5
98
+ recovery_timeout: float = 60.0
99
+ half_open_requests: int = 3
100
+ state: CircuitState = CircuitState.CLOSED
101
+ failure_count: int = 0
102
+ success_count: int = 0
103
+ last_failure_time: Optional[float] = None
104
+ endpoint: str = ""
105
+
106
+ def is_open(self) -> bool:
107
+ """Check if circuit should block requests."""
108
+ if self.state == CircuitState.OPEN:
109
+ if (
110
+ self.last_failure_time
111
+ and (time.time() - self.last_failure_time) >= self.recovery_timeout
112
+ ):
113
+ self.state = CircuitState.HALF_OPEN
114
+ self.success_count = 0
115
+ metrics.record_circuit_state(self.endpoint, "half_open")
116
+ logger.info(f"Circuit entering HALF_OPEN for {self.endpoint}")
117
+ return False
118
+ return True
119
+ return False
120
+
121
+ def record_success(self) -> None:
122
+ """Record successful request."""
123
+ if self.state == CircuitState.HALF_OPEN:
124
+ self.success_count += 1
125
+ if self.success_count >= self.half_open_requests:
126
+ self.state = CircuitState.CLOSED
127
+ self.failure_count = 0
128
+ metrics.record_circuit_state(self.endpoint, "closed")
129
+ logger.info(f"Circuit CLOSED for {self.endpoint}")
130
+ elif self.state == CircuitState.CLOSED:
131
+ self.failure_count = 0
132
+
133
+ def record_failure(self) -> None:
134
+ """Record failed request."""
135
+ self.failure_count += 1
136
+ self.last_failure_time = time.time()
137
+
138
+ if self.state == CircuitState.HALF_OPEN:
139
+ self.state = CircuitState.OPEN
140
+ metrics.record_circuit_state(self.endpoint, "open")
141
+ logger.warning(f"Circuit OPEN for {self.endpoint} (failed in HALF_OPEN)")
142
+ elif self.failure_count >= self.failure_threshold:
143
+ self.state = CircuitState.OPEN
144
+ metrics.record_circuit_state(self.endpoint, "open")
145
+ logger.warning(f"Circuit OPEN for {self.endpoint} (threshold reached)")
146
+
147
+
148
+ # Per-endpoint circuit breakers
149
+ circuit_breakers: Dict[str, CircuitBreaker] = {}
150
+
151
+
152
+ def get_circuit_breaker(endpoint: str) -> CircuitBreaker:
153
+ """Get or create circuit breaker for endpoint."""
154
+ if endpoint not in circuit_breakers:
155
+ circuit_breakers[endpoint] = CircuitBreaker(endpoint=endpoint)
156
+ return circuit_breakers[endpoint]
157
+
158
+
159
+ @dataclass
160
+ class TokenBucket:
161
+ """Token bucket for rate limiting with per-endpoint TPS."""
162
+
163
+ capacity: float # TPS for this endpoint
164
+ tokens: float
165
+ last_refill: float = field(default_factory=time.time)
166
+ queue: List[asyncio.Future] = field(default_factory=list)
167
+ endpoint: str = ""
168
+ region: str = ""
169
+
170
+ def refill(self) -> None:
171
+ """Refill tokens based on elapsed time."""
172
+ now = time.time()
173
+ elapsed = now - self.last_refill
174
+ tokens_to_add = elapsed * self.capacity
175
+ self.tokens = min(self.capacity, self.tokens + tokens_to_add)
176
+ self.last_refill = now
177
+
178
+ async def acquire(self, timeout: Optional[float] = None) -> bool:
179
+ """Acquire a token, waiting if necessary."""
180
+ start_time = time.time()
181
+ deadline = start_time + timeout if timeout else None
182
+
183
+ while True:
184
+ self.refill()
185
+
186
+ if self.tokens >= 1:
187
+ self.tokens -= 1
188
+ wait_time = time.time() - start_time
189
+ if wait_time > 0.01: # Only record meaningful waits
190
+ metrics.record_queue_wait(self.endpoint, wait_time)
191
+ return True
192
+
193
+ # Check deadline
194
+ if deadline and time.time() >= deadline:
195
+ logger.warning(f"Token acquisition timeout for {self.endpoint}")
196
+ return False
197
+
198
+ # Check queue depth for back-pressure
199
+ if len(self.queue) > 100:
200
+ logger.error(f"Queue depth exceeded for {self.endpoint}, failing fast")
201
+ raise Exception(f"Rate limit queue full for {self.endpoint}")
202
+
203
+ # Wait with jitter
204
+ wait_time = (1.0 / self.capacity) * random.uniform(0.5, 1.5)
205
+ wait_time = min(wait_time, 1.0) # Cap at 1 second
206
+
207
+ if deadline:
208
+ wait_time = min(wait_time, max(0, deadline - time.time()))
209
+
210
+ await asyncio.sleep(wait_time)
211
+
212
+
213
+ # Per-endpoint-region token buckets
214
+ token_buckets: Dict[Tuple[str, str], TokenBucket] = {}
215
+
216
+ # Default TPS limits per endpoint family (from Amazon docs)
217
+ DEFAULT_TPS_LIMITS = {
218
+ "/v2/campaigns": 10,
219
+ "/v2/ad-groups": 10,
220
+ "/v2/keywords": 10,
221
+ "/v2/product-ads": 10,
222
+ "/v2/profiles": 5,
223
+ "/reporting": 2,
224
+ "/amc": 1,
225
+ "/exports": 1,
226
+ "default": 5,
227
+ }
228
+
229
+
230
+ def get_endpoint_family(url: str) -> str:
231
+ """Extract endpoint family from URL."""
232
+ path = urlparse(url).path.lower()
233
+
234
+ # Match specific patterns
235
+ if "/v2/campaigns" in path:
236
+ return "/v2/campaigns"
237
+ elif "/v2/ad-groups" in path:
238
+ return "/v2/ad-groups"
239
+ elif "/v2/keywords" in path:
240
+ return "/v2/keywords"
241
+ elif "/v2/product-ads" in path:
242
+ return "/v2/product-ads"
243
+ elif "/v2/profiles" in path:
244
+ return "/v2/profiles"
245
+ elif "/reporting" in path:
246
+ return "/reporting"
247
+ elif "/amc/" in path:
248
+ return "/amc"
249
+ elif "/exports" in path:
250
+ return "/exports"
251
+
252
+ return "default"
253
+
254
+
255
+ def get_region_from_url(url: str) -> str:
256
+ """Extract region from URL."""
257
+ return RegionConfig.get_region_from_url(url)
258
+
259
+
260
+ def get_token_bucket(url: str, tps_override: Optional[float] = None) -> TokenBucket:
261
+ """Get or create token bucket for endpoint/region."""
262
+ endpoint_family = get_endpoint_family(url)
263
+ region = get_region_from_url(url)
264
+ key = (endpoint_family, region)
265
+
266
+ if key not in token_buckets:
267
+ tps = tps_override or DEFAULT_TPS_LIMITS.get(
268
+ endpoint_family, DEFAULT_TPS_LIMITS["default"]
269
+ )
270
+ token_buckets[key] = TokenBucket(
271
+ capacity=tps, tokens=tps, endpoint=endpoint_family, region=region
272
+ )
273
+ logger.info(f"Created token bucket: {endpoint_family}/{region} with {tps} TPS")
274
+
275
+ return token_buckets[key]
276
+
277
+
278
+ def parse_retry_after(response: httpx.Response) -> Optional[float]:
279
+ """Parse Retry-After header from response.
280
+
281
+ Supports both delta-seconds and HTTP-date formats.
282
+ """
283
+ retry_after = response.headers.get("retry-after", "").strip()
284
+ if not retry_after:
285
+ return None
286
+
287
+ try:
288
+ # Try delta-seconds first
289
+ if retry_after.isdigit():
290
+ delay = float(retry_after)
291
+ logger.debug(f"Parsed Retry-After as delta-seconds: {delay}")
292
+ return delay
293
+
294
+ # Try HTTP-date format
295
+ from email.utils import parsedate_to_datetime
296
+
297
+ retry_date = parsedate_to_datetime(retry_after)
298
+ delay = (retry_date - datetime.now(retry_date.tzinfo)).total_seconds()
299
+ delay = max(0, delay) # Ensure non-negative
300
+ logger.debug(f"Parsed Retry-After as HTTP-date: {delay}s")
301
+ return delay
302
+ except Exception as e:
303
+ logger.warning(f"Failed to parse Retry-After header '{retry_after}': {e}")
304
+ return None
305
+
306
+
307
+ def should_retry_status(status_code: int) -> bool:
308
+ """Determine if status code is retryable."""
309
+ # Retry: 429, 408, 502, 503, 504
310
+ return status_code in {429, 408, 502, 503, 504}
311
+
312
+
313
+ def is_idempotent_request(request: httpx.Request) -> bool:
314
+ """Check if request is idempotent."""
315
+ method = request.method.upper()
316
+
317
+ # GET, HEAD, PUT, DELETE are idempotent
318
+ if method in {"GET", "HEAD", "PUT", "DELETE"}:
319
+ return True
320
+
321
+ # POST with idempotency key
322
+ if method == "POST":
323
+ headers = request.headers
324
+ if "idempotency-key" in headers or "x-amzn-idempotency-key" in headers:
325
+ return True
326
+
327
+ return False
328
+
329
+
330
+ class ResilientRetry:
331
+ """Enhanced retry decorator with all Amazon recommendations."""
332
+
333
+ def __init__(
334
+ self,
335
+ max_attempts: int = 5,
336
+ initial_delay: float = 1.0,
337
+ max_delay: float = 60.0,
338
+ backoff_multiplier: float = 2.0,
339
+ total_timeout: float = 180.0, # 3 minutes default
340
+ use_circuit_breaker: bool = True,
341
+ use_rate_limiter: bool = True,
342
+ interactive: bool = False, # True for user-facing, False for batch
343
+ ):
344
+ self.max_attempts = max_attempts if not interactive else min(5, max_attempts)
345
+ self.initial_delay = initial_delay
346
+ self.max_delay = max_delay
347
+ self.backoff_multiplier = backoff_multiplier
348
+ self.total_timeout = total_timeout
349
+ self.use_circuit_breaker = use_circuit_breaker
350
+ self.use_rate_limiter = use_rate_limiter
351
+
352
+ def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
353
+ @wraps(func)
354
+ async def wrapper(*args, **kwargs) -> T:
355
+ start_time = time.time()
356
+ last_exception: Optional[Exception] = None
357
+ current_delay = self.initial_delay
358
+
359
+ # Extract URL from request if available
360
+ url = ""
361
+ request = None
362
+ for arg in args:
363
+ if isinstance(arg, httpx.Request):
364
+ request = arg
365
+ url = str(request.url)
366
+ break
367
+
368
+ endpoint = get_endpoint_family(url) if url else "unknown"
369
+ region = get_region_from_url(url) if url else "unknown"
370
+
371
+ # Check circuit breaker
372
+ if self.use_circuit_breaker and url:
373
+ breaker = get_circuit_breaker(endpoint)
374
+ if breaker.is_open():
375
+ raise Exception(f"Circuit breaker OPEN for {endpoint}")
376
+
377
+ for attempt in range(1, self.max_attempts + 1):
378
+ try:
379
+ # Check total timeout budget
380
+ elapsed = time.time() - start_time
381
+ if elapsed >= self.total_timeout:
382
+ logger.error(f"Total retry budget exhausted ({elapsed:.1f}s)")
383
+ raise Exception(f"Retry timeout after {elapsed:.1f}s")
384
+
385
+ # Rate limiting
386
+ if self.use_rate_limiter and url:
387
+ bucket = get_token_bucket(url)
388
+ remaining_time = self.total_timeout - elapsed
389
+ acquired = await bucket.acquire(timeout=remaining_time)
390
+ if not acquired:
391
+ raise Exception(f"Rate limit timeout for {endpoint}")
392
+
393
+ # Make the actual call
394
+ result = await func(*args, **kwargs)
395
+
396
+ # Record success
397
+ if self.use_circuit_breaker and url:
398
+ breaker = get_circuit_breaker(endpoint)
399
+ breaker.record_success()
400
+
401
+ if attempt > 1:
402
+ metrics.record_success_after_retry(endpoint, attempt)
403
+
404
+ return result
405
+
406
+ except (
407
+ httpx.HTTPStatusError,
408
+ httpx.RequestError,
409
+ httpx.TimeoutException,
410
+ ) as e:
411
+ last_exception = e
412
+
413
+ # Determine if we should retry
414
+ should_retry = False
415
+ retry_after_delay: Optional[float] = None
416
+
417
+ if isinstance(e, httpx.HTTPStatusError):
418
+ status_code = e.response.status_code
419
+
420
+ # Check if status is retryable
421
+ if should_retry_status(status_code):
422
+ should_retry = True
423
+
424
+ # Record throttle
425
+ if status_code == 429:
426
+ metrics.record_throttle(endpoint, region)
427
+
428
+ # Check for Retry-After header
429
+ retry_after_delay = parse_retry_after(e.response)
430
+ if retry_after_delay:
431
+ metrics.record_retry_after(endpoint, retry_after_delay)
432
+
433
+ # For 4xx errors, only retry if idempotent
434
+ elif 400 <= status_code < 500:
435
+ if request and is_idempotent_request(request):
436
+ logger.debug(
437
+ f"Retrying idempotent request despite {status_code}"
438
+ )
439
+ should_retry = True
440
+
441
+ elif isinstance(e, (httpx.RequestError, httpx.TimeoutException)):
442
+ # Network errors and timeouts are retryable
443
+ should_retry = True
444
+
445
+ # Record failure
446
+ if self.use_circuit_breaker and url:
447
+ breaker = get_circuit_breaker(endpoint)
448
+ breaker.record_failure()
449
+
450
+ if not should_retry or attempt >= self.max_attempts:
451
+ logger.error(f"Request failed after {attempt} attempts: {e}")
452
+ raise
453
+
454
+ # Calculate delay
455
+ if retry_after_delay:
456
+ # Honor Retry-After with jitter
457
+ delay = retry_after_delay + random.uniform(
458
+ 0, min(retry_after_delay * 0.1, 5)
459
+ )
460
+ else:
461
+ # Full jitter exponential backoff
462
+ delay = random.uniform(0, min(current_delay, self.max_delay))
463
+ current_delay = min(
464
+ current_delay * self.backoff_multiplier,
465
+ self.max_delay,
466
+ )
467
+
468
+ # Ensure delay doesn't exceed remaining budget
469
+ remaining_time = self.total_timeout - (time.time() - start_time)
470
+ delay = min(delay, remaining_time - 1) # Leave 1s for the request
471
+
472
+ if delay <= 0:
473
+ logger.error("No time left in retry budget")
474
+ raise
475
+
476
+ metrics.record_retry(endpoint, attempt, delay)
477
+ logger.info(
478
+ f"Retry {attempt}/{self.max_attempts} after {delay:.2f}s for {endpoint}"
479
+ )
480
+
481
+ await asyncio.sleep(delay)
482
+
483
+ # Should never reach here
484
+ if last_exception:
485
+ raise last_exception
486
+ raise Exception("Retry logic error")
487
+
488
+ return wrapper
489
+
490
+ @classmethod
491
+ def for_interactive(cls) -> "ResilientRetry":
492
+ """Create retry config optimized for interactive/user-facing requests."""
493
+ return cls(
494
+ max_attempts=5,
495
+ total_timeout=120,
496
+ interactive=True, # 2 minutes
497
+ )
498
+
499
+ @classmethod
500
+ def for_batch(cls) -> "ResilientRetry":
501
+ """Create retry config optimized for batch/background operations."""
502
+ return cls(
503
+ max_attempts=10,
504
+ total_timeout=300,
505
+ interactive=False, # 5 minutes
506
+ )
507
+
508
+
509
+ # Export convenience decorators
510
+ resilient_retry = ResilientRetry()
511
+ interactive_retry = ResilientRetry.for_interactive()
512
+ batch_retry = ResilientRetry.for_batch()
@@ -0,0 +1,195 @@
1
+ """Resilient HTTP client with integrated retry, rate limiting, and circuit breaking.
2
+
3
+ This module provides a drop-in replacement for AuthenticatedClient that includes
4
+ all resilience patterns recommended by Amazon for API interactions.
5
+ """
6
+
7
+ import logging
8
+ from typing import Any, Dict
9
+
10
+ import httpx
11
+
12
+ from ...utils.http_client import AuthenticatedClient
13
+ from .resilience import (
14
+ ResilientRetry,
15
+ get_circuit_breaker,
16
+ get_endpoint_family,
17
+ get_token_bucket,
18
+ metrics,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class ResilientAuthenticatedClient(AuthenticatedClient):
25
+ """Enhanced authenticated client with built-in resilience patterns.
26
+
27
+ This client extends AuthenticatedClient to add:
28
+ - Automatic retry with exponential backoff and jitter
29
+ - Retry-After header support
30
+ - Token bucket rate limiting per endpoint/region
31
+ - Circuit breaker protection
32
+ - Comprehensive metrics collection
33
+ - Total retry budget enforcement
34
+
35
+ It's a drop-in replacement that requires no code changes.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ *args,
41
+ enable_rate_limiting: bool = True,
42
+ enable_circuit_breaker: bool = True,
43
+ interactive_mode: bool = False,
44
+ **kwargs,
45
+ ):
46
+ """Initialize resilient client with configurable features.
47
+
48
+ :param enable_rate_limiting: Enable token bucket rate limiting
49
+ :param enable_circuit_breaker: Enable circuit breaker protection
50
+ :param interactive_mode: Optimize for interactive (True) or batch (False)
51
+ :param args: Positional arguments for parent class
52
+ :param kwargs: Keyword arguments for parent class
53
+ """
54
+ super().__init__(*args, **kwargs)
55
+ self.enable_rate_limiting = enable_rate_limiting
56
+ self.enable_circuit_breaker = enable_circuit_breaker
57
+ self.interactive_mode = interactive_mode
58
+
59
+ # Configure retry based on mode
60
+ if interactive_mode:
61
+ self.retry_decorator = ResilientRetry.for_interactive()
62
+ else:
63
+ self.retry_decorator = ResilientRetry.for_batch()
64
+
65
+ logger.info(
66
+ f"ResilientAuthenticatedClient initialized: "
67
+ f"rate_limiting={enable_rate_limiting}, "
68
+ f"circuit_breaker={enable_circuit_breaker}, "
69
+ f"mode={'interactive' if interactive_mode else 'batch'}"
70
+ )
71
+
72
+ async def send(self, request: httpx.Request, **kwargs) -> httpx.Response:
73
+ """Send request with resilience patterns applied.
74
+
75
+ Wraps the parent send method with:
76
+ 1. Pre-request rate limiting via token bucket
77
+ 2. Circuit breaker checking
78
+ 3. Retry logic with backoff and jitter
79
+ 4. Metrics collection
80
+
81
+ :param request: The HTTP request to send
82
+ :param kwargs: Additional arguments for send
83
+ :return: The HTTP response
84
+ :raises: Various HTTP exceptions after exhausting retries
85
+ """
86
+ url = str(request.url)
87
+ endpoint = get_endpoint_family(url)
88
+
89
+ # Pre-flight circuit breaker check
90
+ if self.enable_circuit_breaker:
91
+ breaker = get_circuit_breaker(endpoint)
92
+ if breaker.is_open():
93
+ logger.warning(f"Circuit breaker OPEN for {endpoint}, failing fast")
94
+ raise Exception(f"Circuit breaker is OPEN for {endpoint}")
95
+
96
+ # Apply rate limiting before sending
97
+ if self.enable_rate_limiting:
98
+ bucket = get_token_bucket(url)
99
+ # Use shorter timeout for interactive mode
100
+ timeout = 30.0 if self.interactive_mode else 120.0
101
+ acquired = await bucket.acquire(timeout=timeout)
102
+ if not acquired:
103
+ logger.error(f"Rate limit timeout for {endpoint}")
104
+ raise Exception(f"Rate limit acquisition timeout for {endpoint}")
105
+
106
+ # Create wrapped send function for retry decorator
107
+ @self.retry_decorator
108
+ async def send_with_retry():
109
+ return await super(ResilientAuthenticatedClient, self).send(
110
+ request, **kwargs
111
+ )
112
+
113
+ # Execute with all resilience patterns
114
+ try:
115
+ response = await send_with_retry()
116
+
117
+ # Record success in circuit breaker
118
+ if self.enable_circuit_breaker:
119
+ breaker = get_circuit_breaker(endpoint)
120
+ breaker.record_success()
121
+
122
+ return response
123
+
124
+ except Exception:
125
+ # Record failure in circuit breaker
126
+ if self.enable_circuit_breaker:
127
+ breaker = get_circuit_breaker(endpoint)
128
+ breaker.record_failure()
129
+ raise
130
+
131
+ async def request(self, method: str, url: str, **kwargs) -> httpx.Response:
132
+ """Make HTTP request with resilience patterns.
133
+
134
+ Convenience method that builds request and sends with resilience.
135
+
136
+ :param method: HTTP method
137
+ :param url: Request URL
138
+ :param kwargs: Additional request parameters
139
+ :return: HTTP response
140
+ """
141
+ # Build request
142
+ request = self.build_request(method, url, **kwargs)
143
+
144
+ # Send with resilience
145
+ return await self.send(request)
146
+
147
+ def get_metrics(self) -> Dict[str, Any]:
148
+ """Get collected resilience metrics.
149
+
150
+ Returns metrics including:
151
+ - Throttle counts per endpoint/region
152
+ - Retry attempts and delays
153
+ - Circuit breaker states
154
+ - Queue wait times
155
+ - Success after retry counts
156
+
157
+ :return: Dictionary of collected metrics
158
+ """
159
+ return metrics.get_metrics()
160
+
161
+ def reset_metrics(self) -> None:
162
+ """Reset all collected metrics.
163
+
164
+ Useful for testing or periodic metric collection.
165
+ """
166
+ global metrics
167
+ from .resilience import MetricsCollector
168
+
169
+ metrics = MetricsCollector()
170
+ logger.info("Metrics reset")
171
+
172
+
173
+ def create_resilient_client(
174
+ auth_manager=None,
175
+ media_registry=None,
176
+ header_resolver=None,
177
+ interactive: bool = False,
178
+ **kwargs,
179
+ ) -> ResilientAuthenticatedClient:
180
+ """Factory function to create a resilient authenticated client.
181
+
182
+ :param auth_manager: Authentication manager
183
+ :param media_registry: Media type registry
184
+ :param header_resolver: Header name resolver
185
+ :param interactive: Optimize for interactive (True) or batch (False)
186
+ :param kwargs: Additional client configuration
187
+ :return: Configured resilient client
188
+ """
189
+ return ResilientAuthenticatedClient(
190
+ auth_manager=auth_manager,
191
+ media_registry=media_registry,
192
+ header_resolver=header_resolver,
193
+ interactive_mode=interactive,
194
+ **kwargs,
195
+ )