mlx-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_stack/__init__.py +5 -0
- mlx_stack/_version.py +24 -0
- mlx_stack/cli/__init__.py +5 -0
- mlx_stack/cli/bench.py +221 -0
- mlx_stack/cli/config.py +166 -0
- mlx_stack/cli/down.py +109 -0
- mlx_stack/cli/init.py +180 -0
- mlx_stack/cli/install.py +165 -0
- mlx_stack/cli/logs.py +234 -0
- mlx_stack/cli/main.py +187 -0
- mlx_stack/cli/models.py +304 -0
- mlx_stack/cli/profile.py +65 -0
- mlx_stack/cli/pull.py +134 -0
- mlx_stack/cli/recommend.py +397 -0
- mlx_stack/cli/status.py +111 -0
- mlx_stack/cli/up.py +163 -0
- mlx_stack/cli/watch.py +252 -0
- mlx_stack/core/__init__.py +1 -0
- mlx_stack/core/benchmark.py +1182 -0
- mlx_stack/core/catalog.py +560 -0
- mlx_stack/core/config.py +471 -0
- mlx_stack/core/deps.py +323 -0
- mlx_stack/core/hardware.py +304 -0
- mlx_stack/core/launchd.py +531 -0
- mlx_stack/core/litellm_gen.py +188 -0
- mlx_stack/core/log_rotation.py +231 -0
- mlx_stack/core/log_viewer.py +386 -0
- mlx_stack/core/models.py +639 -0
- mlx_stack/core/paths.py +79 -0
- mlx_stack/core/process.py +887 -0
- mlx_stack/core/pull.py +815 -0
- mlx_stack/core/scoring.py +611 -0
- mlx_stack/core/stack_down.py +317 -0
- mlx_stack/core/stack_init.py +524 -0
- mlx_stack/core/stack_status.py +229 -0
- mlx_stack/core/stack_up.py +856 -0
- mlx_stack/core/watchdog.py +744 -0
- mlx_stack/data/__init__.py +1 -0
- mlx_stack/data/catalog/__init__.py +1 -0
- mlx_stack/data/catalog/deepseek-r1-32b.yaml +46 -0
- mlx_stack/data/catalog/deepseek-r1-8b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-12b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-27b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-4b.yaml +45 -0
- mlx_stack/data/catalog/llama3.3-8b.yaml +44 -0
- mlx_stack/data/catalog/nemotron-49b.yaml +41 -0
- mlx_stack/data/catalog/nemotron-8b.yaml +44 -0
- mlx_stack/data/catalog/qwen3-8b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-0.8b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-14b.yaml +46 -0
- mlx_stack/data/catalog/qwen3.5-32b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-3b.yaml +44 -0
- mlx_stack/data/catalog/qwen3.5-72b.yaml +42 -0
- mlx_stack/data/catalog/qwen3.5-8b.yaml +45 -0
- mlx_stack/py.typed +1 -0
- mlx_stack/utils/__init__.py +1 -0
- mlx_stack-0.1.0.dist-info/METADATA +397 -0
- mlx_stack-0.1.0.dist-info/RECORD +61 -0
- mlx_stack-0.1.0.dist-info/WHEEL +4 -0
- mlx_stack-0.1.0.dist-info/entry_points.txt +2 -0
- mlx_stack-0.1.0.dist-info/licenses/LICENSE +21 -0
mlx_stack/core/deps.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
"""Dependency management module for mlx-stack.
|
|
2
|
+
|
|
3
|
+
Checks for and auto-installs pinned versions of vllm-mlx and litellm
|
|
4
|
+
as uv tools. Performs PATH lookup to detect installed tools, auto-installs
|
|
5
|
+
via ``uv tool install <tool>==<version>`` when missing, shows progress
|
|
6
|
+
during installation, verifies post-install availability, detects version
|
|
7
|
+
mismatches with warnings, and provides clear manual install instructions
|
|
8
|
+
on failure.
|
|
9
|
+
|
|
10
|
+
Read-only commands (profile, config, recommend, models) do NOT trigger
|
|
11
|
+
dependency checks.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
import shutil
|
|
19
|
+
import subprocess
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
22
|
+
from rich.console import Console
|
|
23
|
+
|
|
24
|
+
# --------------------------------------------------------------------------- #
|
|
25
|
+
# Pinned versions
|
|
26
|
+
# --------------------------------------------------------------------------- #
|
|
27
|
+
|
|
28
|
+
PINNED_VERSIONS: dict[str, str] = {
|
|
29
|
+
"vllm-mlx": "0.2.6",
|
|
30
|
+
"litellm": "1.67.2",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
# Map tool name to the CLI binary name used for PATH lookup
|
|
34
|
+
_TOOL_BINARY_MAP: dict[str, str] = {
|
|
35
|
+
"vllm-mlx": "vllm-mlx",
|
|
36
|
+
"litellm": "litellm",
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
# --------------------------------------------------------------------------- #
|
|
40
|
+
# Exceptions
|
|
41
|
+
# --------------------------------------------------------------------------- #
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DependencyError(Exception):
|
|
45
|
+
"""Raised when a required dependency cannot be found or installed."""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class DependencyInstallError(DependencyError):
|
|
49
|
+
"""Raised when ``uv tool install`` fails."""
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class DependencyVersionMismatchWarning(Exception):
|
|
53
|
+
"""Sentinel for version mismatch warnings (not raised, used for typing)."""
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# --------------------------------------------------------------------------- #
|
|
57
|
+
# Data classes
|
|
58
|
+
# --------------------------------------------------------------------------- #
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass(frozen=True)
|
|
62
|
+
class DependencyStatus:
|
|
63
|
+
"""Status of a single managed dependency.
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
name: Package name (e.g. ``vllm-mlx``).
|
|
67
|
+
pinned_version: The version we want installed.
|
|
68
|
+
installed: Whether the tool binary is available on PATH.
|
|
69
|
+
installed_version: The detected installed version, or ``None``.
|
|
70
|
+
version_match: ``True`` when installed version matches pinned.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
name: str
|
|
74
|
+
pinned_version: str
|
|
75
|
+
installed: bool
|
|
76
|
+
installed_version: str | None
|
|
77
|
+
version_match: bool | None
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# --------------------------------------------------------------------------- #
|
|
81
|
+
# Internal helpers
|
|
82
|
+
# --------------------------------------------------------------------------- #
|
|
83
|
+
|
|
84
|
+
_console = Console(stderr=True)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _find_binary(tool: str) -> str | None:
|
|
88
|
+
"""Return the full path to a tool's binary, or ``None`` if not found.
|
|
89
|
+
|
|
90
|
+
Uses :func:`shutil.which` for PATH lookup.
|
|
91
|
+
"""
|
|
92
|
+
binary = _TOOL_BINARY_MAP.get(tool, tool)
|
|
93
|
+
return shutil.which(binary)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _get_installed_version(tool: str) -> str | None:
|
|
97
|
+
"""Attempt to determine the installed version of a uv tool.
|
|
98
|
+
|
|
99
|
+
Runs ``uv tool list`` and parses the output for the tool name and
|
|
100
|
+
version. Returns ``None`` when the tool is not found or the output
|
|
101
|
+
cannot be parsed.
|
|
102
|
+
"""
|
|
103
|
+
uv_path = shutil.which("uv")
|
|
104
|
+
if uv_path is None:
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
env = {**os.environ, "NO_COLOR": "1"}
|
|
109
|
+
result = subprocess.run(
|
|
110
|
+
[uv_path, "tool", "list"],
|
|
111
|
+
capture_output=True,
|
|
112
|
+
text=True,
|
|
113
|
+
timeout=30,
|
|
114
|
+
env=env,
|
|
115
|
+
)
|
|
116
|
+
except (subprocess.TimeoutExpired, OSError):
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
if result.returncode != 0:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
# Parse lines like "vllm-mlx v0.2.6" or "litellm v1.67.2"
|
|
123
|
+
for line in result.stdout.splitlines():
|
|
124
|
+
# Match "<tool> v<version>" pattern
|
|
125
|
+
pattern = rf"^{re.escape(tool)}\s+v?(\S+)"
|
|
126
|
+
match = re.match(pattern, line.strip())
|
|
127
|
+
if match:
|
|
128
|
+
return match.group(1)
|
|
129
|
+
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _install_tool(tool: str, version: str) -> None:
|
|
134
|
+
"""Install a tool at a pinned version via ``uv tool install``.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
tool: Package name (e.g. ``vllm-mlx``).
|
|
138
|
+
version: Exact version to install.
|
|
139
|
+
|
|
140
|
+
Raises:
|
|
141
|
+
DependencyInstallError: When ``uv tool install`` fails.
|
|
142
|
+
DependencyError: When ``uv`` itself is not available.
|
|
143
|
+
"""
|
|
144
|
+
uv_path = shutil.which("uv")
|
|
145
|
+
if uv_path is None:
|
|
146
|
+
msg = (
|
|
147
|
+
"uv is not available on PATH. "
|
|
148
|
+
"Install it from https://docs.astral.sh/uv/ and try again."
|
|
149
|
+
)
|
|
150
|
+
raise DependencyError(msg)
|
|
151
|
+
|
|
152
|
+
install_spec = f"{tool}=={version}"
|
|
153
|
+
cmd = [uv_path, "tool", "install", install_spec]
|
|
154
|
+
cmd_str = f"uv tool install {install_spec}"
|
|
155
|
+
|
|
156
|
+
_console.print(f"[cyan]Installing {tool} v{version}...[/cyan]")
|
|
157
|
+
|
|
158
|
+
try:
|
|
159
|
+
result = subprocess.run(
|
|
160
|
+
cmd,
|
|
161
|
+
capture_output=True,
|
|
162
|
+
text=True,
|
|
163
|
+
timeout=300,
|
|
164
|
+
)
|
|
165
|
+
except subprocess.TimeoutExpired:
|
|
166
|
+
msg = (
|
|
167
|
+
f"Installation timed out: {cmd_str}\n\n"
|
|
168
|
+
f"Install manually with: {cmd_str}"
|
|
169
|
+
)
|
|
170
|
+
raise DependencyInstallError(msg) from None
|
|
171
|
+
except OSError as exc:
|
|
172
|
+
msg = (
|
|
173
|
+
f"Failed to run: {cmd_str}\n"
|
|
174
|
+
f"Error: {exc}\n\n"
|
|
175
|
+
f"Install manually with: {cmd_str}"
|
|
176
|
+
)
|
|
177
|
+
raise DependencyInstallError(msg) from None
|
|
178
|
+
|
|
179
|
+
if result.returncode != 0:
|
|
180
|
+
stderr = result.stderr.strip()
|
|
181
|
+
msg = (
|
|
182
|
+
f"Failed to install {tool}:\n"
|
|
183
|
+
f" Command: {cmd_str}\n"
|
|
184
|
+
f" Error: {stderr}\n\n"
|
|
185
|
+
f"Install manually with: {cmd_str}"
|
|
186
|
+
)
|
|
187
|
+
raise DependencyInstallError(msg)
|
|
188
|
+
|
|
189
|
+
_console.print(f"[green]✓ {tool} v{version} installed successfully.[/green]")
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _verify_post_install(tool: str) -> bool:
|
|
193
|
+
"""Verify that a tool is available on PATH after installation.
|
|
194
|
+
|
|
195
|
+
Returns ``True`` if the binary is found.
|
|
196
|
+
"""
|
|
197
|
+
return _find_binary(tool) is not None
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
# --------------------------------------------------------------------------- #
|
|
201
|
+
# Public API
|
|
202
|
+
# --------------------------------------------------------------------------- #
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def check_dependency(tool: str) -> DependencyStatus:
|
|
206
|
+
"""Check the installation status of a single dependency.
|
|
207
|
+
|
|
208
|
+
Performs PATH lookup and version detection without installing anything.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
tool: Package name (must be a key in :data:`PINNED_VERSIONS`).
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
A :class:`DependencyStatus` with installation details.
|
|
215
|
+
|
|
216
|
+
Raises:
|
|
217
|
+
ValueError: If *tool* is not a known dependency.
|
|
218
|
+
"""
|
|
219
|
+
if tool not in PINNED_VERSIONS:
|
|
220
|
+
msg = f"Unknown dependency '{tool}'. Known: {', '.join(sorted(PINNED_VERSIONS))}"
|
|
221
|
+
raise ValueError(msg)
|
|
222
|
+
|
|
223
|
+
pinned = PINNED_VERSIONS[tool]
|
|
224
|
+
binary_path = _find_binary(tool)
|
|
225
|
+
installed = binary_path is not None
|
|
226
|
+
|
|
227
|
+
installed_version: str | None = None
|
|
228
|
+
version_match: bool | None = None
|
|
229
|
+
|
|
230
|
+
if installed:
|
|
231
|
+
installed_version = _get_installed_version(tool)
|
|
232
|
+
if installed_version is not None:
|
|
233
|
+
version_match = installed_version == pinned
|
|
234
|
+
|
|
235
|
+
return DependencyStatus(
|
|
236
|
+
name=tool,
|
|
237
|
+
pinned_version=pinned,
|
|
238
|
+
installed=installed,
|
|
239
|
+
installed_version=installed_version,
|
|
240
|
+
version_match=version_match,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def ensure_dependency(tool: str) -> DependencyStatus:
|
|
245
|
+
"""Ensure a dependency is installed at the pinned version.
|
|
246
|
+
|
|
247
|
+
1. Check if the tool is already available on PATH.
|
|
248
|
+
2. If missing, auto-install via ``uv tool install <tool>==<version>``.
|
|
249
|
+
3. Verify post-install availability.
|
|
250
|
+
4. Warn on version mismatch (but do not block).
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
tool: Package name (must be a key in :data:`PINNED_VERSIONS`).
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
A :class:`DependencyStatus` reflecting the final state.
|
|
257
|
+
|
|
258
|
+
Raises:
|
|
259
|
+
ValueError: If *tool* is not a known dependency.
|
|
260
|
+
DependencyError: If ``uv`` is not available.
|
|
261
|
+
DependencyInstallError: If auto-install fails.
|
|
262
|
+
"""
|
|
263
|
+
status = check_dependency(tool)
|
|
264
|
+
|
|
265
|
+
if not status.installed:
|
|
266
|
+
# Auto-install
|
|
267
|
+
_install_tool(tool, status.pinned_version)
|
|
268
|
+
|
|
269
|
+
# Verify post-install
|
|
270
|
+
if not _verify_post_install(tool):
|
|
271
|
+
cmd_str = f"uv tool install {tool}=={status.pinned_version}"
|
|
272
|
+
msg = (
|
|
273
|
+
f"{tool} was not found on PATH after installation.\n"
|
|
274
|
+
f"This may be because the uv tool bin directory is not in your PATH.\n\n"
|
|
275
|
+
f"Try running:\n"
|
|
276
|
+
f" {cmd_str}\n"
|
|
277
|
+
f" export PATH=\"$HOME/.local/bin:$PATH\""
|
|
278
|
+
)
|
|
279
|
+
raise DependencyInstallError(msg)
|
|
280
|
+
|
|
281
|
+
# Re-check status after install
|
|
282
|
+
status = check_dependency(tool)
|
|
283
|
+
|
|
284
|
+
# Warn on version mismatch
|
|
285
|
+
if status.installed and status.version_match is False:
|
|
286
|
+
_warn_version_mismatch(tool, status)
|
|
287
|
+
|
|
288
|
+
return status
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def ensure_all_dependencies() -> list[DependencyStatus]:
|
|
292
|
+
"""Ensure all managed dependencies are installed.
|
|
293
|
+
|
|
294
|
+
Calls :func:`ensure_dependency` for each tool in :data:`PINNED_VERSIONS`.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
A list of :class:`DependencyStatus` for each dependency.
|
|
298
|
+
|
|
299
|
+
Raises:
|
|
300
|
+
DependencyError: If any dependency cannot be installed.
|
|
301
|
+
DependencyInstallError: If any auto-install fails.
|
|
302
|
+
"""
|
|
303
|
+
results: list[DependencyStatus] = []
|
|
304
|
+
for tool in PINNED_VERSIONS:
|
|
305
|
+
results.append(ensure_dependency(tool))
|
|
306
|
+
return results
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _warn_version_mismatch(tool: str, status: DependencyStatus) -> None:
|
|
310
|
+
"""Display a Rich warning about a version mismatch.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
tool: The package name.
|
|
314
|
+
status: The dependency status with version info.
|
|
315
|
+
"""
|
|
316
|
+
pinned = status.pinned_version
|
|
317
|
+
installed = status.installed_version or "unknown"
|
|
318
|
+
cmd = f"uv tool install {tool}=={pinned}"
|
|
319
|
+
_console.print(
|
|
320
|
+
f"[yellow]⚠ {tool} version mismatch: "
|
|
321
|
+
f"installed v{installed}, expected v{pinned}.[/yellow]\n"
|
|
322
|
+
f" Upgrade/downgrade with: {cmd}"
|
|
323
|
+
)
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""Hardware detection module for Apple Silicon Macs.
|
|
2
|
+
|
|
3
|
+
Detects chip model, GPU core count, unified memory, and memory bandwidth.
|
|
4
|
+
Uses a lookup table of 17 known M-series variants (M1 through M5) and
|
|
5
|
+
estimates bandwidth for unknown chips.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import re
|
|
12
|
+
import subprocess
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from mlx_stack.core.paths import ensure_data_home, get_profile_path
|
|
17
|
+
|
|
18
|
+
# --------------------------------------------------------------------------- #
|
|
19
|
+
# Lookup table — 17 known Apple Silicon M-series variants
|
|
20
|
+
# bandwidth_gbps values sourced from Apple published specs
|
|
21
|
+
# --------------------------------------------------------------------------- #
|
|
22
|
+
CHIP_SPECS: dict[str, dict[str, float | int]] = {
|
|
23
|
+
"Apple M1": {"bandwidth_gbps": 68.25},
|
|
24
|
+
"Apple M1 Pro": {"bandwidth_gbps": 200.0},
|
|
25
|
+
"Apple M1 Max": {"bandwidth_gbps": 400.0},
|
|
26
|
+
"Apple M1 Ultra": {"bandwidth_gbps": 800.0},
|
|
27
|
+
"Apple M2": {"bandwidth_gbps": 100.0},
|
|
28
|
+
"Apple M2 Pro": {"bandwidth_gbps": 200.0},
|
|
29
|
+
"Apple M2 Max": {"bandwidth_gbps": 400.0},
|
|
30
|
+
"Apple M2 Ultra": {"bandwidth_gbps": 800.0},
|
|
31
|
+
"Apple M3": {"bandwidth_gbps": 100.0},
|
|
32
|
+
"Apple M3 Pro": {"bandwidth_gbps": 150.0},
|
|
33
|
+
"Apple M3 Max": {"bandwidth_gbps": 400.0},
|
|
34
|
+
"Apple M3 Ultra": {"bandwidth_gbps": 800.0},
|
|
35
|
+
"Apple M4": {"bandwidth_gbps": 120.0},
|
|
36
|
+
"Apple M4 Pro": {"bandwidth_gbps": 273.0},
|
|
37
|
+
"Apple M4 Max": {"bandwidth_gbps": 546.0},
|
|
38
|
+
"Apple M4 Ultra": {"bandwidth_gbps": 819.2},
|
|
39
|
+
"Apple M5 Max": {"bandwidth_gbps": 546.0},
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class HardwareError(Exception):
|
|
44
|
+
"""Raised when hardware detection fails or hardware is unsupported."""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True)
|
|
48
|
+
class HardwareProfile:
|
|
49
|
+
"""Detected hardware profile for the current machine."""
|
|
50
|
+
|
|
51
|
+
chip: str
|
|
52
|
+
gpu_cores: int
|
|
53
|
+
memory_gb: int
|
|
54
|
+
bandwidth_gbps: float
|
|
55
|
+
is_estimate: bool
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def profile_id(self) -> str:
|
|
59
|
+
"""Generate a profile ID like 'm4-pro-64'."""
|
|
60
|
+
# Strip 'Apple ' prefix, lowercase, replace spaces with hyphens
|
|
61
|
+
chip_part = self.chip.removeprefix("Apple ").lower().replace(" ", "-")
|
|
62
|
+
return f"{chip_part}-{self.memory_gb}"
|
|
63
|
+
|
|
64
|
+
def to_dict(self) -> dict[str, Any]:
|
|
65
|
+
"""Serialize to a dictionary suitable for JSON output."""
|
|
66
|
+
return {
|
|
67
|
+
"chip": self.chip,
|
|
68
|
+
"gpu_cores": self.gpu_cores,
|
|
69
|
+
"memory_gb": self.memory_gb,
|
|
70
|
+
"bandwidth_gbps": self.bandwidth_gbps,
|
|
71
|
+
"profile_id": self.profile_id,
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# --------------------------------------------------------------------------- #
|
|
76
|
+
# System command wrappers (mockable for tests)
|
|
77
|
+
# --------------------------------------------------------------------------- #
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _run_sysctl(key: str) -> str:
|
|
81
|
+
"""Run sysctl and return the output for a given key.
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
HardwareError: If the sysctl command fails.
|
|
85
|
+
"""
|
|
86
|
+
try:
|
|
87
|
+
result = subprocess.run(
|
|
88
|
+
["sysctl", "-n", key],
|
|
89
|
+
capture_output=True,
|
|
90
|
+
text=True,
|
|
91
|
+
timeout=10,
|
|
92
|
+
)
|
|
93
|
+
if result.returncode != 0:
|
|
94
|
+
msg = f"sysctl failed for key '{key}': {result.stderr.strip()}"
|
|
95
|
+
raise HardwareError(msg)
|
|
96
|
+
return result.stdout.strip()
|
|
97
|
+
except FileNotFoundError:
|
|
98
|
+
msg = "sysctl command not found — are you running on macOS?"
|
|
99
|
+
raise HardwareError(msg) from None
|
|
100
|
+
except subprocess.TimeoutExpired:
|
|
101
|
+
msg = f"sysctl timed out reading key '{key}'"
|
|
102
|
+
raise HardwareError(msg) from None
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _run_system_profiler() -> str:
|
|
106
|
+
"""Run system_profiler SPDisplaysDataType and return output.
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
HardwareError: If the system_profiler command fails.
|
|
110
|
+
"""
|
|
111
|
+
try:
|
|
112
|
+
result = subprocess.run(
|
|
113
|
+
["system_profiler", "SPDisplaysDataType"],
|
|
114
|
+
capture_output=True,
|
|
115
|
+
text=True,
|
|
116
|
+
timeout=30,
|
|
117
|
+
)
|
|
118
|
+
if result.returncode != 0:
|
|
119
|
+
msg = f"system_profiler failed: {result.stderr.strip()}"
|
|
120
|
+
raise HardwareError(msg)
|
|
121
|
+
return result.stdout
|
|
122
|
+
except FileNotFoundError:
|
|
123
|
+
msg = "system_profiler command not found — are you running on macOS?"
|
|
124
|
+
raise HardwareError(msg) from None
|
|
125
|
+
except subprocess.TimeoutExpired:
|
|
126
|
+
msg = "system_profiler timed out"
|
|
127
|
+
raise HardwareError(msg) from None
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# --------------------------------------------------------------------------- #
|
|
131
|
+
# Detection functions
|
|
132
|
+
# --------------------------------------------------------------------------- #
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def detect_chip() -> str:
|
|
136
|
+
"""Detect the chip model name from sysctl.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Chip name string, e.g. 'Apple M4 Pro'.
|
|
140
|
+
|
|
141
|
+
Raises:
|
|
142
|
+
HardwareError: If the chip cannot be detected or is not Apple Silicon.
|
|
143
|
+
"""
|
|
144
|
+
brand = _run_sysctl("machdep.cpu.brand_string")
|
|
145
|
+
if not brand:
|
|
146
|
+
msg = "Could not read CPU brand string from sysctl"
|
|
147
|
+
raise HardwareError(msg)
|
|
148
|
+
|
|
149
|
+
# Validate it's Apple Silicon (M-series)
|
|
150
|
+
if not re.match(r"Apple M\d", brand):
|
|
151
|
+
msg = "mlx-stack requires Apple Silicon (M1 or later)"
|
|
152
|
+
raise HardwareError(msg)
|
|
153
|
+
|
|
154
|
+
return brand
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def detect_memory_gb() -> int:
|
|
158
|
+
"""Detect unified memory in GB from sysctl hw.memsize.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Memory in GB as an integer.
|
|
162
|
+
|
|
163
|
+
Raises:
|
|
164
|
+
HardwareError: If memory size cannot be read.
|
|
165
|
+
"""
|
|
166
|
+
raw = _run_sysctl("hw.memsize")
|
|
167
|
+
try:
|
|
168
|
+
memsize_bytes = int(raw)
|
|
169
|
+
except ValueError:
|
|
170
|
+
msg = f"Unexpected hw.memsize value: {raw!r}"
|
|
171
|
+
raise HardwareError(msg) from None
|
|
172
|
+
|
|
173
|
+
return memsize_bytes // (1024 ** 3)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def detect_gpu_cores() -> int:
|
|
177
|
+
"""Detect GPU core count from system_profiler.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Number of GPU cores.
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
HardwareError: If GPU core count cannot be determined.
|
|
184
|
+
"""
|
|
185
|
+
output = _run_system_profiler()
|
|
186
|
+
|
|
187
|
+
# Parse "Total Number of Cores: 40" from system_profiler output
|
|
188
|
+
match = re.search(r"Total Number of Cores:\s*(\d+)", output)
|
|
189
|
+
if not match:
|
|
190
|
+
msg = "Could not determine GPU core count from system_profiler output"
|
|
191
|
+
raise HardwareError(msg)
|
|
192
|
+
|
|
193
|
+
return int(match.group(1))
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def estimate_bandwidth(memory_gb: int) -> float:
|
|
197
|
+
"""Estimate memory bandwidth for unknown chips based on memory size.
|
|
198
|
+
|
|
199
|
+
Uses a simple heuristic: larger memory configurations tend to have
|
|
200
|
+
higher bandwidth. This is a rough estimate — users should run
|
|
201
|
+
`mlx-stack bench --save` for accurate numbers.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
memory_gb: Total unified memory in GB.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
Estimated bandwidth in GB/s.
|
|
208
|
+
"""
|
|
209
|
+
if memory_gb <= 8:
|
|
210
|
+
return 68.0
|
|
211
|
+
if memory_gb <= 16:
|
|
212
|
+
return 100.0
|
|
213
|
+
if memory_gb <= 32:
|
|
214
|
+
return 200.0
|
|
215
|
+
if memory_gb <= 64:
|
|
216
|
+
return 400.0
|
|
217
|
+
if memory_gb <= 128:
|
|
218
|
+
return 546.0
|
|
219
|
+
return 800.0
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def lookup_bandwidth(chip: str, memory_gb: int) -> tuple[float, bool]:
|
|
223
|
+
"""Look up bandwidth from the chip specs table.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
chip: Full chip name (e.g. 'Apple M4 Pro').
|
|
227
|
+
memory_gb: Total unified memory in GB.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Tuple of (bandwidth_gbps, is_estimate). is_estimate is True when
|
|
231
|
+
the chip is not in the lookup table and bandwidth was estimated.
|
|
232
|
+
"""
|
|
233
|
+
specs = CHIP_SPECS.get(chip)
|
|
234
|
+
if specs is not None:
|
|
235
|
+
return float(specs["bandwidth_gbps"]), False
|
|
236
|
+
return estimate_bandwidth(memory_gb), True
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
# --------------------------------------------------------------------------- #
|
|
240
|
+
# Main detection entry point
|
|
241
|
+
# --------------------------------------------------------------------------- #
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def detect_hardware() -> HardwareProfile:
|
|
245
|
+
"""Detect hardware and build a HardwareProfile.
|
|
246
|
+
|
|
247
|
+
Detects chip model, GPU core count, unified memory, and memory
|
|
248
|
+
bandwidth. Known chips use the lookup table; unknown chips use
|
|
249
|
+
a bandwidth estimation.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
A HardwareProfile with all detected values.
|
|
253
|
+
|
|
254
|
+
Raises:
|
|
255
|
+
HardwareError: If hardware cannot be detected or is unsupported.
|
|
256
|
+
"""
|
|
257
|
+
chip = detect_chip()
|
|
258
|
+
memory_gb = detect_memory_gb()
|
|
259
|
+
gpu_cores = detect_gpu_cores()
|
|
260
|
+
bandwidth_gbps, is_estimate = lookup_bandwidth(chip, memory_gb)
|
|
261
|
+
|
|
262
|
+
return HardwareProfile(
|
|
263
|
+
chip=chip,
|
|
264
|
+
gpu_cores=gpu_cores,
|
|
265
|
+
memory_gb=memory_gb,
|
|
266
|
+
bandwidth_gbps=bandwidth_gbps,
|
|
267
|
+
is_estimate=is_estimate,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def save_profile(profile: HardwareProfile) -> None:
|
|
272
|
+
"""Write the hardware profile to ~/.mlx-stack/profile.json.
|
|
273
|
+
|
|
274
|
+
Overwrites any existing profile file.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
profile: The hardware profile to save.
|
|
278
|
+
"""
|
|
279
|
+
ensure_data_home()
|
|
280
|
+
profile_path = get_profile_path()
|
|
281
|
+
profile_path.write_text(json.dumps(profile.to_dict(), indent=2) + "\n")
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def load_profile() -> HardwareProfile | None:
|
|
285
|
+
"""Load the hardware profile from disk, if it exists.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
A HardwareProfile if the file exists and is valid, None otherwise.
|
|
289
|
+
"""
|
|
290
|
+
profile_path = get_profile_path()
|
|
291
|
+
if not profile_path.exists():
|
|
292
|
+
return None
|
|
293
|
+
|
|
294
|
+
try:
|
|
295
|
+
data = json.loads(profile_path.read_text())
|
|
296
|
+
return HardwareProfile(
|
|
297
|
+
chip=data["chip"],
|
|
298
|
+
gpu_cores=data["gpu_cores"],
|
|
299
|
+
memory_gb=data["memory_gb"],
|
|
300
|
+
bandwidth_gbps=data["bandwidth_gbps"],
|
|
301
|
+
is_estimate=False, # saved profiles are considered authoritative
|
|
302
|
+
)
|
|
303
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
304
|
+
return None
|