cortex-llm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex/__init__.py +73 -0
- cortex/__main__.py +83 -0
- cortex/config.py +329 -0
- cortex/conversation_manager.py +468 -0
- cortex/fine_tuning/__init__.py +8 -0
- cortex/fine_tuning/dataset.py +332 -0
- cortex/fine_tuning/mlx_lora_trainer.py +502 -0
- cortex/fine_tuning/trainer.py +957 -0
- cortex/fine_tuning/wizard.py +707 -0
- cortex/gpu_validator.py +467 -0
- cortex/inference_engine.py +727 -0
- cortex/metal/__init__.py +275 -0
- cortex/metal/gpu_validator.py +177 -0
- cortex/metal/memory_pool.py +886 -0
- cortex/metal/mlx_accelerator.py +678 -0
- cortex/metal/mlx_converter.py +638 -0
- cortex/metal/mps_optimizer.py +417 -0
- cortex/metal/optimizer.py +665 -0
- cortex/metal/performance_profiler.py +364 -0
- cortex/model_downloader.py +130 -0
- cortex/model_manager.py +2187 -0
- cortex/quantization/__init__.py +5 -0
- cortex/quantization/dynamic_quantizer.py +736 -0
- cortex/template_registry/__init__.py +15 -0
- cortex/template_registry/auto_detector.py +144 -0
- cortex/template_registry/config_manager.py +234 -0
- cortex/template_registry/interactive.py +260 -0
- cortex/template_registry/registry.py +347 -0
- cortex/template_registry/template_profiles/__init__.py +5 -0
- cortex/template_registry/template_profiles/base.py +142 -0
- cortex/template_registry/template_profiles/complex/__init__.py +5 -0
- cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
- cortex/template_registry/template_profiles/standard/__init__.py +9 -0
- cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
- cortex/template_registry/template_profiles/standard/chatml.py +82 -0
- cortex/template_registry/template_profiles/standard/gemma.py +103 -0
- cortex/template_registry/template_profiles/standard/llama.py +87 -0
- cortex/template_registry/template_profiles/standard/simple.py +65 -0
- cortex/ui/__init__.py +120 -0
- cortex/ui/cli.py +1685 -0
- cortex/ui/markdown_render.py +185 -0
- cortex/ui/terminal_app.py +534 -0
- cortex_llm-1.0.0.dist-info/METADATA +275 -0
- cortex_llm-1.0.0.dist-info/RECORD +48 -0
- cortex_llm-1.0.0.dist-info/WHEEL +5 -0
- cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
- cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
- cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,678 @@
|
|
|
1
|
+
"""MLX framework GPU acceleration for Apple Silicon with AMX and advanced quantization."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import mlx.core as mx
|
|
5
|
+
import mlx.nn as nn
|
|
6
|
+
from mlx.utils import tree_map, tree_flatten
|
|
7
|
+
from typing import Dict, Any, Optional, List, Tuple, Callable, Generator
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
import functools
|
|
10
|
+
import time
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
# Configure logging
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
logger.setLevel(logging.INFO)
|
|
16
|
+
|
|
17
|
+
# Import MLX LM functions safely
|
|
18
|
+
try:
|
|
19
|
+
from mlx_lm import generate, stream_generate
|
|
20
|
+
except ImportError:
|
|
21
|
+
# Fallback if mlx_lm is not available
|
|
22
|
+
generate = None
|
|
23
|
+
stream_generate = None
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class MLXConfig:
|
|
27
|
+
"""Configuration for MLX acceleration with AMX support."""
|
|
28
|
+
compile_model: bool = True
|
|
29
|
+
use_graph: bool = True
|
|
30
|
+
batch_size: int = 8
|
|
31
|
+
prefetch_size: int = 2
|
|
32
|
+
stream_parallel: bool = True
|
|
33
|
+
fusion_threshold: int = 1024
|
|
34
|
+
memory_fraction: float = 0.85
|
|
35
|
+
dtype: mx.Dtype = mx.bfloat16 # Better for modern Apple Silicon
|
|
36
|
+
use_amx: bool = True # Enable AMX coprocessor
|
|
37
|
+
fuse_operations: bool = True # Operation fusion for efficiency
|
|
38
|
+
lazy_evaluation: bool = True # Lazy eval for optimization
|
|
39
|
+
rotating_kv_cache: bool = True # For long contexts
|
|
40
|
+
kv_cache_size: int = 4096 # Max KV cache size
|
|
41
|
+
quantization_bits: int = 4 # Default quantization
|
|
42
|
+
mixed_precision: bool = False # Mixed precision quantization
|
|
43
|
+
|
|
44
|
+
class MLXAccelerator:
|
|
45
|
+
"""Accelerate models using MLX framework with Metal optimization."""
|
|
46
|
+
|
|
47
|
+
OPTIMIZATION_PRESETS = {
|
|
48
|
+
"speed": {
|
|
49
|
+
"compile_model": True,
|
|
50
|
+
"use_graph": True,
|
|
51
|
+
"stream_parallel": True,
|
|
52
|
+
"dtype": mx.bfloat16
|
|
53
|
+
},
|
|
54
|
+
"memory": {
|
|
55
|
+
"compile_model": True,
|
|
56
|
+
"use_graph": False,
|
|
57
|
+
"stream_parallel": False,
|
|
58
|
+
"dtype": mx.bfloat16
|
|
59
|
+
},
|
|
60
|
+
"balanced": {
|
|
61
|
+
"compile_model": True,
|
|
62
|
+
"use_graph": True,
|
|
63
|
+
"stream_parallel": True,
|
|
64
|
+
"dtype": mx.float32
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
def __init__(self, config: Optional[MLXConfig] = None):
|
|
69
|
+
"""Initialize MLX accelerator."""
|
|
70
|
+
self.config = config or MLXConfig()
|
|
71
|
+
self.device = mx.default_device()
|
|
72
|
+
|
|
73
|
+
logger.info(f"Initializing MLX Accelerator with device: {self.device}")
|
|
74
|
+
logger.info(f"Config: AMX={self.config.use_amx}, fuse_ops={self.config.fuse_operations}, ")
|
|
75
|
+
logger.info(f" lazy_eval={self.config.lazy_evaluation}, kv_cache={self.config.rotating_kv_cache}")
|
|
76
|
+
logger.info(f" quantization={self.config.quantization_bits}bit, dtype={self.config.dtype}")
|
|
77
|
+
|
|
78
|
+
# Check if device is GPU - MLX returns Device(gpu, 0) format
|
|
79
|
+
device_str = str(self.device).lower()
|
|
80
|
+
if "gpu" not in device_str:
|
|
81
|
+
logger.error(f"MLX not using GPU: {self.device}")
|
|
82
|
+
raise RuntimeError(f"MLX not using GPU: {self.device}")
|
|
83
|
+
|
|
84
|
+
mx.set_default_device(mx.gpu)
|
|
85
|
+
logger.info("MLX device set to GPU")
|
|
86
|
+
|
|
87
|
+
def optimize_model(
|
|
88
|
+
self,
|
|
89
|
+
model: nn.Module,
|
|
90
|
+
example_input: Optional[mx.array] = None
|
|
91
|
+
) -> nn.Module:
|
|
92
|
+
"""
|
|
93
|
+
Optimize an MLX model for GPU execution with AMX support.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
model: MLX model to optimize
|
|
97
|
+
example_input: Example input for shape inference
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Optimized model
|
|
101
|
+
"""
|
|
102
|
+
logger.info("Starting model optimization")
|
|
103
|
+
|
|
104
|
+
# Check if this is an mlx_lm model (already optimized)
|
|
105
|
+
is_mlx_lm_model = not hasattr(model, 'apply_to_parameters')
|
|
106
|
+
|
|
107
|
+
if is_mlx_lm_model:
|
|
108
|
+
logger.info("Detected mlx_lm model - applying compatible optimizations")
|
|
109
|
+
|
|
110
|
+
# MLX LM models are already quantized and optimized
|
|
111
|
+
# We can still enable some runtime optimizations
|
|
112
|
+
|
|
113
|
+
if self.config.use_amx:
|
|
114
|
+
logger.info("AMX acceleration will be used automatically")
|
|
115
|
+
mx.set_default_device(mx.gpu)
|
|
116
|
+
|
|
117
|
+
if self.config.compile_model:
|
|
118
|
+
logger.info("Enabling JIT compilation")
|
|
119
|
+
model = self._compile_model(model)
|
|
120
|
+
|
|
121
|
+
if self.config.rotating_kv_cache:
|
|
122
|
+
logger.info(f"Setting up rotating KV cache (size: {self.config.kv_cache_size})")
|
|
123
|
+
model = self._setup_rotating_kv_cache(model)
|
|
124
|
+
|
|
125
|
+
else:
|
|
126
|
+
# Standard MLX nn.Module optimization path
|
|
127
|
+
model = self._optimize_dtype(model)
|
|
128
|
+
|
|
129
|
+
if self.config.use_amx:
|
|
130
|
+
logger.info("Enabling AMX acceleration")
|
|
131
|
+
model = self._enable_amx_acceleration(model)
|
|
132
|
+
|
|
133
|
+
if self.config.fuse_operations:
|
|
134
|
+
logger.info("Enabling operation fusion")
|
|
135
|
+
model = self._fuse_operations(model)
|
|
136
|
+
|
|
137
|
+
if self.config.compile_model:
|
|
138
|
+
logger.info("Compiling model with JIT")
|
|
139
|
+
model = self._compile_model(model)
|
|
140
|
+
|
|
141
|
+
if self.config.use_graph and example_input is not None:
|
|
142
|
+
logger.info("Enabling graph optimization")
|
|
143
|
+
model = self._enable_graph_optimization(model, example_input)
|
|
144
|
+
|
|
145
|
+
if self.config.stream_parallel:
|
|
146
|
+
logger.info("Enabling stream parallelism")
|
|
147
|
+
model = self._enable_stream_parallelism(model)
|
|
148
|
+
|
|
149
|
+
if self.config.rotating_kv_cache:
|
|
150
|
+
logger.info(f"Setting up rotating KV cache (size: {self.config.kv_cache_size})")
|
|
151
|
+
model = self._setup_rotating_kv_cache(model)
|
|
152
|
+
|
|
153
|
+
# Evaluate parameters if they exist
|
|
154
|
+
if hasattr(model, 'parameters'):
|
|
155
|
+
mx.eval(model.parameters())
|
|
156
|
+
|
|
157
|
+
logger.info("Model optimization completed")
|
|
158
|
+
|
|
159
|
+
return model
|
|
160
|
+
|
|
161
|
+
def _optimize_dtype(self, model: nn.Module) -> nn.Module:
|
|
162
|
+
"""Optimize model data types for performance."""
|
|
163
|
+
target_dtype = self.config.dtype
|
|
164
|
+
|
|
165
|
+
# Try bfloat16 first, fall back to float16 if not supported
|
|
166
|
+
if target_dtype == mx.bfloat16:
|
|
167
|
+
try:
|
|
168
|
+
test = mx.array([1.0], dtype=mx.bfloat16)
|
|
169
|
+
mx.eval(test)
|
|
170
|
+
logger.info("Using bfloat16 precision")
|
|
171
|
+
except:
|
|
172
|
+
target_dtype = mx.float16
|
|
173
|
+
logger.info("bfloat16 not supported, falling back to float16")
|
|
174
|
+
|
|
175
|
+
# Check if model has apply_to_parameters method
|
|
176
|
+
if hasattr(model, 'apply_to_parameters'):
|
|
177
|
+
def convert_param(x):
|
|
178
|
+
if x.dtype == mx.float32:
|
|
179
|
+
return x.astype(target_dtype)
|
|
180
|
+
return x
|
|
181
|
+
|
|
182
|
+
model.apply_to_parameters(convert_param)
|
|
183
|
+
logger.debug(f"Model dtype optimized to {target_dtype}")
|
|
184
|
+
else:
|
|
185
|
+
# For models without apply_to_parameters (like mlx_lm models)
|
|
186
|
+
# They typically already have optimized dtype from loading
|
|
187
|
+
logger.debug(f"Model already optimized, target dtype: {target_dtype}")
|
|
188
|
+
|
|
189
|
+
return model
|
|
190
|
+
|
|
191
|
+
def _compile_model(self, model: nn.Module) -> nn.Module:
|
|
192
|
+
"""Compile model with JIT for faster execution."""
|
|
193
|
+
logger.debug("Compiling model with mx.compile decorator")
|
|
194
|
+
|
|
195
|
+
# Use advanced compilation with operation fusion
|
|
196
|
+
@mx.compile
|
|
197
|
+
def compiled_forward(x, cache=None):
|
|
198
|
+
if cache is not None:
|
|
199
|
+
return model(x, cache=cache)
|
|
200
|
+
return model(x)
|
|
201
|
+
|
|
202
|
+
# Store original for fallback
|
|
203
|
+
original_forward = model.__call__
|
|
204
|
+
model.__call__ = compiled_forward
|
|
205
|
+
model._original_forward = original_forward
|
|
206
|
+
model._compiled = True
|
|
207
|
+
|
|
208
|
+
logger.debug("Model compilation completed")
|
|
209
|
+
return model
|
|
210
|
+
|
|
211
|
+
def _enable_graph_optimization(
|
|
212
|
+
self,
|
|
213
|
+
model: nn.Module,
|
|
214
|
+
example_input: mx.array
|
|
215
|
+
) -> nn.Module:
|
|
216
|
+
"""Enable graph-level optimizations."""
|
|
217
|
+
try:
|
|
218
|
+
with mx.stream(mx.gpu):
|
|
219
|
+
_ = model(example_input)
|
|
220
|
+
mx.eval(model.parameters())
|
|
221
|
+
logger.debug("Graph optimization enabled")
|
|
222
|
+
except Exception as e:
|
|
223
|
+
logger.warning(f"Graph optimization failed: {e}")
|
|
224
|
+
print(f"Warning: Graph optimization failed: {e}")
|
|
225
|
+
|
|
226
|
+
return model
|
|
227
|
+
|
|
228
|
+
def _enable_stream_parallelism(self, model: nn.Module) -> nn.Module:
|
|
229
|
+
"""Enable stream parallelism for concurrent operations."""
|
|
230
|
+
|
|
231
|
+
def parallel_forward(self, x):
|
|
232
|
+
streams = [mx.Stream(mx.gpu) for _ in range(2)]
|
|
233
|
+
|
|
234
|
+
with streams[0]:
|
|
235
|
+
x1 = self.layers[:len(self.layers)//2](x)
|
|
236
|
+
|
|
237
|
+
with streams[1]:
|
|
238
|
+
x2 = self.layers[len(self.layers)//2:](x)
|
|
239
|
+
|
|
240
|
+
mx.synchronize()
|
|
241
|
+
return x1 + x2
|
|
242
|
+
|
|
243
|
+
return model
|
|
244
|
+
|
|
245
|
+
def accelerate_transformer(
|
|
246
|
+
self,
|
|
247
|
+
model: nn.Module,
|
|
248
|
+
num_heads: int,
|
|
249
|
+
head_dim: int
|
|
250
|
+
) -> nn.Module:
|
|
251
|
+
"""Apply transformer-specific optimizations with AMX acceleration."""
|
|
252
|
+
|
|
253
|
+
@mx.compile
|
|
254
|
+
def optimized_attention(query, key, value, mask=None, cache=None):
|
|
255
|
+
"""Fused attention with AMX-accelerated matmul."""
|
|
256
|
+
scale = head_dim ** -0.5
|
|
257
|
+
|
|
258
|
+
# Update cache if provided (for KV caching)
|
|
259
|
+
if cache is not None:
|
|
260
|
+
if "k" in cache and "v" in cache:
|
|
261
|
+
key = mx.concatenate([cache["k"], key], axis=1)
|
|
262
|
+
value = mx.concatenate([cache["v"], value], axis=1)
|
|
263
|
+
# Implement rotating cache if sequence too long
|
|
264
|
+
if key.shape[1] > self.config.kv_cache_size:
|
|
265
|
+
key = key[:, -self.config.kv_cache_size:]
|
|
266
|
+
value = value[:, -self.config.kv_cache_size:]
|
|
267
|
+
cache["k"] = key
|
|
268
|
+
cache["v"] = value
|
|
269
|
+
|
|
270
|
+
# AMX-accelerated matrix multiplication
|
|
271
|
+
scores = mx.matmul(query, mx.swapaxes(key, -2, -1)) * scale
|
|
272
|
+
|
|
273
|
+
if mask is not None:
|
|
274
|
+
scores = scores + mask
|
|
275
|
+
|
|
276
|
+
# Fused softmax operation
|
|
277
|
+
probs = mx.softmax(scores, axis=-1)
|
|
278
|
+
|
|
279
|
+
# AMX-accelerated output projection
|
|
280
|
+
output = mx.matmul(probs, value)
|
|
281
|
+
|
|
282
|
+
return output, cache
|
|
283
|
+
|
|
284
|
+
# Replace attention mechanism
|
|
285
|
+
if hasattr(model, 'attention'):
|
|
286
|
+
model.attention.forward = optimized_attention
|
|
287
|
+
|
|
288
|
+
# Apply to all transformer layers
|
|
289
|
+
for layer in model.layers if hasattr(model, 'layers') else []:
|
|
290
|
+
if hasattr(layer, 'self_attn'):
|
|
291
|
+
layer.self_attn.forward = optimized_attention
|
|
292
|
+
|
|
293
|
+
return model
|
|
294
|
+
|
|
295
|
+
def optimize_generation(
|
|
296
|
+
self,
|
|
297
|
+
generate_fn: Callable,
|
|
298
|
+
max_cache_size: int = 32768
|
|
299
|
+
) -> Callable:
|
|
300
|
+
"""Optimize text generation function."""
|
|
301
|
+
|
|
302
|
+
@functools.wraps(generate_fn)
|
|
303
|
+
def optimized_generate(*args, **kwargs):
|
|
304
|
+
cache = {}
|
|
305
|
+
|
|
306
|
+
def cached_forward(x, cache_key):
|
|
307
|
+
if cache_key in cache:
|
|
308
|
+
return cache[cache_key]
|
|
309
|
+
|
|
310
|
+
result = generate_fn(x)
|
|
311
|
+
|
|
312
|
+
if len(cache) < max_cache_size:
|
|
313
|
+
cache[cache_key] = result
|
|
314
|
+
|
|
315
|
+
return result
|
|
316
|
+
|
|
317
|
+
return generate_fn(*args, **kwargs)
|
|
318
|
+
|
|
319
|
+
return optimized_generate
|
|
320
|
+
|
|
321
|
+
def create_pipeline(
|
|
322
|
+
self,
|
|
323
|
+
models: List[nn.Module],
|
|
324
|
+
batch_size: int = 1
|
|
325
|
+
) -> Callable:
|
|
326
|
+
"""Create an optimized inference pipeline."""
|
|
327
|
+
|
|
328
|
+
optimized_models = [self.optimize_model(m) for m in models]
|
|
329
|
+
|
|
330
|
+
def pipeline(x):
|
|
331
|
+
"""Run inference through pipeline."""
|
|
332
|
+
for model in optimized_models:
|
|
333
|
+
x = model(x)
|
|
334
|
+
mx.eval(x)
|
|
335
|
+
return x
|
|
336
|
+
|
|
337
|
+
return mx.compile(pipeline)
|
|
338
|
+
|
|
339
|
+
def profile_model(
|
|
340
|
+
self,
|
|
341
|
+
model: nn.Module,
|
|
342
|
+
input_shape: Tuple[int, ...],
|
|
343
|
+
num_iterations: int = 100
|
|
344
|
+
) -> Dict[str, Any]:
|
|
345
|
+
"""Profile model performance on MLX."""
|
|
346
|
+
model.eval()
|
|
347
|
+
|
|
348
|
+
dummy_input = mx.random.normal(input_shape)
|
|
349
|
+
if self.config.dtype == mx.float16:
|
|
350
|
+
dummy_input = dummy_input.astype(mx.float16)
|
|
351
|
+
|
|
352
|
+
mx.eval(dummy_input)
|
|
353
|
+
|
|
354
|
+
warmup_iterations = 10
|
|
355
|
+
for _ in range(warmup_iterations):
|
|
356
|
+
output = model(dummy_input)
|
|
357
|
+
mx.eval(output)
|
|
358
|
+
|
|
359
|
+
start_time = time.perf_counter()
|
|
360
|
+
for _ in range(num_iterations):
|
|
361
|
+
output = model(dummy_input)
|
|
362
|
+
mx.eval(output)
|
|
363
|
+
|
|
364
|
+
end_time = time.perf_counter()
|
|
365
|
+
|
|
366
|
+
avg_time = (end_time - start_time) / num_iterations
|
|
367
|
+
throughput = input_shape[0] / avg_time if avg_time > 0 else 0
|
|
368
|
+
|
|
369
|
+
num_params = sum(p.size for p in tree_flatten(model.parameters()))
|
|
370
|
+
|
|
371
|
+
return {
|
|
372
|
+
"avg_inference_time": avg_time,
|
|
373
|
+
"throughput": throughput,
|
|
374
|
+
"num_parameters": num_params,
|
|
375
|
+
"dtype": str(self.config.dtype),
|
|
376
|
+
"device": str(self.device),
|
|
377
|
+
"batch_size": input_shape[0]
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
def optimize_memory(self, model: nn.Module) -> nn.Module:
|
|
381
|
+
"""Optimize memory usage for large models."""
|
|
382
|
+
|
|
383
|
+
def shard_weights(weights, num_shards=2):
|
|
384
|
+
"""Shard weights across multiple arrays."""
|
|
385
|
+
if weights.size < self.config.fusion_threshold:
|
|
386
|
+
return [weights]
|
|
387
|
+
|
|
388
|
+
return mx.split(weights, num_shards, axis=0)
|
|
389
|
+
|
|
390
|
+
for name, param in model.parameters().items():
|
|
391
|
+
if param.size > self.config.fusion_threshold:
|
|
392
|
+
sharded = shard_weights(param)
|
|
393
|
+
model.parameters()[name] = sharded[0]
|
|
394
|
+
|
|
395
|
+
return model
|
|
396
|
+
|
|
397
|
+
def quantize_model(
|
|
398
|
+
self,
|
|
399
|
+
model: nn.Module,
|
|
400
|
+
bits: int = 4,
|
|
401
|
+
mixed_precision: Optional[Dict[str, int]] = None
|
|
402
|
+
) -> nn.Module:
|
|
403
|
+
"""Advanced quantization with mixed precision support."""
|
|
404
|
+
logger.info(f"Starting model quantization: {bits}-bit")
|
|
405
|
+
if mixed_precision:
|
|
406
|
+
logger.info(f"Mixed precision config: {mixed_precision}")
|
|
407
|
+
|
|
408
|
+
quantized_layers = 0
|
|
409
|
+
total_layers = 0
|
|
410
|
+
|
|
411
|
+
def quantize_weight(param_name: str, w: mx.array) -> mx.array:
|
|
412
|
+
"""Quantize weight with per-layer precision."""
|
|
413
|
+
if w.dtype not in [mx.float32, mx.float16, mx.bfloat16]:
|
|
414
|
+
return w
|
|
415
|
+
|
|
416
|
+
nonlocal quantized_layers, total_layers
|
|
417
|
+
total_layers += 1
|
|
418
|
+
|
|
419
|
+
# Determine bits for this layer
|
|
420
|
+
layer_bits = bits
|
|
421
|
+
if mixed_precision:
|
|
422
|
+
# Critical layers get higher precision
|
|
423
|
+
if any(critical in param_name for critical in ["lm_head", "embed", "wte", "wpe"]):
|
|
424
|
+
layer_bits = mixed_precision.get("critical_bits", 6)
|
|
425
|
+
logger.debug(f"Layer {param_name}: using {layer_bits}-bit (critical)")
|
|
426
|
+
elif "attention" in param_name:
|
|
427
|
+
layer_bits = mixed_precision.get("attention_bits", bits)
|
|
428
|
+
logger.debug(f"Layer {param_name}: using {layer_bits}-bit (attention)")
|
|
429
|
+
elif any(ffn in param_name for ffn in ["mlp", "feed_forward", "ffn"]):
|
|
430
|
+
layer_bits = mixed_precision.get("ffn_bits", bits)
|
|
431
|
+
logger.debug(f"Layer {param_name}: using {layer_bits}-bit (FFN)")
|
|
432
|
+
|
|
433
|
+
# Group-wise quantization for better quality
|
|
434
|
+
group_size = 64
|
|
435
|
+
orig_shape = w.shape
|
|
436
|
+
w_flat = w.reshape(-1)
|
|
437
|
+
|
|
438
|
+
# Pad for group alignment
|
|
439
|
+
pad_size = (group_size - w_flat.shape[0] % group_size) % group_size
|
|
440
|
+
if pad_size > 0:
|
|
441
|
+
w_flat = mx.pad(w_flat, [(0, pad_size)])
|
|
442
|
+
|
|
443
|
+
# Reshape for group-wise quantization
|
|
444
|
+
w_grouped = w_flat.reshape(-1, group_size)
|
|
445
|
+
|
|
446
|
+
# Compute scales per group
|
|
447
|
+
w_max = mx.max(mx.abs(w_grouped), axis=1, keepdims=True)
|
|
448
|
+
scale = w_max / (2 ** (layer_bits - 1) - 1)
|
|
449
|
+
scale = mx.where(scale == 0, 1.0, scale) # Avoid division by zero
|
|
450
|
+
|
|
451
|
+
# Quantize
|
|
452
|
+
if layer_bits == 4:
|
|
453
|
+
quantized = mx.round(w_grouped / scale).astype(mx.int8)
|
|
454
|
+
quantized_layers += 1
|
|
455
|
+
elif layer_bits == 8:
|
|
456
|
+
quantized = mx.round(w_grouped / scale).astype(mx.int8)
|
|
457
|
+
quantized_layers += 1
|
|
458
|
+
else:
|
|
459
|
+
# For higher precision, keep as is
|
|
460
|
+
logger.debug(f"Layer {param_name}: keeping original precision")
|
|
461
|
+
return w
|
|
462
|
+
|
|
463
|
+
# Dequantize for inference
|
|
464
|
+
dequantized = quantized.astype(mx.float16) * scale
|
|
465
|
+
|
|
466
|
+
# Reshape back
|
|
467
|
+
dequantized_flat = dequantized.reshape(-1)
|
|
468
|
+
if pad_size > 0:
|
|
469
|
+
dequantized_flat = dequantized_flat[:-pad_size]
|
|
470
|
+
|
|
471
|
+
return dequantized_flat.reshape(orig_shape)
|
|
472
|
+
|
|
473
|
+
# Apply quantization to all parameters
|
|
474
|
+
if hasattr(model, 'named_parameters'):
|
|
475
|
+
for name, param in model.named_parameters():
|
|
476
|
+
quantized = quantize_weight(name, param)
|
|
477
|
+
# Update parameter in-place
|
|
478
|
+
if hasattr(param, 'update'):
|
|
479
|
+
param.update(quantized)
|
|
480
|
+
else:
|
|
481
|
+
# For models that don't support in-place update
|
|
482
|
+
logger.debug(f"Cannot update parameter {name} in-place, skipping")
|
|
483
|
+
|
|
484
|
+
mx.eval(model.parameters())
|
|
485
|
+
|
|
486
|
+
logger.info(f"Quantization completed: {quantized_layers}/{total_layers} layers quantized")
|
|
487
|
+
return model
|
|
488
|
+
|
|
489
|
+
@staticmethod
|
|
490
|
+
def get_device_info() -> Dict[str, Any]:
|
|
491
|
+
"""Get MLX device information."""
|
|
492
|
+
device = mx.default_device()
|
|
493
|
+
|
|
494
|
+
info = {
|
|
495
|
+
"device": str(device),
|
|
496
|
+
"is_gpu": str(device).lower() == "gpu",
|
|
497
|
+
"default_dtype": str(mx.float32)
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
return info
|
|
501
|
+
|
|
502
|
+
def _enable_amx_acceleration(self, model: nn.Module) -> nn.Module:
|
|
503
|
+
"""Enable AMX coprocessor acceleration."""
|
|
504
|
+
logger.debug("Configuring model for AMX acceleration")
|
|
505
|
+
# Configure for AMX usage
|
|
506
|
+
mx.set_default_device(mx.gpu)
|
|
507
|
+
|
|
508
|
+
# Check if model has apply_to_parameters method
|
|
509
|
+
if hasattr(model, 'apply_to_parameters'):
|
|
510
|
+
# Apply AMX-friendly layouts to weight matrices
|
|
511
|
+
def optimize_for_amx(param):
|
|
512
|
+
if len(param.shape) == 2: # Matrix weights
|
|
513
|
+
# Ensure alignment for AMX (32x32 tiles)
|
|
514
|
+
rows, cols = param.shape
|
|
515
|
+
if rows % 32 != 0 or cols % 32 != 0:
|
|
516
|
+
# Pad to AMX-friendly dimensions
|
|
517
|
+
pad_rows = (32 - rows % 32) % 32
|
|
518
|
+
pad_cols = (32 - cols % 32) % 32
|
|
519
|
+
if pad_rows > 0 or pad_cols > 0:
|
|
520
|
+
param = mx.pad(param, [(0, pad_rows), (0, pad_cols)])
|
|
521
|
+
return param
|
|
522
|
+
|
|
523
|
+
model.apply_to_parameters(optimize_for_amx)
|
|
524
|
+
logger.debug("AMX optimization applied to model weights")
|
|
525
|
+
else:
|
|
526
|
+
# For models without apply_to_parameters
|
|
527
|
+
# AMX will still be used automatically by MLX for matrix operations
|
|
528
|
+
logger.debug("AMX acceleration enabled (automatic for matrix ops)")
|
|
529
|
+
|
|
530
|
+
return model
|
|
531
|
+
|
|
532
|
+
def _fuse_operations(self, model: nn.Module) -> nn.Module:
|
|
533
|
+
"""Fuse operations for reduced kernel launches."""
|
|
534
|
+
# Operation fusion is handled by mx.compile decorator
|
|
535
|
+
# Mark model for aggressive fusion
|
|
536
|
+
if hasattr(model, 'config'):
|
|
537
|
+
model.config.fuse_ops = True
|
|
538
|
+
logger.debug("Operation fusion enabled in model config")
|
|
539
|
+
return model
|
|
540
|
+
|
|
541
|
+
def _setup_rotating_kv_cache(self, model: nn.Module) -> nn.Module:
|
|
542
|
+
"""Setup rotating KV cache for long contexts."""
|
|
543
|
+
logger.debug(f"Setting up rotating KV cache with max size: {self.config.kv_cache_size}")
|
|
544
|
+
# Initialize cache structure
|
|
545
|
+
model.kv_cache = {
|
|
546
|
+
"max_size": self.config.kv_cache_size,
|
|
547
|
+
"current_size": 0,
|
|
548
|
+
"cache": {}
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
# Modify forward to use cache
|
|
552
|
+
original_forward = model.forward if hasattr(model, 'forward') else model.__call__
|
|
553
|
+
|
|
554
|
+
def forward_with_cache(x, **kwargs):
|
|
555
|
+
kwargs['cache'] = model.kv_cache.get('cache', {})
|
|
556
|
+
result = original_forward(x, **kwargs)
|
|
557
|
+
return result
|
|
558
|
+
|
|
559
|
+
model.forward = forward_with_cache
|
|
560
|
+
logger.debug("Rotating KV cache configured")
|
|
561
|
+
return model
|
|
562
|
+
|
|
563
|
+
def generate_optimized(
|
|
564
|
+
self,
|
|
565
|
+
model: nn.Module,
|
|
566
|
+
tokenizer: Any,
|
|
567
|
+
prompt: str,
|
|
568
|
+
max_tokens: int = 512,
|
|
569
|
+
temperature: float = 0.7,
|
|
570
|
+
top_p: float = 0.95,
|
|
571
|
+
repetition_penalty: float = 1.1,
|
|
572
|
+
stream: bool = True,
|
|
573
|
+
stop_sequences: List[str] = None
|
|
574
|
+
) -> Generator[str, None, None]:
|
|
575
|
+
"""Optimized generation with AMX and caching."""
|
|
576
|
+
# Add stop sequences to tokenizer if provided
|
|
577
|
+
if stop_sequences and hasattr(tokenizer, 'add_eos_token'):
|
|
578
|
+
for stop_seq in stop_sequences:
|
|
579
|
+
try:
|
|
580
|
+
tokenizer.add_eos_token(stop_seq)
|
|
581
|
+
logger.debug(f"Added stop sequence to tokenizer: {stop_seq}")
|
|
582
|
+
except Exception as e:
|
|
583
|
+
logger.warning(f"Could not add stop sequence '{stop_seq}': {e}")
|
|
584
|
+
|
|
585
|
+
# Import sample_utils for creating sampler
|
|
586
|
+
try:
|
|
587
|
+
from mlx_lm.sample_utils import make_sampler
|
|
588
|
+
# Create sampler with temperature and top_p
|
|
589
|
+
sampler = make_sampler(temperature, top_p=top_p)
|
|
590
|
+
logger.debug(f"Created sampler with temperature={temperature}, top_p={top_p}")
|
|
591
|
+
except ImportError:
|
|
592
|
+
sampler = None
|
|
593
|
+
logger.warning("mlx_lm.sample_utils not available, using default sampler")
|
|
594
|
+
|
|
595
|
+
# Check if mlx_lm functions are available
|
|
596
|
+
if stream and stream_generate is not None:
|
|
597
|
+
logger.debug("Using mlx_lm stream_generate for optimized generation")
|
|
598
|
+
# stream_generate accepts sampler, not individual params
|
|
599
|
+
generation_kwargs = {
|
|
600
|
+
"prompt": prompt,
|
|
601
|
+
"max_tokens": max_tokens,
|
|
602
|
+
}
|
|
603
|
+
if sampler is not None:
|
|
604
|
+
generation_kwargs["sampler"] = sampler
|
|
605
|
+
|
|
606
|
+
# Note: repetition_penalty may need to be handled via logits_processors
|
|
607
|
+
# For now, we'll use the basic generation
|
|
608
|
+
for response in stream_generate(
|
|
609
|
+
model,
|
|
610
|
+
tokenizer,
|
|
611
|
+
**generation_kwargs
|
|
612
|
+
):
|
|
613
|
+
# stream_generate returns GenerationResponse objects with .text attribute
|
|
614
|
+
if hasattr(response, 'text'):
|
|
615
|
+
yield response.text
|
|
616
|
+
else:
|
|
617
|
+
# Fallback if structure changes
|
|
618
|
+
yield str(response)
|
|
619
|
+
elif not stream and generate is not None:
|
|
620
|
+
logger.debug("Using mlx_lm generate for optimized generation")
|
|
621
|
+
# generate also uses sampler, not individual params
|
|
622
|
+
generation_kwargs = {
|
|
623
|
+
"prompt": prompt,
|
|
624
|
+
"max_tokens": max_tokens,
|
|
625
|
+
}
|
|
626
|
+
if sampler is not None:
|
|
627
|
+
generation_kwargs["sampler"] = sampler
|
|
628
|
+
|
|
629
|
+
result = generate(
|
|
630
|
+
model,
|
|
631
|
+
tokenizer,
|
|
632
|
+
**generation_kwargs
|
|
633
|
+
)
|
|
634
|
+
yield result
|
|
635
|
+
else:
|
|
636
|
+
# Fallback: just return a message
|
|
637
|
+
logger.warning("MLX generation functions not available, using fallback")
|
|
638
|
+
yield f"MLX generation not available. Input: {prompt[:50]}..."
|
|
639
|
+
|
|
640
|
+
@staticmethod
|
|
641
|
+
def benchmark_operation(
|
|
642
|
+
operation: Callable,
|
|
643
|
+
input_shape: Tuple[int, ...],
|
|
644
|
+
num_iterations: int = 1000,
|
|
645
|
+
use_amx: bool = True
|
|
646
|
+
) -> Dict[str, float]:
|
|
647
|
+
"""Benchmark operation with AMX comparison."""
|
|
648
|
+
x = mx.random.normal(input_shape)
|
|
649
|
+
mx.eval(x)
|
|
650
|
+
|
|
651
|
+
# Warmup
|
|
652
|
+
for _ in range(10):
|
|
653
|
+
_ = operation(x)
|
|
654
|
+
mx.eval(_)
|
|
655
|
+
|
|
656
|
+
# Benchmark
|
|
657
|
+
start = time.perf_counter()
|
|
658
|
+
for _ in range(num_iterations):
|
|
659
|
+
result = operation(x)
|
|
660
|
+
mx.eval(result)
|
|
661
|
+
end = time.perf_counter()
|
|
662
|
+
|
|
663
|
+
avg_time = (end - start) / num_iterations * 1000 # ms
|
|
664
|
+
|
|
665
|
+
# Calculate FLOPS for matmul operations
|
|
666
|
+
flops = 0
|
|
667
|
+
if len(input_shape) >= 2:
|
|
668
|
+
# Approximate FLOPS for matrix operations
|
|
669
|
+
flops = 2 * np.prod(input_shape) * input_shape[-1] * num_iterations / (end - start)
|
|
670
|
+
|
|
671
|
+
result = {
|
|
672
|
+
"avg_time_ms": avg_time,
|
|
673
|
+
"throughput_gflops": flops / 1e9 if flops > 0 else 0,
|
|
674
|
+
"amx_enabled": use_amx
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
logger.debug(f"Benchmark results: {result}")
|
|
678
|
+
return result
|