mlx-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. mlx_stack/__init__.py +5 -0
  2. mlx_stack/_version.py +24 -0
  3. mlx_stack/cli/__init__.py +5 -0
  4. mlx_stack/cli/bench.py +221 -0
  5. mlx_stack/cli/config.py +166 -0
  6. mlx_stack/cli/down.py +109 -0
  7. mlx_stack/cli/init.py +180 -0
  8. mlx_stack/cli/install.py +165 -0
  9. mlx_stack/cli/logs.py +234 -0
  10. mlx_stack/cli/main.py +187 -0
  11. mlx_stack/cli/models.py +304 -0
  12. mlx_stack/cli/profile.py +65 -0
  13. mlx_stack/cli/pull.py +134 -0
  14. mlx_stack/cli/recommend.py +397 -0
  15. mlx_stack/cli/status.py +111 -0
  16. mlx_stack/cli/up.py +163 -0
  17. mlx_stack/cli/watch.py +252 -0
  18. mlx_stack/core/__init__.py +1 -0
  19. mlx_stack/core/benchmark.py +1182 -0
  20. mlx_stack/core/catalog.py +560 -0
  21. mlx_stack/core/config.py +471 -0
  22. mlx_stack/core/deps.py +323 -0
  23. mlx_stack/core/hardware.py +304 -0
  24. mlx_stack/core/launchd.py +531 -0
  25. mlx_stack/core/litellm_gen.py +188 -0
  26. mlx_stack/core/log_rotation.py +231 -0
  27. mlx_stack/core/log_viewer.py +386 -0
  28. mlx_stack/core/models.py +639 -0
  29. mlx_stack/core/paths.py +79 -0
  30. mlx_stack/core/process.py +887 -0
  31. mlx_stack/core/pull.py +815 -0
  32. mlx_stack/core/scoring.py +611 -0
  33. mlx_stack/core/stack_down.py +317 -0
  34. mlx_stack/core/stack_init.py +524 -0
  35. mlx_stack/core/stack_status.py +229 -0
  36. mlx_stack/core/stack_up.py +856 -0
  37. mlx_stack/core/watchdog.py +744 -0
  38. mlx_stack/data/__init__.py +1 -0
  39. mlx_stack/data/catalog/__init__.py +1 -0
  40. mlx_stack/data/catalog/deepseek-r1-32b.yaml +46 -0
  41. mlx_stack/data/catalog/deepseek-r1-8b.yaml +45 -0
  42. mlx_stack/data/catalog/gemma3-12b.yaml +45 -0
  43. mlx_stack/data/catalog/gemma3-27b.yaml +45 -0
  44. mlx_stack/data/catalog/gemma3-4b.yaml +45 -0
  45. mlx_stack/data/catalog/llama3.3-8b.yaml +44 -0
  46. mlx_stack/data/catalog/nemotron-49b.yaml +41 -0
  47. mlx_stack/data/catalog/nemotron-8b.yaml +44 -0
  48. mlx_stack/data/catalog/qwen3-8b.yaml +45 -0
  49. mlx_stack/data/catalog/qwen3.5-0.8b.yaml +45 -0
  50. mlx_stack/data/catalog/qwen3.5-14b.yaml +46 -0
  51. mlx_stack/data/catalog/qwen3.5-32b.yaml +45 -0
  52. mlx_stack/data/catalog/qwen3.5-3b.yaml +44 -0
  53. mlx_stack/data/catalog/qwen3.5-72b.yaml +42 -0
  54. mlx_stack/data/catalog/qwen3.5-8b.yaml +45 -0
  55. mlx_stack/py.typed +1 -0
  56. mlx_stack/utils/__init__.py +1 -0
  57. mlx_stack-0.1.0.dist-info/METADATA +397 -0
  58. mlx_stack-0.1.0.dist-info/RECORD +61 -0
  59. mlx_stack-0.1.0.dist-info/WHEEL +4 -0
  60. mlx_stack-0.1.0.dist-info/entry_points.txt +2 -0
  61. mlx_stack-0.1.0.dist-info/licenses/LICENSE +21 -0
mlx_stack/core/pull.py ADDED
@@ -0,0 +1,815 @@
1
+ """Model download and inventory management for mlx-stack.
2
+
3
+ Resolves catalog ID + quant to HuggingFace source, prefers mlx-community
4
+ pre-converted weights with fallback to mlx_lm conversion. Checks disk
5
+ space before download, shows progress, tracks inventory in models.json,
6
+ detects duplicates, handles network errors with automatic retry, and
7
+ cleans up partial downloads on failure.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import shutil
14
+ import subprocess
15
+ import time
16
+ from dataclasses import asdict, dataclass
17
+ from datetime import datetime, timezone
18
+ from pathlib import Path
19
+ from typing import Any
20
+
21
+ from rich.console import Console
22
+
23
+ from mlx_stack.core.catalog import CatalogEntry, QuantSource, get_entry_by_id, load_catalog
24
+ from mlx_stack.core.config import ConfigCorruptError, get_value
25
+ from mlx_stack.core.paths import ensure_data_home, get_data_home
26
+
27
+ # --------------------------------------------------------------------------- #
28
+ # HuggingFace CLI binary resolution
29
+ # --------------------------------------------------------------------------- #
30
+
31
+
32
+ def _resolve_hf_cli() -> str:
33
+ """Resolve the HuggingFace CLI binary name.
34
+
35
+ Modern huggingface_hub versions install the CLI as ``hf`` rather than
36
+ ``huggingface-cli``. We try ``hf`` first (via :func:`shutil.which`)
37
+ and fall back to ``huggingface-cli`` for older installations.
38
+
39
+ Returns:
40
+ The binary name that is available on ``PATH``, preferring ``hf``.
41
+ """
42
+ if shutil.which("hf"):
43
+ return "hf"
44
+ if shutil.which("huggingface-cli"):
45
+ return "huggingface-cli"
46
+ # Neither found — return "hf" (the modern default) so the caller
47
+ # raises a helpful FileNotFoundError.
48
+ return "hf"
49
+
50
+
51
+ # --------------------------------------------------------------------------- #
52
+ # Exceptions
53
+ # --------------------------------------------------------------------------- #
54
+
55
+
56
+ class PullError(Exception):
57
+ """Raised when model pull operations fail."""
58
+
59
+
60
+ class DiskSpaceError(PullError):
61
+ """Raised when insufficient disk space is available."""
62
+
63
+
64
+ class DownloadError(PullError):
65
+ """Raised when model download fails."""
66
+
67
+
68
+ class ConversionError(PullError):
69
+ """Raised when mlx_lm conversion fails."""
70
+
71
+
72
+ class InvalidModelError(PullError):
73
+ """Raised when the model ID is not found in the catalog."""
74
+
75
+
76
+ # --------------------------------------------------------------------------- #
77
+ # Data classes
78
+ # --------------------------------------------------------------------------- #
79
+
80
+
81
+ @dataclass(frozen=True)
82
+ class ModelInventoryEntry:
83
+ """An entry in the local model inventory (models.json)."""
84
+
85
+ model_id: str
86
+ name: str
87
+ quant: str
88
+ source_type: str # "mlx-community" or "converted"
89
+ hf_repo: str
90
+ local_path: str
91
+ disk_size_gb: float
92
+ downloaded_at: str # ISO 8601 timestamp
93
+
94
+
95
+ @dataclass(frozen=True)
96
+ class PullResult:
97
+ """Result of a model pull operation."""
98
+
99
+ model_id: str
100
+ name: str
101
+ quant: str
102
+ source_type: str
103
+ local_path: Path
104
+ already_existed: bool
105
+ disk_size_gb: float
106
+
107
+
108
+ # --------------------------------------------------------------------------- #
109
+ # Models directory resolution
110
+ # --------------------------------------------------------------------------- #
111
+
112
+
113
+ def get_models_directory() -> Path:
114
+ """Resolve the models directory from config.
115
+
116
+ Respects custom model-dir from config, falling back to the default
117
+ ~/.mlx-stack/models/.
118
+
119
+ Returns:
120
+ Path to the models directory.
121
+ """
122
+ try:
123
+ model_dir = str(get_value("model-dir"))
124
+ return Path(model_dir).expanduser()
125
+ except (ConfigCorruptError, Exception):
126
+ return get_data_home() / "models"
127
+
128
+
129
+ # --------------------------------------------------------------------------- #
130
+ # Inventory management (models.json)
131
+ # --------------------------------------------------------------------------- #
132
+
133
+
134
+ def _get_inventory_path() -> Path:
135
+ """Return the path to the models inventory file."""
136
+ return get_data_home() / "models.json"
137
+
138
+
139
+ def load_inventory() -> list[dict[str, Any]]:
140
+ """Load the model inventory from models.json.
141
+
142
+ Returns:
143
+ List of model inventory entries as dicts.
144
+ """
145
+ path = _get_inventory_path()
146
+ if not path.exists():
147
+ return []
148
+
149
+ try:
150
+ content = path.read_text(encoding="utf-8")
151
+ data = json.loads(content)
152
+ if isinstance(data, list):
153
+ return data
154
+ except (OSError, json.JSONDecodeError):
155
+ pass
156
+
157
+ return []
158
+
159
+
160
+ def save_inventory(entries: list[dict[str, Any]]) -> None:
161
+ """Save the model inventory to models.json.
162
+
163
+ Args:
164
+ entries: List of model inventory entries as dicts.
165
+ """
166
+ ensure_data_home()
167
+ path = _get_inventory_path()
168
+ content = json.dumps(entries, indent=2, sort_keys=False)
169
+ path.write_text(content + "\n", encoding="utf-8")
170
+
171
+
172
+ def add_to_inventory(entry: ModelInventoryEntry) -> None:
173
+ """Add a model entry to the inventory.
174
+
175
+ Replaces any existing entry with the same model_id and quant.
176
+
177
+ Args:
178
+ entry: The inventory entry to add.
179
+ """
180
+ entries = load_inventory()
181
+
182
+ # Remove existing entry for same model_id + quant
183
+ entries = [
184
+ e for e in entries
185
+ if not (e.get("model_id") == entry.model_id and e.get("quant") == entry.quant)
186
+ ]
187
+
188
+ entries.append(asdict(entry))
189
+ save_inventory(entries)
190
+
191
+
192
+ def find_in_inventory(model_id: str, quant: str) -> dict[str, Any] | None:
193
+ """Check if a model is already in the inventory.
194
+
195
+ Args:
196
+ model_id: The catalog model ID.
197
+ quant: The quantization level.
198
+
199
+ Returns:
200
+ The inventory entry dict if found, None otherwise.
201
+ """
202
+ entries = load_inventory()
203
+ for entry in entries:
204
+ if entry.get("model_id") == model_id and entry.get("quant") == quant:
205
+ return entry
206
+ return None
207
+
208
+
209
+ # --------------------------------------------------------------------------- #
210
+ # Source resolution
211
+ # --------------------------------------------------------------------------- #
212
+
213
+
214
+ def resolve_source(
215
+ entry: CatalogEntry,
216
+ quant: str,
217
+ ) -> tuple[QuantSource, str]:
218
+ """Resolve the download source for a model + quant combination.
219
+
220
+ Prefers mlx-community pre-converted weights. Falls back to the
221
+ convert_from source if the quant source has convert_from=True.
222
+
223
+ Args:
224
+ entry: The catalog entry.
225
+ quant: The quantization level.
226
+
227
+ Returns:
228
+ Tuple of (QuantSource, source_type) where source_type is
229
+ "mlx-community" or "converted".
230
+
231
+ Raises:
232
+ PullError: If the quant is not available for the model.
233
+ """
234
+ if quant not in entry.sources:
235
+ available = ", ".join(sorted(entry.sources.keys()))
236
+ msg = (
237
+ f"Quantization '{quant}' is not available for {entry.name}. "
238
+ f"Available: {available}"
239
+ )
240
+ raise PullError(msg)
241
+
242
+ source = entry.sources[quant]
243
+
244
+ if source.convert_from:
245
+ return source, "converted"
246
+ else:
247
+ return source, "mlx-community"
248
+
249
+
250
+ # --------------------------------------------------------------------------- #
251
+ # Disk space check
252
+ # --------------------------------------------------------------------------- #
253
+
254
+
255
+ def check_disk_space(
256
+ models_dir: Path,
257
+ required_gb: float,
258
+ ) -> tuple[bool, float]:
259
+ """Check if there is enough disk space for the download.
260
+
261
+ Args:
262
+ models_dir: The directory where the model will be stored.
263
+ required_gb: Required disk space in GB.
264
+
265
+ Returns:
266
+ Tuple of (has_space, available_gb).
267
+ """
268
+ # Ensure parent dir exists for statvfs
269
+ models_dir.mkdir(parents=True, exist_ok=True)
270
+
271
+ try:
272
+ stat = shutil.disk_usage(models_dir)
273
+ available_gb = stat.free / (1024**3)
274
+ # Add 20% buffer for safety
275
+ return available_gb >= required_gb * 1.2, round(available_gb, 1)
276
+ except OSError:
277
+ # If we can't check, allow the download
278
+ return True, 0.0
279
+
280
+
281
+ # --------------------------------------------------------------------------- #
282
+ # Model local path determination
283
+ # --------------------------------------------------------------------------- #
284
+
285
+
286
+ def get_model_local_path(models_dir: Path, hf_repo: str) -> Path:
287
+ """Determine the local path for a model based on its HF repo name.
288
+
289
+ Args:
290
+ models_dir: The models directory.
291
+ hf_repo: The HuggingFace repo name (e.g., "mlx-community/Qwen3.5-0.8B-4bit").
292
+
293
+ Returns:
294
+ The local path for the model directory.
295
+ """
296
+ # Use the repo name (last part) as the directory name
297
+ repo_name = hf_repo.rsplit("/", 1)[-1] if "/" in hf_repo else hf_repo
298
+ return models_dir / repo_name
299
+
300
+
301
+ def is_model_downloaded(model_path: Path) -> bool:
302
+ """Check if a model directory already exists and has content.
303
+
304
+ Args:
305
+ model_path: Path to the model directory.
306
+
307
+ Returns:
308
+ True if the directory exists and contains files.
309
+ """
310
+ if not model_path.exists() or not model_path.is_dir():
311
+ return False
312
+ # Check for at least one file
313
+ try:
314
+ return any(model_path.iterdir())
315
+ except OSError:
316
+ return False
317
+
318
+
319
+ # --------------------------------------------------------------------------- #
320
+ # Download with retry
321
+ # --------------------------------------------------------------------------- #
322
+
323
+
324
+ def _filter_traceback(output: str) -> str:
325
+ """Filter Python traceback lines from output, returning clean error message.
326
+
327
+ Extracts the meaningful error message from output that may contain
328
+ a full Python traceback. Removes traceback header, frame lines, and
329
+ code context lines, keeping only pre-traceback content and the final
330
+ exception line.
331
+
332
+ Args:
333
+ output: Raw output that may contain traceback lines.
334
+
335
+ Returns:
336
+ The filtered, human-readable error message.
337
+ """
338
+ lines = output.strip().splitlines()
339
+ if not lines:
340
+ return output
341
+
342
+ # Check if the output contains a traceback
343
+ has_traceback = any(
344
+ line.strip().startswith("Traceback (most recent call last)")
345
+ for line in lines
346
+ )
347
+
348
+ if not has_traceback:
349
+ return output.strip()
350
+
351
+ # Walk through lines:
352
+ # - Keep lines before the traceback
353
+ # - Skip the traceback header and all indented frame/code lines
354
+ # - Keep the final exception line (first non-indented line after frames)
355
+ meaningful_lines: list[str] = []
356
+ in_traceback = False
357
+ for line in lines:
358
+ stripped = line.strip()
359
+ if stripped.startswith("Traceback (most recent call last)"):
360
+ in_traceback = True
361
+ continue
362
+ if in_traceback:
363
+ # Inside traceback: skip lines that start with whitespace
364
+ # (frame references like ' File "..."' and code context lines)
365
+ if line.startswith((" ", "\t")) or stripped == "":
366
+ continue
367
+ # First non-indented, non-empty line is the exception message
368
+ meaningful_lines.append(stripped)
369
+ in_traceback = False
370
+ continue
371
+ if stripped:
372
+ meaningful_lines.append(stripped)
373
+
374
+ return "\n".join(meaningful_lines) if meaningful_lines else output.strip()
375
+
376
+
377
+ def _run_download(
378
+ hf_repo: str,
379
+ local_dir: Path,
380
+ console: Console,
381
+ ) -> None:
382
+ """Run the HuggingFace CLI download command with real-time output.
383
+
384
+ Resolves the CLI binary via :func:`_resolve_hf_cli` (prefers ``hf``,
385
+ falls back to ``huggingface-cli``). Uses subprocess.Popen with
386
+ stderr=subprocess.STDOUT so that HF CLI tqdm progress bars (written
387
+ to stderr) are merged into stdout and streamed to the user in
388
+ real-time. Captures output lines for error extraction on failure.
389
+
390
+ Args:
391
+ hf_repo: The HuggingFace repo to download.
392
+ local_dir: The local directory to download to.
393
+ console: Rich console for output.
394
+
395
+ Raises:
396
+ DownloadError: If the download fails.
397
+ """
398
+ # Resolve the HF CLI binary: prefer "hf" (modern), fall back to
399
+ # "huggingface-cli" (legacy).
400
+ hf_binary = _resolve_hf_cli()
401
+ cmd = [
402
+ hf_binary,
403
+ "download",
404
+ hf_repo,
405
+ "--local-dir",
406
+ str(local_dir),
407
+ ]
408
+
409
+ try:
410
+ proc = subprocess.Popen(
411
+ cmd,
412
+ stdout=subprocess.PIPE,
413
+ stderr=subprocess.STDOUT,
414
+ text=True,
415
+ )
416
+ except FileNotFoundError:
417
+ msg = (
418
+ "HuggingFace CLI not found (tried 'hf' and 'huggingface-cli').\n"
419
+ "Install huggingface_hub:\n"
420
+ " pip install 'huggingface_hub[cli]'\n"
421
+ "Or: uv pip install 'huggingface_hub[cli]'"
422
+ )
423
+ raise DownloadError(msg) from None
424
+ except OSError as exc:
425
+ msg = f"Failed to start download: {exc}"
426
+ raise DownloadError(msg) from None
427
+
428
+ # Stream stdout (merged with stderr) line-by-line to show download
429
+ # progress bars in real-time. Capture lines for error extraction.
430
+ # Filter traceback blocks DURING streaming — suppress them from
431
+ # console output but still capture them for the error handler.
432
+ assert proc.stdout is not None
433
+ captured_lines: list[str] = []
434
+ in_traceback = False
435
+ try:
436
+ for line in proc.stdout:
437
+ stripped = line.rstrip("\n")
438
+ if not stripped:
439
+ continue
440
+
441
+ captured_lines.append(stripped)
442
+
443
+ # Detect start of a traceback block
444
+ if stripped.strip().startswith("Traceback (most recent call last)"):
445
+ in_traceback = True
446
+ continue
447
+
448
+ if in_traceback:
449
+ # Inside traceback: suppress indented frame/code lines
450
+ if stripped.startswith((" ", "\t")):
451
+ continue
452
+ # First non-indented line after frames is the exception
453
+ # message — suppress it too (it's the error summary)
454
+ in_traceback = False
455
+ continue
456
+
457
+ # Normal line — show to user
458
+ console.print(f" {stripped}")
459
+
460
+ # Wait for process to complete
461
+ proc.wait(timeout=3600)
462
+ except subprocess.TimeoutExpired:
463
+ proc.kill()
464
+ proc.wait()
465
+ msg = "Download timed out after 1 hour."
466
+ raise DownloadError(msg) from None
467
+
468
+ if proc.returncode != 0:
469
+ raw_output = "\n".join(captured_lines)
470
+ clean_error = _filter_traceback(raw_output)
471
+ msg = f"Download failed for {hf_repo}:\n{clean_error}"
472
+ raise DownloadError(msg)
473
+
474
+
475
+ def download_model(
476
+ hf_repo: str,
477
+ local_dir: Path,
478
+ console: Console,
479
+ max_retries: int = 2,
480
+ ) -> None:
481
+ """Download a model from HuggingFace with automatic retry.
482
+
483
+ Args:
484
+ hf_repo: The HuggingFace repo to download.
485
+ local_dir: The local directory to download to.
486
+ console: Rich console for output.
487
+ max_retries: Maximum number of attempts (default 2 = 1 retry).
488
+
489
+ Raises:
490
+ DownloadError: If all download attempts fail.
491
+ """
492
+ local_dir.mkdir(parents=True, exist_ok=True)
493
+
494
+ last_error: DownloadError | None = None
495
+ for attempt in range(1, max_retries + 1):
496
+ try:
497
+ console.print(
498
+ f"[cyan]Downloading {hf_repo}...[/cyan]"
499
+ + (f" (attempt {attempt}/{max_retries})" if attempt > 1 else "")
500
+ )
501
+ _run_download(hf_repo, local_dir, console)
502
+ console.print("[green]✓ Download complete.[/green]")
503
+ return
504
+ except DownloadError as exc:
505
+ last_error = exc
506
+ if attempt < max_retries:
507
+ console.print(
508
+ f"[yellow]Download attempt {attempt} failed. "
509
+ f"Retrying...[/yellow]"
510
+ )
511
+ time.sleep(2) # Brief pause before retry
512
+ else:
513
+ break
514
+
515
+ # All attempts failed — clean up partial download
516
+ _cleanup_partial(local_dir)
517
+
518
+ assert last_error is not None
519
+ msg = (
520
+ f"{last_error}\n\n"
521
+ "Check your network connection and HuggingFace authentication.\n"
522
+ "Set HF_TOKEN environment variable if the model requires authentication."
523
+ )
524
+ raise DownloadError(msg)
525
+
526
+
527
+ # --------------------------------------------------------------------------- #
528
+ # MLX conversion
529
+ # --------------------------------------------------------------------------- #
530
+
531
+
532
+ def convert_model(
533
+ hf_repo: str,
534
+ local_dir: Path,
535
+ quant: str,
536
+ console: Console,
537
+ ) -> None:
538
+ """Convert a model using mlx_lm.
539
+
540
+ Downloads the base model and converts it to the specified quantization.
541
+
542
+ Args:
543
+ hf_repo: The HuggingFace repo of the base model (e.g., "Qwen/Qwen3.5-8B").
544
+ local_dir: The directory to write converted model to.
545
+ quant: The quantization level (int4, int8).
546
+ console: Rich console for output.
547
+
548
+ Raises:
549
+ ConversionError: If conversion fails.
550
+ """
551
+ # Map our quant names to mlx_lm quant names
552
+ quant_map = {
553
+ "int4": "4",
554
+ "int8": "8",
555
+ }
556
+ mlx_quant = quant_map.get(quant)
557
+
558
+ console.print(f"[cyan]Converting {hf_repo} to {quant}...[/cyan]")
559
+ console.print("[dim]This may take several minutes.[/dim]")
560
+
561
+ local_dir.mkdir(parents=True, exist_ok=True)
562
+
563
+ if mlx_quant:
564
+ # Quantized conversion
565
+ cmd = [
566
+ "python3",
567
+ "-m",
568
+ "mlx_lm.convert",
569
+ "--hf-path",
570
+ hf_repo,
571
+ "--mlx-path",
572
+ str(local_dir),
573
+ "-q",
574
+ "--q-bits",
575
+ mlx_quant,
576
+ ]
577
+ else:
578
+ # bf16 — just download, no quant
579
+ cmd = [
580
+ "python3",
581
+ "-m",
582
+ "mlx_lm.convert",
583
+ "--hf-path",
584
+ hf_repo,
585
+ "--mlx-path",
586
+ str(local_dir),
587
+ ]
588
+
589
+ try:
590
+ result = subprocess.run(
591
+ cmd,
592
+ capture_output=True,
593
+ text=True,
594
+ timeout=7200, # 2 hour timeout for large conversions
595
+ )
596
+ except FileNotFoundError:
597
+ _cleanup_partial(local_dir)
598
+ msg = (
599
+ "mlx_lm not found. Install it with:\n"
600
+ " pip install mlx_lm\n"
601
+ "Or: uv pip install mlx_lm"
602
+ )
603
+ raise ConversionError(msg) from None
604
+ except subprocess.TimeoutExpired:
605
+ _cleanup_partial(local_dir)
606
+ msg = "Conversion timed out after 2 hours."
607
+ raise ConversionError(msg) from None
608
+ except OSError as exc:
609
+ _cleanup_partial(local_dir)
610
+ msg = f"Failed to start conversion: {exc}"
611
+ raise ConversionError(msg) from None
612
+
613
+ if result.returncode != 0:
614
+ stderr = result.stderr.strip()
615
+ _cleanup_partial(local_dir)
616
+ msg = f"Conversion failed for {hf_repo}:\n{stderr}"
617
+ raise ConversionError(msg)
618
+
619
+ console.print("[green]✓ Conversion complete.[/green]")
620
+
621
+
622
+ # --------------------------------------------------------------------------- #
623
+ # Cleanup
624
+ # --------------------------------------------------------------------------- #
625
+
626
+
627
+ def _cleanup_partial(local_dir: Path) -> None:
628
+ """Remove a partial/failed download directory.
629
+
630
+ Args:
631
+ local_dir: The directory to remove.
632
+ """
633
+ if local_dir.exists():
634
+ try:
635
+ shutil.rmtree(local_dir)
636
+ except OSError:
637
+ pass
638
+
639
+
640
+ # --------------------------------------------------------------------------- #
641
+ # Quant validation
642
+ # --------------------------------------------------------------------------- #
643
+
644
+ VALID_QUANTS = {"int4", "int8", "bf16"}
645
+
646
+
647
+ def validate_quant(quant: str) -> str:
648
+ """Validate a quantization value.
649
+
650
+ Args:
651
+ quant: The quantization string to validate.
652
+
653
+ Returns:
654
+ The validated quant string.
655
+
656
+ Raises:
657
+ PullError: If the quant is not valid.
658
+ """
659
+ if quant not in VALID_QUANTS:
660
+ valid = ", ".join(sorted(VALID_QUANTS))
661
+ msg = f"Invalid quantization '{quant}'. Valid values: {valid}"
662
+ raise PullError(msg)
663
+ return quant
664
+
665
+
666
+ # --------------------------------------------------------------------------- #
667
+ # Main pull orchestrator
668
+ # --------------------------------------------------------------------------- #
669
+
670
+
671
+ def pull_model(
672
+ model_id: str,
673
+ quant: str | None = None,
674
+ force: bool = False,
675
+ console: Console | None = None,
676
+ catalog: list[CatalogEntry] | None = None,
677
+ ) -> PullResult:
678
+ """Pull (download) a model from the catalog.
679
+
680
+ Orchestrates the full pull workflow:
681
+ 1. Resolve model from catalog
682
+ 2. Determine quant (from flag or config default)
683
+ 3. Resolve source (mlx-community or convert_from)
684
+ 4. Check disk space
685
+ 5. Check for existing download (duplicate detection)
686
+ 6. Download or convert
687
+ 7. Update inventory
688
+
689
+ Args:
690
+ model_id: The catalog model ID (e.g., "qwen3.5-8b").
691
+ quant: Quantization override (None uses config default).
692
+ force: If True, re-download even if model exists.
693
+ console: Rich console for output (creates one if None).
694
+ catalog: Pre-loaded catalog (loads from package if None).
695
+
696
+ Returns:
697
+ PullResult with details of the completed pull.
698
+
699
+ Raises:
700
+ InvalidModelError: If the model ID is not in the catalog.
701
+ PullError: If the quant is invalid or unavailable.
702
+ DiskSpaceError: If insufficient disk space.
703
+ DownloadError: If download fails after retries.
704
+ ConversionError: If mlx_lm conversion fails.
705
+ """
706
+ if console is None:
707
+ console = Console()
708
+
709
+ # 1. Load catalog and resolve model
710
+ if catalog is None:
711
+ catalog = load_catalog()
712
+
713
+ entry = get_entry_by_id(catalog, model_id)
714
+ if entry is None:
715
+ msg = (
716
+ f"Model '{model_id}' not found in catalog.\n"
717
+ "Run 'mlx-stack models --catalog' to see available models."
718
+ )
719
+ raise InvalidModelError(msg)
720
+
721
+ # 2. Determine quantization
722
+ if quant is None:
723
+ try:
724
+ quant = str(get_value("default-quant"))
725
+ except Exception:
726
+ quant = "int4"
727
+
728
+ quant = validate_quant(quant)
729
+
730
+ # 3. Resolve source
731
+ source, source_type = resolve_source(entry, quant)
732
+
733
+ # 4. Get models directory and local path
734
+ models_dir = get_models_directory()
735
+ local_path = get_model_local_path(models_dir, source.hf_repo)
736
+
737
+ # 5. Check for existing download (duplicate detection)
738
+ if not force and is_model_downloaded(local_path):
739
+ # Check inventory too
740
+ inv_entry = find_in_inventory(model_id, quant)
741
+ if inv_entry is not None or is_model_downloaded(local_path):
742
+ console.print(
743
+ f"[yellow]Model '{entry.name}' ({quant}) already exists at "
744
+ f"{local_path}.[/yellow]\n"
745
+ "Use --force to re-download."
746
+ )
747
+ return PullResult(
748
+ model_id=model_id,
749
+ name=entry.name,
750
+ quant=quant,
751
+ source_type=source_type,
752
+ local_path=local_path,
753
+ already_existed=True,
754
+ disk_size_gb=source.disk_size_gb,
755
+ )
756
+
757
+ # 6. Check disk space
758
+ has_space, available_gb = check_disk_space(models_dir, source.disk_size_gb)
759
+ if not has_space:
760
+ msg = (
761
+ f"Insufficient disk space for {entry.name} ({quant}).\n"
762
+ f"Required: {source.disk_size_gb:.1f} GB (+ 20% buffer)\n"
763
+ f"Available: {available_gb:.1f} GB"
764
+ )
765
+ raise DiskSpaceError(msg)
766
+
767
+ # 7. Display info
768
+ console.print()
769
+ console.print(f"[bold cyan]Pulling {entry.name}[/bold cyan]")
770
+ console.print(f" Quantization: {quant}")
771
+ console.print(f" Source: {source.hf_repo}")
772
+ console.print(f" Type: {source_type}")
773
+ console.print(f" Estimated size: {source.disk_size_gb:.1f} GB")
774
+ console.print(f" Destination: {local_path}")
775
+ console.print()
776
+
777
+ # 8. Download or convert
778
+ if force and local_path.exists():
779
+ console.print("[yellow]Removing existing download (--force)...[/yellow]")
780
+ _cleanup_partial(local_path)
781
+
782
+ if source_type == "mlx-community":
783
+ download_model(source.hf_repo, local_path, console)
784
+ else:
785
+ # convert_from — need mlx_lm conversion
786
+ convert_model(source.hf_repo, local_path, quant, console)
787
+
788
+ # 9. Update inventory
789
+ inv = ModelInventoryEntry(
790
+ model_id=model_id,
791
+ name=entry.name,
792
+ quant=quant,
793
+ source_type=source_type,
794
+ hf_repo=source.hf_repo,
795
+ local_path=str(local_path),
796
+ disk_size_gb=source.disk_size_gb,
797
+ downloaded_at=datetime.now(timezone.utc).isoformat(),
798
+ )
799
+ add_to_inventory(inv)
800
+
801
+ console.print()
802
+ console.print(
803
+ f"[bold green]✓ {entry.name} ({quant}) is ready.[/bold green]"
804
+ )
805
+ console.print(f" Location: {local_path}")
806
+
807
+ return PullResult(
808
+ model_id=model_id,
809
+ name=entry.name,
810
+ quant=quant,
811
+ source_type=source_type,
812
+ local_path=local_path,
813
+ already_existed=False,
814
+ disk_size_gb=source.disk_size_gb,
815
+ )