mlenvdoctor 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlenvdoctor/export.py ADDED
@@ -0,0 +1,290 @@
1
+ """Export functionality for diagnostic results."""
2
+
3
+ import json
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ from .diagnose import DiagnosticIssue
9
+
10
+
11
+ def issue_to_dict(issue: DiagnosticIssue) -> Dict[str, Any]:
12
+ """Convert DiagnosticIssue to dictionary."""
13
+ return {
14
+ "name": issue.name,
15
+ "status": issue.status,
16
+ "severity": issue.severity,
17
+ "fix": issue.fix,
18
+ "details": issue.details,
19
+ }
20
+
21
+
22
+ def export_json(
23
+ issues: List[DiagnosticIssue],
24
+ output_file: Optional[Path] = None,
25
+ include_metadata: bool = True,
26
+ ) -> Path:
27
+ """
28
+ Export diagnostic results to JSON.
29
+
30
+ Args:
31
+ issues: List of diagnostic issues
32
+ output_file: Output file path (default: diagnostic-results.json)
33
+ include_metadata: Include metadata (timestamp, version, etc.)
34
+
35
+ Returns:
36
+ Path to exported file
37
+ """
38
+ if output_file is None:
39
+ output_file = Path("diagnostic-results.json")
40
+
41
+ # Convert issues to dictionaries
42
+ issues_data = [issue_to_dict(issue) for issue in issues]
43
+
44
+ # Calculate summary
45
+ critical_count = sum(
46
+ 1 for i in issues if i.severity == "critical" and "FAIL" in i.status
47
+ )
48
+ warning_count = sum(
49
+ 1 for i in issues if i.severity == "warning" and ("WARN" in i.status or "FAIL" in i.status)
50
+ )
51
+ pass_count = sum(1 for i in issues if "PASS" in i.status)
52
+
53
+ # Build export data
54
+ export_data: Dict[str, Any] = {
55
+ "issues": issues_data,
56
+ "summary": {
57
+ "total": len(issues),
58
+ "passed": pass_count,
59
+ "warnings": warning_count,
60
+ "critical": critical_count,
61
+ },
62
+ }
63
+
64
+ # Add metadata if requested
65
+ if include_metadata:
66
+ from . import __version__
67
+
68
+ export_data["metadata"] = {
69
+ "version": __version__,
70
+ "timestamp": datetime.now().isoformat(),
71
+ "tool": "mlenvdoctor",
72
+ }
73
+
74
+ # Write to file
75
+ output_file.write_text(json.dumps(export_data, indent=2, ensure_ascii=False), encoding="utf-8")
76
+
77
+ return output_file
78
+
79
+
80
+ def export_csv(issues: List[DiagnosticIssue], output_file: Optional[Path] = None) -> Path:
81
+ """
82
+ Export diagnostic results to CSV.
83
+
84
+ Args:
85
+ issues: List of diagnostic issues
86
+ output_file: Output file path (default: diagnostic-results.csv)
87
+
88
+ Returns:
89
+ Path to exported file
90
+ """
91
+ import csv
92
+
93
+ if output_file is None:
94
+ output_file = Path("diagnostic-results.csv")
95
+
96
+ with output_file.open("w", newline="", encoding="utf-8") as f:
97
+ writer = csv.writer(f)
98
+ writer.writerow(["Issue", "Status", "Severity", "Fix", "Details"])
99
+
100
+ for issue in issues:
101
+ writer.writerow(
102
+ [
103
+ issue.name,
104
+ issue.status,
105
+ issue.severity,
106
+ issue.fix,
107
+ issue.details or "",
108
+ ]
109
+ )
110
+
111
+ return output_file
112
+
113
+
114
+ def export_html(issues: List[DiagnosticIssue], output_file: Optional[Path] = None) -> Path:
115
+ """
116
+ Export diagnostic results to HTML report.
117
+
118
+ Args:
119
+ issues: List of diagnostic issues
120
+ output_file: Output file path (default: diagnostic-results.html)
121
+
122
+ Returns:
123
+ Path to exported file
124
+ """
125
+ if output_file is None:
126
+ output_file = Path("diagnostic-results.html")
127
+
128
+ # Calculate summary
129
+ critical_count = sum(
130
+ 1 for i in issues if i.severity == "critical" and "FAIL" in i.status
131
+ )
132
+ warning_count = sum(
133
+ 1 for i in issues if i.severity == "warning" and ("WARN" in i.status or "FAIL" in i.status)
134
+ )
135
+ pass_count = sum(1 for i in issues if "PASS" in i.status)
136
+
137
+ from . import __version__
138
+
139
+ html_content = f"""<!DOCTYPE html>
140
+ <html lang="en">
141
+ <head>
142
+ <meta charset="UTF-8">
143
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
144
+ <title>ML Environment Doctor - Diagnostic Report</title>
145
+ <style>
146
+ body {{
147
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
148
+ max-width: 1200px;
149
+ margin: 0 auto;
150
+ padding: 20px;
151
+ background-color: #f5f5f5;
152
+ }}
153
+ .header {{
154
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
155
+ color: white;
156
+ padding: 30px;
157
+ border-radius: 10px;
158
+ margin-bottom: 20px;
159
+ }}
160
+ .summary {{
161
+ display: grid;
162
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
163
+ gap: 15px;
164
+ margin-bottom: 30px;
165
+ }}
166
+ .summary-card {{
167
+ background: white;
168
+ padding: 20px;
169
+ border-radius: 8px;
170
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
171
+ }}
172
+ .summary-card h3 {{
173
+ margin: 0 0 10px 0;
174
+ font-size: 14px;
175
+ color: #666;
176
+ }}
177
+ .summary-card .value {{
178
+ font-size: 32px;
179
+ font-weight: bold;
180
+ }}
181
+ .passed {{ color: #10b981; }}
182
+ .warning {{ color: #f59e0b; }}
183
+ .critical {{ color: #ef4444; }}
184
+ table {{
185
+ width: 100%;
186
+ background: white;
187
+ border-collapse: collapse;
188
+ border-radius: 8px;
189
+ overflow: hidden;
190
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
191
+ }}
192
+ th {{
193
+ background-color: #667eea;
194
+ color: white;
195
+ padding: 15px;
196
+ text-align: left;
197
+ font-weight: 600;
198
+ }}
199
+ td {{
200
+ padding: 12px 15px;
201
+ border-bottom: 1px solid #e5e7eb;
202
+ }}
203
+ tr:hover {{
204
+ background-color: #f9fafb;
205
+ }}
206
+ .status-pass {{ color: #10b981; font-weight: 600; }}
207
+ .status-fail {{ color: #ef4444; font-weight: 600; }}
208
+ .status-warn {{ color: #f59e0b; font-weight: 600; }}
209
+ .severity-critical {{ background-color: #fee2e2; }}
210
+ .severity-warning {{ background-color: #fef3c7; }}
211
+ .severity-info {{ background-color: #dbeafe; }}
212
+ .footer {{
213
+ margin-top: 30px;
214
+ text-align: center;
215
+ color: #666;
216
+ font-size: 12px;
217
+ }}
218
+ </style>
219
+ </head>
220
+ <body>
221
+ <div class="header">
222
+ <h1>ML Environment Doctor</h1>
223
+ <p>Diagnostic Report - Generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
224
+ <p style="font-size: 14px; opacity: 0.9;">Version {__version__}</p>
225
+ </div>
226
+
227
+ <div class="summary">
228
+ <div class="summary-card">
229
+ <h3>Total Checks</h3>
230
+ <div class="value">{len(issues)}</div>
231
+ </div>
232
+ <div class="summary-card">
233
+ <h3>Passed</h3>
234
+ <div class="value passed">{pass_count}</div>
235
+ </div>
236
+ <div class="summary-card">
237
+ <h3>Warnings</h3>
238
+ <div class="value warning">{warning_count}</div>
239
+ </div>
240
+ <div class="summary-card">
241
+ <h3>Critical Issues</h3>
242
+ <div class="value critical">{critical_count}</div>
243
+ </div>
244
+ </div>
245
+
246
+ <table>
247
+ <thead>
248
+ <tr>
249
+ <th>Issue</th>
250
+ <th>Status</th>
251
+ <th>Severity</th>
252
+ <th>Fix</th>
253
+ <th>Details</th>
254
+ </tr>
255
+ </thead>
256
+ <tbody>
257
+ """
258
+
259
+ for issue in issues:
260
+ status_class = "status-pass"
261
+ if "FAIL" in issue.status:
262
+ status_class = "status-fail"
263
+ elif "WARN" in issue.status:
264
+ status_class = "status-warn"
265
+
266
+ severity_class = f"severity-{issue.severity}"
267
+
268
+ html_content += f"""
269
+ <tr class="{severity_class}">
270
+ <td><strong>{issue.name}</strong></td>
271
+ <td class="{status_class}">{issue.status}</td>
272
+ <td>{issue.severity.upper()}</td>
273
+ <td>{issue.fix or '-'}</td>
274
+ <td>{issue.details or '-'}</td>
275
+ </tr>
276
+ """
277
+
278
+ html_content += """
279
+ </tbody>
280
+ </table>
281
+
282
+ <div class="footer">
283
+ <p>Generated by ML Environment Doctor | <a href="https://github.com/dheena731/ml_env_doctor">GitHub</a></p>
284
+ </div>
285
+ </body>
286
+ </html>
287
+ """
288
+
289
+ output_file.write_text(html_content, encoding="utf-8")
290
+ return output_file
mlenvdoctor/fix.py CHANGED
@@ -5,6 +5,7 @@ from pathlib import Path
5
5
  from typing import Optional
6
6
 
7
7
  from .diagnose import diagnose_env
8
+ from .icons import icon_wrench
8
9
  from .utils import (
9
10
  check_command_exists,
10
11
  console,
@@ -212,7 +213,7 @@ def create_virtualenv(env_name: str = ".venv") -> Optional[Path]:
212
213
 
213
214
  def auto_fix(use_conda: bool = False, create_venv: bool = False, stack: str = "trl-peft") -> bool:
214
215
  """Auto-fix environment issues based on diagnostics."""
215
- console.print("[bold blue]🔧 Running Auto-Fix...[/bold blue]\n")
216
+ console.print(f"[bold blue]{icon_wrench()} Running Auto-Fix...[/bold blue]\n")
216
217
 
217
218
  # Run diagnostics
218
219
  issues = diagnose_env(full=False)
mlenvdoctor/gpu.py CHANGED
@@ -8,11 +8,17 @@ try:
8
8
  except ImportError:
9
9
  torch = None # type: ignore
10
10
 
11
+ import sys
12
+
11
13
  from rich.console import Console
12
14
 
13
15
  from .utils import print_error, print_info, print_success
14
16
 
15
- console = Console()
17
+ # Configure console for Windows compatibility
18
+ if sys.platform == "win32":
19
+ console = Console(legacy_windows=True, force_terminal=True)
20
+ else:
21
+ console = Console()
16
22
 
17
23
 
18
24
  def benchmark_gpu_ops() -> Dict[str, float]:
mlenvdoctor/icons.py ADDED
@@ -0,0 +1,100 @@
1
+ """Safe emoji/icon handling for cross-platform compatibility."""
2
+
3
+ import sys
4
+ from typing import Literal
5
+
6
+ # Check if we can safely use emojis
7
+ _USE_EMOJIS = True
8
+
9
+ try:
10
+ # Try to write an emoji to see if it works
11
+ if sys.platform == "win32":
12
+ import io
13
+
14
+ # Test if console supports UTF-8
15
+ test_output = io.StringIO()
16
+ try:
17
+ test_output.write("🔍")
18
+ test_output.getvalue()
19
+ except (UnicodeEncodeError, UnicodeError):
20
+ _USE_EMOJIS = False
21
+ except Exception:
22
+ _USE_EMOJIS = False
23
+
24
+
25
+ def get_icon(icon_name: Literal["search", "check", "cross", "warning", "info", "wrench", "whale", "test"]) -> str:
26
+ """
27
+ Get a safe icon/emoji for the current platform.
28
+
29
+ Args:
30
+ icon_name: Name of the icon to get
31
+
32
+ Returns:
33
+ Emoji if supported, ASCII alternative otherwise
34
+ """
35
+ if _USE_EMOJIS:
36
+ icons = {
37
+ "search": "🔍",
38
+ "check": "✅",
39
+ "cross": "❌",
40
+ "warning": "⚠️",
41
+ "info": "ℹ️",
42
+ "wrench": "🔧",
43
+ "whale": "🐳",
44
+ "test": "🧪",
45
+ }
46
+ else:
47
+ # ASCII alternatives
48
+ icons = {
49
+ "search": "[*]",
50
+ "check": "[OK]",
51
+ "cross": "[X]",
52
+ "warning": "[!]",
53
+ "info": "[i]",
54
+ "wrench": "[FIX]",
55
+ "whale": "[DOCKER]",
56
+ "test": "[TEST]",
57
+ }
58
+
59
+ return icons.get(icon_name, "")
60
+
61
+
62
+ # Convenience functions
63
+ def icon_search() -> str:
64
+ """Get search icon."""
65
+ return get_icon("search")
66
+
67
+
68
+ def icon_check() -> str:
69
+ """Get check icon."""
70
+ return get_icon("check")
71
+
72
+
73
+ def icon_cross() -> str:
74
+ """Get cross/error icon."""
75
+ return get_icon("cross")
76
+
77
+
78
+ def icon_warning() -> str:
79
+ """Get warning icon."""
80
+ return get_icon("warning")
81
+
82
+
83
+ def icon_info() -> str:
84
+ """Get info icon."""
85
+ return get_icon("info")
86
+
87
+
88
+ def icon_wrench() -> str:
89
+ """Get wrench/fix icon."""
90
+ return get_icon("wrench")
91
+
92
+
93
+ def icon_whale() -> str:
94
+ """Get whale/docker icon."""
95
+ return get_icon("whale")
96
+
97
+
98
+ def icon_test() -> str:
99
+ """Get test icon."""
100
+ return get_icon("test")
mlenvdoctor/logger.py ADDED
@@ -0,0 +1,81 @@
1
+ """Logging configuration for ML Environment Doctor."""
2
+
3
+ import logging
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ from rich.console import Console
9
+ from rich.logging import RichHandler
10
+
11
+ from .utils import get_home_config_dir
12
+
13
+ console = Console()
14
+
15
+
16
+ def setup_logger(
17
+ name: str = "mlenvdoctor",
18
+ level: str = "INFO",
19
+ log_file: Optional[Path] = None,
20
+ enable_rich: bool = True,
21
+ ) -> logging.Logger:
22
+ """
23
+ Set up logger with Rich console handler and optional file handler.
24
+
25
+ Args:
26
+ name: Logger name
27
+ level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
28
+ log_file: Optional path to log file
29
+ enable_rich: Use Rich handler for console output
30
+
31
+ Returns:
32
+ Configured logger instance
33
+ """
34
+ logger = logging.getLogger(name)
35
+ logger.setLevel(getattr(logging, level.upper()))
36
+
37
+ # Remove existing handlers to avoid duplicates
38
+ logger.handlers.clear()
39
+
40
+ # Console handler with Rich formatting
41
+ if enable_rich:
42
+ console_handler = RichHandler(
43
+ console=console,
44
+ show_time=True,
45
+ show_path=False,
46
+ rich_tracebacks=True,
47
+ tracebacks_show_locals=False,
48
+ )
49
+ else:
50
+ console_handler = logging.StreamHandler(sys.stdout)
51
+ console_handler.setFormatter(
52
+ logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
53
+ )
54
+
55
+ console_handler.setLevel(getattr(logging, level.upper()))
56
+ logger.addHandler(console_handler)
57
+
58
+ # File handler if log file specified
59
+ if log_file:
60
+ log_file.parent.mkdir(parents=True, exist_ok=True)
61
+ file_handler = logging.FileHandler(log_file, encoding="utf-8")
62
+ file_handler.setLevel(logging.DEBUG) # Always log everything to file
63
+ file_handler.setFormatter(
64
+ logging.Formatter(
65
+ "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s"
66
+ )
67
+ )
68
+ logger.addHandler(file_handler)
69
+
70
+ return logger
71
+
72
+
73
+ def get_default_log_file() -> Path:
74
+ """Get default log file path."""
75
+ log_dir = get_home_config_dir() / "logs"
76
+ log_dir.mkdir(parents=True, exist_ok=True)
77
+ return log_dir / "mlenvdoctor.log"
78
+
79
+
80
+ # Default logger instance
81
+ logger = setup_logger()
@@ -0,0 +1,115 @@
1
+ """Parallel execution utilities for independent operations."""
2
+
3
+ import concurrent.futures
4
+ from typing import Callable, Iterable, List, TypeVar
5
+
6
+ from .logger import logger
7
+
8
+ T = TypeVar("T")
9
+ R = TypeVar("R")
10
+
11
+
12
+ def run_parallel(
13
+ func: Callable[[T], R],
14
+ items: Iterable[T],
15
+ max_workers: int = 4,
16
+ timeout: float | None = None,
17
+ ) -> List[R]:
18
+ """
19
+ Run a function in parallel on multiple items.
20
+
21
+ Args:
22
+ func: Function to execute
23
+ items: Iterable of items to process
24
+ items_list: List of items to process
25
+ max_workers: Maximum number of parallel workers
26
+ timeout: Maximum time to wait for all tasks (None = no timeout)
27
+
28
+ Returns:
29
+ List of results in the same order as input items
30
+
31
+ Example:
32
+ def check_library(name: str) -> bool:
33
+ return importlib.util.find_spec(name) is not None
34
+
35
+ results = run_parallel(check_library, ["torch", "transformers", "peft"])
36
+ """
37
+ items_list = list(items)
38
+
39
+ if not items_list:
40
+ return []
41
+
42
+ # Use ThreadPoolExecutor for I/O-bound operations
43
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
44
+ # Submit all tasks
45
+ future_to_item = {executor.submit(func, item): item for item in items_list}
46
+
47
+ results: List[R] = []
48
+ completed = 0
49
+
50
+ # Process completed tasks
51
+ for future in concurrent.futures.as_completed(future_to_item, timeout=timeout):
52
+ item = future_to_item[future]
53
+ try:
54
+ result = future.result()
55
+ results.append(result)
56
+ completed += 1
57
+ except Exception as e:
58
+ logger.error(f"Error processing {item}: {e}")
59
+ # Re-raise to maintain error behavior
60
+ raise
61
+
62
+ if completed != len(items_list):
63
+ raise RuntimeError(f"Only {completed}/{len(items_list)} tasks completed")
64
+
65
+ return results
66
+
67
+
68
+ def run_parallel_with_results(
69
+ func: Callable[[T], R],
70
+ items: Iterable[T],
71
+ max_workers: int = 4,
72
+ timeout: float | None = None,
73
+ ) -> List[tuple[T, R | Exception]]:
74
+ """
75
+ Run a function in parallel and return results with original items.
76
+
77
+ Unlike run_parallel, this catches exceptions and returns them as results.
78
+
79
+ Args:
80
+ func: Function to execute
81
+ items: Iterable of items to process
82
+ max_workers: Maximum number of parallel workers
83
+ timeout: Maximum time to wait for all tasks
84
+
85
+ Returns:
86
+ List of (item, result_or_exception) tuples
87
+
88
+ Example:
89
+ def check_library(name: str) -> bool:
90
+ if name == "bad":
91
+ raise ValueError("Bad library")
92
+ return True
93
+
94
+ results = run_parallel_with_results(check_library, ["torch", "bad", "peft"])
95
+ # Returns: [("torch", True), ("bad", ValueError(...)), ("peft", True)]
96
+ """
97
+ items_list = list(items)
98
+
99
+ if not items_list:
100
+ return []
101
+
102
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
103
+ future_to_item = {executor.submit(func, item): item for item in items_list}
104
+
105
+ results: List[tuple[T, R | Exception]] = []
106
+
107
+ for future in concurrent.futures.as_completed(future_to_item, timeout=timeout):
108
+ item = future_to_item[future]
109
+ try:
110
+ result = future.result()
111
+ results.append((item, result))
112
+ except Exception as e:
113
+ results.append((item, e))
114
+
115
+ return results