mlx-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_stack/__init__.py +5 -0
- mlx_stack/_version.py +24 -0
- mlx_stack/cli/__init__.py +5 -0
- mlx_stack/cli/bench.py +221 -0
- mlx_stack/cli/config.py +166 -0
- mlx_stack/cli/down.py +109 -0
- mlx_stack/cli/init.py +180 -0
- mlx_stack/cli/install.py +165 -0
- mlx_stack/cli/logs.py +234 -0
- mlx_stack/cli/main.py +187 -0
- mlx_stack/cli/models.py +304 -0
- mlx_stack/cli/profile.py +65 -0
- mlx_stack/cli/pull.py +134 -0
- mlx_stack/cli/recommend.py +397 -0
- mlx_stack/cli/status.py +111 -0
- mlx_stack/cli/up.py +163 -0
- mlx_stack/cli/watch.py +252 -0
- mlx_stack/core/__init__.py +1 -0
- mlx_stack/core/benchmark.py +1182 -0
- mlx_stack/core/catalog.py +560 -0
- mlx_stack/core/config.py +471 -0
- mlx_stack/core/deps.py +323 -0
- mlx_stack/core/hardware.py +304 -0
- mlx_stack/core/launchd.py +531 -0
- mlx_stack/core/litellm_gen.py +188 -0
- mlx_stack/core/log_rotation.py +231 -0
- mlx_stack/core/log_viewer.py +386 -0
- mlx_stack/core/models.py +639 -0
- mlx_stack/core/paths.py +79 -0
- mlx_stack/core/process.py +887 -0
- mlx_stack/core/pull.py +815 -0
- mlx_stack/core/scoring.py +611 -0
- mlx_stack/core/stack_down.py +317 -0
- mlx_stack/core/stack_init.py +524 -0
- mlx_stack/core/stack_status.py +229 -0
- mlx_stack/core/stack_up.py +856 -0
- mlx_stack/core/watchdog.py +744 -0
- mlx_stack/data/__init__.py +1 -0
- mlx_stack/data/catalog/__init__.py +1 -0
- mlx_stack/data/catalog/deepseek-r1-32b.yaml +46 -0
- mlx_stack/data/catalog/deepseek-r1-8b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-12b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-27b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-4b.yaml +45 -0
- mlx_stack/data/catalog/llama3.3-8b.yaml +44 -0
- mlx_stack/data/catalog/nemotron-49b.yaml +41 -0
- mlx_stack/data/catalog/nemotron-8b.yaml +44 -0
- mlx_stack/data/catalog/qwen3-8b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-0.8b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-14b.yaml +46 -0
- mlx_stack/data/catalog/qwen3.5-32b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-3b.yaml +44 -0
- mlx_stack/data/catalog/qwen3.5-72b.yaml +42 -0
- mlx_stack/data/catalog/qwen3.5-8b.yaml +45 -0
- mlx_stack/py.typed +1 -0
- mlx_stack/utils/__init__.py +1 -0
- mlx_stack-0.1.0.dist-info/METADATA +397 -0
- mlx_stack-0.1.0.dist-info/RECORD +61 -0
- mlx_stack-0.1.0.dist-info/WHEEL +4 -0
- mlx_stack-0.1.0.dist-info/entry_points.txt +2 -0
- mlx_stack-0.1.0.dist-info/licenses/LICENSE +21 -0
mlx_stack/core/models.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
1
|
+
"""Local model scanning and catalog listing for mlx-stack.
|
|
2
|
+
|
|
3
|
+
Scans the configured model directory for locally downloaded models,
|
|
4
|
+
determines disk size, quantization, and source type. Identifies active
|
|
5
|
+
stack models. Provides catalog listing with hardware-specific benchmark
|
|
6
|
+
data for the --catalog view.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import yaml
|
|
17
|
+
|
|
18
|
+
from mlx_stack.core.catalog import CatalogEntry, load_catalog
|
|
19
|
+
from mlx_stack.core.config import ConfigCorruptError, get_value
|
|
20
|
+
from mlx_stack.core.hardware import HardwareProfile, load_profile
|
|
21
|
+
from mlx_stack.core.paths import get_data_home, get_stacks_dir
|
|
22
|
+
|
|
23
|
+
# --------------------------------------------------------------------------- #
|
|
24
|
+
# Exceptions
|
|
25
|
+
# --------------------------------------------------------------------------- #
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ModelsError(Exception):
|
|
29
|
+
"""Raised when model listing operations fail."""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# --------------------------------------------------------------------------- #
|
|
33
|
+
# Data classes
|
|
34
|
+
# --------------------------------------------------------------------------- #
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True)
|
|
38
|
+
class LocalModel:
|
|
39
|
+
"""A locally downloaded model discovered on disk."""
|
|
40
|
+
|
|
41
|
+
name: str
|
|
42
|
+
path: Path
|
|
43
|
+
disk_size_bytes: int
|
|
44
|
+
quant: str
|
|
45
|
+
source_type: str # "mlx-community", "converted", "unknown"
|
|
46
|
+
is_active: bool # True if referenced by the active stack
|
|
47
|
+
catalog_name: str | None # Human-readable name from catalog, if matched
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass(frozen=True)
|
|
51
|
+
class CatalogModel:
|
|
52
|
+
"""A catalog model entry formatted for display in --catalog mode."""
|
|
53
|
+
|
|
54
|
+
id: str
|
|
55
|
+
name: str
|
|
56
|
+
family: str
|
|
57
|
+
params_b: float
|
|
58
|
+
quants: list[str]
|
|
59
|
+
gen_tps: float | None
|
|
60
|
+
memory_gb: float | None
|
|
61
|
+
prompt_tps: float | None
|
|
62
|
+
is_estimated: bool
|
|
63
|
+
is_local: bool
|
|
64
|
+
tags: list[str]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# --------------------------------------------------------------------------- #
|
|
68
|
+
# Model directory resolution
|
|
69
|
+
# --------------------------------------------------------------------------- #
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_models_directory() -> Path:
|
|
73
|
+
"""Resolve the models directory from config.
|
|
74
|
+
|
|
75
|
+
Respects custom model-dir from config, falling back to the default
|
|
76
|
+
~/.mlx-stack/models/.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Path to the models directory.
|
|
80
|
+
"""
|
|
81
|
+
try:
|
|
82
|
+
model_dir = str(get_value("model-dir"))
|
|
83
|
+
return Path(model_dir).expanduser()
|
|
84
|
+
except (ConfigCorruptError, Exception):
|
|
85
|
+
return get_data_home() / "models"
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# --------------------------------------------------------------------------- #
|
|
89
|
+
# Active stack model detection
|
|
90
|
+
# --------------------------------------------------------------------------- #
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _load_active_stack() -> dict[str, Any] | None:
|
|
94
|
+
"""Load the active stack definition, if it exists.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
The parsed stack YAML dict, or None if no stack is configured.
|
|
98
|
+
"""
|
|
99
|
+
stack_path = get_stacks_dir() / "default.yaml"
|
|
100
|
+
if not stack_path.exists():
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
content = stack_path.read_text(encoding="utf-8")
|
|
105
|
+
data = yaml.safe_load(content)
|
|
106
|
+
if isinstance(data, dict):
|
|
107
|
+
return data
|
|
108
|
+
except (OSError, yaml.YAMLError):
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _get_active_stack_models(stack: dict[str, Any] | None) -> list[dict[str, str]]:
|
|
115
|
+
"""Extract model info from the active stack definition.
|
|
116
|
+
|
|
117
|
+
Returns a list (not a dict) because the same model_id could appear
|
|
118
|
+
with different quants in different tiers. Each entry contains
|
|
119
|
+
model_id, quant, source, and tier name.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
List of dicts with keys: model_id, quant, source, tier.
|
|
123
|
+
"""
|
|
124
|
+
if stack is None:
|
|
125
|
+
return []
|
|
126
|
+
|
|
127
|
+
tiers = stack.get("tiers", [])
|
|
128
|
+
result: list[dict[str, str]] = []
|
|
129
|
+
for tier in tiers:
|
|
130
|
+
model_id = tier.get("model", "")
|
|
131
|
+
if model_id:
|
|
132
|
+
result.append({
|
|
133
|
+
"model_id": model_id,
|
|
134
|
+
"quant": tier.get("quant", ""),
|
|
135
|
+
"source": tier.get("source", ""),
|
|
136
|
+
"tier": tier.get("name", ""),
|
|
137
|
+
})
|
|
138
|
+
return result
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# --------------------------------------------------------------------------- #
|
|
142
|
+
# Local model scanning
|
|
143
|
+
# --------------------------------------------------------------------------- #
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _compute_dir_size(path: Path) -> int:
|
|
147
|
+
"""Compute the total size of all files in a directory recursively.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
path: Directory path.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Total size in bytes.
|
|
154
|
+
"""
|
|
155
|
+
total = 0
|
|
156
|
+
try:
|
|
157
|
+
for dirpath, _dirnames, filenames in os.walk(path):
|
|
158
|
+
for f in filenames:
|
|
159
|
+
fp = Path(dirpath) / f
|
|
160
|
+
try:
|
|
161
|
+
total += fp.stat().st_size
|
|
162
|
+
except OSError:
|
|
163
|
+
continue
|
|
164
|
+
except OSError:
|
|
165
|
+
pass
|
|
166
|
+
return total
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _detect_quant(model_dir: Path, dirname: str) -> str:
|
|
170
|
+
"""Detect quantization level from directory name or contents.
|
|
171
|
+
|
|
172
|
+
Heuristic: looks for common patterns in directory names like
|
|
173
|
+
'4bit', '8bit', 'int4', 'int8', 'bf16'.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
model_dir: Path to the model directory.
|
|
177
|
+
dirname: Directory name.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Detected quantization string, or "unknown".
|
|
181
|
+
"""
|
|
182
|
+
lower = dirname.lower()
|
|
183
|
+
if "4bit" in lower or "int4" in lower or "-4bit" in lower:
|
|
184
|
+
return "int4"
|
|
185
|
+
if "8bit" in lower or "int8" in lower or "-8bit" in lower:
|
|
186
|
+
return "int8"
|
|
187
|
+
if "bf16" in lower or "bfloat16" in lower:
|
|
188
|
+
return "bf16"
|
|
189
|
+
if "fp16" in lower or "f16" in lower:
|
|
190
|
+
return "bf16" # Approximate — fp16 treated as bf16 for display
|
|
191
|
+
|
|
192
|
+
# Check for config.json with quantization info
|
|
193
|
+
config_path = model_dir / "config.json"
|
|
194
|
+
if config_path.exists():
|
|
195
|
+
try:
|
|
196
|
+
import json
|
|
197
|
+
|
|
198
|
+
config_data = json.loads(config_path.read_text(encoding="utf-8"))
|
|
199
|
+
quant_config = config_data.get("quantization_config", {})
|
|
200
|
+
bits = quant_config.get("bits")
|
|
201
|
+
if bits == 4:
|
|
202
|
+
return "int4"
|
|
203
|
+
if bits == 8:
|
|
204
|
+
return "int8"
|
|
205
|
+
except (OSError, ValueError, KeyError):
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
return "unknown"
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _detect_source_type(dirname: str) -> str:
|
|
212
|
+
"""Detect the source type from directory name.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
dirname: Directory name.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
Source type string: "mlx-community", "converted", or "unknown".
|
|
219
|
+
"""
|
|
220
|
+
lower = dirname.lower()
|
|
221
|
+
if "mlx-community" in lower or lower.startswith("mlx-community"):
|
|
222
|
+
return "mlx-community"
|
|
223
|
+
# Directories from HF usually have org/repo pattern replaced with --
|
|
224
|
+
if "--" in dirname:
|
|
225
|
+
parts = dirname.split("--", 1)
|
|
226
|
+
if parts[0].lower() == "mlx-community":
|
|
227
|
+
return "mlx-community"
|
|
228
|
+
return "unknown"
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _match_to_catalog(dirname: str, catalog: list[CatalogEntry]) -> CatalogEntry | None:
|
|
232
|
+
"""Try to match a local directory name to a catalog entry.
|
|
233
|
+
|
|
234
|
+
Matches by checking if the HF repo name (last part) is in the dirname.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
dirname: Local directory name.
|
|
238
|
+
catalog: Loaded catalog entries.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Matching CatalogEntry or None.
|
|
242
|
+
"""
|
|
243
|
+
lower = dirname.lower()
|
|
244
|
+
for entry in catalog:
|
|
245
|
+
for _quant, source in entry.sources.items():
|
|
246
|
+
# Extract repo name from hf_repo (after the /)
|
|
247
|
+
repo_name = (
|
|
248
|
+
source.hf_repo.rsplit("/", 1)[-1] if "/" in source.hf_repo else source.hf_repo
|
|
249
|
+
)
|
|
250
|
+
if repo_name.lower() == lower:
|
|
251
|
+
return entry
|
|
252
|
+
return None
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def scan_local_models(
|
|
256
|
+
models_dir: Path | None = None,
|
|
257
|
+
catalog: list[CatalogEntry] | None = None,
|
|
258
|
+
stack: dict[str, Any] | None = None,
|
|
259
|
+
) -> list[LocalModel]:
|
|
260
|
+
"""Scan the models directory for locally downloaded models.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
models_dir: Override models directory (uses config if None).
|
|
264
|
+
catalog: Pre-loaded catalog (loads from package if None).
|
|
265
|
+
stack: Pre-loaded stack definition (loads from file if None).
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
List of LocalModel entries found on disk, sorted by name.
|
|
269
|
+
"""
|
|
270
|
+
if models_dir is None:
|
|
271
|
+
models_dir = get_models_directory()
|
|
272
|
+
|
|
273
|
+
if not models_dir.exists() or not models_dir.is_dir():
|
|
274
|
+
return []
|
|
275
|
+
|
|
276
|
+
if catalog is None:
|
|
277
|
+
try:
|
|
278
|
+
catalog = load_catalog()
|
|
279
|
+
except Exception:
|
|
280
|
+
catalog = []
|
|
281
|
+
|
|
282
|
+
if stack is None:
|
|
283
|
+
stack = _load_active_stack()
|
|
284
|
+
|
|
285
|
+
active_stack_entries = _get_active_stack_models(stack)
|
|
286
|
+
|
|
287
|
+
results: list[LocalModel] = []
|
|
288
|
+
try:
|
|
289
|
+
entries = sorted(models_dir.iterdir())
|
|
290
|
+
except OSError:
|
|
291
|
+
return []
|
|
292
|
+
|
|
293
|
+
for item in entries:
|
|
294
|
+
if not item.is_dir():
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
dirname = item.name
|
|
298
|
+
disk_size = _compute_dir_size(item)
|
|
299
|
+
quant = _detect_quant(item, dirname)
|
|
300
|
+
source_type = _detect_source_type(dirname)
|
|
301
|
+
|
|
302
|
+
# Try to match to catalog entry for human-readable name
|
|
303
|
+
catalog_entry = _match_to_catalog(dirname, catalog)
|
|
304
|
+
catalog_name = catalog_entry.name if catalog_entry else None
|
|
305
|
+
|
|
306
|
+
# Check if this model is active in the stack using
|
|
307
|
+
# source+quant-aware identity matching.
|
|
308
|
+
is_active = False
|
|
309
|
+
for stack_entry in active_stack_entries:
|
|
310
|
+
stack_source = stack_entry.get("source", "")
|
|
311
|
+
stack_quant = stack_entry.get("quant", "")
|
|
312
|
+
stack_model_id = stack_entry.get("model_id", "")
|
|
313
|
+
source_dir = (
|
|
314
|
+
stack_source.rsplit("/", 1)[-1]
|
|
315
|
+
if "/" in stack_source
|
|
316
|
+
else stack_source
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Primary match: source directory name matches AND quant matches
|
|
320
|
+
if source_dir and source_dir == dirname:
|
|
321
|
+
# Source dir match found — verify quant compatibility
|
|
322
|
+
if not stack_quant or quant == "unknown" or stack_quant == quant:
|
|
323
|
+
is_active = True
|
|
324
|
+
break
|
|
325
|
+
|
|
326
|
+
# Secondary match: model_id matches dirname AND quant matches
|
|
327
|
+
if stack_model_id == dirname:
|
|
328
|
+
if not stack_quant or quant == "unknown" or stack_quant == quant:
|
|
329
|
+
is_active = True
|
|
330
|
+
break
|
|
331
|
+
|
|
332
|
+
# Tertiary match: catalog entry ID matches stack model ID,
|
|
333
|
+
# AND the local model's source matches the catalog source for
|
|
334
|
+
# the stack's quant
|
|
335
|
+
if catalog_entry and catalog_entry.id == stack_model_id:
|
|
336
|
+
# Verify quant-aware source match
|
|
337
|
+
if stack_quant in catalog_entry.sources:
|
|
338
|
+
expected_source = catalog_entry.sources[stack_quant]
|
|
339
|
+
expected_dir = (
|
|
340
|
+
expected_source.hf_repo.rsplit("/", 1)[-1]
|
|
341
|
+
if "/" in expected_source.hf_repo
|
|
342
|
+
else expected_source.hf_repo
|
|
343
|
+
)
|
|
344
|
+
if expected_dir == dirname:
|
|
345
|
+
is_active = True
|
|
346
|
+
break
|
|
347
|
+
elif not stack_quant or quant == "unknown" or stack_quant == quant:
|
|
348
|
+
# Fallback: catalog match with compatible quant
|
|
349
|
+
is_active = True
|
|
350
|
+
break
|
|
351
|
+
|
|
352
|
+
results.append(
|
|
353
|
+
LocalModel(
|
|
354
|
+
name=dirname,
|
|
355
|
+
path=item,
|
|
356
|
+
disk_size_bytes=disk_size,
|
|
357
|
+
quant=quant,
|
|
358
|
+
source_type=source_type,
|
|
359
|
+
is_active=is_active,
|
|
360
|
+
catalog_name=catalog_name,
|
|
361
|
+
)
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
return results
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
# --------------------------------------------------------------------------- #
|
|
368
|
+
# Remote-only stack models (in stack but not downloaded)
|
|
369
|
+
# --------------------------------------------------------------------------- #
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def get_remote_stack_models(
|
|
373
|
+
local_models: list[LocalModel],
|
|
374
|
+
stack: dict[str, Any] | None = None,
|
|
375
|
+
catalog: list[CatalogEntry] | None = None,
|
|
376
|
+
) -> list[dict[str, Any]]:
|
|
377
|
+
"""Find models referenced in the active stack but not downloaded locally.
|
|
378
|
+
|
|
379
|
+
Uses source+quant-aware matching to determine whether a stack model
|
|
380
|
+
is available locally.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
local_models: Already-scanned local models.
|
|
384
|
+
stack: Pre-loaded stack definition (loads from file if None).
|
|
385
|
+
catalog: Pre-loaded catalog (loads from package if None).
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
List of dicts with model info for remote-only stack models.
|
|
389
|
+
"""
|
|
390
|
+
if stack is None:
|
|
391
|
+
stack = _load_active_stack()
|
|
392
|
+
|
|
393
|
+
if stack is None:
|
|
394
|
+
return []
|
|
395
|
+
|
|
396
|
+
if catalog is None:
|
|
397
|
+
try:
|
|
398
|
+
catalog = load_catalog()
|
|
399
|
+
except Exception:
|
|
400
|
+
catalog = []
|
|
401
|
+
|
|
402
|
+
active_stack_entries = _get_active_stack_models(stack)
|
|
403
|
+
|
|
404
|
+
remote_models: list[dict[str, Any]] = []
|
|
405
|
+
for stack_entry in active_stack_entries:
|
|
406
|
+
model_id = stack_entry["model_id"]
|
|
407
|
+
stack_source = stack_entry.get("source", "")
|
|
408
|
+
stack_quant = stack_entry.get("quant", "int4")
|
|
409
|
+
source_dir = (
|
|
410
|
+
stack_source.rsplit("/", 1)[-1]
|
|
411
|
+
if "/" in stack_source
|
|
412
|
+
else stack_source
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
# Check if locally available using source+quant-aware matching
|
|
416
|
+
is_local = False
|
|
417
|
+
|
|
418
|
+
for lm in local_models:
|
|
419
|
+
# Match by source directory name + quant
|
|
420
|
+
if source_dir and source_dir == lm.name:
|
|
421
|
+
if not stack_quant or lm.quant == "unknown" or stack_quant == lm.quant:
|
|
422
|
+
is_local = True
|
|
423
|
+
break
|
|
424
|
+
|
|
425
|
+
# Match by model_id as dirname + quant
|
|
426
|
+
if model_id == lm.name:
|
|
427
|
+
if not stack_quant or lm.quant == "unknown" or stack_quant == lm.quant:
|
|
428
|
+
is_local = True
|
|
429
|
+
break
|
|
430
|
+
|
|
431
|
+
# Match by catalog entry with quant-aware source matching
|
|
432
|
+
cat_entry = _find_catalog_entry(catalog, model_id)
|
|
433
|
+
if cat_entry and lm.catalog_name and cat_entry.name == lm.catalog_name:
|
|
434
|
+
# Verify quant-aware source match
|
|
435
|
+
if stack_quant in cat_entry.sources:
|
|
436
|
+
expected_source = cat_entry.sources[stack_quant]
|
|
437
|
+
expected_dir = (
|
|
438
|
+
expected_source.hf_repo.rsplit("/", 1)[-1]
|
|
439
|
+
if "/" in expected_source.hf_repo
|
|
440
|
+
else expected_source.hf_repo
|
|
441
|
+
)
|
|
442
|
+
if expected_dir == lm.name:
|
|
443
|
+
is_local = True
|
|
444
|
+
break
|
|
445
|
+
elif not stack_quant or lm.quant == "unknown" or stack_quant == lm.quant:
|
|
446
|
+
is_local = True
|
|
447
|
+
break
|
|
448
|
+
|
|
449
|
+
if not is_local:
|
|
450
|
+
# Find estimated download size from catalog
|
|
451
|
+
entry = _find_catalog_entry(catalog, model_id)
|
|
452
|
+
est_size_gb: float | None = None
|
|
453
|
+
catalog_name: str | None = None
|
|
454
|
+
if entry:
|
|
455
|
+
catalog_name = entry.name
|
|
456
|
+
if stack_quant in entry.sources:
|
|
457
|
+
est_size_gb = entry.sources[stack_quant].disk_size_gb
|
|
458
|
+
|
|
459
|
+
remote_models.append(
|
|
460
|
+
{
|
|
461
|
+
"model_id": model_id,
|
|
462
|
+
"tier": stack_entry.get("tier", ""),
|
|
463
|
+
"quant": stack_quant,
|
|
464
|
+
"source": "remote",
|
|
465
|
+
"catalog_name": catalog_name or model_id,
|
|
466
|
+
"est_size_gb": est_size_gb,
|
|
467
|
+
}
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
return remote_models
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def _find_catalog_entry(catalog: list[CatalogEntry], model_id: str) -> CatalogEntry | None:
|
|
474
|
+
"""Find a catalog entry by model ID."""
|
|
475
|
+
for entry in catalog:
|
|
476
|
+
if entry.id == model_id:
|
|
477
|
+
return entry
|
|
478
|
+
return None
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
# --------------------------------------------------------------------------- #
|
|
482
|
+
# Catalog listing with hardware-specific data
|
|
483
|
+
# --------------------------------------------------------------------------- #
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def list_catalog_models(
|
|
487
|
+
catalog: list[CatalogEntry] | None = None,
|
|
488
|
+
profile: HardwareProfile | None = None,
|
|
489
|
+
local_models: list[LocalModel] | None = None,
|
|
490
|
+
) -> list[CatalogModel]:
|
|
491
|
+
"""Build a catalog listing with hardware-specific benchmark data.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
catalog: Pre-loaded catalog (loads from package if None).
|
|
495
|
+
profile: Hardware profile for benchmark lookup (loads from file if None).
|
|
496
|
+
local_models: Pre-scanned local models for local availability check.
|
|
497
|
+
|
|
498
|
+
Returns:
|
|
499
|
+
List of CatalogModel entries for display.
|
|
500
|
+
"""
|
|
501
|
+
if catalog is None:
|
|
502
|
+
catalog = load_catalog()
|
|
503
|
+
|
|
504
|
+
if profile is None:
|
|
505
|
+
profile = load_profile()
|
|
506
|
+
|
|
507
|
+
if local_models is None:
|
|
508
|
+
local_models = scan_local_models(catalog=catalog)
|
|
509
|
+
|
|
510
|
+
# Build set of locally available model IDs (via catalog match)
|
|
511
|
+
local_catalog_ids: set[str] = set()
|
|
512
|
+
for lm in local_models:
|
|
513
|
+
if lm.catalog_name:
|
|
514
|
+
# Find ID from catalog by name
|
|
515
|
+
for entry in catalog:
|
|
516
|
+
if entry.name == lm.catalog_name:
|
|
517
|
+
local_catalog_ids.add(entry.id)
|
|
518
|
+
break
|
|
519
|
+
# Also check if dirname matches any catalog entry's source repo
|
|
520
|
+
for entry in catalog:
|
|
521
|
+
for _quant, source in entry.sources.items():
|
|
522
|
+
repo_name = (
|
|
523
|
+
source.hf_repo.rsplit("/", 1)[-1] if "/" in source.hf_repo else source.hf_repo
|
|
524
|
+
)
|
|
525
|
+
if repo_name == lm.name:
|
|
526
|
+
local_catalog_ids.add(entry.id)
|
|
527
|
+
|
|
528
|
+
results: list[CatalogModel] = []
|
|
529
|
+
for entry in catalog:
|
|
530
|
+
quants = sorted(entry.sources.keys())
|
|
531
|
+
|
|
532
|
+
gen_tps: float | None = None
|
|
533
|
+
memory_gb: float | None = None
|
|
534
|
+
prompt_tps: float | None = None
|
|
535
|
+
is_estimated = False
|
|
536
|
+
|
|
537
|
+
if profile is not None:
|
|
538
|
+
profile_id = profile.profile_id
|
|
539
|
+
if profile_id in entry.benchmarks:
|
|
540
|
+
bench = entry.benchmarks[profile_id]
|
|
541
|
+
gen_tps = bench.gen_tps
|
|
542
|
+
memory_gb = bench.memory_gb
|
|
543
|
+
prompt_tps = bench.prompt_tps
|
|
544
|
+
else:
|
|
545
|
+
# Bandwidth-ratio estimation
|
|
546
|
+
estimated = _estimate_benchmark(entry, profile)
|
|
547
|
+
if estimated:
|
|
548
|
+
gen_tps, memory_gb, prompt_tps = estimated
|
|
549
|
+
is_estimated = True
|
|
550
|
+
|
|
551
|
+
is_local = entry.id in local_catalog_ids
|
|
552
|
+
|
|
553
|
+
results.append(
|
|
554
|
+
CatalogModel(
|
|
555
|
+
id=entry.id,
|
|
556
|
+
name=entry.name,
|
|
557
|
+
family=entry.family,
|
|
558
|
+
params_b=entry.params_b,
|
|
559
|
+
quants=quants,
|
|
560
|
+
gen_tps=gen_tps,
|
|
561
|
+
memory_gb=memory_gb,
|
|
562
|
+
prompt_tps=prompt_tps,
|
|
563
|
+
is_estimated=is_estimated,
|
|
564
|
+
is_local=is_local,
|
|
565
|
+
tags=entry.tags,
|
|
566
|
+
)
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
return results
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def _estimate_benchmark(
|
|
573
|
+
entry: CatalogEntry,
|
|
574
|
+
profile: HardwareProfile,
|
|
575
|
+
) -> tuple[float, float, float] | None:
|
|
576
|
+
"""Estimate benchmark data using bandwidth ratio.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
entry: Catalog entry with benchmark data.
|
|
580
|
+
profile: Hardware profile with bandwidth info.
|
|
581
|
+
|
|
582
|
+
Returns:
|
|
583
|
+
Tuple of (gen_tps, memory_gb, prompt_tps) or None if no reference data.
|
|
584
|
+
"""
|
|
585
|
+
from mlx_stack.core.scoring import _REFERENCE_PROFILES
|
|
586
|
+
|
|
587
|
+
if not entry.benchmarks:
|
|
588
|
+
return None
|
|
589
|
+
|
|
590
|
+
# Find a reference benchmark
|
|
591
|
+
ref_profile_id: str | None = None
|
|
592
|
+
ref_bench = None
|
|
593
|
+
|
|
594
|
+
for pid, bench in entry.benchmarks.items():
|
|
595
|
+
if pid in _REFERENCE_PROFILES:
|
|
596
|
+
ref_profile_id = pid
|
|
597
|
+
ref_bench = bench
|
|
598
|
+
break
|
|
599
|
+
|
|
600
|
+
if ref_profile_id is None or ref_bench is None:
|
|
601
|
+
# Use first available benchmark
|
|
602
|
+
ref_profile_id = next(iter(entry.benchmarks))
|
|
603
|
+
ref_bench = entry.benchmarks[ref_profile_id]
|
|
604
|
+
|
|
605
|
+
ref_bw = _REFERENCE_PROFILES.get(ref_profile_id)
|
|
606
|
+
if ref_bw is None:
|
|
607
|
+
# Can't estimate without reference bandwidth
|
|
608
|
+
# Use benchmark data as-is
|
|
609
|
+
return ref_bench.gen_tps, ref_bench.memory_gb, ref_bench.prompt_tps
|
|
610
|
+
|
|
611
|
+
ratio = profile.bandwidth_gbps / ref_bw
|
|
612
|
+
return (
|
|
613
|
+
round(ref_bench.gen_tps * ratio, 1),
|
|
614
|
+
ref_bench.memory_gb, # Memory usage is hardware-independent
|
|
615
|
+
round(ref_bench.prompt_tps * ratio, 1),
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
# --------------------------------------------------------------------------- #
|
|
620
|
+
# Human-readable size formatting
|
|
621
|
+
# --------------------------------------------------------------------------- #
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
def format_size(size_bytes: int) -> str:
|
|
625
|
+
"""Format a byte count into a human-readable size string.
|
|
626
|
+
|
|
627
|
+
Args:
|
|
628
|
+
size_bytes: Size in bytes.
|
|
629
|
+
|
|
630
|
+
Returns:
|
|
631
|
+
Formatted string like '5.2 GB', '850 MB', '120 KB'.
|
|
632
|
+
"""
|
|
633
|
+
if size_bytes >= 1_000_000_000:
|
|
634
|
+
return f"{size_bytes / 1_000_000_000:.1f} GB"
|
|
635
|
+
if size_bytes >= 1_000_000:
|
|
636
|
+
return f"{size_bytes / 1_000_000:.0f} MB"
|
|
637
|
+
if size_bytes >= 1_000:
|
|
638
|
+
return f"{size_bytes / 1_000:.0f} KB"
|
|
639
|
+
return f"{size_bytes} B"
|