mlx-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_stack/__init__.py +5 -0
- mlx_stack/_version.py +24 -0
- mlx_stack/cli/__init__.py +5 -0
- mlx_stack/cli/bench.py +221 -0
- mlx_stack/cli/config.py +166 -0
- mlx_stack/cli/down.py +109 -0
- mlx_stack/cli/init.py +180 -0
- mlx_stack/cli/install.py +165 -0
- mlx_stack/cli/logs.py +234 -0
- mlx_stack/cli/main.py +187 -0
- mlx_stack/cli/models.py +304 -0
- mlx_stack/cli/profile.py +65 -0
- mlx_stack/cli/pull.py +134 -0
- mlx_stack/cli/recommend.py +397 -0
- mlx_stack/cli/status.py +111 -0
- mlx_stack/cli/up.py +163 -0
- mlx_stack/cli/watch.py +252 -0
- mlx_stack/core/__init__.py +1 -0
- mlx_stack/core/benchmark.py +1182 -0
- mlx_stack/core/catalog.py +560 -0
- mlx_stack/core/config.py +471 -0
- mlx_stack/core/deps.py +323 -0
- mlx_stack/core/hardware.py +304 -0
- mlx_stack/core/launchd.py +531 -0
- mlx_stack/core/litellm_gen.py +188 -0
- mlx_stack/core/log_rotation.py +231 -0
- mlx_stack/core/log_viewer.py +386 -0
- mlx_stack/core/models.py +639 -0
- mlx_stack/core/paths.py +79 -0
- mlx_stack/core/process.py +887 -0
- mlx_stack/core/pull.py +815 -0
- mlx_stack/core/scoring.py +611 -0
- mlx_stack/core/stack_down.py +317 -0
- mlx_stack/core/stack_init.py +524 -0
- mlx_stack/core/stack_status.py +229 -0
- mlx_stack/core/stack_up.py +856 -0
- mlx_stack/core/watchdog.py +744 -0
- mlx_stack/data/__init__.py +1 -0
- mlx_stack/data/catalog/__init__.py +1 -0
- mlx_stack/data/catalog/deepseek-r1-32b.yaml +46 -0
- mlx_stack/data/catalog/deepseek-r1-8b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-12b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-27b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-4b.yaml +45 -0
- mlx_stack/data/catalog/llama3.3-8b.yaml +44 -0
- mlx_stack/data/catalog/nemotron-49b.yaml +41 -0
- mlx_stack/data/catalog/nemotron-8b.yaml +44 -0
- mlx_stack/data/catalog/qwen3-8b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-0.8b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-14b.yaml +46 -0
- mlx_stack/data/catalog/qwen3.5-32b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-3b.yaml +44 -0
- mlx_stack/data/catalog/qwen3.5-72b.yaml +42 -0
- mlx_stack/data/catalog/qwen3.5-8b.yaml +45 -0
- mlx_stack/py.typed +1 -0
- mlx_stack/utils/__init__.py +1 -0
- mlx_stack-0.1.0.dist-info/METADATA +397 -0
- mlx_stack-0.1.0.dist-info/RECORD +61 -0
- mlx_stack-0.1.0.dist-info/WHEEL +4 -0
- mlx_stack-0.1.0.dist-info/entry_points.txt +2 -0
- mlx_stack-0.1.0.dist-info/licenses/LICENSE +21 -0
mlx_stack/core/pull.py
ADDED
|
@@ -0,0 +1,815 @@
|
|
|
1
|
+
"""Model download and inventory management for mlx-stack.
|
|
2
|
+
|
|
3
|
+
Resolves catalog ID + quant to HuggingFace source, prefers mlx-community
|
|
4
|
+
pre-converted weights with fallback to mlx_lm conversion. Checks disk
|
|
5
|
+
space before download, shows progress, tracks inventory in models.json,
|
|
6
|
+
detects duplicates, handles network errors with automatic retry, and
|
|
7
|
+
cleans up partial downloads on failure.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import shutil
|
|
14
|
+
import subprocess
|
|
15
|
+
import time
|
|
16
|
+
from dataclasses import asdict, dataclass
|
|
17
|
+
from datetime import datetime, timezone
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from rich.console import Console
|
|
22
|
+
|
|
23
|
+
from mlx_stack.core.catalog import CatalogEntry, QuantSource, get_entry_by_id, load_catalog
|
|
24
|
+
from mlx_stack.core.config import ConfigCorruptError, get_value
|
|
25
|
+
from mlx_stack.core.paths import ensure_data_home, get_data_home
|
|
26
|
+
|
|
27
|
+
# --------------------------------------------------------------------------- #
|
|
28
|
+
# HuggingFace CLI binary resolution
|
|
29
|
+
# --------------------------------------------------------------------------- #
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _resolve_hf_cli() -> str:
|
|
33
|
+
"""Resolve the HuggingFace CLI binary name.
|
|
34
|
+
|
|
35
|
+
Modern huggingface_hub versions install the CLI as ``hf`` rather than
|
|
36
|
+
``huggingface-cli``. We try ``hf`` first (via :func:`shutil.which`)
|
|
37
|
+
and fall back to ``huggingface-cli`` for older installations.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The binary name that is available on ``PATH``, preferring ``hf``.
|
|
41
|
+
"""
|
|
42
|
+
if shutil.which("hf"):
|
|
43
|
+
return "hf"
|
|
44
|
+
if shutil.which("huggingface-cli"):
|
|
45
|
+
return "huggingface-cli"
|
|
46
|
+
# Neither found — return "hf" (the modern default) so the caller
|
|
47
|
+
# raises a helpful FileNotFoundError.
|
|
48
|
+
return "hf"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# --------------------------------------------------------------------------- #
|
|
52
|
+
# Exceptions
|
|
53
|
+
# --------------------------------------------------------------------------- #
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PullError(Exception):
|
|
57
|
+
"""Raised when model pull operations fail."""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class DiskSpaceError(PullError):
|
|
61
|
+
"""Raised when insufficient disk space is available."""
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class DownloadError(PullError):
|
|
65
|
+
"""Raised when model download fails."""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ConversionError(PullError):
|
|
69
|
+
"""Raised when mlx_lm conversion fails."""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class InvalidModelError(PullError):
|
|
73
|
+
"""Raised when the model ID is not found in the catalog."""
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# --------------------------------------------------------------------------- #
|
|
77
|
+
# Data classes
|
|
78
|
+
# --------------------------------------------------------------------------- #
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass(frozen=True)
|
|
82
|
+
class ModelInventoryEntry:
|
|
83
|
+
"""An entry in the local model inventory (models.json)."""
|
|
84
|
+
|
|
85
|
+
model_id: str
|
|
86
|
+
name: str
|
|
87
|
+
quant: str
|
|
88
|
+
source_type: str # "mlx-community" or "converted"
|
|
89
|
+
hf_repo: str
|
|
90
|
+
local_path: str
|
|
91
|
+
disk_size_gb: float
|
|
92
|
+
downloaded_at: str # ISO 8601 timestamp
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass(frozen=True)
|
|
96
|
+
class PullResult:
|
|
97
|
+
"""Result of a model pull operation."""
|
|
98
|
+
|
|
99
|
+
model_id: str
|
|
100
|
+
name: str
|
|
101
|
+
quant: str
|
|
102
|
+
source_type: str
|
|
103
|
+
local_path: Path
|
|
104
|
+
already_existed: bool
|
|
105
|
+
disk_size_gb: float
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# --------------------------------------------------------------------------- #
|
|
109
|
+
# Models directory resolution
|
|
110
|
+
# --------------------------------------------------------------------------- #
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_models_directory() -> Path:
|
|
114
|
+
"""Resolve the models directory from config.
|
|
115
|
+
|
|
116
|
+
Respects custom model-dir from config, falling back to the default
|
|
117
|
+
~/.mlx-stack/models/.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Path to the models directory.
|
|
121
|
+
"""
|
|
122
|
+
try:
|
|
123
|
+
model_dir = str(get_value("model-dir"))
|
|
124
|
+
return Path(model_dir).expanduser()
|
|
125
|
+
except (ConfigCorruptError, Exception):
|
|
126
|
+
return get_data_home() / "models"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# --------------------------------------------------------------------------- #
|
|
130
|
+
# Inventory management (models.json)
|
|
131
|
+
# --------------------------------------------------------------------------- #
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _get_inventory_path() -> Path:
|
|
135
|
+
"""Return the path to the models inventory file."""
|
|
136
|
+
return get_data_home() / "models.json"
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def load_inventory() -> list[dict[str, Any]]:
|
|
140
|
+
"""Load the model inventory from models.json.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
List of model inventory entries as dicts.
|
|
144
|
+
"""
|
|
145
|
+
path = _get_inventory_path()
|
|
146
|
+
if not path.exists():
|
|
147
|
+
return []
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
content = path.read_text(encoding="utf-8")
|
|
151
|
+
data = json.loads(content)
|
|
152
|
+
if isinstance(data, list):
|
|
153
|
+
return data
|
|
154
|
+
except (OSError, json.JSONDecodeError):
|
|
155
|
+
pass
|
|
156
|
+
|
|
157
|
+
return []
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def save_inventory(entries: list[dict[str, Any]]) -> None:
|
|
161
|
+
"""Save the model inventory to models.json.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
entries: List of model inventory entries as dicts.
|
|
165
|
+
"""
|
|
166
|
+
ensure_data_home()
|
|
167
|
+
path = _get_inventory_path()
|
|
168
|
+
content = json.dumps(entries, indent=2, sort_keys=False)
|
|
169
|
+
path.write_text(content + "\n", encoding="utf-8")
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def add_to_inventory(entry: ModelInventoryEntry) -> None:
|
|
173
|
+
"""Add a model entry to the inventory.
|
|
174
|
+
|
|
175
|
+
Replaces any existing entry with the same model_id and quant.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
entry: The inventory entry to add.
|
|
179
|
+
"""
|
|
180
|
+
entries = load_inventory()
|
|
181
|
+
|
|
182
|
+
# Remove existing entry for same model_id + quant
|
|
183
|
+
entries = [
|
|
184
|
+
e for e in entries
|
|
185
|
+
if not (e.get("model_id") == entry.model_id and e.get("quant") == entry.quant)
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
entries.append(asdict(entry))
|
|
189
|
+
save_inventory(entries)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def find_in_inventory(model_id: str, quant: str) -> dict[str, Any] | None:
|
|
193
|
+
"""Check if a model is already in the inventory.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
model_id: The catalog model ID.
|
|
197
|
+
quant: The quantization level.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
The inventory entry dict if found, None otherwise.
|
|
201
|
+
"""
|
|
202
|
+
entries = load_inventory()
|
|
203
|
+
for entry in entries:
|
|
204
|
+
if entry.get("model_id") == model_id and entry.get("quant") == quant:
|
|
205
|
+
return entry
|
|
206
|
+
return None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# --------------------------------------------------------------------------- #
|
|
210
|
+
# Source resolution
|
|
211
|
+
# --------------------------------------------------------------------------- #
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def resolve_source(
|
|
215
|
+
entry: CatalogEntry,
|
|
216
|
+
quant: str,
|
|
217
|
+
) -> tuple[QuantSource, str]:
|
|
218
|
+
"""Resolve the download source for a model + quant combination.
|
|
219
|
+
|
|
220
|
+
Prefers mlx-community pre-converted weights. Falls back to the
|
|
221
|
+
convert_from source if the quant source has convert_from=True.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
entry: The catalog entry.
|
|
225
|
+
quant: The quantization level.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Tuple of (QuantSource, source_type) where source_type is
|
|
229
|
+
"mlx-community" or "converted".
|
|
230
|
+
|
|
231
|
+
Raises:
|
|
232
|
+
PullError: If the quant is not available for the model.
|
|
233
|
+
"""
|
|
234
|
+
if quant not in entry.sources:
|
|
235
|
+
available = ", ".join(sorted(entry.sources.keys()))
|
|
236
|
+
msg = (
|
|
237
|
+
f"Quantization '{quant}' is not available for {entry.name}. "
|
|
238
|
+
f"Available: {available}"
|
|
239
|
+
)
|
|
240
|
+
raise PullError(msg)
|
|
241
|
+
|
|
242
|
+
source = entry.sources[quant]
|
|
243
|
+
|
|
244
|
+
if source.convert_from:
|
|
245
|
+
return source, "converted"
|
|
246
|
+
else:
|
|
247
|
+
return source, "mlx-community"
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
# --------------------------------------------------------------------------- #
|
|
251
|
+
# Disk space check
|
|
252
|
+
# --------------------------------------------------------------------------- #
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def check_disk_space(
|
|
256
|
+
models_dir: Path,
|
|
257
|
+
required_gb: float,
|
|
258
|
+
) -> tuple[bool, float]:
|
|
259
|
+
"""Check if there is enough disk space for the download.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
models_dir: The directory where the model will be stored.
|
|
263
|
+
required_gb: Required disk space in GB.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Tuple of (has_space, available_gb).
|
|
267
|
+
"""
|
|
268
|
+
# Ensure parent dir exists for statvfs
|
|
269
|
+
models_dir.mkdir(parents=True, exist_ok=True)
|
|
270
|
+
|
|
271
|
+
try:
|
|
272
|
+
stat = shutil.disk_usage(models_dir)
|
|
273
|
+
available_gb = stat.free / (1024**3)
|
|
274
|
+
# Add 20% buffer for safety
|
|
275
|
+
return available_gb >= required_gb * 1.2, round(available_gb, 1)
|
|
276
|
+
except OSError:
|
|
277
|
+
# If we can't check, allow the download
|
|
278
|
+
return True, 0.0
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
# --------------------------------------------------------------------------- #
|
|
282
|
+
# Model local path determination
|
|
283
|
+
# --------------------------------------------------------------------------- #
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def get_model_local_path(models_dir: Path, hf_repo: str) -> Path:
|
|
287
|
+
"""Determine the local path for a model based on its HF repo name.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
models_dir: The models directory.
|
|
291
|
+
hf_repo: The HuggingFace repo name (e.g., "mlx-community/Qwen3.5-0.8B-4bit").
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
The local path for the model directory.
|
|
295
|
+
"""
|
|
296
|
+
# Use the repo name (last part) as the directory name
|
|
297
|
+
repo_name = hf_repo.rsplit("/", 1)[-1] if "/" in hf_repo else hf_repo
|
|
298
|
+
return models_dir / repo_name
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def is_model_downloaded(model_path: Path) -> bool:
|
|
302
|
+
"""Check if a model directory already exists and has content.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
model_path: Path to the model directory.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
True if the directory exists and contains files.
|
|
309
|
+
"""
|
|
310
|
+
if not model_path.exists() or not model_path.is_dir():
|
|
311
|
+
return False
|
|
312
|
+
# Check for at least one file
|
|
313
|
+
try:
|
|
314
|
+
return any(model_path.iterdir())
|
|
315
|
+
except OSError:
|
|
316
|
+
return False
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
# --------------------------------------------------------------------------- #
|
|
320
|
+
# Download with retry
|
|
321
|
+
# --------------------------------------------------------------------------- #
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _filter_traceback(output: str) -> str:
|
|
325
|
+
"""Filter Python traceback lines from output, returning clean error message.
|
|
326
|
+
|
|
327
|
+
Extracts the meaningful error message from output that may contain
|
|
328
|
+
a full Python traceback. Removes traceback header, frame lines, and
|
|
329
|
+
code context lines, keeping only pre-traceback content and the final
|
|
330
|
+
exception line.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
output: Raw output that may contain traceback lines.
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
The filtered, human-readable error message.
|
|
337
|
+
"""
|
|
338
|
+
lines = output.strip().splitlines()
|
|
339
|
+
if not lines:
|
|
340
|
+
return output
|
|
341
|
+
|
|
342
|
+
# Check if the output contains a traceback
|
|
343
|
+
has_traceback = any(
|
|
344
|
+
line.strip().startswith("Traceback (most recent call last)")
|
|
345
|
+
for line in lines
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
if not has_traceback:
|
|
349
|
+
return output.strip()
|
|
350
|
+
|
|
351
|
+
# Walk through lines:
|
|
352
|
+
# - Keep lines before the traceback
|
|
353
|
+
# - Skip the traceback header and all indented frame/code lines
|
|
354
|
+
# - Keep the final exception line (first non-indented line after frames)
|
|
355
|
+
meaningful_lines: list[str] = []
|
|
356
|
+
in_traceback = False
|
|
357
|
+
for line in lines:
|
|
358
|
+
stripped = line.strip()
|
|
359
|
+
if stripped.startswith("Traceback (most recent call last)"):
|
|
360
|
+
in_traceback = True
|
|
361
|
+
continue
|
|
362
|
+
if in_traceback:
|
|
363
|
+
# Inside traceback: skip lines that start with whitespace
|
|
364
|
+
# (frame references like ' File "..."' and code context lines)
|
|
365
|
+
if line.startswith((" ", "\t")) or stripped == "":
|
|
366
|
+
continue
|
|
367
|
+
# First non-indented, non-empty line is the exception message
|
|
368
|
+
meaningful_lines.append(stripped)
|
|
369
|
+
in_traceback = False
|
|
370
|
+
continue
|
|
371
|
+
if stripped:
|
|
372
|
+
meaningful_lines.append(stripped)
|
|
373
|
+
|
|
374
|
+
return "\n".join(meaningful_lines) if meaningful_lines else output.strip()
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def _run_download(
|
|
378
|
+
hf_repo: str,
|
|
379
|
+
local_dir: Path,
|
|
380
|
+
console: Console,
|
|
381
|
+
) -> None:
|
|
382
|
+
"""Run the HuggingFace CLI download command with real-time output.
|
|
383
|
+
|
|
384
|
+
Resolves the CLI binary via :func:`_resolve_hf_cli` (prefers ``hf``,
|
|
385
|
+
falls back to ``huggingface-cli``). Uses subprocess.Popen with
|
|
386
|
+
stderr=subprocess.STDOUT so that HF CLI tqdm progress bars (written
|
|
387
|
+
to stderr) are merged into stdout and streamed to the user in
|
|
388
|
+
real-time. Captures output lines for error extraction on failure.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
hf_repo: The HuggingFace repo to download.
|
|
392
|
+
local_dir: The local directory to download to.
|
|
393
|
+
console: Rich console for output.
|
|
394
|
+
|
|
395
|
+
Raises:
|
|
396
|
+
DownloadError: If the download fails.
|
|
397
|
+
"""
|
|
398
|
+
# Resolve the HF CLI binary: prefer "hf" (modern), fall back to
|
|
399
|
+
# "huggingface-cli" (legacy).
|
|
400
|
+
hf_binary = _resolve_hf_cli()
|
|
401
|
+
cmd = [
|
|
402
|
+
hf_binary,
|
|
403
|
+
"download",
|
|
404
|
+
hf_repo,
|
|
405
|
+
"--local-dir",
|
|
406
|
+
str(local_dir),
|
|
407
|
+
]
|
|
408
|
+
|
|
409
|
+
try:
|
|
410
|
+
proc = subprocess.Popen(
|
|
411
|
+
cmd,
|
|
412
|
+
stdout=subprocess.PIPE,
|
|
413
|
+
stderr=subprocess.STDOUT,
|
|
414
|
+
text=True,
|
|
415
|
+
)
|
|
416
|
+
except FileNotFoundError:
|
|
417
|
+
msg = (
|
|
418
|
+
"HuggingFace CLI not found (tried 'hf' and 'huggingface-cli').\n"
|
|
419
|
+
"Install huggingface_hub:\n"
|
|
420
|
+
" pip install 'huggingface_hub[cli]'\n"
|
|
421
|
+
"Or: uv pip install 'huggingface_hub[cli]'"
|
|
422
|
+
)
|
|
423
|
+
raise DownloadError(msg) from None
|
|
424
|
+
except OSError as exc:
|
|
425
|
+
msg = f"Failed to start download: {exc}"
|
|
426
|
+
raise DownloadError(msg) from None
|
|
427
|
+
|
|
428
|
+
# Stream stdout (merged with stderr) line-by-line to show download
|
|
429
|
+
# progress bars in real-time. Capture lines for error extraction.
|
|
430
|
+
# Filter traceback blocks DURING streaming — suppress them from
|
|
431
|
+
# console output but still capture them for the error handler.
|
|
432
|
+
assert proc.stdout is not None
|
|
433
|
+
captured_lines: list[str] = []
|
|
434
|
+
in_traceback = False
|
|
435
|
+
try:
|
|
436
|
+
for line in proc.stdout:
|
|
437
|
+
stripped = line.rstrip("\n")
|
|
438
|
+
if not stripped:
|
|
439
|
+
continue
|
|
440
|
+
|
|
441
|
+
captured_lines.append(stripped)
|
|
442
|
+
|
|
443
|
+
# Detect start of a traceback block
|
|
444
|
+
if stripped.strip().startswith("Traceback (most recent call last)"):
|
|
445
|
+
in_traceback = True
|
|
446
|
+
continue
|
|
447
|
+
|
|
448
|
+
if in_traceback:
|
|
449
|
+
# Inside traceback: suppress indented frame/code lines
|
|
450
|
+
if stripped.startswith((" ", "\t")):
|
|
451
|
+
continue
|
|
452
|
+
# First non-indented line after frames is the exception
|
|
453
|
+
# message — suppress it too (it's the error summary)
|
|
454
|
+
in_traceback = False
|
|
455
|
+
continue
|
|
456
|
+
|
|
457
|
+
# Normal line — show to user
|
|
458
|
+
console.print(f" {stripped}")
|
|
459
|
+
|
|
460
|
+
# Wait for process to complete
|
|
461
|
+
proc.wait(timeout=3600)
|
|
462
|
+
except subprocess.TimeoutExpired:
|
|
463
|
+
proc.kill()
|
|
464
|
+
proc.wait()
|
|
465
|
+
msg = "Download timed out after 1 hour."
|
|
466
|
+
raise DownloadError(msg) from None
|
|
467
|
+
|
|
468
|
+
if proc.returncode != 0:
|
|
469
|
+
raw_output = "\n".join(captured_lines)
|
|
470
|
+
clean_error = _filter_traceback(raw_output)
|
|
471
|
+
msg = f"Download failed for {hf_repo}:\n{clean_error}"
|
|
472
|
+
raise DownloadError(msg)
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def download_model(
|
|
476
|
+
hf_repo: str,
|
|
477
|
+
local_dir: Path,
|
|
478
|
+
console: Console,
|
|
479
|
+
max_retries: int = 2,
|
|
480
|
+
) -> None:
|
|
481
|
+
"""Download a model from HuggingFace with automatic retry.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
hf_repo: The HuggingFace repo to download.
|
|
485
|
+
local_dir: The local directory to download to.
|
|
486
|
+
console: Rich console for output.
|
|
487
|
+
max_retries: Maximum number of attempts (default 2 = 1 retry).
|
|
488
|
+
|
|
489
|
+
Raises:
|
|
490
|
+
DownloadError: If all download attempts fail.
|
|
491
|
+
"""
|
|
492
|
+
local_dir.mkdir(parents=True, exist_ok=True)
|
|
493
|
+
|
|
494
|
+
last_error: DownloadError | None = None
|
|
495
|
+
for attempt in range(1, max_retries + 1):
|
|
496
|
+
try:
|
|
497
|
+
console.print(
|
|
498
|
+
f"[cyan]Downloading {hf_repo}...[/cyan]"
|
|
499
|
+
+ (f" (attempt {attempt}/{max_retries})" if attempt > 1 else "")
|
|
500
|
+
)
|
|
501
|
+
_run_download(hf_repo, local_dir, console)
|
|
502
|
+
console.print("[green]✓ Download complete.[/green]")
|
|
503
|
+
return
|
|
504
|
+
except DownloadError as exc:
|
|
505
|
+
last_error = exc
|
|
506
|
+
if attempt < max_retries:
|
|
507
|
+
console.print(
|
|
508
|
+
f"[yellow]Download attempt {attempt} failed. "
|
|
509
|
+
f"Retrying...[/yellow]"
|
|
510
|
+
)
|
|
511
|
+
time.sleep(2) # Brief pause before retry
|
|
512
|
+
else:
|
|
513
|
+
break
|
|
514
|
+
|
|
515
|
+
# All attempts failed — clean up partial download
|
|
516
|
+
_cleanup_partial(local_dir)
|
|
517
|
+
|
|
518
|
+
assert last_error is not None
|
|
519
|
+
msg = (
|
|
520
|
+
f"{last_error}\n\n"
|
|
521
|
+
"Check your network connection and HuggingFace authentication.\n"
|
|
522
|
+
"Set HF_TOKEN environment variable if the model requires authentication."
|
|
523
|
+
)
|
|
524
|
+
raise DownloadError(msg)
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
# --------------------------------------------------------------------------- #
|
|
528
|
+
# MLX conversion
|
|
529
|
+
# --------------------------------------------------------------------------- #
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def convert_model(
|
|
533
|
+
hf_repo: str,
|
|
534
|
+
local_dir: Path,
|
|
535
|
+
quant: str,
|
|
536
|
+
console: Console,
|
|
537
|
+
) -> None:
|
|
538
|
+
"""Convert a model using mlx_lm.
|
|
539
|
+
|
|
540
|
+
Downloads the base model and converts it to the specified quantization.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
hf_repo: The HuggingFace repo of the base model (e.g., "Qwen/Qwen3.5-8B").
|
|
544
|
+
local_dir: The directory to write converted model to.
|
|
545
|
+
quant: The quantization level (int4, int8).
|
|
546
|
+
console: Rich console for output.
|
|
547
|
+
|
|
548
|
+
Raises:
|
|
549
|
+
ConversionError: If conversion fails.
|
|
550
|
+
"""
|
|
551
|
+
# Map our quant names to mlx_lm quant names
|
|
552
|
+
quant_map = {
|
|
553
|
+
"int4": "4",
|
|
554
|
+
"int8": "8",
|
|
555
|
+
}
|
|
556
|
+
mlx_quant = quant_map.get(quant)
|
|
557
|
+
|
|
558
|
+
console.print(f"[cyan]Converting {hf_repo} to {quant}...[/cyan]")
|
|
559
|
+
console.print("[dim]This may take several minutes.[/dim]")
|
|
560
|
+
|
|
561
|
+
local_dir.mkdir(parents=True, exist_ok=True)
|
|
562
|
+
|
|
563
|
+
if mlx_quant:
|
|
564
|
+
# Quantized conversion
|
|
565
|
+
cmd = [
|
|
566
|
+
"python3",
|
|
567
|
+
"-m",
|
|
568
|
+
"mlx_lm.convert",
|
|
569
|
+
"--hf-path",
|
|
570
|
+
hf_repo,
|
|
571
|
+
"--mlx-path",
|
|
572
|
+
str(local_dir),
|
|
573
|
+
"-q",
|
|
574
|
+
"--q-bits",
|
|
575
|
+
mlx_quant,
|
|
576
|
+
]
|
|
577
|
+
else:
|
|
578
|
+
# bf16 — just download, no quant
|
|
579
|
+
cmd = [
|
|
580
|
+
"python3",
|
|
581
|
+
"-m",
|
|
582
|
+
"mlx_lm.convert",
|
|
583
|
+
"--hf-path",
|
|
584
|
+
hf_repo,
|
|
585
|
+
"--mlx-path",
|
|
586
|
+
str(local_dir),
|
|
587
|
+
]
|
|
588
|
+
|
|
589
|
+
try:
|
|
590
|
+
result = subprocess.run(
|
|
591
|
+
cmd,
|
|
592
|
+
capture_output=True,
|
|
593
|
+
text=True,
|
|
594
|
+
timeout=7200, # 2 hour timeout for large conversions
|
|
595
|
+
)
|
|
596
|
+
except FileNotFoundError:
|
|
597
|
+
_cleanup_partial(local_dir)
|
|
598
|
+
msg = (
|
|
599
|
+
"mlx_lm not found. Install it with:\n"
|
|
600
|
+
" pip install mlx_lm\n"
|
|
601
|
+
"Or: uv pip install mlx_lm"
|
|
602
|
+
)
|
|
603
|
+
raise ConversionError(msg) from None
|
|
604
|
+
except subprocess.TimeoutExpired:
|
|
605
|
+
_cleanup_partial(local_dir)
|
|
606
|
+
msg = "Conversion timed out after 2 hours."
|
|
607
|
+
raise ConversionError(msg) from None
|
|
608
|
+
except OSError as exc:
|
|
609
|
+
_cleanup_partial(local_dir)
|
|
610
|
+
msg = f"Failed to start conversion: {exc}"
|
|
611
|
+
raise ConversionError(msg) from None
|
|
612
|
+
|
|
613
|
+
if result.returncode != 0:
|
|
614
|
+
stderr = result.stderr.strip()
|
|
615
|
+
_cleanup_partial(local_dir)
|
|
616
|
+
msg = f"Conversion failed for {hf_repo}:\n{stderr}"
|
|
617
|
+
raise ConversionError(msg)
|
|
618
|
+
|
|
619
|
+
console.print("[green]✓ Conversion complete.[/green]")
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
# --------------------------------------------------------------------------- #
|
|
623
|
+
# Cleanup
|
|
624
|
+
# --------------------------------------------------------------------------- #
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def _cleanup_partial(local_dir: Path) -> None:
|
|
628
|
+
"""Remove a partial/failed download directory.
|
|
629
|
+
|
|
630
|
+
Args:
|
|
631
|
+
local_dir: The directory to remove.
|
|
632
|
+
"""
|
|
633
|
+
if local_dir.exists():
|
|
634
|
+
try:
|
|
635
|
+
shutil.rmtree(local_dir)
|
|
636
|
+
except OSError:
|
|
637
|
+
pass
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
# --------------------------------------------------------------------------- #
|
|
641
|
+
# Quant validation
|
|
642
|
+
# --------------------------------------------------------------------------- #
|
|
643
|
+
|
|
644
|
+
VALID_QUANTS = {"int4", "int8", "bf16"}
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def validate_quant(quant: str) -> str:
|
|
648
|
+
"""Validate a quantization value.
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
quant: The quantization string to validate.
|
|
652
|
+
|
|
653
|
+
Returns:
|
|
654
|
+
The validated quant string.
|
|
655
|
+
|
|
656
|
+
Raises:
|
|
657
|
+
PullError: If the quant is not valid.
|
|
658
|
+
"""
|
|
659
|
+
if quant not in VALID_QUANTS:
|
|
660
|
+
valid = ", ".join(sorted(VALID_QUANTS))
|
|
661
|
+
msg = f"Invalid quantization '{quant}'. Valid values: {valid}"
|
|
662
|
+
raise PullError(msg)
|
|
663
|
+
return quant
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
# --------------------------------------------------------------------------- #
|
|
667
|
+
# Main pull orchestrator
|
|
668
|
+
# --------------------------------------------------------------------------- #
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
def pull_model(
|
|
672
|
+
model_id: str,
|
|
673
|
+
quant: str | None = None,
|
|
674
|
+
force: bool = False,
|
|
675
|
+
console: Console | None = None,
|
|
676
|
+
catalog: list[CatalogEntry] | None = None,
|
|
677
|
+
) -> PullResult:
|
|
678
|
+
"""Pull (download) a model from the catalog.
|
|
679
|
+
|
|
680
|
+
Orchestrates the full pull workflow:
|
|
681
|
+
1. Resolve model from catalog
|
|
682
|
+
2. Determine quant (from flag or config default)
|
|
683
|
+
3. Resolve source (mlx-community or convert_from)
|
|
684
|
+
4. Check disk space
|
|
685
|
+
5. Check for existing download (duplicate detection)
|
|
686
|
+
6. Download or convert
|
|
687
|
+
7. Update inventory
|
|
688
|
+
|
|
689
|
+
Args:
|
|
690
|
+
model_id: The catalog model ID (e.g., "qwen3.5-8b").
|
|
691
|
+
quant: Quantization override (None uses config default).
|
|
692
|
+
force: If True, re-download even if model exists.
|
|
693
|
+
console: Rich console for output (creates one if None).
|
|
694
|
+
catalog: Pre-loaded catalog (loads from package if None).
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
PullResult with details of the completed pull.
|
|
698
|
+
|
|
699
|
+
Raises:
|
|
700
|
+
InvalidModelError: If the model ID is not in the catalog.
|
|
701
|
+
PullError: If the quant is invalid or unavailable.
|
|
702
|
+
DiskSpaceError: If insufficient disk space.
|
|
703
|
+
DownloadError: If download fails after retries.
|
|
704
|
+
ConversionError: If mlx_lm conversion fails.
|
|
705
|
+
"""
|
|
706
|
+
if console is None:
|
|
707
|
+
console = Console()
|
|
708
|
+
|
|
709
|
+
# 1. Load catalog and resolve model
|
|
710
|
+
if catalog is None:
|
|
711
|
+
catalog = load_catalog()
|
|
712
|
+
|
|
713
|
+
entry = get_entry_by_id(catalog, model_id)
|
|
714
|
+
if entry is None:
|
|
715
|
+
msg = (
|
|
716
|
+
f"Model '{model_id}' not found in catalog.\n"
|
|
717
|
+
"Run 'mlx-stack models --catalog' to see available models."
|
|
718
|
+
)
|
|
719
|
+
raise InvalidModelError(msg)
|
|
720
|
+
|
|
721
|
+
# 2. Determine quantization
|
|
722
|
+
if quant is None:
|
|
723
|
+
try:
|
|
724
|
+
quant = str(get_value("default-quant"))
|
|
725
|
+
except Exception:
|
|
726
|
+
quant = "int4"
|
|
727
|
+
|
|
728
|
+
quant = validate_quant(quant)
|
|
729
|
+
|
|
730
|
+
# 3. Resolve source
|
|
731
|
+
source, source_type = resolve_source(entry, quant)
|
|
732
|
+
|
|
733
|
+
# 4. Get models directory and local path
|
|
734
|
+
models_dir = get_models_directory()
|
|
735
|
+
local_path = get_model_local_path(models_dir, source.hf_repo)
|
|
736
|
+
|
|
737
|
+
# 5. Check for existing download (duplicate detection)
|
|
738
|
+
if not force and is_model_downloaded(local_path):
|
|
739
|
+
# Check inventory too
|
|
740
|
+
inv_entry = find_in_inventory(model_id, quant)
|
|
741
|
+
if inv_entry is not None or is_model_downloaded(local_path):
|
|
742
|
+
console.print(
|
|
743
|
+
f"[yellow]Model '{entry.name}' ({quant}) already exists at "
|
|
744
|
+
f"{local_path}.[/yellow]\n"
|
|
745
|
+
"Use --force to re-download."
|
|
746
|
+
)
|
|
747
|
+
return PullResult(
|
|
748
|
+
model_id=model_id,
|
|
749
|
+
name=entry.name,
|
|
750
|
+
quant=quant,
|
|
751
|
+
source_type=source_type,
|
|
752
|
+
local_path=local_path,
|
|
753
|
+
already_existed=True,
|
|
754
|
+
disk_size_gb=source.disk_size_gb,
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
# 6. Check disk space
|
|
758
|
+
has_space, available_gb = check_disk_space(models_dir, source.disk_size_gb)
|
|
759
|
+
if not has_space:
|
|
760
|
+
msg = (
|
|
761
|
+
f"Insufficient disk space for {entry.name} ({quant}).\n"
|
|
762
|
+
f"Required: {source.disk_size_gb:.1f} GB (+ 20% buffer)\n"
|
|
763
|
+
f"Available: {available_gb:.1f} GB"
|
|
764
|
+
)
|
|
765
|
+
raise DiskSpaceError(msg)
|
|
766
|
+
|
|
767
|
+
# 7. Display info
|
|
768
|
+
console.print()
|
|
769
|
+
console.print(f"[bold cyan]Pulling {entry.name}[/bold cyan]")
|
|
770
|
+
console.print(f" Quantization: {quant}")
|
|
771
|
+
console.print(f" Source: {source.hf_repo}")
|
|
772
|
+
console.print(f" Type: {source_type}")
|
|
773
|
+
console.print(f" Estimated size: {source.disk_size_gb:.1f} GB")
|
|
774
|
+
console.print(f" Destination: {local_path}")
|
|
775
|
+
console.print()
|
|
776
|
+
|
|
777
|
+
# 8. Download or convert
|
|
778
|
+
if force and local_path.exists():
|
|
779
|
+
console.print("[yellow]Removing existing download (--force)...[/yellow]")
|
|
780
|
+
_cleanup_partial(local_path)
|
|
781
|
+
|
|
782
|
+
if source_type == "mlx-community":
|
|
783
|
+
download_model(source.hf_repo, local_path, console)
|
|
784
|
+
else:
|
|
785
|
+
# convert_from — need mlx_lm conversion
|
|
786
|
+
convert_model(source.hf_repo, local_path, quant, console)
|
|
787
|
+
|
|
788
|
+
# 9. Update inventory
|
|
789
|
+
inv = ModelInventoryEntry(
|
|
790
|
+
model_id=model_id,
|
|
791
|
+
name=entry.name,
|
|
792
|
+
quant=quant,
|
|
793
|
+
source_type=source_type,
|
|
794
|
+
hf_repo=source.hf_repo,
|
|
795
|
+
local_path=str(local_path),
|
|
796
|
+
disk_size_gb=source.disk_size_gb,
|
|
797
|
+
downloaded_at=datetime.now(timezone.utc).isoformat(),
|
|
798
|
+
)
|
|
799
|
+
add_to_inventory(inv)
|
|
800
|
+
|
|
801
|
+
console.print()
|
|
802
|
+
console.print(
|
|
803
|
+
f"[bold green]✓ {entry.name} ({quant}) is ready.[/bold green]"
|
|
804
|
+
)
|
|
805
|
+
console.print(f" Location: {local_path}")
|
|
806
|
+
|
|
807
|
+
return PullResult(
|
|
808
|
+
model_id=model_id,
|
|
809
|
+
name=entry.name,
|
|
810
|
+
quant=quant,
|
|
811
|
+
source_type=source_type,
|
|
812
|
+
local_path=local_path,
|
|
813
|
+
already_existed=False,
|
|
814
|
+
disk_size_gb=source.disk_size_gb,
|
|
815
|
+
)
|