cortex-llm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex/__init__.py +73 -0
- cortex/__main__.py +83 -0
- cortex/config.py +329 -0
- cortex/conversation_manager.py +468 -0
- cortex/fine_tuning/__init__.py +8 -0
- cortex/fine_tuning/dataset.py +332 -0
- cortex/fine_tuning/mlx_lora_trainer.py +502 -0
- cortex/fine_tuning/trainer.py +957 -0
- cortex/fine_tuning/wizard.py +707 -0
- cortex/gpu_validator.py +467 -0
- cortex/inference_engine.py +727 -0
- cortex/metal/__init__.py +275 -0
- cortex/metal/gpu_validator.py +177 -0
- cortex/metal/memory_pool.py +886 -0
- cortex/metal/mlx_accelerator.py +678 -0
- cortex/metal/mlx_converter.py +638 -0
- cortex/metal/mps_optimizer.py +417 -0
- cortex/metal/optimizer.py +665 -0
- cortex/metal/performance_profiler.py +364 -0
- cortex/model_downloader.py +130 -0
- cortex/model_manager.py +2187 -0
- cortex/quantization/__init__.py +5 -0
- cortex/quantization/dynamic_quantizer.py +736 -0
- cortex/template_registry/__init__.py +15 -0
- cortex/template_registry/auto_detector.py +144 -0
- cortex/template_registry/config_manager.py +234 -0
- cortex/template_registry/interactive.py +260 -0
- cortex/template_registry/registry.py +347 -0
- cortex/template_registry/template_profiles/__init__.py +5 -0
- cortex/template_registry/template_profiles/base.py +142 -0
- cortex/template_registry/template_profiles/complex/__init__.py +5 -0
- cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
- cortex/template_registry/template_profiles/standard/__init__.py +9 -0
- cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
- cortex/template_registry/template_profiles/standard/chatml.py +82 -0
- cortex/template_registry/template_profiles/standard/gemma.py +103 -0
- cortex/template_registry/template_profiles/standard/llama.py +87 -0
- cortex/template_registry/template_profiles/standard/simple.py +65 -0
- cortex/ui/__init__.py +120 -0
- cortex/ui/cli.py +1685 -0
- cortex/ui/markdown_render.py +185 -0
- cortex/ui/terminal_app.py +534 -0
- cortex_llm-1.0.0.dist-info/METADATA +275 -0
- cortex_llm-1.0.0.dist-info/RECORD +48 -0
- cortex_llm-1.0.0.dist-info/WHEEL +5 -0
- cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
- cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
- cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
cortex/model_manager.py
ADDED
|
@@ -0,0 +1,2187 @@
|
|
|
1
|
+
"""Model management for GPU-accelerated inference."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, Any, Optional, List, Tuple
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from enum import Enum
|
|
10
|
+
import hashlib
|
|
11
|
+
import json
|
|
12
|
+
import shutil
|
|
13
|
+
import struct
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
|
|
16
|
+
# Configure logging
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
logger.setLevel(logging.INFO)
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import mlx.core as mx
|
|
22
|
+
import mlx.nn as nn
|
|
23
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
|
24
|
+
|
|
25
|
+
# Import MLX LM functions safely
|
|
26
|
+
try:
|
|
27
|
+
from mlx_lm import load as mlx_load
|
|
28
|
+
except ImportError:
|
|
29
|
+
mlx_load = None
|
|
30
|
+
|
|
31
|
+
from cortex.config import Config
|
|
32
|
+
from cortex.gpu_validator import GPUValidator
|
|
33
|
+
from cortex.metal.memory_pool import MemoryPool, AllocationStrategy
|
|
34
|
+
from cortex.metal.mlx_converter import MLXConverter, ConversionConfig, QuantizationRecipe, ConversionFormat
|
|
35
|
+
from cortex.metal.mlx_accelerator import MLXAccelerator, MLXConfig
|
|
36
|
+
from cortex.quantization.dynamic_quantizer import DynamicQuantizer, QuantizationConfig, QuantizationMode
|
|
37
|
+
|
|
38
|
+
# Configure tokenizer parallelism for optimal performance
|
|
39
|
+
# Enable parallelism for better tokenization speed
|
|
40
|
+
# This is safe with threading (not multiprocessing)
|
|
41
|
+
if os.environ.get('TOKENIZERS_PARALLELISM') is None:
|
|
42
|
+
# Enable tokenizer parallelism for better performance
|
|
43
|
+
# Safe with threading, improves tokenization speed
|
|
44
|
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
|
45
|
+
|
|
46
|
+
# Set optimal number of threads for tokenizer
|
|
47
|
+
if os.environ.get('RAYON_NUM_THREADS') is None:
|
|
48
|
+
# Use half of available CPU cores for tokenizer threads
|
|
49
|
+
import multiprocessing
|
|
50
|
+
num_cores = multiprocessing.cpu_count()
|
|
51
|
+
os.environ['RAYON_NUM_THREADS'] = str(max(1, num_cores // 2))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ModelFormat(Enum):
|
|
55
|
+
"""Supported model formats."""
|
|
56
|
+
GGUF = "gguf"
|
|
57
|
+
MLX = "mlx"
|
|
58
|
+
SAFETENSORS = "safetensors"
|
|
59
|
+
PYTORCH = "pytorch"
|
|
60
|
+
QUANTIZED = "quantized" # For GPTQ/AWQ/etc
|
|
61
|
+
UNKNOWN = "unknown"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class QuantizationType(Enum):
|
|
65
|
+
"""Supported quantization types."""
|
|
66
|
+
NONE = "none"
|
|
67
|
+
INT4 = "int4"
|
|
68
|
+
INT8 = "int8"
|
|
69
|
+
GPTQ = "gptq"
|
|
70
|
+
AWQ = "awq"
|
|
71
|
+
Q4_K_M = "Q4_K_M"
|
|
72
|
+
Q5_K_M = "Q5_K_M"
|
|
73
|
+
Q6_K = "Q6_K"
|
|
74
|
+
Q8_0 = "Q8_0"
|
|
75
|
+
FP16 = "FP16"
|
|
76
|
+
FP32 = "FP32"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class ModelInfo:
|
|
81
|
+
"""Information about a loaded model."""
|
|
82
|
+
name: str
|
|
83
|
+
path: Path
|
|
84
|
+
format: ModelFormat
|
|
85
|
+
quantization: QuantizationType
|
|
86
|
+
size_bytes: int
|
|
87
|
+
parameters: int
|
|
88
|
+
context_length: int
|
|
89
|
+
loaded_at: datetime
|
|
90
|
+
gpu_memory_used: int
|
|
91
|
+
tokenizer_path: Optional[Path]
|
|
92
|
+
config: Dict[str, Any]
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def size_gb(self) -> float:
|
|
96
|
+
"""Get model size in GB."""
|
|
97
|
+
return self.size_bytes / (1024 ** 3)
|
|
98
|
+
|
|
99
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
100
|
+
"""Convert to dictionary."""
|
|
101
|
+
return {
|
|
102
|
+
'name': self.name,
|
|
103
|
+
'path': str(self.path),
|
|
104
|
+
'format': self.format.value,
|
|
105
|
+
'quantization': self.quantization.value,
|
|
106
|
+
'size_gb': self.size_gb,
|
|
107
|
+
'parameters': self.parameters,
|
|
108
|
+
'context_length': self.context_length,
|
|
109
|
+
'loaded_at': self.loaded_at.isoformat(),
|
|
110
|
+
'gpu_memory_used': self.gpu_memory_used
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class ModelManager:
|
|
115
|
+
"""Manage model loading and GPU memory allocation."""
|
|
116
|
+
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
config: Config,
|
|
120
|
+
gpu_validator: GPUValidator,
|
|
121
|
+
memory_pool: Optional[MemoryPool] = None
|
|
122
|
+
):
|
|
123
|
+
"""Initialize model manager."""
|
|
124
|
+
self.config = config
|
|
125
|
+
self.gpu_validator = gpu_validator
|
|
126
|
+
self.memory_pool = memory_pool
|
|
127
|
+
self.loaded_models: Dict[str, ModelInfo] = {}
|
|
128
|
+
self.current_model: Optional[str] = None
|
|
129
|
+
self.model_cache: Dict[str, Any] = {}
|
|
130
|
+
self.tokenizers: Dict[str, Any] = {}
|
|
131
|
+
|
|
132
|
+
# Initialize quantizer for memory-efficient loading
|
|
133
|
+
self.quantizer = DynamicQuantizer(QuantizationConfig(
|
|
134
|
+
mode=QuantizationMode.DYNAMIC,
|
|
135
|
+
per_channel=True,
|
|
136
|
+
cache_quantized=True
|
|
137
|
+
))
|
|
138
|
+
|
|
139
|
+
# Initialize MLX converter for native conversion
|
|
140
|
+
# Use a consistent cache directory
|
|
141
|
+
mlx_cache_dir = Path.home() / ".cortex" / "mlx_models"
|
|
142
|
+
self.mlx_converter = MLXConverter(
|
|
143
|
+
cache_dir=mlx_cache_dir
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Initialize MLX accelerator for optimizations
|
|
147
|
+
self.mlx_accelerator = None
|
|
148
|
+
self._mlx_init_error: Optional[str] = None
|
|
149
|
+
try:
|
|
150
|
+
self.mlx_accelerator = MLXAccelerator(MLXConfig(
|
|
151
|
+
compile_model=True,
|
|
152
|
+
use_amx=True,
|
|
153
|
+
fuse_operations=True,
|
|
154
|
+
rotating_kv_cache=True,
|
|
155
|
+
quantization_bits=4
|
|
156
|
+
))
|
|
157
|
+
except Exception as e:
|
|
158
|
+
self._mlx_init_error = str(e)
|
|
159
|
+
logger.warning("MLX accelerator initialization failed: %s", e, exc_info=True)
|
|
160
|
+
|
|
161
|
+
self._setup_directories()
|
|
162
|
+
self._initialize_memory_pool()
|
|
163
|
+
|
|
164
|
+
def __del__(self):
|
|
165
|
+
"""Clean up resources on deletion."""
|
|
166
|
+
try:
|
|
167
|
+
# Unload all models properly
|
|
168
|
+
for model_name in list(self.loaded_models.keys()):
|
|
169
|
+
self.unload_model(model_name)
|
|
170
|
+
except:
|
|
171
|
+
pass # Ignore errors during cleanup
|
|
172
|
+
|
|
173
|
+
def _setup_directories(self) -> None:
|
|
174
|
+
"""Create necessary directories."""
|
|
175
|
+
self.config.model.model_path.expanduser().mkdir(parents=True, exist_ok=True)
|
|
176
|
+
self.config.model.model_cache_dir.expanduser().mkdir(parents=True, exist_ok=True)
|
|
177
|
+
self.config.model.quantization_cache.expanduser().mkdir(parents=True, exist_ok=True)
|
|
178
|
+
|
|
179
|
+
def _initialize_memory_pool(self) -> None:
|
|
180
|
+
"""Initialize memory pool if needed."""
|
|
181
|
+
# Skip if already provided by InferenceEngine to avoid duplication
|
|
182
|
+
if self.memory_pool is not None:
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
if self.config.gpu.force_gpu:
|
|
186
|
+
try:
|
|
187
|
+
# Only create if not already provided
|
|
188
|
+
self.memory_pool = MemoryPool(
|
|
189
|
+
pool_size=None,
|
|
190
|
+
strategy=AllocationStrategy.UNIFIED,
|
|
191
|
+
device="mps" if torch.backends.mps.is_available() else "mlx",
|
|
192
|
+
auto_size=True,
|
|
193
|
+
silent=True # Suppress message since InferenceEngine also creates a pool
|
|
194
|
+
)
|
|
195
|
+
except Exception as e:
|
|
196
|
+
# InferenceEngine likely has its own pool; log for visibility.
|
|
197
|
+
logger.debug("Memory pool initialization skipped: %s", e)
|
|
198
|
+
|
|
199
|
+
def _prefer_speed_quantization(self) -> bool:
|
|
200
|
+
"""Return True when config prefers maximum speed over quality."""
|
|
201
|
+
level = getattr(self.config.gpu, "gpu_optimization_level", "maximum")
|
|
202
|
+
level = str(level).lower().strip()
|
|
203
|
+
return level in {"maximum", "max", "speed", "fast", "performance"}
|
|
204
|
+
|
|
205
|
+
def load_model(
|
|
206
|
+
self,
|
|
207
|
+
model_path: str,
|
|
208
|
+
model_name: Optional[str] = None,
|
|
209
|
+
force_reload: bool = False,
|
|
210
|
+
convert_to_mlx: bool = False,
|
|
211
|
+
quantization: Optional[str] = None
|
|
212
|
+
) -> Tuple[bool, str]:
|
|
213
|
+
"""
|
|
214
|
+
Load a model to GPU memory with optional MLX conversion.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
model_path: Path to model file or directory (or HF repo ID)
|
|
218
|
+
model_name: Optional name for the model
|
|
219
|
+
force_reload: Force reload even if already loaded
|
|
220
|
+
convert_to_mlx: Convert to MLX format for better performance
|
|
221
|
+
quantization: Quantization recipe ('4bit', '5bit', '8bit', 'mixed')
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Tuple of (success, message)
|
|
225
|
+
"""
|
|
226
|
+
# Check if it's a HuggingFace repo ID
|
|
227
|
+
is_hf_repo = "/" in model_path and not Path(model_path).exists()
|
|
228
|
+
|
|
229
|
+
# Auto-enable MLX conversion if MLX backend is enabled in config
|
|
230
|
+
if hasattr(self.config, 'gpu') and hasattr(self.config.gpu, 'mlx_backend'):
|
|
231
|
+
if self.config.gpu.mlx_backend:
|
|
232
|
+
logger.info("MLX backend enabled in config, auto-converting models to MLX format")
|
|
233
|
+
convert_to_mlx = True
|
|
234
|
+
|
|
235
|
+
# Handle MLX conversion for HF models or local models
|
|
236
|
+
if convert_to_mlx or is_hf_repo:
|
|
237
|
+
# Check if this is a cached MLX model name
|
|
238
|
+
if "_4bit" in model_path or "_5bit" in model_path or "_8bit" in model_path or "_none" in model_path:
|
|
239
|
+
# This might be a cached MLX model, check if it exists
|
|
240
|
+
mlx_cache_dir = Path.home() / ".cortex" / "mlx_models"
|
|
241
|
+
cache_path = mlx_cache_dir / Path(model_path).name
|
|
242
|
+
|
|
243
|
+
if cache_path.exists() and cache_path.is_dir():
|
|
244
|
+
logger.info(f"Loading cached MLX model from {cache_path}")
|
|
245
|
+
success, result = self._load_mlx(cache_path, model_name or Path(model_path).name, {
|
|
246
|
+
'format': ModelFormat.MLX,
|
|
247
|
+
'quantization': QuantizationType.INT4 if "_4bit" in model_path else QuantizationType.INT8,
|
|
248
|
+
'reason': 'Cached MLX model'
|
|
249
|
+
})
|
|
250
|
+
|
|
251
|
+
if success:
|
|
252
|
+
self.current_model = model_name or Path(model_path).name
|
|
253
|
+
# When loading from cache, don't update the config - it already has the right path
|
|
254
|
+
return True, f"Successfully loaded MLX model '{model_name or Path(model_path).name}'"
|
|
255
|
+
else:
|
|
256
|
+
return False, f"Failed to load MLX model: {result}"
|
|
257
|
+
else:
|
|
258
|
+
# Cached model not found, need to reconvert from original
|
|
259
|
+
# Try to extract original path from the cached name
|
|
260
|
+
base_name = model_path.replace("_4bit", "").replace("_5bit", "").replace("_8bit", "").replace("_none", "")
|
|
261
|
+
|
|
262
|
+
# Try to find the original model
|
|
263
|
+
if base_name.startswith("_Users_"):
|
|
264
|
+
# This is a local model path encoded in the name
|
|
265
|
+
original_path = "/" + base_name[1:].replace("_", "/")
|
|
266
|
+
original_path = Path(original_path).expanduser()
|
|
267
|
+
|
|
268
|
+
if original_path.exists():
|
|
269
|
+
logger.info(f"Found original model at {original_path}, will convert")
|
|
270
|
+
model_path = str(original_path)
|
|
271
|
+
# Continue with normal conversion flow
|
|
272
|
+
else:
|
|
273
|
+
return False, f"Cached MLX model not found and original model not found at {original_path}"
|
|
274
|
+
else:
|
|
275
|
+
return False, f"Cached MLX model not found at {cache_path}"
|
|
276
|
+
|
|
277
|
+
# Check if model is already in MLX format by looking ahead
|
|
278
|
+
test_path = Path(model_path).expanduser().resolve()
|
|
279
|
+
if test_path.exists() and test_path.is_dir():
|
|
280
|
+
# Check if it's in the mlx_models directory - if so, it's already converted
|
|
281
|
+
mlx_models_dir = Path.home() / ".cortex" / "mlx_models"
|
|
282
|
+
# Use proper path comparison
|
|
283
|
+
try:
|
|
284
|
+
is_in_mlx_dir = test_path.is_relative_to(mlx_models_dir)
|
|
285
|
+
except (ValueError, AttributeError):
|
|
286
|
+
# Fallback for older Python versions
|
|
287
|
+
is_in_mlx_dir = str(mlx_models_dir.resolve()) in str(test_path.resolve())
|
|
288
|
+
|
|
289
|
+
# Check for MLX format markers - include adapter files for fine-tuned models
|
|
290
|
+
has_mlx_weights = (test_path / 'weights.npz').exists() or (test_path / 'model.safetensors').exists()
|
|
291
|
+
has_config = (test_path / 'config.json').exists()
|
|
292
|
+
has_adapter = (test_path / 'adapter.safetensors').exists()
|
|
293
|
+
has_fine_tuned_marker = (test_path / 'fine_tuned.marker').exists()
|
|
294
|
+
|
|
295
|
+
# A model is MLX format if:
|
|
296
|
+
# 1. It's in the mlx_models directory, OR
|
|
297
|
+
# 2. It has MLX weights and config, OR
|
|
298
|
+
# 3. It's a fine-tuned model with adapters
|
|
299
|
+
if is_in_mlx_dir or (has_mlx_weights and has_config) or has_fine_tuned_marker or has_adapter:
|
|
300
|
+
# Already MLX format, skip conversion
|
|
301
|
+
logger.info(f"Model at {model_path} is already in MLX format, skipping conversion")
|
|
302
|
+
path = test_path
|
|
303
|
+
|
|
304
|
+
# Check if it's a fine-tuned model
|
|
305
|
+
format_info = {
|
|
306
|
+
'format': ModelFormat.MLX,
|
|
307
|
+
'quantization': QuantizationType.NONE, # Will be detected from model
|
|
308
|
+
'reason': 'Existing MLX model'
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
# Check config for fine-tuning markers
|
|
312
|
+
config_path = path / "config.json"
|
|
313
|
+
if config_path.exists():
|
|
314
|
+
try:
|
|
315
|
+
with open(config_path, 'r') as f:
|
|
316
|
+
config = json.load(f)
|
|
317
|
+
if config.get('fine_tuned') or config.get('lora_adapter'):
|
|
318
|
+
format_info['is_fine_tuned'] = True
|
|
319
|
+
format_info['has_lora_adapter'] = True
|
|
320
|
+
logger.info("Detected fine-tuned model with LoRA adapters")
|
|
321
|
+
except Exception as e:
|
|
322
|
+
logger.warning(f"Could not read config to check fine-tuning status: {e}")
|
|
323
|
+
|
|
324
|
+
# Check for adapter files
|
|
325
|
+
if (path / "adapter.safetensors").exists() or (path / "adapter_config.json").exists():
|
|
326
|
+
format_info['has_lora_adapter'] = True
|
|
327
|
+
logger.info("Found LoRA adapter files")
|
|
328
|
+
|
|
329
|
+
# Load the existing MLX model directly
|
|
330
|
+
success, result = self._load_mlx(path, model_name or path.name, format_info)
|
|
331
|
+
|
|
332
|
+
if success:
|
|
333
|
+
self.current_model = model_name or path.name
|
|
334
|
+
# Store the original model path
|
|
335
|
+
self.config.update_last_used_model(str(model_path))
|
|
336
|
+
return True, f"Successfully loaded MLX model '{model_name or path.name}'"
|
|
337
|
+
else:
|
|
338
|
+
return False, f"Failed to load MLX model: {result}"
|
|
339
|
+
else:
|
|
340
|
+
# Needs conversion
|
|
341
|
+
# Determine quantization recipe - smart default based on model size
|
|
342
|
+
# First check model name for size hints
|
|
343
|
+
model_name_lower = str(test_path).lower()
|
|
344
|
+
if any(size in model_name_lower for size in ['270m', '350m', '500m']):
|
|
345
|
+
quant_recipe = QuantizationRecipe.NONE
|
|
346
|
+
logger.info(f"Very small model detected from name ({test_path.name}), skipping quantization")
|
|
347
|
+
print(f"Note: Very small model detected ({test_path.name}), skipping quantization for quality")
|
|
348
|
+
elif any(size in model_name_lower for size in ['1b', '2b', '3b']):
|
|
349
|
+
quant_recipe = QuantizationRecipe.QUALITY_8BIT
|
|
350
|
+
logger.info(f"Small model detected from name ({test_path.name}), using 8-bit quantization")
|
|
351
|
+
print(f"Note: Small model detected ({test_path.name}), using 8-bit quantization")
|
|
352
|
+
else:
|
|
353
|
+
# Use smart parameter detection for quantization decisions
|
|
354
|
+
try:
|
|
355
|
+
model_size_gb = self._get_model_size(test_path) / (1024**3)
|
|
356
|
+
|
|
357
|
+
# Use accurate parameter detection (returns billions)
|
|
358
|
+
actual_params_b = self.get_model_parameters_smart(test_path)
|
|
359
|
+
params_billions = float(actual_params_b) if actual_params_b is not None else (model_size_gb / 2.2)
|
|
360
|
+
|
|
361
|
+
logger.info(f"Model size: {model_size_gb:.2f}GB, parameters: {params_billions:.2f}B")
|
|
362
|
+
|
|
363
|
+
# For very small models like Gemma-270M, avoid quantization
|
|
364
|
+
if params_billions < 0.5:
|
|
365
|
+
quant_recipe = QuantizationRecipe.NONE
|
|
366
|
+
logger.info(f"Very small model detected ({self._format_param_count(params_billions)} params), skipping quantization")
|
|
367
|
+
print(f"Note: Very small model detected ({self._format_param_count(params_billions)} params), skipping quantization")
|
|
368
|
+
elif params_billions < 1.0:
|
|
369
|
+
quant_recipe = QuantizationRecipe.QUALITY_8BIT # Prefer 8bit for small models
|
|
370
|
+
logger.info(f"Small model detected ({self._format_param_count(params_billions)} params), using 8-bit quantization")
|
|
371
|
+
print(f"Note: Small model detected ({self._format_param_count(params_billions)} params), using 8-bit quantization")
|
|
372
|
+
else:
|
|
373
|
+
quant_recipe = QuantizationRecipe.SPEED_4BIT # Default for larger models
|
|
374
|
+
except Exception as e:
|
|
375
|
+
logger.warning(f"Could not estimate model parameters: {e}, defaulting to 4-bit")
|
|
376
|
+
quant_recipe = QuantizationRecipe.SPEED_4BIT # Fallback
|
|
377
|
+
|
|
378
|
+
if quantization:
|
|
379
|
+
quant_map = {
|
|
380
|
+
"4bit": QuantizationRecipe.SPEED_4BIT,
|
|
381
|
+
"5bit": QuantizationRecipe.BALANCED_5BIT,
|
|
382
|
+
"8bit": QuantizationRecipe.QUALITY_8BIT,
|
|
383
|
+
"mixed": QuantizationRecipe.MIXED_PRECISION,
|
|
384
|
+
"none": QuantizationRecipe.NONE
|
|
385
|
+
}
|
|
386
|
+
quant_recipe = quant_map.get(quantization, quant_recipe) # Use smart default if invalid
|
|
387
|
+
elif self._prefer_speed_quantization() and quant_recipe != QuantizationRecipe.NONE:
|
|
388
|
+
if quant_recipe != QuantizationRecipe.SPEED_4BIT:
|
|
389
|
+
logger.info("Max optimization enabled, using 4-bit quantization for best tokens/sec")
|
|
390
|
+
print("Note: Max optimization enabled, using 4-bit quantization for best tokens/sec")
|
|
391
|
+
quant_recipe = QuantizationRecipe.SPEED_4BIT
|
|
392
|
+
|
|
393
|
+
# Determine source format
|
|
394
|
+
source_format = ConversionFormat.HUGGINGFACE if is_hf_repo else ConversionFormat.SAFETENSORS
|
|
395
|
+
|
|
396
|
+
# For local SafeTensors models, ensure correct format detection
|
|
397
|
+
if not is_hf_repo and test_path.exists():
|
|
398
|
+
if any(f.suffix == '.safetensors' for f in test_path.glob('*.safetensors')):
|
|
399
|
+
source_format = ConversionFormat.SAFETENSORS
|
|
400
|
+
logger.info("Detected SafeTensors format for conversion to MLX")
|
|
401
|
+
|
|
402
|
+
# Convert to MLX
|
|
403
|
+
conversion_config = ConversionConfig(
|
|
404
|
+
source_format=source_format,
|
|
405
|
+
quantization=quant_recipe,
|
|
406
|
+
use_amx=True,
|
|
407
|
+
compile_model=True
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
logger.info(f"Converting model to MLX format with {quant_recipe.name} quantization...")
|
|
411
|
+
print(f"Converting model to MLX format for optimal performance...")
|
|
412
|
+
|
|
413
|
+
success, msg, mlx_path = self.mlx_converter.convert_model(
|
|
414
|
+
model_path,
|
|
415
|
+
output_name=model_name,
|
|
416
|
+
config=conversion_config
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
if not success:
|
|
420
|
+
return False, f"MLX conversion failed: {msg}"
|
|
421
|
+
|
|
422
|
+
# Update path to converted model
|
|
423
|
+
path = mlx_path
|
|
424
|
+
print(f"✓ Model converted to MLX format at {mlx_path}")
|
|
425
|
+
logger.info(f"Successfully converted model to MLX format at {mlx_path}")
|
|
426
|
+
|
|
427
|
+
# Now load the converted MLX model directly
|
|
428
|
+
success, result = self._load_mlx(path, model_name or path.name, {
|
|
429
|
+
'format': ModelFormat.MLX,
|
|
430
|
+
'quantization': self._get_quantization_type_from_recipe(quant_recipe),
|
|
431
|
+
'reason': 'MLX converted model'
|
|
432
|
+
})
|
|
433
|
+
|
|
434
|
+
if success:
|
|
435
|
+
self.current_model = model_name or path.name
|
|
436
|
+
# Store the original model path
|
|
437
|
+
self.config.update_last_used_model(str(model_path))
|
|
438
|
+
return True, f"Successfully loaded MLX-converted model '{model_name or path.name}'"
|
|
439
|
+
else:
|
|
440
|
+
return False, f"Failed to load converted MLX model: {result}"
|
|
441
|
+
else:
|
|
442
|
+
# HuggingFace repo or non-existent path - needs conversion
|
|
443
|
+
# Determine quantization recipe - smart default for HF models too
|
|
444
|
+
# For HF models, we don't know the exact size, but we can use heuristics
|
|
445
|
+
model_name_lower = model_path.lower()
|
|
446
|
+
if any(size in model_name_lower for size in ['270m', '350m', '500m']):
|
|
447
|
+
quant_recipe = QuantizationRecipe.NONE # Very small models
|
|
448
|
+
logger.info(f"Very small model detected from name ({model_path}), skipping quantization")
|
|
449
|
+
elif any(size in model_name_lower for size in ['1b', '2b', '3b']):
|
|
450
|
+
quant_recipe = QuantizationRecipe.QUALITY_8BIT # Small models
|
|
451
|
+
logger.info(f"Small model detected from name ({model_path}), using 8-bit quantization")
|
|
452
|
+
else:
|
|
453
|
+
quant_recipe = QuantizationRecipe.SPEED_4BIT # Default for larger models
|
|
454
|
+
|
|
455
|
+
if quantization:
|
|
456
|
+
quant_map = {
|
|
457
|
+
"4bit": QuantizationRecipe.SPEED_4BIT,
|
|
458
|
+
"5bit": QuantizationRecipe.BALANCED_5BIT,
|
|
459
|
+
"8bit": QuantizationRecipe.QUALITY_8BIT,
|
|
460
|
+
"mixed": QuantizationRecipe.MIXED_PRECISION,
|
|
461
|
+
"none": QuantizationRecipe.NONE
|
|
462
|
+
}
|
|
463
|
+
quant_recipe = quant_map.get(quantization, quant_recipe)
|
|
464
|
+
elif self._prefer_speed_quantization() and quant_recipe != QuantizationRecipe.NONE:
|
|
465
|
+
if quant_recipe != QuantizationRecipe.SPEED_4BIT:
|
|
466
|
+
logger.info("Max optimization enabled, using 4-bit quantization for best tokens/sec")
|
|
467
|
+
print("Note: Max optimization enabled, using 4-bit quantization for best tokens/sec")
|
|
468
|
+
quant_recipe = QuantizationRecipe.SPEED_4BIT
|
|
469
|
+
|
|
470
|
+
# Convert to MLX
|
|
471
|
+
conversion_config = ConversionConfig(
|
|
472
|
+
source_format=ConversionFormat.HUGGINGFACE if is_hf_repo else ConversionFormat.SAFETENSORS,
|
|
473
|
+
quantization=quant_recipe,
|
|
474
|
+
use_amx=True,
|
|
475
|
+
compile_model=True
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
logger.info(f"Converting HF model to MLX format with {quant_recipe.name} quantization...")
|
|
479
|
+
print(f"Downloading and converting model to MLX format...")
|
|
480
|
+
|
|
481
|
+
success, msg, mlx_path = self.mlx_converter.convert_model(
|
|
482
|
+
model_path,
|
|
483
|
+
output_name=model_name,
|
|
484
|
+
config=conversion_config
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
if not success:
|
|
488
|
+
return False, f"MLX conversion failed: {msg}"
|
|
489
|
+
|
|
490
|
+
# Update path to converted model
|
|
491
|
+
path = mlx_path
|
|
492
|
+
print(f"✓ Model converted to MLX format at {mlx_path}")
|
|
493
|
+
logger.info(f"Successfully converted model to MLX format at {mlx_path}")
|
|
494
|
+
|
|
495
|
+
# Now load the converted MLX model directly
|
|
496
|
+
success, result = self._load_mlx(path, model_name or path.name, {
|
|
497
|
+
'format': ModelFormat.MLX,
|
|
498
|
+
'quantization': self._get_quantization_type_from_recipe(quant_recipe),
|
|
499
|
+
'reason': 'MLX converted model'
|
|
500
|
+
})
|
|
501
|
+
|
|
502
|
+
if success:
|
|
503
|
+
self.current_model = model_name or path.name
|
|
504
|
+
# Store the original model path
|
|
505
|
+
self.config.update_last_used_model(str(model_path))
|
|
506
|
+
return True, f"Successfully loaded MLX-converted model '{model_name or path.name}'"
|
|
507
|
+
else:
|
|
508
|
+
return False, f"Failed to load converted MLX model: {result}"
|
|
509
|
+
else:
|
|
510
|
+
# Validate inputs - properly expand home directory
|
|
511
|
+
path = Path(model_path).expanduser().resolve()
|
|
512
|
+
if not path.exists():
|
|
513
|
+
# Try mlx-community models
|
|
514
|
+
if model_path.startswith("mlx-community/"):
|
|
515
|
+
# Download from HuggingFace mlx-community
|
|
516
|
+
success, msg, mlx_path = self.mlx_converter.convert_model(
|
|
517
|
+
model_path,
|
|
518
|
+
output_name=model_name,
|
|
519
|
+
config=ConversionConfig(quantization=QuantizationRecipe.NONE)
|
|
520
|
+
)
|
|
521
|
+
if success:
|
|
522
|
+
path = mlx_path
|
|
523
|
+
else:
|
|
524
|
+
return False, f"Failed to download MLX model: {msg}"
|
|
525
|
+
else:
|
|
526
|
+
return False, f"Model path does not exist: {model_path}"
|
|
527
|
+
|
|
528
|
+
# Use the full name for directories, stem for files
|
|
529
|
+
if path.is_dir():
|
|
530
|
+
model_name = model_name or path.name
|
|
531
|
+
else:
|
|
532
|
+
model_name = model_name or path.stem
|
|
533
|
+
|
|
534
|
+
# Check if already loaded
|
|
535
|
+
if model_name in self.loaded_models and not force_reload:
|
|
536
|
+
self.current_model = model_name
|
|
537
|
+
# Still update last used model even if already loaded
|
|
538
|
+
self.config.update_last_used_model(model_name)
|
|
539
|
+
return True, f"Model '{model_name}' already loaded"
|
|
540
|
+
|
|
541
|
+
# Check memory constraints
|
|
542
|
+
if len(self.loaded_models) >= self.config.model.max_loaded_models:
|
|
543
|
+
oldest = min(self.loaded_models.items(), key=lambda x: x[1].loaded_at)[0]
|
|
544
|
+
self.unload_model(oldest)
|
|
545
|
+
|
|
546
|
+
# Detect format and load
|
|
547
|
+
format_info = self._detect_format(path)
|
|
548
|
+
if format_info['format'] == ModelFormat.UNKNOWN:
|
|
549
|
+
return False, f"Unknown model format: {format_info['reason']}"
|
|
550
|
+
|
|
551
|
+
# Check GPU compatibility and determine if quantization is needed
|
|
552
|
+
model_size_gb = self._get_model_size(path) / (1024**3)
|
|
553
|
+
can_load, message = self.gpu_validator.verify_model_compatibility(model_size_gb)
|
|
554
|
+
|
|
555
|
+
# Determine if we need quantization (only for non-quantized models)
|
|
556
|
+
needs_quantization = False
|
|
557
|
+
quantization_mode = None
|
|
558
|
+
|
|
559
|
+
# Only apply dynamic quantization to non-quantized SafeTensors/PyTorch models
|
|
560
|
+
can_apply_quantization = (
|
|
561
|
+
format_info['format'] in [ModelFormat.SAFETENSORS, ModelFormat.PYTORCH] and
|
|
562
|
+
format_info['quantization'] == QuantizationType.NONE
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
if not can_load and can_apply_quantization:
|
|
566
|
+
# Check if quantization would help
|
|
567
|
+
gpu_status = self.gpu_validator.get_gpu_memory_status()
|
|
568
|
+
available_gb = gpu_status['available_gb']
|
|
569
|
+
|
|
570
|
+
# DEBUG: Uncomment to see memory calculations for quantization decisions
|
|
571
|
+
# print(f"DEBUG: Model size on disk: {model_size_gb:.1f}GB, Available memory: {available_gb:.1f}GB")
|
|
572
|
+
|
|
573
|
+
# Estimate if INT8 quantization would fit
|
|
574
|
+
# Use same 3.5x multiplier as gpu_validator for consistency
|
|
575
|
+
# INT8 is more stable than INT4, so prefer it when possible
|
|
576
|
+
estimated_int8_size = model_size_gb * 0.5 * 2.5 # 50% reduction + 150% overhead (less conservative for INT8)
|
|
577
|
+
|
|
578
|
+
# DEBUG: Uncomment to see quantization size estimates
|
|
579
|
+
# print(f"DEBUG: INT8 estimated size: {estimated_int8_size:.1f}GB")
|
|
580
|
+
|
|
581
|
+
if estimated_int8_size <= available_gb:
|
|
582
|
+
needs_quantization = True
|
|
583
|
+
quantization_mode = 'int8'
|
|
584
|
+
required_with_overhead = model_size_gb * 3.5
|
|
585
|
+
print(f"Model requires {required_with_overhead:.1f}GB (including overhead), only {available_gb:.1f}GB available.")
|
|
586
|
+
print(f"Will apply INT8 quantization to reduce to ~{estimated_int8_size:.1f}GB")
|
|
587
|
+
else:
|
|
588
|
+
# Try INT4 as last resort
|
|
589
|
+
estimated_int4_size = model_size_gb * 0.25 * 3.5 # 75% reduction + 250% overhead
|
|
590
|
+
|
|
591
|
+
# DEBUG: Uncomment to see INT4 quantization estimates
|
|
592
|
+
# print(f"DEBUG: INT4 estimated size: {estimated_int4_size:.1f}GB")
|
|
593
|
+
|
|
594
|
+
if estimated_int4_size <= available_gb:
|
|
595
|
+
needs_quantization = True
|
|
596
|
+
quantization_mode = 'int4'
|
|
597
|
+
required_with_overhead = model_size_gb * 3.5
|
|
598
|
+
print(f"Model requires {required_with_overhead:.1f}GB (including overhead), only {available_gb:.1f}GB available.")
|
|
599
|
+
print(f"Will apply INT4 quantization to reduce to ~{estimated_int4_size:.1f}GB")
|
|
600
|
+
else:
|
|
601
|
+
return False, f"Model too large even with quantization: {message}"
|
|
602
|
+
elif not can_load:
|
|
603
|
+
# Can't apply quantization to this format
|
|
604
|
+
return False, f"GPU incompatible: {message}"
|
|
605
|
+
|
|
606
|
+
# Load based on format
|
|
607
|
+
loader_map = {
|
|
608
|
+
ModelFormat.MLX: self._load_mlx,
|
|
609
|
+
ModelFormat.GGUF: self._load_gguf,
|
|
610
|
+
ModelFormat.SAFETENSORS: self._load_safetensors,
|
|
611
|
+
ModelFormat.PYTORCH: self._load_pytorch,
|
|
612
|
+
ModelFormat.QUANTIZED: self._load_quantized
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
loader = loader_map.get(format_info['format'])
|
|
616
|
+
if not loader:
|
|
617
|
+
return False, f"No loader for format: {format_info['format'].value}"
|
|
618
|
+
|
|
619
|
+
try:
|
|
620
|
+
# Pass quantization info to loaders that support it
|
|
621
|
+
if format_info['format'] in [ModelFormat.SAFETENSORS, ModelFormat.PYTORCH]:
|
|
622
|
+
success, result = loader(path, model_name, format_info, needs_quantization, quantization_mode)
|
|
623
|
+
else:
|
|
624
|
+
success, result = loader(path, model_name, format_info)
|
|
625
|
+
|
|
626
|
+
if success:
|
|
627
|
+
self.current_model = model_name
|
|
628
|
+
# Save the last used model to config
|
|
629
|
+
self.config.update_last_used_model(model_name)
|
|
630
|
+
return True, f"Successfully loaded '{model_name}'"
|
|
631
|
+
else:
|
|
632
|
+
return False, result
|
|
633
|
+
except Exception as e:
|
|
634
|
+
return False, f"Error loading model: {str(e)}"
|
|
635
|
+
|
|
636
|
+
def _detect_format(self, path: Path) -> Dict[str, Any]:
|
|
637
|
+
"""
|
|
638
|
+
Detect model format from path.
|
|
639
|
+
|
|
640
|
+
Returns:
|
|
641
|
+
Dict with 'format', 'quantization', and 'reason'
|
|
642
|
+
"""
|
|
643
|
+
# Check for specific file types
|
|
644
|
+
if path.is_file():
|
|
645
|
+
if path.suffix.lower() == '.gguf':
|
|
646
|
+
return {
|
|
647
|
+
'format': ModelFormat.GGUF,
|
|
648
|
+
'quantization': QuantizationType.Q4_K_M,
|
|
649
|
+
'reason': 'GGUF file'
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
# Check for directory-based formats
|
|
653
|
+
if path.is_dir():
|
|
654
|
+
# Check for MLX format - support both regular MLX and fine-tuned MLX models
|
|
655
|
+
has_config = (path / 'config.json').exists()
|
|
656
|
+
has_weights_npz = (path / 'weights.npz').exists()
|
|
657
|
+
has_safetensors = any(path.glob('*.safetensors'))
|
|
658
|
+
has_fine_tuned_marker = (path / 'fine_tuned.marker').exists()
|
|
659
|
+
|
|
660
|
+
if has_config and (has_weights_npz or has_safetensors):
|
|
661
|
+
# Detect if it's a fine-tuned model with LoRA adapters
|
|
662
|
+
has_adapter = (path / 'adapter.safetensors').exists()
|
|
663
|
+
|
|
664
|
+
return {
|
|
665
|
+
'format': ModelFormat.MLX,
|
|
666
|
+
'quantization': QuantizationType.NONE,
|
|
667
|
+
'reason': 'MLX model with LoRA adapters' if has_adapter else 'MLX model (weights + config)',
|
|
668
|
+
'has_lora_adapter': has_adapter,
|
|
669
|
+
'is_fine_tuned': has_fine_tuned_marker
|
|
670
|
+
}
|
|
671
|
+
|
|
672
|
+
# Check for SafeTensors format
|
|
673
|
+
safetensor_files = list(path.glob('*.safetensors'))
|
|
674
|
+
if safetensor_files and (path / 'config.json').exists():
|
|
675
|
+
# Check if it's quantized by looking at the config
|
|
676
|
+
quantization = self._detect_quantization(path)
|
|
677
|
+
if quantization in [QuantizationType.GPTQ, QuantizationType.AWQ, QuantizationType.INT4, QuantizationType.INT8]:
|
|
678
|
+
return {
|
|
679
|
+
'format': ModelFormat.QUANTIZED,
|
|
680
|
+
'quantization': quantization,
|
|
681
|
+
'reason': f'Quantized model ({quantization.value})'
|
|
682
|
+
}
|
|
683
|
+
else:
|
|
684
|
+
return {
|
|
685
|
+
'format': ModelFormat.SAFETENSORS,
|
|
686
|
+
'quantization': quantization,
|
|
687
|
+
'reason': 'SafeTensors model'
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
# Check for PyTorch format
|
|
691
|
+
if (path / 'pytorch_model.bin').exists() or list(path.glob('pytorch_model*.bin')):
|
|
692
|
+
return {
|
|
693
|
+
'format': ModelFormat.PYTORCH,
|
|
694
|
+
'quantization': QuantizationType.NONE,
|
|
695
|
+
'reason': 'PyTorch model'
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
return {
|
|
699
|
+
'format': ModelFormat.UNKNOWN,
|
|
700
|
+
'quantization': QuantizationType.NONE,
|
|
701
|
+
'reason': 'No recognized model files found'
|
|
702
|
+
}
|
|
703
|
+
|
|
704
|
+
def _detect_quantization(self, path: Path) -> QuantizationType:
|
|
705
|
+
"""Detect quantization type from model files."""
|
|
706
|
+
# Check config.json for quantization info
|
|
707
|
+
config_path = path / 'config.json'
|
|
708
|
+
if config_path.exists():
|
|
709
|
+
try:
|
|
710
|
+
with open(config_path) as f:
|
|
711
|
+
config = json.load(f)
|
|
712
|
+
|
|
713
|
+
# Check for quantization config
|
|
714
|
+
if 'quantization_config' in config:
|
|
715
|
+
quant_config = config['quantization_config']
|
|
716
|
+
if 'quant_method' in quant_config:
|
|
717
|
+
method = quant_config['quant_method'].upper()
|
|
718
|
+
if 'GPTQ' in method:
|
|
719
|
+
return QuantizationType.GPTQ
|
|
720
|
+
elif 'AWQ' in method:
|
|
721
|
+
return QuantizationType.AWQ
|
|
722
|
+
if 'bits' in quant_config:
|
|
723
|
+
bits = quant_config['bits']
|
|
724
|
+
if bits == 4:
|
|
725
|
+
return QuantizationType.INT4
|
|
726
|
+
elif bits == 8:
|
|
727
|
+
return QuantizationType.INT8
|
|
728
|
+
|
|
729
|
+
# Check model name for hints (be careful with model size indicators like 4B = 4 billion)
|
|
730
|
+
model_name = str(path.name).upper()
|
|
731
|
+
# Only detect as quantized if explicitly mentioned
|
|
732
|
+
if '4BIT' in model_name or 'INT4' in model_name or 'GPTQ-4' in model_name:
|
|
733
|
+
return QuantizationType.INT4
|
|
734
|
+
elif '8BIT' in model_name or 'INT8' in model_name or 'GPTQ-8' in model_name:
|
|
735
|
+
return QuantizationType.INT8
|
|
736
|
+
elif 'GPTQ' in model_name:
|
|
737
|
+
return QuantizationType.GPTQ
|
|
738
|
+
elif 'AWQ' in model_name:
|
|
739
|
+
return QuantizationType.AWQ
|
|
740
|
+
except:
|
|
741
|
+
pass
|
|
742
|
+
|
|
743
|
+
# Check for quantization-specific files
|
|
744
|
+
safetensor_files = list(path.glob('*.safetensors'))
|
|
745
|
+
if safetensor_files:
|
|
746
|
+
# Load one file to check for quantization tensors
|
|
747
|
+
try:
|
|
748
|
+
from safetensors.torch import load_file
|
|
749
|
+
sample = load_file(safetensor_files[0], device='cpu')
|
|
750
|
+
# Check for GPTQ/AWQ specific tensors
|
|
751
|
+
has_scales = any('.scales' in k for k in sample.keys())
|
|
752
|
+
has_qweight = any('.qweight' in k for k in sample.keys())
|
|
753
|
+
|
|
754
|
+
if has_scales or has_qweight:
|
|
755
|
+
return QuantizationType.GPTQ
|
|
756
|
+
except:
|
|
757
|
+
pass
|
|
758
|
+
|
|
759
|
+
return QuantizationType.NONE
|
|
760
|
+
|
|
761
|
+
def _get_quantization_type_from_recipe(self, recipe: QuantizationRecipe) -> QuantizationType:
|
|
762
|
+
"""Convert MLX quantization recipe to QuantizationType."""
|
|
763
|
+
recipe_to_type = {
|
|
764
|
+
QuantizationRecipe.SPEED_4BIT: QuantizationType.INT4,
|
|
765
|
+
QuantizationRecipe.BALANCED_5BIT: QuantizationType.INT4, # Closest match
|
|
766
|
+
QuantizationRecipe.QUALITY_8BIT: QuantizationType.INT8,
|
|
767
|
+
QuantizationRecipe.MIXED_PRECISION: QuantizationType.INT4,
|
|
768
|
+
QuantizationRecipe.NONE: QuantizationType.NONE
|
|
769
|
+
}
|
|
770
|
+
return recipe_to_type.get(recipe, QuantizationType.INT4)
|
|
771
|
+
|
|
772
|
+
def _format_param_count(self, params_b: Optional[float]) -> str:
|
|
773
|
+
"""Format a parameter count in billions as a human-readable string (M/B/T)."""
|
|
774
|
+
try:
|
|
775
|
+
if params_b is None:
|
|
776
|
+
return "unknown"
|
|
777
|
+
# Trillions
|
|
778
|
+
if params_b >= 1000:
|
|
779
|
+
return f"{params_b / 1000:.1f}T"
|
|
780
|
+
# Billions
|
|
781
|
+
if params_b >= 1:
|
|
782
|
+
return f"{params_b:.1f}B"
|
|
783
|
+
# Millions (10M - 999M)
|
|
784
|
+
if params_b >= 0.01:
|
|
785
|
+
return f"{params_b * 1000:.0f}M"
|
|
786
|
+
# Low millions (1M - 9.9M)
|
|
787
|
+
if params_b >= 0.001:
|
|
788
|
+
return f"{params_b * 1000:.1f}M"
|
|
789
|
+
# Thousands
|
|
790
|
+
if params_b > 0:
|
|
791
|
+
return f"{params_b * 1e6:.0f}K"
|
|
792
|
+
return "0"
|
|
793
|
+
except Exception:
|
|
794
|
+
# Fallback formatting
|
|
795
|
+
try:
|
|
796
|
+
return f"{float(params_b):.2f}B"
|
|
797
|
+
except Exception:
|
|
798
|
+
return "unknown"
|
|
799
|
+
|
|
800
|
+
def _get_model_size(self, path: Path) -> int:
|
|
801
|
+
"""Get total size of model files in bytes."""
|
|
802
|
+
if path.is_file():
|
|
803
|
+
return path.stat().st_size
|
|
804
|
+
elif path.is_dir():
|
|
805
|
+
total = 0
|
|
806
|
+
# Check if this is a fine-tuned model with LoRA adapters
|
|
807
|
+
is_finetuned = (path / 'fine_tuned.marker').exists() or (path / 'adapter.safetensors').exists()
|
|
808
|
+
|
|
809
|
+
for file in path.rglob('*'):
|
|
810
|
+
if file.is_file():
|
|
811
|
+
# Skip training checkpoint files for fine-tuned models
|
|
812
|
+
if is_finetuned and file.name.endswith('_adapters.safetensors'):
|
|
813
|
+
# These are intermediate checkpoints, not needed for inference
|
|
814
|
+
continue
|
|
815
|
+
# Skip cache and git files
|
|
816
|
+
if '/.cache/' in str(file) or '/.git/' in str(file):
|
|
817
|
+
continue
|
|
818
|
+
total += file.stat().st_size
|
|
819
|
+
return total
|
|
820
|
+
return 0
|
|
821
|
+
|
|
822
|
+
def _load_mlx(self, path: Path, model_name: str, format_info: Dict) -> Tuple[bool, str]:
|
|
823
|
+
"""Load MLX format model with optimizations and LoRA adapter support."""
|
|
824
|
+
try:
|
|
825
|
+
if mlx_load is None:
|
|
826
|
+
return False, "MLX LM library not available. Install with: pip install mlx-lm"
|
|
827
|
+
|
|
828
|
+
# Check if this is a fine-tuned model with LoRA adapters
|
|
829
|
+
has_adapter = format_info.get('has_lora_adapter', False)
|
|
830
|
+
is_fine_tuned = format_info.get('is_fine_tuned', False)
|
|
831
|
+
|
|
832
|
+
if has_adapter or is_fine_tuned:
|
|
833
|
+
logger.info(f"Loading fine-tuned MLX model with LoRA adapters: {model_name}")
|
|
834
|
+
|
|
835
|
+
# For fine-tuned models, we need to load the base model and apply adapters
|
|
836
|
+
try:
|
|
837
|
+
# Try to load with adapter integration
|
|
838
|
+
from mlx_lm.tuner.utils import apply_lora_layers
|
|
839
|
+
|
|
840
|
+
# Load the model (should include merged weights)
|
|
841
|
+
model, tokenizer = mlx_load(str(path))
|
|
842
|
+
|
|
843
|
+
# The model should already have adapters merged since we saved it that way
|
|
844
|
+
logger.info("Fine-tuned MLX model loaded with integrated LoRA weights")
|
|
845
|
+
|
|
846
|
+
except ImportError:
|
|
847
|
+
# Fallback to regular loading
|
|
848
|
+
logger.warning("MLX LoRA utilities not available, loading as regular MLX model")
|
|
849
|
+
model, tokenizer = mlx_load(str(path))
|
|
850
|
+
else:
|
|
851
|
+
# Regular MLX model loading
|
|
852
|
+
model, tokenizer = mlx_load(str(path))
|
|
853
|
+
|
|
854
|
+
# Apply MLX accelerator optimizations if available
|
|
855
|
+
if self.mlx_accelerator:
|
|
856
|
+
# Silently apply optimizations - details shown in CLI
|
|
857
|
+
logger.info("Applying MLX optimizations (AMX, operation fusion)...")
|
|
858
|
+
model = self.mlx_accelerator.optimize_model(model)
|
|
859
|
+
|
|
860
|
+
# MLX LM models are already quantized during conversion
|
|
861
|
+
# No need to apply additional quantization
|
|
862
|
+
logger.info("MLX model already optimized with quantization")
|
|
863
|
+
|
|
864
|
+
# Evaluate parameters if they exist
|
|
865
|
+
if hasattr(model, 'parameters'):
|
|
866
|
+
mx.eval(model.parameters())
|
|
867
|
+
|
|
868
|
+
self.model_cache[model_name] = model
|
|
869
|
+
self.tokenizers[model_name] = tokenizer
|
|
870
|
+
|
|
871
|
+
# Load config
|
|
872
|
+
config = {}
|
|
873
|
+
config_path = path / 'config.json'
|
|
874
|
+
if config_path.exists():
|
|
875
|
+
with open(config_path) as f:
|
|
876
|
+
config = json.load(f)
|
|
877
|
+
|
|
878
|
+
# Create model info with accurate parameter detection
|
|
879
|
+
parameters = self.get_model_parameters_smart(path)
|
|
880
|
+
|
|
881
|
+
model_info = ModelInfo(
|
|
882
|
+
name=model_name,
|
|
883
|
+
path=path,
|
|
884
|
+
format=ModelFormat.MLX,
|
|
885
|
+
quantization=format_info['quantization'],
|
|
886
|
+
size_bytes=self._get_model_size(path),
|
|
887
|
+
parameters=parameters,
|
|
888
|
+
context_length=config.get('max_position_embeddings', 4096),
|
|
889
|
+
loaded_at=datetime.now(),
|
|
890
|
+
gpu_memory_used=self._estimate_mlx_memory(model),
|
|
891
|
+
tokenizer_path=path / 'tokenizer.json',
|
|
892
|
+
config=config
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
self.loaded_models[model_name] = model_info
|
|
896
|
+
return True, "MLX model loaded successfully"
|
|
897
|
+
|
|
898
|
+
except Exception as e:
|
|
899
|
+
return False, f"Failed to load MLX model: {str(e)}"
|
|
900
|
+
|
|
901
|
+
def _load_gguf(self, path: Path, model_name: str, format_info: Dict) -> Tuple[bool, str]:
|
|
902
|
+
"""Load GGUF format model using llama-cpp-python."""
|
|
903
|
+
try:
|
|
904
|
+
from llama_cpp import Llama
|
|
905
|
+
|
|
906
|
+
print("Loading GGUF model with llama.cpp...")
|
|
907
|
+
|
|
908
|
+
# Determine optimal parameters for Apple Silicon
|
|
909
|
+
n_gpu_layers = -1 # Use all layers on GPU
|
|
910
|
+
n_ctx = 4096 # Context size
|
|
911
|
+
n_batch = 512 # Batch size for prompt processing
|
|
912
|
+
|
|
913
|
+
# Load the model
|
|
914
|
+
model = Llama(
|
|
915
|
+
model_path=str(path),
|
|
916
|
+
n_gpu_layers=n_gpu_layers,
|
|
917
|
+
n_ctx=n_ctx,
|
|
918
|
+
n_batch=n_batch,
|
|
919
|
+
n_threads=8, # Use 8 threads for CPU operations
|
|
920
|
+
use_mlock=True, # Lock model in RAM
|
|
921
|
+
verbose=False
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
# Create a simple tokenizer wrapper for compatibility
|
|
925
|
+
class GGUFTokenizer:
|
|
926
|
+
def __init__(self, model):
|
|
927
|
+
self.model = model
|
|
928
|
+
self.pad_token = None
|
|
929
|
+
self.eos_token = None
|
|
930
|
+
|
|
931
|
+
def encode(self, text):
|
|
932
|
+
return self.model.tokenize(text.encode('utf-8'))
|
|
933
|
+
|
|
934
|
+
def decode(self, tokens):
|
|
935
|
+
return self.model.detokenize(tokens).decode('utf-8')
|
|
936
|
+
|
|
937
|
+
tokenizer = GGUFTokenizer(model)
|
|
938
|
+
|
|
939
|
+
self.model_cache[model_name] = model
|
|
940
|
+
self.tokenizers[model_name] = tokenizer
|
|
941
|
+
|
|
942
|
+
# Create model info
|
|
943
|
+
# Get model parameters from the model object if available
|
|
944
|
+
try:
|
|
945
|
+
# Try to get parameters from model metadata
|
|
946
|
+
n_params = getattr(model, 'n_params', 0)
|
|
947
|
+
if n_params == 0:
|
|
948
|
+
# Estimate based on model size
|
|
949
|
+
n_params = self._get_model_size(path) // 2 # Rough estimate
|
|
950
|
+
except:
|
|
951
|
+
n_params = self._get_model_size(path) // 2
|
|
952
|
+
|
|
953
|
+
model_info = ModelInfo(
|
|
954
|
+
name=model_name,
|
|
955
|
+
path=path,
|
|
956
|
+
format=ModelFormat.GGUF,
|
|
957
|
+
quantization=format_info['quantization'],
|
|
958
|
+
size_bytes=self._get_model_size(path),
|
|
959
|
+
parameters=n_params,
|
|
960
|
+
context_length=n_ctx,
|
|
961
|
+
loaded_at=datetime.now(),
|
|
962
|
+
gpu_memory_used=self._get_model_size(path), # GGUF loads full model
|
|
963
|
+
tokenizer_path=None,
|
|
964
|
+
config={'n_ctx': n_ctx, 'n_gpu_layers': n_gpu_layers}
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
self.loaded_models[model_name] = model_info
|
|
968
|
+
return True, "GGUF model loaded successfully"
|
|
969
|
+
|
|
970
|
+
except ImportError:
|
|
971
|
+
return False, "llama-cpp-python not installed. Install with: pip install llama-cpp-python"
|
|
972
|
+
except Exception as e:
|
|
973
|
+
return False, f"Failed to load GGUF model: {str(e)}"
|
|
974
|
+
|
|
975
|
+
def _load_safetensors(
|
|
976
|
+
self,
|
|
977
|
+
path: Path,
|
|
978
|
+
model_name: str,
|
|
979
|
+
format_info: Dict,
|
|
980
|
+
needs_quantization: bool = False,
|
|
981
|
+
quantization_mode: Optional[str] = None
|
|
982
|
+
) -> Tuple[bool, str]:
|
|
983
|
+
"""Load standard SafeTensors model with optional quantization."""
|
|
984
|
+
try:
|
|
985
|
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
|
986
|
+
|
|
987
|
+
# Check for cached quantized model first BEFORE loading anything
|
|
988
|
+
if needs_quantization and self.quantizer.config.cache_quantized:
|
|
989
|
+
cached = self.quantizer.load_cached_model(path, self.quantizer.config)
|
|
990
|
+
if cached:
|
|
991
|
+
print(f"Loading cached quantized model...")
|
|
992
|
+
# Load to CPU first with minimal memory usage
|
|
993
|
+
print(f"Creating model structure...")
|
|
994
|
+
with torch.device('cpu'):
|
|
995
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
996
|
+
str(path),
|
|
997
|
+
torch_dtype=torch.float16,
|
|
998
|
+
low_cpu_mem_usage=True,
|
|
999
|
+
trust_remote_code=True,
|
|
1000
|
+
device_map={'': 'cpu'} # Force CPU loading
|
|
1001
|
+
)
|
|
1002
|
+
|
|
1003
|
+
print(f"Applying cached quantized weights...")
|
|
1004
|
+
model.load_state_dict(cached[0])
|
|
1005
|
+
|
|
1006
|
+
# Now move to target device
|
|
1007
|
+
print(f"Moving to {device}...")
|
|
1008
|
+
model = model.to(device)
|
|
1009
|
+
quantization_info = cached[1]
|
|
1010
|
+
print(f"Quantized model loaded from cache")
|
|
1011
|
+
else:
|
|
1012
|
+
# Load and quantize
|
|
1013
|
+
print(f"Loading model for quantization...")
|
|
1014
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
1015
|
+
str(path),
|
|
1016
|
+
torch_dtype=torch.float16,
|
|
1017
|
+
low_cpu_mem_usage=True,
|
|
1018
|
+
trust_remote_code=True
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
if needs_quantization:
|
|
1022
|
+
print(f"Applying {quantization_mode} quantization...")
|
|
1023
|
+
gpu_status = self.gpu_validator.get_gpu_memory_status()
|
|
1024
|
+
model, quantization_info = self.quantizer.quantize_model(
|
|
1025
|
+
model,
|
|
1026
|
+
target_dtype=quantization_mode,
|
|
1027
|
+
available_memory_gb=gpu_status['available_gb'],
|
|
1028
|
+
model_size_gb=self._get_model_size(path) / (1024**3),
|
|
1029
|
+
target_device=device # Pass the target device (MPS)
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
# Cache the quantized model
|
|
1033
|
+
if self.quantizer.config.cache_quantized:
|
|
1034
|
+
cache_path = self.quantizer.cache_quantized_model(model, path, quantization_info)
|
|
1035
|
+
print(f"Cached quantized model for faster future loads")
|
|
1036
|
+
|
|
1037
|
+
model = model.to(device)
|
|
1038
|
+
else:
|
|
1039
|
+
# Normal loading without quantization
|
|
1040
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
1041
|
+
str(path),
|
|
1042
|
+
torch_dtype=torch.float16,
|
|
1043
|
+
low_cpu_mem_usage=True,
|
|
1044
|
+
trust_remote_code=True
|
|
1045
|
+
)
|
|
1046
|
+
model = model.to(device)
|
|
1047
|
+
model.eval() # Set model to evaluation mode
|
|
1048
|
+
|
|
1049
|
+
# Load tokenizer
|
|
1050
|
+
tokenizer = AutoTokenizer.from_pretrained(str(path), use_fast=True)
|
|
1051
|
+
if tokenizer.pad_token is None:
|
|
1052
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
1053
|
+
|
|
1054
|
+
self.model_cache[model_name] = model
|
|
1055
|
+
self.tokenizers[model_name] = tokenizer
|
|
1056
|
+
|
|
1057
|
+
# Load config
|
|
1058
|
+
config = {}
|
|
1059
|
+
config_path = path / 'config.json'
|
|
1060
|
+
if config_path.exists():
|
|
1061
|
+
with open(config_path) as f:
|
|
1062
|
+
config = json.load(f)
|
|
1063
|
+
|
|
1064
|
+
# Create model info with quantization details if applicable
|
|
1065
|
+
if needs_quantization:
|
|
1066
|
+
# Update quantization type based on what was applied
|
|
1067
|
+
actual_quantization = QuantizationType.INT8 if quantization_mode == 'int8' else QuantizationType.INT4
|
|
1068
|
+
# Calculate actual memory used after quantization
|
|
1069
|
+
memory_used = sum(
|
|
1070
|
+
p.numel() * 1 if hasattr(p, 'numel') else 0 # Quantized uses less bytes per element
|
|
1071
|
+
for p in model.parameters()
|
|
1072
|
+
)
|
|
1073
|
+
else:
|
|
1074
|
+
actual_quantization = format_info['quantization']
|
|
1075
|
+
memory_used = sum(p.element_size() * p.numel() for p in model.parameters())
|
|
1076
|
+
|
|
1077
|
+
# Use smart parameter detection instead of counting loaded model parameters
|
|
1078
|
+
parameters = self.get_model_parameters_smart(path)
|
|
1079
|
+
|
|
1080
|
+
model_info = ModelInfo(
|
|
1081
|
+
name=model_name,
|
|
1082
|
+
path=path,
|
|
1083
|
+
format=ModelFormat.SAFETENSORS,
|
|
1084
|
+
quantization=actual_quantization,
|
|
1085
|
+
size_bytes=self._get_model_size(path),
|
|
1086
|
+
parameters=parameters,
|
|
1087
|
+
context_length=config.get('max_position_embeddings', 4096),
|
|
1088
|
+
loaded_at=datetime.now(),
|
|
1089
|
+
gpu_memory_used=memory_used,
|
|
1090
|
+
tokenizer_path=path / 'tokenizer.json',
|
|
1091
|
+
config=config
|
|
1092
|
+
)
|
|
1093
|
+
|
|
1094
|
+
self.loaded_models[model_name] = model_info
|
|
1095
|
+
return True, "SafeTensors model loaded successfully"
|
|
1096
|
+
|
|
1097
|
+
except Exception as e:
|
|
1098
|
+
return False, f"Failed to load SafeTensors model: {str(e)}"
|
|
1099
|
+
|
|
1100
|
+
def _load_pytorch(
|
|
1101
|
+
self,
|
|
1102
|
+
path: Path,
|
|
1103
|
+
model_name: str,
|
|
1104
|
+
format_info: Dict,
|
|
1105
|
+
needs_quantization: bool = False,
|
|
1106
|
+
quantization_mode: Optional[str] = None
|
|
1107
|
+
) -> Tuple[bool, str]:
|
|
1108
|
+
"""Load PyTorch format model with optional quantization."""
|
|
1109
|
+
try:
|
|
1110
|
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
|
1111
|
+
|
|
1112
|
+
# Check for cached quantized model first
|
|
1113
|
+
if needs_quantization and self.quantizer.config.cache_quantized:
|
|
1114
|
+
cached = self.quantizer.load_cached_model(path, self.quantizer.config)
|
|
1115
|
+
if cached:
|
|
1116
|
+
print(f"Loading cached quantized model...")
|
|
1117
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
1118
|
+
str(path),
|
|
1119
|
+
torch_dtype=torch.float16,
|
|
1120
|
+
low_cpu_mem_usage=True,
|
|
1121
|
+
trust_remote_code=True
|
|
1122
|
+
)
|
|
1123
|
+
model.load_state_dict(cached[0])
|
|
1124
|
+
model = model.to(device)
|
|
1125
|
+
quantization_info = cached[1]
|
|
1126
|
+
else:
|
|
1127
|
+
# Load and quantize
|
|
1128
|
+
print(f"Loading model for quantization...")
|
|
1129
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
1130
|
+
str(path),
|
|
1131
|
+
torch_dtype=torch.float16,
|
|
1132
|
+
low_cpu_mem_usage=True,
|
|
1133
|
+
trust_remote_code=True
|
|
1134
|
+
)
|
|
1135
|
+
|
|
1136
|
+
if needs_quantization:
|
|
1137
|
+
print(f"Applying {quantization_mode} quantization...")
|
|
1138
|
+
gpu_status = self.gpu_validator.get_gpu_memory_status()
|
|
1139
|
+
model, quantization_info = self.quantizer.quantize_model(
|
|
1140
|
+
model,
|
|
1141
|
+
target_dtype=quantization_mode,
|
|
1142
|
+
available_memory_gb=gpu_status['available_gb'],
|
|
1143
|
+
model_size_gb=self._get_model_size(path) / (1024**3),
|
|
1144
|
+
target_device=device # Pass the target device (MPS)
|
|
1145
|
+
)
|
|
1146
|
+
|
|
1147
|
+
# Cache the quantized model
|
|
1148
|
+
if self.quantizer.config.cache_quantized:
|
|
1149
|
+
cache_path = self.quantizer.cache_quantized_model(model, path, quantization_info)
|
|
1150
|
+
print(f"Cached quantized model for faster future loads")
|
|
1151
|
+
|
|
1152
|
+
model = model.to(device)
|
|
1153
|
+
else:
|
|
1154
|
+
# Normal loading without quantization
|
|
1155
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
1156
|
+
str(path),
|
|
1157
|
+
torch_dtype=torch.float16,
|
|
1158
|
+
low_cpu_mem_usage=True,
|
|
1159
|
+
trust_remote_code=True
|
|
1160
|
+
)
|
|
1161
|
+
model = model.to(device)
|
|
1162
|
+
model.eval() # Set model to evaluation mode
|
|
1163
|
+
|
|
1164
|
+
# Load tokenizer
|
|
1165
|
+
tokenizer = AutoTokenizer.from_pretrained(str(path), use_fast=True)
|
|
1166
|
+
if tokenizer.pad_token is None:
|
|
1167
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
1168
|
+
|
|
1169
|
+
self.model_cache[model_name] = model
|
|
1170
|
+
self.tokenizers[model_name] = tokenizer
|
|
1171
|
+
|
|
1172
|
+
# Get config
|
|
1173
|
+
config = model.config.to_dict() if hasattr(model, 'config') else {}
|
|
1174
|
+
|
|
1175
|
+
# Create model info with quantization details if applicable
|
|
1176
|
+
if needs_quantization:
|
|
1177
|
+
# Update quantization type based on what was applied
|
|
1178
|
+
actual_quantization = QuantizationType.INT8 if quantization_mode == 'int8' else QuantizationType.INT4
|
|
1179
|
+
# Calculate actual memory used after quantization
|
|
1180
|
+
memory_used = sum(
|
|
1181
|
+
p.numel() * 1 if hasattr(p, 'numel') else 0 # Quantized uses less bytes per element
|
|
1182
|
+
for p in model.parameters()
|
|
1183
|
+
)
|
|
1184
|
+
else:
|
|
1185
|
+
actual_quantization = format_info['quantization']
|
|
1186
|
+
memory_used = sum(p.element_size() * p.numel() for p in model.parameters())
|
|
1187
|
+
|
|
1188
|
+
# Use smart parameter detection instead of counting loaded model parameters
|
|
1189
|
+
parameters = self.get_model_parameters_smart(path)
|
|
1190
|
+
|
|
1191
|
+
model_info = ModelInfo(
|
|
1192
|
+
name=model_name,
|
|
1193
|
+
path=path,
|
|
1194
|
+
format=ModelFormat.PYTORCH,
|
|
1195
|
+
quantization=actual_quantization,
|
|
1196
|
+
size_bytes=self._get_model_size(path),
|
|
1197
|
+
parameters=parameters,
|
|
1198
|
+
context_length=config.get('max_position_embeddings', 4096),
|
|
1199
|
+
loaded_at=datetime.now(),
|
|
1200
|
+
gpu_memory_used=memory_used,
|
|
1201
|
+
tokenizer_path=None,
|
|
1202
|
+
config=config
|
|
1203
|
+
)
|
|
1204
|
+
|
|
1205
|
+
self.loaded_models[model_name] = model_info
|
|
1206
|
+
return True, "PyTorch model loaded successfully"
|
|
1207
|
+
|
|
1208
|
+
except Exception as e:
|
|
1209
|
+
return False, f"Failed to load PyTorch model: {str(e)}"
|
|
1210
|
+
|
|
1211
|
+
def _load_quantized(self, path: Path, model_name: str, format_info: Dict) -> Tuple[bool, str]:
|
|
1212
|
+
"""Load quantized model (GPTQ/AWQ/etc) using appropriate libraries."""
|
|
1213
|
+
quant_type = format_info['quantization']
|
|
1214
|
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
|
1215
|
+
|
|
1216
|
+
print(f"Loading {quant_type.value} quantized model...")
|
|
1217
|
+
|
|
1218
|
+
# Try different quantization libraries based on detected type
|
|
1219
|
+
if quant_type in [QuantizationType.GPTQ, QuantizationType.INT4]:
|
|
1220
|
+
# Try GPTQ loader first
|
|
1221
|
+
try:
|
|
1222
|
+
from auto_gptq import AutoGPTQForCausalLM
|
|
1223
|
+
|
|
1224
|
+
model = AutoGPTQForCausalLM.from_quantized(
|
|
1225
|
+
str(path),
|
|
1226
|
+
device="cuda:0" if torch.cuda.is_available() else "cpu", # GPTQ doesn't support MPS directly
|
|
1227
|
+
use_safetensors=True,
|
|
1228
|
+
trust_remote_code=True,
|
|
1229
|
+
inject_fused_attention=False, # Disable for compatibility
|
|
1230
|
+
inject_fused_mlp=False
|
|
1231
|
+
)
|
|
1232
|
+
|
|
1233
|
+
# Move to MPS if needed
|
|
1234
|
+
if device.type == "mps" and not torch.cuda.is_available():
|
|
1235
|
+
model = model.cpu() # GPTQ models may need CPU fallback on Mac
|
|
1236
|
+
print("Note: GPTQ model loaded on CPU (MPS not fully supported)")
|
|
1237
|
+
|
|
1238
|
+
tokenizer = AutoTokenizer.from_pretrained(str(path), use_fast=True)
|
|
1239
|
+
if tokenizer.pad_token is None:
|
|
1240
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
1241
|
+
|
|
1242
|
+
self.model_cache[model_name] = model
|
|
1243
|
+
self.tokenizers[model_name] = tokenizer
|
|
1244
|
+
|
|
1245
|
+
config = self._load_config(path)
|
|
1246
|
+
model_info = self._create_model_info(
|
|
1247
|
+
model_name, path, ModelFormat.QUANTIZED, quant_type,
|
|
1248
|
+
model, config
|
|
1249
|
+
)
|
|
1250
|
+
|
|
1251
|
+
self.loaded_models[model_name] = model_info
|
|
1252
|
+
return True, f"GPTQ quantized model loaded successfully"
|
|
1253
|
+
|
|
1254
|
+
except ImportError:
|
|
1255
|
+
print("GPTQ library not available, trying alternative methods...")
|
|
1256
|
+
except Exception as e:
|
|
1257
|
+
print(f"GPTQ loading failed: {str(e)[:100]}")
|
|
1258
|
+
|
|
1259
|
+
if quant_type == QuantizationType.AWQ:
|
|
1260
|
+
# Try AWQ loader
|
|
1261
|
+
try:
|
|
1262
|
+
from awq import AutoAWQForCausalLM
|
|
1263
|
+
|
|
1264
|
+
model = AutoAWQForCausalLM.from_quantized(
|
|
1265
|
+
str(path),
|
|
1266
|
+
fuse_layers=False, # Disable for compatibility
|
|
1267
|
+
trust_remote_code=True
|
|
1268
|
+
)
|
|
1269
|
+
|
|
1270
|
+
tokenizer = AutoTokenizer.from_pretrained(str(path), use_fast=True)
|
|
1271
|
+
if tokenizer.pad_token is None:
|
|
1272
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
1273
|
+
|
|
1274
|
+
self.model_cache[model_name] = model
|
|
1275
|
+
self.tokenizers[model_name] = tokenizer
|
|
1276
|
+
|
|
1277
|
+
config = self._load_config(path)
|
|
1278
|
+
model_info = self._create_model_info(
|
|
1279
|
+
model_name, path, ModelFormat.QUANTIZED, quant_type,
|
|
1280
|
+
model, config
|
|
1281
|
+
)
|
|
1282
|
+
|
|
1283
|
+
self.loaded_models[model_name] = model_info
|
|
1284
|
+
return True, f"AWQ quantized model loaded successfully"
|
|
1285
|
+
|
|
1286
|
+
except ImportError:
|
|
1287
|
+
print("AWQ library not available, trying alternative methods...")
|
|
1288
|
+
except Exception as e:
|
|
1289
|
+
print(f"AWQ loading failed: {str(e)[:100]}")
|
|
1290
|
+
|
|
1291
|
+
# Try using accelerate for general quantized models
|
|
1292
|
+
try:
|
|
1293
|
+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
|
1294
|
+
from transformers import AutoConfig
|
|
1295
|
+
|
|
1296
|
+
print("Attempting to load with accelerate library...")
|
|
1297
|
+
|
|
1298
|
+
# Load config first
|
|
1299
|
+
config = AutoConfig.from_pretrained(str(path), trust_remote_code=True)
|
|
1300
|
+
|
|
1301
|
+
# Initialize model with empty weights
|
|
1302
|
+
with init_empty_weights():
|
|
1303
|
+
model = AutoModelForCausalLM.from_config(
|
|
1304
|
+
config,
|
|
1305
|
+
torch_dtype=torch.float16,
|
|
1306
|
+
trust_remote_code=True
|
|
1307
|
+
)
|
|
1308
|
+
|
|
1309
|
+
# Determine the checkpoint files
|
|
1310
|
+
checkpoint_files = list(path.glob("*.safetensors"))
|
|
1311
|
+
if not checkpoint_files:
|
|
1312
|
+
checkpoint_files = list(path.glob("pytorch_model*.bin"))
|
|
1313
|
+
|
|
1314
|
+
if not checkpoint_files:
|
|
1315
|
+
raise ValueError("No model files found")
|
|
1316
|
+
|
|
1317
|
+
# Create proper device map for MPS
|
|
1318
|
+
if device.type == "mps":
|
|
1319
|
+
device_map = {"": "cpu"} # Load to CPU first for MPS compatibility
|
|
1320
|
+
else:
|
|
1321
|
+
device_map = "auto"
|
|
1322
|
+
|
|
1323
|
+
# Load and dispatch to device
|
|
1324
|
+
model = load_checkpoint_and_dispatch(
|
|
1325
|
+
model,
|
|
1326
|
+
checkpoint=str(path), # Directory containing the model files
|
|
1327
|
+
device_map=device_map,
|
|
1328
|
+
dtype=torch.float16,
|
|
1329
|
+
offload_folder=str(self.config.model.model_cache_dir / "offload")
|
|
1330
|
+
)
|
|
1331
|
+
|
|
1332
|
+
# Move to MPS if needed
|
|
1333
|
+
if device.type == "mps" and device_map == {"": "cpu"}:
|
|
1334
|
+
model = model.to(device)
|
|
1335
|
+
|
|
1336
|
+
tokenizer = AutoTokenizer.from_pretrained(str(path))
|
|
1337
|
+
if tokenizer.pad_token is None:
|
|
1338
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
1339
|
+
|
|
1340
|
+
self.model_cache[model_name] = model
|
|
1341
|
+
self.tokenizers[model_name] = tokenizer
|
|
1342
|
+
|
|
1343
|
+
config_dict = self._load_config(path)
|
|
1344
|
+
model_info = self._create_model_info(
|
|
1345
|
+
model_name, path, ModelFormat.QUANTIZED, quant_type,
|
|
1346
|
+
model, config_dict
|
|
1347
|
+
)
|
|
1348
|
+
|
|
1349
|
+
self.loaded_models[model_name] = model_info
|
|
1350
|
+
return True, f"Quantized model loaded with accelerate"
|
|
1351
|
+
|
|
1352
|
+
except Exception as e:
|
|
1353
|
+
print(f"Accelerate loading failed: {str(e)[:100]}")
|
|
1354
|
+
|
|
1355
|
+
# Try bitsandbytes for 4-bit/8-bit models
|
|
1356
|
+
try:
|
|
1357
|
+
from transformers import BitsAndBytesConfig
|
|
1358
|
+
|
|
1359
|
+
print("Attempting to load with bitsandbytes quantization...")
|
|
1360
|
+
|
|
1361
|
+
bnb_config = BitsAndBytesConfig(
|
|
1362
|
+
load_in_4bit=True if quant_type == QuantizationType.INT4 else False,
|
|
1363
|
+
load_in_8bit=True if quant_type == QuantizationType.INT8 else False,
|
|
1364
|
+
bnb_4bit_compute_dtype=torch.float16,
|
|
1365
|
+
bnb_4bit_use_double_quant=True,
|
|
1366
|
+
bnb_4bit_quant_type="nf4"
|
|
1367
|
+
)
|
|
1368
|
+
|
|
1369
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
1370
|
+
str(path),
|
|
1371
|
+
quantization_config=bnb_config,
|
|
1372
|
+
torch_dtype=torch.float16,
|
|
1373
|
+
trust_remote_code=True,
|
|
1374
|
+
device_map="auto" if torch.cuda.is_available() else {"": device}
|
|
1375
|
+
)
|
|
1376
|
+
|
|
1377
|
+
tokenizer = AutoTokenizer.from_pretrained(str(path))
|
|
1378
|
+
if tokenizer.pad_token is None:
|
|
1379
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
1380
|
+
|
|
1381
|
+
self.model_cache[model_name] = model
|
|
1382
|
+
self.tokenizers[model_name] = tokenizer
|
|
1383
|
+
|
|
1384
|
+
config = self._load_config(path)
|
|
1385
|
+
model_info = self._create_model_info(
|
|
1386
|
+
model_name, path, ModelFormat.QUANTIZED, quant_type,
|
|
1387
|
+
model, config
|
|
1388
|
+
)
|
|
1389
|
+
|
|
1390
|
+
self.loaded_models[model_name] = model_info
|
|
1391
|
+
return True, f"Quantized model loaded with bitsandbytes"
|
|
1392
|
+
|
|
1393
|
+
except Exception as e:
|
|
1394
|
+
print(f"Bitsandbytes loading failed: {str(e)[:100]}")
|
|
1395
|
+
|
|
1396
|
+
# If all methods fail, provide guidance
|
|
1397
|
+
return False, f"Failed to load {quant_type.value} quantized model. The model format may not be compatible with Apple Silicon."
|
|
1398
|
+
|
|
1399
|
+
def _load_config(self, path: Path) -> Dict[str, Any]:
|
|
1400
|
+
"""Load config.json from model path."""
|
|
1401
|
+
config_path = path / 'config.json'
|
|
1402
|
+
if config_path.exists():
|
|
1403
|
+
with open(config_path) as f:
|
|
1404
|
+
return json.load(f)
|
|
1405
|
+
return {}
|
|
1406
|
+
|
|
1407
|
+
def _create_model_info(
|
|
1408
|
+
self,
|
|
1409
|
+
model_name: str,
|
|
1410
|
+
path: Path,
|
|
1411
|
+
format: ModelFormat,
|
|
1412
|
+
quantization: QuantizationType,
|
|
1413
|
+
model: Any,
|
|
1414
|
+
config: Dict[str, Any]
|
|
1415
|
+
) -> ModelInfo:
|
|
1416
|
+
"""Create ModelInfo object for a loaded model."""
|
|
1417
|
+
# Use smart parameter detection instead of loading model parameters
|
|
1418
|
+
parameters = self.get_model_parameters_smart(path)
|
|
1419
|
+
|
|
1420
|
+
# Calculate memory usage
|
|
1421
|
+
if hasattr(model, 'parameters'):
|
|
1422
|
+
try:
|
|
1423
|
+
memory_used = sum(p.element_size() * p.numel() for p in model.parameters())
|
|
1424
|
+
except:
|
|
1425
|
+
memory_used = self._get_model_size(path)
|
|
1426
|
+
else:
|
|
1427
|
+
memory_used = self._get_model_size(path)
|
|
1428
|
+
|
|
1429
|
+
return ModelInfo(
|
|
1430
|
+
name=model_name,
|
|
1431
|
+
path=path,
|
|
1432
|
+
format=format,
|
|
1433
|
+
quantization=quantization,
|
|
1434
|
+
size_bytes=self._get_model_size(path),
|
|
1435
|
+
parameters=parameters,
|
|
1436
|
+
context_length=config.get('max_position_embeddings', 4096),
|
|
1437
|
+
loaded_at=datetime.now(),
|
|
1438
|
+
gpu_memory_used=memory_used,
|
|
1439
|
+
tokenizer_path=path / 'tokenizer.json' if (path / 'tokenizer.json').exists() else None,
|
|
1440
|
+
config=config
|
|
1441
|
+
)
|
|
1442
|
+
|
|
1443
|
+
def _count_mlx_parameters(self, model: Any) -> int:
|
|
1444
|
+
"""Count parameters in MLX model."""
|
|
1445
|
+
try:
|
|
1446
|
+
if hasattr(model, 'num_parameters'):
|
|
1447
|
+
return model.num_parameters()
|
|
1448
|
+
elif hasattr(model, 'parameters'):
|
|
1449
|
+
params = model.parameters()
|
|
1450
|
+
if isinstance(params, dict):
|
|
1451
|
+
return sum(p.size for p in params.values())
|
|
1452
|
+
return 0
|
|
1453
|
+
except:
|
|
1454
|
+
return 0
|
|
1455
|
+
|
|
1456
|
+
def _estimate_mlx_memory(self, model: Any) -> int:
|
|
1457
|
+
"""Estimate memory usage of MLX model."""
|
|
1458
|
+
try:
|
|
1459
|
+
if hasattr(model, 'parameters'):
|
|
1460
|
+
params = model.parameters()
|
|
1461
|
+
if isinstance(params, dict):
|
|
1462
|
+
return sum(p.nbytes if hasattr(p, 'nbytes') else 0 for p in params.values())
|
|
1463
|
+
return 0
|
|
1464
|
+
except:
|
|
1465
|
+
return 0
|
|
1466
|
+
|
|
1467
|
+
def unload_model(self, model_name: str) -> Tuple[bool, str]:
|
|
1468
|
+
"""Unload a model from memory."""
|
|
1469
|
+
if model_name not in self.loaded_models:
|
|
1470
|
+
return False, f"Model '{model_name}' not loaded"
|
|
1471
|
+
|
|
1472
|
+
try:
|
|
1473
|
+
# Special cleanup for GGUF models (llama-cpp-python)
|
|
1474
|
+
if model_name in self.model_cache:
|
|
1475
|
+
model = self.model_cache[model_name]
|
|
1476
|
+
model_info = self.loaded_models.get(model_name)
|
|
1477
|
+
|
|
1478
|
+
# Clean up GGUF models properly to avoid memory leaks
|
|
1479
|
+
if model_info and model_info.format == ModelFormat.GGUF:
|
|
1480
|
+
try:
|
|
1481
|
+
# llama-cpp-python models have a close() method
|
|
1482
|
+
if hasattr(model, 'close'):
|
|
1483
|
+
model.close()
|
|
1484
|
+
# Also try to explicitly delete the model
|
|
1485
|
+
del model
|
|
1486
|
+
except Exception as e:
|
|
1487
|
+
print(f"Warning: Error closing GGUF model: {e}")
|
|
1488
|
+
|
|
1489
|
+
# Remove from cache
|
|
1490
|
+
del self.model_cache[model_name]
|
|
1491
|
+
|
|
1492
|
+
# Remove tokenizer
|
|
1493
|
+
if model_name in self.tokenizers:
|
|
1494
|
+
del self.tokenizers[model_name]
|
|
1495
|
+
|
|
1496
|
+
# Remove model info
|
|
1497
|
+
del self.loaded_models[model_name]
|
|
1498
|
+
|
|
1499
|
+
# Update current model
|
|
1500
|
+
if self.current_model == model_name:
|
|
1501
|
+
self.current_model = None
|
|
1502
|
+
|
|
1503
|
+
# Clear GPU cache
|
|
1504
|
+
if torch.cuda.is_available():
|
|
1505
|
+
torch.cuda.empty_cache()
|
|
1506
|
+
elif torch.backends.mps.is_available():
|
|
1507
|
+
# Clear MPS cache on Apple Silicon
|
|
1508
|
+
try:
|
|
1509
|
+
torch.mps.empty_cache()
|
|
1510
|
+
except:
|
|
1511
|
+
pass # MPS cache clearing might not be available in all versions
|
|
1512
|
+
|
|
1513
|
+
# Force garbage collection for thorough cleanup
|
|
1514
|
+
import gc
|
|
1515
|
+
gc.collect()
|
|
1516
|
+
|
|
1517
|
+
return True, f"Model '{model_name}' unloaded"
|
|
1518
|
+
|
|
1519
|
+
except Exception as e:
|
|
1520
|
+
return False, f"Error unloading model: {str(e)}"
|
|
1521
|
+
|
|
1522
|
+
def list_models(self) -> List[Dict[str, Any]]:
|
|
1523
|
+
"""List all loaded models."""
|
|
1524
|
+
return [model.to_dict() for model in self.loaded_models.values()]
|
|
1525
|
+
|
|
1526
|
+
def discover_available_models(self) -> List[Dict[str, Any]]:
|
|
1527
|
+
"""Discover all available models including MLX converted ones."""
|
|
1528
|
+
available = []
|
|
1529
|
+
model_path = self.config.model.model_path.expanduser().resolve()
|
|
1530
|
+
|
|
1531
|
+
if not model_path.exists():
|
|
1532
|
+
return available
|
|
1533
|
+
|
|
1534
|
+
# Also check MLX models directory for fine-tuned models
|
|
1535
|
+
mlx_path = Path.home() / ".cortex" / "mlx_models"
|
|
1536
|
+
|
|
1537
|
+
# First, get all MLX converted models to check for optimized versions
|
|
1538
|
+
mlx_models = self.mlx_converter.list_converted_models()
|
|
1539
|
+
mlx_cache_map = {} # Map original paths to MLX versions
|
|
1540
|
+
|
|
1541
|
+
# Build a map of original model paths to their MLX versions
|
|
1542
|
+
for name, info in mlx_models.items():
|
|
1543
|
+
# Extract original path from MLX model name
|
|
1544
|
+
if name.startswith("_Users_") and ("_4bit" in name or "_5bit" in name or "_8bit" in name):
|
|
1545
|
+
base_name = name.replace("_4bit", "").replace("_5bit", "").replace("_8bit", "")
|
|
1546
|
+
original_path = "/" + base_name[1:].replace("_", "/")
|
|
1547
|
+
mlx_cache_map[original_path] = {
|
|
1548
|
+
'mlx_name': name,
|
|
1549
|
+
'mlx_path': info['path'],
|
|
1550
|
+
'quantization': info.get('quantization', 4),
|
|
1551
|
+
'size_gb': info.get('size_gb', 0)
|
|
1552
|
+
}
|
|
1553
|
+
|
|
1554
|
+
# Search for models in the models directory
|
|
1555
|
+
for item in model_path.iterdir():
|
|
1556
|
+
if item.is_file() and item.suffix == '.gguf':
|
|
1557
|
+
# GGUF file
|
|
1558
|
+
size_gb = item.stat().st_size / (1024**3)
|
|
1559
|
+
available.append({
|
|
1560
|
+
'name': item.stem,
|
|
1561
|
+
'path': str(item),
|
|
1562
|
+
'relative_path': item.name,
|
|
1563
|
+
'format': 'GGUF',
|
|
1564
|
+
'size_gb': round(size_gb, 2),
|
|
1565
|
+
'mlx_optimized': False,
|
|
1566
|
+
'mlx_available': False
|
|
1567
|
+
})
|
|
1568
|
+
elif item.is_dir():
|
|
1569
|
+
# Check if it's a model directory
|
|
1570
|
+
format_info = self._detect_format(item)
|
|
1571
|
+
if format_info['format'] != ModelFormat.UNKNOWN:
|
|
1572
|
+
# Check if this model has an MLX optimized version
|
|
1573
|
+
full_path = str(item.resolve())
|
|
1574
|
+
has_mlx = full_path in mlx_cache_map
|
|
1575
|
+
|
|
1576
|
+
if has_mlx:
|
|
1577
|
+
# Use the MLX version's info
|
|
1578
|
+
mlx_info = mlx_cache_map[full_path]
|
|
1579
|
+
available.append({
|
|
1580
|
+
'name': item.name,
|
|
1581
|
+
'path': str(item), # Keep original path for compatibility
|
|
1582
|
+
'relative_path': item.name,
|
|
1583
|
+
'format': 'MLX-optimized',
|
|
1584
|
+
'size_gb': round(mlx_info['size_gb'], 2) if mlx_info['size_gb'] > 0 else round(self._get_model_size(item) / (1024**3), 2),
|
|
1585
|
+
'mlx_optimized': True,
|
|
1586
|
+
'mlx_path': mlx_info['mlx_path'],
|
|
1587
|
+
'mlx_name': mlx_info['mlx_name'],
|
|
1588
|
+
'quantization': f"{mlx_info['quantization']}-bit",
|
|
1589
|
+
'original_format': format_info['format'].value.upper()
|
|
1590
|
+
})
|
|
1591
|
+
else:
|
|
1592
|
+
# Regular model without MLX optimization
|
|
1593
|
+
size_gb = self._get_model_size(item) / (1024**3)
|
|
1594
|
+
format_str = format_info['format'].value.upper()
|
|
1595
|
+
if format_info['quantization'] != QuantizationType.NONE:
|
|
1596
|
+
format_str = f"{format_str} ({format_info['quantization'].value})"
|
|
1597
|
+
|
|
1598
|
+
available.append({
|
|
1599
|
+
'name': item.name,
|
|
1600
|
+
'path': str(item),
|
|
1601
|
+
'relative_path': item.name,
|
|
1602
|
+
'format': format_str,
|
|
1603
|
+
'size_gb': round(size_gb, 2),
|
|
1604
|
+
'mlx_optimized': False,
|
|
1605
|
+
'mlx_available': format_info['format'] in [
|
|
1606
|
+
ModelFormat.SAFETENSORS,
|
|
1607
|
+
ModelFormat.PYTORCH
|
|
1608
|
+
]
|
|
1609
|
+
})
|
|
1610
|
+
|
|
1611
|
+
# Add fine-tuned models from MLX directory that aren't already included
|
|
1612
|
+
if mlx_path.exists():
|
|
1613
|
+
for item in mlx_path.iterdir():
|
|
1614
|
+
if item.is_dir():
|
|
1615
|
+
# Check if it's a fine-tuned model that's not already in the list
|
|
1616
|
+
if (item / 'fine_tuned.marker').exists():
|
|
1617
|
+
# This is a fine-tuned model
|
|
1618
|
+
already_added = any(model['name'] == item.name for model in available)
|
|
1619
|
+
if not already_added:
|
|
1620
|
+
size_gb = self._get_model_size(item) / (1024**3)
|
|
1621
|
+
|
|
1622
|
+
# Read metadata from marker file
|
|
1623
|
+
base_model = "Unknown"
|
|
1624
|
+
try:
|
|
1625
|
+
with open(item / 'fine_tuned.marker', 'r') as f:
|
|
1626
|
+
content = f.read()
|
|
1627
|
+
for line in content.split('\n'):
|
|
1628
|
+
if 'LoRA fine-tuned version of' in line:
|
|
1629
|
+
base_model = line.split('LoRA fine-tuned version of ')[-1].strip()
|
|
1630
|
+
break
|
|
1631
|
+
except:
|
|
1632
|
+
pass
|
|
1633
|
+
|
|
1634
|
+
available.append({
|
|
1635
|
+
'name': item.name,
|
|
1636
|
+
'path': str(item),
|
|
1637
|
+
'relative_path': item.name,
|
|
1638
|
+
'format': 'MLX Fine-tuned',
|
|
1639
|
+
'size_gb': round(size_gb, 2),
|
|
1640
|
+
'mlx_optimized': True,
|
|
1641
|
+
'is_fine_tuned': True,
|
|
1642
|
+
'base_model': base_model,
|
|
1643
|
+
'fine_tuning_method': 'LoRA'
|
|
1644
|
+
})
|
|
1645
|
+
|
|
1646
|
+
# Sort by name
|
|
1647
|
+
available.sort(key=lambda x: x['name'].lower())
|
|
1648
|
+
return available
|
|
1649
|
+
|
|
1650
|
+
def get_current_model(self) -> Optional[ModelInfo]:
|
|
1651
|
+
"""Get currently active model."""
|
|
1652
|
+
if self.current_model:
|
|
1653
|
+
return self.loaded_models.get(self.current_model)
|
|
1654
|
+
return None
|
|
1655
|
+
|
|
1656
|
+
def switch_model(self, model_name: str) -> Tuple[bool, str]:
|
|
1657
|
+
"""Switch to a different loaded model."""
|
|
1658
|
+
if model_name not in self.loaded_models:
|
|
1659
|
+
return False, f"Model '{model_name}' not loaded"
|
|
1660
|
+
|
|
1661
|
+
self.current_model = model_name
|
|
1662
|
+
return True, f"Switched to model '{model_name}'"
|
|
1663
|
+
|
|
1664
|
+
def get_memory_status(self) -> Dict[str, Any]:
|
|
1665
|
+
"""Get GPU memory status with MLX details."""
|
|
1666
|
+
status = self.gpu_validator.get_gpu_memory_status()
|
|
1667
|
+
|
|
1668
|
+
total_model_memory = sum(
|
|
1669
|
+
model.gpu_memory_used for model in self.loaded_models.values()
|
|
1670
|
+
)
|
|
1671
|
+
|
|
1672
|
+
status['models_loaded'] = len(self.loaded_models)
|
|
1673
|
+
status['model_memory_gb'] = total_model_memory / (1024**3)
|
|
1674
|
+
status['current_model'] = self.current_model
|
|
1675
|
+
|
|
1676
|
+
# Add MLX-specific info
|
|
1677
|
+
mlx_models = [m for m in self.loaded_models.values() if m.format == ModelFormat.MLX]
|
|
1678
|
+
if mlx_models:
|
|
1679
|
+
status['mlx_models'] = len(mlx_models)
|
|
1680
|
+
status['mlx_memory_gb'] = sum(m.gpu_memory_used for m in mlx_models) / (1024**3)
|
|
1681
|
+
|
|
1682
|
+
if self.memory_pool:
|
|
1683
|
+
pool_stats = self.memory_pool.get_stats()
|
|
1684
|
+
status['memory_pool'] = {
|
|
1685
|
+
'allocated_gb': pool_stats['allocated_memory'] / (1024**3),
|
|
1686
|
+
'free_gb': pool_stats['free_memory'] / (1024**3),
|
|
1687
|
+
'fragmentation': pool_stats['fragmentation'],
|
|
1688
|
+
'total_blocks': pool_stats['total_blocks'],
|
|
1689
|
+
'zero_copy_enabled': pool_stats.get('zero_copy', False)
|
|
1690
|
+
}
|
|
1691
|
+
|
|
1692
|
+
# Add MLX accelerator status
|
|
1693
|
+
if self.mlx_accelerator:
|
|
1694
|
+
status['mlx_acceleration'] = {
|
|
1695
|
+
'amx_enabled': self.mlx_accelerator.config.use_amx,
|
|
1696
|
+
'operation_fusion': self.mlx_accelerator.config.fuse_operations,
|
|
1697
|
+
'kv_cache_size': self.mlx_accelerator.config.kv_cache_size,
|
|
1698
|
+
'quantization_bits': self.mlx_accelerator.config.quantization_bits
|
|
1699
|
+
}
|
|
1700
|
+
|
|
1701
|
+
return status
|
|
1702
|
+
|
|
1703
|
+
def detect_model_parameters(self, model_path: Path) -> Optional[int]:
|
|
1704
|
+
"""
|
|
1705
|
+
Detect the actual number of parameters in a model.
|
|
1706
|
+
|
|
1707
|
+
Uses proper parameter counting that handles quantization, LoRA adapters,
|
|
1708
|
+
and non-weight files correctly.
|
|
1709
|
+
|
|
1710
|
+
Returns:
|
|
1711
|
+
Number of parameters, or None if detection fails
|
|
1712
|
+
"""
|
|
1713
|
+
try:
|
|
1714
|
+
# Check cache first
|
|
1715
|
+
cache_key = f"{model_path.resolve()}:{model_path.stat().st_mtime}"
|
|
1716
|
+
cached_result = self._get_cached_parameter_count(cache_key)
|
|
1717
|
+
if cached_result is not None:
|
|
1718
|
+
logger.debug(f"Using cached parameter count: {cached_result:,}")
|
|
1719
|
+
return cached_result
|
|
1720
|
+
|
|
1721
|
+
# Detect model format and apply appropriate detection method
|
|
1722
|
+
format_info = self._detect_format(model_path)
|
|
1723
|
+
param_count = None
|
|
1724
|
+
|
|
1725
|
+
if format_info['format'] == ModelFormat.SAFETENSORS:
|
|
1726
|
+
param_count = self._detect_safetensors_parameters(model_path)
|
|
1727
|
+
elif format_info['format'] == ModelFormat.MLX:
|
|
1728
|
+
param_count = self._detect_mlx_parameters(model_path)
|
|
1729
|
+
elif format_info['format'] == ModelFormat.PYTORCH:
|
|
1730
|
+
param_count = self._detect_pytorch_parameters(model_path)
|
|
1731
|
+
|
|
1732
|
+
# Fallback to config.json analysis
|
|
1733
|
+
if param_count is None:
|
|
1734
|
+
param_count = self._detect_config_parameters(model_path)
|
|
1735
|
+
|
|
1736
|
+
# Cache the result if successful
|
|
1737
|
+
if param_count is not None:
|
|
1738
|
+
self._cache_parameter_count(cache_key, param_count)
|
|
1739
|
+
logger.info(f"Detected {param_count:,} parameters in {model_path.name}")
|
|
1740
|
+
else:
|
|
1741
|
+
logger.warning(f"Could not detect parameters for {model_path.name}")
|
|
1742
|
+
|
|
1743
|
+
return param_count
|
|
1744
|
+
|
|
1745
|
+
except Exception as e:
|
|
1746
|
+
logger.warning(f"Parameter detection failed for {model_path}: {e}")
|
|
1747
|
+
return None
|
|
1748
|
+
|
|
1749
|
+
def _get_cached_parameter_count(self, cache_key: str) -> Optional[int]:
|
|
1750
|
+
"""Get cached parameter count."""
|
|
1751
|
+
cache_file = self.config.model.model_cache_dir / "parameter_counts.json"
|
|
1752
|
+
if not cache_file.exists():
|
|
1753
|
+
return None
|
|
1754
|
+
|
|
1755
|
+
try:
|
|
1756
|
+
with open(cache_file, 'r') as f:
|
|
1757
|
+
cache = json.load(f)
|
|
1758
|
+
return cache.get(cache_key)
|
|
1759
|
+
except:
|
|
1760
|
+
return None
|
|
1761
|
+
|
|
1762
|
+
def _cache_parameter_count(self, cache_key: str, param_count: int) -> None:
|
|
1763
|
+
"""Cache parameter count for faster future lookups."""
|
|
1764
|
+
cache_file = self.config.model.model_cache_dir / "parameter_counts.json"
|
|
1765
|
+
|
|
1766
|
+
# Load existing cache
|
|
1767
|
+
cache = {}
|
|
1768
|
+
if cache_file.exists():
|
|
1769
|
+
try:
|
|
1770
|
+
with open(cache_file, 'r') as f:
|
|
1771
|
+
cache = json.load(f)
|
|
1772
|
+
except:
|
|
1773
|
+
pass
|
|
1774
|
+
|
|
1775
|
+
# Update cache
|
|
1776
|
+
cache[cache_key] = param_count
|
|
1777
|
+
|
|
1778
|
+
# Keep only recent entries (last 100)
|
|
1779
|
+
if len(cache) > 100:
|
|
1780
|
+
sorted_items = sorted(cache.items(), key=lambda x: x[0])[-100:]
|
|
1781
|
+
cache = dict(sorted_items)
|
|
1782
|
+
|
|
1783
|
+
# Save cache
|
|
1784
|
+
try:
|
|
1785
|
+
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
|
1786
|
+
with open(cache_file, 'w') as f:
|
|
1787
|
+
json.dump(cache, f)
|
|
1788
|
+
except Exception as e:
|
|
1789
|
+
logger.warning(f"Failed to cache parameter count: {e}")
|
|
1790
|
+
|
|
1791
|
+
def _detect_safetensors_parameters(self, model_path: Path) -> Optional[int]:
|
|
1792
|
+
"""Detect parameters by reading SafeTensors headers."""
|
|
1793
|
+
try:
|
|
1794
|
+
safetensor_files = list(model_path.glob("*.safetensors"))
|
|
1795
|
+
if not safetensor_files:
|
|
1796
|
+
return None
|
|
1797
|
+
|
|
1798
|
+
total_params = 0
|
|
1799
|
+
|
|
1800
|
+
for st_file in safetensor_files:
|
|
1801
|
+
# Skip adapter files for base model parameter counting
|
|
1802
|
+
if "adapter" in st_file.name.lower():
|
|
1803
|
+
continue
|
|
1804
|
+
|
|
1805
|
+
# Read SafeTensors header to get tensor shapes without loading weights
|
|
1806
|
+
params = self._read_safetensors_header(st_file)
|
|
1807
|
+
if params is not None:
|
|
1808
|
+
total_params += params
|
|
1809
|
+
else:
|
|
1810
|
+
logger.warning(f"Could not read SafeTensors header from {st_file.name}")
|
|
1811
|
+
return None
|
|
1812
|
+
|
|
1813
|
+
return total_params if total_params > 0 else None
|
|
1814
|
+
|
|
1815
|
+
except Exception as e:
|
|
1816
|
+
logger.warning(f"SafeTensors parameter detection failed: {e}")
|
|
1817
|
+
return None
|
|
1818
|
+
|
|
1819
|
+
def _read_safetensors_header(self, file_path: Path) -> Optional[int]:
|
|
1820
|
+
"""Read parameter count from SafeTensors file header without loading the full file."""
|
|
1821
|
+
try:
|
|
1822
|
+
with open(file_path, 'rb') as f:
|
|
1823
|
+
# Read the header length (first 8 bytes)
|
|
1824
|
+
header_size_bytes = f.read(8)
|
|
1825
|
+
if len(header_size_bytes) < 8:
|
|
1826
|
+
return None
|
|
1827
|
+
|
|
1828
|
+
header_size = struct.unpack('<Q', header_size_bytes)[0]
|
|
1829
|
+
|
|
1830
|
+
# Read the header JSON
|
|
1831
|
+
header_json = f.read(header_size).decode('utf-8')
|
|
1832
|
+
header = json.loads(header_json)
|
|
1833
|
+
|
|
1834
|
+
# Count parameters from tensor shapes
|
|
1835
|
+
total_params = 0
|
|
1836
|
+
for tensor_name, tensor_info in header.items():
|
|
1837
|
+
if tensor_name == "__metadata__":
|
|
1838
|
+
continue
|
|
1839
|
+
|
|
1840
|
+
# Skip non-parameter tensors (buffers, etc.)
|
|
1841
|
+
if self._is_parameter_tensor(tensor_name):
|
|
1842
|
+
shape = tensor_info.get('shape', [])
|
|
1843
|
+
if shape:
|
|
1844
|
+
tensor_params = 1
|
|
1845
|
+
for dim in shape:
|
|
1846
|
+
tensor_params *= dim
|
|
1847
|
+
total_params += tensor_params
|
|
1848
|
+
|
|
1849
|
+
return total_params
|
|
1850
|
+
|
|
1851
|
+
except Exception as e:
|
|
1852
|
+
logger.debug(f"Failed to read SafeTensors header from {file_path}: {e}")
|
|
1853
|
+
return None
|
|
1854
|
+
|
|
1855
|
+
def _is_parameter_tensor(self, tensor_name: str) -> bool:
|
|
1856
|
+
"""Check if a tensor name represents a model parameter (not a buffer)."""
|
|
1857
|
+
# Common parameter patterns
|
|
1858
|
+
param_patterns = [
|
|
1859
|
+
'weight', 'bias', 'embeddings', 'lm_head',
|
|
1860
|
+
'q_proj', 'k_proj', 'v_proj', 'o_proj',
|
|
1861
|
+
'gate_proj', 'up_proj', 'down_proj',
|
|
1862
|
+
'fc1', 'fc2', 'mlp', 'attention'
|
|
1863
|
+
]
|
|
1864
|
+
|
|
1865
|
+
# Common non-parameter patterns (buffers)
|
|
1866
|
+
non_param_patterns = [
|
|
1867
|
+
'position_ids', 'attention_mask', 'token_type_ids',
|
|
1868
|
+
'freqs_cos', 'freqs_sin', 'inv_freq'
|
|
1869
|
+
]
|
|
1870
|
+
|
|
1871
|
+
tensor_lower = tensor_name.lower()
|
|
1872
|
+
|
|
1873
|
+
# Check non-parameter patterns first
|
|
1874
|
+
for pattern in non_param_patterns:
|
|
1875
|
+
if pattern in tensor_lower:
|
|
1876
|
+
return False
|
|
1877
|
+
|
|
1878
|
+
# Check parameter patterns
|
|
1879
|
+
for pattern in param_patterns:
|
|
1880
|
+
if pattern in tensor_lower:
|
|
1881
|
+
return True
|
|
1882
|
+
|
|
1883
|
+
# Default: assume it's a parameter if it contains common layer indicators
|
|
1884
|
+
return any(indicator in tensor_lower for indicator in ['layer', 'block', 'transformer'])
|
|
1885
|
+
|
|
1886
|
+
def _detect_mlx_parameters(self, model_path: Path) -> Optional[int]:
|
|
1887
|
+
"""Detect parameters in MLX models by inspecting weights.npz or using config."""
|
|
1888
|
+
try:
|
|
1889
|
+
# First try to read from weights.npz directly
|
|
1890
|
+
weights_file = model_path / "weights.npz"
|
|
1891
|
+
if weights_file.exists():
|
|
1892
|
+
import numpy as np
|
|
1893
|
+
|
|
1894
|
+
# Load the weights file
|
|
1895
|
+
weights = np.load(weights_file)
|
|
1896
|
+
total_params = 0
|
|
1897
|
+
|
|
1898
|
+
for array_name in weights.files:
|
|
1899
|
+
if self._is_parameter_tensor(array_name):
|
|
1900
|
+
array = weights[array_name]
|
|
1901
|
+
total_params += array.size
|
|
1902
|
+
|
|
1903
|
+
return total_params if total_params > 0 else None
|
|
1904
|
+
|
|
1905
|
+
# Fallback to checking for SafeTensors in MLX directory
|
|
1906
|
+
safetensor_files = list(model_path.glob("*.safetensors"))
|
|
1907
|
+
if safetensor_files:
|
|
1908
|
+
return self._detect_safetensors_parameters(model_path)
|
|
1909
|
+
|
|
1910
|
+
return None
|
|
1911
|
+
|
|
1912
|
+
except Exception as e:
|
|
1913
|
+
logger.warning(f"MLX parameter detection failed: {e}")
|
|
1914
|
+
return None
|
|
1915
|
+
|
|
1916
|
+
def _detect_pytorch_parameters(self, model_path: Path) -> Optional[int]:
|
|
1917
|
+
"""Detect parameters in PyTorch models."""
|
|
1918
|
+
try:
|
|
1919
|
+
# For PyTorch models, we need to load the config to get architecture info
|
|
1920
|
+
# as loading the full model would be too expensive
|
|
1921
|
+
return self._detect_config_parameters(model_path)
|
|
1922
|
+
|
|
1923
|
+
except Exception as e:
|
|
1924
|
+
logger.warning(f"PyTorch parameter detection failed: {e}")
|
|
1925
|
+
return None
|
|
1926
|
+
|
|
1927
|
+
def _detect_config_parameters(self, model_path: Path) -> Optional[int]:
|
|
1928
|
+
"""Detect parameters by analyzing config.json and calculating from architecture."""
|
|
1929
|
+
try:
|
|
1930
|
+
config_path = model_path / "config.json"
|
|
1931
|
+
if not config_path.exists():
|
|
1932
|
+
return None
|
|
1933
|
+
|
|
1934
|
+
with open(config_path, 'r') as f:
|
|
1935
|
+
config = json.load(f)
|
|
1936
|
+
|
|
1937
|
+
# Check for directly specified parameter count
|
|
1938
|
+
if 'num_parameters' in config:
|
|
1939
|
+
return int(config['num_parameters'])
|
|
1940
|
+
|
|
1941
|
+
# Calculate from architecture parameters
|
|
1942
|
+
model_type = config.get('model_type', '').lower()
|
|
1943
|
+
|
|
1944
|
+
if model_type in ['llama', 'gemma', 'mistral', 'qwen']:
|
|
1945
|
+
return self._calculate_llama_parameters(config)
|
|
1946
|
+
elif model_type in ['gpt', 'gpt2', 'gpt_neo', 'gpt_neox']:
|
|
1947
|
+
return self._calculate_gpt_parameters(config)
|
|
1948
|
+
elif model_type in ['bert', 'roberta', 'distilbert']:
|
|
1949
|
+
return self._calculate_bert_parameters(config)
|
|
1950
|
+
else:
|
|
1951
|
+
# Generic calculation for transformer models
|
|
1952
|
+
return self._calculate_generic_transformer_parameters(config)
|
|
1953
|
+
|
|
1954
|
+
except Exception as e:
|
|
1955
|
+
logger.warning(f"Config parameter detection failed: {e}")
|
|
1956
|
+
return None
|
|
1957
|
+
|
|
1958
|
+
def _calculate_llama_parameters(self, config: Dict) -> Optional[int]:
|
|
1959
|
+
"""Calculate parameters for Llama-style models (including Gemma)."""
|
|
1960
|
+
try:
|
|
1961
|
+
vocab_size = config.get('vocab_size', 32000)
|
|
1962
|
+
hidden_size = config.get('hidden_size', 4096)
|
|
1963
|
+
intermediate_size = config.get('intermediate_size', 11008)
|
|
1964
|
+
num_layers = config.get('num_hidden_layers', 32)
|
|
1965
|
+
num_attention_heads = config.get('num_attention_heads', 32)
|
|
1966
|
+
|
|
1967
|
+
# Check if this is a Gemma model for special handling
|
|
1968
|
+
is_gemma = config.get('model_type', '').lower() == 'gemma'
|
|
1969
|
+
|
|
1970
|
+
# Embedding layer
|
|
1971
|
+
embedding_params = vocab_size * hidden_size
|
|
1972
|
+
|
|
1973
|
+
# Each transformer layer:
|
|
1974
|
+
if is_gemma:
|
|
1975
|
+
# Gemma uses grouped query attention with fewer k/v heads
|
|
1976
|
+
num_key_value_heads = config.get('num_key_value_heads', num_attention_heads // 4)
|
|
1977
|
+
head_dim = hidden_size // num_attention_heads
|
|
1978
|
+
|
|
1979
|
+
# Attention projections
|
|
1980
|
+
q_proj_params = hidden_size * hidden_size # Full size
|
|
1981
|
+
k_proj_params = hidden_size * (num_key_value_heads * head_dim) # Reduced
|
|
1982
|
+
v_proj_params = hidden_size * (num_key_value_heads * head_dim) # Reduced
|
|
1983
|
+
o_proj_params = hidden_size * hidden_size # Full size
|
|
1984
|
+
|
|
1985
|
+
attention_params = q_proj_params + k_proj_params + v_proj_params + o_proj_params
|
|
1986
|
+
else:
|
|
1987
|
+
# Standard Llama: q, k, v, o projections all full size
|
|
1988
|
+
attention_params = 4 * (hidden_size * hidden_size)
|
|
1989
|
+
|
|
1990
|
+
# Feed-forward: gate_proj, up_proj, down_proj
|
|
1991
|
+
ff_params = 2 * (hidden_size * intermediate_size) + (intermediate_size * hidden_size)
|
|
1992
|
+
|
|
1993
|
+
# Layer norms (2 per layer)
|
|
1994
|
+
ln_params = 2 * hidden_size
|
|
1995
|
+
|
|
1996
|
+
layer_params = attention_params + ff_params + ln_params
|
|
1997
|
+
transformer_params = num_layers * layer_params
|
|
1998
|
+
|
|
1999
|
+
# Final layer norm
|
|
2000
|
+
final_ln_params = hidden_size
|
|
2001
|
+
|
|
2002
|
+
# LM head - check if tied to embeddings (common in smaller models)
|
|
2003
|
+
tie_word_embeddings = config.get('tie_word_embeddings', True) # Default True for most models
|
|
2004
|
+
if tie_word_embeddings:
|
|
2005
|
+
lm_head_params = 0 # Tied to embeddings, don't double count
|
|
2006
|
+
else:
|
|
2007
|
+
lm_head_params = vocab_size * hidden_size
|
|
2008
|
+
|
|
2009
|
+
total = embedding_params + transformer_params + final_ln_params + lm_head_params
|
|
2010
|
+
|
|
2011
|
+
return total
|
|
2012
|
+
|
|
2013
|
+
except Exception as e:
|
|
2014
|
+
logger.warning(f"Llama parameter calculation failed: {e}")
|
|
2015
|
+
return None
|
|
2016
|
+
|
|
2017
|
+
def _calculate_gpt_parameters(self, config: Dict) -> Optional[int]:
|
|
2018
|
+
"""Calculate parameters for GPT-style models."""
|
|
2019
|
+
try:
|
|
2020
|
+
vocab_size = config.get('vocab_size', 50257)
|
|
2021
|
+
n_embd = config.get('n_embd', config.get('hidden_size', 768))
|
|
2022
|
+
n_layer = config.get('n_layer', config.get('num_hidden_layers', 12))
|
|
2023
|
+
n_head = config.get('n_head', config.get('num_attention_heads', 12))
|
|
2024
|
+
|
|
2025
|
+
# Token + position embeddings
|
|
2026
|
+
max_position_embeddings = config.get('n_positions', config.get('max_position_embeddings', 1024))
|
|
2027
|
+
embedding_params = vocab_size * n_embd + max_position_embeddings * n_embd
|
|
2028
|
+
|
|
2029
|
+
# Each transformer block
|
|
2030
|
+
# - Attention: qkv projection + output projection
|
|
2031
|
+
attention_params = 4 * (n_embd * n_embd)
|
|
2032
|
+
|
|
2033
|
+
# - MLP: typically 4x expansion
|
|
2034
|
+
mlp_size = config.get('n_inner', 4 * n_embd)
|
|
2035
|
+
mlp_params = n_embd * mlp_size + mlp_size * n_embd
|
|
2036
|
+
|
|
2037
|
+
# - Layer norms
|
|
2038
|
+
ln_params = 2 * n_embd
|
|
2039
|
+
|
|
2040
|
+
block_params = attention_params + mlp_params + ln_params
|
|
2041
|
+
transformer_params = n_layer * block_params
|
|
2042
|
+
|
|
2043
|
+
# Final layer norm + LM head
|
|
2044
|
+
final_ln_params = n_embd
|
|
2045
|
+
lm_head_params = vocab_size * n_embd
|
|
2046
|
+
|
|
2047
|
+
total = embedding_params + transformer_params + final_ln_params + lm_head_params
|
|
2048
|
+
|
|
2049
|
+
return total
|
|
2050
|
+
|
|
2051
|
+
except Exception as e:
|
|
2052
|
+
logger.warning(f"GPT parameter calculation failed: {e}")
|
|
2053
|
+
return None
|
|
2054
|
+
|
|
2055
|
+
def _calculate_bert_parameters(self, config: Dict) -> Optional[int]:
|
|
2056
|
+
"""Calculate parameters for BERT-style models."""
|
|
2057
|
+
try:
|
|
2058
|
+
vocab_size = config.get('vocab_size', 30522)
|
|
2059
|
+
hidden_size = config.get('hidden_size', 768)
|
|
2060
|
+
num_hidden_layers = config.get('num_hidden_layers', 12)
|
|
2061
|
+
intermediate_size = config.get('intermediate_size', 3072)
|
|
2062
|
+
max_position_embeddings = config.get('max_position_embeddings', 512)
|
|
2063
|
+
type_vocab_size = config.get('type_vocab_size', 2)
|
|
2064
|
+
|
|
2065
|
+
# Embeddings: token + position + token_type
|
|
2066
|
+
embedding_params = (vocab_size * hidden_size +
|
|
2067
|
+
max_position_embeddings * hidden_size +
|
|
2068
|
+
type_vocab_size * hidden_size)
|
|
2069
|
+
|
|
2070
|
+
# Each encoder layer
|
|
2071
|
+
# - Self-attention
|
|
2072
|
+
attention_params = 4 * (hidden_size * hidden_size)
|
|
2073
|
+
|
|
2074
|
+
# - Feed-forward
|
|
2075
|
+
ff_params = hidden_size * intermediate_size + intermediate_size * hidden_size
|
|
2076
|
+
|
|
2077
|
+
# - Layer norms
|
|
2078
|
+
ln_params = 2 * hidden_size
|
|
2079
|
+
|
|
2080
|
+
layer_params = attention_params + ff_params + ln_params
|
|
2081
|
+
encoder_params = num_hidden_layers * layer_params
|
|
2082
|
+
|
|
2083
|
+
# Pooler (optional)
|
|
2084
|
+
pooler_params = hidden_size * hidden_size
|
|
2085
|
+
|
|
2086
|
+
total = embedding_params + encoder_params + pooler_params
|
|
2087
|
+
|
|
2088
|
+
return total
|
|
2089
|
+
|
|
2090
|
+
except Exception as e:
|
|
2091
|
+
logger.warning(f"BERT parameter calculation failed: {e}")
|
|
2092
|
+
return None
|
|
2093
|
+
|
|
2094
|
+
def _calculate_generic_transformer_parameters(self, config: Dict) -> Optional[int]:
|
|
2095
|
+
"""Generic parameter calculation for transformer models."""
|
|
2096
|
+
try:
|
|
2097
|
+
# Try to extract common parameters
|
|
2098
|
+
vocab_size = config.get('vocab_size', 32000)
|
|
2099
|
+
hidden_size = config.get('hidden_size', config.get('n_embd', config.get('d_model', 512)))
|
|
2100
|
+
num_layers = config.get('num_hidden_layers', config.get('n_layer', config.get('num_layers', 6)))
|
|
2101
|
+
|
|
2102
|
+
if hidden_size is None or num_layers is None:
|
|
2103
|
+
return None
|
|
2104
|
+
|
|
2105
|
+
# Very rough estimation for generic transformers
|
|
2106
|
+
# Embeddings + layers + head
|
|
2107
|
+
embedding_params = vocab_size * hidden_size
|
|
2108
|
+
|
|
2109
|
+
# Each layer: attention + ffn + norms (rough 6x hidden_size^2 per layer)
|
|
2110
|
+
layer_params = 6 * (hidden_size * hidden_size)
|
|
2111
|
+
transformer_params = num_layers * layer_params
|
|
2112
|
+
|
|
2113
|
+
# Output head
|
|
2114
|
+
head_params = vocab_size * hidden_size
|
|
2115
|
+
|
|
2116
|
+
total = embedding_params + transformer_params + head_params
|
|
2117
|
+
|
|
2118
|
+
logger.info(f"Generic parameter estimation: {total:,} parameters")
|
|
2119
|
+
return total
|
|
2120
|
+
|
|
2121
|
+
except Exception as e:
|
|
2122
|
+
logger.warning(f"Generic parameter calculation failed: {e}")
|
|
2123
|
+
return None
|
|
2124
|
+
|
|
2125
|
+
def get_model_parameters_smart(self, model_path: Path) -> float:
|
|
2126
|
+
"""Get model parameters in billions with smart detection, fallback to size estimation."""
|
|
2127
|
+
# Try accurate parameter detection first
|
|
2128
|
+
param_count = self.detect_model_parameters(model_path)
|
|
2129
|
+
|
|
2130
|
+
if param_count is not None:
|
|
2131
|
+
return param_count / 1e9 # Convert to billions
|
|
2132
|
+
|
|
2133
|
+
# Fallback to improved size-based estimation
|
|
2134
|
+
logger.warning(f"Falling back to size-based parameter estimation for {model_path.name}")
|
|
2135
|
+
|
|
2136
|
+
size_bytes = self._get_model_size(model_path)
|
|
2137
|
+
size_gb = size_bytes / (1024**3)
|
|
2138
|
+
|
|
2139
|
+
# Improved estimation that considers file overhead
|
|
2140
|
+
# Only count actual weight files, not tokenizer configs etc.
|
|
2141
|
+
weight_size = self._estimate_weight_file_size(model_path)
|
|
2142
|
+
weight_size_gb = weight_size / (1024**3)
|
|
2143
|
+
|
|
2144
|
+
# Use weight size if significantly different from total size
|
|
2145
|
+
if weight_size_gb < size_gb * 0.8: # If weight files are <80% of total
|
|
2146
|
+
size_gb = weight_size_gb
|
|
2147
|
+
logger.info(f"Using weight-only size: {size_gb:.2f}GB (total: {size_bytes / (1024**3):.2f}GB)")
|
|
2148
|
+
|
|
2149
|
+
# Better default estimation: 2.2 bytes per parameter (accounts for some overhead)
|
|
2150
|
+
estimated_params_b = size_gb / 2.2 # Already in billions
|
|
2151
|
+
|
|
2152
|
+
logger.info(f"Estimated {estimated_params_b:.2f}B parameters from {size_gb:.2f}GB model size")
|
|
2153
|
+
return estimated_params_b
|
|
2154
|
+
|
|
2155
|
+
def _estimate_weight_file_size(self, model_path: Path) -> int:
|
|
2156
|
+
"""Estimate size of actual weight files, excluding configs and tokenizers."""
|
|
2157
|
+
if model_path.is_file():
|
|
2158
|
+
return model_path.stat().st_size
|
|
2159
|
+
|
|
2160
|
+
weight_patterns = [
|
|
2161
|
+
'*.safetensors', '*.bin', '*.npz',
|
|
2162
|
+
'pytorch_model*.bin', 'model*.safetensors'
|
|
2163
|
+
]
|
|
2164
|
+
|
|
2165
|
+
non_weight_patterns = [
|
|
2166
|
+
'tokenizer*', 'vocab*', 'merges.txt', 'config.json',
|
|
2167
|
+
'generation_config.json', 'special_tokens_map.json',
|
|
2168
|
+
'tokenizer_config.json', 'added_tokens.json'
|
|
2169
|
+
]
|
|
2170
|
+
|
|
2171
|
+
total_weight_size = 0
|
|
2172
|
+
|
|
2173
|
+
for file_path in model_path.rglob('*'):
|
|
2174
|
+
if file_path.is_file():
|
|
2175
|
+
# Check if it matches weight patterns
|
|
2176
|
+
is_weight = any(file_path.match(pattern) for pattern in weight_patterns)
|
|
2177
|
+
|
|
2178
|
+
# Exclude non-weight files
|
|
2179
|
+
is_non_weight = any(file_path.match(pattern) for pattern in non_weight_patterns)
|
|
2180
|
+
|
|
2181
|
+
if is_weight and not is_non_weight:
|
|
2182
|
+
total_weight_size += file_path.stat().st_size
|
|
2183
|
+
elif not is_non_weight and file_path.suffix in ['.safetensors', '.bin', '.npz']:
|
|
2184
|
+
# Include other tensor files that don't match specific patterns
|
|
2185
|
+
total_weight_size += file_path.stat().st_size
|
|
2186
|
+
|
|
2187
|
+
return total_weight_size
|