mlx-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_stack/__init__.py +5 -0
- mlx_stack/_version.py +24 -0
- mlx_stack/cli/__init__.py +5 -0
- mlx_stack/cli/bench.py +221 -0
- mlx_stack/cli/config.py +166 -0
- mlx_stack/cli/down.py +109 -0
- mlx_stack/cli/init.py +180 -0
- mlx_stack/cli/install.py +165 -0
- mlx_stack/cli/logs.py +234 -0
- mlx_stack/cli/main.py +187 -0
- mlx_stack/cli/models.py +304 -0
- mlx_stack/cli/profile.py +65 -0
- mlx_stack/cli/pull.py +134 -0
- mlx_stack/cli/recommend.py +397 -0
- mlx_stack/cli/status.py +111 -0
- mlx_stack/cli/up.py +163 -0
- mlx_stack/cli/watch.py +252 -0
- mlx_stack/core/__init__.py +1 -0
- mlx_stack/core/benchmark.py +1182 -0
- mlx_stack/core/catalog.py +560 -0
- mlx_stack/core/config.py +471 -0
- mlx_stack/core/deps.py +323 -0
- mlx_stack/core/hardware.py +304 -0
- mlx_stack/core/launchd.py +531 -0
- mlx_stack/core/litellm_gen.py +188 -0
- mlx_stack/core/log_rotation.py +231 -0
- mlx_stack/core/log_viewer.py +386 -0
- mlx_stack/core/models.py +639 -0
- mlx_stack/core/paths.py +79 -0
- mlx_stack/core/process.py +887 -0
- mlx_stack/core/pull.py +815 -0
- mlx_stack/core/scoring.py +611 -0
- mlx_stack/core/stack_down.py +317 -0
- mlx_stack/core/stack_init.py +524 -0
- mlx_stack/core/stack_status.py +229 -0
- mlx_stack/core/stack_up.py +856 -0
- mlx_stack/core/watchdog.py +744 -0
- mlx_stack/data/__init__.py +1 -0
- mlx_stack/data/catalog/__init__.py +1 -0
- mlx_stack/data/catalog/deepseek-r1-32b.yaml +46 -0
- mlx_stack/data/catalog/deepseek-r1-8b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-12b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-27b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-4b.yaml +45 -0
- mlx_stack/data/catalog/llama3.3-8b.yaml +44 -0
- mlx_stack/data/catalog/nemotron-49b.yaml +41 -0
- mlx_stack/data/catalog/nemotron-8b.yaml +44 -0
- mlx_stack/data/catalog/qwen3-8b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-0.8b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-14b.yaml +46 -0
- mlx_stack/data/catalog/qwen3.5-32b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-3b.yaml +44 -0
- mlx_stack/data/catalog/qwen3.5-72b.yaml +42 -0
- mlx_stack/data/catalog/qwen3.5-8b.yaml +45 -0
- mlx_stack/py.typed +1 -0
- mlx_stack/utils/__init__.py +1 -0
- mlx_stack-0.1.0.dist-info/METADATA +397 -0
- mlx_stack-0.1.0.dist-info/RECORD +61 -0
- mlx_stack-0.1.0.dist-info/WHEEL +4 -0
- mlx_stack-0.1.0.dist-info/entry_points.txt +2 -0
- mlx_stack-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,560 @@
|
|
|
1
|
+
"""Model catalog system for mlx-stack.
|
|
2
|
+
|
|
3
|
+
Loads, validates, and queries the curated catalog of MLX-compatible models
|
|
4
|
+
shipped as YAML data files with the package. Each catalog entry describes
|
|
5
|
+
a model's identity, architecture, quantization sources, capabilities,
|
|
6
|
+
quality scores, hardware benchmarks, and tags.
|
|
7
|
+
|
|
8
|
+
Uses importlib.resources for accessing package data files.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import importlib.resources
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import yaml
|
|
18
|
+
|
|
19
|
+
# --------------------------------------------------------------------------- #
|
|
20
|
+
# Exceptions
|
|
21
|
+
# --------------------------------------------------------------------------- #
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CatalogError(Exception):
|
|
25
|
+
"""Raised when catalog loading or validation fails."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# --------------------------------------------------------------------------- #
|
|
29
|
+
# Schema — required fields and their expected types
|
|
30
|
+
# --------------------------------------------------------------------------- #
|
|
31
|
+
|
|
32
|
+
# Top-level required fields and their types
|
|
33
|
+
_REQUIRED_FIELDS: dict[str, type | tuple[type, ...]] = {
|
|
34
|
+
"id": str,
|
|
35
|
+
"name": str,
|
|
36
|
+
"family": str,
|
|
37
|
+
"params_b": (int, float),
|
|
38
|
+
"architecture": str,
|
|
39
|
+
"min_mlx_lm_version": str,
|
|
40
|
+
"sources": dict,
|
|
41
|
+
"capabilities": dict,
|
|
42
|
+
"quality": dict,
|
|
43
|
+
"benchmarks": dict,
|
|
44
|
+
"tags": list,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Required capability fields
|
|
48
|
+
_REQUIRED_CAPABILITIES: set[str] = {
|
|
49
|
+
"tool_calling",
|
|
50
|
+
"tool_call_parser",
|
|
51
|
+
"thinking",
|
|
52
|
+
"reasoning_parser",
|
|
53
|
+
"vision",
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
# Required quality fields
|
|
57
|
+
_REQUIRED_QUALITY_FIELDS: set[str] = {
|
|
58
|
+
"overall",
|
|
59
|
+
"coding",
|
|
60
|
+
"reasoning",
|
|
61
|
+
"instruction_following",
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
# Valid quantization levels
|
|
65
|
+
_VALID_QUANTS: set[str] = {"int4", "int8", "bf16"}
|
|
66
|
+
|
|
67
|
+
# Required source fields per quant
|
|
68
|
+
_REQUIRED_SOURCE_FIELDS: set[str] = {"hf_repo", "disk_size_gb"}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# --------------------------------------------------------------------------- #
|
|
72
|
+
# Data classes
|
|
73
|
+
# --------------------------------------------------------------------------- #
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass(frozen=True)
|
|
77
|
+
class QuantSource:
|
|
78
|
+
"""A quantization source for a model."""
|
|
79
|
+
|
|
80
|
+
hf_repo: str
|
|
81
|
+
disk_size_gb: float
|
|
82
|
+
convert_from: bool = False
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass(frozen=True)
|
|
86
|
+
class Capabilities:
|
|
87
|
+
"""Model capabilities."""
|
|
88
|
+
|
|
89
|
+
tool_calling: bool
|
|
90
|
+
tool_call_parser: str | None
|
|
91
|
+
thinking: bool
|
|
92
|
+
reasoning_parser: str | None
|
|
93
|
+
vision: bool
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass(frozen=True)
|
|
97
|
+
class QualityScores:
|
|
98
|
+
"""Model quality scores (0–100 scale)."""
|
|
99
|
+
|
|
100
|
+
overall: int
|
|
101
|
+
coding: int
|
|
102
|
+
reasoning: int
|
|
103
|
+
instruction_following: int
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass(frozen=True)
|
|
107
|
+
class BenchmarkResult:
|
|
108
|
+
"""Benchmark data for a specific hardware profile."""
|
|
109
|
+
|
|
110
|
+
prompt_tps: float
|
|
111
|
+
gen_tps: float
|
|
112
|
+
memory_gb: float
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass(frozen=True)
|
|
116
|
+
class CatalogEntry:
|
|
117
|
+
"""A single model entry in the catalog."""
|
|
118
|
+
|
|
119
|
+
id: str
|
|
120
|
+
name: str
|
|
121
|
+
family: str
|
|
122
|
+
params_b: float
|
|
123
|
+
architecture: str
|
|
124
|
+
min_mlx_lm_version: str
|
|
125
|
+
sources: dict[str, QuantSource]
|
|
126
|
+
capabilities: Capabilities
|
|
127
|
+
quality: QualityScores
|
|
128
|
+
benchmarks: dict[str, BenchmarkResult]
|
|
129
|
+
tags: list[str] = field(default_factory=list)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
# --------------------------------------------------------------------------- #
|
|
133
|
+
# Validation
|
|
134
|
+
# --------------------------------------------------------------------------- #
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _validate_entry(data: dict[str, Any], filename: str) -> None:
|
|
138
|
+
"""Validate a single catalog entry against the expected schema.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
data: Parsed YAML data for a catalog entry.
|
|
142
|
+
filename: The filename, used in error messages.
|
|
143
|
+
|
|
144
|
+
Raises:
|
|
145
|
+
CatalogError: If validation fails.
|
|
146
|
+
"""
|
|
147
|
+
# Check required top-level fields
|
|
148
|
+
for field_name, expected_type in _REQUIRED_FIELDS.items():
|
|
149
|
+
if field_name not in data:
|
|
150
|
+
msg = f"Catalog file '{filename}': missing required field '{field_name}'"
|
|
151
|
+
raise CatalogError(msg)
|
|
152
|
+
if not isinstance(data[field_name], expected_type):
|
|
153
|
+
msg = (
|
|
154
|
+
f"Catalog file '{filename}': field '{field_name}' has wrong type "
|
|
155
|
+
f"(expected {expected_type}, got {type(data[field_name]).__name__})"
|
|
156
|
+
)
|
|
157
|
+
raise CatalogError(msg)
|
|
158
|
+
|
|
159
|
+
# Validate sources — each quant must have required fields
|
|
160
|
+
sources = data["sources"]
|
|
161
|
+
if not sources:
|
|
162
|
+
msg = f"Catalog file '{filename}': 'sources' must not be empty"
|
|
163
|
+
raise CatalogError(msg)
|
|
164
|
+
for quant, source_data in sources.items():
|
|
165
|
+
if quant not in _VALID_QUANTS:
|
|
166
|
+
msg = (
|
|
167
|
+
f"Catalog file '{filename}': invalid quantization '{quant}' "
|
|
168
|
+
f"(valid: {', '.join(sorted(_VALID_QUANTS))})"
|
|
169
|
+
)
|
|
170
|
+
raise CatalogError(msg)
|
|
171
|
+
if not isinstance(source_data, dict):
|
|
172
|
+
msg = f"Catalog file '{filename}': source for quant '{quant}' must be a mapping"
|
|
173
|
+
raise CatalogError(msg)
|
|
174
|
+
for req_field in _REQUIRED_SOURCE_FIELDS:
|
|
175
|
+
if req_field not in source_data:
|
|
176
|
+
msg = (
|
|
177
|
+
f"Catalog file '{filename}': source '{quant}' missing "
|
|
178
|
+
f"required field '{req_field}'"
|
|
179
|
+
)
|
|
180
|
+
raise CatalogError(msg)
|
|
181
|
+
# Validate disk_size_gb is numeric
|
|
182
|
+
disk_size = source_data["disk_size_gb"]
|
|
183
|
+
if not isinstance(disk_size, (int, float)):
|
|
184
|
+
msg = (
|
|
185
|
+
f"Catalog file '{filename}': source '{quant}' field 'disk_size_gb' "
|
|
186
|
+
f"must be numeric, got {type(disk_size).__name__}"
|
|
187
|
+
)
|
|
188
|
+
raise CatalogError(msg)
|
|
189
|
+
|
|
190
|
+
# Validate capabilities
|
|
191
|
+
caps = data["capabilities"]
|
|
192
|
+
for cap_field in _REQUIRED_CAPABILITIES:
|
|
193
|
+
if cap_field not in caps:
|
|
194
|
+
msg = (
|
|
195
|
+
f"Catalog file '{filename}': capabilities missing "
|
|
196
|
+
f"required field '{cap_field}'"
|
|
197
|
+
)
|
|
198
|
+
raise CatalogError(msg)
|
|
199
|
+
|
|
200
|
+
# Validate quality scores
|
|
201
|
+
quality = data["quality"]
|
|
202
|
+
for q_field in _REQUIRED_QUALITY_FIELDS:
|
|
203
|
+
if q_field not in quality:
|
|
204
|
+
msg = (
|
|
205
|
+
f"Catalog file '{filename}': quality missing "
|
|
206
|
+
f"required field '{q_field}'"
|
|
207
|
+
)
|
|
208
|
+
raise CatalogError(msg)
|
|
209
|
+
q_value = quality[q_field]
|
|
210
|
+
if not isinstance(q_value, (int, float)):
|
|
211
|
+
msg = (
|
|
212
|
+
f"Catalog file '{filename}': quality field '{q_field}' "
|
|
213
|
+
f"must be numeric, got {type(q_value).__name__}"
|
|
214
|
+
)
|
|
215
|
+
raise CatalogError(msg)
|
|
216
|
+
|
|
217
|
+
# Validate benchmarks — each entry must have prompt_tps, gen_tps, memory_gb
|
|
218
|
+
benchmarks = data["benchmarks"]
|
|
219
|
+
for hw_key, bench_data in benchmarks.items():
|
|
220
|
+
if not isinstance(bench_data, dict):
|
|
221
|
+
msg = (
|
|
222
|
+
f"Catalog file '{filename}': benchmark entry '{hw_key}' "
|
|
223
|
+
f"must be a mapping"
|
|
224
|
+
)
|
|
225
|
+
raise CatalogError(msg)
|
|
226
|
+
for req_field in ("prompt_tps", "gen_tps", "memory_gb"):
|
|
227
|
+
if req_field not in bench_data:
|
|
228
|
+
msg = (
|
|
229
|
+
f"Catalog file '{filename}': benchmark '{hw_key}' missing "
|
|
230
|
+
f"required field '{req_field}'"
|
|
231
|
+
)
|
|
232
|
+
raise CatalogError(msg)
|
|
233
|
+
bench_value = bench_data[req_field]
|
|
234
|
+
if not isinstance(bench_value, (int, float)):
|
|
235
|
+
msg = (
|
|
236
|
+
f"Catalog file '{filename}': benchmark '{hw_key}' field "
|
|
237
|
+
f"'{req_field}' must be numeric, got {type(bench_value).__name__}"
|
|
238
|
+
)
|
|
239
|
+
raise CatalogError(msg)
|
|
240
|
+
|
|
241
|
+
# Validate tags is a list of strings
|
|
242
|
+
for tag in data["tags"]:
|
|
243
|
+
if not isinstance(tag, str):
|
|
244
|
+
msg = f"Catalog file '{filename}': tags must be strings, got {type(tag).__name__}"
|
|
245
|
+
raise CatalogError(msg)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# --------------------------------------------------------------------------- #
|
|
249
|
+
# Parsing
|
|
250
|
+
# --------------------------------------------------------------------------- #
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _parse_entry(data: dict[str, Any]) -> CatalogEntry:
|
|
254
|
+
"""Parse a validated dictionary into a CatalogEntry.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
data: Validated YAML data.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
A CatalogEntry instance.
|
|
261
|
+
|
|
262
|
+
Raises:
|
|
263
|
+
CatalogError: If type coercion fails for any nested field value.
|
|
264
|
+
"""
|
|
265
|
+
model_id = data.get("id", "<unknown>")
|
|
266
|
+
|
|
267
|
+
# Parse sources
|
|
268
|
+
sources: dict[str, QuantSource] = {}
|
|
269
|
+
for quant, source_data in data["sources"].items():
|
|
270
|
+
try:
|
|
271
|
+
sources[quant] = QuantSource(
|
|
272
|
+
hf_repo=source_data["hf_repo"],
|
|
273
|
+
disk_size_gb=float(source_data["disk_size_gb"]),
|
|
274
|
+
convert_from=bool(source_data.get("convert_from", False)),
|
|
275
|
+
)
|
|
276
|
+
except (ValueError, TypeError) as exc:
|
|
277
|
+
msg = (
|
|
278
|
+
f"Catalog entry '{model_id}': invalid value in source '{quant}': {exc}"
|
|
279
|
+
)
|
|
280
|
+
raise CatalogError(msg) from None
|
|
281
|
+
|
|
282
|
+
# Parse capabilities
|
|
283
|
+
caps_data = data["capabilities"]
|
|
284
|
+
try:
|
|
285
|
+
capabilities = Capabilities(
|
|
286
|
+
tool_calling=bool(caps_data["tool_calling"]),
|
|
287
|
+
tool_call_parser=caps_data.get("tool_call_parser") or None,
|
|
288
|
+
thinking=bool(caps_data["thinking"]),
|
|
289
|
+
reasoning_parser=caps_data.get("reasoning_parser") or None,
|
|
290
|
+
vision=bool(caps_data["vision"]),
|
|
291
|
+
)
|
|
292
|
+
except (ValueError, TypeError) as exc:
|
|
293
|
+
msg = f"Catalog entry '{model_id}': invalid value in capabilities: {exc}"
|
|
294
|
+
raise CatalogError(msg) from None
|
|
295
|
+
|
|
296
|
+
# Parse quality scores
|
|
297
|
+
q_data = data["quality"]
|
|
298
|
+
try:
|
|
299
|
+
quality = QualityScores(
|
|
300
|
+
overall=int(q_data["overall"]),
|
|
301
|
+
coding=int(q_data["coding"]),
|
|
302
|
+
reasoning=int(q_data["reasoning"]),
|
|
303
|
+
instruction_following=int(q_data["instruction_following"]),
|
|
304
|
+
)
|
|
305
|
+
except (ValueError, TypeError) as exc:
|
|
306
|
+
msg = f"Catalog entry '{model_id}': invalid value in quality scores: {exc}"
|
|
307
|
+
raise CatalogError(msg) from None
|
|
308
|
+
|
|
309
|
+
# Parse benchmarks
|
|
310
|
+
benchmarks: dict[str, BenchmarkResult] = {}
|
|
311
|
+
for hw_key, bench_data in data["benchmarks"].items():
|
|
312
|
+
try:
|
|
313
|
+
benchmarks[hw_key] = BenchmarkResult(
|
|
314
|
+
prompt_tps=float(bench_data["prompt_tps"]),
|
|
315
|
+
gen_tps=float(bench_data["gen_tps"]),
|
|
316
|
+
memory_gb=float(bench_data["memory_gb"]),
|
|
317
|
+
)
|
|
318
|
+
except (ValueError, TypeError) as exc:
|
|
319
|
+
msg = (
|
|
320
|
+
f"Catalog entry '{model_id}': invalid value in "
|
|
321
|
+
f"benchmark '{hw_key}': {exc}"
|
|
322
|
+
)
|
|
323
|
+
raise CatalogError(msg) from None
|
|
324
|
+
|
|
325
|
+
try:
|
|
326
|
+
return CatalogEntry(
|
|
327
|
+
id=str(data["id"]),
|
|
328
|
+
name=str(data["name"]),
|
|
329
|
+
family=str(data["family"]),
|
|
330
|
+
params_b=float(data["params_b"]),
|
|
331
|
+
architecture=str(data["architecture"]),
|
|
332
|
+
min_mlx_lm_version=str(data["min_mlx_lm_version"]),
|
|
333
|
+
sources=sources,
|
|
334
|
+
capabilities=capabilities,
|
|
335
|
+
quality=quality,
|
|
336
|
+
benchmarks=benchmarks,
|
|
337
|
+
tags=list(data.get("tags", [])),
|
|
338
|
+
)
|
|
339
|
+
except (ValueError, TypeError) as exc:
|
|
340
|
+
msg = f"Catalog entry '{model_id}': invalid top-level field value: {exc}"
|
|
341
|
+
raise CatalogError(msg) from None
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
# --------------------------------------------------------------------------- #
|
|
345
|
+
# Loading
|
|
346
|
+
# --------------------------------------------------------------------------- #
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def load_catalog() -> list[CatalogEntry]:
|
|
350
|
+
"""Load all catalog entries from the shipped YAML data files.
|
|
351
|
+
|
|
352
|
+
Uses importlib.resources to locate the catalog directory within the
|
|
353
|
+
installed package. Each .yaml file in the catalog directory is loaded,
|
|
354
|
+
validated, and parsed into a CatalogEntry.
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
A list of CatalogEntry instances, sorted by family then params_b.
|
|
358
|
+
|
|
359
|
+
Raises:
|
|
360
|
+
CatalogError: If any catalog file is missing, corrupt, or invalid.
|
|
361
|
+
"""
|
|
362
|
+
entries: list[CatalogEntry] = []
|
|
363
|
+
|
|
364
|
+
try:
|
|
365
|
+
catalog_pkg = importlib.resources.files("mlx_stack.data.catalog")
|
|
366
|
+
except (ModuleNotFoundError, TypeError) as exc:
|
|
367
|
+
msg = f"Could not locate catalog data directory: {exc}"
|
|
368
|
+
raise CatalogError(msg) from None
|
|
369
|
+
|
|
370
|
+
yaml_files: list[Any] = []
|
|
371
|
+
try:
|
|
372
|
+
for item in catalog_pkg.iterdir():
|
|
373
|
+
if hasattr(item, "name") and item.name.endswith(".yaml"):
|
|
374
|
+
yaml_files.append(item)
|
|
375
|
+
except (OSError, TypeError) as exc:
|
|
376
|
+
msg = f"Could not read catalog directory: {exc}"
|
|
377
|
+
raise CatalogError(msg) from None
|
|
378
|
+
|
|
379
|
+
if not yaml_files:
|
|
380
|
+
msg = "No catalog YAML files found — catalog directory is empty"
|
|
381
|
+
raise CatalogError(msg)
|
|
382
|
+
|
|
383
|
+
for yaml_file in sorted(yaml_files, key=lambda f: f.name):
|
|
384
|
+
filename = yaml_file.name
|
|
385
|
+
try:
|
|
386
|
+
content = yaml_file.read_text(encoding="utf-8")
|
|
387
|
+
except OSError as exc:
|
|
388
|
+
msg = f"Could not read catalog file '{filename}': {exc}"
|
|
389
|
+
raise CatalogError(msg) from None
|
|
390
|
+
|
|
391
|
+
try:
|
|
392
|
+
data = yaml.safe_load(content)
|
|
393
|
+
except yaml.YAMLError as exc:
|
|
394
|
+
msg = f"Catalog file '{filename}' contains invalid YAML: {exc}"
|
|
395
|
+
raise CatalogError(msg) from None
|
|
396
|
+
|
|
397
|
+
if not isinstance(data, dict):
|
|
398
|
+
actual_type = type(data).__name__
|
|
399
|
+
msg = (
|
|
400
|
+
f"Catalog file '{filename}' must contain a YAML mapping, "
|
|
401
|
+
f"got {actual_type}"
|
|
402
|
+
)
|
|
403
|
+
raise CatalogError(msg) from None
|
|
404
|
+
|
|
405
|
+
_validate_entry(data, filename)
|
|
406
|
+
entries.append(_parse_entry(data))
|
|
407
|
+
|
|
408
|
+
# Sort by family, then by params_b ascending
|
|
409
|
+
entries.sort(key=lambda e: (e.family, e.params_b))
|
|
410
|
+
|
|
411
|
+
return entries
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def load_catalog_from_directory(directory: str) -> list[CatalogEntry]:
|
|
415
|
+
"""Load catalog entries from an arbitrary directory.
|
|
416
|
+
|
|
417
|
+
This is useful for testing with custom catalog files.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
directory: Path to a directory containing YAML catalog files.
|
|
421
|
+
|
|
422
|
+
Returns:
|
|
423
|
+
A list of CatalogEntry instances, sorted by family then params_b.
|
|
424
|
+
|
|
425
|
+
Raises:
|
|
426
|
+
CatalogError: If any catalog file is missing, corrupt, or invalid.
|
|
427
|
+
"""
|
|
428
|
+
from pathlib import Path
|
|
429
|
+
|
|
430
|
+
catalog_dir = Path(directory)
|
|
431
|
+
|
|
432
|
+
if not catalog_dir.is_dir():
|
|
433
|
+
msg = f"Catalog directory not found: {directory}"
|
|
434
|
+
raise CatalogError(msg)
|
|
435
|
+
|
|
436
|
+
yaml_files = sorted(catalog_dir.glob("*.yaml"))
|
|
437
|
+
|
|
438
|
+
if not yaml_files:
|
|
439
|
+
msg = f"No catalog YAML files found in '{directory}'"
|
|
440
|
+
raise CatalogError(msg)
|
|
441
|
+
|
|
442
|
+
entries: list[CatalogEntry] = []
|
|
443
|
+
|
|
444
|
+
for yaml_file in yaml_files:
|
|
445
|
+
filename = yaml_file.name
|
|
446
|
+
try:
|
|
447
|
+
content = yaml_file.read_text(encoding="utf-8")
|
|
448
|
+
except OSError as exc:
|
|
449
|
+
msg = f"Could not read catalog file '{filename}': {exc}"
|
|
450
|
+
raise CatalogError(msg) from None
|
|
451
|
+
|
|
452
|
+
try:
|
|
453
|
+
data = yaml.safe_load(content)
|
|
454
|
+
except yaml.YAMLError as exc:
|
|
455
|
+
msg = f"Catalog file '{filename}' contains invalid YAML: {exc}"
|
|
456
|
+
raise CatalogError(msg) from None
|
|
457
|
+
|
|
458
|
+
if not isinstance(data, dict):
|
|
459
|
+
actual_type = type(data).__name__
|
|
460
|
+
msg = (
|
|
461
|
+
f"Catalog file '{filename}' must contain a YAML mapping, "
|
|
462
|
+
f"got {actual_type}"
|
|
463
|
+
)
|
|
464
|
+
raise CatalogError(msg) from None
|
|
465
|
+
|
|
466
|
+
_validate_entry(data, filename)
|
|
467
|
+
entries.append(_parse_entry(data))
|
|
468
|
+
|
|
469
|
+
entries.sort(key=lambda e: (e.family, e.params_b))
|
|
470
|
+
|
|
471
|
+
return entries
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
# --------------------------------------------------------------------------- #
|
|
475
|
+
# Querying
|
|
476
|
+
# --------------------------------------------------------------------------- #
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def get_entry_by_id(catalog: list[CatalogEntry], model_id: str) -> CatalogEntry | None:
|
|
480
|
+
"""Look up a catalog entry by its model ID.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
catalog: The loaded catalog.
|
|
484
|
+
model_id: The model ID to look up.
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
The matching CatalogEntry, or None if not found.
|
|
488
|
+
"""
|
|
489
|
+
for entry in catalog:
|
|
490
|
+
if entry.id == model_id:
|
|
491
|
+
return entry
|
|
492
|
+
return None
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def query_by_family(catalog: list[CatalogEntry], family: str) -> list[CatalogEntry]:
|
|
496
|
+
"""Filter catalog entries by model family (case-insensitive).
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
catalog: The loaded catalog.
|
|
500
|
+
family: The family name to filter by.
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
A list of matching CatalogEntry instances.
|
|
504
|
+
"""
|
|
505
|
+
family_lower = family.lower()
|
|
506
|
+
return [e for e in catalog if e.family.lower() == family_lower]
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def query_by_tag(catalog: list[CatalogEntry], tag: str) -> list[CatalogEntry]:
|
|
510
|
+
"""Filter catalog entries by tag (case-insensitive).
|
|
511
|
+
|
|
512
|
+
Args:
|
|
513
|
+
catalog: The loaded catalog.
|
|
514
|
+
tag: The tag to filter by.
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
A list of matching CatalogEntry instances.
|
|
518
|
+
"""
|
|
519
|
+
tag_lower = tag.lower()
|
|
520
|
+
return [e for e in catalog if tag_lower in [t.lower() for t in e.tags]]
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def query_by_capability(
|
|
524
|
+
catalog: list[CatalogEntry],
|
|
525
|
+
**capabilities: bool,
|
|
526
|
+
) -> list[CatalogEntry]:
|
|
527
|
+
"""Filter catalog entries by capability flags.
|
|
528
|
+
|
|
529
|
+
Supports filtering by any combination of: tool_calling, thinking, vision.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
catalog: The loaded catalog.
|
|
533
|
+
**capabilities: Capability flags to filter by (e.g., tool_calling=True).
|
|
534
|
+
|
|
535
|
+
Returns:
|
|
536
|
+
A list of matching CatalogEntry instances.
|
|
537
|
+
|
|
538
|
+
Raises:
|
|
539
|
+
ValueError: If an invalid capability name is given.
|
|
540
|
+
"""
|
|
541
|
+
valid_caps = {"tool_calling", "thinking", "vision"}
|
|
542
|
+
for cap_name in capabilities:
|
|
543
|
+
if cap_name not in valid_caps:
|
|
544
|
+
msg = (
|
|
545
|
+
f"Invalid capability filter '{cap_name}' "
|
|
546
|
+
f"(valid: {', '.join(sorted(valid_caps))})"
|
|
547
|
+
)
|
|
548
|
+
raise ValueError(msg)
|
|
549
|
+
|
|
550
|
+
results: list[CatalogEntry] = []
|
|
551
|
+
for entry in catalog:
|
|
552
|
+
match = True
|
|
553
|
+
for cap_name, cap_value in capabilities.items():
|
|
554
|
+
actual = getattr(entry.capabilities, cap_name)
|
|
555
|
+
if actual != cap_value:
|
|
556
|
+
match = False
|
|
557
|
+
break
|
|
558
|
+
if match:
|
|
559
|
+
results.append(entry)
|
|
560
|
+
return results
|