texas-grocery-mcp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- texas_grocery_mcp/__init__.py +3 -0
- texas_grocery_mcp/auth/__init__.py +5 -0
- texas_grocery_mcp/auth/browser_refresh.py +1629 -0
- texas_grocery_mcp/auth/credentials.py +337 -0
- texas_grocery_mcp/auth/session.py +767 -0
- texas_grocery_mcp/clients/__init__.py +5 -0
- texas_grocery_mcp/clients/graphql.py +2400 -0
- texas_grocery_mcp/models/__init__.py +54 -0
- texas_grocery_mcp/models/cart.py +60 -0
- texas_grocery_mcp/models/coupon.py +44 -0
- texas_grocery_mcp/models/errors.py +43 -0
- texas_grocery_mcp/models/health.py +41 -0
- texas_grocery_mcp/models/product.py +274 -0
- texas_grocery_mcp/models/store.py +77 -0
- texas_grocery_mcp/observability/__init__.py +6 -0
- texas_grocery_mcp/observability/health.py +141 -0
- texas_grocery_mcp/observability/logging.py +73 -0
- texas_grocery_mcp/reliability/__init__.py +23 -0
- texas_grocery_mcp/reliability/cache.py +116 -0
- texas_grocery_mcp/reliability/circuit_breaker.py +138 -0
- texas_grocery_mcp/reliability/retry.py +96 -0
- texas_grocery_mcp/reliability/throttle.py +113 -0
- texas_grocery_mcp/server.py +211 -0
- texas_grocery_mcp/services/__init__.py +5 -0
- texas_grocery_mcp/services/geocoding.py +227 -0
- texas_grocery_mcp/state.py +166 -0
- texas_grocery_mcp/tools/__init__.py +5 -0
- texas_grocery_mcp/tools/cart.py +821 -0
- texas_grocery_mcp/tools/coupon.py +381 -0
- texas_grocery_mcp/tools/product.py +437 -0
- texas_grocery_mcp/tools/session.py +486 -0
- texas_grocery_mcp/tools/store.py +353 -0
- texas_grocery_mcp/utils/__init__.py +5 -0
- texas_grocery_mcp/utils/config.py +146 -0
- texas_grocery_mcp/utils/secure_file.py +123 -0
- texas_grocery_mcp-0.1.0.dist-info/METADATA +296 -0
- texas_grocery_mcp-0.1.0.dist-info/RECORD +40 -0
- texas_grocery_mcp-0.1.0.dist-info/WHEEL +4 -0
- texas_grocery_mcp-0.1.0.dist-info/entry_points.txt +2 -0
- texas_grocery_mcp-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Health check endpoints."""
|
|
2
|
+
|
|
3
|
+
from datetime import UTC, datetime
|
|
4
|
+
from typing import Any, Literal, cast
|
|
5
|
+
|
|
6
|
+
import structlog
|
|
7
|
+
|
|
8
|
+
from texas_grocery_mcp.models.health import (
|
|
9
|
+
CircuitBreakerStatus,
|
|
10
|
+
ComponentHealth,
|
|
11
|
+
HealthResponse,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
logger = structlog.get_logger()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def health_live() -> dict[str, str]:
|
|
18
|
+
"""Liveness probe - is the process running?
|
|
19
|
+
|
|
20
|
+
Returns a simple alive status. Use for Kubernetes liveness probes.
|
|
21
|
+
"""
|
|
22
|
+
return {"status": "alive"}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _check_redis_health_sync(redis_url: str) -> ComponentHealth:
|
|
26
|
+
"""Check Redis connectivity (sync version).
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
redis_url: Redis connection URL
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
ComponentHealth with status and optional message
|
|
33
|
+
"""
|
|
34
|
+
try:
|
|
35
|
+
import redis
|
|
36
|
+
|
|
37
|
+
# Parse URL and connect with timeout
|
|
38
|
+
client = redis.from_url( # type: ignore[no-untyped-call]
|
|
39
|
+
redis_url,
|
|
40
|
+
socket_connect_timeout=2.0,
|
|
41
|
+
socket_timeout=2.0,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
# Ping to verify connectivity
|
|
46
|
+
client.ping()
|
|
47
|
+
|
|
48
|
+
# Get basic info for health details
|
|
49
|
+
info = client.info(section="server")
|
|
50
|
+
redis_version = info.get("redis_version", "unknown")
|
|
51
|
+
|
|
52
|
+
return ComponentHealth(
|
|
53
|
+
status="up",
|
|
54
|
+
message=f"Redis {redis_version}",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
finally:
|
|
58
|
+
client.close()
|
|
59
|
+
|
|
60
|
+
except ImportError:
|
|
61
|
+
return ComponentHealth(
|
|
62
|
+
status="up",
|
|
63
|
+
message="Redis client not installed (optional dependency)",
|
|
64
|
+
)
|
|
65
|
+
except Exception as e:
|
|
66
|
+
logger.warning("Redis health check failed", error=str(e))
|
|
67
|
+
return ComponentHealth(
|
|
68
|
+
status="down",
|
|
69
|
+
message=f"Connection failed: {str(e)}",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def health_ready() -> dict[str, Any]:
|
|
74
|
+
"""Readiness probe - can the server handle requests?
|
|
75
|
+
|
|
76
|
+
Returns detailed component health. Use for Kubernetes readiness probes.
|
|
77
|
+
"""
|
|
78
|
+
components: dict[str, ComponentHealth] = {}
|
|
79
|
+
circuit_breakers: dict[str, CircuitBreakerStatus] = {}
|
|
80
|
+
overall_status: Literal["healthy", "degraded", "unhealthy"] = "healthy"
|
|
81
|
+
|
|
82
|
+
# Check GraphQL API status
|
|
83
|
+
try:
|
|
84
|
+
from texas_grocery_mcp.clients.graphql import HEBGraphQLClient
|
|
85
|
+
|
|
86
|
+
client = HEBGraphQLClient()
|
|
87
|
+
cb_status = client.circuit_breaker.get_status()
|
|
88
|
+
|
|
89
|
+
state_raw = cb_status.get("state")
|
|
90
|
+
state: Literal["closed", "open", "half_open"] = "closed"
|
|
91
|
+
if isinstance(state_raw, str) and state_raw in ("closed", "open", "half_open"):
|
|
92
|
+
state = cast(Literal["closed", "open", "half_open"], state_raw)
|
|
93
|
+
|
|
94
|
+
failures_raw = cb_status.get("failure_count", 0)
|
|
95
|
+
failures = int(failures_raw) if isinstance(failures_raw, int) else 0
|
|
96
|
+
|
|
97
|
+
if state == "open":
|
|
98
|
+
components["graphql_api"] = ComponentHealth(
|
|
99
|
+
status="down", message="Circuit breaker open"
|
|
100
|
+
)
|
|
101
|
+
overall_status = "degraded"
|
|
102
|
+
else:
|
|
103
|
+
components["graphql_api"] = ComponentHealth(status="up")
|
|
104
|
+
|
|
105
|
+
circuit_breakers["heb_graphql"] = CircuitBreakerStatus(
|
|
106
|
+
state=state,
|
|
107
|
+
failures=failures,
|
|
108
|
+
)
|
|
109
|
+
except Exception as e:
|
|
110
|
+
components["graphql_api"] = ComponentHealth(
|
|
111
|
+
status="down", message=str(e)
|
|
112
|
+
)
|
|
113
|
+
overall_status = "unhealthy"
|
|
114
|
+
|
|
115
|
+
# Check cache status (if configured)
|
|
116
|
+
try:
|
|
117
|
+
from texas_grocery_mcp.utils.config import get_settings
|
|
118
|
+
|
|
119
|
+
settings = get_settings()
|
|
120
|
+
if settings.redis_url:
|
|
121
|
+
# Actually check Redis connectivity
|
|
122
|
+
cache_health = _check_redis_health_sync(settings.redis_url)
|
|
123
|
+
components["cache"] = cache_health
|
|
124
|
+
|
|
125
|
+
if cache_health.status == "down" and overall_status == "healthy":
|
|
126
|
+
overall_status = "degraded"
|
|
127
|
+
else:
|
|
128
|
+
components["cache"] = ComponentHealth(
|
|
129
|
+
status="up", message="Not configured (using in-memory)"
|
|
130
|
+
)
|
|
131
|
+
except Exception as e:
|
|
132
|
+
components["cache"] = ComponentHealth(status="down", message=str(e))
|
|
133
|
+
if overall_status == "healthy":
|
|
134
|
+
overall_status = "degraded"
|
|
135
|
+
|
|
136
|
+
return HealthResponse(
|
|
137
|
+
status=overall_status,
|
|
138
|
+
timestamp=datetime.now(UTC).isoformat(),
|
|
139
|
+
components=components,
|
|
140
|
+
circuit_breakers=circuit_breakers,
|
|
141
|
+
).model_dump()
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Structured JSON logging configuration."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
from collections.abc import MutableMapping
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import structlog
|
|
9
|
+
from structlog.types import Processor
|
|
10
|
+
|
|
11
|
+
from texas_grocery_mcp.utils.config import get_settings
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def add_timestamp(
|
|
15
|
+
logger: Any, method_name: str, event_dict: MutableMapping[str, Any]
|
|
16
|
+
) -> MutableMapping[str, Any]:
|
|
17
|
+
"""Add ISO timestamp to log entry."""
|
|
18
|
+
from datetime import UTC, datetime
|
|
19
|
+
|
|
20
|
+
event_dict["timestamp"] = datetime.now(UTC).isoformat()
|
|
21
|
+
return event_dict
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def configure_logging(log_level: str | None = None) -> None:
|
|
25
|
+
"""Configure structured JSON logging.
|
|
26
|
+
|
|
27
|
+
Logs to stderr to keep stdout clean for MCP protocol.
|
|
28
|
+
"""
|
|
29
|
+
settings = get_settings()
|
|
30
|
+
level = log_level or settings.log_level
|
|
31
|
+
|
|
32
|
+
# Shared processors for all loggers
|
|
33
|
+
shared_processors: list[Processor] = [
|
|
34
|
+
structlog.contextvars.merge_contextvars,
|
|
35
|
+
structlog.stdlib.add_log_level,
|
|
36
|
+
add_timestamp,
|
|
37
|
+
structlog.processors.StackInfoRenderer(),
|
|
38
|
+
structlog.processors.format_exc_info,
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
# Configure structlog
|
|
42
|
+
structlog.configure(
|
|
43
|
+
processors=[
|
|
44
|
+
*shared_processors,
|
|
45
|
+
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
|
46
|
+
],
|
|
47
|
+
wrapper_class=structlog.stdlib.BoundLogger,
|
|
48
|
+
context_class=dict,
|
|
49
|
+
logger_factory=structlog.stdlib.LoggerFactory(),
|
|
50
|
+
cache_logger_on_first_use=True,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Configure stdlib logging
|
|
54
|
+
formatter = structlog.stdlib.ProcessorFormatter(
|
|
55
|
+
foreign_pre_chain=shared_processors,
|
|
56
|
+
processors=[
|
|
57
|
+
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
|
58
|
+
structlog.processors.JSONRenderer(),
|
|
59
|
+
],
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
handler = logging.StreamHandler(sys.stderr)
|
|
63
|
+
handler.setFormatter(formatter)
|
|
64
|
+
|
|
65
|
+
root_logger = logging.getLogger()
|
|
66
|
+
root_logger.handlers.clear()
|
|
67
|
+
root_logger.addHandler(handler)
|
|
68
|
+
root_logger.setLevel(getattr(logging, level.upper()))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
|
|
72
|
+
"""Get a structured logger instance."""
|
|
73
|
+
return structlog.get_logger(name) # type: ignore[no-any-return]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Reliability patterns for production resilience."""
|
|
2
|
+
|
|
3
|
+
from texas_grocery_mcp.reliability.cache import TTLCache
|
|
4
|
+
from texas_grocery_mcp.reliability.circuit_breaker import (
|
|
5
|
+
CircuitBreaker,
|
|
6
|
+
CircuitBreakerConfig,
|
|
7
|
+
CircuitBreakerOpenError,
|
|
8
|
+
CircuitState,
|
|
9
|
+
)
|
|
10
|
+
from texas_grocery_mcp.reliability.retry import RetryConfig, with_retry
|
|
11
|
+
from texas_grocery_mcp.reliability.throttle import ThrottleConfig, Throttler
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"CircuitBreaker",
|
|
15
|
+
"CircuitBreakerConfig",
|
|
16
|
+
"CircuitBreakerOpenError",
|
|
17
|
+
"CircuitState",
|
|
18
|
+
"RetryConfig",
|
|
19
|
+
"TTLCache",
|
|
20
|
+
"with_retry",
|
|
21
|
+
"ThrottleConfig",
|
|
22
|
+
"Throttler",
|
|
23
|
+
]
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Simple TTL cache for API responses."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, timedelta
|
|
4
|
+
from typing import Generic, TypeVar
|
|
5
|
+
|
|
6
|
+
import structlog
|
|
7
|
+
|
|
8
|
+
logger = structlog.get_logger()
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TTLCache(Generic[T]):
|
|
14
|
+
"""Simple in-memory cache with time-to-live expiration.
|
|
15
|
+
|
|
16
|
+
Thread-safe for basic operations. Uses a dict for storage with
|
|
17
|
+
(value, timestamp) tuples.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
cache = TTLCache[ProductDetails](ttl_hours=24)
|
|
21
|
+
cache.set("127074", product_details)
|
|
22
|
+
details = cache.get("127074") # Returns None if expired
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, ttl_hours: int = 24, max_size: int = 1000):
|
|
26
|
+
"""Initialize the cache.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
ttl_hours: Time-to-live in hours before entries expire
|
|
30
|
+
max_size: Maximum number of entries to store
|
|
31
|
+
"""
|
|
32
|
+
self._cache: dict[str, tuple[T, datetime]] = {}
|
|
33
|
+
self._ttl = timedelta(hours=ttl_hours)
|
|
34
|
+
self._max_size = max_size
|
|
35
|
+
|
|
36
|
+
def get(self, key: str) -> T | None:
|
|
37
|
+
"""Get a value from the cache.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
key: Cache key
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Cached value if present and not expired, None otherwise
|
|
44
|
+
"""
|
|
45
|
+
if key not in self._cache:
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
value, cached_at = self._cache[key]
|
|
49
|
+
if datetime.now() - cached_at >= self._ttl:
|
|
50
|
+
# Expired - remove and return None
|
|
51
|
+
del self._cache[key]
|
|
52
|
+
logger.debug("Cache entry expired", key=key)
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
logger.debug("Cache hit", key=key)
|
|
56
|
+
return value
|
|
57
|
+
|
|
58
|
+
def set(self, key: str, value: T) -> None:
|
|
59
|
+
"""Store a value in the cache.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
key: Cache key
|
|
63
|
+
value: Value to cache
|
|
64
|
+
"""
|
|
65
|
+
# Evict oldest entries if at max size
|
|
66
|
+
if len(self._cache) >= self._max_size and key not in self._cache:
|
|
67
|
+
self._evict_oldest()
|
|
68
|
+
|
|
69
|
+
self._cache[key] = (value, datetime.now())
|
|
70
|
+
logger.debug("Cache set", key=key)
|
|
71
|
+
|
|
72
|
+
def invalidate(self, key: str) -> None:
|
|
73
|
+
"""Remove a specific entry from the cache.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
key: Cache key to remove
|
|
77
|
+
"""
|
|
78
|
+
if key in self._cache:
|
|
79
|
+
del self._cache[key]
|
|
80
|
+
logger.debug("Cache entry invalidated", key=key)
|
|
81
|
+
|
|
82
|
+
def clear(self) -> None:
|
|
83
|
+
"""Remove all entries from the cache."""
|
|
84
|
+
count = len(self._cache)
|
|
85
|
+
self._cache.clear()
|
|
86
|
+
logger.info("Cache cleared", entries_removed=count)
|
|
87
|
+
|
|
88
|
+
def _evict_oldest(self) -> None:
|
|
89
|
+
"""Remove the oldest entry from the cache."""
|
|
90
|
+
if not self._cache:
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
oldest_key = min(self._cache, key=lambda k: self._cache[k][1])
|
|
94
|
+
del self._cache[oldest_key]
|
|
95
|
+
logger.debug("Cache evicted oldest entry", key=oldest_key)
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def size(self) -> int:
|
|
99
|
+
"""Current number of entries in the cache."""
|
|
100
|
+
return len(self._cache)
|
|
101
|
+
|
|
102
|
+
def stats(self) -> dict[str, int | float]:
|
|
103
|
+
"""Get cache statistics.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Dict with cache stats (size, ttl_hours, max_size)
|
|
107
|
+
"""
|
|
108
|
+
now = datetime.now()
|
|
109
|
+
valid_count = sum(1 for _, (_, ts) in self._cache.items() if now - ts < self._ttl)
|
|
110
|
+
return {
|
|
111
|
+
"size": len(self._cache),
|
|
112
|
+
"valid_entries": valid_count,
|
|
113
|
+
"expired_entries": len(self._cache) - valid_count,
|
|
114
|
+
"ttl_hours": self._ttl.total_seconds() / 3600,
|
|
115
|
+
"max_size": self._max_size,
|
|
116
|
+
}
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Circuit breaker pattern for preventing cascading failures."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from threading import Lock
|
|
7
|
+
|
|
8
|
+
import structlog
|
|
9
|
+
|
|
10
|
+
logger = structlog.get_logger()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CircuitState(Enum):
|
|
14
|
+
"""Circuit breaker states."""
|
|
15
|
+
|
|
16
|
+
CLOSED = "closed"
|
|
17
|
+
OPEN = "open"
|
|
18
|
+
HALF_OPEN = "half_open"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CircuitBreakerOpenError(Exception):
|
|
22
|
+
"""Raised when circuit breaker is open."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, name: str, retry_after: float):
|
|
25
|
+
self.name = name
|
|
26
|
+
self.retry_after = retry_after
|
|
27
|
+
super().__init__(f"Circuit breaker '{name}' is open. Retry after {retry_after:.1f}s")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class CircuitBreakerConfig:
|
|
32
|
+
"""Configuration for circuit breaker."""
|
|
33
|
+
|
|
34
|
+
failure_threshold: int = 5
|
|
35
|
+
recovery_timeout: float = 30.0
|
|
36
|
+
half_open_max_calls: int = 3
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class CircuitBreaker:
|
|
40
|
+
"""Circuit breaker for external service calls.
|
|
41
|
+
|
|
42
|
+
Prevents cascading failures by "tripping" after repeated failures
|
|
43
|
+
and allowing the system to recover.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, name: str, config: CircuitBreakerConfig | None = None):
|
|
47
|
+
self.name = name
|
|
48
|
+
self.config = config or CircuitBreakerConfig()
|
|
49
|
+
self._state = CircuitState.CLOSED
|
|
50
|
+
self._failure_count = 0
|
|
51
|
+
self._success_count = 0
|
|
52
|
+
self._last_failure_time: float | None = None
|
|
53
|
+
self._half_open_calls = 0
|
|
54
|
+
self._lock = Lock()
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def state(self) -> CircuitState:
|
|
58
|
+
"""Get current circuit state, checking for transition to half-open."""
|
|
59
|
+
with self._lock:
|
|
60
|
+
if self._state == CircuitState.OPEN and self._should_attempt_recovery():
|
|
61
|
+
self._state = CircuitState.HALF_OPEN
|
|
62
|
+
self._half_open_calls = 0
|
|
63
|
+
logger.info("Circuit breaker transitioning to half-open", name=self.name)
|
|
64
|
+
return self._state
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def failure_count(self) -> int:
|
|
68
|
+
"""Get current failure count."""
|
|
69
|
+
return self._failure_count
|
|
70
|
+
|
|
71
|
+
def _should_attempt_recovery(self) -> bool:
|
|
72
|
+
"""Check if enough time has passed to attempt recovery."""
|
|
73
|
+
if self._last_failure_time is None:
|
|
74
|
+
return False
|
|
75
|
+
return time.time() - self._last_failure_time >= self.config.recovery_timeout
|
|
76
|
+
|
|
77
|
+
def check(self) -> None:
|
|
78
|
+
"""Check if request should be allowed.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
CircuitBreakerOpen: If circuit is open and recovery timeout hasn't elapsed.
|
|
82
|
+
"""
|
|
83
|
+
current_state = self.state # This may transition to half-open
|
|
84
|
+
|
|
85
|
+
if current_state == CircuitState.OPEN:
|
|
86
|
+
retry_after = (
|
|
87
|
+
self.config.recovery_timeout
|
|
88
|
+
- (time.time() - (self._last_failure_time or 0))
|
|
89
|
+
)
|
|
90
|
+
raise CircuitBreakerOpenError(self.name, max(0, retry_after))
|
|
91
|
+
|
|
92
|
+
if current_state == CircuitState.HALF_OPEN:
|
|
93
|
+
with self._lock:
|
|
94
|
+
if self._half_open_calls >= self.config.half_open_max_calls:
|
|
95
|
+
raise CircuitBreakerOpenError(self.name, self.config.recovery_timeout)
|
|
96
|
+
self._half_open_calls += 1
|
|
97
|
+
|
|
98
|
+
def record_success(self) -> None:
|
|
99
|
+
"""Record a successful call."""
|
|
100
|
+
with self._lock:
|
|
101
|
+
if self._state == CircuitState.HALF_OPEN:
|
|
102
|
+
self._success_count += 1
|
|
103
|
+
if self._success_count >= self.config.half_open_max_calls:
|
|
104
|
+
self._state = CircuitState.CLOSED
|
|
105
|
+
self._failure_count = 0
|
|
106
|
+
self._success_count = 0
|
|
107
|
+
logger.info("Circuit breaker closed", name=self.name)
|
|
108
|
+
else:
|
|
109
|
+
self._failure_count = 0
|
|
110
|
+
self._success_count = 0
|
|
111
|
+
|
|
112
|
+
def record_failure(self) -> None:
|
|
113
|
+
"""Record a failed call."""
|
|
114
|
+
with self._lock:
|
|
115
|
+
self._failure_count += 1
|
|
116
|
+
self._last_failure_time = time.time()
|
|
117
|
+
|
|
118
|
+
if self._state == CircuitState.HALF_OPEN:
|
|
119
|
+
self._state = CircuitState.OPEN
|
|
120
|
+
logger.warning(
|
|
121
|
+
"Circuit breaker reopened from half-open",
|
|
122
|
+
name=self.name,
|
|
123
|
+
)
|
|
124
|
+
elif self._failure_count >= self.config.failure_threshold:
|
|
125
|
+
self._state = CircuitState.OPEN
|
|
126
|
+
logger.warning(
|
|
127
|
+
"Circuit breaker opened",
|
|
128
|
+
name=self.name,
|
|
129
|
+
failure_count=self._failure_count,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def get_status(self) -> dict[str, str | int]:
|
|
133
|
+
"""Get circuit breaker status for health checks."""
|
|
134
|
+
return {
|
|
135
|
+
"name": self.name,
|
|
136
|
+
"state": self.state.value,
|
|
137
|
+
"failure_count": self._failure_count,
|
|
138
|
+
}
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Retry logic with exponential backoff and jitter."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import random
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from functools import wraps
|
|
8
|
+
from typing import ParamSpec, TypeVar
|
|
9
|
+
|
|
10
|
+
import structlog
|
|
11
|
+
|
|
12
|
+
logger = structlog.get_logger()
|
|
13
|
+
|
|
14
|
+
P = ParamSpec("P")
|
|
15
|
+
T = TypeVar("T")
|
|
16
|
+
|
|
17
|
+
# Exceptions that should trigger retry
|
|
18
|
+
RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = (
|
|
19
|
+
ConnectionError,
|
|
20
|
+
TimeoutError,
|
|
21
|
+
OSError,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class RetryConfig:
|
|
27
|
+
"""Configuration for retry behavior."""
|
|
28
|
+
|
|
29
|
+
max_attempts: int = 3
|
|
30
|
+
base_delay: float = 1.0
|
|
31
|
+
max_delay: float = 30.0
|
|
32
|
+
exponential_base: float = 2.0
|
|
33
|
+
jitter: bool = True
|
|
34
|
+
retryable_exceptions: tuple[type[Exception], ...] = field(
|
|
35
|
+
default_factory=lambda: RETRYABLE_EXCEPTIONS
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def calculate_delay(attempt: int, config: RetryConfig) -> float:
|
|
40
|
+
"""Calculate delay for next retry attempt."""
|
|
41
|
+
delay = config.base_delay * (config.exponential_base ** (attempt - 1))
|
|
42
|
+
delay = min(delay, config.max_delay)
|
|
43
|
+
|
|
44
|
+
if config.jitter:
|
|
45
|
+
delay = delay * (0.5 + random.random())
|
|
46
|
+
|
|
47
|
+
return delay
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def with_retry(
|
|
51
|
+
config: RetryConfig | None = None,
|
|
52
|
+
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
|
|
53
|
+
"""Decorator for async functions with retry logic.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
config: Retry configuration. Uses defaults if not provided.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Decorated function that retries on transient failures.
|
|
60
|
+
"""
|
|
61
|
+
if config is None:
|
|
62
|
+
config = RetryConfig()
|
|
63
|
+
|
|
64
|
+
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
|
65
|
+
@wraps(func)
|
|
66
|
+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
67
|
+
last_exception: Exception | None = None
|
|
68
|
+
|
|
69
|
+
for attempt in range(1, config.max_attempts + 1):
|
|
70
|
+
try:
|
|
71
|
+
return await func(*args, **kwargs)
|
|
72
|
+
except config.retryable_exceptions as e:
|
|
73
|
+
last_exception = e
|
|
74
|
+
if attempt < config.max_attempts:
|
|
75
|
+
delay = calculate_delay(attempt, config)
|
|
76
|
+
logger.warning(
|
|
77
|
+
"Retry attempt",
|
|
78
|
+
function=func.__name__,
|
|
79
|
+
attempt=attempt,
|
|
80
|
+
max_attempts=config.max_attempts,
|
|
81
|
+
delay=delay,
|
|
82
|
+
error=str(e),
|
|
83
|
+
)
|
|
84
|
+
await asyncio.sleep(delay)
|
|
85
|
+
except Exception:
|
|
86
|
+
# Non-retryable exception, raise immediately
|
|
87
|
+
raise
|
|
88
|
+
|
|
89
|
+
# Exhausted all retries
|
|
90
|
+
if last_exception:
|
|
91
|
+
raise last_exception
|
|
92
|
+
raise RuntimeError("Unexpected retry state")
|
|
93
|
+
|
|
94
|
+
return wrapper
|
|
95
|
+
|
|
96
|
+
return decorator
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Request throttling to prevent rate limiting."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import random
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from types import TracebackType
|
|
8
|
+
|
|
9
|
+
import structlog
|
|
10
|
+
|
|
11
|
+
logger = structlog.get_logger()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class ThrottleConfig:
|
|
16
|
+
"""Configuration for request throttling."""
|
|
17
|
+
|
|
18
|
+
max_concurrent: int = 3
|
|
19
|
+
"""Maximum concurrent requests allowed."""
|
|
20
|
+
|
|
21
|
+
min_delay_ms: int = 200
|
|
22
|
+
"""Minimum delay between requests in milliseconds."""
|
|
23
|
+
|
|
24
|
+
jitter_ms: int = 200
|
|
25
|
+
"""Random jitter added to delay (0 to jitter_ms)."""
|
|
26
|
+
|
|
27
|
+
enabled: bool = True
|
|
28
|
+
"""Whether throttling is enabled."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Throttler:
|
|
32
|
+
"""
|
|
33
|
+
Semaphore-based request throttler with delay and jitter.
|
|
34
|
+
|
|
35
|
+
Limits concurrent requests and enforces minimum delays between requests
|
|
36
|
+
to prevent overwhelming external APIs with burst traffic.
|
|
37
|
+
|
|
38
|
+
Usage:
|
|
39
|
+
throttler = Throttler(ThrottleConfig(max_concurrent=3))
|
|
40
|
+
|
|
41
|
+
async with throttler:
|
|
42
|
+
await make_request()
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, config: ThrottleConfig, name: str = "default"):
|
|
46
|
+
"""Initialize throttler with configuration.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
config: Throttling configuration
|
|
50
|
+
name: Name for logging purposes
|
|
51
|
+
"""
|
|
52
|
+
self._config = config
|
|
53
|
+
self._name = name
|
|
54
|
+
self._semaphore = asyncio.Semaphore(config.max_concurrent)
|
|
55
|
+
self._last_request_time: float = 0
|
|
56
|
+
self._lock = asyncio.Lock()
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def config(self) -> ThrottleConfig:
|
|
60
|
+
"""Get the throttle configuration."""
|
|
61
|
+
return self._config
|
|
62
|
+
|
|
63
|
+
async def __aenter__(self) -> "Throttler":
|
|
64
|
+
"""Acquire throttle slot, waiting if necessary."""
|
|
65
|
+
if not self._config.enabled:
|
|
66
|
+
return self
|
|
67
|
+
|
|
68
|
+
await self._semaphore.acquire()
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
await self._wait_for_delay()
|
|
72
|
+
except Exception:
|
|
73
|
+
self._semaphore.release()
|
|
74
|
+
raise
|
|
75
|
+
|
|
76
|
+
return self
|
|
77
|
+
|
|
78
|
+
async def __aexit__(
|
|
79
|
+
self,
|
|
80
|
+
exc_type: type[BaseException] | None,
|
|
81
|
+
exc_val: BaseException | None,
|
|
82
|
+
exc_tb: TracebackType | None,
|
|
83
|
+
) -> None:
|
|
84
|
+
"""Release throttle slot and record request time."""
|
|
85
|
+
if not self._config.enabled:
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
async with self._lock:
|
|
89
|
+
self._last_request_time = time.monotonic()
|
|
90
|
+
|
|
91
|
+
self._semaphore.release()
|
|
92
|
+
|
|
93
|
+
async def _wait_for_delay(self) -> None:
|
|
94
|
+
"""Wait for minimum delay + jitter since last request."""
|
|
95
|
+
async with self._lock:
|
|
96
|
+
now = time.monotonic()
|
|
97
|
+
min_delay = self._config.min_delay_ms / 1000.0
|
|
98
|
+
elapsed = now - self._last_request_time
|
|
99
|
+
|
|
100
|
+
wait_time = min_delay - elapsed if elapsed < min_delay else 0
|
|
101
|
+
|
|
102
|
+
# Add jitter to make traffic pattern less predictable
|
|
103
|
+
if self._config.jitter_ms > 0:
|
|
104
|
+
jitter = random.uniform(0, self._config.jitter_ms / 1000.0)
|
|
105
|
+
wait_time += jitter
|
|
106
|
+
|
|
107
|
+
if wait_time > 0:
|
|
108
|
+
logger.debug(
|
|
109
|
+
"Throttling request",
|
|
110
|
+
throttler=self._name,
|
|
111
|
+
wait_ms=int(wait_time * 1000),
|
|
112
|
+
)
|
|
113
|
+
await asyncio.sleep(wait_time)
|