taipanstack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- taipanstack/__init__.py +53 -0
- taipanstack/config/__init__.py +25 -0
- taipanstack/config/generators.py +357 -0
- taipanstack/config/models.py +316 -0
- taipanstack/config/version_config.py +227 -0
- taipanstack/core/__init__.py +47 -0
- taipanstack/core/compat.py +329 -0
- taipanstack/core/optimizations.py +392 -0
- taipanstack/core/result.py +199 -0
- taipanstack/security/__init__.py +55 -0
- taipanstack/security/decorators.py +369 -0
- taipanstack/security/guards.py +362 -0
- taipanstack/security/sanitizers.py +321 -0
- taipanstack/security/validators.py +342 -0
- taipanstack/utils/__init__.py +24 -0
- taipanstack/utils/circuit_breaker.py +268 -0
- taipanstack/utils/filesystem.py +417 -0
- taipanstack/utils/logging.py +328 -0
- taipanstack/utils/metrics.py +272 -0
- taipanstack/utils/retry.py +300 -0
- taipanstack/utils/subprocess.py +344 -0
- taipanstack-0.1.0.dist-info/METADATA +350 -0
- taipanstack-0.1.0.dist-info/RECORD +25 -0
- taipanstack-0.1.0.dist-info/WHEEL +4 -0
- taipanstack-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Input validators for type-safe validation.
|
|
3
|
+
|
|
4
|
+
Provides validation functions for common input types like email,
|
|
5
|
+
project names, URLs, etc. All validators raise ValueError on invalid input.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import re
|
|
11
|
+
from ipaddress import IPv4Address, IPv6Address, ip_address
|
|
12
|
+
from typing import Literal
|
|
13
|
+
from urllib.parse import urlparse
|
|
14
|
+
|
|
15
|
+
# Constants to avoid magic values (PLR2004)
|
|
16
|
+
PYTHON_MAJOR_VERSION = 3
|
|
17
|
+
MIN_PYTHON_MINOR_VERSION = 10
|
|
18
|
+
MAX_EMAIL_LOCAL_LENGTH = 64
|
|
19
|
+
MAX_EMAIL_DOMAIN_LENGTH = 255
|
|
20
|
+
MAX_PORT_NUMBER = 65535
|
|
21
|
+
MIN_PRIVILEGED_PORT = 1024
|
|
22
|
+
LOCALHOST_DOMAINS = ("localhost", "127.0.0.1", "::1")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def validate_project_name(
|
|
26
|
+
name: str,
|
|
27
|
+
*,
|
|
28
|
+
max_length: int = 100,
|
|
29
|
+
allow_hyphen: bool = True,
|
|
30
|
+
allow_underscore: bool = True,
|
|
31
|
+
) -> str:
|
|
32
|
+
"""Validate a project name.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
name: The project name to validate.
|
|
36
|
+
max_length: Maximum allowed length.
|
|
37
|
+
allow_hyphen: Allow hyphens in name.
|
|
38
|
+
allow_underscore: Allow underscores in name.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The validated project name.
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
ValueError: If the name is invalid.
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
>>> validate_project_name("my_project")
|
|
48
|
+
'my_project'
|
|
49
|
+
>>> validate_project_name("123project")
|
|
50
|
+
ValueError: Project name must start with a letter
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
if not name:
|
|
54
|
+
msg = "Project name cannot be empty"
|
|
55
|
+
raise ValueError(msg)
|
|
56
|
+
|
|
57
|
+
if len(name) > max_length:
|
|
58
|
+
msg = f"Project name exceeds maximum length of {max_length}"
|
|
59
|
+
raise ValueError(msg)
|
|
60
|
+
|
|
61
|
+
# Build allowed characters
|
|
62
|
+
allowed = r"a-zA-Z0-9"
|
|
63
|
+
if allow_hyphen:
|
|
64
|
+
allowed += r"-"
|
|
65
|
+
if allow_underscore:
|
|
66
|
+
allowed += r"_"
|
|
67
|
+
|
|
68
|
+
pattern = f"^[a-zA-Z][{allowed}]*$"
|
|
69
|
+
|
|
70
|
+
if not re.match(pattern, name):
|
|
71
|
+
if not name[0].isalpha():
|
|
72
|
+
msg = "Project name must start with a letter"
|
|
73
|
+
raise ValueError(msg)
|
|
74
|
+
hyphen_msg = ", hyphens" if allow_hyphen else ""
|
|
75
|
+
underscore_msg = ", underscores" if allow_underscore else ""
|
|
76
|
+
msg = (
|
|
77
|
+
f"Project name contains invalid characters. "
|
|
78
|
+
f"Allowed: letters, numbers{hyphen_msg}{underscore_msg}"
|
|
79
|
+
)
|
|
80
|
+
raise ValueError(msg)
|
|
81
|
+
|
|
82
|
+
# Check for reserved names
|
|
83
|
+
reserved = {
|
|
84
|
+
"test",
|
|
85
|
+
"tests",
|
|
86
|
+
"src",
|
|
87
|
+
"lib",
|
|
88
|
+
"bin",
|
|
89
|
+
"build",
|
|
90
|
+
"dist",
|
|
91
|
+
"setup",
|
|
92
|
+
"config",
|
|
93
|
+
"settings",
|
|
94
|
+
"core",
|
|
95
|
+
"main",
|
|
96
|
+
"app",
|
|
97
|
+
"site-packages",
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
if name.lower() in reserved:
|
|
101
|
+
msg = f"Project name '{name}' is reserved"
|
|
102
|
+
raise ValueError(msg)
|
|
103
|
+
|
|
104
|
+
return name
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def validate_python_version(version: str) -> str:
|
|
108
|
+
"""Validate Python version string.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
version: Version string like "3.12" or "3.10".
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
The validated version string.
|
|
115
|
+
|
|
116
|
+
Raises:
|
|
117
|
+
ValueError: If version format is invalid or unsupported.
|
|
118
|
+
|
|
119
|
+
"""
|
|
120
|
+
pattern = r"^\d+\.\d+$"
|
|
121
|
+
|
|
122
|
+
if not re.match(pattern, version):
|
|
123
|
+
msg = f"Invalid version format: '{version}'. Use 'X.Y' format (e.g., '3.12')"
|
|
124
|
+
raise ValueError(msg)
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
major, minor = map(int, version.split("."))
|
|
128
|
+
except ValueError as e:
|
|
129
|
+
msg = f"Invalid version numbers in '{version}'"
|
|
130
|
+
raise ValueError(msg) from e
|
|
131
|
+
|
|
132
|
+
if major != PYTHON_MAJOR_VERSION:
|
|
133
|
+
msg = f"Only Python 3.x is supported, got {major}.x"
|
|
134
|
+
raise ValueError(msg)
|
|
135
|
+
|
|
136
|
+
if minor < MIN_PYTHON_MINOR_VERSION:
|
|
137
|
+
msg = (
|
|
138
|
+
f"Python 3.{minor} is not supported. "
|
|
139
|
+
f"Minimum is 3.{MIN_PYTHON_MINOR_VERSION}"
|
|
140
|
+
)
|
|
141
|
+
raise ValueError(msg)
|
|
142
|
+
|
|
143
|
+
return version
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def validate_email(email: str) -> str:
|
|
147
|
+
"""Validate email address format.
|
|
148
|
+
|
|
149
|
+
Uses a reasonable regex pattern that covers most valid emails
|
|
150
|
+
without being overly strict.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
email: The email address to validate.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
The validated email address.
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
ValueError: If email format is invalid.
|
|
160
|
+
|
|
161
|
+
"""
|
|
162
|
+
if not email:
|
|
163
|
+
msg = "Email cannot be empty"
|
|
164
|
+
raise ValueError(msg)
|
|
165
|
+
|
|
166
|
+
# RFC 5322 compliant pattern (simplified)
|
|
167
|
+
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
|
168
|
+
|
|
169
|
+
if not re.match(pattern, email):
|
|
170
|
+
msg = f"Invalid email format: {email}"
|
|
171
|
+
raise ValueError(msg)
|
|
172
|
+
|
|
173
|
+
# Additional checks
|
|
174
|
+
local, domain = email.rsplit("@", 1)
|
|
175
|
+
|
|
176
|
+
if len(local) > MAX_EMAIL_LOCAL_LENGTH:
|
|
177
|
+
msg = f"Email local part exceeds {MAX_EMAIL_LOCAL_LENGTH} characters"
|
|
178
|
+
raise ValueError(msg)
|
|
179
|
+
|
|
180
|
+
if len(domain) > MAX_EMAIL_DOMAIN_LENGTH:
|
|
181
|
+
msg = f"Email domain exceeds {MAX_EMAIL_DOMAIN_LENGTH} characters"
|
|
182
|
+
raise ValueError(msg)
|
|
183
|
+
|
|
184
|
+
return email
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def validate_url(
|
|
188
|
+
url: str,
|
|
189
|
+
*,
|
|
190
|
+
allowed_schemes: tuple[str, ...] = ("http", "https"),
|
|
191
|
+
require_tld: bool = True,
|
|
192
|
+
) -> str:
|
|
193
|
+
"""Validate URL format and scheme.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
url: The URL to validate.
|
|
197
|
+
allowed_schemes: Tuple of allowed URL schemes.
|
|
198
|
+
require_tld: Whether to require a TLD in the domain.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
The validated URL.
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
ValueError: If URL format is invalid.
|
|
205
|
+
|
|
206
|
+
"""
|
|
207
|
+
if not url:
|
|
208
|
+
msg = "URL cannot be empty"
|
|
209
|
+
raise ValueError(msg)
|
|
210
|
+
|
|
211
|
+
try:
|
|
212
|
+
parsed = urlparse(url)
|
|
213
|
+
except ValueError as e:
|
|
214
|
+
msg = f"Invalid URL format: {e}"
|
|
215
|
+
raise ValueError(msg) from e
|
|
216
|
+
|
|
217
|
+
if not parsed.scheme:
|
|
218
|
+
msg = "URL must have a scheme (e.g., https://)"
|
|
219
|
+
raise ValueError(msg)
|
|
220
|
+
|
|
221
|
+
if parsed.scheme not in allowed_schemes:
|
|
222
|
+
msg = f"URL scheme '{parsed.scheme}' is not allowed. Allowed: {allowed_schemes}"
|
|
223
|
+
raise ValueError(msg)
|
|
224
|
+
|
|
225
|
+
if not parsed.netloc:
|
|
226
|
+
msg = "URL must have a domain"
|
|
227
|
+
raise ValueError(msg)
|
|
228
|
+
|
|
229
|
+
if require_tld:
|
|
230
|
+
# Check for TLD (at least one dot)
|
|
231
|
+
domain = parsed.netloc.split(":")[0] # Remove port if present
|
|
232
|
+
has_no_tld = "." not in domain or domain.endswith(".")
|
|
233
|
+
is_localhost = domain.lower() in LOCALHOST_DOMAINS
|
|
234
|
+
if has_no_tld and not is_localhost:
|
|
235
|
+
msg = f"URL domain must have a TLD: {domain}"
|
|
236
|
+
raise ValueError(msg)
|
|
237
|
+
|
|
238
|
+
return url
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def validate_ip_address(
|
|
242
|
+
ip: str,
|
|
243
|
+
*,
|
|
244
|
+
version: Literal["v4", "v6", "any"] = "any",
|
|
245
|
+
allow_private: bool = True,
|
|
246
|
+
) -> IPv4Address | IPv6Address:
|
|
247
|
+
"""Validate IP address.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
ip: The IP address string to validate.
|
|
251
|
+
version: IP version to allow ('v4', 'v6', or 'any').
|
|
252
|
+
allow_private: Whether to allow private/internal IPs.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
The validated IP address object.
|
|
256
|
+
|
|
257
|
+
Raises:
|
|
258
|
+
ValueError: If IP address is invalid.
|
|
259
|
+
|
|
260
|
+
"""
|
|
261
|
+
try:
|
|
262
|
+
addr = ip_address(ip)
|
|
263
|
+
except ValueError as e:
|
|
264
|
+
msg = f"Invalid IP address: {ip}"
|
|
265
|
+
raise ValueError(msg) from e
|
|
266
|
+
|
|
267
|
+
if version == "v4" and not isinstance(addr, IPv4Address):
|
|
268
|
+
msg = f"Expected IPv4 address, got IPv6: {ip}"
|
|
269
|
+
raise ValueError(msg)
|
|
270
|
+
|
|
271
|
+
if version == "v6" and not isinstance(addr, IPv6Address):
|
|
272
|
+
msg = f"Expected IPv6 address, got IPv4: {ip}"
|
|
273
|
+
raise ValueError(msg)
|
|
274
|
+
|
|
275
|
+
if not allow_private and addr.is_private:
|
|
276
|
+
msg = f"Private IP addresses are not allowed: {ip}"
|
|
277
|
+
raise ValueError(msg)
|
|
278
|
+
|
|
279
|
+
return addr
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def validate_port(
|
|
283
|
+
port: int | str,
|
|
284
|
+
*,
|
|
285
|
+
allow_privileged: bool = False,
|
|
286
|
+
) -> int:
|
|
287
|
+
"""Validate port number.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
port: The port number to validate.
|
|
291
|
+
allow_privileged: Whether to allow ports below 1024.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
The validated port number.
|
|
295
|
+
|
|
296
|
+
Raises:
|
|
297
|
+
ValueError: If port is invalid.
|
|
298
|
+
|
|
299
|
+
"""
|
|
300
|
+
try:
|
|
301
|
+
port_int = int(port)
|
|
302
|
+
except ValueError as e:
|
|
303
|
+
msg = f"Invalid port number: {port}"
|
|
304
|
+
raise ValueError(msg) from e
|
|
305
|
+
|
|
306
|
+
if port_int < 0 or port_int > MAX_PORT_NUMBER:
|
|
307
|
+
msg = f"Port must be between 0 and {MAX_PORT_NUMBER}: {port_int}"
|
|
308
|
+
raise ValueError(msg)
|
|
309
|
+
|
|
310
|
+
if not allow_privileged and port_int < MIN_PRIVILEGED_PORT:
|
|
311
|
+
msg = f"Privileged ports (< {MIN_PRIVILEGED_PORT}) are not allowed: {port_int}"
|
|
312
|
+
raise ValueError(msg)
|
|
313
|
+
|
|
314
|
+
return port_int
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def validate_semver(version: str) -> tuple[int, int, int]:
|
|
318
|
+
"""Validate semantic version string.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
version: Version string like "1.2.3" or "v1.2.3".
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Tuple of (major, minor, patch).
|
|
325
|
+
|
|
326
|
+
Raises:
|
|
327
|
+
ValueError: If version format is invalid.
|
|
328
|
+
|
|
329
|
+
"""
|
|
330
|
+
# Remove leading 'v' if present
|
|
331
|
+
version = version.lstrip("vV")
|
|
332
|
+
|
|
333
|
+
pattern = r"^(\d+)\.(\d+)\.(\d+)(?:-[a-zA-Z0-9.]+)?(?:\+[a-zA-Z0-9.]+)?$"
|
|
334
|
+
match = re.match(pattern, version)
|
|
335
|
+
|
|
336
|
+
if not match:
|
|
337
|
+
msg = f"Invalid semantic version: {version}. Expected format: X.Y.Z"
|
|
338
|
+
raise ValueError(msg)
|
|
339
|
+
|
|
340
|
+
major, minor, patch = map(int, match.groups()[:3])
|
|
341
|
+
|
|
342
|
+
return major, minor, patch
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Utility modules for Stack."""
|
|
2
|
+
|
|
3
|
+
from taipanstack.utils.filesystem import ensure_dir, safe_read, safe_write
|
|
4
|
+
from taipanstack.utils.logging import get_logger, setup_logging
|
|
5
|
+
from taipanstack.utils.retry import Retrier, RetryConfig, RetryError, retry
|
|
6
|
+
from taipanstack.utils.subprocess import SafeCommandResult, run_safe_command
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
# Retry
|
|
10
|
+
"Retrier",
|
|
11
|
+
"RetryConfig",
|
|
12
|
+
"RetryError",
|
|
13
|
+
# Subprocess
|
|
14
|
+
"SafeCommandResult",
|
|
15
|
+
# Filesystem
|
|
16
|
+
"ensure_dir",
|
|
17
|
+
# Logging
|
|
18
|
+
"get_logger",
|
|
19
|
+
"retry",
|
|
20
|
+
"run_safe_command",
|
|
21
|
+
"safe_read",
|
|
22
|
+
"safe_write",
|
|
23
|
+
"setup_logging",
|
|
24
|
+
]
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Circuit Breaker pattern implementation.
|
|
3
|
+
|
|
4
|
+
Provides protection against cascading failures by temporarily
|
|
5
|
+
blocking calls to a failing service. Compatible with any
|
|
6
|
+
Python framework.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import functools
|
|
12
|
+
import logging
|
|
13
|
+
import threading
|
|
14
|
+
import time
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from enum import Enum
|
|
18
|
+
from typing import ParamSpec, TypeVar
|
|
19
|
+
|
|
20
|
+
P = ParamSpec("P")
|
|
21
|
+
R = TypeVar("R")
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger("taipanstack.utils.circuit_breaker")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CircuitState(Enum):
|
|
27
|
+
"""States of the circuit breaker."""
|
|
28
|
+
|
|
29
|
+
CLOSED = "closed" # Normal operation, requests flow through
|
|
30
|
+
OPEN = "open" # Circuit is tripped, requests are blocked
|
|
31
|
+
HALF_OPEN = "half_open" # Testing if service has recovered
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CircuitBreakerError(Exception):
|
|
35
|
+
"""Raised when circuit breaker is open."""
|
|
36
|
+
|
|
37
|
+
def __init__(self, message: str, state: CircuitState) -> None:
|
|
38
|
+
"""Initialize CircuitBreakerError.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
message: Error description.
|
|
42
|
+
state: Current circuit state.
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
self.state = state
|
|
46
|
+
super().__init__(message)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class CircuitBreakerConfig:
|
|
51
|
+
"""Configuration for circuit breaker behavior.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
failure_threshold: Number of failures before opening circuit.
|
|
55
|
+
success_threshold: Successes needed in half-open to close.
|
|
56
|
+
timeout: Seconds before trying half-open after open.
|
|
57
|
+
excluded_exceptions: Exceptions that don't count as failures.
|
|
58
|
+
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
failure_threshold: int = 5
|
|
62
|
+
success_threshold: int = 2
|
|
63
|
+
timeout: float = 30.0
|
|
64
|
+
excluded_exceptions: tuple[type[Exception], ...] = ()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass
|
|
68
|
+
class CircuitBreakerState:
|
|
69
|
+
"""Internal state tracking for circuit breaker."""
|
|
70
|
+
|
|
71
|
+
state: CircuitState = CircuitState.CLOSED
|
|
72
|
+
failure_count: int = 0
|
|
73
|
+
success_count: int = 0
|
|
74
|
+
last_failure_time: float = 0.0
|
|
75
|
+
lock: threading.Lock = field(default_factory=threading.Lock)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class CircuitBreaker:
|
|
79
|
+
"""Circuit breaker implementation.
|
|
80
|
+
|
|
81
|
+
Monitors function calls and opens the circuit when too many
|
|
82
|
+
failures occur, preventing further calls until the service
|
|
83
|
+
recovers.
|
|
84
|
+
|
|
85
|
+
Example:
|
|
86
|
+
>>> breaker = CircuitBreaker(failure_threshold=3)
|
|
87
|
+
>>> @breaker
|
|
88
|
+
... def call_external_api():
|
|
89
|
+
... return requests.get("https://api.example.com")
|
|
90
|
+
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
*,
|
|
96
|
+
failure_threshold: int = 5,
|
|
97
|
+
success_threshold: int = 2,
|
|
98
|
+
timeout: float = 30.0,
|
|
99
|
+
excluded_exceptions: tuple[type[Exception], ...] = (),
|
|
100
|
+
name: str = "default",
|
|
101
|
+
) -> None:
|
|
102
|
+
"""Initialize CircuitBreaker.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
failure_threshold: Failures before opening circuit.
|
|
106
|
+
success_threshold: Successes to close from half-open.
|
|
107
|
+
timeout: Seconds before attempting half-open.
|
|
108
|
+
excluded_exceptions: Exceptions that don't trip circuit.
|
|
109
|
+
name: Name for logging/identification.
|
|
110
|
+
|
|
111
|
+
"""
|
|
112
|
+
self.config = CircuitBreakerConfig(
|
|
113
|
+
failure_threshold=failure_threshold,
|
|
114
|
+
success_threshold=success_threshold,
|
|
115
|
+
timeout=timeout,
|
|
116
|
+
excluded_exceptions=excluded_exceptions,
|
|
117
|
+
)
|
|
118
|
+
self.name = name
|
|
119
|
+
self._state = CircuitBreakerState()
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def state(self) -> CircuitState:
|
|
123
|
+
"""Get current circuit state."""
|
|
124
|
+
return self._state.state
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def failure_count(self) -> int:
|
|
128
|
+
"""Get current failure count."""
|
|
129
|
+
return self._state.failure_count
|
|
130
|
+
|
|
131
|
+
def _should_attempt(self) -> bool:
|
|
132
|
+
"""Check if a call should be attempted."""
|
|
133
|
+
with self._state.lock:
|
|
134
|
+
match self._state.state:
|
|
135
|
+
case CircuitState.CLOSED:
|
|
136
|
+
return True
|
|
137
|
+
|
|
138
|
+
case CircuitState.OPEN:
|
|
139
|
+
# Check if timeout has passed
|
|
140
|
+
elapsed = time.monotonic() - self._state.last_failure_time
|
|
141
|
+
if elapsed >= self.config.timeout:
|
|
142
|
+
self._state.state = CircuitState.HALF_OPEN
|
|
143
|
+
self._state.success_count = 0
|
|
144
|
+
logger.info("Circuit %s entering half-open state", self.name)
|
|
145
|
+
return True
|
|
146
|
+
return False
|
|
147
|
+
|
|
148
|
+
case CircuitState.HALF_OPEN:
|
|
149
|
+
# Allow limited attempts
|
|
150
|
+
return True
|
|
151
|
+
|
|
152
|
+
def _record_success(self) -> None:
|
|
153
|
+
"""Record a successful call."""
|
|
154
|
+
with self._state.lock:
|
|
155
|
+
match self._state.state:
|
|
156
|
+
case CircuitState.HALF_OPEN:
|
|
157
|
+
self._state.success_count += 1
|
|
158
|
+
if self._state.success_count >= self.config.success_threshold:
|
|
159
|
+
self._state.state = CircuitState.CLOSED
|
|
160
|
+
self._state.failure_count = 0
|
|
161
|
+
logger.info("Circuit %s closed after recovery", self.name)
|
|
162
|
+
|
|
163
|
+
case CircuitState.CLOSED:
|
|
164
|
+
# Reset failure count on success
|
|
165
|
+
self._state.failure_count = 0
|
|
166
|
+
|
|
167
|
+
case CircuitState.OPEN:
|
|
168
|
+
pass # Should not happen, but handle gracefully
|
|
169
|
+
|
|
170
|
+
def _record_failure(self, exc: Exception) -> None:
|
|
171
|
+
"""Record a failed call."""
|
|
172
|
+
# Check if exception should be excluded
|
|
173
|
+
if isinstance(exc, self.config.excluded_exceptions):
|
|
174
|
+
return
|
|
175
|
+
|
|
176
|
+
with self._state.lock:
|
|
177
|
+
self._state.failure_count += 1
|
|
178
|
+
self._state.last_failure_time = time.monotonic()
|
|
179
|
+
|
|
180
|
+
match self._state.state:
|
|
181
|
+
case CircuitState.HALF_OPEN:
|
|
182
|
+
# Any failure in half-open reopens circuit
|
|
183
|
+
self._state.state = CircuitState.OPEN
|
|
184
|
+
logger.warning(
|
|
185
|
+
"Circuit %s reopened after failure in half-open",
|
|
186
|
+
self.name,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
case CircuitState.CLOSED:
|
|
190
|
+
if self._state.failure_count >= self.config.failure_threshold:
|
|
191
|
+
self._state.state = CircuitState.OPEN
|
|
192
|
+
logger.warning(
|
|
193
|
+
"Circuit %s opened after %d failures",
|
|
194
|
+
self.name,
|
|
195
|
+
self._state.failure_count,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
case CircuitState.OPEN:
|
|
199
|
+
pass # Already open, nothing to do
|
|
200
|
+
|
|
201
|
+
def reset(self) -> None:
|
|
202
|
+
"""Reset circuit breaker to closed state."""
|
|
203
|
+
with self._state.lock:
|
|
204
|
+
self._state.state = CircuitState.CLOSED
|
|
205
|
+
self._state.failure_count = 0
|
|
206
|
+
self._state.success_count = 0
|
|
207
|
+
logger.info("Circuit %s manually reset", self.name)
|
|
208
|
+
|
|
209
|
+
def __call__(self, func: Callable[P, R]) -> Callable[P, R]:
|
|
210
|
+
"""Decorate a function with circuit breaker protection."""
|
|
211
|
+
|
|
212
|
+
@functools.wraps(func)
|
|
213
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
214
|
+
if not self._should_attempt():
|
|
215
|
+
raise CircuitBreakerError(
|
|
216
|
+
f"Circuit {self.name} is open",
|
|
217
|
+
state=self._state.state,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
try:
|
|
221
|
+
result = func(*args, **kwargs)
|
|
222
|
+
self._record_success()
|
|
223
|
+
return result
|
|
224
|
+
except Exception as e:
|
|
225
|
+
self._record_failure(e)
|
|
226
|
+
raise
|
|
227
|
+
|
|
228
|
+
return wrapper
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def circuit_breaker(
|
|
232
|
+
*,
|
|
233
|
+
failure_threshold: int = 5,
|
|
234
|
+
success_threshold: int = 2,
|
|
235
|
+
timeout: float = 30.0,
|
|
236
|
+
excluded_exceptions: tuple[type[Exception], ...] = (),
|
|
237
|
+
name: str | None = None,
|
|
238
|
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
|
239
|
+
"""Decorator to apply circuit breaker pattern.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
failure_threshold: Failures before opening circuit.
|
|
243
|
+
success_threshold: Successes to close from half-open.
|
|
244
|
+
timeout: Seconds before attempting half-open.
|
|
245
|
+
excluded_exceptions: Exceptions that don't trip circuit.
|
|
246
|
+
name: Optional name for the circuit.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Decorated function with circuit breaker protection.
|
|
250
|
+
|
|
251
|
+
Example:
|
|
252
|
+
>>> @circuit_breaker(failure_threshold=3, timeout=60)
|
|
253
|
+
... def call_api(endpoint: str) -> dict:
|
|
254
|
+
... return requests.get(endpoint).json()
|
|
255
|
+
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
|
259
|
+
breaker = CircuitBreaker(
|
|
260
|
+
failure_threshold=failure_threshold,
|
|
261
|
+
success_threshold=success_threshold,
|
|
262
|
+
timeout=timeout,
|
|
263
|
+
excluded_exceptions=excluded_exceptions,
|
|
264
|
+
name=name or func.__name__,
|
|
265
|
+
)
|
|
266
|
+
return breaker(func)
|
|
267
|
+
|
|
268
|
+
return decorator
|