pocketsmith-mcp 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. pocketsmith_mcp/__init__.py +8 -0
  2. pocketsmith_mcp/__main__.py +19 -0
  3. pocketsmith_mcp/client/__init__.py +14 -0
  4. pocketsmith_mcp/client/api_client.py +269 -0
  5. pocketsmith_mcp/client/circuit_breaker.py +179 -0
  6. pocketsmith_mcp/client/rate_limiter.py +106 -0
  7. pocketsmith_mcp/client/retry.py +106 -0
  8. pocketsmith_mcp/config.py +110 -0
  9. pocketsmith_mcp/errors.py +87 -0
  10. pocketsmith_mcp/logger.py +69 -0
  11. pocketsmith_mcp/models/__init__.py +24 -0
  12. pocketsmith_mcp/models/account.py +177 -0
  13. pocketsmith_mcp/models/attachment.py +81 -0
  14. pocketsmith_mcp/models/category.py +90 -0
  15. pocketsmith_mcp/models/common.py +65 -0
  16. pocketsmith_mcp/models/event.py +81 -0
  17. pocketsmith_mcp/models/institution.py +31 -0
  18. pocketsmith_mcp/models/transaction.py +94 -0
  19. pocketsmith_mcp/models/user.py +73 -0
  20. pocketsmith_mcp/server.py +69 -0
  21. pocketsmith_mcp/tools/__init__.py +40 -0
  22. pocketsmith_mcp/tools/accounts.py +122 -0
  23. pocketsmith_mcp/tools/attachments.py +149 -0
  24. pocketsmith_mcp/tools/budgeting.py +169 -0
  25. pocketsmith_mcp/tools/categories.py +183 -0
  26. pocketsmith_mcp/tools/events.py +195 -0
  27. pocketsmith_mcp/tools/institutions.py +143 -0
  28. pocketsmith_mcp/tools/labels.py +56 -0
  29. pocketsmith_mcp/tools/transaction_accounts.py +117 -0
  30. pocketsmith_mcp/tools/transactions.py +241 -0
  31. pocketsmith_mcp/tools/users.py +101 -0
  32. pocketsmith_mcp/tools/utilities.py +52 -0
  33. pocketsmith_mcp-1.0.0.dist-info/METADATA +365 -0
  34. pocketsmith_mcp-1.0.0.dist-info/RECORD +37 -0
  35. pocketsmith_mcp-1.0.0.dist-info/WHEEL +4 -0
  36. pocketsmith_mcp-1.0.0.dist-info/entry_points.txt +2 -0
  37. pocketsmith_mcp-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,8 @@
1
+ """PocketSmith MCP Server - Production-ready MCP server for PocketSmith API."""
2
+
3
+ __version__ = "1.0.0"
4
+ __author__ = "PocketSmith MCP"
5
+
6
+ from pocketsmith_mcp.server import create_server
7
+
8
+ __all__ = ["create_server", "__version__"]
@@ -0,0 +1,19 @@
1
+ """Entry point for pocketsmith-mcp server.
2
+
3
+ This module allows running the server as:
4
+ python -m pocketsmith_mcp
5
+ uvx pocketsmith-mcp
6
+ uv run pocketsmith-mcp
7
+ """
8
+
9
+ from pocketsmith_mcp.server import get_server
10
+
11
+
12
+ def main() -> None:
13
+ """Run the PocketSmith MCP server."""
14
+ server = get_server()
15
+ server.run()
16
+
17
+
18
+ if __name__ == "__main__":
19
+ main()
@@ -0,0 +1,14 @@
1
+ """PocketSmith API client with retry, rate limiting, and circuit breaker."""
2
+
3
+ from pocketsmith_mcp.client.api_client import PocketSmithClient
4
+ from pocketsmith_mcp.client.circuit_breaker import CircuitBreaker, CircuitState
5
+ from pocketsmith_mcp.client.rate_limiter import RateLimiter
6
+ from pocketsmith_mcp.client.retry import retry_with_backoff
7
+
8
+ __all__ = [
9
+ "CircuitBreaker",
10
+ "CircuitState",
11
+ "PocketSmithClient",
12
+ "RateLimiter",
13
+ "retry_with_backoff",
14
+ ]
@@ -0,0 +1,269 @@
1
+ """Async HTTP client for PocketSmith API with retry, rate limiting, circuit breaker."""
2
+
3
+ from typing import Any
4
+
5
+ import httpx
6
+
7
+ from pocketsmith_mcp.client.circuit_breaker import CircuitBreaker
8
+ from pocketsmith_mcp.client.rate_limiter import RateLimiter
9
+ from pocketsmith_mcp.client.retry import retry_with_backoff
10
+ from pocketsmith_mcp.errors import APIError, AuthError, CircuitBreakerOpenError, RateLimitError
11
+ from pocketsmith_mcp.logger import get_logger
12
+
13
+ logger = get_logger("api_client")
14
+
15
+
16
+ class PocketSmithClient:
17
+ """
18
+ Production-ready async client for PocketSmith API v2.
19
+
20
+ Features:
21
+ - Rate limiting (token bucket algorithm)
22
+ - Retry with exponential backoff and jitter
23
+ - Circuit breaker for fault tolerance
24
+ - Comprehensive error handling
25
+ """
26
+
27
+ BASE_URL = "https://api.pocketsmith.com/v2"
28
+
29
+ def __init__(
30
+ self,
31
+ api_key: str,
32
+ base_url: str | None = None,
33
+ timeout: float = 30.0,
34
+ max_retries: int = 3,
35
+ rate_limit_per_minute: int = 60,
36
+ ):
37
+ """
38
+ Initialize the PocketSmith API client.
39
+
40
+ Args:
41
+ api_key: PocketSmith API key (X-Developer-Key)
42
+ base_url: API base URL (default: https://api.pocketsmith.com/v2)
43
+ timeout: Request timeout in seconds
44
+ max_retries: Maximum retry attempts for failed requests
45
+ rate_limit_per_minute: Maximum requests per minute
46
+ """
47
+ if not api_key:
48
+ raise ValueError("api_key is required")
49
+
50
+ self.api_key = api_key
51
+ self.base_url = base_url or self.BASE_URL
52
+ self.timeout = timeout
53
+ self.max_retries = max_retries
54
+
55
+ self._client = httpx.AsyncClient(
56
+ base_url=self.base_url,
57
+ headers={
58
+ "X-Developer-Key": api_key,
59
+ "Content-Type": "application/json",
60
+ "Accept": "application/json",
61
+ },
62
+ timeout=timeout,
63
+ )
64
+
65
+ self._rate_limiter = RateLimiter(
66
+ tokens_per_interval=rate_limit_per_minute,
67
+ interval_seconds=60,
68
+ )
69
+
70
+ self._circuit_breaker = CircuitBreaker(
71
+ failure_threshold=5,
72
+ reset_timeout_seconds=60,
73
+ )
74
+
75
+ async def _request(
76
+ self,
77
+ method: str,
78
+ path: str,
79
+ params: dict[str, Any] | None = None,
80
+ json_data: dict[str, Any] | None = None,
81
+ ) -> dict[str, Any] | list[Any]:
82
+ """
83
+ Make an authenticated API request with retry, rate limiting, and circuit breaker.
84
+
85
+ Args:
86
+ method: HTTP method (GET, POST, PUT, DELETE)
87
+ path: API endpoint path
88
+ params: Query parameters
89
+ json_data: JSON request body
90
+
91
+ Returns:
92
+ Parsed JSON response
93
+
94
+ Raises:
95
+ AuthError: Authentication failed (401)
96
+ RateLimitError: Rate limit exceeded (429)
97
+ APIError: Other API errors
98
+ CircuitBreakerOpenError: Circuit breaker is open
99
+ """
100
+ # Check circuit breaker
101
+ if not self._circuit_breaker.can_execute():
102
+ raise CircuitBreakerOpenError()
103
+
104
+ # Rate limiting
105
+ await self._rate_limiter.acquire()
106
+
107
+ async def execute_request() -> dict[str, Any] | list[Any]:
108
+ # Clean up params - remove None values
109
+ clean_params = None
110
+ if params:
111
+ clean_params = {k: v for k, v in params.items() if v is not None}
112
+
113
+ logger.debug(f"Request: {method} {path} params={clean_params}")
114
+
115
+ response = await self._client.request(
116
+ method=method,
117
+ url=path,
118
+ params=clean_params,
119
+ json=json_data,
120
+ )
121
+
122
+ logger.debug(f"Response: {response.status_code}")
123
+
124
+ # Handle errors
125
+ if response.status_code == 401:
126
+ raise AuthError("Invalid API key")
127
+
128
+ if response.status_code == 429:
129
+ retry_after = response.headers.get("Retry-After", "60")
130
+ raise RateLimitError(
131
+ f"Rate limit exceeded. Retry after {retry_after}s",
132
+ retry_after=int(retry_after),
133
+ )
134
+
135
+ if response.status_code >= 500:
136
+ self._circuit_breaker.record_failure()
137
+ raise APIError(
138
+ f"Server error: {response.status_code}",
139
+ status_code=response.status_code,
140
+ response_body=response.text,
141
+ )
142
+
143
+ if response.status_code >= 400:
144
+ error_body = response.text
145
+ try:
146
+ error_json = response.json()
147
+ if "error" in error_json:
148
+ error_body = error_json["error"]
149
+ except Exception:
150
+ pass
151
+ raise APIError(
152
+ f"Client error: {response.status_code}",
153
+ status_code=response.status_code,
154
+ response_body=error_body,
155
+ )
156
+
157
+ # Record success
158
+ self._circuit_breaker.record_success()
159
+
160
+ # Handle empty responses
161
+ if response.status_code == 204:
162
+ return {}
163
+
164
+ result: dict[str, Any] | list[Any] = response.json()
165
+ return result
166
+
167
+ # Retry with backoff for retryable errors
168
+ return await retry_with_backoff(
169
+ execute_request,
170
+ max_attempts=self.max_retries,
171
+ base_delay=1.0,
172
+ max_delay=30.0,
173
+ retryable_errors=(httpx.TimeoutException, httpx.NetworkError),
174
+ )
175
+
176
+ async def get(
177
+ self,
178
+ path: str,
179
+ params: dict[str, Any] | None = None,
180
+ ) -> dict[str, Any] | list[Any]:
181
+ """
182
+ Make a GET request.
183
+
184
+ Args:
185
+ path: API endpoint path
186
+ params: Query parameters
187
+
188
+ Returns:
189
+ Parsed JSON response
190
+ """
191
+ return await self._request("GET", path, params=params)
192
+
193
+ async def post(
194
+ self,
195
+ path: str,
196
+ json_data: dict[str, Any] | None = None,
197
+ ) -> dict[str, Any] | list[Any]:
198
+ """
199
+ Make a POST request.
200
+
201
+ Args:
202
+ path: API endpoint path
203
+ json_data: JSON request body
204
+
205
+ Returns:
206
+ Parsed JSON response
207
+ """
208
+ return await self._request("POST", path, json_data=json_data)
209
+
210
+ async def put(
211
+ self,
212
+ path: str,
213
+ json_data: dict[str, Any] | None = None,
214
+ ) -> dict[str, Any] | list[Any]:
215
+ """
216
+ Make a PUT request.
217
+
218
+ Args:
219
+ path: API endpoint path
220
+ json_data: JSON request body
221
+
222
+ Returns:
223
+ Parsed JSON response
224
+ """
225
+ return await self._request("PUT", path, json_data=json_data)
226
+
227
+ async def delete(self, path: str) -> dict[str, Any] | list[Any]:
228
+ """
229
+ Make a DELETE request.
230
+
231
+ Args:
232
+ path: API endpoint path
233
+
234
+ Returns:
235
+ Parsed JSON response (usually empty)
236
+ """
237
+ return await self._request("DELETE", path)
238
+
239
+ async def close(self) -> None:
240
+ """Close the HTTP client."""
241
+ await self._client.aclose()
242
+
243
+ async def __aenter__(self) -> "PocketSmithClient":
244
+ """Async context manager entry."""
245
+ return self
246
+
247
+ async def __aexit__(
248
+ self,
249
+ exc_type: type[BaseException] | None,
250
+ exc_val: BaseException | None,
251
+ exc_tb: Any,
252
+ ) -> None:
253
+ """Async context manager exit."""
254
+ await self.close()
255
+
256
+ def get_stats(self) -> dict[str, Any]:
257
+ """
258
+ Get client statistics.
259
+
260
+ Returns:
261
+ Dictionary with rate limiter and circuit breaker stats
262
+ """
263
+ return {
264
+ "rate_limiter": {
265
+ "available_tokens": self._rate_limiter.available_tokens,
266
+ "max_tokens": self._rate_limiter.max_tokens,
267
+ },
268
+ "circuit_breaker": self._circuit_breaker.get_stats(),
269
+ }
@@ -0,0 +1,179 @@
1
+ """Circuit breaker pattern for fault tolerance."""
2
+
3
+ import time
4
+ from enum import Enum
5
+ from threading import Lock
6
+ from typing import Any
7
+
8
+ from pocketsmith_mcp.logger import get_logger
9
+
10
+ logger = get_logger("circuit_breaker")
11
+
12
+
13
+ class CircuitState(str, Enum):
14
+ """Circuit breaker states."""
15
+
16
+ CLOSED = "closed" # Normal operation
17
+ OPEN = "open" # Blocking all calls
18
+ HALF_OPEN = "half_open" # Testing if service recovered
19
+
20
+
21
+ class CircuitBreaker:
22
+ """
23
+ Circuit breaker for external service calls.
24
+
25
+ Implements the circuit breaker pattern to prevent cascading failures
26
+ when an external service is unhealthy.
27
+
28
+ States:
29
+ - CLOSED: Normal operation, all calls pass through
30
+ - OPEN: Service is unhealthy, all calls fail immediately
31
+ - HALF_OPEN: Testing if service recovered, limited calls allowed
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ failure_threshold: int = 5,
37
+ reset_timeout_seconds: float = 60.0,
38
+ half_open_max_calls: int = 1,
39
+ ):
40
+ """
41
+ Initialize the circuit breaker.
42
+
43
+ Args:
44
+ failure_threshold: Number of failures before opening circuit
45
+ reset_timeout_seconds: Time to wait before testing recovery
46
+ half_open_max_calls: Number of test calls allowed in half-open state
47
+ """
48
+ if failure_threshold < 1:
49
+ raise ValueError("failure_threshold must be at least 1")
50
+ if reset_timeout_seconds <= 0:
51
+ raise ValueError("reset_timeout_seconds must be positive")
52
+ if half_open_max_calls < 1:
53
+ raise ValueError("half_open_max_calls must be at least 1")
54
+
55
+ self.failure_threshold = failure_threshold
56
+ self.reset_timeout_seconds = reset_timeout_seconds
57
+ self.half_open_max_calls = half_open_max_calls
58
+
59
+ self._state = CircuitState.CLOSED
60
+ self._failures = 0
61
+ self._successes = 0
62
+ self._last_failure_time: float = 0.0
63
+ self._half_open_calls = 0
64
+ self._lock = Lock()
65
+
66
+ @property
67
+ def state(self) -> CircuitState:
68
+ """Get the current circuit state."""
69
+ with self._lock:
70
+ self._check_state_transition()
71
+ return self._state
72
+
73
+ @property
74
+ def failures(self) -> int:
75
+ """Get the current failure count."""
76
+ return self._failures
77
+
78
+ @property
79
+ def is_closed(self) -> bool:
80
+ """Check if the circuit is closed (normal operation)."""
81
+ return self.state == CircuitState.CLOSED
82
+
83
+ @property
84
+ def is_open(self) -> bool:
85
+ """Check if the circuit is open (blocking calls)."""
86
+ return self.state == CircuitState.OPEN
87
+
88
+ def can_execute(self) -> bool:
89
+ """
90
+ Check if the circuit allows execution.
91
+
92
+ Returns:
93
+ True if a call can be made, False if blocked
94
+ """
95
+ with self._lock:
96
+ self._check_state_transition()
97
+
98
+ if self._state == CircuitState.CLOSED:
99
+ return True
100
+
101
+ if self._state == CircuitState.OPEN:
102
+ return False
103
+
104
+ # HALF_OPEN: Allow limited test calls
105
+ if self._half_open_calls < self.half_open_max_calls:
106
+ self._half_open_calls += 1
107
+ return True
108
+ return False
109
+
110
+ def record_success(self) -> None:
111
+ """Record a successful call."""
112
+ with self._lock:
113
+ self._successes += 1
114
+
115
+ if self._state == CircuitState.HALF_OPEN:
116
+ # Service recovered, close the circuit
117
+ logger.info("Circuit breaker: Service recovered, closing circuit")
118
+ self._state = CircuitState.CLOSED
119
+
120
+ # Reset failure count on success
121
+ self._failures = 0
122
+ self._half_open_calls = 0
123
+
124
+ def record_failure(self) -> None:
125
+ """Record a failed call."""
126
+ with self._lock:
127
+ self._failures += 1
128
+ self._last_failure_time = time.monotonic()
129
+
130
+ if self._state == CircuitState.HALF_OPEN:
131
+ # Test call failed, reopen circuit
132
+ logger.warning("Circuit breaker: Test call failed, reopening circuit")
133
+ self._state = CircuitState.OPEN
134
+ return
135
+
136
+ if self._failures >= self.failure_threshold:
137
+ # Too many failures, open circuit
138
+ logger.warning(
139
+ f"Circuit breaker: {self._failures} failures reached threshold, "
140
+ f"opening circuit for {self.reset_timeout_seconds}s"
141
+ )
142
+ self._state = CircuitState.OPEN
143
+
144
+ def _check_state_transition(self) -> None:
145
+ """Check if state should transition based on timeout."""
146
+ if self._state == CircuitState.OPEN:
147
+ elapsed = time.monotonic() - self._last_failure_time
148
+ if elapsed >= self.reset_timeout_seconds:
149
+ logger.info("Circuit breaker: Reset timeout elapsed, entering half-open state")
150
+ self._state = CircuitState.HALF_OPEN
151
+ self._half_open_calls = 0
152
+
153
+ def reset(self) -> None:
154
+ """Reset the circuit breaker to initial state."""
155
+ with self._lock:
156
+ self._state = CircuitState.CLOSED
157
+ self._failures = 0
158
+ self._successes = 0
159
+ self._last_failure_time = 0.0
160
+ self._half_open_calls = 0
161
+ logger.info("Circuit breaker: Reset to closed state")
162
+
163
+ def force_open(self) -> None:
164
+ """Force the circuit to open state."""
165
+ with self._lock:
166
+ self._state = CircuitState.OPEN
167
+ self._last_failure_time = time.monotonic()
168
+ logger.warning("Circuit breaker: Forced to open state")
169
+
170
+ def get_stats(self) -> dict[str, Any]:
171
+ """Get circuit breaker statistics."""
172
+ with self._lock:
173
+ return {
174
+ "state": self._state.value,
175
+ "failures": self._failures,
176
+ "successes": self._successes,
177
+ "failure_threshold": self.failure_threshold,
178
+ "reset_timeout_seconds": self.reset_timeout_seconds,
179
+ }
@@ -0,0 +1,106 @@
1
+ """Token bucket rate limiter for API calls."""
2
+
3
+ import asyncio
4
+ import time
5
+
6
+
7
+ class RateLimiter:
8
+ """
9
+ Token bucket rate limiter with async support.
10
+
11
+ Implements a token bucket algorithm that allows a certain number of
12
+ requests per time interval. Tokens are refilled continuously based
13
+ on elapsed time.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ tokens_per_interval: int,
19
+ interval_seconds: float,
20
+ initial_tokens: int | None = None,
21
+ ):
22
+ """
23
+ Initialize the rate limiter.
24
+
25
+ Args:
26
+ tokens_per_interval: Number of tokens to add per interval
27
+ interval_seconds: Length of the interval in seconds
28
+ initial_tokens: Initial number of tokens (defaults to tokens_per_interval)
29
+ """
30
+ if tokens_per_interval <= 0:
31
+ raise ValueError("tokens_per_interval must be positive")
32
+ if interval_seconds <= 0:
33
+ raise ValueError("interval_seconds must be positive")
34
+
35
+ self.tokens_per_interval = tokens_per_interval
36
+ self.interval_seconds = interval_seconds
37
+ self.tokens = float(initial_tokens if initial_tokens is not None else tokens_per_interval)
38
+ self.max_tokens = float(tokens_per_interval)
39
+ self.last_refill = time.monotonic()
40
+ self._lock = asyncio.Lock()
41
+
42
+ async def acquire(self, tokens: int = 1) -> None:
43
+ """
44
+ Acquire tokens, waiting if necessary.
45
+
46
+ Args:
47
+ tokens: Number of tokens to acquire (default: 1)
48
+
49
+ Raises:
50
+ ValueError: If tokens requested exceeds max_tokens
51
+ """
52
+ if tokens > self.max_tokens:
53
+ raise ValueError(f"Cannot acquire {tokens} tokens (max: {self.max_tokens})")
54
+
55
+ async with self._lock:
56
+ self._refill()
57
+
58
+ if self.tokens >= tokens:
59
+ self.tokens -= tokens
60
+ return
61
+
62
+ # Calculate wait time until we have enough tokens
63
+ tokens_needed = tokens - self.tokens
64
+ wait_time = (tokens_needed / self.tokens_per_interval) * self.interval_seconds
65
+
66
+ await asyncio.sleep(wait_time)
67
+ self._refill()
68
+ self.tokens -= tokens
69
+
70
+ def try_acquire(self, tokens: int = 1) -> bool:
71
+ """
72
+ Try to acquire tokens without waiting.
73
+
74
+ Args:
75
+ tokens: Number of tokens to acquire (default: 1)
76
+
77
+ Returns:
78
+ True if tokens were acquired, False otherwise
79
+ """
80
+ self._refill()
81
+
82
+ if self.tokens >= tokens:
83
+ self.tokens -= tokens
84
+ return True
85
+ return False
86
+
87
+ def _refill(self) -> None:
88
+ """Refill tokens based on elapsed time."""
89
+ now = time.monotonic()
90
+ elapsed = now - self.last_refill
91
+
92
+ # Calculate tokens to add based on elapsed time
93
+ tokens_to_add = (elapsed / self.interval_seconds) * self.tokens_per_interval
94
+ self.tokens = min(self.max_tokens, self.tokens + tokens_to_add)
95
+ self.last_refill = now
96
+
97
+ @property
98
+ def available_tokens(self) -> float:
99
+ """Get the current number of available tokens."""
100
+ self._refill()
101
+ return self.tokens
102
+
103
+ def reset(self) -> None:
104
+ """Reset the rate limiter to full capacity."""
105
+ self.tokens = self.max_tokens
106
+ self.last_refill = time.monotonic()
@@ -0,0 +1,106 @@
1
+ """Exponential backoff retry with jitter."""
2
+
3
+ import asyncio
4
+ import random
5
+ from collections.abc import Awaitable, Callable
6
+ from typing import TypeVar
7
+
8
+ from pocketsmith_mcp.logger import get_logger
9
+
10
+ T = TypeVar("T")
11
+ logger = get_logger("retry")
12
+
13
+
14
+ async def retry_with_backoff(
15
+ func: Callable[[], Awaitable[T]],
16
+ max_attempts: int = 3,
17
+ base_delay: float = 1.0,
18
+ max_delay: float = 30.0,
19
+ jitter_factor: float = 0.2,
20
+ retryable_errors: tuple[type[Exception], ...] = (Exception,),
21
+ on_retry: Callable[[Exception, int], None] | None = None,
22
+ ) -> T:
23
+ """
24
+ Retry an async function with exponential backoff and jitter.
25
+
26
+ Args:
27
+ func: Async function to retry (no arguments)
28
+ max_attempts: Maximum number of attempts (default: 3)
29
+ base_delay: Base delay in seconds (default: 1.0)
30
+ max_delay: Maximum delay in seconds (default: 30.0)
31
+ jitter_factor: Jitter factor (0.0-1.0) to randomize delay (default: 0.2)
32
+ retryable_errors: Tuple of exception types to retry (default: all)
33
+ on_retry: Optional callback called on each retry with (exception, attempt)
34
+
35
+ Returns:
36
+ Result of the function
37
+
38
+ Raises:
39
+ The last exception if all retries fail
40
+ """
41
+ if max_attempts < 1:
42
+ raise ValueError("max_attempts must be at least 1")
43
+ if base_delay <= 0:
44
+ raise ValueError("base_delay must be positive")
45
+ if max_delay <= 0:
46
+ raise ValueError("max_delay must be positive")
47
+ if not 0 <= jitter_factor <= 1:
48
+ raise ValueError("jitter_factor must be between 0 and 1")
49
+
50
+ last_error: Exception = Exception("No attempts made")
51
+
52
+ for attempt in range(1, max_attempts + 1):
53
+ try:
54
+ return await func()
55
+ except retryable_errors as e:
56
+ last_error = e
57
+
58
+ if attempt == max_attempts:
59
+ logger.warning(
60
+ f"All {max_attempts} attempts failed. Last error: {e}"
61
+ )
62
+ break
63
+
64
+ # Calculate delay with exponential backoff
65
+ delay = min(base_delay * (2 ** (attempt - 1)), max_delay)
66
+
67
+ # Add jitter
68
+ jitter = delay * jitter_factor * random.random()
69
+ total_delay = delay + jitter
70
+
71
+ logger.info(
72
+ f"Attempt {attempt}/{max_attempts} failed: {e}. "
73
+ f"Retrying in {total_delay:.2f}s"
74
+ )
75
+
76
+ if on_retry:
77
+ on_retry(e, attempt)
78
+
79
+ await asyncio.sleep(total_delay)
80
+
81
+ raise last_error
82
+
83
+
84
+ def calculate_delay(
85
+ attempt: int,
86
+ base_delay: float = 1.0,
87
+ max_delay: float = 30.0,
88
+ jitter_factor: float = 0.2,
89
+ ) -> float:
90
+ """
91
+ Calculate delay for a given attempt number.
92
+
93
+ Args:
94
+ attempt: Current attempt number (1-based)
95
+ base_delay: Base delay in seconds
96
+ max_delay: Maximum delay in seconds
97
+ jitter_factor: Jitter factor (0.0-1.0)
98
+
99
+ Returns:
100
+ Calculated delay in seconds
101
+ """
102
+ delay = min(base_delay * (2 ** (attempt - 1)), max_delay)
103
+ rand_val: float = random.random()
104
+ jitter = delay * jitter_factor * rand_val
105
+ total_delay: float = delay + jitter
106
+ return total_delay