cortex-llm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex/__init__.py +73 -0
- cortex/__main__.py +83 -0
- cortex/config.py +329 -0
- cortex/conversation_manager.py +468 -0
- cortex/fine_tuning/__init__.py +8 -0
- cortex/fine_tuning/dataset.py +332 -0
- cortex/fine_tuning/mlx_lora_trainer.py +502 -0
- cortex/fine_tuning/trainer.py +957 -0
- cortex/fine_tuning/wizard.py +707 -0
- cortex/gpu_validator.py +467 -0
- cortex/inference_engine.py +727 -0
- cortex/metal/__init__.py +275 -0
- cortex/metal/gpu_validator.py +177 -0
- cortex/metal/memory_pool.py +886 -0
- cortex/metal/mlx_accelerator.py +678 -0
- cortex/metal/mlx_converter.py +638 -0
- cortex/metal/mps_optimizer.py +417 -0
- cortex/metal/optimizer.py +665 -0
- cortex/metal/performance_profiler.py +364 -0
- cortex/model_downloader.py +130 -0
- cortex/model_manager.py +2187 -0
- cortex/quantization/__init__.py +5 -0
- cortex/quantization/dynamic_quantizer.py +736 -0
- cortex/template_registry/__init__.py +15 -0
- cortex/template_registry/auto_detector.py +144 -0
- cortex/template_registry/config_manager.py +234 -0
- cortex/template_registry/interactive.py +260 -0
- cortex/template_registry/registry.py +347 -0
- cortex/template_registry/template_profiles/__init__.py +5 -0
- cortex/template_registry/template_profiles/base.py +142 -0
- cortex/template_registry/template_profiles/complex/__init__.py +5 -0
- cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
- cortex/template_registry/template_profiles/standard/__init__.py +9 -0
- cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
- cortex/template_registry/template_profiles/standard/chatml.py +82 -0
- cortex/template_registry/template_profiles/standard/gemma.py +103 -0
- cortex/template_registry/template_profiles/standard/llama.py +87 -0
- cortex/template_registry/template_profiles/standard/simple.py +65 -0
- cortex/ui/__init__.py +120 -0
- cortex/ui/cli.py +1685 -0
- cortex/ui/markdown_render.py +185 -0
- cortex/ui/terminal_app.py +534 -0
- cortex_llm-1.0.0.dist-info/METADATA +275 -0
- cortex_llm-1.0.0.dist-info/RECORD +48 -0
- cortex_llm-1.0.0.dist-info/WHEEL +5 -0
- cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
- cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
- cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
cortex/gpu_validator.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
1
|
+
"""GPU validation for Metal/MPS support on Apple Silicon."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import platform
|
|
5
|
+
import subprocess
|
|
6
|
+
from typing import Dict, Any, Optional, Tuple
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
import psutil
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class GPUInfo:
|
|
12
|
+
"""GPU information and capabilities."""
|
|
13
|
+
has_metal: bool
|
|
14
|
+
has_mps: bool
|
|
15
|
+
has_mlx: bool
|
|
16
|
+
gpu_cores: int
|
|
17
|
+
total_memory: int
|
|
18
|
+
available_memory: int
|
|
19
|
+
metal_version: Optional[str]
|
|
20
|
+
chip_name: str
|
|
21
|
+
unified_memory: bool
|
|
22
|
+
is_apple_silicon: bool
|
|
23
|
+
|
|
24
|
+
# MSL v4 capabilities
|
|
25
|
+
gpu_family: str # apple5 (M1), apple6 (M2), apple7 (M3), apple8 (M4)
|
|
26
|
+
supports_bfloat16: bool
|
|
27
|
+
supports_simdgroup_matrix: bool
|
|
28
|
+
supports_mpp: bool
|
|
29
|
+
supports_tile_functions: bool
|
|
30
|
+
supports_atomic_float: bool
|
|
31
|
+
supports_fast_math: bool
|
|
32
|
+
supports_function_constants: bool
|
|
33
|
+
max_threads_per_threadgroup: int
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def is_valid(self) -> bool:
|
|
37
|
+
"""Check if GPU meets requirements."""
|
|
38
|
+
# Minimum requirements for production
|
|
39
|
+
min_memory = 4 * 1024 * 1024 * 1024 # 4GB minimum for small models
|
|
40
|
+
min_cores = 8 # M1 and above have at least 8 cores
|
|
41
|
+
|
|
42
|
+
return (
|
|
43
|
+
self.has_metal and
|
|
44
|
+
self.has_mps and
|
|
45
|
+
self.has_mlx and
|
|
46
|
+
self.gpu_cores >= min_cores and
|
|
47
|
+
self.available_memory >= min_memory
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def get_validation_errors(self) -> list[str]:
|
|
51
|
+
"""Get list of validation errors."""
|
|
52
|
+
errors = []
|
|
53
|
+
min_memory = 4 * 1024 * 1024 * 1024 # 4GB minimum
|
|
54
|
+
min_cores = 8
|
|
55
|
+
|
|
56
|
+
if not self.has_metal:
|
|
57
|
+
errors.append("Metal support not available")
|
|
58
|
+
if not self.has_mps:
|
|
59
|
+
errors.append("Metal Performance Shaders (MPS) not available")
|
|
60
|
+
if not self.has_mlx:
|
|
61
|
+
errors.append("MLX framework not available (install with: pip install mlx)")
|
|
62
|
+
if self.gpu_cores < min_cores:
|
|
63
|
+
errors.append(f"Insufficient GPU cores: {self.gpu_cores} (need {min_cores})")
|
|
64
|
+
if self.available_memory < min_memory:
|
|
65
|
+
memory_gb = self.available_memory / (1024 * 1024 * 1024)
|
|
66
|
+
min_memory_gb = min_memory / (1024 * 1024 * 1024)
|
|
67
|
+
errors.append(f"Insufficient GPU memory: {memory_gb:.1f}GB (need {min_memory_gb:.1f}GB)")
|
|
68
|
+
return errors
|
|
69
|
+
|
|
70
|
+
class GPUValidator:
|
|
71
|
+
"""Validate GPU capabilities for Cortex."""
|
|
72
|
+
|
|
73
|
+
def __init__(self, config=None):
|
|
74
|
+
"""Initialize GPU validator."""
|
|
75
|
+
self.config = config # Store config if provided
|
|
76
|
+
self.gpu_info: Optional[GPUInfo] = None
|
|
77
|
+
self._torch_available = False
|
|
78
|
+
self._mlx_available = False
|
|
79
|
+
self._validate_imports()
|
|
80
|
+
|
|
81
|
+
def _validate_imports(self) -> None:
|
|
82
|
+
"""Validate that required GPU libraries are available."""
|
|
83
|
+
try:
|
|
84
|
+
import torch
|
|
85
|
+
self._torch_available = True
|
|
86
|
+
except ImportError:
|
|
87
|
+
self._torch_available = False
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
import mlx.core as mx
|
|
91
|
+
self._mlx_available = True
|
|
92
|
+
except ImportError:
|
|
93
|
+
self._mlx_available = False
|
|
94
|
+
|
|
95
|
+
def validate(self) -> Tuple[bool, Optional[GPUInfo], list[str]]:
|
|
96
|
+
"""
|
|
97
|
+
Validate GPU support.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Tuple of (is_valid, gpu_info, errors)
|
|
101
|
+
"""
|
|
102
|
+
errors = []
|
|
103
|
+
|
|
104
|
+
if platform.system().lower() != "darwin":
|
|
105
|
+
errors.append(f"macOS required, found {platform.system()}")
|
|
106
|
+
return False, None, errors
|
|
107
|
+
|
|
108
|
+
if platform.machine() != "arm64":
|
|
109
|
+
errors.append(f"ARM64 architecture required, found {platform.machine()}")
|
|
110
|
+
return False, None, errors
|
|
111
|
+
|
|
112
|
+
self.gpu_info = self._get_gpu_info()
|
|
113
|
+
|
|
114
|
+
if not self.gpu_info.is_valid:
|
|
115
|
+
errors.extend(self.gpu_info.get_validation_errors())
|
|
116
|
+
return False, self.gpu_info, errors
|
|
117
|
+
|
|
118
|
+
return True, self.gpu_info, []
|
|
119
|
+
|
|
120
|
+
def _get_gpu_info(self) -> GPUInfo:
|
|
121
|
+
"""Get GPU information from system."""
|
|
122
|
+
chip_name = self._get_chip_name()
|
|
123
|
+
gpu_cores = self._get_gpu_cores(chip_name)
|
|
124
|
+
memory_info = self._get_memory_info()
|
|
125
|
+
|
|
126
|
+
has_metal = self._check_metal_support()
|
|
127
|
+
has_mps = self._check_mps_support()
|
|
128
|
+
has_mlx = self._check_mlx_support()
|
|
129
|
+
metal_version = self._get_metal_version()
|
|
130
|
+
|
|
131
|
+
# Detect GPU family and capabilities
|
|
132
|
+
gpu_family = self._detect_gpu_family(chip_name)
|
|
133
|
+
capabilities = self._detect_msl_capabilities(gpu_family)
|
|
134
|
+
|
|
135
|
+
return GPUInfo(
|
|
136
|
+
has_metal=has_metal,
|
|
137
|
+
has_mps=has_mps,
|
|
138
|
+
has_mlx=has_mlx,
|
|
139
|
+
gpu_cores=gpu_cores,
|
|
140
|
+
total_memory=memory_info['total'],
|
|
141
|
+
available_memory=memory_info['available'],
|
|
142
|
+
metal_version=metal_version,
|
|
143
|
+
chip_name=chip_name,
|
|
144
|
+
unified_memory=True,
|
|
145
|
+
is_apple_silicon=any(chip in chip_name for chip in ["M1", "M2", "M3", "M4"]),
|
|
146
|
+
gpu_family=gpu_family,
|
|
147
|
+
supports_bfloat16=capabilities['bfloat16'],
|
|
148
|
+
supports_simdgroup_matrix=capabilities['simdgroup_matrix'],
|
|
149
|
+
supports_mpp=capabilities['mpp'],
|
|
150
|
+
supports_tile_functions=capabilities['tile_functions'],
|
|
151
|
+
supports_atomic_float=capabilities['atomic_float'],
|
|
152
|
+
supports_fast_math=capabilities['fast_math'],
|
|
153
|
+
supports_function_constants=capabilities['function_constants'],
|
|
154
|
+
max_threads_per_threadgroup=capabilities['max_threads_per_threadgroup']
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def _get_chip_name(self) -> str:
|
|
158
|
+
"""Get Apple Silicon chip name."""
|
|
159
|
+
try:
|
|
160
|
+
result = subprocess.run(
|
|
161
|
+
["sysctl", "-n", "machdep.cpu.brand_string"],
|
|
162
|
+
capture_output=True,
|
|
163
|
+
text=True,
|
|
164
|
+
check=True
|
|
165
|
+
)
|
|
166
|
+
return result.stdout.strip()
|
|
167
|
+
except subprocess.CalledProcessError:
|
|
168
|
+
return "Unknown"
|
|
169
|
+
|
|
170
|
+
def _get_gpu_cores(self, chip_name: str) -> int:
|
|
171
|
+
"""Get number of GPU cores based on chip."""
|
|
172
|
+
gpu_core_map = {
|
|
173
|
+
"M4": 16,
|
|
174
|
+
"M4 Pro": 20,
|
|
175
|
+
"M4 Max": 40,
|
|
176
|
+
"M3": 10,
|
|
177
|
+
"M3 Pro": 18,
|
|
178
|
+
"M3 Max": 40,
|
|
179
|
+
"M2": 10,
|
|
180
|
+
"M2 Pro": 19,
|
|
181
|
+
"M2 Max": 38,
|
|
182
|
+
"M1": 8,
|
|
183
|
+
"M1 Pro": 16,
|
|
184
|
+
"M1 Max": 32,
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
for chip, cores in gpu_core_map.items():
|
|
188
|
+
if chip in chip_name:
|
|
189
|
+
return cores
|
|
190
|
+
|
|
191
|
+
return 0
|
|
192
|
+
|
|
193
|
+
def _get_memory_info(self) -> Dict[str, int]:
|
|
194
|
+
"""Get memory information."""
|
|
195
|
+
vm = psutil.virtual_memory()
|
|
196
|
+
return {
|
|
197
|
+
'total': vm.total,
|
|
198
|
+
'available': vm.available,
|
|
199
|
+
'used': vm.used,
|
|
200
|
+
'percent': vm.percent
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
def _check_metal_support(self) -> bool:
|
|
204
|
+
"""Check if Metal is supported."""
|
|
205
|
+
try:
|
|
206
|
+
result = subprocess.run(
|
|
207
|
+
["system_profiler", "SPDisplaysDataType"],
|
|
208
|
+
capture_output=True,
|
|
209
|
+
text=True,
|
|
210
|
+
check=True
|
|
211
|
+
)
|
|
212
|
+
return "Metal" in result.stdout
|
|
213
|
+
except subprocess.CalledProcessError:
|
|
214
|
+
return False
|
|
215
|
+
|
|
216
|
+
def _check_mps_support(self) -> bool:
|
|
217
|
+
"""Check if MPS (Metal Performance Shaders) is available."""
|
|
218
|
+
if not self._torch_available:
|
|
219
|
+
return False
|
|
220
|
+
|
|
221
|
+
try:
|
|
222
|
+
import torch
|
|
223
|
+
return torch.backends.mps.is_available()
|
|
224
|
+
except Exception:
|
|
225
|
+
return False
|
|
226
|
+
|
|
227
|
+
def _check_mlx_support(self) -> bool:
|
|
228
|
+
"""Check if MLX is available."""
|
|
229
|
+
if not self._mlx_available:
|
|
230
|
+
return False
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
import mlx.core as mx
|
|
234
|
+
device = mx.default_device()
|
|
235
|
+
return "gpu" in str(device).lower()
|
|
236
|
+
except Exception:
|
|
237
|
+
return False
|
|
238
|
+
|
|
239
|
+
def _get_metal_version(self) -> Optional[str]:
|
|
240
|
+
"""Get Metal version."""
|
|
241
|
+
try:
|
|
242
|
+
result = subprocess.run(
|
|
243
|
+
["xcrun", "--show-sdk-version"],
|
|
244
|
+
capture_output=True,
|
|
245
|
+
text=True,
|
|
246
|
+
check=True
|
|
247
|
+
)
|
|
248
|
+
sdk_version = result.stdout.strip()
|
|
249
|
+
|
|
250
|
+
if float(sdk_version.split('.')[0]) >= 14:
|
|
251
|
+
return "Metal 3"
|
|
252
|
+
else:
|
|
253
|
+
return "Metal 2"
|
|
254
|
+
except Exception:
|
|
255
|
+
return None
|
|
256
|
+
|
|
257
|
+
def _detect_gpu_family(self, chip_name: str) -> str:
|
|
258
|
+
"""
|
|
259
|
+
Detect GPU family based on chip name.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
GPU family identifier (apple5, apple6, apple7, apple8)
|
|
263
|
+
"""
|
|
264
|
+
# Map chip names to GPU families
|
|
265
|
+
if "M4" in chip_name:
|
|
266
|
+
return "apple8"
|
|
267
|
+
elif "M3" in chip_name:
|
|
268
|
+
return "apple7"
|
|
269
|
+
elif "M2" in chip_name:
|
|
270
|
+
return "apple6"
|
|
271
|
+
elif "M1" in chip_name:
|
|
272
|
+
return "apple5"
|
|
273
|
+
else:
|
|
274
|
+
# Default to M1 capabilities for unknown chips
|
|
275
|
+
return "apple5"
|
|
276
|
+
|
|
277
|
+
def _detect_msl_capabilities(self, gpu_family: str) -> Dict[str, Any]:
|
|
278
|
+
"""
|
|
279
|
+
Detect MSL v4 capabilities based on GPU family.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
gpu_family: GPU family identifier
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Dictionary of capabilities
|
|
286
|
+
"""
|
|
287
|
+
# Base capabilities (M1 and all Apple Silicon)
|
|
288
|
+
capabilities = {
|
|
289
|
+
'bfloat16': False,
|
|
290
|
+
'simdgroup_matrix': True, # All Apple Silicon supports this
|
|
291
|
+
'mpp': False,
|
|
292
|
+
'tile_functions': False,
|
|
293
|
+
'atomic_float': True,
|
|
294
|
+
'fast_math': True,
|
|
295
|
+
'function_constants': True,
|
|
296
|
+
'max_threads_per_threadgroup': 1024
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
# M2 and later capabilities
|
|
300
|
+
if gpu_family in ["apple6", "apple7", "apple8"]:
|
|
301
|
+
capabilities.update({
|
|
302
|
+
'bfloat16': True,
|
|
303
|
+
'mpp': True,
|
|
304
|
+
'max_threads_per_threadgroup': 1024
|
|
305
|
+
})
|
|
306
|
+
|
|
307
|
+
# M3 and later capabilities
|
|
308
|
+
if gpu_family in ["apple7", "apple8"]:
|
|
309
|
+
capabilities.update({
|
|
310
|
+
'tile_functions': True,
|
|
311
|
+
'max_threads_per_threadgroup': 1024
|
|
312
|
+
})
|
|
313
|
+
|
|
314
|
+
return capabilities
|
|
315
|
+
|
|
316
|
+
def get_optimal_dtype(self) -> str:
|
|
317
|
+
"""
|
|
318
|
+
Get optimal data type for current hardware.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
'bfloat16' for M2+, 'float16' for M1
|
|
322
|
+
"""
|
|
323
|
+
if not self.gpu_info:
|
|
324
|
+
self.validate()
|
|
325
|
+
|
|
326
|
+
if self.gpu_info and self.gpu_info.supports_bfloat16:
|
|
327
|
+
return 'bfloat16'
|
|
328
|
+
else:
|
|
329
|
+
return 'float16'
|
|
330
|
+
|
|
331
|
+
def check_bfloat16_support(self) -> bool:
|
|
332
|
+
"""
|
|
333
|
+
Check if current hardware supports bfloat16.
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
True if bfloat16 is supported
|
|
337
|
+
"""
|
|
338
|
+
if not self.gpu_info:
|
|
339
|
+
self.validate()
|
|
340
|
+
|
|
341
|
+
return self.gpu_info.supports_bfloat16 if self.gpu_info else False
|
|
342
|
+
|
|
343
|
+
def print_gpu_info(self) -> None:
|
|
344
|
+
"""Print GPU information."""
|
|
345
|
+
if not self.gpu_info:
|
|
346
|
+
self.validate()
|
|
347
|
+
|
|
348
|
+
if not self.gpu_info:
|
|
349
|
+
print("ā Unable to get GPU information")
|
|
350
|
+
return
|
|
351
|
+
|
|
352
|
+
print("š„ļø GPU Information:")
|
|
353
|
+
print(f" Chip: {self.gpu_info.chip_name}")
|
|
354
|
+
print(f" GPU Family: {self.gpu_info.gpu_family}")
|
|
355
|
+
print(f" GPU Cores: {self.gpu_info.gpu_cores}")
|
|
356
|
+
print(f" Total Memory: {self.gpu_info.total_memory / (1024**3):.1f} GB")
|
|
357
|
+
print(f" Available Memory: {self.gpu_info.available_memory / (1024**3):.1f} GB")
|
|
358
|
+
print(f" Metal: {'ā
' if self.gpu_info.has_metal else 'ā'}")
|
|
359
|
+
print(f" MPS: {'ā
' if self.gpu_info.has_mps else 'ā'}")
|
|
360
|
+
print(f" MLX: {'ā
' if self.gpu_info.has_mlx else 'ā'}")
|
|
361
|
+
print(f" Metal Version: {self.gpu_info.metal_version or 'Unknown'}")
|
|
362
|
+
print(f" Unified Memory: {'ā
' if self.gpu_info.unified_memory else 'ā'}")
|
|
363
|
+
print(f" Apple Silicon: {'ā
' if self.gpu_info.is_apple_silicon else 'ā'}")
|
|
364
|
+
|
|
365
|
+
print("\nš MSL v4 Capabilities:")
|
|
366
|
+
print(f" bfloat16: {'ā
' if self.gpu_info.supports_bfloat16 else 'ā'}")
|
|
367
|
+
print(f" SIMD-group matrices: {'ā
' if self.gpu_info.supports_simdgroup_matrix else 'ā'}")
|
|
368
|
+
print(f" MPP operations: {'ā
' if self.gpu_info.supports_mpp else 'ā'}")
|
|
369
|
+
print(f" Tile functions: {'ā
' if self.gpu_info.supports_tile_functions else 'ā'}")
|
|
370
|
+
print(f" Atomic float: {'ā
' if self.gpu_info.supports_atomic_float else 'ā'}")
|
|
371
|
+
print(f" Fast math: {'ā
' if self.gpu_info.supports_fast_math else 'ā'}")
|
|
372
|
+
print(f" Function constants: {'ā
' if self.gpu_info.supports_function_constants else 'ā'}")
|
|
373
|
+
print(f" Max threads/threadgroup: {self.gpu_info.max_threads_per_threadgroup}")
|
|
374
|
+
print(f" Optimal dtype: {self.get_optimal_dtype()}")
|
|
375
|
+
|
|
376
|
+
if self.gpu_info.is_valid:
|
|
377
|
+
print("\nā
GPU meets all requirements for Cortex")
|
|
378
|
+
else:
|
|
379
|
+
print("\nā GPU does not meet requirements:")
|
|
380
|
+
for error in self.gpu_info.get_validation_errors():
|
|
381
|
+
print(f" ⢠{error}")
|
|
382
|
+
|
|
383
|
+
def ensure_gpu_available(self) -> None:
|
|
384
|
+
"""Ensure GPU is available or exit."""
|
|
385
|
+
is_valid, gpu_info, errors = self.validate()
|
|
386
|
+
|
|
387
|
+
if not is_valid:
|
|
388
|
+
print("ā ļø GPU validation warnings:")
|
|
389
|
+
for error in errors:
|
|
390
|
+
print(f" ⢠{error}")
|
|
391
|
+
print("\nNote: Cortex is optimized for Apple Silicon with unified memory architecture.")
|
|
392
|
+
print("Performance may be limited with current configuration.")
|
|
393
|
+
# Don't exit for testing purposes
|
|
394
|
+
# sys.exit(1)
|
|
395
|
+
|
|
396
|
+
if gpu_info and not gpu_info.is_apple_silicon:
|
|
397
|
+
print(f"ā ļø Warning: Detected {gpu_info.chip_name} - Apple Silicon recommended")
|
|
398
|
+
print(" Performance may not match specifications in PRD")
|
|
399
|
+
|
|
400
|
+
def get_gpu_memory_status(self) -> Dict[str, Any]:
|
|
401
|
+
"""Get current GPU memory status."""
|
|
402
|
+
if not self.gpu_info:
|
|
403
|
+
self.validate()
|
|
404
|
+
|
|
405
|
+
if not self.gpu_info:
|
|
406
|
+
return {
|
|
407
|
+
'available': False,
|
|
408
|
+
'error': 'GPU info not available'
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
memory_info = self._get_memory_info()
|
|
412
|
+
|
|
413
|
+
return {
|
|
414
|
+
'available': True,
|
|
415
|
+
'total_gb': memory_info['total'] / (1024**3),
|
|
416
|
+
'available_gb': memory_info['available'] / (1024**3),
|
|
417
|
+
'used_gb': memory_info['used'] / (1024**3),
|
|
418
|
+
'percent_used': memory_info['percent'],
|
|
419
|
+
'can_load_model': memory_info['available'] >= 20 * 1024**3
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
def verify_model_compatibility(self, model_size_gb: float) -> Tuple[bool, str]:
|
|
423
|
+
"""
|
|
424
|
+
Verify if a model can be loaded on GPU.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
model_size_gb: Model size in gigabytes
|
|
428
|
+
|
|
429
|
+
Returns:
|
|
430
|
+
Tuple of (can_load, message)
|
|
431
|
+
"""
|
|
432
|
+
memory_status = self.get_gpu_memory_status()
|
|
433
|
+
|
|
434
|
+
if not memory_status['available']:
|
|
435
|
+
return False, memory_status.get('error', 'GPU not available')
|
|
436
|
+
|
|
437
|
+
available_gb = memory_status['available_gb']
|
|
438
|
+
|
|
439
|
+
# Add overhead for KV cache, activations, and loading overhead
|
|
440
|
+
# Some models expand significantly in memory during loading:
|
|
441
|
+
# - Sharded models may temporarily duplicate memory
|
|
442
|
+
# - KV cache and attention buffers add overhead
|
|
443
|
+
# - Qwen models observed using 3.6x disk size in memory
|
|
444
|
+
# Use conservative multiplier to ensure successful loading
|
|
445
|
+
required_gb = model_size_gb * 3.5 # 250% overhead for safety
|
|
446
|
+
|
|
447
|
+
if required_gb > available_gb:
|
|
448
|
+
return False, f"Model requires {required_gb:.1f}GB (including overhead), only {available_gb:.1f}GB available"
|
|
449
|
+
|
|
450
|
+
return True, f"Model can be loaded ({required_gb:.1f}GB required / {available_gb:.1f}GB available)"
|
|
451
|
+
|
|
452
|
+
def main():
|
|
453
|
+
"""Main function for testing GPU validation."""
|
|
454
|
+
validator = GPUValidator()
|
|
455
|
+
validator.print_gpu_info()
|
|
456
|
+
|
|
457
|
+
is_valid, gpu_info, errors = validator.validate()
|
|
458
|
+
|
|
459
|
+
if is_valid:
|
|
460
|
+
print("\nā
System ready for Cortex")
|
|
461
|
+
else:
|
|
462
|
+
print("\nā System not ready for Cortex")
|
|
463
|
+
for error in errors:
|
|
464
|
+
print(f" ⢠{error}")
|
|
465
|
+
|
|
466
|
+
if __name__ == "__main__":
|
|
467
|
+
main()
|