cortex-llm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex/__init__.py +73 -0
- cortex/__main__.py +83 -0
- cortex/config.py +329 -0
- cortex/conversation_manager.py +468 -0
- cortex/fine_tuning/__init__.py +8 -0
- cortex/fine_tuning/dataset.py +332 -0
- cortex/fine_tuning/mlx_lora_trainer.py +502 -0
- cortex/fine_tuning/trainer.py +957 -0
- cortex/fine_tuning/wizard.py +707 -0
- cortex/gpu_validator.py +467 -0
- cortex/inference_engine.py +727 -0
- cortex/metal/__init__.py +275 -0
- cortex/metal/gpu_validator.py +177 -0
- cortex/metal/memory_pool.py +886 -0
- cortex/metal/mlx_accelerator.py +678 -0
- cortex/metal/mlx_converter.py +638 -0
- cortex/metal/mps_optimizer.py +417 -0
- cortex/metal/optimizer.py +665 -0
- cortex/metal/performance_profiler.py +364 -0
- cortex/model_downloader.py +130 -0
- cortex/model_manager.py +2187 -0
- cortex/quantization/__init__.py +5 -0
- cortex/quantization/dynamic_quantizer.py +736 -0
- cortex/template_registry/__init__.py +15 -0
- cortex/template_registry/auto_detector.py +144 -0
- cortex/template_registry/config_manager.py +234 -0
- cortex/template_registry/interactive.py +260 -0
- cortex/template_registry/registry.py +347 -0
- cortex/template_registry/template_profiles/__init__.py +5 -0
- cortex/template_registry/template_profiles/base.py +142 -0
- cortex/template_registry/template_profiles/complex/__init__.py +5 -0
- cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
- cortex/template_registry/template_profiles/standard/__init__.py +9 -0
- cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
- cortex/template_registry/template_profiles/standard/chatml.py +82 -0
- cortex/template_registry/template_profiles/standard/gemma.py +103 -0
- cortex/template_registry/template_profiles/standard/llama.py +87 -0
- cortex/template_registry/template_profiles/standard/simple.py +65 -0
- cortex/ui/__init__.py +120 -0
- cortex/ui/cli.py +1685 -0
- cortex/ui/markdown_render.py +185 -0
- cortex/ui/terminal_app.py +534 -0
- cortex_llm-1.0.0.dist-info/METADATA +275 -0
- cortex_llm-1.0.0.dist-info/RECORD +48 -0
- cortex_llm-1.0.0.dist-info/WHEEL +5 -0
- cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
- cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
- cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,665 @@
|
|
|
1
|
+
"""Unified Metal optimization interface for Apple Silicon LLM inference.
|
|
2
|
+
|
|
3
|
+
This module provides a simple, effective interface for accelerating LLM inference
|
|
4
|
+
on Apple Silicon using the most appropriate backend (MLX or MPS).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
from typing import Dict, Any, Optional, Union, Tuple, Callable
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from enum import Enum
|
|
12
|
+
import logging
|
|
13
|
+
import warnings
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
# Configure logging
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
# Add parent directory to path for imports
|
|
22
|
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
23
|
+
from gpu_validator import GPUValidator
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Backend(Enum):
|
|
27
|
+
"""Available acceleration backends."""
|
|
28
|
+
MLX = "mlx"
|
|
29
|
+
MPS = "mps"
|
|
30
|
+
CPU = "cpu"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class OptimizationConfig:
|
|
35
|
+
"""Configuration for Metal optimization."""
|
|
36
|
+
backend: Backend = Backend.MLX
|
|
37
|
+
dtype: str = "auto" # auto, float32, float16, bfloat16
|
|
38
|
+
batch_size: int = 1
|
|
39
|
+
max_memory_gb: Optional[float] = None # None = auto-detect
|
|
40
|
+
use_quantization: bool = True
|
|
41
|
+
quantization_bits: int = 8 # 4 or 8
|
|
42
|
+
compile_model: bool = True
|
|
43
|
+
use_kv_cache: bool = True
|
|
44
|
+
kv_cache_size: int = 2048
|
|
45
|
+
enable_profiling: bool = False
|
|
46
|
+
fallback_to_cpu: bool = True
|
|
47
|
+
|
|
48
|
+
def validate(self) -> bool:
|
|
49
|
+
"""Validate configuration settings."""
|
|
50
|
+
if self.quantization_bits not in [4, 8]:
|
|
51
|
+
raise ValueError(f"Quantization bits must be 4 or 8, got {self.quantization_bits}")
|
|
52
|
+
|
|
53
|
+
if self.batch_size < 1:
|
|
54
|
+
raise ValueError(f"Batch size must be >= 1, got {self.batch_size}")
|
|
55
|
+
|
|
56
|
+
if self.kv_cache_size < 0:
|
|
57
|
+
raise ValueError(f"KV cache size must be >= 0, got {self.kv_cache_size}")
|
|
58
|
+
|
|
59
|
+
return True
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class MetalOptimizer:
|
|
63
|
+
"""Unified Metal optimizer for LLM inference on Apple Silicon."""
|
|
64
|
+
|
|
65
|
+
def __init__(self, config: Optional[OptimizationConfig] = None):
|
|
66
|
+
"""Initialize the Metal optimizer.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
config: Optimization configuration. Uses defaults if None.
|
|
70
|
+
"""
|
|
71
|
+
self.config = config or OptimizationConfig()
|
|
72
|
+
self.config.validate()
|
|
73
|
+
|
|
74
|
+
# Initialize GPU validator
|
|
75
|
+
self.gpu_validator = GPUValidator()
|
|
76
|
+
self.gpu_validator.validate()
|
|
77
|
+
|
|
78
|
+
# Detect best backend
|
|
79
|
+
self.backend = self._select_backend()
|
|
80
|
+
self.device = self._get_device()
|
|
81
|
+
|
|
82
|
+
# Initialize backend-specific components
|
|
83
|
+
self._backend_optimizer = None
|
|
84
|
+
self._initialize_backend()
|
|
85
|
+
|
|
86
|
+
logger.info(f"MetalOptimizer initialized with backend: {self.backend.value}")
|
|
87
|
+
logger.info(f"GPU: {self.gpu_validator.get_gpu_family()}, "
|
|
88
|
+
f"bfloat16: {self.gpu_validator.check_bfloat16_support()}")
|
|
89
|
+
|
|
90
|
+
def _select_backend(self) -> Backend:
|
|
91
|
+
"""Select the best available backend.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Selected backend based on availability and configuration.
|
|
95
|
+
"""
|
|
96
|
+
if self.config.backend != Backend.MLX:
|
|
97
|
+
# User explicitly selected a backend
|
|
98
|
+
return self._validate_backend(self.config.backend)
|
|
99
|
+
|
|
100
|
+
# Auto-select best backend
|
|
101
|
+
try:
|
|
102
|
+
import mlx.core as mx
|
|
103
|
+
# MLX is available and preferred
|
|
104
|
+
return Backend.MLX
|
|
105
|
+
except ImportError:
|
|
106
|
+
logger.warning("MLX not available, falling back to MPS")
|
|
107
|
+
|
|
108
|
+
if torch.backends.mps.is_available():
|
|
109
|
+
return Backend.MPS
|
|
110
|
+
|
|
111
|
+
if self.config.fallback_to_cpu:
|
|
112
|
+
logger.warning("No GPU acceleration available, using CPU")
|
|
113
|
+
return Backend.CPU
|
|
114
|
+
|
|
115
|
+
raise RuntimeError("No suitable backend available and CPU fallback disabled")
|
|
116
|
+
|
|
117
|
+
def _validate_backend(self, backend: Backend) -> Backend:
|
|
118
|
+
"""Validate that the requested backend is available.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
backend: Requested backend
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
The backend if available
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
RuntimeError: If backend not available
|
|
128
|
+
"""
|
|
129
|
+
if backend == Backend.MLX:
|
|
130
|
+
try:
|
|
131
|
+
import mlx.core as mx
|
|
132
|
+
return backend
|
|
133
|
+
except ImportError:
|
|
134
|
+
raise RuntimeError("MLX backend requested but not installed")
|
|
135
|
+
|
|
136
|
+
elif backend == Backend.MPS:
|
|
137
|
+
if not torch.backends.mps.is_available():
|
|
138
|
+
raise RuntimeError("MPS backend requested but not available")
|
|
139
|
+
return backend
|
|
140
|
+
|
|
141
|
+
elif backend == Backend.CPU:
|
|
142
|
+
return backend
|
|
143
|
+
|
|
144
|
+
raise ValueError(f"Unknown backend: {backend}")
|
|
145
|
+
|
|
146
|
+
def _get_device(self) -> Union[str, torch.device]:
|
|
147
|
+
"""Get the appropriate device for the selected backend.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Device object or string
|
|
151
|
+
"""
|
|
152
|
+
if self.backend == Backend.MPS:
|
|
153
|
+
return torch.device("mps")
|
|
154
|
+
elif self.backend == Backend.CPU:
|
|
155
|
+
return torch.device("cpu")
|
|
156
|
+
else:
|
|
157
|
+
return "gpu" # MLX uses string device names
|
|
158
|
+
|
|
159
|
+
def _initialize_backend(self) -> None:
|
|
160
|
+
"""Initialize the backend-specific optimizer."""
|
|
161
|
+
if self.backend == Backend.MLX:
|
|
162
|
+
self._initialize_mlx()
|
|
163
|
+
elif self.backend == Backend.MPS:
|
|
164
|
+
self._initialize_mps()
|
|
165
|
+
else:
|
|
166
|
+
# CPU backend needs no special initialization
|
|
167
|
+
pass
|
|
168
|
+
|
|
169
|
+
def _initialize_mlx(self) -> None:
|
|
170
|
+
"""Initialize MLX backend."""
|
|
171
|
+
try:
|
|
172
|
+
from mlx_accelerator import MLXAccelerator, MLXConfig
|
|
173
|
+
|
|
174
|
+
mlx_config = MLXConfig(
|
|
175
|
+
compile_model=self.config.compile_model,
|
|
176
|
+
batch_size=self.config.batch_size,
|
|
177
|
+
dtype=self._get_mlx_dtype()
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
self._backend_optimizer = MLXAccelerator(mlx_config)
|
|
181
|
+
|
|
182
|
+
except ImportError as e:
|
|
183
|
+
logger.error(f"Failed to initialize MLX backend: {e}")
|
|
184
|
+
raise RuntimeError("MLX initialization failed")
|
|
185
|
+
|
|
186
|
+
def _initialize_mps(self) -> None:
|
|
187
|
+
"""Initialize MPS backend."""
|
|
188
|
+
try:
|
|
189
|
+
from mps_optimizer import MPSOptimizer, MPSConfig
|
|
190
|
+
|
|
191
|
+
mps_config = MPSConfig(
|
|
192
|
+
use_fp16=(self.config.dtype in ["float16", "auto"]),
|
|
193
|
+
use_bfloat16=self._should_use_bfloat16(),
|
|
194
|
+
max_batch_size=self.config.batch_size,
|
|
195
|
+
optimize_memory=True
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
self._backend_optimizer = MPSOptimizer(mps_config)
|
|
199
|
+
|
|
200
|
+
except ImportError as e:
|
|
201
|
+
logger.error(f"Failed to initialize MPS backend: {e}")
|
|
202
|
+
raise RuntimeError("MPS initialization failed")
|
|
203
|
+
|
|
204
|
+
def _get_mlx_dtype(self):
|
|
205
|
+
"""Get MLX dtype based on configuration and hardware."""
|
|
206
|
+
if self.config.dtype == "auto":
|
|
207
|
+
# Auto-select based on hardware
|
|
208
|
+
if self.gpu_validator.check_bfloat16_support():
|
|
209
|
+
import mlx.core as mx
|
|
210
|
+
return mx.bfloat16
|
|
211
|
+
else:
|
|
212
|
+
import mlx.core as mx
|
|
213
|
+
return mx.float16
|
|
214
|
+
|
|
215
|
+
elif self.config.dtype == "float32":
|
|
216
|
+
import mlx.core as mx
|
|
217
|
+
return mx.float32
|
|
218
|
+
elif self.config.dtype == "float16":
|
|
219
|
+
import mlx.core as mx
|
|
220
|
+
return mx.float16
|
|
221
|
+
elif self.config.dtype == "bfloat16":
|
|
222
|
+
import mlx.core as mx
|
|
223
|
+
return mx.bfloat16
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(f"Unknown dtype: {self.config.dtype}")
|
|
226
|
+
|
|
227
|
+
def _should_use_bfloat16(self) -> bool:
|
|
228
|
+
"""Determine if bfloat16 should be used."""
|
|
229
|
+
if self.config.dtype == "bfloat16":
|
|
230
|
+
return True
|
|
231
|
+
|
|
232
|
+
if self.config.dtype == "auto":
|
|
233
|
+
return self.gpu_validator.check_bfloat16_support()
|
|
234
|
+
|
|
235
|
+
return False
|
|
236
|
+
|
|
237
|
+
def optimize_model(
|
|
238
|
+
self,
|
|
239
|
+
model: Any,
|
|
240
|
+
model_type: str = "auto"
|
|
241
|
+
) -> Tuple[Any, Dict[str, Any]]:
|
|
242
|
+
"""Optimize a model for inference.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
model: Model to optimize (PyTorch, MLX, or Hugging Face)
|
|
246
|
+
model_type: Type of model ("pytorch", "mlx", "transformers", "auto")
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Tuple of (optimized_model, optimization_info)
|
|
250
|
+
"""
|
|
251
|
+
if model_type == "auto":
|
|
252
|
+
model_type = self._detect_model_type(model)
|
|
253
|
+
|
|
254
|
+
logger.info(f"Optimizing {model_type} model with {self.backend.value} backend")
|
|
255
|
+
|
|
256
|
+
optimization_info = {
|
|
257
|
+
"backend": self.backend.value,
|
|
258
|
+
"device": str(self.device),
|
|
259
|
+
"dtype": self.config.dtype,
|
|
260
|
+
"quantization": self.config.use_quantization,
|
|
261
|
+
"quantization_bits": self.config.quantization_bits if self.config.use_quantization else None,
|
|
262
|
+
"gpu_family": self.gpu_validator.get_gpu_family(),
|
|
263
|
+
"optimizations_applied": []
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
if self.backend == Backend.MLX:
|
|
267
|
+
optimized_model = self._optimize_mlx_model(model, model_type)
|
|
268
|
+
optimization_info["optimizations_applied"].extend([
|
|
269
|
+
"mlx_compilation",
|
|
270
|
+
"dtype_optimization",
|
|
271
|
+
"memory_layout"
|
|
272
|
+
])
|
|
273
|
+
|
|
274
|
+
elif self.backend == Backend.MPS:
|
|
275
|
+
optimized_model = self._optimize_mps_model(model, model_type)
|
|
276
|
+
optimization_info["optimizations_applied"].extend([
|
|
277
|
+
"mps_optimization",
|
|
278
|
+
"dtype_conversion",
|
|
279
|
+
"memory_optimization"
|
|
280
|
+
])
|
|
281
|
+
|
|
282
|
+
else:
|
|
283
|
+
# CPU - minimal optimization
|
|
284
|
+
optimized_model = model
|
|
285
|
+
optimization_info["optimizations_applied"].append("none")
|
|
286
|
+
|
|
287
|
+
# Apply quantization if requested
|
|
288
|
+
if self.config.use_quantization:
|
|
289
|
+
optimized_model = self._apply_quantization(
|
|
290
|
+
optimized_model,
|
|
291
|
+
self.config.quantization_bits
|
|
292
|
+
)
|
|
293
|
+
optimization_info["optimizations_applied"].append(
|
|
294
|
+
f"int{self.config.quantization_bits}_quantization"
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
return optimized_model, optimization_info
|
|
298
|
+
|
|
299
|
+
def _detect_model_type(self, model: Any) -> str:
|
|
300
|
+
"""Detect the type of model.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
model: Model to detect
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
Model type string
|
|
307
|
+
"""
|
|
308
|
+
# Check for PyTorch model
|
|
309
|
+
if hasattr(model, "parameters") and hasattr(model, "forward"):
|
|
310
|
+
return "pytorch"
|
|
311
|
+
|
|
312
|
+
# Check for MLX model
|
|
313
|
+
if hasattr(model, "apply_to_parameters"):
|
|
314
|
+
return "mlx"
|
|
315
|
+
|
|
316
|
+
# Check for Hugging Face transformers
|
|
317
|
+
if hasattr(model, "config") and hasattr(model, "forward"):
|
|
318
|
+
return "transformers"
|
|
319
|
+
|
|
320
|
+
return "unknown"
|
|
321
|
+
|
|
322
|
+
def _optimize_mlx_model(self, model: Any, model_type: str) -> Any:
|
|
323
|
+
"""Optimize model using MLX backend.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
model: Model to optimize
|
|
327
|
+
model_type: Type of model
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
Optimized model
|
|
331
|
+
"""
|
|
332
|
+
if not self._backend_optimizer:
|
|
333
|
+
raise RuntimeError("MLX backend not initialized")
|
|
334
|
+
|
|
335
|
+
if model_type == "pytorch":
|
|
336
|
+
# Convert PyTorch model to MLX
|
|
337
|
+
logger.warning("PyTorch to MLX conversion not yet implemented")
|
|
338
|
+
return model
|
|
339
|
+
|
|
340
|
+
elif model_type == "mlx":
|
|
341
|
+
# Already MLX model, just optimize
|
|
342
|
+
return self._backend_optimizer.optimize_model(model)
|
|
343
|
+
|
|
344
|
+
elif model_type == "transformers":
|
|
345
|
+
# Convert Hugging Face model to MLX
|
|
346
|
+
logger.warning("Transformers to MLX conversion not yet implemented")
|
|
347
|
+
return model
|
|
348
|
+
|
|
349
|
+
return model
|
|
350
|
+
|
|
351
|
+
def _optimize_mps_model(self, model: Any, model_type: str) -> Any:
|
|
352
|
+
"""Optimize model using MPS backend.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
model: Model to optimize
|
|
356
|
+
model_type: Type of model
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
Optimized model
|
|
360
|
+
"""
|
|
361
|
+
if not self._backend_optimizer:
|
|
362
|
+
raise RuntimeError("MPS backend not initialized")
|
|
363
|
+
|
|
364
|
+
if model_type in ["pytorch", "transformers"]:
|
|
365
|
+
# MPS works with PyTorch models
|
|
366
|
+
return self._backend_optimizer.optimize_model(model)
|
|
367
|
+
|
|
368
|
+
elif model_type == "mlx":
|
|
369
|
+
# Cannot use MLX model with MPS
|
|
370
|
+
logger.error("Cannot use MLX model with MPS backend")
|
|
371
|
+
raise ValueError("MLX models not compatible with MPS backend")
|
|
372
|
+
|
|
373
|
+
return model
|
|
374
|
+
|
|
375
|
+
def _apply_quantization(self, model: Any, bits: int) -> Any:
|
|
376
|
+
"""Apply quantization to the model.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
model: Model to quantize
|
|
380
|
+
bits: Quantization bits (4 or 8)
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
Quantized model
|
|
384
|
+
"""
|
|
385
|
+
if self.backend == Backend.MLX:
|
|
386
|
+
# Use MLX quantization
|
|
387
|
+
if hasattr(self._backend_optimizer, "quantize_model"):
|
|
388
|
+
return self._backend_optimizer.quantize_model(model, bits)
|
|
389
|
+
|
|
390
|
+
elif self.backend == Backend.MPS:
|
|
391
|
+
# Use custom quantization for PyTorch
|
|
392
|
+
try:
|
|
393
|
+
from quantization.dynamic_quantizer import DynamicQuantizer, QuantizationConfig
|
|
394
|
+
|
|
395
|
+
quantizer = DynamicQuantizer(
|
|
396
|
+
QuantizationConfig(
|
|
397
|
+
mode="int8" if bits == 8 else "int4",
|
|
398
|
+
device=self.device
|
|
399
|
+
)
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
quantized_model, _ = quantizer.quantize_model(model)
|
|
403
|
+
return quantized_model
|
|
404
|
+
|
|
405
|
+
except ImportError:
|
|
406
|
+
logger.warning("Quantization module not available")
|
|
407
|
+
return model
|
|
408
|
+
|
|
409
|
+
return model
|
|
410
|
+
|
|
411
|
+
def create_inference_session(
|
|
412
|
+
self,
|
|
413
|
+
model: Any,
|
|
414
|
+
tokenizer: Optional[Any] = None
|
|
415
|
+
) -> 'InferenceSession':
|
|
416
|
+
"""Create an optimized inference session.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
model: Model for inference
|
|
420
|
+
tokenizer: Optional tokenizer
|
|
421
|
+
|
|
422
|
+
Returns:
|
|
423
|
+
InferenceSession object
|
|
424
|
+
"""
|
|
425
|
+
optimized_model, info = self.optimize_model(model)
|
|
426
|
+
|
|
427
|
+
return InferenceSession(
|
|
428
|
+
model=optimized_model,
|
|
429
|
+
tokenizer=tokenizer,
|
|
430
|
+
optimizer=self,
|
|
431
|
+
optimization_info=info
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
def get_memory_usage(self) -> Dict[str, float]:
|
|
435
|
+
"""Get current memory usage statistics.
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
Dictionary with memory statistics in GB
|
|
439
|
+
"""
|
|
440
|
+
import psutil
|
|
441
|
+
|
|
442
|
+
vm = psutil.virtual_memory()
|
|
443
|
+
stats = {
|
|
444
|
+
"total_gb": vm.total / (1024**3),
|
|
445
|
+
"available_gb": vm.available / (1024**3),
|
|
446
|
+
"used_gb": vm.used / (1024**3),
|
|
447
|
+
"percent_used": vm.percent
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
if self.backend == Backend.MPS:
|
|
451
|
+
if hasattr(torch.mps, "current_allocated_memory"):
|
|
452
|
+
stats["mps_allocated_gb"] = torch.mps.current_allocated_memory() / (1024**3)
|
|
453
|
+
|
|
454
|
+
return stats
|
|
455
|
+
|
|
456
|
+
def profile_inference(
|
|
457
|
+
self,
|
|
458
|
+
model: Any,
|
|
459
|
+
input_shape: Tuple[int, ...],
|
|
460
|
+
num_iterations: int = 100
|
|
461
|
+
) -> Dict[str, Any]:
|
|
462
|
+
"""Profile model inference performance.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
model: Model to profile
|
|
466
|
+
input_shape: Shape of input tensor
|
|
467
|
+
num_iterations: Number of iterations for profiling
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Profiling results
|
|
471
|
+
"""
|
|
472
|
+
if self.backend == Backend.MLX and self._backend_optimizer:
|
|
473
|
+
return self._backend_optimizer.profile_model(
|
|
474
|
+
model, input_shape, num_iterations
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
elif self.backend == Backend.MPS and self._backend_optimizer:
|
|
478
|
+
return self._backend_optimizer.profile_model(
|
|
479
|
+
model, input_shape, num_iterations
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# Basic CPU profiling
|
|
483
|
+
import time
|
|
484
|
+
|
|
485
|
+
if self.backend == Backend.CPU:
|
|
486
|
+
dummy_input = torch.randn(input_shape)
|
|
487
|
+
else:
|
|
488
|
+
dummy_input = torch.randn(input_shape).to(self.device)
|
|
489
|
+
|
|
490
|
+
# Warmup
|
|
491
|
+
for _ in range(10):
|
|
492
|
+
with torch.no_grad():
|
|
493
|
+
_ = model(dummy_input)
|
|
494
|
+
|
|
495
|
+
# Profile
|
|
496
|
+
start_time = time.perf_counter()
|
|
497
|
+
for _ in range(num_iterations):
|
|
498
|
+
with torch.no_grad():
|
|
499
|
+
_ = model(dummy_input)
|
|
500
|
+
|
|
501
|
+
end_time = time.perf_counter()
|
|
502
|
+
avg_time = (end_time - start_time) / num_iterations
|
|
503
|
+
|
|
504
|
+
return {
|
|
505
|
+
"backend": self.backend.value,
|
|
506
|
+
"avg_inference_time": avg_time,
|
|
507
|
+
"throughput": input_shape[0] / avg_time,
|
|
508
|
+
"device": str(self.device),
|
|
509
|
+
"iterations": num_iterations
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
def cleanup(self) -> None:
|
|
513
|
+
"""Clean up resources."""
|
|
514
|
+
if self.backend == Backend.MPS:
|
|
515
|
+
torch.mps.empty_cache()
|
|
516
|
+
|
|
517
|
+
if hasattr(self._backend_optimizer, "cleanup"):
|
|
518
|
+
self._backend_optimizer.cleanup()
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
class InferenceSession:
|
|
522
|
+
"""Optimized inference session for LLM models."""
|
|
523
|
+
|
|
524
|
+
def __init__(
|
|
525
|
+
self,
|
|
526
|
+
model: Any,
|
|
527
|
+
tokenizer: Optional[Any],
|
|
528
|
+
optimizer: MetalOptimizer,
|
|
529
|
+
optimization_info: Dict[str, Any]
|
|
530
|
+
):
|
|
531
|
+
"""Initialize inference session.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
model: Optimized model
|
|
535
|
+
tokenizer: Optional tokenizer
|
|
536
|
+
optimizer: MetalOptimizer instance
|
|
537
|
+
optimization_info: Optimization information
|
|
538
|
+
"""
|
|
539
|
+
self.model = model
|
|
540
|
+
self.tokenizer = tokenizer
|
|
541
|
+
self.optimizer = optimizer
|
|
542
|
+
self.optimization_info = optimization_info
|
|
543
|
+
self.generation_config = {}
|
|
544
|
+
|
|
545
|
+
def generate(
|
|
546
|
+
self,
|
|
547
|
+
prompt: Union[str, torch.Tensor],
|
|
548
|
+
max_tokens: int = 100,
|
|
549
|
+
temperature: float = 0.7,
|
|
550
|
+
top_p: float = 0.9,
|
|
551
|
+
**kwargs
|
|
552
|
+
) -> Union[str, torch.Tensor]:
|
|
553
|
+
"""Generate text from prompt.
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
prompt: Input prompt (string or tensor)
|
|
557
|
+
max_tokens: Maximum tokens to generate
|
|
558
|
+
temperature: Sampling temperature
|
|
559
|
+
top_p: Nucleus sampling parameter
|
|
560
|
+
**kwargs: Additional generation parameters
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
Generated text or tokens
|
|
564
|
+
"""
|
|
565
|
+
# This is a simplified interface - actual implementation
|
|
566
|
+
# would depend on the model type and backend
|
|
567
|
+
|
|
568
|
+
if isinstance(prompt, str) and self.tokenizer:
|
|
569
|
+
# Encode prompt
|
|
570
|
+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
|
|
571
|
+
|
|
572
|
+
if self.optimizer.backend == Backend.MPS:
|
|
573
|
+
input_ids = input_ids.to(self.optimizer.device)
|
|
574
|
+
else:
|
|
575
|
+
input_ids = prompt
|
|
576
|
+
|
|
577
|
+
# Generate based on backend
|
|
578
|
+
if self.optimizer.backend == Backend.MLX:
|
|
579
|
+
# MLX generation (simplified)
|
|
580
|
+
output = self._generate_mlx(input_ids, max_tokens, temperature, top_p)
|
|
581
|
+
else:
|
|
582
|
+
# PyTorch generation
|
|
583
|
+
output = self._generate_pytorch(input_ids, max_tokens, temperature, top_p, **kwargs)
|
|
584
|
+
|
|
585
|
+
# Decode if tokenizer available
|
|
586
|
+
if self.tokenizer and not isinstance(output, str):
|
|
587
|
+
return self.tokenizer.decode(output[0], skip_special_tokens=True)
|
|
588
|
+
|
|
589
|
+
return output
|
|
590
|
+
|
|
591
|
+
def _generate_mlx(
|
|
592
|
+
self,
|
|
593
|
+
input_ids: Any,
|
|
594
|
+
max_tokens: int,
|
|
595
|
+
temperature: float,
|
|
596
|
+
top_p: float
|
|
597
|
+
) -> Any:
|
|
598
|
+
"""Generate using MLX backend."""
|
|
599
|
+
# Simplified MLX generation
|
|
600
|
+
# Actual implementation would use MLX-specific generation
|
|
601
|
+
return input_ids
|
|
602
|
+
|
|
603
|
+
def _generate_pytorch(
|
|
604
|
+
self,
|
|
605
|
+
input_ids: torch.Tensor,
|
|
606
|
+
max_tokens: int,
|
|
607
|
+
temperature: float,
|
|
608
|
+
top_p: float,
|
|
609
|
+
**kwargs
|
|
610
|
+
) -> torch.Tensor:
|
|
611
|
+
"""Generate using PyTorch backend."""
|
|
612
|
+
# Use model's generate method if available
|
|
613
|
+
if hasattr(self.model, "generate"):
|
|
614
|
+
return self.model.generate(
|
|
615
|
+
input_ids,
|
|
616
|
+
max_new_tokens=max_tokens,
|
|
617
|
+
temperature=temperature,
|
|
618
|
+
top_p=top_p,
|
|
619
|
+
do_sample=temperature > 0,
|
|
620
|
+
**kwargs
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# Simple generation loop (fallback)
|
|
624
|
+
generated = input_ids
|
|
625
|
+
for _ in range(max_tokens):
|
|
626
|
+
with torch.no_grad():
|
|
627
|
+
outputs = self.model(generated)
|
|
628
|
+
logits = outputs.logits if hasattr(outputs, "logits") else outputs
|
|
629
|
+
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
|
|
630
|
+
generated = torch.cat([generated, next_token], dim=1)
|
|
631
|
+
|
|
632
|
+
return generated
|
|
633
|
+
|
|
634
|
+
def stream_generate(
|
|
635
|
+
self,
|
|
636
|
+
prompt: str,
|
|
637
|
+
max_tokens: int = 100,
|
|
638
|
+
callback: Optional[Callable] = None,
|
|
639
|
+
**kwargs
|
|
640
|
+
):
|
|
641
|
+
"""Stream generation token by token.
|
|
642
|
+
|
|
643
|
+
Args:
|
|
644
|
+
prompt: Input prompt
|
|
645
|
+
max_tokens: Maximum tokens to generate
|
|
646
|
+
callback: Optional callback for each token
|
|
647
|
+
**kwargs: Additional parameters
|
|
648
|
+
|
|
649
|
+
Yields:
|
|
650
|
+
Generated tokens
|
|
651
|
+
"""
|
|
652
|
+
# This would implement streaming generation
|
|
653
|
+
# For now, just yield the final result
|
|
654
|
+
result = self.generate(prompt, max_tokens, **kwargs)
|
|
655
|
+
yield result
|
|
656
|
+
|
|
657
|
+
def get_info(self) -> Dict[str, Any]:
|
|
658
|
+
"""Get session information.
|
|
659
|
+
|
|
660
|
+
Returns:
|
|
661
|
+
Session information dictionary
|
|
662
|
+
"""
|
|
663
|
+
info = self.optimization_info.copy()
|
|
664
|
+
info["memory_usage"] = self.optimizer.get_memory_usage()
|
|
665
|
+
return info
|