cortex-llm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex/__init__.py +73 -0
- cortex/__main__.py +83 -0
- cortex/config.py +329 -0
- cortex/conversation_manager.py +468 -0
- cortex/fine_tuning/__init__.py +8 -0
- cortex/fine_tuning/dataset.py +332 -0
- cortex/fine_tuning/mlx_lora_trainer.py +502 -0
- cortex/fine_tuning/trainer.py +957 -0
- cortex/fine_tuning/wizard.py +707 -0
- cortex/gpu_validator.py +467 -0
- cortex/inference_engine.py +727 -0
- cortex/metal/__init__.py +275 -0
- cortex/metal/gpu_validator.py +177 -0
- cortex/metal/memory_pool.py +886 -0
- cortex/metal/mlx_accelerator.py +678 -0
- cortex/metal/mlx_converter.py +638 -0
- cortex/metal/mps_optimizer.py +417 -0
- cortex/metal/optimizer.py +665 -0
- cortex/metal/performance_profiler.py +364 -0
- cortex/model_downloader.py +130 -0
- cortex/model_manager.py +2187 -0
- cortex/quantization/__init__.py +5 -0
- cortex/quantization/dynamic_quantizer.py +736 -0
- cortex/template_registry/__init__.py +15 -0
- cortex/template_registry/auto_detector.py +144 -0
- cortex/template_registry/config_manager.py +234 -0
- cortex/template_registry/interactive.py +260 -0
- cortex/template_registry/registry.py +347 -0
- cortex/template_registry/template_profiles/__init__.py +5 -0
- cortex/template_registry/template_profiles/base.py +142 -0
- cortex/template_registry/template_profiles/complex/__init__.py +5 -0
- cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
- cortex/template_registry/template_profiles/standard/__init__.py +9 -0
- cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
- cortex/template_registry/template_profiles/standard/chatml.py +82 -0
- cortex/template_registry/template_profiles/standard/gemma.py +103 -0
- cortex/template_registry/template_profiles/standard/llama.py +87 -0
- cortex/template_registry/template_profiles/standard/simple.py +65 -0
- cortex/ui/__init__.py +120 -0
- cortex/ui/cli.py +1685 -0
- cortex/ui/markdown_render.py +185 -0
- cortex/ui/terminal_app.py +534 -0
- cortex_llm-1.0.0.dist-info/METADATA +275 -0
- cortex_llm-1.0.0.dist-info/RECORD +48 -0
- cortex_llm-1.0.0.dist-info/WHEEL +5 -0
- cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
- cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
- cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,638 @@
|
|
|
1
|
+
"""MLX model converter for optimal Apple Silicon performance."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, Any, Optional, Tuple, Callable, Union
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from enum import Enum
|
|
9
|
+
import hashlib
|
|
10
|
+
import time
|
|
11
|
+
|
|
12
|
+
import mlx.core as mx
|
|
13
|
+
import mlx.nn as nn
|
|
14
|
+
from mlx.utils import tree_map_with_path
|
|
15
|
+
from huggingface_hub import snapshot_download
|
|
16
|
+
|
|
17
|
+
# Configure logging
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
logger.setLevel(logging.INFO)
|
|
20
|
+
|
|
21
|
+
# Import MLX LM functions safely
|
|
22
|
+
try:
|
|
23
|
+
from mlx_lm import load
|
|
24
|
+
from mlx_lm import utils as mlx_utils
|
|
25
|
+
except ImportError:
|
|
26
|
+
load = None
|
|
27
|
+
mlx_utils = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ConversionFormat(Enum):
|
|
31
|
+
"""Supported conversion formats."""
|
|
32
|
+
HUGGINGFACE = "huggingface"
|
|
33
|
+
SAFETENSORS = "safetensors"
|
|
34
|
+
PYTORCH = "pytorch"
|
|
35
|
+
GGUF = "gguf"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class QuantizationRecipe(Enum):
|
|
39
|
+
"""Predefined quantization recipes for different use cases."""
|
|
40
|
+
SPEED_4BIT = "4bit" # Maximum speed, 75% size reduction
|
|
41
|
+
BALANCED_5BIT = "5bit" # Balance between speed and quality
|
|
42
|
+
QUALITY_8BIT = "8bit" # Higher quality, 50% size reduction
|
|
43
|
+
MIXED_PRECISION = "mixed" # Custom per-layer quantization
|
|
44
|
+
NONE = "none" # No quantization
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class ConversionConfig:
|
|
49
|
+
"""Configuration for model conversion."""
|
|
50
|
+
source_format: ConversionFormat = ConversionFormat.HUGGINGFACE
|
|
51
|
+
quantization: QuantizationRecipe = QuantizationRecipe.SPEED_4BIT
|
|
52
|
+
group_size: int = 64 # Quantization group size
|
|
53
|
+
mixed_precision_config: Optional[Dict[str, Any]] = None
|
|
54
|
+
cache_converted: bool = True
|
|
55
|
+
validate_conversion: bool = True
|
|
56
|
+
use_amx: bool = True # Enable AMX optimizations
|
|
57
|
+
compile_model: bool = True # JIT compile for performance
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class MLXConverter:
|
|
61
|
+
"""Convert models to MLX format with optimal quantization."""
|
|
62
|
+
|
|
63
|
+
def __init__(self, cache_dir: Optional[Path] = None):
|
|
64
|
+
"""Initialize MLX converter."""
|
|
65
|
+
self.cache_dir = cache_dir or Path.home() / ".cortex" / "mlx_models"
|
|
66
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
self.conversion_cache = self.cache_dir / "conversion_cache.json"
|
|
68
|
+
self._load_conversion_cache()
|
|
69
|
+
|
|
70
|
+
logger.info(f"MLX Converter initialized with cache dir: {self.cache_dir}")
|
|
71
|
+
logger.info(f"MLX LM available: {mlx_utils is not None and load is not None}")
|
|
72
|
+
|
|
73
|
+
def _load_conversion_cache(self) -> None:
|
|
74
|
+
"""Load conversion cache metadata."""
|
|
75
|
+
if self.conversion_cache.exists():
|
|
76
|
+
with open(self.conversion_cache) as f:
|
|
77
|
+
self.cache_metadata = json.load(f)
|
|
78
|
+
else:
|
|
79
|
+
self.cache_metadata = {}
|
|
80
|
+
|
|
81
|
+
def _save_conversion_cache(self) -> None:
|
|
82
|
+
"""Save conversion cache metadata."""
|
|
83
|
+
with open(self.conversion_cache, 'w') as f:
|
|
84
|
+
json.dump(self.cache_metadata, f, indent=2)
|
|
85
|
+
|
|
86
|
+
def convert_model(
|
|
87
|
+
self,
|
|
88
|
+
source_path: str,
|
|
89
|
+
output_name: Optional[str] = None,
|
|
90
|
+
config: Optional[ConversionConfig] = None
|
|
91
|
+
) -> Tuple[bool, str, Optional[Path]]:
|
|
92
|
+
"""
|
|
93
|
+
Convert a model to MLX format with optimal settings.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
source_path: Path to source model (HF repo ID or local path)
|
|
97
|
+
output_name: Name for converted model
|
|
98
|
+
config: Conversion configuration
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Tuple of (success, message, output_path)
|
|
102
|
+
"""
|
|
103
|
+
config = config or ConversionConfig()
|
|
104
|
+
|
|
105
|
+
# Generate output name if not provided
|
|
106
|
+
if not output_name:
|
|
107
|
+
if "/" in source_path and not source_path.startswith("/"):
|
|
108
|
+
# HuggingFace repo ID (e.g., "meta-llama/Llama-2-7b")
|
|
109
|
+
output_name = source_path.replace("/", "_")
|
|
110
|
+
else:
|
|
111
|
+
# Local path - use just the model directory name
|
|
112
|
+
output_name = Path(source_path).name
|
|
113
|
+
|
|
114
|
+
# Add quantization suffix
|
|
115
|
+
if config.quantization != QuantizationRecipe.NONE:
|
|
116
|
+
output_name = f"{output_name}_{config.quantization.value}"
|
|
117
|
+
|
|
118
|
+
output_path = self.cache_dir / output_name
|
|
119
|
+
source_ref = self._get_source_ref(source_path)
|
|
120
|
+
|
|
121
|
+
# Check if already converted
|
|
122
|
+
cache_key = self._get_cache_key(source_path, config)
|
|
123
|
+
if cache_key in self.cache_metadata and output_path.exists():
|
|
124
|
+
valid, reason = self._validate_existing_output(output_path, config, source_ref)
|
|
125
|
+
if valid:
|
|
126
|
+
logger.info(f"Model already converted, using cached version at {output_path}")
|
|
127
|
+
return True, f"Model already converted at {output_path}", output_path
|
|
128
|
+
return False, f"Cached MLX output is invalid: {reason}. Please delete {output_path} and retry.", None
|
|
129
|
+
|
|
130
|
+
if output_path.exists():
|
|
131
|
+
valid, reason = self._validate_existing_output(output_path, config, source_ref)
|
|
132
|
+
if valid:
|
|
133
|
+
logger.info(f"Found existing MLX model at {output_path}, using it")
|
|
134
|
+
self.cache_metadata[cache_key] = {
|
|
135
|
+
"output_path": str(output_path),
|
|
136
|
+
"timestamp": time.time(),
|
|
137
|
+
"config": {
|
|
138
|
+
"quantization": config.quantization.value,
|
|
139
|
+
"group_size": config.group_size
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
self._save_conversion_cache()
|
|
143
|
+
self._write_conversion_metadata(output_path, source_ref, config)
|
|
144
|
+
return True, f"Model already converted at {output_path}", output_path
|
|
145
|
+
return False, (
|
|
146
|
+
f"Output path already exists but does not match requested conversion: {reason}. "
|
|
147
|
+
f"Please delete {output_path} or choose a different output name."
|
|
148
|
+
), None
|
|
149
|
+
|
|
150
|
+
logger.info(f"Starting MLX conversion for {source_path}")
|
|
151
|
+
logger.info(f"Config: quantization={config.quantization.value}, AMX={config.use_amx}, compile={config.compile_model}")
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
# Download if HuggingFace repo
|
|
155
|
+
if "/" in source_path and not Path(source_path).exists():
|
|
156
|
+
logger.info(f"Downloading model from HuggingFace: {source_path}")
|
|
157
|
+
print(f"Downloading model from HuggingFace: {source_path}")
|
|
158
|
+
local_path = self._download_from_hub(source_path)
|
|
159
|
+
logger.info(f"Downloaded to: {local_path}")
|
|
160
|
+
else:
|
|
161
|
+
local_path = Path(source_path)
|
|
162
|
+
logger.info(f"Using local model at: {local_path}")
|
|
163
|
+
|
|
164
|
+
# Detect format and convert
|
|
165
|
+
if config.source_format == ConversionFormat.GGUF:
|
|
166
|
+
success, msg, converted_path = self._convert_gguf(
|
|
167
|
+
local_path, output_path, config
|
|
168
|
+
)
|
|
169
|
+
else:
|
|
170
|
+
success, msg, converted_path = self._convert_transformers(
|
|
171
|
+
local_path, output_path, config
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if success:
|
|
175
|
+
# Update cache
|
|
176
|
+
self.cache_metadata[cache_key] = {
|
|
177
|
+
"output_path": str(converted_path),
|
|
178
|
+
"timestamp": time.time(),
|
|
179
|
+
"config": {
|
|
180
|
+
"quantization": config.quantization.value,
|
|
181
|
+
"group_size": config.group_size
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
self._save_conversion_cache()
|
|
185
|
+
self._write_conversion_metadata(converted_path, source_ref, config)
|
|
186
|
+
logger.info(f"Conversion successful, cached at: {converted_path}")
|
|
187
|
+
else:
|
|
188
|
+
logger.error(f"Conversion failed: {msg}")
|
|
189
|
+
|
|
190
|
+
return success, msg, converted_path
|
|
191
|
+
|
|
192
|
+
except Exception as e:
|
|
193
|
+
logger.error(f"Conversion failed with exception: {str(e)}")
|
|
194
|
+
return False, f"Conversion failed: {str(e)}", None
|
|
195
|
+
|
|
196
|
+
def _download_from_hub(self, repo_id: str) -> Path:
|
|
197
|
+
"""Download model from HuggingFace Hub."""
|
|
198
|
+
download_dir = self.cache_dir / "downloads" / repo_id.replace("/", "_")
|
|
199
|
+
|
|
200
|
+
if not download_dir.exists():
|
|
201
|
+
snapshot_download(
|
|
202
|
+
repo_id=repo_id,
|
|
203
|
+
local_dir=download_dir,
|
|
204
|
+
local_dir_use_symlinks=False
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
return download_dir
|
|
208
|
+
|
|
209
|
+
def _requires_sentencepiece(self, model_path: Path) -> bool:
|
|
210
|
+
"""Return True if the model likely needs SentencePiece."""
|
|
211
|
+
# If a fast tokenizer is present, SentencePiece should not be required.
|
|
212
|
+
if (model_path / "tokenizer.json").exists():
|
|
213
|
+
return False
|
|
214
|
+
|
|
215
|
+
sp_files = [
|
|
216
|
+
"tokenizer.model",
|
|
217
|
+
"sentencepiece.model",
|
|
218
|
+
"sentencepiece.bpe.model",
|
|
219
|
+
"spiece.model",
|
|
220
|
+
]
|
|
221
|
+
if any((model_path / name).exists() for name in sp_files):
|
|
222
|
+
return True
|
|
223
|
+
|
|
224
|
+
config_path = model_path / "tokenizer_config.json"
|
|
225
|
+
if config_path.exists():
|
|
226
|
+
try:
|
|
227
|
+
with open(config_path) as f:
|
|
228
|
+
cfg = json.load(f)
|
|
229
|
+
tokenizer_class = str(cfg.get("tokenizer_class", "")).lower()
|
|
230
|
+
if any(key in tokenizer_class for key in ["sentencepiece", "llama", "t5", "gemma", "mistral"]):
|
|
231
|
+
return True
|
|
232
|
+
if any(key in cfg for key in ["sp_model", "spiece_model_file", "sentencepiece_model"]):
|
|
233
|
+
return True
|
|
234
|
+
except Exception:
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
return False
|
|
238
|
+
|
|
239
|
+
def _ensure_sentencepiece(self, model_path: Path) -> Optional[str]:
|
|
240
|
+
"""Return an error message if SentencePiece is required but missing."""
|
|
241
|
+
if not self._requires_sentencepiece(model_path):
|
|
242
|
+
return None
|
|
243
|
+
try:
|
|
244
|
+
import sentencepiece # noqa: F401
|
|
245
|
+
except Exception:
|
|
246
|
+
return (
|
|
247
|
+
"SentencePiece tokenizer detected but the 'sentencepiece' package is not installed. "
|
|
248
|
+
"Install it with: pip install sentencepiece (if build fails, ensure cmake is installed)."
|
|
249
|
+
)
|
|
250
|
+
return None
|
|
251
|
+
|
|
252
|
+
def _normalize_hf_repo(self, hf_repo: Any) -> Optional[str]:
|
|
253
|
+
"""Normalize HF repo metadata for model card creation."""
|
|
254
|
+
if hf_repo is None:
|
|
255
|
+
return None
|
|
256
|
+
if isinstance(hf_repo, (str, Path)):
|
|
257
|
+
return str(hf_repo)
|
|
258
|
+
if isinstance(hf_repo, (list, tuple)):
|
|
259
|
+
cleaned = [str(x) for x in hf_repo if isinstance(x, (str, Path)) and str(x).strip()]
|
|
260
|
+
if len(cleaned) == 1:
|
|
261
|
+
logger.warning("base_model is a list; using the single entry for model card creation")
|
|
262
|
+
return cleaned[0]
|
|
263
|
+
logger.warning("base_model is a list with multiple entries; skipping model card creation")
|
|
264
|
+
return None
|
|
265
|
+
logger.warning(f"Unexpected base_model type {type(hf_repo)}, skipping model card creation")
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
def _get_source_ref(self, source_path: str) -> str:
|
|
269
|
+
"""Normalize source reference for cache validation."""
|
|
270
|
+
if "/" in source_path and not Path(source_path).exists():
|
|
271
|
+
return source_path
|
|
272
|
+
return str(Path(source_path).expanduser().resolve())
|
|
273
|
+
|
|
274
|
+
def _write_conversion_metadata(
|
|
275
|
+
self,
|
|
276
|
+
output_path: Path,
|
|
277
|
+
source_ref: str,
|
|
278
|
+
config: ConversionConfig
|
|
279
|
+
) -> None:
|
|
280
|
+
"""Write conversion metadata for traceability."""
|
|
281
|
+
metadata_path = output_path / "conversion.json"
|
|
282
|
+
metadata = {
|
|
283
|
+
"source_ref": source_ref,
|
|
284
|
+
"source_format": config.source_format.value,
|
|
285
|
+
"quantization": config.quantization.value,
|
|
286
|
+
"group_size": config.group_size,
|
|
287
|
+
"timestamp": time.time(),
|
|
288
|
+
}
|
|
289
|
+
try:
|
|
290
|
+
with open(metadata_path, "w") as f:
|
|
291
|
+
json.dump(metadata, f, indent=2)
|
|
292
|
+
except Exception as e:
|
|
293
|
+
logger.warning("Failed to write conversion metadata: %s", e)
|
|
294
|
+
|
|
295
|
+
def _validate_existing_output(
|
|
296
|
+
self,
|
|
297
|
+
model_path: Path,
|
|
298
|
+
config: ConversionConfig,
|
|
299
|
+
source_ref: str
|
|
300
|
+
) -> Tuple[bool, str]:
|
|
301
|
+
"""Validate an existing MLX output for completeness and config match."""
|
|
302
|
+
if not model_path.exists():
|
|
303
|
+
return False, "output path does not exist"
|
|
304
|
+
if not model_path.is_dir():
|
|
305
|
+
return False, "output path is not a directory"
|
|
306
|
+
|
|
307
|
+
config_path = model_path / "config.json"
|
|
308
|
+
if not config_path.exists():
|
|
309
|
+
return False, "missing config.json"
|
|
310
|
+
if not any(model_path.glob("model*.safetensors")):
|
|
311
|
+
return False, "missing model*.safetensors"
|
|
312
|
+
|
|
313
|
+
try:
|
|
314
|
+
with open(config_path) as f:
|
|
315
|
+
model_config = json.load(f)
|
|
316
|
+
except Exception as e:
|
|
317
|
+
return False, f"invalid config.json: {e}"
|
|
318
|
+
|
|
319
|
+
metadata_path = model_path / "conversion.json"
|
|
320
|
+
if metadata_path.exists():
|
|
321
|
+
try:
|
|
322
|
+
with open(metadata_path) as f:
|
|
323
|
+
metadata = json.load(f)
|
|
324
|
+
if metadata.get("source_ref") != source_ref:
|
|
325
|
+
return False, "conversion source mismatch"
|
|
326
|
+
except Exception as e:
|
|
327
|
+
return False, f"invalid conversion metadata: {e}"
|
|
328
|
+
else:
|
|
329
|
+
logger.warning("Conversion metadata missing for %s; proceeding with structural validation only", model_path)
|
|
330
|
+
|
|
331
|
+
if config.quantization == QuantizationRecipe.NONE:
|
|
332
|
+
if "quantization_config" in model_config or "quantization" in model_config:
|
|
333
|
+
return False, "expected unquantized model but output is quantized"
|
|
334
|
+
return True, "valid unquantized model"
|
|
335
|
+
|
|
336
|
+
quant_cfg = model_config.get("quantization_config") or model_config.get("quantization")
|
|
337
|
+
if quant_cfg is None:
|
|
338
|
+
return False, "missing quantization config"
|
|
339
|
+
|
|
340
|
+
if config.quantization == QuantizationRecipe.MIXED_PRECISION:
|
|
341
|
+
if not isinstance(quant_cfg, dict):
|
|
342
|
+
return False, "invalid mixed-precision config"
|
|
343
|
+
return True, "valid mixed-precision model"
|
|
344
|
+
|
|
345
|
+
expected_bits = self._get_quantization_bits(config.quantization)
|
|
346
|
+
if isinstance(quant_cfg, dict):
|
|
347
|
+
bits = quant_cfg.get("bits")
|
|
348
|
+
group_size = quant_cfg.get("group_size")
|
|
349
|
+
if bits != expected_bits:
|
|
350
|
+
return False, f"quantization bits mismatch (expected {expected_bits}, got {bits})"
|
|
351
|
+
if group_size != config.group_size:
|
|
352
|
+
return False, f"quantization group size mismatch (expected {config.group_size}, got {group_size})"
|
|
353
|
+
else:
|
|
354
|
+
return False, "invalid quantization config format"
|
|
355
|
+
|
|
356
|
+
return True, "valid quantized model"
|
|
357
|
+
|
|
358
|
+
def _convert_transformers(
|
|
359
|
+
self,
|
|
360
|
+
source_path: Path,
|
|
361
|
+
output_path: Path,
|
|
362
|
+
config: ConversionConfig
|
|
363
|
+
) -> Tuple[bool, str, Path]:
|
|
364
|
+
"""Convert Transformers/SafeTensors model to MLX."""
|
|
365
|
+
try:
|
|
366
|
+
if mlx_utils is None:
|
|
367
|
+
logger.warning("MLX LM library not available for conversion")
|
|
368
|
+
return False, "MLX LM library not available for conversion", None
|
|
369
|
+
|
|
370
|
+
logger.info(f"Converting {source_path} to MLX format")
|
|
371
|
+
logger.info(f"Quantization: {config.quantization.value}, bits: {self._get_quantization_bits(config.quantization)}")
|
|
372
|
+
print(f"Converting {source_path} to MLX format...")
|
|
373
|
+
|
|
374
|
+
missing_dep = self._ensure_sentencepiece(source_path)
|
|
375
|
+
if missing_dep:
|
|
376
|
+
logger.error(missing_dep)
|
|
377
|
+
return False, missing_dep, None
|
|
378
|
+
|
|
379
|
+
# Build quantization configuration
|
|
380
|
+
quantize_config = self._build_quantization_config(config)
|
|
381
|
+
|
|
382
|
+
model_path, hf_repo = mlx_utils.get_model_path(str(source_path))
|
|
383
|
+
model, model_config, tokenizer = mlx_utils.fetch_from_hub(
|
|
384
|
+
model_path, lazy=True, trust_remote_code=False
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
dtype = model_config.get("torch_dtype", None)
|
|
388
|
+
if dtype in ["float16", "bfloat16", "float32"]:
|
|
389
|
+
print("[INFO] Using dtype:", dtype)
|
|
390
|
+
dtype = getattr(mx, dtype)
|
|
391
|
+
cast_predicate = getattr(model, "cast_predicate", lambda _: True)
|
|
392
|
+
|
|
393
|
+
def set_dtype(k, v):
|
|
394
|
+
if cast_predicate(k) and mx.issubdtype(v.dtype, mx.floating):
|
|
395
|
+
return v.astype(dtype)
|
|
396
|
+
return v
|
|
397
|
+
|
|
398
|
+
model.update(tree_map_with_path(set_dtype, model.parameters()))
|
|
399
|
+
|
|
400
|
+
if config.quantization != QuantizationRecipe.NONE:
|
|
401
|
+
quant_predicate = None
|
|
402
|
+
if quantize_config and "quant_predicate" in quantize_config:
|
|
403
|
+
quant_predicate = quantize_config["quant_predicate"]
|
|
404
|
+
model, model_config = mlx_utils.quantize_model(
|
|
405
|
+
model,
|
|
406
|
+
model_config,
|
|
407
|
+
config.group_size,
|
|
408
|
+
self._get_quantization_bits(config.quantization),
|
|
409
|
+
mode="affine",
|
|
410
|
+
quant_predicate=quant_predicate,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
normalized_hf_repo = self._normalize_hf_repo(hf_repo)
|
|
414
|
+
mlx_utils.save(output_path, model_path, model, tokenizer, model_config, hf_repo=normalized_hf_repo)
|
|
415
|
+
logger.info("MLX conversion completed")
|
|
416
|
+
|
|
417
|
+
# Apply AMX optimizations if enabled
|
|
418
|
+
if config.use_amx:
|
|
419
|
+
logger.info("Applying AMX optimizations")
|
|
420
|
+
self._apply_amx_optimizations(output_path)
|
|
421
|
+
|
|
422
|
+
# Validate conversion if requested
|
|
423
|
+
if config.validate_conversion:
|
|
424
|
+
logger.info("Validating converted model")
|
|
425
|
+
if not self._validate_model(output_path):
|
|
426
|
+
logger.error("Model validation failed")
|
|
427
|
+
return False, "Validation failed", None
|
|
428
|
+
logger.info("Model validation successful")
|
|
429
|
+
|
|
430
|
+
logger.info(f"Successfully converted model to {output_path}")
|
|
431
|
+
return True, f"Successfully converted to {output_path}", output_path
|
|
432
|
+
|
|
433
|
+
except Exception as e:
|
|
434
|
+
logger.error(f"Transformers conversion failed: {str(e)}")
|
|
435
|
+
return False, f"Transformers conversion failed: {str(e)}", None
|
|
436
|
+
|
|
437
|
+
def _convert_gguf(
|
|
438
|
+
self,
|
|
439
|
+
source_path: Path,
|
|
440
|
+
output_path: Path,
|
|
441
|
+
config: ConversionConfig
|
|
442
|
+
) -> Tuple[bool, str, Path]:
|
|
443
|
+
"""Convert GGUF model to MLX (via HuggingFace intermediate)."""
|
|
444
|
+
try:
|
|
445
|
+
# GGUF -> HF conversion requires llama.cpp tools
|
|
446
|
+
# For now, we'll return an informative message
|
|
447
|
+
return False, (
|
|
448
|
+
"GGUF to MLX conversion requires intermediate HuggingFace format. "
|
|
449
|
+
"Please use 'convert_hf_to_gguf.py' in reverse or download "
|
|
450
|
+
"the HuggingFace version of this model."
|
|
451
|
+
), None
|
|
452
|
+
|
|
453
|
+
except Exception as e:
|
|
454
|
+
return False, f"GGUF conversion failed: {str(e)}", None
|
|
455
|
+
|
|
456
|
+
def _build_quantization_config(
|
|
457
|
+
self,
|
|
458
|
+
config: ConversionConfig
|
|
459
|
+
) -> Dict[str, Any]:
|
|
460
|
+
"""Build quantization configuration for MLX quantization."""
|
|
461
|
+
quant_config = {}
|
|
462
|
+
|
|
463
|
+
if config.quantization == QuantizationRecipe.MIXED_PRECISION:
|
|
464
|
+
# Build mixed precision predicate
|
|
465
|
+
quant_config["quant_predicate"] = self._build_mixed_precision_predicate(
|
|
466
|
+
config.mixed_precision_config
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
return quant_config
|
|
470
|
+
|
|
471
|
+
def _build_mixed_precision_predicate(
|
|
472
|
+
self,
|
|
473
|
+
mixed_config: Optional[Dict[str, Any]]
|
|
474
|
+
) -> Callable:
|
|
475
|
+
"""Build mixed precision quantization predicate."""
|
|
476
|
+
mixed_config = mixed_config or {}
|
|
477
|
+
|
|
478
|
+
# Default: higher precision for critical layers
|
|
479
|
+
critical_layers = mixed_config.get("critical_layers", [
|
|
480
|
+
"lm_head", "embed_tokens", "wte", "wpe"
|
|
481
|
+
])
|
|
482
|
+
critical_bits = mixed_config.get("critical_bits", 6)
|
|
483
|
+
standard_bits = mixed_config.get("standard_bits", 4)
|
|
484
|
+
|
|
485
|
+
logger.info(f"Mixed precision config: critical={critical_bits}bit, standard={standard_bits}bit")
|
|
486
|
+
logger.info(f"Critical layers: {critical_layers}")
|
|
487
|
+
|
|
488
|
+
def predicate(layer_path: str, layer: nn.Module, model_config: Dict) -> Union[bool, Dict]:
|
|
489
|
+
"""Determine quantization for each layer."""
|
|
490
|
+
# Critical layers get higher precision
|
|
491
|
+
for critical in critical_layers:
|
|
492
|
+
if critical in layer_path:
|
|
493
|
+
return {"bits": critical_bits, "group_size": 64}
|
|
494
|
+
|
|
495
|
+
# Attention layers can use standard quantization
|
|
496
|
+
if any(x in layer_path for x in ["q_proj", "k_proj", "v_proj", "o_proj"]):
|
|
497
|
+
return {"bits": standard_bits, "group_size": 64}
|
|
498
|
+
|
|
499
|
+
# FFN layers
|
|
500
|
+
if any(x in layer_path for x in ["gate_proj", "up_proj", "down_proj"]):
|
|
501
|
+
return {"bits": standard_bits, "group_size": 64}
|
|
502
|
+
|
|
503
|
+
# Skip quantization for other layers
|
|
504
|
+
return False
|
|
505
|
+
|
|
506
|
+
return predicate
|
|
507
|
+
|
|
508
|
+
def _get_quantization_bits(self, recipe: QuantizationRecipe) -> int:
|
|
509
|
+
"""Get quantization bits for recipe."""
|
|
510
|
+
mapping = {
|
|
511
|
+
QuantizationRecipe.SPEED_4BIT: 4,
|
|
512
|
+
QuantizationRecipe.BALANCED_5BIT: 5,
|
|
513
|
+
QuantizationRecipe.QUALITY_8BIT: 8,
|
|
514
|
+
QuantizationRecipe.MIXED_PRECISION: 4, # Default for mixed
|
|
515
|
+
QuantizationRecipe.NONE: 16
|
|
516
|
+
}
|
|
517
|
+
return mapping.get(recipe, 16)
|
|
518
|
+
|
|
519
|
+
def _apply_amx_optimizations(self, model_path: Path) -> None:
|
|
520
|
+
"""Apply AMX-specific optimizations to converted model."""
|
|
521
|
+
try:
|
|
522
|
+
# Load model config
|
|
523
|
+
config_path = model_path / "config.json"
|
|
524
|
+
if config_path.exists():
|
|
525
|
+
with open(config_path) as f:
|
|
526
|
+
config = json.load(f)
|
|
527
|
+
|
|
528
|
+
# Add AMX optimization flags
|
|
529
|
+
config["amx_optimized"] = True
|
|
530
|
+
config["use_fused_attention"] = True
|
|
531
|
+
config["operation_fusion"] = True
|
|
532
|
+
|
|
533
|
+
logger.info("AMX optimization flags added to model config")
|
|
534
|
+
|
|
535
|
+
# Save updated config
|
|
536
|
+
with open(config_path, 'w') as f:
|
|
537
|
+
json.dump(config, f, indent=2)
|
|
538
|
+
except Exception as e:
|
|
539
|
+
logger.warning(f"Could not apply AMX optimizations: {e}")
|
|
540
|
+
print(f"Warning: Could not apply AMX optimizations: {e}")
|
|
541
|
+
|
|
542
|
+
def _validate_model(self, model_path: Path) -> bool:
|
|
543
|
+
"""Validate converted model loads correctly."""
|
|
544
|
+
try:
|
|
545
|
+
if load is None:
|
|
546
|
+
logger.warning("Can't validate model without mlx_lm, assuming success")
|
|
547
|
+
return True
|
|
548
|
+
|
|
549
|
+
logger.debug(f"Loading model for validation: {model_path}")
|
|
550
|
+
# Try loading the model
|
|
551
|
+
model, tokenizer = load(str(model_path))
|
|
552
|
+
|
|
553
|
+
# Test a simple forward pass
|
|
554
|
+
test_input = "Hello, world!"
|
|
555
|
+
tokens = tokenizer.encode(test_input)
|
|
556
|
+
|
|
557
|
+
# Just verify model can process tokens
|
|
558
|
+
mx.eval(model.parameters())
|
|
559
|
+
|
|
560
|
+
logger.debug("Model validation passed")
|
|
561
|
+
return True
|
|
562
|
+
except Exception as e:
|
|
563
|
+
logger.error(f"Model validation failed: {e}")
|
|
564
|
+
print(f"Validation failed: {e}")
|
|
565
|
+
return False
|
|
566
|
+
|
|
567
|
+
def _get_cache_key(self, source_path: str, config: ConversionConfig) -> str:
|
|
568
|
+
"""Generate cache key for conversion."""
|
|
569
|
+
key_parts = [
|
|
570
|
+
source_path,
|
|
571
|
+
config.quantization.value,
|
|
572
|
+
str(config.group_size)
|
|
573
|
+
]
|
|
574
|
+
key_string = "_".join(key_parts)
|
|
575
|
+
return hashlib.md5(key_string.encode()).hexdigest()
|
|
576
|
+
|
|
577
|
+
def list_converted_models(self) -> Dict[str, Any]:
|
|
578
|
+
"""List all converted models in cache."""
|
|
579
|
+
models = {}
|
|
580
|
+
|
|
581
|
+
for model_dir in self.cache_dir.iterdir():
|
|
582
|
+
if model_dir.is_dir() and (model_dir / "config.json").exists():
|
|
583
|
+
config_path = model_dir / "config.json"
|
|
584
|
+
with open(config_path) as f:
|
|
585
|
+
config = json.load(f)
|
|
586
|
+
|
|
587
|
+
# Calculate model size
|
|
588
|
+
total_size = sum(
|
|
589
|
+
f.stat().st_size for f in model_dir.rglob("*") if f.is_file()
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
models[model_dir.name] = {
|
|
593
|
+
"path": str(model_dir),
|
|
594
|
+
"size_gb": total_size / (1024**3),
|
|
595
|
+
"quantization": config.get("quantization_config", {}).get("bits", "none"),
|
|
596
|
+
"amx_optimized": config.get("amx_optimized", False)
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
return models
|
|
600
|
+
|
|
601
|
+
def optimize_for_chip(self, model_path: Path, chip: str) -> None:
|
|
602
|
+
"""Optimize model for specific Apple Silicon chip."""
|
|
603
|
+
chip_configs = {
|
|
604
|
+
"m1": {"batch_size": 4, "prefetch": 2},
|
|
605
|
+
"m1_pro": {"batch_size": 6, "prefetch": 3},
|
|
606
|
+
"m1_max": {"batch_size": 8, "prefetch": 4},
|
|
607
|
+
"m1_ultra": {"batch_size": 16, "prefetch": 8},
|
|
608
|
+
"m2": {"batch_size": 6, "prefetch": 3},
|
|
609
|
+
"m2_pro": {"batch_size": 8, "prefetch": 4},
|
|
610
|
+
"m2_max": {"batch_size": 12, "prefetch": 6},
|
|
611
|
+
"m2_ultra": {"batch_size": 24, "prefetch": 12},
|
|
612
|
+
"m3": {"batch_size": 8, "prefetch": 4},
|
|
613
|
+
"m3_pro": {"batch_size": 12, "prefetch": 6},
|
|
614
|
+
"m3_max": {"batch_size": 16, "prefetch": 8},
|
|
615
|
+
"m4": {"batch_size": 12, "prefetch": 6},
|
|
616
|
+
"m4_pro": {"batch_size": 16, "prefetch": 8},
|
|
617
|
+
"m4_max": {"batch_size": 24, "prefetch": 12}
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
if chip.lower() in chip_configs:
|
|
621
|
+
config = chip_configs[chip.lower()]
|
|
622
|
+
|
|
623
|
+
# Update model config with chip-specific settings
|
|
624
|
+
config_path = model_path / "config.json"
|
|
625
|
+
if config_path.exists():
|
|
626
|
+
with open(config_path) as f:
|
|
627
|
+
model_config = json.load(f)
|
|
628
|
+
|
|
629
|
+
model_config["chip_optimization"] = {
|
|
630
|
+
"chip": chip,
|
|
631
|
+
"batch_size": config["batch_size"],
|
|
632
|
+
"prefetch_size": config["prefetch"]
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
logger.info(f"Optimized model for {chip.upper()} chip: batch_size={config['batch_size']}, prefetch={config['prefetch']}")
|
|
636
|
+
|
|
637
|
+
with open(config_path, 'w') as f:
|
|
638
|
+
json.dump(model_config, f, indent=2)
|