pocketsmith-mcp 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pocketsmith_mcp/__init__.py +8 -0
- pocketsmith_mcp/__main__.py +19 -0
- pocketsmith_mcp/client/__init__.py +14 -0
- pocketsmith_mcp/client/api_client.py +269 -0
- pocketsmith_mcp/client/circuit_breaker.py +179 -0
- pocketsmith_mcp/client/rate_limiter.py +106 -0
- pocketsmith_mcp/client/retry.py +106 -0
- pocketsmith_mcp/config.py +110 -0
- pocketsmith_mcp/errors.py +87 -0
- pocketsmith_mcp/logger.py +69 -0
- pocketsmith_mcp/models/__init__.py +24 -0
- pocketsmith_mcp/models/account.py +177 -0
- pocketsmith_mcp/models/attachment.py +81 -0
- pocketsmith_mcp/models/category.py +90 -0
- pocketsmith_mcp/models/common.py +65 -0
- pocketsmith_mcp/models/event.py +81 -0
- pocketsmith_mcp/models/institution.py +31 -0
- pocketsmith_mcp/models/transaction.py +94 -0
- pocketsmith_mcp/models/user.py +73 -0
- pocketsmith_mcp/server.py +69 -0
- pocketsmith_mcp/tools/__init__.py +40 -0
- pocketsmith_mcp/tools/accounts.py +122 -0
- pocketsmith_mcp/tools/attachments.py +149 -0
- pocketsmith_mcp/tools/budgeting.py +169 -0
- pocketsmith_mcp/tools/categories.py +183 -0
- pocketsmith_mcp/tools/events.py +195 -0
- pocketsmith_mcp/tools/institutions.py +143 -0
- pocketsmith_mcp/tools/labels.py +56 -0
- pocketsmith_mcp/tools/transaction_accounts.py +117 -0
- pocketsmith_mcp/tools/transactions.py +241 -0
- pocketsmith_mcp/tools/users.py +101 -0
- pocketsmith_mcp/tools/utilities.py +52 -0
- pocketsmith_mcp-1.0.0.dist-info/METADATA +365 -0
- pocketsmith_mcp-1.0.0.dist-info/RECORD +37 -0
- pocketsmith_mcp-1.0.0.dist-info/WHEEL +4 -0
- pocketsmith_mcp-1.0.0.dist-info/entry_points.txt +2 -0
- pocketsmith_mcp-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Entry point for pocketsmith-mcp server.
|
|
2
|
+
|
|
3
|
+
This module allows running the server as:
|
|
4
|
+
python -m pocketsmith_mcp
|
|
5
|
+
uvx pocketsmith-mcp
|
|
6
|
+
uv run pocketsmith-mcp
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from pocketsmith_mcp.server import get_server
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def main() -> None:
|
|
13
|
+
"""Run the PocketSmith MCP server."""
|
|
14
|
+
server = get_server()
|
|
15
|
+
server.run()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
if __name__ == "__main__":
|
|
19
|
+
main()
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""PocketSmith API client with retry, rate limiting, and circuit breaker."""
|
|
2
|
+
|
|
3
|
+
from pocketsmith_mcp.client.api_client import PocketSmithClient
|
|
4
|
+
from pocketsmith_mcp.client.circuit_breaker import CircuitBreaker, CircuitState
|
|
5
|
+
from pocketsmith_mcp.client.rate_limiter import RateLimiter
|
|
6
|
+
from pocketsmith_mcp.client.retry import retry_with_backoff
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"CircuitBreaker",
|
|
10
|
+
"CircuitState",
|
|
11
|
+
"PocketSmithClient",
|
|
12
|
+
"RateLimiter",
|
|
13
|
+
"retry_with_backoff",
|
|
14
|
+
]
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
"""Async HTTP client for PocketSmith API with retry, rate limiting, circuit breaker."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
|
|
7
|
+
from pocketsmith_mcp.client.circuit_breaker import CircuitBreaker
|
|
8
|
+
from pocketsmith_mcp.client.rate_limiter import RateLimiter
|
|
9
|
+
from pocketsmith_mcp.client.retry import retry_with_backoff
|
|
10
|
+
from pocketsmith_mcp.errors import APIError, AuthError, CircuitBreakerOpenError, RateLimitError
|
|
11
|
+
from pocketsmith_mcp.logger import get_logger
|
|
12
|
+
|
|
13
|
+
logger = get_logger("api_client")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PocketSmithClient:
|
|
17
|
+
"""
|
|
18
|
+
Production-ready async client for PocketSmith API v2.
|
|
19
|
+
|
|
20
|
+
Features:
|
|
21
|
+
- Rate limiting (token bucket algorithm)
|
|
22
|
+
- Retry with exponential backoff and jitter
|
|
23
|
+
- Circuit breaker for fault tolerance
|
|
24
|
+
- Comprehensive error handling
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
BASE_URL = "https://api.pocketsmith.com/v2"
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
api_key: str,
|
|
32
|
+
base_url: str | None = None,
|
|
33
|
+
timeout: float = 30.0,
|
|
34
|
+
max_retries: int = 3,
|
|
35
|
+
rate_limit_per_minute: int = 60,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Initialize the PocketSmith API client.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
api_key: PocketSmith API key (X-Developer-Key)
|
|
42
|
+
base_url: API base URL (default: https://api.pocketsmith.com/v2)
|
|
43
|
+
timeout: Request timeout in seconds
|
|
44
|
+
max_retries: Maximum retry attempts for failed requests
|
|
45
|
+
rate_limit_per_minute: Maximum requests per minute
|
|
46
|
+
"""
|
|
47
|
+
if not api_key:
|
|
48
|
+
raise ValueError("api_key is required")
|
|
49
|
+
|
|
50
|
+
self.api_key = api_key
|
|
51
|
+
self.base_url = base_url or self.BASE_URL
|
|
52
|
+
self.timeout = timeout
|
|
53
|
+
self.max_retries = max_retries
|
|
54
|
+
|
|
55
|
+
self._client = httpx.AsyncClient(
|
|
56
|
+
base_url=self.base_url,
|
|
57
|
+
headers={
|
|
58
|
+
"X-Developer-Key": api_key,
|
|
59
|
+
"Content-Type": "application/json",
|
|
60
|
+
"Accept": "application/json",
|
|
61
|
+
},
|
|
62
|
+
timeout=timeout,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
self._rate_limiter = RateLimiter(
|
|
66
|
+
tokens_per_interval=rate_limit_per_minute,
|
|
67
|
+
interval_seconds=60,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
self._circuit_breaker = CircuitBreaker(
|
|
71
|
+
failure_threshold=5,
|
|
72
|
+
reset_timeout_seconds=60,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
async def _request(
|
|
76
|
+
self,
|
|
77
|
+
method: str,
|
|
78
|
+
path: str,
|
|
79
|
+
params: dict[str, Any] | None = None,
|
|
80
|
+
json_data: dict[str, Any] | None = None,
|
|
81
|
+
) -> dict[str, Any] | list[Any]:
|
|
82
|
+
"""
|
|
83
|
+
Make an authenticated API request with retry, rate limiting, and circuit breaker.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
method: HTTP method (GET, POST, PUT, DELETE)
|
|
87
|
+
path: API endpoint path
|
|
88
|
+
params: Query parameters
|
|
89
|
+
json_data: JSON request body
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Parsed JSON response
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
AuthError: Authentication failed (401)
|
|
96
|
+
RateLimitError: Rate limit exceeded (429)
|
|
97
|
+
APIError: Other API errors
|
|
98
|
+
CircuitBreakerOpenError: Circuit breaker is open
|
|
99
|
+
"""
|
|
100
|
+
# Check circuit breaker
|
|
101
|
+
if not self._circuit_breaker.can_execute():
|
|
102
|
+
raise CircuitBreakerOpenError()
|
|
103
|
+
|
|
104
|
+
# Rate limiting
|
|
105
|
+
await self._rate_limiter.acquire()
|
|
106
|
+
|
|
107
|
+
async def execute_request() -> dict[str, Any] | list[Any]:
|
|
108
|
+
# Clean up params - remove None values
|
|
109
|
+
clean_params = None
|
|
110
|
+
if params:
|
|
111
|
+
clean_params = {k: v for k, v in params.items() if v is not None}
|
|
112
|
+
|
|
113
|
+
logger.debug(f"Request: {method} {path} params={clean_params}")
|
|
114
|
+
|
|
115
|
+
response = await self._client.request(
|
|
116
|
+
method=method,
|
|
117
|
+
url=path,
|
|
118
|
+
params=clean_params,
|
|
119
|
+
json=json_data,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
logger.debug(f"Response: {response.status_code}")
|
|
123
|
+
|
|
124
|
+
# Handle errors
|
|
125
|
+
if response.status_code == 401:
|
|
126
|
+
raise AuthError("Invalid API key")
|
|
127
|
+
|
|
128
|
+
if response.status_code == 429:
|
|
129
|
+
retry_after = response.headers.get("Retry-After", "60")
|
|
130
|
+
raise RateLimitError(
|
|
131
|
+
f"Rate limit exceeded. Retry after {retry_after}s",
|
|
132
|
+
retry_after=int(retry_after),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if response.status_code >= 500:
|
|
136
|
+
self._circuit_breaker.record_failure()
|
|
137
|
+
raise APIError(
|
|
138
|
+
f"Server error: {response.status_code}",
|
|
139
|
+
status_code=response.status_code,
|
|
140
|
+
response_body=response.text,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if response.status_code >= 400:
|
|
144
|
+
error_body = response.text
|
|
145
|
+
try:
|
|
146
|
+
error_json = response.json()
|
|
147
|
+
if "error" in error_json:
|
|
148
|
+
error_body = error_json["error"]
|
|
149
|
+
except Exception:
|
|
150
|
+
pass
|
|
151
|
+
raise APIError(
|
|
152
|
+
f"Client error: {response.status_code}",
|
|
153
|
+
status_code=response.status_code,
|
|
154
|
+
response_body=error_body,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Record success
|
|
158
|
+
self._circuit_breaker.record_success()
|
|
159
|
+
|
|
160
|
+
# Handle empty responses
|
|
161
|
+
if response.status_code == 204:
|
|
162
|
+
return {}
|
|
163
|
+
|
|
164
|
+
result: dict[str, Any] | list[Any] = response.json()
|
|
165
|
+
return result
|
|
166
|
+
|
|
167
|
+
# Retry with backoff for retryable errors
|
|
168
|
+
return await retry_with_backoff(
|
|
169
|
+
execute_request,
|
|
170
|
+
max_attempts=self.max_retries,
|
|
171
|
+
base_delay=1.0,
|
|
172
|
+
max_delay=30.0,
|
|
173
|
+
retryable_errors=(httpx.TimeoutException, httpx.NetworkError),
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
async def get(
|
|
177
|
+
self,
|
|
178
|
+
path: str,
|
|
179
|
+
params: dict[str, Any] | None = None,
|
|
180
|
+
) -> dict[str, Any] | list[Any]:
|
|
181
|
+
"""
|
|
182
|
+
Make a GET request.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
path: API endpoint path
|
|
186
|
+
params: Query parameters
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Parsed JSON response
|
|
190
|
+
"""
|
|
191
|
+
return await self._request("GET", path, params=params)
|
|
192
|
+
|
|
193
|
+
async def post(
|
|
194
|
+
self,
|
|
195
|
+
path: str,
|
|
196
|
+
json_data: dict[str, Any] | None = None,
|
|
197
|
+
) -> dict[str, Any] | list[Any]:
|
|
198
|
+
"""
|
|
199
|
+
Make a POST request.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
path: API endpoint path
|
|
203
|
+
json_data: JSON request body
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Parsed JSON response
|
|
207
|
+
"""
|
|
208
|
+
return await self._request("POST", path, json_data=json_data)
|
|
209
|
+
|
|
210
|
+
async def put(
|
|
211
|
+
self,
|
|
212
|
+
path: str,
|
|
213
|
+
json_data: dict[str, Any] | None = None,
|
|
214
|
+
) -> dict[str, Any] | list[Any]:
|
|
215
|
+
"""
|
|
216
|
+
Make a PUT request.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
path: API endpoint path
|
|
220
|
+
json_data: JSON request body
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Parsed JSON response
|
|
224
|
+
"""
|
|
225
|
+
return await self._request("PUT", path, json_data=json_data)
|
|
226
|
+
|
|
227
|
+
async def delete(self, path: str) -> dict[str, Any] | list[Any]:
|
|
228
|
+
"""
|
|
229
|
+
Make a DELETE request.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
path: API endpoint path
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
Parsed JSON response (usually empty)
|
|
236
|
+
"""
|
|
237
|
+
return await self._request("DELETE", path)
|
|
238
|
+
|
|
239
|
+
async def close(self) -> None:
|
|
240
|
+
"""Close the HTTP client."""
|
|
241
|
+
await self._client.aclose()
|
|
242
|
+
|
|
243
|
+
async def __aenter__(self) -> "PocketSmithClient":
|
|
244
|
+
"""Async context manager entry."""
|
|
245
|
+
return self
|
|
246
|
+
|
|
247
|
+
async def __aexit__(
|
|
248
|
+
self,
|
|
249
|
+
exc_type: type[BaseException] | None,
|
|
250
|
+
exc_val: BaseException | None,
|
|
251
|
+
exc_tb: Any,
|
|
252
|
+
) -> None:
|
|
253
|
+
"""Async context manager exit."""
|
|
254
|
+
await self.close()
|
|
255
|
+
|
|
256
|
+
def get_stats(self) -> dict[str, Any]:
|
|
257
|
+
"""
|
|
258
|
+
Get client statistics.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
Dictionary with rate limiter and circuit breaker stats
|
|
262
|
+
"""
|
|
263
|
+
return {
|
|
264
|
+
"rate_limiter": {
|
|
265
|
+
"available_tokens": self._rate_limiter.available_tokens,
|
|
266
|
+
"max_tokens": self._rate_limiter.max_tokens,
|
|
267
|
+
},
|
|
268
|
+
"circuit_breaker": self._circuit_breaker.get_stats(),
|
|
269
|
+
}
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
"""Circuit breaker pattern for fault tolerance."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from threading import Lock
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from pocketsmith_mcp.logger import get_logger
|
|
9
|
+
|
|
10
|
+
logger = get_logger("circuit_breaker")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CircuitState(str, Enum):
|
|
14
|
+
"""Circuit breaker states."""
|
|
15
|
+
|
|
16
|
+
CLOSED = "closed" # Normal operation
|
|
17
|
+
OPEN = "open" # Blocking all calls
|
|
18
|
+
HALF_OPEN = "half_open" # Testing if service recovered
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CircuitBreaker:
|
|
22
|
+
"""
|
|
23
|
+
Circuit breaker for external service calls.
|
|
24
|
+
|
|
25
|
+
Implements the circuit breaker pattern to prevent cascading failures
|
|
26
|
+
when an external service is unhealthy.
|
|
27
|
+
|
|
28
|
+
States:
|
|
29
|
+
- CLOSED: Normal operation, all calls pass through
|
|
30
|
+
- OPEN: Service is unhealthy, all calls fail immediately
|
|
31
|
+
- HALF_OPEN: Testing if service recovered, limited calls allowed
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
failure_threshold: int = 5,
|
|
37
|
+
reset_timeout_seconds: float = 60.0,
|
|
38
|
+
half_open_max_calls: int = 1,
|
|
39
|
+
):
|
|
40
|
+
"""
|
|
41
|
+
Initialize the circuit breaker.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
failure_threshold: Number of failures before opening circuit
|
|
45
|
+
reset_timeout_seconds: Time to wait before testing recovery
|
|
46
|
+
half_open_max_calls: Number of test calls allowed in half-open state
|
|
47
|
+
"""
|
|
48
|
+
if failure_threshold < 1:
|
|
49
|
+
raise ValueError("failure_threshold must be at least 1")
|
|
50
|
+
if reset_timeout_seconds <= 0:
|
|
51
|
+
raise ValueError("reset_timeout_seconds must be positive")
|
|
52
|
+
if half_open_max_calls < 1:
|
|
53
|
+
raise ValueError("half_open_max_calls must be at least 1")
|
|
54
|
+
|
|
55
|
+
self.failure_threshold = failure_threshold
|
|
56
|
+
self.reset_timeout_seconds = reset_timeout_seconds
|
|
57
|
+
self.half_open_max_calls = half_open_max_calls
|
|
58
|
+
|
|
59
|
+
self._state = CircuitState.CLOSED
|
|
60
|
+
self._failures = 0
|
|
61
|
+
self._successes = 0
|
|
62
|
+
self._last_failure_time: float = 0.0
|
|
63
|
+
self._half_open_calls = 0
|
|
64
|
+
self._lock = Lock()
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def state(self) -> CircuitState:
|
|
68
|
+
"""Get the current circuit state."""
|
|
69
|
+
with self._lock:
|
|
70
|
+
self._check_state_transition()
|
|
71
|
+
return self._state
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def failures(self) -> int:
|
|
75
|
+
"""Get the current failure count."""
|
|
76
|
+
return self._failures
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def is_closed(self) -> bool:
|
|
80
|
+
"""Check if the circuit is closed (normal operation)."""
|
|
81
|
+
return self.state == CircuitState.CLOSED
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def is_open(self) -> bool:
|
|
85
|
+
"""Check if the circuit is open (blocking calls)."""
|
|
86
|
+
return self.state == CircuitState.OPEN
|
|
87
|
+
|
|
88
|
+
def can_execute(self) -> bool:
|
|
89
|
+
"""
|
|
90
|
+
Check if the circuit allows execution.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
True if a call can be made, False if blocked
|
|
94
|
+
"""
|
|
95
|
+
with self._lock:
|
|
96
|
+
self._check_state_transition()
|
|
97
|
+
|
|
98
|
+
if self._state == CircuitState.CLOSED:
|
|
99
|
+
return True
|
|
100
|
+
|
|
101
|
+
if self._state == CircuitState.OPEN:
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
# HALF_OPEN: Allow limited test calls
|
|
105
|
+
if self._half_open_calls < self.half_open_max_calls:
|
|
106
|
+
self._half_open_calls += 1
|
|
107
|
+
return True
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
def record_success(self) -> None:
|
|
111
|
+
"""Record a successful call."""
|
|
112
|
+
with self._lock:
|
|
113
|
+
self._successes += 1
|
|
114
|
+
|
|
115
|
+
if self._state == CircuitState.HALF_OPEN:
|
|
116
|
+
# Service recovered, close the circuit
|
|
117
|
+
logger.info("Circuit breaker: Service recovered, closing circuit")
|
|
118
|
+
self._state = CircuitState.CLOSED
|
|
119
|
+
|
|
120
|
+
# Reset failure count on success
|
|
121
|
+
self._failures = 0
|
|
122
|
+
self._half_open_calls = 0
|
|
123
|
+
|
|
124
|
+
def record_failure(self) -> None:
|
|
125
|
+
"""Record a failed call."""
|
|
126
|
+
with self._lock:
|
|
127
|
+
self._failures += 1
|
|
128
|
+
self._last_failure_time = time.monotonic()
|
|
129
|
+
|
|
130
|
+
if self._state == CircuitState.HALF_OPEN:
|
|
131
|
+
# Test call failed, reopen circuit
|
|
132
|
+
logger.warning("Circuit breaker: Test call failed, reopening circuit")
|
|
133
|
+
self._state = CircuitState.OPEN
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
if self._failures >= self.failure_threshold:
|
|
137
|
+
# Too many failures, open circuit
|
|
138
|
+
logger.warning(
|
|
139
|
+
f"Circuit breaker: {self._failures} failures reached threshold, "
|
|
140
|
+
f"opening circuit for {self.reset_timeout_seconds}s"
|
|
141
|
+
)
|
|
142
|
+
self._state = CircuitState.OPEN
|
|
143
|
+
|
|
144
|
+
def _check_state_transition(self) -> None:
|
|
145
|
+
"""Check if state should transition based on timeout."""
|
|
146
|
+
if self._state == CircuitState.OPEN:
|
|
147
|
+
elapsed = time.monotonic() - self._last_failure_time
|
|
148
|
+
if elapsed >= self.reset_timeout_seconds:
|
|
149
|
+
logger.info("Circuit breaker: Reset timeout elapsed, entering half-open state")
|
|
150
|
+
self._state = CircuitState.HALF_OPEN
|
|
151
|
+
self._half_open_calls = 0
|
|
152
|
+
|
|
153
|
+
def reset(self) -> None:
|
|
154
|
+
"""Reset the circuit breaker to initial state."""
|
|
155
|
+
with self._lock:
|
|
156
|
+
self._state = CircuitState.CLOSED
|
|
157
|
+
self._failures = 0
|
|
158
|
+
self._successes = 0
|
|
159
|
+
self._last_failure_time = 0.0
|
|
160
|
+
self._half_open_calls = 0
|
|
161
|
+
logger.info("Circuit breaker: Reset to closed state")
|
|
162
|
+
|
|
163
|
+
def force_open(self) -> None:
|
|
164
|
+
"""Force the circuit to open state."""
|
|
165
|
+
with self._lock:
|
|
166
|
+
self._state = CircuitState.OPEN
|
|
167
|
+
self._last_failure_time = time.monotonic()
|
|
168
|
+
logger.warning("Circuit breaker: Forced to open state")
|
|
169
|
+
|
|
170
|
+
def get_stats(self) -> dict[str, Any]:
|
|
171
|
+
"""Get circuit breaker statistics."""
|
|
172
|
+
with self._lock:
|
|
173
|
+
return {
|
|
174
|
+
"state": self._state.value,
|
|
175
|
+
"failures": self._failures,
|
|
176
|
+
"successes": self._successes,
|
|
177
|
+
"failure_threshold": self.failure_threshold,
|
|
178
|
+
"reset_timeout_seconds": self.reset_timeout_seconds,
|
|
179
|
+
}
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Token bucket rate limiter for API calls."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RateLimiter:
|
|
8
|
+
"""
|
|
9
|
+
Token bucket rate limiter with async support.
|
|
10
|
+
|
|
11
|
+
Implements a token bucket algorithm that allows a certain number of
|
|
12
|
+
requests per time interval. Tokens are refilled continuously based
|
|
13
|
+
on elapsed time.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
tokens_per_interval: int,
|
|
19
|
+
interval_seconds: float,
|
|
20
|
+
initial_tokens: int | None = None,
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Initialize the rate limiter.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
tokens_per_interval: Number of tokens to add per interval
|
|
27
|
+
interval_seconds: Length of the interval in seconds
|
|
28
|
+
initial_tokens: Initial number of tokens (defaults to tokens_per_interval)
|
|
29
|
+
"""
|
|
30
|
+
if tokens_per_interval <= 0:
|
|
31
|
+
raise ValueError("tokens_per_interval must be positive")
|
|
32
|
+
if interval_seconds <= 0:
|
|
33
|
+
raise ValueError("interval_seconds must be positive")
|
|
34
|
+
|
|
35
|
+
self.tokens_per_interval = tokens_per_interval
|
|
36
|
+
self.interval_seconds = interval_seconds
|
|
37
|
+
self.tokens = float(initial_tokens if initial_tokens is not None else tokens_per_interval)
|
|
38
|
+
self.max_tokens = float(tokens_per_interval)
|
|
39
|
+
self.last_refill = time.monotonic()
|
|
40
|
+
self._lock = asyncio.Lock()
|
|
41
|
+
|
|
42
|
+
async def acquire(self, tokens: int = 1) -> None:
|
|
43
|
+
"""
|
|
44
|
+
Acquire tokens, waiting if necessary.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
tokens: Number of tokens to acquire (default: 1)
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
ValueError: If tokens requested exceeds max_tokens
|
|
51
|
+
"""
|
|
52
|
+
if tokens > self.max_tokens:
|
|
53
|
+
raise ValueError(f"Cannot acquire {tokens} tokens (max: {self.max_tokens})")
|
|
54
|
+
|
|
55
|
+
async with self._lock:
|
|
56
|
+
self._refill()
|
|
57
|
+
|
|
58
|
+
if self.tokens >= tokens:
|
|
59
|
+
self.tokens -= tokens
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
# Calculate wait time until we have enough tokens
|
|
63
|
+
tokens_needed = tokens - self.tokens
|
|
64
|
+
wait_time = (tokens_needed / self.tokens_per_interval) * self.interval_seconds
|
|
65
|
+
|
|
66
|
+
await asyncio.sleep(wait_time)
|
|
67
|
+
self._refill()
|
|
68
|
+
self.tokens -= tokens
|
|
69
|
+
|
|
70
|
+
def try_acquire(self, tokens: int = 1) -> bool:
|
|
71
|
+
"""
|
|
72
|
+
Try to acquire tokens without waiting.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
tokens: Number of tokens to acquire (default: 1)
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
True if tokens were acquired, False otherwise
|
|
79
|
+
"""
|
|
80
|
+
self._refill()
|
|
81
|
+
|
|
82
|
+
if self.tokens >= tokens:
|
|
83
|
+
self.tokens -= tokens
|
|
84
|
+
return True
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
def _refill(self) -> None:
|
|
88
|
+
"""Refill tokens based on elapsed time."""
|
|
89
|
+
now = time.monotonic()
|
|
90
|
+
elapsed = now - self.last_refill
|
|
91
|
+
|
|
92
|
+
# Calculate tokens to add based on elapsed time
|
|
93
|
+
tokens_to_add = (elapsed / self.interval_seconds) * self.tokens_per_interval
|
|
94
|
+
self.tokens = min(self.max_tokens, self.tokens + tokens_to_add)
|
|
95
|
+
self.last_refill = now
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def available_tokens(self) -> float:
|
|
99
|
+
"""Get the current number of available tokens."""
|
|
100
|
+
self._refill()
|
|
101
|
+
return self.tokens
|
|
102
|
+
|
|
103
|
+
def reset(self) -> None:
|
|
104
|
+
"""Reset the rate limiter to full capacity."""
|
|
105
|
+
self.tokens = self.max_tokens
|
|
106
|
+
self.last_refill = time.monotonic()
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Exponential backoff retry with jitter."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import random
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
6
|
+
from typing import TypeVar
|
|
7
|
+
|
|
8
|
+
from pocketsmith_mcp.logger import get_logger
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
logger = get_logger("retry")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def retry_with_backoff(
|
|
15
|
+
func: Callable[[], Awaitable[T]],
|
|
16
|
+
max_attempts: int = 3,
|
|
17
|
+
base_delay: float = 1.0,
|
|
18
|
+
max_delay: float = 30.0,
|
|
19
|
+
jitter_factor: float = 0.2,
|
|
20
|
+
retryable_errors: tuple[type[Exception], ...] = (Exception,),
|
|
21
|
+
on_retry: Callable[[Exception, int], None] | None = None,
|
|
22
|
+
) -> T:
|
|
23
|
+
"""
|
|
24
|
+
Retry an async function with exponential backoff and jitter.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
func: Async function to retry (no arguments)
|
|
28
|
+
max_attempts: Maximum number of attempts (default: 3)
|
|
29
|
+
base_delay: Base delay in seconds (default: 1.0)
|
|
30
|
+
max_delay: Maximum delay in seconds (default: 30.0)
|
|
31
|
+
jitter_factor: Jitter factor (0.0-1.0) to randomize delay (default: 0.2)
|
|
32
|
+
retryable_errors: Tuple of exception types to retry (default: all)
|
|
33
|
+
on_retry: Optional callback called on each retry with (exception, attempt)
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Result of the function
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
The last exception if all retries fail
|
|
40
|
+
"""
|
|
41
|
+
if max_attempts < 1:
|
|
42
|
+
raise ValueError("max_attempts must be at least 1")
|
|
43
|
+
if base_delay <= 0:
|
|
44
|
+
raise ValueError("base_delay must be positive")
|
|
45
|
+
if max_delay <= 0:
|
|
46
|
+
raise ValueError("max_delay must be positive")
|
|
47
|
+
if not 0 <= jitter_factor <= 1:
|
|
48
|
+
raise ValueError("jitter_factor must be between 0 and 1")
|
|
49
|
+
|
|
50
|
+
last_error: Exception = Exception("No attempts made")
|
|
51
|
+
|
|
52
|
+
for attempt in range(1, max_attempts + 1):
|
|
53
|
+
try:
|
|
54
|
+
return await func()
|
|
55
|
+
except retryable_errors as e:
|
|
56
|
+
last_error = e
|
|
57
|
+
|
|
58
|
+
if attempt == max_attempts:
|
|
59
|
+
logger.warning(
|
|
60
|
+
f"All {max_attempts} attempts failed. Last error: {e}"
|
|
61
|
+
)
|
|
62
|
+
break
|
|
63
|
+
|
|
64
|
+
# Calculate delay with exponential backoff
|
|
65
|
+
delay = min(base_delay * (2 ** (attempt - 1)), max_delay)
|
|
66
|
+
|
|
67
|
+
# Add jitter
|
|
68
|
+
jitter = delay * jitter_factor * random.random()
|
|
69
|
+
total_delay = delay + jitter
|
|
70
|
+
|
|
71
|
+
logger.info(
|
|
72
|
+
f"Attempt {attempt}/{max_attempts} failed: {e}. "
|
|
73
|
+
f"Retrying in {total_delay:.2f}s"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if on_retry:
|
|
77
|
+
on_retry(e, attempt)
|
|
78
|
+
|
|
79
|
+
await asyncio.sleep(total_delay)
|
|
80
|
+
|
|
81
|
+
raise last_error
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def calculate_delay(
|
|
85
|
+
attempt: int,
|
|
86
|
+
base_delay: float = 1.0,
|
|
87
|
+
max_delay: float = 30.0,
|
|
88
|
+
jitter_factor: float = 0.2,
|
|
89
|
+
) -> float:
|
|
90
|
+
"""
|
|
91
|
+
Calculate delay for a given attempt number.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
attempt: Current attempt number (1-based)
|
|
95
|
+
base_delay: Base delay in seconds
|
|
96
|
+
max_delay: Maximum delay in seconds
|
|
97
|
+
jitter_factor: Jitter factor (0.0-1.0)
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Calculated delay in seconds
|
|
101
|
+
"""
|
|
102
|
+
delay = min(base_delay * (2 ** (attempt - 1)), max_delay)
|
|
103
|
+
rand_val: float = random.random()
|
|
104
|
+
jitter = delay * jitter_factor * rand_val
|
|
105
|
+
total_delay: float = delay + jitter
|
|
106
|
+
return total_delay
|