cortex-llm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. cortex/__init__.py +73 -0
  2. cortex/__main__.py +83 -0
  3. cortex/config.py +329 -0
  4. cortex/conversation_manager.py +468 -0
  5. cortex/fine_tuning/__init__.py +8 -0
  6. cortex/fine_tuning/dataset.py +332 -0
  7. cortex/fine_tuning/mlx_lora_trainer.py +502 -0
  8. cortex/fine_tuning/trainer.py +957 -0
  9. cortex/fine_tuning/wizard.py +707 -0
  10. cortex/gpu_validator.py +467 -0
  11. cortex/inference_engine.py +727 -0
  12. cortex/metal/__init__.py +275 -0
  13. cortex/metal/gpu_validator.py +177 -0
  14. cortex/metal/memory_pool.py +886 -0
  15. cortex/metal/mlx_accelerator.py +678 -0
  16. cortex/metal/mlx_converter.py +638 -0
  17. cortex/metal/mps_optimizer.py +417 -0
  18. cortex/metal/optimizer.py +665 -0
  19. cortex/metal/performance_profiler.py +364 -0
  20. cortex/model_downloader.py +130 -0
  21. cortex/model_manager.py +2187 -0
  22. cortex/quantization/__init__.py +5 -0
  23. cortex/quantization/dynamic_quantizer.py +736 -0
  24. cortex/template_registry/__init__.py +15 -0
  25. cortex/template_registry/auto_detector.py +144 -0
  26. cortex/template_registry/config_manager.py +234 -0
  27. cortex/template_registry/interactive.py +260 -0
  28. cortex/template_registry/registry.py +347 -0
  29. cortex/template_registry/template_profiles/__init__.py +5 -0
  30. cortex/template_registry/template_profiles/base.py +142 -0
  31. cortex/template_registry/template_profiles/complex/__init__.py +5 -0
  32. cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
  33. cortex/template_registry/template_profiles/standard/__init__.py +9 -0
  34. cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
  35. cortex/template_registry/template_profiles/standard/chatml.py +82 -0
  36. cortex/template_registry/template_profiles/standard/gemma.py +103 -0
  37. cortex/template_registry/template_profiles/standard/llama.py +87 -0
  38. cortex/template_registry/template_profiles/standard/simple.py +65 -0
  39. cortex/ui/__init__.py +120 -0
  40. cortex/ui/cli.py +1685 -0
  41. cortex/ui/markdown_render.py +185 -0
  42. cortex/ui/terminal_app.py +534 -0
  43. cortex_llm-1.0.0.dist-info/METADATA +275 -0
  44. cortex_llm-1.0.0.dist-info/RECORD +48 -0
  45. cortex_llm-1.0.0.dist-info/WHEEL +5 -0
  46. cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
  47. cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
  48. cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2187 @@
1
+ """Model management for GPU-accelerated inference."""
2
+
3
+ import os
4
+ import sys
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import Dict, Any, Optional, List, Tuple
8
+ from dataclasses import dataclass
9
+ from enum import Enum
10
+ import hashlib
11
+ import json
12
+ import shutil
13
+ import struct
14
+ from datetime import datetime
15
+
16
+ # Configure logging
17
+ logger = logging.getLogger(__name__)
18
+ logger.setLevel(logging.INFO)
19
+
20
+ import torch
21
+ import mlx.core as mx
22
+ import mlx.nn as nn
23
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
24
+
25
+ # Import MLX LM functions safely
26
+ try:
27
+ from mlx_lm import load as mlx_load
28
+ except ImportError:
29
+ mlx_load = None
30
+
31
+ from cortex.config import Config
32
+ from cortex.gpu_validator import GPUValidator
33
+ from cortex.metal.memory_pool import MemoryPool, AllocationStrategy
34
+ from cortex.metal.mlx_converter import MLXConverter, ConversionConfig, QuantizationRecipe, ConversionFormat
35
+ from cortex.metal.mlx_accelerator import MLXAccelerator, MLXConfig
36
+ from cortex.quantization.dynamic_quantizer import DynamicQuantizer, QuantizationConfig, QuantizationMode
37
+
38
+ # Configure tokenizer parallelism for optimal performance
39
+ # Enable parallelism for better tokenization speed
40
+ # This is safe with threading (not multiprocessing)
41
+ if os.environ.get('TOKENIZERS_PARALLELISM') is None:
42
+ # Enable tokenizer parallelism for better performance
43
+ # Safe with threading, improves tokenization speed
44
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
45
+
46
+ # Set optimal number of threads for tokenizer
47
+ if os.environ.get('RAYON_NUM_THREADS') is None:
48
+ # Use half of available CPU cores for tokenizer threads
49
+ import multiprocessing
50
+ num_cores = multiprocessing.cpu_count()
51
+ os.environ['RAYON_NUM_THREADS'] = str(max(1, num_cores // 2))
52
+
53
+
54
+ class ModelFormat(Enum):
55
+ """Supported model formats."""
56
+ GGUF = "gguf"
57
+ MLX = "mlx"
58
+ SAFETENSORS = "safetensors"
59
+ PYTORCH = "pytorch"
60
+ QUANTIZED = "quantized" # For GPTQ/AWQ/etc
61
+ UNKNOWN = "unknown"
62
+
63
+
64
+ class QuantizationType(Enum):
65
+ """Supported quantization types."""
66
+ NONE = "none"
67
+ INT4 = "int4"
68
+ INT8 = "int8"
69
+ GPTQ = "gptq"
70
+ AWQ = "awq"
71
+ Q4_K_M = "Q4_K_M"
72
+ Q5_K_M = "Q5_K_M"
73
+ Q6_K = "Q6_K"
74
+ Q8_0 = "Q8_0"
75
+ FP16 = "FP16"
76
+ FP32 = "FP32"
77
+
78
+
79
+ @dataclass
80
+ class ModelInfo:
81
+ """Information about a loaded model."""
82
+ name: str
83
+ path: Path
84
+ format: ModelFormat
85
+ quantization: QuantizationType
86
+ size_bytes: int
87
+ parameters: int
88
+ context_length: int
89
+ loaded_at: datetime
90
+ gpu_memory_used: int
91
+ tokenizer_path: Optional[Path]
92
+ config: Dict[str, Any]
93
+
94
+ @property
95
+ def size_gb(self) -> float:
96
+ """Get model size in GB."""
97
+ return self.size_bytes / (1024 ** 3)
98
+
99
+ def to_dict(self) -> Dict[str, Any]:
100
+ """Convert to dictionary."""
101
+ return {
102
+ 'name': self.name,
103
+ 'path': str(self.path),
104
+ 'format': self.format.value,
105
+ 'quantization': self.quantization.value,
106
+ 'size_gb': self.size_gb,
107
+ 'parameters': self.parameters,
108
+ 'context_length': self.context_length,
109
+ 'loaded_at': self.loaded_at.isoformat(),
110
+ 'gpu_memory_used': self.gpu_memory_used
111
+ }
112
+
113
+
114
+ class ModelManager:
115
+ """Manage model loading and GPU memory allocation."""
116
+
117
+ def __init__(
118
+ self,
119
+ config: Config,
120
+ gpu_validator: GPUValidator,
121
+ memory_pool: Optional[MemoryPool] = None
122
+ ):
123
+ """Initialize model manager."""
124
+ self.config = config
125
+ self.gpu_validator = gpu_validator
126
+ self.memory_pool = memory_pool
127
+ self.loaded_models: Dict[str, ModelInfo] = {}
128
+ self.current_model: Optional[str] = None
129
+ self.model_cache: Dict[str, Any] = {}
130
+ self.tokenizers: Dict[str, Any] = {}
131
+
132
+ # Initialize quantizer for memory-efficient loading
133
+ self.quantizer = DynamicQuantizer(QuantizationConfig(
134
+ mode=QuantizationMode.DYNAMIC,
135
+ per_channel=True,
136
+ cache_quantized=True
137
+ ))
138
+
139
+ # Initialize MLX converter for native conversion
140
+ # Use a consistent cache directory
141
+ mlx_cache_dir = Path.home() / ".cortex" / "mlx_models"
142
+ self.mlx_converter = MLXConverter(
143
+ cache_dir=mlx_cache_dir
144
+ )
145
+
146
+ # Initialize MLX accelerator for optimizations
147
+ self.mlx_accelerator = None
148
+ self._mlx_init_error: Optional[str] = None
149
+ try:
150
+ self.mlx_accelerator = MLXAccelerator(MLXConfig(
151
+ compile_model=True,
152
+ use_amx=True,
153
+ fuse_operations=True,
154
+ rotating_kv_cache=True,
155
+ quantization_bits=4
156
+ ))
157
+ except Exception as e:
158
+ self._mlx_init_error = str(e)
159
+ logger.warning("MLX accelerator initialization failed: %s", e, exc_info=True)
160
+
161
+ self._setup_directories()
162
+ self._initialize_memory_pool()
163
+
164
+ def __del__(self):
165
+ """Clean up resources on deletion."""
166
+ try:
167
+ # Unload all models properly
168
+ for model_name in list(self.loaded_models.keys()):
169
+ self.unload_model(model_name)
170
+ except:
171
+ pass # Ignore errors during cleanup
172
+
173
+ def _setup_directories(self) -> None:
174
+ """Create necessary directories."""
175
+ self.config.model.model_path.expanduser().mkdir(parents=True, exist_ok=True)
176
+ self.config.model.model_cache_dir.expanduser().mkdir(parents=True, exist_ok=True)
177
+ self.config.model.quantization_cache.expanduser().mkdir(parents=True, exist_ok=True)
178
+
179
+ def _initialize_memory_pool(self) -> None:
180
+ """Initialize memory pool if needed."""
181
+ # Skip if already provided by InferenceEngine to avoid duplication
182
+ if self.memory_pool is not None:
183
+ return
184
+
185
+ if self.config.gpu.force_gpu:
186
+ try:
187
+ # Only create if not already provided
188
+ self.memory_pool = MemoryPool(
189
+ pool_size=None,
190
+ strategy=AllocationStrategy.UNIFIED,
191
+ device="mps" if torch.backends.mps.is_available() else "mlx",
192
+ auto_size=True,
193
+ silent=True # Suppress message since InferenceEngine also creates a pool
194
+ )
195
+ except Exception as e:
196
+ # InferenceEngine likely has its own pool; log for visibility.
197
+ logger.debug("Memory pool initialization skipped: %s", e)
198
+
199
+ def _prefer_speed_quantization(self) -> bool:
200
+ """Return True when config prefers maximum speed over quality."""
201
+ level = getattr(self.config.gpu, "gpu_optimization_level", "maximum")
202
+ level = str(level).lower().strip()
203
+ return level in {"maximum", "max", "speed", "fast", "performance"}
204
+
205
+ def load_model(
206
+ self,
207
+ model_path: str,
208
+ model_name: Optional[str] = None,
209
+ force_reload: bool = False,
210
+ convert_to_mlx: bool = False,
211
+ quantization: Optional[str] = None
212
+ ) -> Tuple[bool, str]:
213
+ """
214
+ Load a model to GPU memory with optional MLX conversion.
215
+
216
+ Args:
217
+ model_path: Path to model file or directory (or HF repo ID)
218
+ model_name: Optional name for the model
219
+ force_reload: Force reload even if already loaded
220
+ convert_to_mlx: Convert to MLX format for better performance
221
+ quantization: Quantization recipe ('4bit', '5bit', '8bit', 'mixed')
222
+
223
+ Returns:
224
+ Tuple of (success, message)
225
+ """
226
+ # Check if it's a HuggingFace repo ID
227
+ is_hf_repo = "/" in model_path and not Path(model_path).exists()
228
+
229
+ # Auto-enable MLX conversion if MLX backend is enabled in config
230
+ if hasattr(self.config, 'gpu') and hasattr(self.config.gpu, 'mlx_backend'):
231
+ if self.config.gpu.mlx_backend:
232
+ logger.info("MLX backend enabled in config, auto-converting models to MLX format")
233
+ convert_to_mlx = True
234
+
235
+ # Handle MLX conversion for HF models or local models
236
+ if convert_to_mlx or is_hf_repo:
237
+ # Check if this is a cached MLX model name
238
+ if "_4bit" in model_path or "_5bit" in model_path or "_8bit" in model_path or "_none" in model_path:
239
+ # This might be a cached MLX model, check if it exists
240
+ mlx_cache_dir = Path.home() / ".cortex" / "mlx_models"
241
+ cache_path = mlx_cache_dir / Path(model_path).name
242
+
243
+ if cache_path.exists() and cache_path.is_dir():
244
+ logger.info(f"Loading cached MLX model from {cache_path}")
245
+ success, result = self._load_mlx(cache_path, model_name or Path(model_path).name, {
246
+ 'format': ModelFormat.MLX,
247
+ 'quantization': QuantizationType.INT4 if "_4bit" in model_path else QuantizationType.INT8,
248
+ 'reason': 'Cached MLX model'
249
+ })
250
+
251
+ if success:
252
+ self.current_model = model_name or Path(model_path).name
253
+ # When loading from cache, don't update the config - it already has the right path
254
+ return True, f"Successfully loaded MLX model '{model_name or Path(model_path).name}'"
255
+ else:
256
+ return False, f"Failed to load MLX model: {result}"
257
+ else:
258
+ # Cached model not found, need to reconvert from original
259
+ # Try to extract original path from the cached name
260
+ base_name = model_path.replace("_4bit", "").replace("_5bit", "").replace("_8bit", "").replace("_none", "")
261
+
262
+ # Try to find the original model
263
+ if base_name.startswith("_Users_"):
264
+ # This is a local model path encoded in the name
265
+ original_path = "/" + base_name[1:].replace("_", "/")
266
+ original_path = Path(original_path).expanduser()
267
+
268
+ if original_path.exists():
269
+ logger.info(f"Found original model at {original_path}, will convert")
270
+ model_path = str(original_path)
271
+ # Continue with normal conversion flow
272
+ else:
273
+ return False, f"Cached MLX model not found and original model not found at {original_path}"
274
+ else:
275
+ return False, f"Cached MLX model not found at {cache_path}"
276
+
277
+ # Check if model is already in MLX format by looking ahead
278
+ test_path = Path(model_path).expanduser().resolve()
279
+ if test_path.exists() and test_path.is_dir():
280
+ # Check if it's in the mlx_models directory - if so, it's already converted
281
+ mlx_models_dir = Path.home() / ".cortex" / "mlx_models"
282
+ # Use proper path comparison
283
+ try:
284
+ is_in_mlx_dir = test_path.is_relative_to(mlx_models_dir)
285
+ except (ValueError, AttributeError):
286
+ # Fallback for older Python versions
287
+ is_in_mlx_dir = str(mlx_models_dir.resolve()) in str(test_path.resolve())
288
+
289
+ # Check for MLX format markers - include adapter files for fine-tuned models
290
+ has_mlx_weights = (test_path / 'weights.npz').exists() or (test_path / 'model.safetensors').exists()
291
+ has_config = (test_path / 'config.json').exists()
292
+ has_adapter = (test_path / 'adapter.safetensors').exists()
293
+ has_fine_tuned_marker = (test_path / 'fine_tuned.marker').exists()
294
+
295
+ # A model is MLX format if:
296
+ # 1. It's in the mlx_models directory, OR
297
+ # 2. It has MLX weights and config, OR
298
+ # 3. It's a fine-tuned model with adapters
299
+ if is_in_mlx_dir or (has_mlx_weights and has_config) or has_fine_tuned_marker or has_adapter:
300
+ # Already MLX format, skip conversion
301
+ logger.info(f"Model at {model_path} is already in MLX format, skipping conversion")
302
+ path = test_path
303
+
304
+ # Check if it's a fine-tuned model
305
+ format_info = {
306
+ 'format': ModelFormat.MLX,
307
+ 'quantization': QuantizationType.NONE, # Will be detected from model
308
+ 'reason': 'Existing MLX model'
309
+ }
310
+
311
+ # Check config for fine-tuning markers
312
+ config_path = path / "config.json"
313
+ if config_path.exists():
314
+ try:
315
+ with open(config_path, 'r') as f:
316
+ config = json.load(f)
317
+ if config.get('fine_tuned') or config.get('lora_adapter'):
318
+ format_info['is_fine_tuned'] = True
319
+ format_info['has_lora_adapter'] = True
320
+ logger.info("Detected fine-tuned model with LoRA adapters")
321
+ except Exception as e:
322
+ logger.warning(f"Could not read config to check fine-tuning status: {e}")
323
+
324
+ # Check for adapter files
325
+ if (path / "adapter.safetensors").exists() or (path / "adapter_config.json").exists():
326
+ format_info['has_lora_adapter'] = True
327
+ logger.info("Found LoRA adapter files")
328
+
329
+ # Load the existing MLX model directly
330
+ success, result = self._load_mlx(path, model_name or path.name, format_info)
331
+
332
+ if success:
333
+ self.current_model = model_name or path.name
334
+ # Store the original model path
335
+ self.config.update_last_used_model(str(model_path))
336
+ return True, f"Successfully loaded MLX model '{model_name or path.name}'"
337
+ else:
338
+ return False, f"Failed to load MLX model: {result}"
339
+ else:
340
+ # Needs conversion
341
+ # Determine quantization recipe - smart default based on model size
342
+ # First check model name for size hints
343
+ model_name_lower = str(test_path).lower()
344
+ if any(size in model_name_lower for size in ['270m', '350m', '500m']):
345
+ quant_recipe = QuantizationRecipe.NONE
346
+ logger.info(f"Very small model detected from name ({test_path.name}), skipping quantization")
347
+ print(f"Note: Very small model detected ({test_path.name}), skipping quantization for quality")
348
+ elif any(size in model_name_lower for size in ['1b', '2b', '3b']):
349
+ quant_recipe = QuantizationRecipe.QUALITY_8BIT
350
+ logger.info(f"Small model detected from name ({test_path.name}), using 8-bit quantization")
351
+ print(f"Note: Small model detected ({test_path.name}), using 8-bit quantization")
352
+ else:
353
+ # Use smart parameter detection for quantization decisions
354
+ try:
355
+ model_size_gb = self._get_model_size(test_path) / (1024**3)
356
+
357
+ # Use accurate parameter detection (returns billions)
358
+ actual_params_b = self.get_model_parameters_smart(test_path)
359
+ params_billions = float(actual_params_b) if actual_params_b is not None else (model_size_gb / 2.2)
360
+
361
+ logger.info(f"Model size: {model_size_gb:.2f}GB, parameters: {params_billions:.2f}B")
362
+
363
+ # For very small models like Gemma-270M, avoid quantization
364
+ if params_billions < 0.5:
365
+ quant_recipe = QuantizationRecipe.NONE
366
+ logger.info(f"Very small model detected ({self._format_param_count(params_billions)} params), skipping quantization")
367
+ print(f"Note: Very small model detected ({self._format_param_count(params_billions)} params), skipping quantization")
368
+ elif params_billions < 1.0:
369
+ quant_recipe = QuantizationRecipe.QUALITY_8BIT # Prefer 8bit for small models
370
+ logger.info(f"Small model detected ({self._format_param_count(params_billions)} params), using 8-bit quantization")
371
+ print(f"Note: Small model detected ({self._format_param_count(params_billions)} params), using 8-bit quantization")
372
+ else:
373
+ quant_recipe = QuantizationRecipe.SPEED_4BIT # Default for larger models
374
+ except Exception as e:
375
+ logger.warning(f"Could not estimate model parameters: {e}, defaulting to 4-bit")
376
+ quant_recipe = QuantizationRecipe.SPEED_4BIT # Fallback
377
+
378
+ if quantization:
379
+ quant_map = {
380
+ "4bit": QuantizationRecipe.SPEED_4BIT,
381
+ "5bit": QuantizationRecipe.BALANCED_5BIT,
382
+ "8bit": QuantizationRecipe.QUALITY_8BIT,
383
+ "mixed": QuantizationRecipe.MIXED_PRECISION,
384
+ "none": QuantizationRecipe.NONE
385
+ }
386
+ quant_recipe = quant_map.get(quantization, quant_recipe) # Use smart default if invalid
387
+ elif self._prefer_speed_quantization() and quant_recipe != QuantizationRecipe.NONE:
388
+ if quant_recipe != QuantizationRecipe.SPEED_4BIT:
389
+ logger.info("Max optimization enabled, using 4-bit quantization for best tokens/sec")
390
+ print("Note: Max optimization enabled, using 4-bit quantization for best tokens/sec")
391
+ quant_recipe = QuantizationRecipe.SPEED_4BIT
392
+
393
+ # Determine source format
394
+ source_format = ConversionFormat.HUGGINGFACE if is_hf_repo else ConversionFormat.SAFETENSORS
395
+
396
+ # For local SafeTensors models, ensure correct format detection
397
+ if not is_hf_repo and test_path.exists():
398
+ if any(f.suffix == '.safetensors' for f in test_path.glob('*.safetensors')):
399
+ source_format = ConversionFormat.SAFETENSORS
400
+ logger.info("Detected SafeTensors format for conversion to MLX")
401
+
402
+ # Convert to MLX
403
+ conversion_config = ConversionConfig(
404
+ source_format=source_format,
405
+ quantization=quant_recipe,
406
+ use_amx=True,
407
+ compile_model=True
408
+ )
409
+
410
+ logger.info(f"Converting model to MLX format with {quant_recipe.name} quantization...")
411
+ print(f"Converting model to MLX format for optimal performance...")
412
+
413
+ success, msg, mlx_path = self.mlx_converter.convert_model(
414
+ model_path,
415
+ output_name=model_name,
416
+ config=conversion_config
417
+ )
418
+
419
+ if not success:
420
+ return False, f"MLX conversion failed: {msg}"
421
+
422
+ # Update path to converted model
423
+ path = mlx_path
424
+ print(f"✓ Model converted to MLX format at {mlx_path}")
425
+ logger.info(f"Successfully converted model to MLX format at {mlx_path}")
426
+
427
+ # Now load the converted MLX model directly
428
+ success, result = self._load_mlx(path, model_name or path.name, {
429
+ 'format': ModelFormat.MLX,
430
+ 'quantization': self._get_quantization_type_from_recipe(quant_recipe),
431
+ 'reason': 'MLX converted model'
432
+ })
433
+
434
+ if success:
435
+ self.current_model = model_name or path.name
436
+ # Store the original model path
437
+ self.config.update_last_used_model(str(model_path))
438
+ return True, f"Successfully loaded MLX-converted model '{model_name or path.name}'"
439
+ else:
440
+ return False, f"Failed to load converted MLX model: {result}"
441
+ else:
442
+ # HuggingFace repo or non-existent path - needs conversion
443
+ # Determine quantization recipe - smart default for HF models too
444
+ # For HF models, we don't know the exact size, but we can use heuristics
445
+ model_name_lower = model_path.lower()
446
+ if any(size in model_name_lower for size in ['270m', '350m', '500m']):
447
+ quant_recipe = QuantizationRecipe.NONE # Very small models
448
+ logger.info(f"Very small model detected from name ({model_path}), skipping quantization")
449
+ elif any(size in model_name_lower for size in ['1b', '2b', '3b']):
450
+ quant_recipe = QuantizationRecipe.QUALITY_8BIT # Small models
451
+ logger.info(f"Small model detected from name ({model_path}), using 8-bit quantization")
452
+ else:
453
+ quant_recipe = QuantizationRecipe.SPEED_4BIT # Default for larger models
454
+
455
+ if quantization:
456
+ quant_map = {
457
+ "4bit": QuantizationRecipe.SPEED_4BIT,
458
+ "5bit": QuantizationRecipe.BALANCED_5BIT,
459
+ "8bit": QuantizationRecipe.QUALITY_8BIT,
460
+ "mixed": QuantizationRecipe.MIXED_PRECISION,
461
+ "none": QuantizationRecipe.NONE
462
+ }
463
+ quant_recipe = quant_map.get(quantization, quant_recipe)
464
+ elif self._prefer_speed_quantization() and quant_recipe != QuantizationRecipe.NONE:
465
+ if quant_recipe != QuantizationRecipe.SPEED_4BIT:
466
+ logger.info("Max optimization enabled, using 4-bit quantization for best tokens/sec")
467
+ print("Note: Max optimization enabled, using 4-bit quantization for best tokens/sec")
468
+ quant_recipe = QuantizationRecipe.SPEED_4BIT
469
+
470
+ # Convert to MLX
471
+ conversion_config = ConversionConfig(
472
+ source_format=ConversionFormat.HUGGINGFACE if is_hf_repo else ConversionFormat.SAFETENSORS,
473
+ quantization=quant_recipe,
474
+ use_amx=True,
475
+ compile_model=True
476
+ )
477
+
478
+ logger.info(f"Converting HF model to MLX format with {quant_recipe.name} quantization...")
479
+ print(f"Downloading and converting model to MLX format...")
480
+
481
+ success, msg, mlx_path = self.mlx_converter.convert_model(
482
+ model_path,
483
+ output_name=model_name,
484
+ config=conversion_config
485
+ )
486
+
487
+ if not success:
488
+ return False, f"MLX conversion failed: {msg}"
489
+
490
+ # Update path to converted model
491
+ path = mlx_path
492
+ print(f"✓ Model converted to MLX format at {mlx_path}")
493
+ logger.info(f"Successfully converted model to MLX format at {mlx_path}")
494
+
495
+ # Now load the converted MLX model directly
496
+ success, result = self._load_mlx(path, model_name or path.name, {
497
+ 'format': ModelFormat.MLX,
498
+ 'quantization': self._get_quantization_type_from_recipe(quant_recipe),
499
+ 'reason': 'MLX converted model'
500
+ })
501
+
502
+ if success:
503
+ self.current_model = model_name or path.name
504
+ # Store the original model path
505
+ self.config.update_last_used_model(str(model_path))
506
+ return True, f"Successfully loaded MLX-converted model '{model_name or path.name}'"
507
+ else:
508
+ return False, f"Failed to load converted MLX model: {result}"
509
+ else:
510
+ # Validate inputs - properly expand home directory
511
+ path = Path(model_path).expanduser().resolve()
512
+ if not path.exists():
513
+ # Try mlx-community models
514
+ if model_path.startswith("mlx-community/"):
515
+ # Download from HuggingFace mlx-community
516
+ success, msg, mlx_path = self.mlx_converter.convert_model(
517
+ model_path,
518
+ output_name=model_name,
519
+ config=ConversionConfig(quantization=QuantizationRecipe.NONE)
520
+ )
521
+ if success:
522
+ path = mlx_path
523
+ else:
524
+ return False, f"Failed to download MLX model: {msg}"
525
+ else:
526
+ return False, f"Model path does not exist: {model_path}"
527
+
528
+ # Use the full name for directories, stem for files
529
+ if path.is_dir():
530
+ model_name = model_name or path.name
531
+ else:
532
+ model_name = model_name or path.stem
533
+
534
+ # Check if already loaded
535
+ if model_name in self.loaded_models and not force_reload:
536
+ self.current_model = model_name
537
+ # Still update last used model even if already loaded
538
+ self.config.update_last_used_model(model_name)
539
+ return True, f"Model '{model_name}' already loaded"
540
+
541
+ # Check memory constraints
542
+ if len(self.loaded_models) >= self.config.model.max_loaded_models:
543
+ oldest = min(self.loaded_models.items(), key=lambda x: x[1].loaded_at)[0]
544
+ self.unload_model(oldest)
545
+
546
+ # Detect format and load
547
+ format_info = self._detect_format(path)
548
+ if format_info['format'] == ModelFormat.UNKNOWN:
549
+ return False, f"Unknown model format: {format_info['reason']}"
550
+
551
+ # Check GPU compatibility and determine if quantization is needed
552
+ model_size_gb = self._get_model_size(path) / (1024**3)
553
+ can_load, message = self.gpu_validator.verify_model_compatibility(model_size_gb)
554
+
555
+ # Determine if we need quantization (only for non-quantized models)
556
+ needs_quantization = False
557
+ quantization_mode = None
558
+
559
+ # Only apply dynamic quantization to non-quantized SafeTensors/PyTorch models
560
+ can_apply_quantization = (
561
+ format_info['format'] in [ModelFormat.SAFETENSORS, ModelFormat.PYTORCH] and
562
+ format_info['quantization'] == QuantizationType.NONE
563
+ )
564
+
565
+ if not can_load and can_apply_quantization:
566
+ # Check if quantization would help
567
+ gpu_status = self.gpu_validator.get_gpu_memory_status()
568
+ available_gb = gpu_status['available_gb']
569
+
570
+ # DEBUG: Uncomment to see memory calculations for quantization decisions
571
+ # print(f"DEBUG: Model size on disk: {model_size_gb:.1f}GB, Available memory: {available_gb:.1f}GB")
572
+
573
+ # Estimate if INT8 quantization would fit
574
+ # Use same 3.5x multiplier as gpu_validator for consistency
575
+ # INT8 is more stable than INT4, so prefer it when possible
576
+ estimated_int8_size = model_size_gb * 0.5 * 2.5 # 50% reduction + 150% overhead (less conservative for INT8)
577
+
578
+ # DEBUG: Uncomment to see quantization size estimates
579
+ # print(f"DEBUG: INT8 estimated size: {estimated_int8_size:.1f}GB")
580
+
581
+ if estimated_int8_size <= available_gb:
582
+ needs_quantization = True
583
+ quantization_mode = 'int8'
584
+ required_with_overhead = model_size_gb * 3.5
585
+ print(f"Model requires {required_with_overhead:.1f}GB (including overhead), only {available_gb:.1f}GB available.")
586
+ print(f"Will apply INT8 quantization to reduce to ~{estimated_int8_size:.1f}GB")
587
+ else:
588
+ # Try INT4 as last resort
589
+ estimated_int4_size = model_size_gb * 0.25 * 3.5 # 75% reduction + 250% overhead
590
+
591
+ # DEBUG: Uncomment to see INT4 quantization estimates
592
+ # print(f"DEBUG: INT4 estimated size: {estimated_int4_size:.1f}GB")
593
+
594
+ if estimated_int4_size <= available_gb:
595
+ needs_quantization = True
596
+ quantization_mode = 'int4'
597
+ required_with_overhead = model_size_gb * 3.5
598
+ print(f"Model requires {required_with_overhead:.1f}GB (including overhead), only {available_gb:.1f}GB available.")
599
+ print(f"Will apply INT4 quantization to reduce to ~{estimated_int4_size:.1f}GB")
600
+ else:
601
+ return False, f"Model too large even with quantization: {message}"
602
+ elif not can_load:
603
+ # Can't apply quantization to this format
604
+ return False, f"GPU incompatible: {message}"
605
+
606
+ # Load based on format
607
+ loader_map = {
608
+ ModelFormat.MLX: self._load_mlx,
609
+ ModelFormat.GGUF: self._load_gguf,
610
+ ModelFormat.SAFETENSORS: self._load_safetensors,
611
+ ModelFormat.PYTORCH: self._load_pytorch,
612
+ ModelFormat.QUANTIZED: self._load_quantized
613
+ }
614
+
615
+ loader = loader_map.get(format_info['format'])
616
+ if not loader:
617
+ return False, f"No loader for format: {format_info['format'].value}"
618
+
619
+ try:
620
+ # Pass quantization info to loaders that support it
621
+ if format_info['format'] in [ModelFormat.SAFETENSORS, ModelFormat.PYTORCH]:
622
+ success, result = loader(path, model_name, format_info, needs_quantization, quantization_mode)
623
+ else:
624
+ success, result = loader(path, model_name, format_info)
625
+
626
+ if success:
627
+ self.current_model = model_name
628
+ # Save the last used model to config
629
+ self.config.update_last_used_model(model_name)
630
+ return True, f"Successfully loaded '{model_name}'"
631
+ else:
632
+ return False, result
633
+ except Exception as e:
634
+ return False, f"Error loading model: {str(e)}"
635
+
636
+ def _detect_format(self, path: Path) -> Dict[str, Any]:
637
+ """
638
+ Detect model format from path.
639
+
640
+ Returns:
641
+ Dict with 'format', 'quantization', and 'reason'
642
+ """
643
+ # Check for specific file types
644
+ if path.is_file():
645
+ if path.suffix.lower() == '.gguf':
646
+ return {
647
+ 'format': ModelFormat.GGUF,
648
+ 'quantization': QuantizationType.Q4_K_M,
649
+ 'reason': 'GGUF file'
650
+ }
651
+
652
+ # Check for directory-based formats
653
+ if path.is_dir():
654
+ # Check for MLX format - support both regular MLX and fine-tuned MLX models
655
+ has_config = (path / 'config.json').exists()
656
+ has_weights_npz = (path / 'weights.npz').exists()
657
+ has_safetensors = any(path.glob('*.safetensors'))
658
+ has_fine_tuned_marker = (path / 'fine_tuned.marker').exists()
659
+
660
+ if has_config and (has_weights_npz or has_safetensors):
661
+ # Detect if it's a fine-tuned model with LoRA adapters
662
+ has_adapter = (path / 'adapter.safetensors').exists()
663
+
664
+ return {
665
+ 'format': ModelFormat.MLX,
666
+ 'quantization': QuantizationType.NONE,
667
+ 'reason': 'MLX model with LoRA adapters' if has_adapter else 'MLX model (weights + config)',
668
+ 'has_lora_adapter': has_adapter,
669
+ 'is_fine_tuned': has_fine_tuned_marker
670
+ }
671
+
672
+ # Check for SafeTensors format
673
+ safetensor_files = list(path.glob('*.safetensors'))
674
+ if safetensor_files and (path / 'config.json').exists():
675
+ # Check if it's quantized by looking at the config
676
+ quantization = self._detect_quantization(path)
677
+ if quantization in [QuantizationType.GPTQ, QuantizationType.AWQ, QuantizationType.INT4, QuantizationType.INT8]:
678
+ return {
679
+ 'format': ModelFormat.QUANTIZED,
680
+ 'quantization': quantization,
681
+ 'reason': f'Quantized model ({quantization.value})'
682
+ }
683
+ else:
684
+ return {
685
+ 'format': ModelFormat.SAFETENSORS,
686
+ 'quantization': quantization,
687
+ 'reason': 'SafeTensors model'
688
+ }
689
+
690
+ # Check for PyTorch format
691
+ if (path / 'pytorch_model.bin').exists() or list(path.glob('pytorch_model*.bin')):
692
+ return {
693
+ 'format': ModelFormat.PYTORCH,
694
+ 'quantization': QuantizationType.NONE,
695
+ 'reason': 'PyTorch model'
696
+ }
697
+
698
+ return {
699
+ 'format': ModelFormat.UNKNOWN,
700
+ 'quantization': QuantizationType.NONE,
701
+ 'reason': 'No recognized model files found'
702
+ }
703
+
704
+ def _detect_quantization(self, path: Path) -> QuantizationType:
705
+ """Detect quantization type from model files."""
706
+ # Check config.json for quantization info
707
+ config_path = path / 'config.json'
708
+ if config_path.exists():
709
+ try:
710
+ with open(config_path) as f:
711
+ config = json.load(f)
712
+
713
+ # Check for quantization config
714
+ if 'quantization_config' in config:
715
+ quant_config = config['quantization_config']
716
+ if 'quant_method' in quant_config:
717
+ method = quant_config['quant_method'].upper()
718
+ if 'GPTQ' in method:
719
+ return QuantizationType.GPTQ
720
+ elif 'AWQ' in method:
721
+ return QuantizationType.AWQ
722
+ if 'bits' in quant_config:
723
+ bits = quant_config['bits']
724
+ if bits == 4:
725
+ return QuantizationType.INT4
726
+ elif bits == 8:
727
+ return QuantizationType.INT8
728
+
729
+ # Check model name for hints (be careful with model size indicators like 4B = 4 billion)
730
+ model_name = str(path.name).upper()
731
+ # Only detect as quantized if explicitly mentioned
732
+ if '4BIT' in model_name or 'INT4' in model_name or 'GPTQ-4' in model_name:
733
+ return QuantizationType.INT4
734
+ elif '8BIT' in model_name or 'INT8' in model_name or 'GPTQ-8' in model_name:
735
+ return QuantizationType.INT8
736
+ elif 'GPTQ' in model_name:
737
+ return QuantizationType.GPTQ
738
+ elif 'AWQ' in model_name:
739
+ return QuantizationType.AWQ
740
+ except:
741
+ pass
742
+
743
+ # Check for quantization-specific files
744
+ safetensor_files = list(path.glob('*.safetensors'))
745
+ if safetensor_files:
746
+ # Load one file to check for quantization tensors
747
+ try:
748
+ from safetensors.torch import load_file
749
+ sample = load_file(safetensor_files[0], device='cpu')
750
+ # Check for GPTQ/AWQ specific tensors
751
+ has_scales = any('.scales' in k for k in sample.keys())
752
+ has_qweight = any('.qweight' in k for k in sample.keys())
753
+
754
+ if has_scales or has_qweight:
755
+ return QuantizationType.GPTQ
756
+ except:
757
+ pass
758
+
759
+ return QuantizationType.NONE
760
+
761
+ def _get_quantization_type_from_recipe(self, recipe: QuantizationRecipe) -> QuantizationType:
762
+ """Convert MLX quantization recipe to QuantizationType."""
763
+ recipe_to_type = {
764
+ QuantizationRecipe.SPEED_4BIT: QuantizationType.INT4,
765
+ QuantizationRecipe.BALANCED_5BIT: QuantizationType.INT4, # Closest match
766
+ QuantizationRecipe.QUALITY_8BIT: QuantizationType.INT8,
767
+ QuantizationRecipe.MIXED_PRECISION: QuantizationType.INT4,
768
+ QuantizationRecipe.NONE: QuantizationType.NONE
769
+ }
770
+ return recipe_to_type.get(recipe, QuantizationType.INT4)
771
+
772
+ def _format_param_count(self, params_b: Optional[float]) -> str:
773
+ """Format a parameter count in billions as a human-readable string (M/B/T)."""
774
+ try:
775
+ if params_b is None:
776
+ return "unknown"
777
+ # Trillions
778
+ if params_b >= 1000:
779
+ return f"{params_b / 1000:.1f}T"
780
+ # Billions
781
+ if params_b >= 1:
782
+ return f"{params_b:.1f}B"
783
+ # Millions (10M - 999M)
784
+ if params_b >= 0.01:
785
+ return f"{params_b * 1000:.0f}M"
786
+ # Low millions (1M - 9.9M)
787
+ if params_b >= 0.001:
788
+ return f"{params_b * 1000:.1f}M"
789
+ # Thousands
790
+ if params_b > 0:
791
+ return f"{params_b * 1e6:.0f}K"
792
+ return "0"
793
+ except Exception:
794
+ # Fallback formatting
795
+ try:
796
+ return f"{float(params_b):.2f}B"
797
+ except Exception:
798
+ return "unknown"
799
+
800
+ def _get_model_size(self, path: Path) -> int:
801
+ """Get total size of model files in bytes."""
802
+ if path.is_file():
803
+ return path.stat().st_size
804
+ elif path.is_dir():
805
+ total = 0
806
+ # Check if this is a fine-tuned model with LoRA adapters
807
+ is_finetuned = (path / 'fine_tuned.marker').exists() or (path / 'adapter.safetensors').exists()
808
+
809
+ for file in path.rglob('*'):
810
+ if file.is_file():
811
+ # Skip training checkpoint files for fine-tuned models
812
+ if is_finetuned and file.name.endswith('_adapters.safetensors'):
813
+ # These are intermediate checkpoints, not needed for inference
814
+ continue
815
+ # Skip cache and git files
816
+ if '/.cache/' in str(file) or '/.git/' in str(file):
817
+ continue
818
+ total += file.stat().st_size
819
+ return total
820
+ return 0
821
+
822
+ def _load_mlx(self, path: Path, model_name: str, format_info: Dict) -> Tuple[bool, str]:
823
+ """Load MLX format model with optimizations and LoRA adapter support."""
824
+ try:
825
+ if mlx_load is None:
826
+ return False, "MLX LM library not available. Install with: pip install mlx-lm"
827
+
828
+ # Check if this is a fine-tuned model with LoRA adapters
829
+ has_adapter = format_info.get('has_lora_adapter', False)
830
+ is_fine_tuned = format_info.get('is_fine_tuned', False)
831
+
832
+ if has_adapter or is_fine_tuned:
833
+ logger.info(f"Loading fine-tuned MLX model with LoRA adapters: {model_name}")
834
+
835
+ # For fine-tuned models, we need to load the base model and apply adapters
836
+ try:
837
+ # Try to load with adapter integration
838
+ from mlx_lm.tuner.utils import apply_lora_layers
839
+
840
+ # Load the model (should include merged weights)
841
+ model, tokenizer = mlx_load(str(path))
842
+
843
+ # The model should already have adapters merged since we saved it that way
844
+ logger.info("Fine-tuned MLX model loaded with integrated LoRA weights")
845
+
846
+ except ImportError:
847
+ # Fallback to regular loading
848
+ logger.warning("MLX LoRA utilities not available, loading as regular MLX model")
849
+ model, tokenizer = mlx_load(str(path))
850
+ else:
851
+ # Regular MLX model loading
852
+ model, tokenizer = mlx_load(str(path))
853
+
854
+ # Apply MLX accelerator optimizations if available
855
+ if self.mlx_accelerator:
856
+ # Silently apply optimizations - details shown in CLI
857
+ logger.info("Applying MLX optimizations (AMX, operation fusion)...")
858
+ model = self.mlx_accelerator.optimize_model(model)
859
+
860
+ # MLX LM models are already quantized during conversion
861
+ # No need to apply additional quantization
862
+ logger.info("MLX model already optimized with quantization")
863
+
864
+ # Evaluate parameters if they exist
865
+ if hasattr(model, 'parameters'):
866
+ mx.eval(model.parameters())
867
+
868
+ self.model_cache[model_name] = model
869
+ self.tokenizers[model_name] = tokenizer
870
+
871
+ # Load config
872
+ config = {}
873
+ config_path = path / 'config.json'
874
+ if config_path.exists():
875
+ with open(config_path) as f:
876
+ config = json.load(f)
877
+
878
+ # Create model info with accurate parameter detection
879
+ parameters = self.get_model_parameters_smart(path)
880
+
881
+ model_info = ModelInfo(
882
+ name=model_name,
883
+ path=path,
884
+ format=ModelFormat.MLX,
885
+ quantization=format_info['quantization'],
886
+ size_bytes=self._get_model_size(path),
887
+ parameters=parameters,
888
+ context_length=config.get('max_position_embeddings', 4096),
889
+ loaded_at=datetime.now(),
890
+ gpu_memory_used=self._estimate_mlx_memory(model),
891
+ tokenizer_path=path / 'tokenizer.json',
892
+ config=config
893
+ )
894
+
895
+ self.loaded_models[model_name] = model_info
896
+ return True, "MLX model loaded successfully"
897
+
898
+ except Exception as e:
899
+ return False, f"Failed to load MLX model: {str(e)}"
900
+
901
+ def _load_gguf(self, path: Path, model_name: str, format_info: Dict) -> Tuple[bool, str]:
902
+ """Load GGUF format model using llama-cpp-python."""
903
+ try:
904
+ from llama_cpp import Llama
905
+
906
+ print("Loading GGUF model with llama.cpp...")
907
+
908
+ # Determine optimal parameters for Apple Silicon
909
+ n_gpu_layers = -1 # Use all layers on GPU
910
+ n_ctx = 4096 # Context size
911
+ n_batch = 512 # Batch size for prompt processing
912
+
913
+ # Load the model
914
+ model = Llama(
915
+ model_path=str(path),
916
+ n_gpu_layers=n_gpu_layers,
917
+ n_ctx=n_ctx,
918
+ n_batch=n_batch,
919
+ n_threads=8, # Use 8 threads for CPU operations
920
+ use_mlock=True, # Lock model in RAM
921
+ verbose=False
922
+ )
923
+
924
+ # Create a simple tokenizer wrapper for compatibility
925
+ class GGUFTokenizer:
926
+ def __init__(self, model):
927
+ self.model = model
928
+ self.pad_token = None
929
+ self.eos_token = None
930
+
931
+ def encode(self, text):
932
+ return self.model.tokenize(text.encode('utf-8'))
933
+
934
+ def decode(self, tokens):
935
+ return self.model.detokenize(tokens).decode('utf-8')
936
+
937
+ tokenizer = GGUFTokenizer(model)
938
+
939
+ self.model_cache[model_name] = model
940
+ self.tokenizers[model_name] = tokenizer
941
+
942
+ # Create model info
943
+ # Get model parameters from the model object if available
944
+ try:
945
+ # Try to get parameters from model metadata
946
+ n_params = getattr(model, 'n_params', 0)
947
+ if n_params == 0:
948
+ # Estimate based on model size
949
+ n_params = self._get_model_size(path) // 2 # Rough estimate
950
+ except:
951
+ n_params = self._get_model_size(path) // 2
952
+
953
+ model_info = ModelInfo(
954
+ name=model_name,
955
+ path=path,
956
+ format=ModelFormat.GGUF,
957
+ quantization=format_info['quantization'],
958
+ size_bytes=self._get_model_size(path),
959
+ parameters=n_params,
960
+ context_length=n_ctx,
961
+ loaded_at=datetime.now(),
962
+ gpu_memory_used=self._get_model_size(path), # GGUF loads full model
963
+ tokenizer_path=None,
964
+ config={'n_ctx': n_ctx, 'n_gpu_layers': n_gpu_layers}
965
+ )
966
+
967
+ self.loaded_models[model_name] = model_info
968
+ return True, "GGUF model loaded successfully"
969
+
970
+ except ImportError:
971
+ return False, "llama-cpp-python not installed. Install with: pip install llama-cpp-python"
972
+ except Exception as e:
973
+ return False, f"Failed to load GGUF model: {str(e)}"
974
+
975
+ def _load_safetensors(
976
+ self,
977
+ path: Path,
978
+ model_name: str,
979
+ format_info: Dict,
980
+ needs_quantization: bool = False,
981
+ quantization_mode: Optional[str] = None
982
+ ) -> Tuple[bool, str]:
983
+ """Load standard SafeTensors model with optional quantization."""
984
+ try:
985
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
986
+
987
+ # Check for cached quantized model first BEFORE loading anything
988
+ if needs_quantization and self.quantizer.config.cache_quantized:
989
+ cached = self.quantizer.load_cached_model(path, self.quantizer.config)
990
+ if cached:
991
+ print(f"Loading cached quantized model...")
992
+ # Load to CPU first with minimal memory usage
993
+ print(f"Creating model structure...")
994
+ with torch.device('cpu'):
995
+ model = AutoModelForCausalLM.from_pretrained(
996
+ str(path),
997
+ torch_dtype=torch.float16,
998
+ low_cpu_mem_usage=True,
999
+ trust_remote_code=True,
1000
+ device_map={'': 'cpu'} # Force CPU loading
1001
+ )
1002
+
1003
+ print(f"Applying cached quantized weights...")
1004
+ model.load_state_dict(cached[0])
1005
+
1006
+ # Now move to target device
1007
+ print(f"Moving to {device}...")
1008
+ model = model.to(device)
1009
+ quantization_info = cached[1]
1010
+ print(f"Quantized model loaded from cache")
1011
+ else:
1012
+ # Load and quantize
1013
+ print(f"Loading model for quantization...")
1014
+ model = AutoModelForCausalLM.from_pretrained(
1015
+ str(path),
1016
+ torch_dtype=torch.float16,
1017
+ low_cpu_mem_usage=True,
1018
+ trust_remote_code=True
1019
+ )
1020
+
1021
+ if needs_quantization:
1022
+ print(f"Applying {quantization_mode} quantization...")
1023
+ gpu_status = self.gpu_validator.get_gpu_memory_status()
1024
+ model, quantization_info = self.quantizer.quantize_model(
1025
+ model,
1026
+ target_dtype=quantization_mode,
1027
+ available_memory_gb=gpu_status['available_gb'],
1028
+ model_size_gb=self._get_model_size(path) / (1024**3),
1029
+ target_device=device # Pass the target device (MPS)
1030
+ )
1031
+
1032
+ # Cache the quantized model
1033
+ if self.quantizer.config.cache_quantized:
1034
+ cache_path = self.quantizer.cache_quantized_model(model, path, quantization_info)
1035
+ print(f"Cached quantized model for faster future loads")
1036
+
1037
+ model = model.to(device)
1038
+ else:
1039
+ # Normal loading without quantization
1040
+ model = AutoModelForCausalLM.from_pretrained(
1041
+ str(path),
1042
+ torch_dtype=torch.float16,
1043
+ low_cpu_mem_usage=True,
1044
+ trust_remote_code=True
1045
+ )
1046
+ model = model.to(device)
1047
+ model.eval() # Set model to evaluation mode
1048
+
1049
+ # Load tokenizer
1050
+ tokenizer = AutoTokenizer.from_pretrained(str(path), use_fast=True)
1051
+ if tokenizer.pad_token is None:
1052
+ tokenizer.pad_token = tokenizer.eos_token
1053
+
1054
+ self.model_cache[model_name] = model
1055
+ self.tokenizers[model_name] = tokenizer
1056
+
1057
+ # Load config
1058
+ config = {}
1059
+ config_path = path / 'config.json'
1060
+ if config_path.exists():
1061
+ with open(config_path) as f:
1062
+ config = json.load(f)
1063
+
1064
+ # Create model info with quantization details if applicable
1065
+ if needs_quantization:
1066
+ # Update quantization type based on what was applied
1067
+ actual_quantization = QuantizationType.INT8 if quantization_mode == 'int8' else QuantizationType.INT4
1068
+ # Calculate actual memory used after quantization
1069
+ memory_used = sum(
1070
+ p.numel() * 1 if hasattr(p, 'numel') else 0 # Quantized uses less bytes per element
1071
+ for p in model.parameters()
1072
+ )
1073
+ else:
1074
+ actual_quantization = format_info['quantization']
1075
+ memory_used = sum(p.element_size() * p.numel() for p in model.parameters())
1076
+
1077
+ # Use smart parameter detection instead of counting loaded model parameters
1078
+ parameters = self.get_model_parameters_smart(path)
1079
+
1080
+ model_info = ModelInfo(
1081
+ name=model_name,
1082
+ path=path,
1083
+ format=ModelFormat.SAFETENSORS,
1084
+ quantization=actual_quantization,
1085
+ size_bytes=self._get_model_size(path),
1086
+ parameters=parameters,
1087
+ context_length=config.get('max_position_embeddings', 4096),
1088
+ loaded_at=datetime.now(),
1089
+ gpu_memory_used=memory_used,
1090
+ tokenizer_path=path / 'tokenizer.json',
1091
+ config=config
1092
+ )
1093
+
1094
+ self.loaded_models[model_name] = model_info
1095
+ return True, "SafeTensors model loaded successfully"
1096
+
1097
+ except Exception as e:
1098
+ return False, f"Failed to load SafeTensors model: {str(e)}"
1099
+
1100
+ def _load_pytorch(
1101
+ self,
1102
+ path: Path,
1103
+ model_name: str,
1104
+ format_info: Dict,
1105
+ needs_quantization: bool = False,
1106
+ quantization_mode: Optional[str] = None
1107
+ ) -> Tuple[bool, str]:
1108
+ """Load PyTorch format model with optional quantization."""
1109
+ try:
1110
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
1111
+
1112
+ # Check for cached quantized model first
1113
+ if needs_quantization and self.quantizer.config.cache_quantized:
1114
+ cached = self.quantizer.load_cached_model(path, self.quantizer.config)
1115
+ if cached:
1116
+ print(f"Loading cached quantized model...")
1117
+ model = AutoModelForCausalLM.from_pretrained(
1118
+ str(path),
1119
+ torch_dtype=torch.float16,
1120
+ low_cpu_mem_usage=True,
1121
+ trust_remote_code=True
1122
+ )
1123
+ model.load_state_dict(cached[0])
1124
+ model = model.to(device)
1125
+ quantization_info = cached[1]
1126
+ else:
1127
+ # Load and quantize
1128
+ print(f"Loading model for quantization...")
1129
+ model = AutoModelForCausalLM.from_pretrained(
1130
+ str(path),
1131
+ torch_dtype=torch.float16,
1132
+ low_cpu_mem_usage=True,
1133
+ trust_remote_code=True
1134
+ )
1135
+
1136
+ if needs_quantization:
1137
+ print(f"Applying {quantization_mode} quantization...")
1138
+ gpu_status = self.gpu_validator.get_gpu_memory_status()
1139
+ model, quantization_info = self.quantizer.quantize_model(
1140
+ model,
1141
+ target_dtype=quantization_mode,
1142
+ available_memory_gb=gpu_status['available_gb'],
1143
+ model_size_gb=self._get_model_size(path) / (1024**3),
1144
+ target_device=device # Pass the target device (MPS)
1145
+ )
1146
+
1147
+ # Cache the quantized model
1148
+ if self.quantizer.config.cache_quantized:
1149
+ cache_path = self.quantizer.cache_quantized_model(model, path, quantization_info)
1150
+ print(f"Cached quantized model for faster future loads")
1151
+
1152
+ model = model.to(device)
1153
+ else:
1154
+ # Normal loading without quantization
1155
+ model = AutoModelForCausalLM.from_pretrained(
1156
+ str(path),
1157
+ torch_dtype=torch.float16,
1158
+ low_cpu_mem_usage=True,
1159
+ trust_remote_code=True
1160
+ )
1161
+ model = model.to(device)
1162
+ model.eval() # Set model to evaluation mode
1163
+
1164
+ # Load tokenizer
1165
+ tokenizer = AutoTokenizer.from_pretrained(str(path), use_fast=True)
1166
+ if tokenizer.pad_token is None:
1167
+ tokenizer.pad_token = tokenizer.eos_token
1168
+
1169
+ self.model_cache[model_name] = model
1170
+ self.tokenizers[model_name] = tokenizer
1171
+
1172
+ # Get config
1173
+ config = model.config.to_dict() if hasattr(model, 'config') else {}
1174
+
1175
+ # Create model info with quantization details if applicable
1176
+ if needs_quantization:
1177
+ # Update quantization type based on what was applied
1178
+ actual_quantization = QuantizationType.INT8 if quantization_mode == 'int8' else QuantizationType.INT4
1179
+ # Calculate actual memory used after quantization
1180
+ memory_used = sum(
1181
+ p.numel() * 1 if hasattr(p, 'numel') else 0 # Quantized uses less bytes per element
1182
+ for p in model.parameters()
1183
+ )
1184
+ else:
1185
+ actual_quantization = format_info['quantization']
1186
+ memory_used = sum(p.element_size() * p.numel() for p in model.parameters())
1187
+
1188
+ # Use smart parameter detection instead of counting loaded model parameters
1189
+ parameters = self.get_model_parameters_smart(path)
1190
+
1191
+ model_info = ModelInfo(
1192
+ name=model_name,
1193
+ path=path,
1194
+ format=ModelFormat.PYTORCH,
1195
+ quantization=actual_quantization,
1196
+ size_bytes=self._get_model_size(path),
1197
+ parameters=parameters,
1198
+ context_length=config.get('max_position_embeddings', 4096),
1199
+ loaded_at=datetime.now(),
1200
+ gpu_memory_used=memory_used,
1201
+ tokenizer_path=None,
1202
+ config=config
1203
+ )
1204
+
1205
+ self.loaded_models[model_name] = model_info
1206
+ return True, "PyTorch model loaded successfully"
1207
+
1208
+ except Exception as e:
1209
+ return False, f"Failed to load PyTorch model: {str(e)}"
1210
+
1211
+ def _load_quantized(self, path: Path, model_name: str, format_info: Dict) -> Tuple[bool, str]:
1212
+ """Load quantized model (GPTQ/AWQ/etc) using appropriate libraries."""
1213
+ quant_type = format_info['quantization']
1214
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
1215
+
1216
+ print(f"Loading {quant_type.value} quantized model...")
1217
+
1218
+ # Try different quantization libraries based on detected type
1219
+ if quant_type in [QuantizationType.GPTQ, QuantizationType.INT4]:
1220
+ # Try GPTQ loader first
1221
+ try:
1222
+ from auto_gptq import AutoGPTQForCausalLM
1223
+
1224
+ model = AutoGPTQForCausalLM.from_quantized(
1225
+ str(path),
1226
+ device="cuda:0" if torch.cuda.is_available() else "cpu", # GPTQ doesn't support MPS directly
1227
+ use_safetensors=True,
1228
+ trust_remote_code=True,
1229
+ inject_fused_attention=False, # Disable for compatibility
1230
+ inject_fused_mlp=False
1231
+ )
1232
+
1233
+ # Move to MPS if needed
1234
+ if device.type == "mps" and not torch.cuda.is_available():
1235
+ model = model.cpu() # GPTQ models may need CPU fallback on Mac
1236
+ print("Note: GPTQ model loaded on CPU (MPS not fully supported)")
1237
+
1238
+ tokenizer = AutoTokenizer.from_pretrained(str(path), use_fast=True)
1239
+ if tokenizer.pad_token is None:
1240
+ tokenizer.pad_token = tokenizer.eos_token
1241
+
1242
+ self.model_cache[model_name] = model
1243
+ self.tokenizers[model_name] = tokenizer
1244
+
1245
+ config = self._load_config(path)
1246
+ model_info = self._create_model_info(
1247
+ model_name, path, ModelFormat.QUANTIZED, quant_type,
1248
+ model, config
1249
+ )
1250
+
1251
+ self.loaded_models[model_name] = model_info
1252
+ return True, f"GPTQ quantized model loaded successfully"
1253
+
1254
+ except ImportError:
1255
+ print("GPTQ library not available, trying alternative methods...")
1256
+ except Exception as e:
1257
+ print(f"GPTQ loading failed: {str(e)[:100]}")
1258
+
1259
+ if quant_type == QuantizationType.AWQ:
1260
+ # Try AWQ loader
1261
+ try:
1262
+ from awq import AutoAWQForCausalLM
1263
+
1264
+ model = AutoAWQForCausalLM.from_quantized(
1265
+ str(path),
1266
+ fuse_layers=False, # Disable for compatibility
1267
+ trust_remote_code=True
1268
+ )
1269
+
1270
+ tokenizer = AutoTokenizer.from_pretrained(str(path), use_fast=True)
1271
+ if tokenizer.pad_token is None:
1272
+ tokenizer.pad_token = tokenizer.eos_token
1273
+
1274
+ self.model_cache[model_name] = model
1275
+ self.tokenizers[model_name] = tokenizer
1276
+
1277
+ config = self._load_config(path)
1278
+ model_info = self._create_model_info(
1279
+ model_name, path, ModelFormat.QUANTIZED, quant_type,
1280
+ model, config
1281
+ )
1282
+
1283
+ self.loaded_models[model_name] = model_info
1284
+ return True, f"AWQ quantized model loaded successfully"
1285
+
1286
+ except ImportError:
1287
+ print("AWQ library not available, trying alternative methods...")
1288
+ except Exception as e:
1289
+ print(f"AWQ loading failed: {str(e)[:100]}")
1290
+
1291
+ # Try using accelerate for general quantized models
1292
+ try:
1293
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
1294
+ from transformers import AutoConfig
1295
+
1296
+ print("Attempting to load with accelerate library...")
1297
+
1298
+ # Load config first
1299
+ config = AutoConfig.from_pretrained(str(path), trust_remote_code=True)
1300
+
1301
+ # Initialize model with empty weights
1302
+ with init_empty_weights():
1303
+ model = AutoModelForCausalLM.from_config(
1304
+ config,
1305
+ torch_dtype=torch.float16,
1306
+ trust_remote_code=True
1307
+ )
1308
+
1309
+ # Determine the checkpoint files
1310
+ checkpoint_files = list(path.glob("*.safetensors"))
1311
+ if not checkpoint_files:
1312
+ checkpoint_files = list(path.glob("pytorch_model*.bin"))
1313
+
1314
+ if not checkpoint_files:
1315
+ raise ValueError("No model files found")
1316
+
1317
+ # Create proper device map for MPS
1318
+ if device.type == "mps":
1319
+ device_map = {"": "cpu"} # Load to CPU first for MPS compatibility
1320
+ else:
1321
+ device_map = "auto"
1322
+
1323
+ # Load and dispatch to device
1324
+ model = load_checkpoint_and_dispatch(
1325
+ model,
1326
+ checkpoint=str(path), # Directory containing the model files
1327
+ device_map=device_map,
1328
+ dtype=torch.float16,
1329
+ offload_folder=str(self.config.model.model_cache_dir / "offload")
1330
+ )
1331
+
1332
+ # Move to MPS if needed
1333
+ if device.type == "mps" and device_map == {"": "cpu"}:
1334
+ model = model.to(device)
1335
+
1336
+ tokenizer = AutoTokenizer.from_pretrained(str(path))
1337
+ if tokenizer.pad_token is None:
1338
+ tokenizer.pad_token = tokenizer.eos_token
1339
+
1340
+ self.model_cache[model_name] = model
1341
+ self.tokenizers[model_name] = tokenizer
1342
+
1343
+ config_dict = self._load_config(path)
1344
+ model_info = self._create_model_info(
1345
+ model_name, path, ModelFormat.QUANTIZED, quant_type,
1346
+ model, config_dict
1347
+ )
1348
+
1349
+ self.loaded_models[model_name] = model_info
1350
+ return True, f"Quantized model loaded with accelerate"
1351
+
1352
+ except Exception as e:
1353
+ print(f"Accelerate loading failed: {str(e)[:100]}")
1354
+
1355
+ # Try bitsandbytes for 4-bit/8-bit models
1356
+ try:
1357
+ from transformers import BitsAndBytesConfig
1358
+
1359
+ print("Attempting to load with bitsandbytes quantization...")
1360
+
1361
+ bnb_config = BitsAndBytesConfig(
1362
+ load_in_4bit=True if quant_type == QuantizationType.INT4 else False,
1363
+ load_in_8bit=True if quant_type == QuantizationType.INT8 else False,
1364
+ bnb_4bit_compute_dtype=torch.float16,
1365
+ bnb_4bit_use_double_quant=True,
1366
+ bnb_4bit_quant_type="nf4"
1367
+ )
1368
+
1369
+ model = AutoModelForCausalLM.from_pretrained(
1370
+ str(path),
1371
+ quantization_config=bnb_config,
1372
+ torch_dtype=torch.float16,
1373
+ trust_remote_code=True,
1374
+ device_map="auto" if torch.cuda.is_available() else {"": device}
1375
+ )
1376
+
1377
+ tokenizer = AutoTokenizer.from_pretrained(str(path))
1378
+ if tokenizer.pad_token is None:
1379
+ tokenizer.pad_token = tokenizer.eos_token
1380
+
1381
+ self.model_cache[model_name] = model
1382
+ self.tokenizers[model_name] = tokenizer
1383
+
1384
+ config = self._load_config(path)
1385
+ model_info = self._create_model_info(
1386
+ model_name, path, ModelFormat.QUANTIZED, quant_type,
1387
+ model, config
1388
+ )
1389
+
1390
+ self.loaded_models[model_name] = model_info
1391
+ return True, f"Quantized model loaded with bitsandbytes"
1392
+
1393
+ except Exception as e:
1394
+ print(f"Bitsandbytes loading failed: {str(e)[:100]}")
1395
+
1396
+ # If all methods fail, provide guidance
1397
+ return False, f"Failed to load {quant_type.value} quantized model. The model format may not be compatible with Apple Silicon."
1398
+
1399
+ def _load_config(self, path: Path) -> Dict[str, Any]:
1400
+ """Load config.json from model path."""
1401
+ config_path = path / 'config.json'
1402
+ if config_path.exists():
1403
+ with open(config_path) as f:
1404
+ return json.load(f)
1405
+ return {}
1406
+
1407
+ def _create_model_info(
1408
+ self,
1409
+ model_name: str,
1410
+ path: Path,
1411
+ format: ModelFormat,
1412
+ quantization: QuantizationType,
1413
+ model: Any,
1414
+ config: Dict[str, Any]
1415
+ ) -> ModelInfo:
1416
+ """Create ModelInfo object for a loaded model."""
1417
+ # Use smart parameter detection instead of loading model parameters
1418
+ parameters = self.get_model_parameters_smart(path)
1419
+
1420
+ # Calculate memory usage
1421
+ if hasattr(model, 'parameters'):
1422
+ try:
1423
+ memory_used = sum(p.element_size() * p.numel() for p in model.parameters())
1424
+ except:
1425
+ memory_used = self._get_model_size(path)
1426
+ else:
1427
+ memory_used = self._get_model_size(path)
1428
+
1429
+ return ModelInfo(
1430
+ name=model_name,
1431
+ path=path,
1432
+ format=format,
1433
+ quantization=quantization,
1434
+ size_bytes=self._get_model_size(path),
1435
+ parameters=parameters,
1436
+ context_length=config.get('max_position_embeddings', 4096),
1437
+ loaded_at=datetime.now(),
1438
+ gpu_memory_used=memory_used,
1439
+ tokenizer_path=path / 'tokenizer.json' if (path / 'tokenizer.json').exists() else None,
1440
+ config=config
1441
+ )
1442
+
1443
+ def _count_mlx_parameters(self, model: Any) -> int:
1444
+ """Count parameters in MLX model."""
1445
+ try:
1446
+ if hasattr(model, 'num_parameters'):
1447
+ return model.num_parameters()
1448
+ elif hasattr(model, 'parameters'):
1449
+ params = model.parameters()
1450
+ if isinstance(params, dict):
1451
+ return sum(p.size for p in params.values())
1452
+ return 0
1453
+ except:
1454
+ return 0
1455
+
1456
+ def _estimate_mlx_memory(self, model: Any) -> int:
1457
+ """Estimate memory usage of MLX model."""
1458
+ try:
1459
+ if hasattr(model, 'parameters'):
1460
+ params = model.parameters()
1461
+ if isinstance(params, dict):
1462
+ return sum(p.nbytes if hasattr(p, 'nbytes') else 0 for p in params.values())
1463
+ return 0
1464
+ except:
1465
+ return 0
1466
+
1467
+ def unload_model(self, model_name: str) -> Tuple[bool, str]:
1468
+ """Unload a model from memory."""
1469
+ if model_name not in self.loaded_models:
1470
+ return False, f"Model '{model_name}' not loaded"
1471
+
1472
+ try:
1473
+ # Special cleanup for GGUF models (llama-cpp-python)
1474
+ if model_name in self.model_cache:
1475
+ model = self.model_cache[model_name]
1476
+ model_info = self.loaded_models.get(model_name)
1477
+
1478
+ # Clean up GGUF models properly to avoid memory leaks
1479
+ if model_info and model_info.format == ModelFormat.GGUF:
1480
+ try:
1481
+ # llama-cpp-python models have a close() method
1482
+ if hasattr(model, 'close'):
1483
+ model.close()
1484
+ # Also try to explicitly delete the model
1485
+ del model
1486
+ except Exception as e:
1487
+ print(f"Warning: Error closing GGUF model: {e}")
1488
+
1489
+ # Remove from cache
1490
+ del self.model_cache[model_name]
1491
+
1492
+ # Remove tokenizer
1493
+ if model_name in self.tokenizers:
1494
+ del self.tokenizers[model_name]
1495
+
1496
+ # Remove model info
1497
+ del self.loaded_models[model_name]
1498
+
1499
+ # Update current model
1500
+ if self.current_model == model_name:
1501
+ self.current_model = None
1502
+
1503
+ # Clear GPU cache
1504
+ if torch.cuda.is_available():
1505
+ torch.cuda.empty_cache()
1506
+ elif torch.backends.mps.is_available():
1507
+ # Clear MPS cache on Apple Silicon
1508
+ try:
1509
+ torch.mps.empty_cache()
1510
+ except:
1511
+ pass # MPS cache clearing might not be available in all versions
1512
+
1513
+ # Force garbage collection for thorough cleanup
1514
+ import gc
1515
+ gc.collect()
1516
+
1517
+ return True, f"Model '{model_name}' unloaded"
1518
+
1519
+ except Exception as e:
1520
+ return False, f"Error unloading model: {str(e)}"
1521
+
1522
+ def list_models(self) -> List[Dict[str, Any]]:
1523
+ """List all loaded models."""
1524
+ return [model.to_dict() for model in self.loaded_models.values()]
1525
+
1526
+ def discover_available_models(self) -> List[Dict[str, Any]]:
1527
+ """Discover all available models including MLX converted ones."""
1528
+ available = []
1529
+ model_path = self.config.model.model_path.expanduser().resolve()
1530
+
1531
+ if not model_path.exists():
1532
+ return available
1533
+
1534
+ # Also check MLX models directory for fine-tuned models
1535
+ mlx_path = Path.home() / ".cortex" / "mlx_models"
1536
+
1537
+ # First, get all MLX converted models to check for optimized versions
1538
+ mlx_models = self.mlx_converter.list_converted_models()
1539
+ mlx_cache_map = {} # Map original paths to MLX versions
1540
+
1541
+ # Build a map of original model paths to their MLX versions
1542
+ for name, info in mlx_models.items():
1543
+ # Extract original path from MLX model name
1544
+ if name.startswith("_Users_") and ("_4bit" in name or "_5bit" in name or "_8bit" in name):
1545
+ base_name = name.replace("_4bit", "").replace("_5bit", "").replace("_8bit", "")
1546
+ original_path = "/" + base_name[1:].replace("_", "/")
1547
+ mlx_cache_map[original_path] = {
1548
+ 'mlx_name': name,
1549
+ 'mlx_path': info['path'],
1550
+ 'quantization': info.get('quantization', 4),
1551
+ 'size_gb': info.get('size_gb', 0)
1552
+ }
1553
+
1554
+ # Search for models in the models directory
1555
+ for item in model_path.iterdir():
1556
+ if item.is_file() and item.suffix == '.gguf':
1557
+ # GGUF file
1558
+ size_gb = item.stat().st_size / (1024**3)
1559
+ available.append({
1560
+ 'name': item.stem,
1561
+ 'path': str(item),
1562
+ 'relative_path': item.name,
1563
+ 'format': 'GGUF',
1564
+ 'size_gb': round(size_gb, 2),
1565
+ 'mlx_optimized': False,
1566
+ 'mlx_available': False
1567
+ })
1568
+ elif item.is_dir():
1569
+ # Check if it's a model directory
1570
+ format_info = self._detect_format(item)
1571
+ if format_info['format'] != ModelFormat.UNKNOWN:
1572
+ # Check if this model has an MLX optimized version
1573
+ full_path = str(item.resolve())
1574
+ has_mlx = full_path in mlx_cache_map
1575
+
1576
+ if has_mlx:
1577
+ # Use the MLX version's info
1578
+ mlx_info = mlx_cache_map[full_path]
1579
+ available.append({
1580
+ 'name': item.name,
1581
+ 'path': str(item), # Keep original path for compatibility
1582
+ 'relative_path': item.name,
1583
+ 'format': 'MLX-optimized',
1584
+ 'size_gb': round(mlx_info['size_gb'], 2) if mlx_info['size_gb'] > 0 else round(self._get_model_size(item) / (1024**3), 2),
1585
+ 'mlx_optimized': True,
1586
+ 'mlx_path': mlx_info['mlx_path'],
1587
+ 'mlx_name': mlx_info['mlx_name'],
1588
+ 'quantization': f"{mlx_info['quantization']}-bit",
1589
+ 'original_format': format_info['format'].value.upper()
1590
+ })
1591
+ else:
1592
+ # Regular model without MLX optimization
1593
+ size_gb = self._get_model_size(item) / (1024**3)
1594
+ format_str = format_info['format'].value.upper()
1595
+ if format_info['quantization'] != QuantizationType.NONE:
1596
+ format_str = f"{format_str} ({format_info['quantization'].value})"
1597
+
1598
+ available.append({
1599
+ 'name': item.name,
1600
+ 'path': str(item),
1601
+ 'relative_path': item.name,
1602
+ 'format': format_str,
1603
+ 'size_gb': round(size_gb, 2),
1604
+ 'mlx_optimized': False,
1605
+ 'mlx_available': format_info['format'] in [
1606
+ ModelFormat.SAFETENSORS,
1607
+ ModelFormat.PYTORCH
1608
+ ]
1609
+ })
1610
+
1611
+ # Add fine-tuned models from MLX directory that aren't already included
1612
+ if mlx_path.exists():
1613
+ for item in mlx_path.iterdir():
1614
+ if item.is_dir():
1615
+ # Check if it's a fine-tuned model that's not already in the list
1616
+ if (item / 'fine_tuned.marker').exists():
1617
+ # This is a fine-tuned model
1618
+ already_added = any(model['name'] == item.name for model in available)
1619
+ if not already_added:
1620
+ size_gb = self._get_model_size(item) / (1024**3)
1621
+
1622
+ # Read metadata from marker file
1623
+ base_model = "Unknown"
1624
+ try:
1625
+ with open(item / 'fine_tuned.marker', 'r') as f:
1626
+ content = f.read()
1627
+ for line in content.split('\n'):
1628
+ if 'LoRA fine-tuned version of' in line:
1629
+ base_model = line.split('LoRA fine-tuned version of ')[-1].strip()
1630
+ break
1631
+ except:
1632
+ pass
1633
+
1634
+ available.append({
1635
+ 'name': item.name,
1636
+ 'path': str(item),
1637
+ 'relative_path': item.name,
1638
+ 'format': 'MLX Fine-tuned',
1639
+ 'size_gb': round(size_gb, 2),
1640
+ 'mlx_optimized': True,
1641
+ 'is_fine_tuned': True,
1642
+ 'base_model': base_model,
1643
+ 'fine_tuning_method': 'LoRA'
1644
+ })
1645
+
1646
+ # Sort by name
1647
+ available.sort(key=lambda x: x['name'].lower())
1648
+ return available
1649
+
1650
+ def get_current_model(self) -> Optional[ModelInfo]:
1651
+ """Get currently active model."""
1652
+ if self.current_model:
1653
+ return self.loaded_models.get(self.current_model)
1654
+ return None
1655
+
1656
+ def switch_model(self, model_name: str) -> Tuple[bool, str]:
1657
+ """Switch to a different loaded model."""
1658
+ if model_name not in self.loaded_models:
1659
+ return False, f"Model '{model_name}' not loaded"
1660
+
1661
+ self.current_model = model_name
1662
+ return True, f"Switched to model '{model_name}'"
1663
+
1664
+ def get_memory_status(self) -> Dict[str, Any]:
1665
+ """Get GPU memory status with MLX details."""
1666
+ status = self.gpu_validator.get_gpu_memory_status()
1667
+
1668
+ total_model_memory = sum(
1669
+ model.gpu_memory_used for model in self.loaded_models.values()
1670
+ )
1671
+
1672
+ status['models_loaded'] = len(self.loaded_models)
1673
+ status['model_memory_gb'] = total_model_memory / (1024**3)
1674
+ status['current_model'] = self.current_model
1675
+
1676
+ # Add MLX-specific info
1677
+ mlx_models = [m for m in self.loaded_models.values() if m.format == ModelFormat.MLX]
1678
+ if mlx_models:
1679
+ status['mlx_models'] = len(mlx_models)
1680
+ status['mlx_memory_gb'] = sum(m.gpu_memory_used for m in mlx_models) / (1024**3)
1681
+
1682
+ if self.memory_pool:
1683
+ pool_stats = self.memory_pool.get_stats()
1684
+ status['memory_pool'] = {
1685
+ 'allocated_gb': pool_stats['allocated_memory'] / (1024**3),
1686
+ 'free_gb': pool_stats['free_memory'] / (1024**3),
1687
+ 'fragmentation': pool_stats['fragmentation'],
1688
+ 'total_blocks': pool_stats['total_blocks'],
1689
+ 'zero_copy_enabled': pool_stats.get('zero_copy', False)
1690
+ }
1691
+
1692
+ # Add MLX accelerator status
1693
+ if self.mlx_accelerator:
1694
+ status['mlx_acceleration'] = {
1695
+ 'amx_enabled': self.mlx_accelerator.config.use_amx,
1696
+ 'operation_fusion': self.mlx_accelerator.config.fuse_operations,
1697
+ 'kv_cache_size': self.mlx_accelerator.config.kv_cache_size,
1698
+ 'quantization_bits': self.mlx_accelerator.config.quantization_bits
1699
+ }
1700
+
1701
+ return status
1702
+
1703
+ def detect_model_parameters(self, model_path: Path) -> Optional[int]:
1704
+ """
1705
+ Detect the actual number of parameters in a model.
1706
+
1707
+ Uses proper parameter counting that handles quantization, LoRA adapters,
1708
+ and non-weight files correctly.
1709
+
1710
+ Returns:
1711
+ Number of parameters, or None if detection fails
1712
+ """
1713
+ try:
1714
+ # Check cache first
1715
+ cache_key = f"{model_path.resolve()}:{model_path.stat().st_mtime}"
1716
+ cached_result = self._get_cached_parameter_count(cache_key)
1717
+ if cached_result is not None:
1718
+ logger.debug(f"Using cached parameter count: {cached_result:,}")
1719
+ return cached_result
1720
+
1721
+ # Detect model format and apply appropriate detection method
1722
+ format_info = self._detect_format(model_path)
1723
+ param_count = None
1724
+
1725
+ if format_info['format'] == ModelFormat.SAFETENSORS:
1726
+ param_count = self._detect_safetensors_parameters(model_path)
1727
+ elif format_info['format'] == ModelFormat.MLX:
1728
+ param_count = self._detect_mlx_parameters(model_path)
1729
+ elif format_info['format'] == ModelFormat.PYTORCH:
1730
+ param_count = self._detect_pytorch_parameters(model_path)
1731
+
1732
+ # Fallback to config.json analysis
1733
+ if param_count is None:
1734
+ param_count = self._detect_config_parameters(model_path)
1735
+
1736
+ # Cache the result if successful
1737
+ if param_count is not None:
1738
+ self._cache_parameter_count(cache_key, param_count)
1739
+ logger.info(f"Detected {param_count:,} parameters in {model_path.name}")
1740
+ else:
1741
+ logger.warning(f"Could not detect parameters for {model_path.name}")
1742
+
1743
+ return param_count
1744
+
1745
+ except Exception as e:
1746
+ logger.warning(f"Parameter detection failed for {model_path}: {e}")
1747
+ return None
1748
+
1749
+ def _get_cached_parameter_count(self, cache_key: str) -> Optional[int]:
1750
+ """Get cached parameter count."""
1751
+ cache_file = self.config.model.model_cache_dir / "parameter_counts.json"
1752
+ if not cache_file.exists():
1753
+ return None
1754
+
1755
+ try:
1756
+ with open(cache_file, 'r') as f:
1757
+ cache = json.load(f)
1758
+ return cache.get(cache_key)
1759
+ except:
1760
+ return None
1761
+
1762
+ def _cache_parameter_count(self, cache_key: str, param_count: int) -> None:
1763
+ """Cache parameter count for faster future lookups."""
1764
+ cache_file = self.config.model.model_cache_dir / "parameter_counts.json"
1765
+
1766
+ # Load existing cache
1767
+ cache = {}
1768
+ if cache_file.exists():
1769
+ try:
1770
+ with open(cache_file, 'r') as f:
1771
+ cache = json.load(f)
1772
+ except:
1773
+ pass
1774
+
1775
+ # Update cache
1776
+ cache[cache_key] = param_count
1777
+
1778
+ # Keep only recent entries (last 100)
1779
+ if len(cache) > 100:
1780
+ sorted_items = sorted(cache.items(), key=lambda x: x[0])[-100:]
1781
+ cache = dict(sorted_items)
1782
+
1783
+ # Save cache
1784
+ try:
1785
+ cache_file.parent.mkdir(parents=True, exist_ok=True)
1786
+ with open(cache_file, 'w') as f:
1787
+ json.dump(cache, f)
1788
+ except Exception as e:
1789
+ logger.warning(f"Failed to cache parameter count: {e}")
1790
+
1791
+ def _detect_safetensors_parameters(self, model_path: Path) -> Optional[int]:
1792
+ """Detect parameters by reading SafeTensors headers."""
1793
+ try:
1794
+ safetensor_files = list(model_path.glob("*.safetensors"))
1795
+ if not safetensor_files:
1796
+ return None
1797
+
1798
+ total_params = 0
1799
+
1800
+ for st_file in safetensor_files:
1801
+ # Skip adapter files for base model parameter counting
1802
+ if "adapter" in st_file.name.lower():
1803
+ continue
1804
+
1805
+ # Read SafeTensors header to get tensor shapes without loading weights
1806
+ params = self._read_safetensors_header(st_file)
1807
+ if params is not None:
1808
+ total_params += params
1809
+ else:
1810
+ logger.warning(f"Could not read SafeTensors header from {st_file.name}")
1811
+ return None
1812
+
1813
+ return total_params if total_params > 0 else None
1814
+
1815
+ except Exception as e:
1816
+ logger.warning(f"SafeTensors parameter detection failed: {e}")
1817
+ return None
1818
+
1819
+ def _read_safetensors_header(self, file_path: Path) -> Optional[int]:
1820
+ """Read parameter count from SafeTensors file header without loading the full file."""
1821
+ try:
1822
+ with open(file_path, 'rb') as f:
1823
+ # Read the header length (first 8 bytes)
1824
+ header_size_bytes = f.read(8)
1825
+ if len(header_size_bytes) < 8:
1826
+ return None
1827
+
1828
+ header_size = struct.unpack('<Q', header_size_bytes)[0]
1829
+
1830
+ # Read the header JSON
1831
+ header_json = f.read(header_size).decode('utf-8')
1832
+ header = json.loads(header_json)
1833
+
1834
+ # Count parameters from tensor shapes
1835
+ total_params = 0
1836
+ for tensor_name, tensor_info in header.items():
1837
+ if tensor_name == "__metadata__":
1838
+ continue
1839
+
1840
+ # Skip non-parameter tensors (buffers, etc.)
1841
+ if self._is_parameter_tensor(tensor_name):
1842
+ shape = tensor_info.get('shape', [])
1843
+ if shape:
1844
+ tensor_params = 1
1845
+ for dim in shape:
1846
+ tensor_params *= dim
1847
+ total_params += tensor_params
1848
+
1849
+ return total_params
1850
+
1851
+ except Exception as e:
1852
+ logger.debug(f"Failed to read SafeTensors header from {file_path}: {e}")
1853
+ return None
1854
+
1855
+ def _is_parameter_tensor(self, tensor_name: str) -> bool:
1856
+ """Check if a tensor name represents a model parameter (not a buffer)."""
1857
+ # Common parameter patterns
1858
+ param_patterns = [
1859
+ 'weight', 'bias', 'embeddings', 'lm_head',
1860
+ 'q_proj', 'k_proj', 'v_proj', 'o_proj',
1861
+ 'gate_proj', 'up_proj', 'down_proj',
1862
+ 'fc1', 'fc2', 'mlp', 'attention'
1863
+ ]
1864
+
1865
+ # Common non-parameter patterns (buffers)
1866
+ non_param_patterns = [
1867
+ 'position_ids', 'attention_mask', 'token_type_ids',
1868
+ 'freqs_cos', 'freqs_sin', 'inv_freq'
1869
+ ]
1870
+
1871
+ tensor_lower = tensor_name.lower()
1872
+
1873
+ # Check non-parameter patterns first
1874
+ for pattern in non_param_patterns:
1875
+ if pattern in tensor_lower:
1876
+ return False
1877
+
1878
+ # Check parameter patterns
1879
+ for pattern in param_patterns:
1880
+ if pattern in tensor_lower:
1881
+ return True
1882
+
1883
+ # Default: assume it's a parameter if it contains common layer indicators
1884
+ return any(indicator in tensor_lower for indicator in ['layer', 'block', 'transformer'])
1885
+
1886
+ def _detect_mlx_parameters(self, model_path: Path) -> Optional[int]:
1887
+ """Detect parameters in MLX models by inspecting weights.npz or using config."""
1888
+ try:
1889
+ # First try to read from weights.npz directly
1890
+ weights_file = model_path / "weights.npz"
1891
+ if weights_file.exists():
1892
+ import numpy as np
1893
+
1894
+ # Load the weights file
1895
+ weights = np.load(weights_file)
1896
+ total_params = 0
1897
+
1898
+ for array_name in weights.files:
1899
+ if self._is_parameter_tensor(array_name):
1900
+ array = weights[array_name]
1901
+ total_params += array.size
1902
+
1903
+ return total_params if total_params > 0 else None
1904
+
1905
+ # Fallback to checking for SafeTensors in MLX directory
1906
+ safetensor_files = list(model_path.glob("*.safetensors"))
1907
+ if safetensor_files:
1908
+ return self._detect_safetensors_parameters(model_path)
1909
+
1910
+ return None
1911
+
1912
+ except Exception as e:
1913
+ logger.warning(f"MLX parameter detection failed: {e}")
1914
+ return None
1915
+
1916
+ def _detect_pytorch_parameters(self, model_path: Path) -> Optional[int]:
1917
+ """Detect parameters in PyTorch models."""
1918
+ try:
1919
+ # For PyTorch models, we need to load the config to get architecture info
1920
+ # as loading the full model would be too expensive
1921
+ return self._detect_config_parameters(model_path)
1922
+
1923
+ except Exception as e:
1924
+ logger.warning(f"PyTorch parameter detection failed: {e}")
1925
+ return None
1926
+
1927
+ def _detect_config_parameters(self, model_path: Path) -> Optional[int]:
1928
+ """Detect parameters by analyzing config.json and calculating from architecture."""
1929
+ try:
1930
+ config_path = model_path / "config.json"
1931
+ if not config_path.exists():
1932
+ return None
1933
+
1934
+ with open(config_path, 'r') as f:
1935
+ config = json.load(f)
1936
+
1937
+ # Check for directly specified parameter count
1938
+ if 'num_parameters' in config:
1939
+ return int(config['num_parameters'])
1940
+
1941
+ # Calculate from architecture parameters
1942
+ model_type = config.get('model_type', '').lower()
1943
+
1944
+ if model_type in ['llama', 'gemma', 'mistral', 'qwen']:
1945
+ return self._calculate_llama_parameters(config)
1946
+ elif model_type in ['gpt', 'gpt2', 'gpt_neo', 'gpt_neox']:
1947
+ return self._calculate_gpt_parameters(config)
1948
+ elif model_type in ['bert', 'roberta', 'distilbert']:
1949
+ return self._calculate_bert_parameters(config)
1950
+ else:
1951
+ # Generic calculation for transformer models
1952
+ return self._calculate_generic_transformer_parameters(config)
1953
+
1954
+ except Exception as e:
1955
+ logger.warning(f"Config parameter detection failed: {e}")
1956
+ return None
1957
+
1958
+ def _calculate_llama_parameters(self, config: Dict) -> Optional[int]:
1959
+ """Calculate parameters for Llama-style models (including Gemma)."""
1960
+ try:
1961
+ vocab_size = config.get('vocab_size', 32000)
1962
+ hidden_size = config.get('hidden_size', 4096)
1963
+ intermediate_size = config.get('intermediate_size', 11008)
1964
+ num_layers = config.get('num_hidden_layers', 32)
1965
+ num_attention_heads = config.get('num_attention_heads', 32)
1966
+
1967
+ # Check if this is a Gemma model for special handling
1968
+ is_gemma = config.get('model_type', '').lower() == 'gemma'
1969
+
1970
+ # Embedding layer
1971
+ embedding_params = vocab_size * hidden_size
1972
+
1973
+ # Each transformer layer:
1974
+ if is_gemma:
1975
+ # Gemma uses grouped query attention with fewer k/v heads
1976
+ num_key_value_heads = config.get('num_key_value_heads', num_attention_heads // 4)
1977
+ head_dim = hidden_size // num_attention_heads
1978
+
1979
+ # Attention projections
1980
+ q_proj_params = hidden_size * hidden_size # Full size
1981
+ k_proj_params = hidden_size * (num_key_value_heads * head_dim) # Reduced
1982
+ v_proj_params = hidden_size * (num_key_value_heads * head_dim) # Reduced
1983
+ o_proj_params = hidden_size * hidden_size # Full size
1984
+
1985
+ attention_params = q_proj_params + k_proj_params + v_proj_params + o_proj_params
1986
+ else:
1987
+ # Standard Llama: q, k, v, o projections all full size
1988
+ attention_params = 4 * (hidden_size * hidden_size)
1989
+
1990
+ # Feed-forward: gate_proj, up_proj, down_proj
1991
+ ff_params = 2 * (hidden_size * intermediate_size) + (intermediate_size * hidden_size)
1992
+
1993
+ # Layer norms (2 per layer)
1994
+ ln_params = 2 * hidden_size
1995
+
1996
+ layer_params = attention_params + ff_params + ln_params
1997
+ transformer_params = num_layers * layer_params
1998
+
1999
+ # Final layer norm
2000
+ final_ln_params = hidden_size
2001
+
2002
+ # LM head - check if tied to embeddings (common in smaller models)
2003
+ tie_word_embeddings = config.get('tie_word_embeddings', True) # Default True for most models
2004
+ if tie_word_embeddings:
2005
+ lm_head_params = 0 # Tied to embeddings, don't double count
2006
+ else:
2007
+ lm_head_params = vocab_size * hidden_size
2008
+
2009
+ total = embedding_params + transformer_params + final_ln_params + lm_head_params
2010
+
2011
+ return total
2012
+
2013
+ except Exception as e:
2014
+ logger.warning(f"Llama parameter calculation failed: {e}")
2015
+ return None
2016
+
2017
+ def _calculate_gpt_parameters(self, config: Dict) -> Optional[int]:
2018
+ """Calculate parameters for GPT-style models."""
2019
+ try:
2020
+ vocab_size = config.get('vocab_size', 50257)
2021
+ n_embd = config.get('n_embd', config.get('hidden_size', 768))
2022
+ n_layer = config.get('n_layer', config.get('num_hidden_layers', 12))
2023
+ n_head = config.get('n_head', config.get('num_attention_heads', 12))
2024
+
2025
+ # Token + position embeddings
2026
+ max_position_embeddings = config.get('n_positions', config.get('max_position_embeddings', 1024))
2027
+ embedding_params = vocab_size * n_embd + max_position_embeddings * n_embd
2028
+
2029
+ # Each transformer block
2030
+ # - Attention: qkv projection + output projection
2031
+ attention_params = 4 * (n_embd * n_embd)
2032
+
2033
+ # - MLP: typically 4x expansion
2034
+ mlp_size = config.get('n_inner', 4 * n_embd)
2035
+ mlp_params = n_embd * mlp_size + mlp_size * n_embd
2036
+
2037
+ # - Layer norms
2038
+ ln_params = 2 * n_embd
2039
+
2040
+ block_params = attention_params + mlp_params + ln_params
2041
+ transformer_params = n_layer * block_params
2042
+
2043
+ # Final layer norm + LM head
2044
+ final_ln_params = n_embd
2045
+ lm_head_params = vocab_size * n_embd
2046
+
2047
+ total = embedding_params + transformer_params + final_ln_params + lm_head_params
2048
+
2049
+ return total
2050
+
2051
+ except Exception as e:
2052
+ logger.warning(f"GPT parameter calculation failed: {e}")
2053
+ return None
2054
+
2055
+ def _calculate_bert_parameters(self, config: Dict) -> Optional[int]:
2056
+ """Calculate parameters for BERT-style models."""
2057
+ try:
2058
+ vocab_size = config.get('vocab_size', 30522)
2059
+ hidden_size = config.get('hidden_size', 768)
2060
+ num_hidden_layers = config.get('num_hidden_layers', 12)
2061
+ intermediate_size = config.get('intermediate_size', 3072)
2062
+ max_position_embeddings = config.get('max_position_embeddings', 512)
2063
+ type_vocab_size = config.get('type_vocab_size', 2)
2064
+
2065
+ # Embeddings: token + position + token_type
2066
+ embedding_params = (vocab_size * hidden_size +
2067
+ max_position_embeddings * hidden_size +
2068
+ type_vocab_size * hidden_size)
2069
+
2070
+ # Each encoder layer
2071
+ # - Self-attention
2072
+ attention_params = 4 * (hidden_size * hidden_size)
2073
+
2074
+ # - Feed-forward
2075
+ ff_params = hidden_size * intermediate_size + intermediate_size * hidden_size
2076
+
2077
+ # - Layer norms
2078
+ ln_params = 2 * hidden_size
2079
+
2080
+ layer_params = attention_params + ff_params + ln_params
2081
+ encoder_params = num_hidden_layers * layer_params
2082
+
2083
+ # Pooler (optional)
2084
+ pooler_params = hidden_size * hidden_size
2085
+
2086
+ total = embedding_params + encoder_params + pooler_params
2087
+
2088
+ return total
2089
+
2090
+ except Exception as e:
2091
+ logger.warning(f"BERT parameter calculation failed: {e}")
2092
+ return None
2093
+
2094
+ def _calculate_generic_transformer_parameters(self, config: Dict) -> Optional[int]:
2095
+ """Generic parameter calculation for transformer models."""
2096
+ try:
2097
+ # Try to extract common parameters
2098
+ vocab_size = config.get('vocab_size', 32000)
2099
+ hidden_size = config.get('hidden_size', config.get('n_embd', config.get('d_model', 512)))
2100
+ num_layers = config.get('num_hidden_layers', config.get('n_layer', config.get('num_layers', 6)))
2101
+
2102
+ if hidden_size is None or num_layers is None:
2103
+ return None
2104
+
2105
+ # Very rough estimation for generic transformers
2106
+ # Embeddings + layers + head
2107
+ embedding_params = vocab_size * hidden_size
2108
+
2109
+ # Each layer: attention + ffn + norms (rough 6x hidden_size^2 per layer)
2110
+ layer_params = 6 * (hidden_size * hidden_size)
2111
+ transformer_params = num_layers * layer_params
2112
+
2113
+ # Output head
2114
+ head_params = vocab_size * hidden_size
2115
+
2116
+ total = embedding_params + transformer_params + head_params
2117
+
2118
+ logger.info(f"Generic parameter estimation: {total:,} parameters")
2119
+ return total
2120
+
2121
+ except Exception as e:
2122
+ logger.warning(f"Generic parameter calculation failed: {e}")
2123
+ return None
2124
+
2125
+ def get_model_parameters_smart(self, model_path: Path) -> float:
2126
+ """Get model parameters in billions with smart detection, fallback to size estimation."""
2127
+ # Try accurate parameter detection first
2128
+ param_count = self.detect_model_parameters(model_path)
2129
+
2130
+ if param_count is not None:
2131
+ return param_count / 1e9 # Convert to billions
2132
+
2133
+ # Fallback to improved size-based estimation
2134
+ logger.warning(f"Falling back to size-based parameter estimation for {model_path.name}")
2135
+
2136
+ size_bytes = self._get_model_size(model_path)
2137
+ size_gb = size_bytes / (1024**3)
2138
+
2139
+ # Improved estimation that considers file overhead
2140
+ # Only count actual weight files, not tokenizer configs etc.
2141
+ weight_size = self._estimate_weight_file_size(model_path)
2142
+ weight_size_gb = weight_size / (1024**3)
2143
+
2144
+ # Use weight size if significantly different from total size
2145
+ if weight_size_gb < size_gb * 0.8: # If weight files are <80% of total
2146
+ size_gb = weight_size_gb
2147
+ logger.info(f"Using weight-only size: {size_gb:.2f}GB (total: {size_bytes / (1024**3):.2f}GB)")
2148
+
2149
+ # Better default estimation: 2.2 bytes per parameter (accounts for some overhead)
2150
+ estimated_params_b = size_gb / 2.2 # Already in billions
2151
+
2152
+ logger.info(f"Estimated {estimated_params_b:.2f}B parameters from {size_gb:.2f}GB model size")
2153
+ return estimated_params_b
2154
+
2155
+ def _estimate_weight_file_size(self, model_path: Path) -> int:
2156
+ """Estimate size of actual weight files, excluding configs and tokenizers."""
2157
+ if model_path.is_file():
2158
+ return model_path.stat().st_size
2159
+
2160
+ weight_patterns = [
2161
+ '*.safetensors', '*.bin', '*.npz',
2162
+ 'pytorch_model*.bin', 'model*.safetensors'
2163
+ ]
2164
+
2165
+ non_weight_patterns = [
2166
+ 'tokenizer*', 'vocab*', 'merges.txt', 'config.json',
2167
+ 'generation_config.json', 'special_tokens_map.json',
2168
+ 'tokenizer_config.json', 'added_tokens.json'
2169
+ ]
2170
+
2171
+ total_weight_size = 0
2172
+
2173
+ for file_path in model_path.rglob('*'):
2174
+ if file_path.is_file():
2175
+ # Check if it matches weight patterns
2176
+ is_weight = any(file_path.match(pattern) for pattern in weight_patterns)
2177
+
2178
+ # Exclude non-weight files
2179
+ is_non_weight = any(file_path.match(pattern) for pattern in non_weight_patterns)
2180
+
2181
+ if is_weight and not is_non_weight:
2182
+ total_weight_size += file_path.stat().st_size
2183
+ elif not is_non_weight and file_path.suffix in ['.safetensors', '.bin', '.npz']:
2184
+ # Include other tensor files that don't match specific patterns
2185
+ total_weight_size += file_path.stat().st_size
2186
+
2187
+ return total_weight_size