mlenvdoctor 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlenvdoctor/retry.py ADDED
@@ -0,0 +1,92 @@
1
+ """Retry logic for transient failures."""
2
+
3
+ import functools
4
+ import time
5
+ from typing import Any, Callable, Optional, TypeVar
6
+
7
+ from .exceptions import DiagnosticError
8
+ from .logger import logger
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ def retry(
14
+ max_attempts: int = 3,
15
+ delay: float = 1.0,
16
+ backoff: float = 2.0,
17
+ exceptions: tuple[type[Exception], ...] = (Exception,),
18
+ on_retry: Optional[Callable[[Exception, int], None]] = None,
19
+ ) -> Callable[[Callable[..., T]], Callable[..., T]]:
20
+ """
21
+ Decorator to retry a function on failure.
22
+
23
+ Args:
24
+ max_attempts: Maximum number of retry attempts
25
+ delay: Initial delay between retries in seconds
26
+ backoff: Multiplier for delay after each retry
27
+ exceptions: Tuple of exceptions to catch and retry
28
+ on_retry: Optional callback called on each retry
29
+
30
+ Returns:
31
+ Decorated function with retry logic
32
+
33
+ Example:
34
+ @retry(max_attempts=3, delay=1.0, exceptions=(ConnectionError,))
35
+ def fetch_data():
36
+ ...
37
+ """
38
+
39
+ def decorator(func: Callable[..., T]) -> Callable[..., T]:
40
+ @functools.wraps(func)
41
+ def wrapper(*args: Any, **kwargs: Any) -> T:
42
+ current_delay = delay
43
+ last_exception: Optional[Exception] = None
44
+
45
+ for attempt in range(1, max_attempts + 1):
46
+ try:
47
+ return func(*args, **kwargs)
48
+ except exceptions as e:
49
+ last_exception = e
50
+ if attempt < max_attempts:
51
+ logger.warning(
52
+ f"{func.__name__} failed (attempt {attempt}/{max_attempts}): {e}. "
53
+ f"Retrying in {current_delay:.1f}s..."
54
+ )
55
+ if on_retry:
56
+ on_retry(e, attempt)
57
+ time.sleep(current_delay)
58
+ current_delay *= backoff
59
+ else:
60
+ logger.error(f"{func.__name__} failed after {max_attempts} attempts: {e}")
61
+
62
+ # All attempts failed
63
+ if last_exception:
64
+ raise DiagnosticError(
65
+ f"{func.__name__} failed after {max_attempts} attempts",
66
+ f"Last error: {last_exception}",
67
+ ) from last_exception
68
+
69
+ # Should never reach here, but satisfy type checker
70
+ raise RuntimeError("Retry logic failed unexpectedly")
71
+
72
+ return wrapper
73
+
74
+ return decorator
75
+
76
+
77
+ def retry_network(func: Callable[..., T]) -> Callable[..., T]:
78
+ """
79
+ Decorator specifically for network operations.
80
+
81
+ Retries on network-related exceptions with exponential backoff.
82
+ """
83
+ return retry(
84
+ max_attempts=3,
85
+ delay=1.0,
86
+ backoff=2.0,
87
+ exceptions=(
88
+ ConnectionError,
89
+ TimeoutError,
90
+ OSError,
91
+ ),
92
+ )(func)
mlenvdoctor/utils.py CHANGED
@@ -5,10 +5,20 @@ import sys
5
5
  from pathlib import Path
6
6
  from typing import List, Optional, Tuple
7
7
 
8
+ import sys
9
+
8
10
  from rich.console import Console
9
11
  from rich.progress import Progress, SpinnerColumn, TextColumn
10
12
 
11
- console = Console()
13
+ from .exceptions import DiagnosticError
14
+ from .icons import icon_check, icon_cross, icon_info, icon_warning
15
+
16
+ # Configure console for Windows compatibility
17
+ if sys.platform == "win32":
18
+ # Use legacy Windows renderer if needed
19
+ console = Console(legacy_windows=True, force_terminal=True)
20
+ else:
21
+ console = Console()
12
22
 
13
23
 
14
24
  def run_command(
@@ -16,8 +26,29 @@ def run_command(
16
26
  capture_output: bool = True,
17
27
  check: bool = False,
18
28
  timeout: Optional[int] = 30,
19
- ) -> subprocess.CompletedProcess:
20
- """Run a shell command with error handling."""
29
+ ) -> subprocess.CompletedProcess[str]:
30
+ """
31
+ Run a shell command with error handling and input validation.
32
+
33
+ Args:
34
+ cmd: Command and arguments as a list
35
+ capture_output: Whether to capture stdout/stderr
36
+ check: Whether to raise on non-zero exit code
37
+ timeout: Command timeout in seconds
38
+
39
+ Returns:
40
+ CompletedProcess with command result
41
+
42
+ Raises:
43
+ DiagnosticError: For command execution errors
44
+ ConfigurationError: For invalid input
45
+ """
46
+ from .validators import sanitize_command, validate_timeout
47
+
48
+ # Validate and sanitize inputs
49
+ cmd = sanitize_command(cmd)
50
+ timeout = validate_timeout(timeout)
51
+
21
52
  try:
22
53
  result = subprocess.run(
23
54
  cmd,
@@ -27,29 +58,58 @@ def run_command(
27
58
  timeout=timeout,
28
59
  )
29
60
  return result
30
- except subprocess.TimeoutExpired:
31
- console.print(f"[red]Command timed out: {' '.join(cmd)}[/red]")
32
- raise
33
- except FileNotFoundError:
34
- console.print(f"[red]Command not found: {cmd[0]}[/red]")
35
- raise
61
+ except subprocess.TimeoutExpired as e:
62
+ error_msg = f"Command timed out after {timeout}s: {' '.join(cmd)}"
63
+ console.print(f"[red]{error_msg}[/red]")
64
+ raise DiagnosticError(
65
+ error_msg,
66
+ "Try increasing timeout or check if the command is hanging",
67
+ ) from e
68
+ except FileNotFoundError as e:
69
+ error_msg = f"Command not found: {cmd[0]}"
70
+ console.print(f"[red]{error_msg}[/red]")
71
+ raise DiagnosticError(
72
+ error_msg,
73
+ f"Install {cmd[0]} or ensure it's in your PATH",
74
+ ) from e
36
75
  except subprocess.CalledProcessError as e:
37
76
  if not check:
38
- return e # type: ignore
77
+ # Return the exception as if it were a result
78
+ # This maintains backward compatibility but is type-unsafe
79
+ return subprocess.CompletedProcess( # type: ignore[return-value]
80
+ cmd, e.returncode, e.stdout, e.stderr
81
+ )
39
82
  raise
40
83
 
41
84
 
42
85
  def check_command_exists(cmd: str) -> bool:
43
- """Check if a command exists in PATH."""
86
+ """
87
+ Check if a command exists in PATH.
88
+
89
+ Args:
90
+ cmd: Command name to check
91
+
92
+ Returns:
93
+ True if command exists and is executable, False otherwise
94
+ """
95
+ if not isinstance(cmd, str) or not cmd.strip():
96
+ return False
97
+
44
98
  try:
45
- subprocess.run(
46
- [cmd, "--version"] if cmd != "nvidia-smi" else [cmd],
99
+ # Use 'which' on Unix, 'where' on Windows
100
+ if sys.platform == "win32":
101
+ check_cmd = ["where", cmd]
102
+ else:
103
+ check_cmd = ["which", cmd]
104
+
105
+ result = subprocess.run(
106
+ check_cmd,
47
107
  capture_output=True,
48
108
  timeout=5,
49
109
  check=False,
50
110
  )
51
- return True
52
- except (FileNotFoundError, subprocess.TimeoutExpired):
111
+ return result.returncode == 0
112
+ except (FileNotFoundError, subprocess.TimeoutExpired, Exception):
53
113
  return False
54
114
 
55
115
 
@@ -63,22 +123,22 @@ def get_home_config_dir() -> Path:
63
123
 
64
124
  def print_success(message: str) -> None:
65
125
  """Print a success message."""
66
- console.print(f"[green] {message}[/green]")
126
+ console.print(f"[green]{icon_check()} {message}[/green]")
67
127
 
68
128
 
69
129
  def print_error(message: str) -> None:
70
130
  """Print an error message."""
71
- console.print(f"[red] {message}[/red]")
131
+ console.print(f"[red]{icon_cross()} {message}[/red]")
72
132
 
73
133
 
74
134
  def print_warning(message: str) -> None:
75
135
  """Print a warning message."""
76
- console.print(f"[yellow]⚠️ {message}[/yellow]")
136
+ console.print(f"[yellow]{icon_warning()} {message}[/yellow]")
77
137
 
78
138
 
79
139
  def print_info(message: str) -> None:
80
140
  """Print an info message."""
81
- console.print(f"[blue]ℹ️ {message}[/blue]")
141
+ console.print(f"[blue]{icon_info()} {message}[/blue]")
82
142
 
83
143
 
84
144
  def with_spinner(message: str):
@@ -0,0 +1,217 @@
1
+ """Input validation and sanitization for ML Environment Doctor."""
2
+
3
+ import re
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ from .exceptions import ConfigurationError
8
+
9
+
10
+ def validate_model_name(model_name: str) -> str:
11
+ """
12
+ Validate and sanitize model name.
13
+
14
+ Args:
15
+ model_name: Model name to validate
16
+
17
+ Returns:
18
+ Sanitized model name
19
+
20
+ Raises:
21
+ ConfigurationError: If model name is invalid
22
+ """
23
+ if not model_name or not isinstance(model_name, str):
24
+ raise ConfigurationError(
25
+ "Model name must be a non-empty string",
26
+ "Use a valid model name like 'tinyllama', 'gpt2', or 'mistral-7b'",
27
+ )
28
+
29
+ # Remove whitespace
30
+ model_name = model_name.strip()
31
+
32
+ # Check for dangerous characters (basic sanitization)
33
+ if not re.match(r"^[a-zA-Z0-9._-]+$", model_name):
34
+ raise ConfigurationError(
35
+ f"Invalid model name: {model_name}",
36
+ "Model name can only contain letters, numbers, dots, underscores, and hyphens",
37
+ )
38
+
39
+ return model_name.lower()
40
+
41
+
42
+ def validate_file_path(file_path: Path, must_exist: bool = False, must_be_file: bool = False) -> Path:
43
+ """
44
+ Validate and sanitize file path.
45
+
46
+ Args:
47
+ file_path: Path to validate
48
+ must_exist: Whether the path must exist
49
+ must_be_file: Whether the path must be a file
50
+
51
+ Returns:
52
+ Resolved, absolute path
53
+
54
+ Raises:
55
+ ConfigurationError: If path is invalid
56
+ """
57
+ if not isinstance(file_path, (Path, str)):
58
+ raise ConfigurationError(
59
+ "File path must be a Path object or string",
60
+ "Use pathlib.Path or a valid string path",
61
+ )
62
+
63
+ path = Path(file_path).resolve()
64
+
65
+ # Check for path traversal attempts
66
+ if ".." in str(path):
67
+ # Resolve should handle this, but double-check
68
+ resolved = path.resolve()
69
+ if ".." in str(resolved):
70
+ raise ConfigurationError(
71
+ "Invalid path: contains '..'",
72
+ "Use absolute paths or relative paths without '..'",
73
+ )
74
+
75
+ if must_exist and not path.exists():
76
+ raise ConfigurationError(
77
+ f"Path does not exist: {path}",
78
+ "Ensure the file or directory exists",
79
+ )
80
+
81
+ if must_be_file and not path.is_file():
82
+ raise ConfigurationError(
83
+ f"Path is not a file: {path}",
84
+ "Provide a valid file path",
85
+ )
86
+
87
+ return path
88
+
89
+
90
+ def validate_log_level(level: str) -> str:
91
+ """
92
+ Validate logging level.
93
+
94
+ Args:
95
+ level: Log level to validate
96
+
97
+ Returns:
98
+ Validated log level
99
+
100
+ Raises:
101
+ ConfigurationError: If level is invalid
102
+ """
103
+ valid_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
104
+ level_upper = level.upper() if isinstance(level, str) else str(level).upper()
105
+
106
+ if level_upper not in valid_levels:
107
+ raise ConfigurationError(
108
+ f"Invalid log level: {level}",
109
+ f"Use one of: {', '.join(valid_levels)}",
110
+ )
111
+
112
+ return level_upper
113
+
114
+
115
+ def validate_stack_name(stack: str) -> str:
116
+ """
117
+ Validate ML stack name.
118
+
119
+ Args:
120
+ stack: Stack name to validate
121
+
122
+ Returns:
123
+ Validated stack name
124
+
125
+ Raises:
126
+ ConfigurationError: If stack is invalid
127
+ """
128
+ valid_stacks = {"trl-peft", "minimal"}
129
+ stack_lower = stack.lower() if isinstance(stack, str) else str(stack).lower()
130
+
131
+ if stack_lower not in valid_stacks:
132
+ raise ConfigurationError(
133
+ f"Invalid stack: {stack}",
134
+ f"Use one of: {', '.join(valid_stacks)}",
135
+ )
136
+
137
+ return stack_lower
138
+
139
+
140
+ def sanitize_command(cmd: list[str]) -> list[str]:
141
+ """
142
+ Sanitize command arguments to prevent injection.
143
+
144
+ Args:
145
+ cmd: Command and arguments list
146
+
147
+ Returns:
148
+ Sanitized command list
149
+
150
+ Raises:
151
+ ConfigurationError: If command contains dangerous patterns
152
+ """
153
+ if not isinstance(cmd, list) or not cmd:
154
+ raise ConfigurationError(
155
+ "Command must be a non-empty list",
156
+ "Provide command as a list of strings",
157
+ )
158
+
159
+ sanitized = []
160
+ for arg in cmd:
161
+ if not isinstance(arg, str):
162
+ raise ConfigurationError(
163
+ "All command arguments must be strings",
164
+ "Convert all arguments to strings",
165
+ )
166
+
167
+ # Check for command injection patterns
168
+ dangerous_patterns = [";", "&&", "||", "`", "$(", "<", ">", "|"]
169
+ for pattern in dangerous_patterns:
170
+ if pattern in arg:
171
+ raise ConfigurationError(
172
+ f"Dangerous pattern detected in command: {pattern}",
173
+ "Do not use shell operators in command arguments",
174
+ )
175
+
176
+ sanitized.append(arg)
177
+
178
+ return sanitized
179
+
180
+
181
+ def validate_timeout(timeout: Optional[int], min_timeout: int = 1, max_timeout: int = 3600) -> Optional[int]:
182
+ """
183
+ Validate timeout value.
184
+
185
+ Args:
186
+ timeout: Timeout in seconds
187
+ min_timeout: Minimum allowed timeout
188
+ max_timeout: Maximum allowed timeout
189
+
190
+ Returns:
191
+ Validated timeout
192
+
193
+ Raises:
194
+ ConfigurationError: If timeout is invalid
195
+ """
196
+ if timeout is None:
197
+ return None
198
+
199
+ if not isinstance(timeout, int):
200
+ raise ConfigurationError(
201
+ f"Timeout must be an integer, got {type(timeout)}",
202
+ "Provide timeout as an integer number of seconds",
203
+ )
204
+
205
+ if timeout < min_timeout:
206
+ raise ConfigurationError(
207
+ f"Timeout too small: {timeout}s (minimum: {min_timeout}s)",
208
+ f"Increase timeout to at least {min_timeout} seconds",
209
+ )
210
+
211
+ if timeout > max_timeout:
212
+ raise ConfigurationError(
213
+ f"Timeout too large: {timeout}s (maximum: {max_timeout}s)",
214
+ f"Decrease timeout to at most {max_timeout} seconds",
215
+ )
216
+
217
+ return timeout
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlenvdoctor
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: Diagnose & fix ML environments for LLM fine-tuning
5
5
  Author: ML Environment Doctor Contributors
6
6
  License: MIT
@@ -20,6 +20,7 @@ Requires-Python: >=3.8
20
20
  Requires-Dist: packaging>=23.0
21
21
  Requires-Dist: psutil>=5.9.0
22
22
  Requires-Dist: rich>=13.0.0
23
+ Requires-Dist: tomli>=2.0.0; python_version < '3.11'
23
24
  Requires-Dist: typer>=0.9.0
24
25
  Provides-Extra: dev
25
26
  Requires-Dist: black>=23.0.0; extra == 'dev'
@@ -34,7 +35,7 @@ Description-Content-Type: text/markdown
34
35
 
35
36
  [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
36
37
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
37
- [![PyPI](https://img.shields.io/pypi/v/mlenvdoctor.svg)](https://pypi.org/project/mlenvdoctor/)
38
+ [![PyPI](https://img.shields.io/pypi/v/mlenvdoctor.svg)]([https://pypi.org/project/mlenvdoctor/])
38
39
 
39
40
  > **Single command fixes 90% of "my torch.cuda.is_available() is False" issues.**
40
41
 
@@ -0,0 +1,21 @@
1
+ mlenvdoctor/__init__.py,sha256=igtJXQ-DiuG4_2BcfW64KPxvhCzSeuXUNUVqDzuLDI8,327
2
+ mlenvdoctor/cli.py,sha256=jvc1MWNciyRGedYCEKSCZeqXfX4wf6xe38Vx44k86xc,7145
3
+ mlenvdoctor/config.py,sha256=c7WUd8XUAo8nz-Ri4TIldc-_p3wiyDNqOsL-MhV9jbI,4770
4
+ mlenvdoctor/constants.py,sha256=ZoYf6-gqmw88CGENyYVur_hCmdcUm0IR5N9pUdOLjJY,1980
5
+ mlenvdoctor/diagnose.py,sha256=7lsAVOPeK2I3nmlzMKmtniyEh1UGv76bJqlqSpjvPL8,20553
6
+ mlenvdoctor/dockerize.py,sha256=AC8HX5sRkSFAM0O0caBnKW4HAdS49MVmMcsplKEDXI4,5562
7
+ mlenvdoctor/exceptions.py,sha256=8wzZE-In0zimXJ3omUA3YmFeklcseO53kqo0SpB58DM,1069
8
+ mlenvdoctor/export.py,sha256=CsuLpbpR2OVs-bud27K8Xv28KCI9veNPIWCsznyFmaw,8671
9
+ mlenvdoctor/fix.py,sha256=fXS4uxBN-FWRFKowmOsdPYI7bnY8jnYxG7UADpJ1hwc,8989
10
+ mlenvdoctor/gpu.py,sha256=sMFgtF4pt-dpOr6IDxvm6f0ChfmCf58E8mApIH6jvAs,6295
11
+ mlenvdoctor/icons.py,sha256=vu35SuBxlZu55rUqldgQfX4UeHeENt082ighZKnzHZY,2289
12
+ mlenvdoctor/logger.py,sha256=OKJQjcdOspARokBcIDyCri14gS_7MlrCt2B3znKc34Y,2377
13
+ mlenvdoctor/parallel.py,sha256=HVJmu8t4k2-XeoQPHlLmhInLOcXjUTiP0VM3THOkAmE,3630
14
+ mlenvdoctor/retry.py,sha256=ZH-KWe6BPNK1sUNFSK2uwdobZ_Z77fxMUfSrSpInt3c,2993
15
+ mlenvdoctor/utils.py,sha256=ehohh-iRLe2qkOMxj5v9yTWONf5gWSdY6CvfrRttTlg,4862
16
+ mlenvdoctor/validators.py,sha256=Kz1FcJM4Cym-S_z5vTocv0cxzKOEgvZIqv8C8c1gSzY,6109
17
+ mlenvdoctor-0.1.2.dist-info/METADATA,sha256=7Rlpv9kjMHQWPBsnZR1gbaUzocWNXNJ3wdGy5UXaNQU,8942
18
+ mlenvdoctor-0.1.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
19
+ mlenvdoctor-0.1.2.dist-info/entry_points.txt,sha256=Y-WH-ANeiTdECIaqi_EB3ZEf_kACkvsYBHnNhXsCI4k,52
20
+ mlenvdoctor-0.1.2.dist-info/licenses/LICENSE,sha256=rGHdyWGvGWYnEFlthqtB-RtRCTa7WaAOElom5qD-nHw,1114
21
+ mlenvdoctor-0.1.2.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- mlenvdoctor/__init__.py,sha256=vYK9Wp5kAcHKL7njV76xMfPA24ILdxJvIhoumF6-Sz4,110
2
- mlenvdoctor/cli.py,sha256=aQ2rpjxsfMDwYAVKsn7cHSar42AZJTrr1tSTl-iM0L4,5488
3
- mlenvdoctor/diagnose.py,sha256=xa3aqCornGApMJkEWQNIGHwNBRhGA3ud1hBQ6wIVhVQ,17099
4
- mlenvdoctor/dockerize.py,sha256=AC8HX5sRkSFAM0O0caBnKW4HAdS49MVmMcsplKEDXI4,5562
5
- mlenvdoctor/fix.py,sha256=P4Qce41LLgjaHugbMFFSg7ldfsSSNDFBz5_T_YA9mig,8945
6
- mlenvdoctor/gpu.py,sha256=iuiLAW8lZLBpuUL1yapOkx5VYLtY_i1SwK9cE5koZTE,6129
7
- mlenvdoctor/utils.py,sha256=2gtbiJogEI33IpOLHGEfks6b7Jd1Y7pyfojW9wpYsjU,2893
8
- mlenvdoctor-0.1.1.dist-info/METADATA,sha256=NQHbKeu7KnZfHWPJcqQEYUXZqIRL7ORCesdMyyqFKU8,8887
9
- mlenvdoctor-0.1.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
10
- mlenvdoctor-0.1.1.dist-info/entry_points.txt,sha256=Y-WH-ANeiTdECIaqi_EB3ZEf_kACkvsYBHnNhXsCI4k,52
11
- mlenvdoctor-0.1.1.dist-info/licenses/LICENSE,sha256=rGHdyWGvGWYnEFlthqtB-RtRCTa7WaAOElom5qD-nHw,1114
12
- mlenvdoctor-0.1.1.dist-info/RECORD,,