mlenvdoctor 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlenvdoctor/__init__.py +15 -1
- mlenvdoctor/cli.py +80 -30
- mlenvdoctor/config.py +169 -0
- mlenvdoctor/constants.py +63 -0
- mlenvdoctor/diagnose.py +146 -46
- mlenvdoctor/dockerize.py +3 -6
- mlenvdoctor/exceptions.py +51 -0
- mlenvdoctor/export.py +290 -0
- mlenvdoctor/fix.py +19 -13
- mlenvdoctor/gpu.py +15 -9
- mlenvdoctor/icons.py +100 -0
- mlenvdoctor/logger.py +81 -0
- mlenvdoctor/parallel.py +115 -0
- mlenvdoctor/retry.py +92 -0
- mlenvdoctor/utils.py +79 -22
- mlenvdoctor/validators.py +217 -0
- {mlenvdoctor-0.1.0.dist-info → mlenvdoctor-0.1.2.dist-info}/METADATA +3 -2
- mlenvdoctor-0.1.2.dist-info/RECORD +21 -0
- mlenvdoctor-0.1.0.dist-info/RECORD +0 -12
- {mlenvdoctor-0.1.0.dist-info → mlenvdoctor-0.1.2.dist-info}/WHEEL +0 -0
- {mlenvdoctor-0.1.0.dist-info → mlenvdoctor-0.1.2.dist-info}/entry_points.txt +0 -0
- {mlenvdoctor-0.1.0.dist-info → mlenvdoctor-0.1.2.dist-info}/licenses/LICENSE +0 -0
mlenvdoctor/fix.py
CHANGED
|
@@ -1,14 +1,11 @@
|
|
|
1
1
|
"""Auto-fix and requirements generation for ML Environment Doctor."""
|
|
2
2
|
|
|
3
|
-
import subprocess
|
|
4
3
|
import sys
|
|
5
4
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
5
|
+
from typing import Optional
|
|
7
6
|
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
|
|
11
|
-
from .diagnose import DiagnosticIssue, diagnose_env
|
|
7
|
+
from .diagnose import diagnose_env
|
|
8
|
+
from .icons import icon_wrench
|
|
12
9
|
from .utils import (
|
|
13
10
|
check_command_exists,
|
|
14
11
|
console,
|
|
@@ -39,7 +36,9 @@ ML_STACKS = {
|
|
|
39
36
|
}
|
|
40
37
|
|
|
41
38
|
|
|
42
|
-
def generate_requirements_txt(
|
|
39
|
+
def generate_requirements_txt(
|
|
40
|
+
stack: str = "trl-peft", output_file: str = "requirements-mlenvdoctor.txt"
|
|
41
|
+
) -> Path:
|
|
43
42
|
"""Generate requirements.txt file."""
|
|
44
43
|
if stack not in ML_STACKS:
|
|
45
44
|
print_error(f"Unknown stack: {stack}. Available: {list(ML_STACKS.keys())}")
|
|
@@ -61,7 +60,9 @@ def generate_requirements_txt(stack: str = "trl-peft", output_file: str = "requi
|
|
|
61
60
|
content = "# Standard PyTorch (CPU or CUDA)\n\n"
|
|
62
61
|
except ImportError:
|
|
63
62
|
content = "# PyTorch installation\n"
|
|
64
|
-
content +=
|
|
63
|
+
content += (
|
|
64
|
+
"# For CUDA: pip install torch --index-url https://download.pytorch.org/whl/cu124\n"
|
|
65
|
+
)
|
|
65
66
|
content += "# For CPU: pip install torch\n\n"
|
|
66
67
|
|
|
67
68
|
content += "\n".join(requirements)
|
|
@@ -72,7 +73,9 @@ def generate_requirements_txt(stack: str = "trl-peft", output_file: str = "requi
|
|
|
72
73
|
return output_path
|
|
73
74
|
|
|
74
75
|
|
|
75
|
-
def generate_conda_env(
|
|
76
|
+
def generate_conda_env(
|
|
77
|
+
stack: str = "trl-peft", output_file: str = "environment-mlenvdoctor.yml"
|
|
78
|
+
) -> Path:
|
|
76
79
|
"""Generate conda environment file."""
|
|
77
80
|
if stack not in ML_STACKS:
|
|
78
81
|
print_error(f"Unknown stack: {stack}. Available: {list(ML_STACKS.keys())}")
|
|
@@ -197,7 +200,11 @@ def create_virtualenv(env_name: str = ".venv") -> Optional[Path]:
|
|
|
197
200
|
|
|
198
201
|
venv.create(env_path, with_pip=True)
|
|
199
202
|
print_success(f"Virtual environment created: {env_name}")
|
|
200
|
-
|
|
203
|
+
if sys.platform == "win32":
|
|
204
|
+
activate_cmd = r".venv\Scripts\activate"
|
|
205
|
+
else:
|
|
206
|
+
activate_cmd = "source .venv/bin/activate"
|
|
207
|
+
print_info(f"Activate with: {activate_cmd}")
|
|
201
208
|
return env_path
|
|
202
209
|
except Exception as e:
|
|
203
210
|
print_error(f"Failed to create virtual environment: {e}")
|
|
@@ -206,7 +213,7 @@ def create_virtualenv(env_name: str = ".venv") -> Optional[Path]:
|
|
|
206
213
|
|
|
207
214
|
def auto_fix(use_conda: bool = False, create_venv: bool = False, stack: str = "trl-peft") -> bool:
|
|
208
215
|
"""Auto-fix environment issues based on diagnostics."""
|
|
209
|
-
console.print("[bold blue]
|
|
216
|
+
console.print(f"[bold blue]{icon_wrench()} Running Auto-Fix...[/bold blue]\n")
|
|
210
217
|
|
|
211
218
|
# Run diagnostics
|
|
212
219
|
issues = diagnose_env(full=False)
|
|
@@ -243,7 +250,6 @@ def auto_fix(use_conda: bool = False, create_venv: bool = False, stack: str = "t
|
|
|
243
250
|
if install.lower() in ["y", "yes"]:
|
|
244
251
|
return install_requirements(str(req_file), use_conda=use_conda)
|
|
245
252
|
else:
|
|
246
|
-
print_info(
|
|
253
|
+
print_info("Requirements file generated. Install manually with:")
|
|
247
254
|
console.print(f"[cyan] pip install -r {req_file}[/cyan]")
|
|
248
255
|
return True
|
|
249
|
-
|
mlenvdoctor/gpu.py
CHANGED
|
@@ -1,19 +1,24 @@
|
|
|
1
1
|
"""GPU benchmarks and smoke tests for ML Environment Doctor."""
|
|
2
2
|
|
|
3
3
|
import time
|
|
4
|
-
from typing import Dict
|
|
4
|
+
from typing import Dict
|
|
5
5
|
|
|
6
6
|
try:
|
|
7
7
|
import torch
|
|
8
8
|
except ImportError:
|
|
9
9
|
torch = None # type: ignore
|
|
10
10
|
|
|
11
|
+
import sys
|
|
12
|
+
|
|
11
13
|
from rich.console import Console
|
|
12
|
-
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
13
14
|
|
|
14
15
|
from .utils import print_error, print_info, print_success
|
|
15
16
|
|
|
16
|
-
console
|
|
17
|
+
# Configure console for Windows compatibility
|
|
18
|
+
if sys.platform == "win32":
|
|
19
|
+
console = Console(legacy_windows=True, force_terminal=True)
|
|
20
|
+
else:
|
|
21
|
+
console = Console()
|
|
17
22
|
|
|
18
23
|
|
|
19
24
|
def benchmark_gpu_ops() -> Dict[str, float]:
|
|
@@ -78,7 +83,8 @@ def smoke_test_lora() -> bool:
|
|
|
78
83
|
if tokenizer.pad_token is None:
|
|
79
84
|
tokenizer.pad_token = tokenizer.eos_token
|
|
80
85
|
model = AutoModelForCausalLM.from_pretrained(
|
|
81
|
-
model_name,
|
|
86
|
+
model_name,
|
|
87
|
+
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
|
|
82
88
|
).to(device)
|
|
83
89
|
|
|
84
90
|
# Configure LoRA
|
|
@@ -101,8 +107,7 @@ def smoke_test_lora() -> bool:
|
|
|
101
107
|
# Forward pass
|
|
102
108
|
with console.status("[bold green]Running forward pass..."):
|
|
103
109
|
with torch.no_grad():
|
|
104
|
-
|
|
105
|
-
loss = outputs.loss if hasattr(outputs, "loss") else None
|
|
110
|
+
_ = model(**inputs)
|
|
106
111
|
|
|
107
112
|
print_success("LoRA smoke test passed!")
|
|
108
113
|
return True
|
|
@@ -146,7 +151,9 @@ def test_model(model_name: str = "tinyllama") -> bool:
|
|
|
146
151
|
# Estimate memory requirements (rough)
|
|
147
152
|
if "7b" in actual_model_name.lower() or "7B" in actual_model_name:
|
|
148
153
|
if free_gb < 16:
|
|
149
|
-
print_error(
|
|
154
|
+
print_error(
|
|
155
|
+
f"Insufficient GPU memory: {free_gb:.1f}GB free, need ~16GB for 7B model"
|
|
156
|
+
)
|
|
150
157
|
return False
|
|
151
158
|
|
|
152
159
|
with console.status(f"[bold green]Loading {actual_model_name}..."):
|
|
@@ -166,7 +173,7 @@ def test_model(model_name: str = "tinyllama") -> bool:
|
|
|
166
173
|
inputs = tokenizer(dummy_text, return_tensors="pt").to(device)
|
|
167
174
|
|
|
168
175
|
with torch.no_grad():
|
|
169
|
-
|
|
176
|
+
_ = model(**inputs)
|
|
170
177
|
|
|
171
178
|
print_success(f"Model {actual_model_name} loaded and tested successfully!")
|
|
172
179
|
return True
|
|
@@ -181,4 +188,3 @@ def test_model(model_name: str = "tinyllama") -> bool:
|
|
|
181
188
|
except Exception as e:
|
|
182
189
|
print_error(f"Model test error: {e}")
|
|
183
190
|
return False
|
|
184
|
-
|
mlenvdoctor/icons.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Safe emoji/icon handling for cross-platform compatibility."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
# Check if we can safely use emojis
|
|
7
|
+
_USE_EMOJIS = True
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
# Try to write an emoji to see if it works
|
|
11
|
+
if sys.platform == "win32":
|
|
12
|
+
import io
|
|
13
|
+
|
|
14
|
+
# Test if console supports UTF-8
|
|
15
|
+
test_output = io.StringIO()
|
|
16
|
+
try:
|
|
17
|
+
test_output.write("🔍")
|
|
18
|
+
test_output.getvalue()
|
|
19
|
+
except (UnicodeEncodeError, UnicodeError):
|
|
20
|
+
_USE_EMOJIS = False
|
|
21
|
+
except Exception:
|
|
22
|
+
_USE_EMOJIS = False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_icon(icon_name: Literal["search", "check", "cross", "warning", "info", "wrench", "whale", "test"]) -> str:
|
|
26
|
+
"""
|
|
27
|
+
Get a safe icon/emoji for the current platform.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
icon_name: Name of the icon to get
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Emoji if supported, ASCII alternative otherwise
|
|
34
|
+
"""
|
|
35
|
+
if _USE_EMOJIS:
|
|
36
|
+
icons = {
|
|
37
|
+
"search": "🔍",
|
|
38
|
+
"check": "✅",
|
|
39
|
+
"cross": "❌",
|
|
40
|
+
"warning": "⚠️",
|
|
41
|
+
"info": "ℹ️",
|
|
42
|
+
"wrench": "🔧",
|
|
43
|
+
"whale": "🐳",
|
|
44
|
+
"test": "🧪",
|
|
45
|
+
}
|
|
46
|
+
else:
|
|
47
|
+
# ASCII alternatives
|
|
48
|
+
icons = {
|
|
49
|
+
"search": "[*]",
|
|
50
|
+
"check": "[OK]",
|
|
51
|
+
"cross": "[X]",
|
|
52
|
+
"warning": "[!]",
|
|
53
|
+
"info": "[i]",
|
|
54
|
+
"wrench": "[FIX]",
|
|
55
|
+
"whale": "[DOCKER]",
|
|
56
|
+
"test": "[TEST]",
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
return icons.get(icon_name, "")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# Convenience functions
|
|
63
|
+
def icon_search() -> str:
|
|
64
|
+
"""Get search icon."""
|
|
65
|
+
return get_icon("search")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def icon_check() -> str:
|
|
69
|
+
"""Get check icon."""
|
|
70
|
+
return get_icon("check")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def icon_cross() -> str:
|
|
74
|
+
"""Get cross/error icon."""
|
|
75
|
+
return get_icon("cross")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def icon_warning() -> str:
|
|
79
|
+
"""Get warning icon."""
|
|
80
|
+
return get_icon("warning")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def icon_info() -> str:
|
|
84
|
+
"""Get info icon."""
|
|
85
|
+
return get_icon("info")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def icon_wrench() -> str:
|
|
89
|
+
"""Get wrench/fix icon."""
|
|
90
|
+
return get_icon("wrench")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def icon_whale() -> str:
|
|
94
|
+
"""Get whale/docker icon."""
|
|
95
|
+
return get_icon("whale")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def icon_test() -> str:
|
|
99
|
+
"""Get test icon."""
|
|
100
|
+
return get_icon("test")
|
mlenvdoctor/logger.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Logging configuration for ML Environment Doctor."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
from rich.logging import RichHandler
|
|
10
|
+
|
|
11
|
+
from .utils import get_home_config_dir
|
|
12
|
+
|
|
13
|
+
console = Console()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def setup_logger(
|
|
17
|
+
name: str = "mlenvdoctor",
|
|
18
|
+
level: str = "INFO",
|
|
19
|
+
log_file: Optional[Path] = None,
|
|
20
|
+
enable_rich: bool = True,
|
|
21
|
+
) -> logging.Logger:
|
|
22
|
+
"""
|
|
23
|
+
Set up logger with Rich console handler and optional file handler.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
name: Logger name
|
|
27
|
+
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
28
|
+
log_file: Optional path to log file
|
|
29
|
+
enable_rich: Use Rich handler for console output
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Configured logger instance
|
|
33
|
+
"""
|
|
34
|
+
logger = logging.getLogger(name)
|
|
35
|
+
logger.setLevel(getattr(logging, level.upper()))
|
|
36
|
+
|
|
37
|
+
# Remove existing handlers to avoid duplicates
|
|
38
|
+
logger.handlers.clear()
|
|
39
|
+
|
|
40
|
+
# Console handler with Rich formatting
|
|
41
|
+
if enable_rich:
|
|
42
|
+
console_handler = RichHandler(
|
|
43
|
+
console=console,
|
|
44
|
+
show_time=True,
|
|
45
|
+
show_path=False,
|
|
46
|
+
rich_tracebacks=True,
|
|
47
|
+
tracebacks_show_locals=False,
|
|
48
|
+
)
|
|
49
|
+
else:
|
|
50
|
+
console_handler = logging.StreamHandler(sys.stdout)
|
|
51
|
+
console_handler.setFormatter(
|
|
52
|
+
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
console_handler.setLevel(getattr(logging, level.upper()))
|
|
56
|
+
logger.addHandler(console_handler)
|
|
57
|
+
|
|
58
|
+
# File handler if log file specified
|
|
59
|
+
if log_file:
|
|
60
|
+
log_file.parent.mkdir(parents=True, exist_ok=True)
|
|
61
|
+
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
|
62
|
+
file_handler.setLevel(logging.DEBUG) # Always log everything to file
|
|
63
|
+
file_handler.setFormatter(
|
|
64
|
+
logging.Formatter(
|
|
65
|
+
"%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s"
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
logger.addHandler(file_handler)
|
|
69
|
+
|
|
70
|
+
return logger
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_default_log_file() -> Path:
|
|
74
|
+
"""Get default log file path."""
|
|
75
|
+
log_dir = get_home_config_dir() / "logs"
|
|
76
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
77
|
+
return log_dir / "mlenvdoctor.log"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# Default logger instance
|
|
81
|
+
logger = setup_logger()
|
mlenvdoctor/parallel.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Parallel execution utilities for independent operations."""
|
|
2
|
+
|
|
3
|
+
import concurrent.futures
|
|
4
|
+
from typing import Callable, Iterable, List, TypeVar
|
|
5
|
+
|
|
6
|
+
from .logger import logger
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T")
|
|
9
|
+
R = TypeVar("R")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def run_parallel(
|
|
13
|
+
func: Callable[[T], R],
|
|
14
|
+
items: Iterable[T],
|
|
15
|
+
max_workers: int = 4,
|
|
16
|
+
timeout: float | None = None,
|
|
17
|
+
) -> List[R]:
|
|
18
|
+
"""
|
|
19
|
+
Run a function in parallel on multiple items.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
func: Function to execute
|
|
23
|
+
items: Iterable of items to process
|
|
24
|
+
items_list: List of items to process
|
|
25
|
+
max_workers: Maximum number of parallel workers
|
|
26
|
+
timeout: Maximum time to wait for all tasks (None = no timeout)
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
List of results in the same order as input items
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
def check_library(name: str) -> bool:
|
|
33
|
+
return importlib.util.find_spec(name) is not None
|
|
34
|
+
|
|
35
|
+
results = run_parallel(check_library, ["torch", "transformers", "peft"])
|
|
36
|
+
"""
|
|
37
|
+
items_list = list(items)
|
|
38
|
+
|
|
39
|
+
if not items_list:
|
|
40
|
+
return []
|
|
41
|
+
|
|
42
|
+
# Use ThreadPoolExecutor for I/O-bound operations
|
|
43
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
44
|
+
# Submit all tasks
|
|
45
|
+
future_to_item = {executor.submit(func, item): item for item in items_list}
|
|
46
|
+
|
|
47
|
+
results: List[R] = []
|
|
48
|
+
completed = 0
|
|
49
|
+
|
|
50
|
+
# Process completed tasks
|
|
51
|
+
for future in concurrent.futures.as_completed(future_to_item, timeout=timeout):
|
|
52
|
+
item = future_to_item[future]
|
|
53
|
+
try:
|
|
54
|
+
result = future.result()
|
|
55
|
+
results.append(result)
|
|
56
|
+
completed += 1
|
|
57
|
+
except Exception as e:
|
|
58
|
+
logger.error(f"Error processing {item}: {e}")
|
|
59
|
+
# Re-raise to maintain error behavior
|
|
60
|
+
raise
|
|
61
|
+
|
|
62
|
+
if completed != len(items_list):
|
|
63
|
+
raise RuntimeError(f"Only {completed}/{len(items_list)} tasks completed")
|
|
64
|
+
|
|
65
|
+
return results
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def run_parallel_with_results(
|
|
69
|
+
func: Callable[[T], R],
|
|
70
|
+
items: Iterable[T],
|
|
71
|
+
max_workers: int = 4,
|
|
72
|
+
timeout: float | None = None,
|
|
73
|
+
) -> List[tuple[T, R | Exception]]:
|
|
74
|
+
"""
|
|
75
|
+
Run a function in parallel and return results with original items.
|
|
76
|
+
|
|
77
|
+
Unlike run_parallel, this catches exceptions and returns them as results.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
func: Function to execute
|
|
81
|
+
items: Iterable of items to process
|
|
82
|
+
max_workers: Maximum number of parallel workers
|
|
83
|
+
timeout: Maximum time to wait for all tasks
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
List of (item, result_or_exception) tuples
|
|
87
|
+
|
|
88
|
+
Example:
|
|
89
|
+
def check_library(name: str) -> bool:
|
|
90
|
+
if name == "bad":
|
|
91
|
+
raise ValueError("Bad library")
|
|
92
|
+
return True
|
|
93
|
+
|
|
94
|
+
results = run_parallel_with_results(check_library, ["torch", "bad", "peft"])
|
|
95
|
+
# Returns: [("torch", True), ("bad", ValueError(...)), ("peft", True)]
|
|
96
|
+
"""
|
|
97
|
+
items_list = list(items)
|
|
98
|
+
|
|
99
|
+
if not items_list:
|
|
100
|
+
return []
|
|
101
|
+
|
|
102
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
103
|
+
future_to_item = {executor.submit(func, item): item for item in items_list}
|
|
104
|
+
|
|
105
|
+
results: List[tuple[T, R | Exception]] = []
|
|
106
|
+
|
|
107
|
+
for future in concurrent.futures.as_completed(future_to_item, timeout=timeout):
|
|
108
|
+
item = future_to_item[future]
|
|
109
|
+
try:
|
|
110
|
+
result = future.result()
|
|
111
|
+
results.append((item, result))
|
|
112
|
+
except Exception as e:
|
|
113
|
+
results.append((item, e))
|
|
114
|
+
|
|
115
|
+
return results
|
mlenvdoctor/retry.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Retry logic for transient failures."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import time
|
|
5
|
+
from typing import Any, Callable, Optional, TypeVar
|
|
6
|
+
|
|
7
|
+
from .exceptions import DiagnosticError
|
|
8
|
+
from .logger import logger
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def retry(
|
|
14
|
+
max_attempts: int = 3,
|
|
15
|
+
delay: float = 1.0,
|
|
16
|
+
backoff: float = 2.0,
|
|
17
|
+
exceptions: tuple[type[Exception], ...] = (Exception,),
|
|
18
|
+
on_retry: Optional[Callable[[Exception, int], None]] = None,
|
|
19
|
+
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
|
20
|
+
"""
|
|
21
|
+
Decorator to retry a function on failure.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
max_attempts: Maximum number of retry attempts
|
|
25
|
+
delay: Initial delay between retries in seconds
|
|
26
|
+
backoff: Multiplier for delay after each retry
|
|
27
|
+
exceptions: Tuple of exceptions to catch and retry
|
|
28
|
+
on_retry: Optional callback called on each retry
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Decorated function with retry logic
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
@retry(max_attempts=3, delay=1.0, exceptions=(ConnectionError,))
|
|
35
|
+
def fetch_data():
|
|
36
|
+
...
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
40
|
+
@functools.wraps(func)
|
|
41
|
+
def wrapper(*args: Any, **kwargs: Any) -> T:
|
|
42
|
+
current_delay = delay
|
|
43
|
+
last_exception: Optional[Exception] = None
|
|
44
|
+
|
|
45
|
+
for attempt in range(1, max_attempts + 1):
|
|
46
|
+
try:
|
|
47
|
+
return func(*args, **kwargs)
|
|
48
|
+
except exceptions as e:
|
|
49
|
+
last_exception = e
|
|
50
|
+
if attempt < max_attempts:
|
|
51
|
+
logger.warning(
|
|
52
|
+
f"{func.__name__} failed (attempt {attempt}/{max_attempts}): {e}. "
|
|
53
|
+
f"Retrying in {current_delay:.1f}s..."
|
|
54
|
+
)
|
|
55
|
+
if on_retry:
|
|
56
|
+
on_retry(e, attempt)
|
|
57
|
+
time.sleep(current_delay)
|
|
58
|
+
current_delay *= backoff
|
|
59
|
+
else:
|
|
60
|
+
logger.error(f"{func.__name__} failed after {max_attempts} attempts: {e}")
|
|
61
|
+
|
|
62
|
+
# All attempts failed
|
|
63
|
+
if last_exception:
|
|
64
|
+
raise DiagnosticError(
|
|
65
|
+
f"{func.__name__} failed after {max_attempts} attempts",
|
|
66
|
+
f"Last error: {last_exception}",
|
|
67
|
+
) from last_exception
|
|
68
|
+
|
|
69
|
+
# Should never reach here, but satisfy type checker
|
|
70
|
+
raise RuntimeError("Retry logic failed unexpectedly")
|
|
71
|
+
|
|
72
|
+
return wrapper
|
|
73
|
+
|
|
74
|
+
return decorator
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def retry_network(func: Callable[..., T]) -> Callable[..., T]:
|
|
78
|
+
"""
|
|
79
|
+
Decorator specifically for network operations.
|
|
80
|
+
|
|
81
|
+
Retries on network-related exceptions with exponential backoff.
|
|
82
|
+
"""
|
|
83
|
+
return retry(
|
|
84
|
+
max_attempts=3,
|
|
85
|
+
delay=1.0,
|
|
86
|
+
backoff=2.0,
|
|
87
|
+
exceptions=(
|
|
88
|
+
ConnectionError,
|
|
89
|
+
TimeoutError,
|
|
90
|
+
OSError,
|
|
91
|
+
),
|
|
92
|
+
)(func)
|