mlenvdoctor 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlenvdoctor/fix.py CHANGED
@@ -1,14 +1,11 @@
1
1
  """Auto-fix and requirements generation for ML Environment Doctor."""
2
2
 
3
- import subprocess
4
3
  import sys
5
4
  from pathlib import Path
6
- from typing import List, Optional
5
+ from typing import Optional
7
6
 
8
- from rich.console import Console
9
- from rich.progress import Progress, SpinnerColumn, TextColumn
10
-
11
- from .diagnose import DiagnosticIssue, diagnose_env
7
+ from .diagnose import diagnose_env
8
+ from .icons import icon_wrench
12
9
  from .utils import (
13
10
  check_command_exists,
14
11
  console,
@@ -39,7 +36,9 @@ ML_STACKS = {
39
36
  }
40
37
 
41
38
 
42
- def generate_requirements_txt(stack: str = "trl-peft", output_file: str = "requirements-mlenvdoctor.txt") -> Path:
39
+ def generate_requirements_txt(
40
+ stack: str = "trl-peft", output_file: str = "requirements-mlenvdoctor.txt"
41
+ ) -> Path:
43
42
  """Generate requirements.txt file."""
44
43
  if stack not in ML_STACKS:
45
44
  print_error(f"Unknown stack: {stack}. Available: {list(ML_STACKS.keys())}")
@@ -61,7 +60,9 @@ def generate_requirements_txt(stack: str = "trl-peft", output_file: str = "requi
61
60
  content = "# Standard PyTorch (CPU or CUDA)\n\n"
62
61
  except ImportError:
63
62
  content = "# PyTorch installation\n"
64
- content += "# For CUDA: pip install torch --index-url https://download.pytorch.org/whl/cu124\n"
63
+ content += (
64
+ "# For CUDA: pip install torch --index-url https://download.pytorch.org/whl/cu124\n"
65
+ )
65
66
  content += "# For CPU: pip install torch\n\n"
66
67
 
67
68
  content += "\n".join(requirements)
@@ -72,7 +73,9 @@ def generate_requirements_txt(stack: str = "trl-peft", output_file: str = "requi
72
73
  return output_path
73
74
 
74
75
 
75
- def generate_conda_env(stack: str = "trl-peft", output_file: str = "environment-mlenvdoctor.yml") -> Path:
76
+ def generate_conda_env(
77
+ stack: str = "trl-peft", output_file: str = "environment-mlenvdoctor.yml"
78
+ ) -> Path:
76
79
  """Generate conda environment file."""
77
80
  if stack not in ML_STACKS:
78
81
  print_error(f"Unknown stack: {stack}. Available: {list(ML_STACKS.keys())}")
@@ -197,7 +200,11 @@ def create_virtualenv(env_name: str = ".venv") -> Optional[Path]:
197
200
 
198
201
  venv.create(env_path, with_pip=True)
199
202
  print_success(f"Virtual environment created: {env_name}")
200
- print_info(f"Activate with: {'.venv\\Scripts\\activate' if sys.platform == 'win32' else 'source .venv/bin/activate'}")
203
+ if sys.platform == "win32":
204
+ activate_cmd = r".venv\Scripts\activate"
205
+ else:
206
+ activate_cmd = "source .venv/bin/activate"
207
+ print_info(f"Activate with: {activate_cmd}")
201
208
  return env_path
202
209
  except Exception as e:
203
210
  print_error(f"Failed to create virtual environment: {e}")
@@ -206,7 +213,7 @@ def create_virtualenv(env_name: str = ".venv") -> Optional[Path]:
206
213
 
207
214
  def auto_fix(use_conda: bool = False, create_venv: bool = False, stack: str = "trl-peft") -> bool:
208
215
  """Auto-fix environment issues based on diagnostics."""
209
- console.print("[bold blue]🔧 Running Auto-Fix...[/bold blue]\n")
216
+ console.print(f"[bold blue]{icon_wrench()} Running Auto-Fix...[/bold blue]\n")
210
217
 
211
218
  # Run diagnostics
212
219
  issues = diagnose_env(full=False)
@@ -243,7 +250,6 @@ def auto_fix(use_conda: bool = False, create_venv: bool = False, stack: str = "t
243
250
  if install.lower() in ["y", "yes"]:
244
251
  return install_requirements(str(req_file), use_conda=use_conda)
245
252
  else:
246
- print_info(f"Requirements file generated. Install manually with:")
253
+ print_info("Requirements file generated. Install manually with:")
247
254
  console.print(f"[cyan] pip install -r {req_file}[/cyan]")
248
255
  return True
249
-
mlenvdoctor/gpu.py CHANGED
@@ -1,19 +1,24 @@
1
1
  """GPU benchmarks and smoke tests for ML Environment Doctor."""
2
2
 
3
3
  import time
4
- from typing import Dict, List, Optional
4
+ from typing import Dict
5
5
 
6
6
  try:
7
7
  import torch
8
8
  except ImportError:
9
9
  torch = None # type: ignore
10
10
 
11
+ import sys
12
+
11
13
  from rich.console import Console
12
- from rich.progress import Progress, SpinnerColumn, TextColumn
13
14
 
14
15
  from .utils import print_error, print_info, print_success
15
16
 
16
- console = Console()
17
+ # Configure console for Windows compatibility
18
+ if sys.platform == "win32":
19
+ console = Console(legacy_windows=True, force_terminal=True)
20
+ else:
21
+ console = Console()
17
22
 
18
23
 
19
24
  def benchmark_gpu_ops() -> Dict[str, float]:
@@ -78,7 +83,8 @@ def smoke_test_lora() -> bool:
78
83
  if tokenizer.pad_token is None:
79
84
  tokenizer.pad_token = tokenizer.eos_token
80
85
  model = AutoModelForCausalLM.from_pretrained(
81
- model_name, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
86
+ model_name,
87
+ torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
82
88
  ).to(device)
83
89
 
84
90
  # Configure LoRA
@@ -101,8 +107,7 @@ def smoke_test_lora() -> bool:
101
107
  # Forward pass
102
108
  with console.status("[bold green]Running forward pass..."):
103
109
  with torch.no_grad():
104
- outputs = model(**inputs)
105
- loss = outputs.loss if hasattr(outputs, "loss") else None
110
+ _ = model(**inputs)
106
111
 
107
112
  print_success("LoRA smoke test passed!")
108
113
  return True
@@ -146,7 +151,9 @@ def test_model(model_name: str = "tinyllama") -> bool:
146
151
  # Estimate memory requirements (rough)
147
152
  if "7b" in actual_model_name.lower() or "7B" in actual_model_name:
148
153
  if free_gb < 16:
149
- print_error(f"Insufficient GPU memory: {free_gb:.1f}GB free, need ~16GB for 7B model")
154
+ print_error(
155
+ f"Insufficient GPU memory: {free_gb:.1f}GB free, need ~16GB for 7B model"
156
+ )
150
157
  return False
151
158
 
152
159
  with console.status(f"[bold green]Loading {actual_model_name}..."):
@@ -166,7 +173,7 @@ def test_model(model_name: str = "tinyllama") -> bool:
166
173
  inputs = tokenizer(dummy_text, return_tensors="pt").to(device)
167
174
 
168
175
  with torch.no_grad():
169
- outputs = model(**inputs)
176
+ _ = model(**inputs)
170
177
 
171
178
  print_success(f"Model {actual_model_name} loaded and tested successfully!")
172
179
  return True
@@ -181,4 +188,3 @@ def test_model(model_name: str = "tinyllama") -> bool:
181
188
  except Exception as e:
182
189
  print_error(f"Model test error: {e}")
183
190
  return False
184
-
mlenvdoctor/icons.py ADDED
@@ -0,0 +1,100 @@
1
+ """Safe emoji/icon handling for cross-platform compatibility."""
2
+
3
+ import sys
4
+ from typing import Literal
5
+
6
+ # Check if we can safely use emojis
7
+ _USE_EMOJIS = True
8
+
9
+ try:
10
+ # Try to write an emoji to see if it works
11
+ if sys.platform == "win32":
12
+ import io
13
+
14
+ # Test if console supports UTF-8
15
+ test_output = io.StringIO()
16
+ try:
17
+ test_output.write("🔍")
18
+ test_output.getvalue()
19
+ except (UnicodeEncodeError, UnicodeError):
20
+ _USE_EMOJIS = False
21
+ except Exception:
22
+ _USE_EMOJIS = False
23
+
24
+
25
+ def get_icon(icon_name: Literal["search", "check", "cross", "warning", "info", "wrench", "whale", "test"]) -> str:
26
+ """
27
+ Get a safe icon/emoji for the current platform.
28
+
29
+ Args:
30
+ icon_name: Name of the icon to get
31
+
32
+ Returns:
33
+ Emoji if supported, ASCII alternative otherwise
34
+ """
35
+ if _USE_EMOJIS:
36
+ icons = {
37
+ "search": "🔍",
38
+ "check": "✅",
39
+ "cross": "❌",
40
+ "warning": "⚠️",
41
+ "info": "ℹ️",
42
+ "wrench": "🔧",
43
+ "whale": "🐳",
44
+ "test": "🧪",
45
+ }
46
+ else:
47
+ # ASCII alternatives
48
+ icons = {
49
+ "search": "[*]",
50
+ "check": "[OK]",
51
+ "cross": "[X]",
52
+ "warning": "[!]",
53
+ "info": "[i]",
54
+ "wrench": "[FIX]",
55
+ "whale": "[DOCKER]",
56
+ "test": "[TEST]",
57
+ }
58
+
59
+ return icons.get(icon_name, "")
60
+
61
+
62
+ # Convenience functions
63
+ def icon_search() -> str:
64
+ """Get search icon."""
65
+ return get_icon("search")
66
+
67
+
68
+ def icon_check() -> str:
69
+ """Get check icon."""
70
+ return get_icon("check")
71
+
72
+
73
+ def icon_cross() -> str:
74
+ """Get cross/error icon."""
75
+ return get_icon("cross")
76
+
77
+
78
+ def icon_warning() -> str:
79
+ """Get warning icon."""
80
+ return get_icon("warning")
81
+
82
+
83
+ def icon_info() -> str:
84
+ """Get info icon."""
85
+ return get_icon("info")
86
+
87
+
88
+ def icon_wrench() -> str:
89
+ """Get wrench/fix icon."""
90
+ return get_icon("wrench")
91
+
92
+
93
+ def icon_whale() -> str:
94
+ """Get whale/docker icon."""
95
+ return get_icon("whale")
96
+
97
+
98
+ def icon_test() -> str:
99
+ """Get test icon."""
100
+ return get_icon("test")
mlenvdoctor/logger.py ADDED
@@ -0,0 +1,81 @@
1
+ """Logging configuration for ML Environment Doctor."""
2
+
3
+ import logging
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ from rich.console import Console
9
+ from rich.logging import RichHandler
10
+
11
+ from .utils import get_home_config_dir
12
+
13
+ console = Console()
14
+
15
+
16
+ def setup_logger(
17
+ name: str = "mlenvdoctor",
18
+ level: str = "INFO",
19
+ log_file: Optional[Path] = None,
20
+ enable_rich: bool = True,
21
+ ) -> logging.Logger:
22
+ """
23
+ Set up logger with Rich console handler and optional file handler.
24
+
25
+ Args:
26
+ name: Logger name
27
+ level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
28
+ log_file: Optional path to log file
29
+ enable_rich: Use Rich handler for console output
30
+
31
+ Returns:
32
+ Configured logger instance
33
+ """
34
+ logger = logging.getLogger(name)
35
+ logger.setLevel(getattr(logging, level.upper()))
36
+
37
+ # Remove existing handlers to avoid duplicates
38
+ logger.handlers.clear()
39
+
40
+ # Console handler with Rich formatting
41
+ if enable_rich:
42
+ console_handler = RichHandler(
43
+ console=console,
44
+ show_time=True,
45
+ show_path=False,
46
+ rich_tracebacks=True,
47
+ tracebacks_show_locals=False,
48
+ )
49
+ else:
50
+ console_handler = logging.StreamHandler(sys.stdout)
51
+ console_handler.setFormatter(
52
+ logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
53
+ )
54
+
55
+ console_handler.setLevel(getattr(logging, level.upper()))
56
+ logger.addHandler(console_handler)
57
+
58
+ # File handler if log file specified
59
+ if log_file:
60
+ log_file.parent.mkdir(parents=True, exist_ok=True)
61
+ file_handler = logging.FileHandler(log_file, encoding="utf-8")
62
+ file_handler.setLevel(logging.DEBUG) # Always log everything to file
63
+ file_handler.setFormatter(
64
+ logging.Formatter(
65
+ "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s"
66
+ )
67
+ )
68
+ logger.addHandler(file_handler)
69
+
70
+ return logger
71
+
72
+
73
+ def get_default_log_file() -> Path:
74
+ """Get default log file path."""
75
+ log_dir = get_home_config_dir() / "logs"
76
+ log_dir.mkdir(parents=True, exist_ok=True)
77
+ return log_dir / "mlenvdoctor.log"
78
+
79
+
80
+ # Default logger instance
81
+ logger = setup_logger()
@@ -0,0 +1,115 @@
1
+ """Parallel execution utilities for independent operations."""
2
+
3
+ import concurrent.futures
4
+ from typing import Callable, Iterable, List, TypeVar
5
+
6
+ from .logger import logger
7
+
8
+ T = TypeVar("T")
9
+ R = TypeVar("R")
10
+
11
+
12
+ def run_parallel(
13
+ func: Callable[[T], R],
14
+ items: Iterable[T],
15
+ max_workers: int = 4,
16
+ timeout: float | None = None,
17
+ ) -> List[R]:
18
+ """
19
+ Run a function in parallel on multiple items.
20
+
21
+ Args:
22
+ func: Function to execute
23
+ items: Iterable of items to process
24
+ items_list: List of items to process
25
+ max_workers: Maximum number of parallel workers
26
+ timeout: Maximum time to wait for all tasks (None = no timeout)
27
+
28
+ Returns:
29
+ List of results in the same order as input items
30
+
31
+ Example:
32
+ def check_library(name: str) -> bool:
33
+ return importlib.util.find_spec(name) is not None
34
+
35
+ results = run_parallel(check_library, ["torch", "transformers", "peft"])
36
+ """
37
+ items_list = list(items)
38
+
39
+ if not items_list:
40
+ return []
41
+
42
+ # Use ThreadPoolExecutor for I/O-bound operations
43
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
44
+ # Submit all tasks
45
+ future_to_item = {executor.submit(func, item): item for item in items_list}
46
+
47
+ results: List[R] = []
48
+ completed = 0
49
+
50
+ # Process completed tasks
51
+ for future in concurrent.futures.as_completed(future_to_item, timeout=timeout):
52
+ item = future_to_item[future]
53
+ try:
54
+ result = future.result()
55
+ results.append(result)
56
+ completed += 1
57
+ except Exception as e:
58
+ logger.error(f"Error processing {item}: {e}")
59
+ # Re-raise to maintain error behavior
60
+ raise
61
+
62
+ if completed != len(items_list):
63
+ raise RuntimeError(f"Only {completed}/{len(items_list)} tasks completed")
64
+
65
+ return results
66
+
67
+
68
+ def run_parallel_with_results(
69
+ func: Callable[[T], R],
70
+ items: Iterable[T],
71
+ max_workers: int = 4,
72
+ timeout: float | None = None,
73
+ ) -> List[tuple[T, R | Exception]]:
74
+ """
75
+ Run a function in parallel and return results with original items.
76
+
77
+ Unlike run_parallel, this catches exceptions and returns them as results.
78
+
79
+ Args:
80
+ func: Function to execute
81
+ items: Iterable of items to process
82
+ max_workers: Maximum number of parallel workers
83
+ timeout: Maximum time to wait for all tasks
84
+
85
+ Returns:
86
+ List of (item, result_or_exception) tuples
87
+
88
+ Example:
89
+ def check_library(name: str) -> bool:
90
+ if name == "bad":
91
+ raise ValueError("Bad library")
92
+ return True
93
+
94
+ results = run_parallel_with_results(check_library, ["torch", "bad", "peft"])
95
+ # Returns: [("torch", True), ("bad", ValueError(...)), ("peft", True)]
96
+ """
97
+ items_list = list(items)
98
+
99
+ if not items_list:
100
+ return []
101
+
102
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
103
+ future_to_item = {executor.submit(func, item): item for item in items_list}
104
+
105
+ results: List[tuple[T, R | Exception]] = []
106
+
107
+ for future in concurrent.futures.as_completed(future_to_item, timeout=timeout):
108
+ item = future_to_item[future]
109
+ try:
110
+ result = future.result()
111
+ results.append((item, result))
112
+ except Exception as e:
113
+ results.append((item, e))
114
+
115
+ return results
mlenvdoctor/retry.py ADDED
@@ -0,0 +1,92 @@
1
+ """Retry logic for transient failures."""
2
+
3
+ import functools
4
+ import time
5
+ from typing import Any, Callable, Optional, TypeVar
6
+
7
+ from .exceptions import DiagnosticError
8
+ from .logger import logger
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ def retry(
14
+ max_attempts: int = 3,
15
+ delay: float = 1.0,
16
+ backoff: float = 2.0,
17
+ exceptions: tuple[type[Exception], ...] = (Exception,),
18
+ on_retry: Optional[Callable[[Exception, int], None]] = None,
19
+ ) -> Callable[[Callable[..., T]], Callable[..., T]]:
20
+ """
21
+ Decorator to retry a function on failure.
22
+
23
+ Args:
24
+ max_attempts: Maximum number of retry attempts
25
+ delay: Initial delay between retries in seconds
26
+ backoff: Multiplier for delay after each retry
27
+ exceptions: Tuple of exceptions to catch and retry
28
+ on_retry: Optional callback called on each retry
29
+
30
+ Returns:
31
+ Decorated function with retry logic
32
+
33
+ Example:
34
+ @retry(max_attempts=3, delay=1.0, exceptions=(ConnectionError,))
35
+ def fetch_data():
36
+ ...
37
+ """
38
+
39
+ def decorator(func: Callable[..., T]) -> Callable[..., T]:
40
+ @functools.wraps(func)
41
+ def wrapper(*args: Any, **kwargs: Any) -> T:
42
+ current_delay = delay
43
+ last_exception: Optional[Exception] = None
44
+
45
+ for attempt in range(1, max_attempts + 1):
46
+ try:
47
+ return func(*args, **kwargs)
48
+ except exceptions as e:
49
+ last_exception = e
50
+ if attempt < max_attempts:
51
+ logger.warning(
52
+ f"{func.__name__} failed (attempt {attempt}/{max_attempts}): {e}. "
53
+ f"Retrying in {current_delay:.1f}s..."
54
+ )
55
+ if on_retry:
56
+ on_retry(e, attempt)
57
+ time.sleep(current_delay)
58
+ current_delay *= backoff
59
+ else:
60
+ logger.error(f"{func.__name__} failed after {max_attempts} attempts: {e}")
61
+
62
+ # All attempts failed
63
+ if last_exception:
64
+ raise DiagnosticError(
65
+ f"{func.__name__} failed after {max_attempts} attempts",
66
+ f"Last error: {last_exception}",
67
+ ) from last_exception
68
+
69
+ # Should never reach here, but satisfy type checker
70
+ raise RuntimeError("Retry logic failed unexpectedly")
71
+
72
+ return wrapper
73
+
74
+ return decorator
75
+
76
+
77
+ def retry_network(func: Callable[..., T]) -> Callable[..., T]:
78
+ """
79
+ Decorator specifically for network operations.
80
+
81
+ Retries on network-related exceptions with exponential backoff.
82
+ """
83
+ return retry(
84
+ max_attempts=3,
85
+ delay=1.0,
86
+ backoff=2.0,
87
+ exceptions=(
88
+ ConnectionError,
89
+ TimeoutError,
90
+ OSError,
91
+ ),
92
+ )(func)