cortex-llm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. cortex/__init__.py +73 -0
  2. cortex/__main__.py +83 -0
  3. cortex/config.py +329 -0
  4. cortex/conversation_manager.py +468 -0
  5. cortex/fine_tuning/__init__.py +8 -0
  6. cortex/fine_tuning/dataset.py +332 -0
  7. cortex/fine_tuning/mlx_lora_trainer.py +502 -0
  8. cortex/fine_tuning/trainer.py +957 -0
  9. cortex/fine_tuning/wizard.py +707 -0
  10. cortex/gpu_validator.py +467 -0
  11. cortex/inference_engine.py +727 -0
  12. cortex/metal/__init__.py +275 -0
  13. cortex/metal/gpu_validator.py +177 -0
  14. cortex/metal/memory_pool.py +886 -0
  15. cortex/metal/mlx_accelerator.py +678 -0
  16. cortex/metal/mlx_converter.py +638 -0
  17. cortex/metal/mps_optimizer.py +417 -0
  18. cortex/metal/optimizer.py +665 -0
  19. cortex/metal/performance_profiler.py +364 -0
  20. cortex/model_downloader.py +130 -0
  21. cortex/model_manager.py +2187 -0
  22. cortex/quantization/__init__.py +5 -0
  23. cortex/quantization/dynamic_quantizer.py +736 -0
  24. cortex/template_registry/__init__.py +15 -0
  25. cortex/template_registry/auto_detector.py +144 -0
  26. cortex/template_registry/config_manager.py +234 -0
  27. cortex/template_registry/interactive.py +260 -0
  28. cortex/template_registry/registry.py +347 -0
  29. cortex/template_registry/template_profiles/__init__.py +5 -0
  30. cortex/template_registry/template_profiles/base.py +142 -0
  31. cortex/template_registry/template_profiles/complex/__init__.py +5 -0
  32. cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
  33. cortex/template_registry/template_profiles/standard/__init__.py +9 -0
  34. cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
  35. cortex/template_registry/template_profiles/standard/chatml.py +82 -0
  36. cortex/template_registry/template_profiles/standard/gemma.py +103 -0
  37. cortex/template_registry/template_profiles/standard/llama.py +87 -0
  38. cortex/template_registry/template_profiles/standard/simple.py +65 -0
  39. cortex/ui/__init__.py +120 -0
  40. cortex/ui/cli.py +1685 -0
  41. cortex/ui/markdown_render.py +185 -0
  42. cortex/ui/terminal_app.py +534 -0
  43. cortex_llm-1.0.0.dist-info/METADATA +275 -0
  44. cortex_llm-1.0.0.dist-info/RECORD +48 -0
  45. cortex_llm-1.0.0.dist-info/WHEEL +5 -0
  46. cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
  47. cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
  48. cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,417 @@
1
+ """MPS backend optimization for PyTorch models on Metal."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Dict, Any, Optional, List, Tuple, Callable
6
+ from dataclasses import dataclass
7
+ import functools
8
+ import warnings
9
+ import sys
10
+ import os
11
+
12
+ # Add parent directory to path for imports
13
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
+ from gpu_validator import GPUValidator
15
+
16
+ @dataclass
17
+ class MPSConfig:
18
+ """Configuration for MPS optimization."""
19
+ use_channels_last: bool = True
20
+ use_fp16: Optional[bool] = None # None = auto-detect based on hardware
21
+ use_bfloat16: Optional[bool] = None # None = auto-detect based on hardware
22
+ use_jit: bool = False # JIT not fully supported on MPS yet
23
+ use_graph_mode: bool = True
24
+ fuse_operations: bool = True
25
+ optimize_memory: bool = True
26
+ max_batch_size: int = 8
27
+ prefetch_factor: int = 2
28
+ num_workers: int = 0 # MPS works best with main thread
29
+ auto_detect_dtype: bool = True # Automatically select best dtype based on hardware
30
+
31
+ class MPSOptimizer:
32
+ """Optimize PyTorch models for Metal Performance Shaders backend."""
33
+
34
+ FUSED_OPERATIONS = {
35
+ "conv_bn_relu": ["Conv2d", "BatchNorm2d", "ReLU"],
36
+ "linear_relu": ["Linear", "ReLU"],
37
+ "linear_gelu": ["Linear", "GELU"],
38
+ "layer_norm_linear": ["LayerNorm", "Linear"],
39
+ }
40
+
41
+ def __init__(self, config: Optional[MPSConfig] = None):
42
+ """Initialize MPS optimizer."""
43
+ self.config = config or MPSConfig()
44
+
45
+ if not torch.backends.mps.is_available():
46
+ raise RuntimeError("MPS backend not available")
47
+
48
+ if not torch.backends.mps.is_built():
49
+ raise RuntimeError("PyTorch not built with MPS support")
50
+
51
+ self.device = torch.device("mps")
52
+
53
+ # Initialize GPU validator for hardware detection
54
+ self.gpu_validator = GPUValidator()
55
+ self.gpu_validator.validate()
56
+
57
+ # Auto-detect optimal dtype if enabled
58
+ if (
59
+ self.config.auto_detect_dtype
60
+ and self.config.use_bfloat16 is None
61
+ and self.config.use_fp16 is None
62
+ ):
63
+ self._auto_detect_dtype()
64
+ elif self.config.use_bfloat16 is None and self.config.use_fp16 is None:
65
+ # Default to fp16 if auto-detect is disabled but no dtype is specified
66
+ self.config.use_bfloat16 = False
67
+ self.config.use_fp16 = True
68
+
69
+ def _auto_detect_dtype(self) -> None:
70
+ """Automatically detect optimal dtype based on hardware."""
71
+ if self.gpu_validator.check_bfloat16_support():
72
+ # Check if PyTorch supports bfloat16 on MPS
73
+ try:
74
+ test_tensor = torch.tensor([1.0], dtype=torch.bfloat16, device=self.device)
75
+ self.config.use_bfloat16 = True
76
+ self.config.use_fp16 = False # Prefer bfloat16 over fp16
77
+ except (RuntimeError, TypeError):
78
+ # PyTorch doesn't support bfloat16 on MPS yet, fall back to fp16
79
+ self.config.use_bfloat16 = False
80
+ self.config.use_fp16 = True
81
+ else:
82
+ # Hardware doesn't support bfloat16, use fp16
83
+ self.config.use_bfloat16 = False
84
+ self.config.use_fp16 = True
85
+
86
+ def optimize_model(
87
+ self,
88
+ model: nn.Module,
89
+ example_input: Optional[torch.Tensor] = None
90
+ ) -> nn.Module:
91
+ """
92
+ Optimize a PyTorch model for MPS backend.
93
+
94
+ Args:
95
+ model: PyTorch model to optimize
96
+ example_input: Example input for shape inference
97
+
98
+ Returns:
99
+ Optimized model
100
+ """
101
+ model = model.to(self.device)
102
+
103
+ if self.config.use_channels_last:
104
+ model = self._convert_to_channels_last(model)
105
+
106
+ # Convert to optimal dtype (bfloat16 or fp16)
107
+ if self.config.use_bfloat16 or self.config.use_fp16:
108
+ model = self._convert_dtype(model)
109
+
110
+ if self.config.fuse_operations:
111
+ model = self._fuse_operations(model)
112
+
113
+ if self.config.optimize_memory:
114
+ model = self._optimize_memory_layout(model)
115
+
116
+ if self.config.use_graph_mode and example_input is not None:
117
+ model = self._enable_graph_mode(model, example_input)
118
+
119
+ model.eval()
120
+
121
+ return model
122
+
123
+ def _convert_to_channels_last(self, model: nn.Module) -> nn.Module:
124
+ """Convert model to channels_last memory format for better performance."""
125
+ def convert_layer(module):
126
+ if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
127
+ module = module.to(memory_format=torch.channels_last)
128
+ return module
129
+
130
+ model.apply(convert_layer)
131
+ return model
132
+
133
+ def _convert_dtype(self, model: nn.Module) -> nn.Module:
134
+ """
135
+ Convert model to optimal dtype (bfloat16 or fp16) for faster computation.
136
+
137
+ Args:
138
+ model: Model to convert
139
+
140
+ Returns:
141
+ Converted model
142
+ """
143
+ # Determine target dtype
144
+ if self.config.use_bfloat16:
145
+ target_dtype = torch.bfloat16
146
+ conversion_method = lambda m: m.to(dtype=torch.bfloat16)
147
+ else:
148
+ target_dtype = torch.float16
149
+ conversion_method = lambda m: m.half()
150
+
151
+ def should_convert(module):
152
+ # These layers should stay in float32 for numerical stability
153
+ exclude_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm)
154
+ return not isinstance(module, exclude_types)
155
+
156
+ for name, module in model.named_modules():
157
+ if should_convert(module):
158
+ if hasattr(module, 'weight') and module.weight is not None:
159
+ if module.weight.dtype == torch.float32:
160
+ conversion_method(module)
161
+
162
+ return model
163
+
164
+ def get_optimal_dtype(self) -> torch.dtype:
165
+ """
166
+ Get the optimal dtype for current hardware.
167
+
168
+ Returns:
169
+ torch.bfloat16, torch.float16, or torch.float32
170
+ """
171
+ if self.config.use_bfloat16:
172
+ return torch.bfloat16
173
+ elif self.config.use_fp16:
174
+ return torch.float16
175
+ else:
176
+ return torch.float32
177
+
178
+ def _fuse_operations(self, model: nn.Module) -> nn.Module:
179
+ """Fuse compatible operations for better performance."""
180
+ # Note: Conv-BN fusion requires specific module pairs, not whole model
181
+ # Skipping automatic fusion as it's model-specific
182
+
183
+ for name, module in model.named_children():
184
+ if isinstance(module, nn.Sequential):
185
+ fused = self._try_fuse_sequential(module)
186
+ if fused is not None:
187
+ setattr(model, name, fused)
188
+
189
+ return model
190
+
191
+ def _try_fuse_sequential(self, sequential: nn.Sequential) -> Optional[nn.Module]:
192
+ """Try to fuse operations in a sequential module."""
193
+ layers = list(sequential.children())
194
+
195
+ if len(layers) < 2:
196
+ return None
197
+
198
+ for pattern_name, pattern in self.FUSED_OPERATIONS.items():
199
+ if self._matches_pattern(layers, pattern):
200
+ return self._create_fused_module(layers, pattern_name)
201
+
202
+ return None
203
+
204
+ def _matches_pattern(self, layers: List[nn.Module], pattern: List[str]) -> bool:
205
+ """Check if layers match a fusion pattern."""
206
+ if len(layers) < len(pattern):
207
+ return False
208
+
209
+ for i, expected_type in enumerate(pattern):
210
+ if not hasattr(nn, expected_type):
211
+ return False
212
+ expected_class = getattr(nn, expected_type)
213
+ if not isinstance(layers[i], expected_class):
214
+ return False
215
+
216
+ return True
217
+
218
+ def _create_fused_module(self, layers: List[nn.Module], pattern_name: str) -> nn.Module:
219
+ """Create a fused module from layers."""
220
+ if pattern_name == "conv_bn_relu":
221
+ return ConvBNReLU(layers[0], layers[1])
222
+ elif pattern_name == "linear_relu":
223
+ return LinearReLU(layers[0])
224
+ elif pattern_name == "linear_gelu":
225
+ return LinearGELU(layers[0])
226
+ else:
227
+ return nn.Sequential(*layers)
228
+
229
+ def _optimize_memory_layout(self, model: nn.Module) -> nn.Module:
230
+ """Optimize memory layout for MPS."""
231
+ def optimize_layer(module):
232
+ if hasattr(module, 'weight'):
233
+ if module.weight is not None and module.weight.is_contiguous():
234
+ module.weight = module.weight.contiguous()
235
+
236
+ if hasattr(module, 'bias'):
237
+ if module.bias is not None and not module.bias.is_contiguous():
238
+ module.bias = module.bias.contiguous()
239
+
240
+ return module
241
+
242
+ model.apply(optimize_layer)
243
+ return model
244
+
245
+ def _enable_graph_mode(
246
+ self,
247
+ model: nn.Module,
248
+ example_input: torch.Tensor
249
+ ) -> nn.Module:
250
+ """Enable graph mode optimization (experimental)."""
251
+ try:
252
+ with warnings.catch_warnings():
253
+ warnings.simplefilter("ignore")
254
+
255
+ example_input = example_input.to(self.device)
256
+
257
+ # Convert input to optimal dtype
258
+ if self.config.use_bfloat16:
259
+ example_input = example_input.to(dtype=torch.bfloat16)
260
+ elif self.config.use_fp16:
261
+ example_input = example_input.half()
262
+
263
+ with torch.no_grad():
264
+ _ = model(example_input)
265
+
266
+ except Exception as e:
267
+ print(f"Warning: Graph mode optimization failed: {e}")
268
+
269
+ return model
270
+
271
+ def optimize_dataloader(
272
+ self,
273
+ dataloader: torch.utils.data.DataLoader
274
+ ) -> torch.utils.data.DataLoader:
275
+ """Optimize DataLoader for MPS backend."""
276
+ optimized_dataloader = torch.utils.data.DataLoader(
277
+ dataloader.dataset,
278
+ batch_size=min(dataloader.batch_size, self.config.max_batch_size),
279
+ shuffle=dataloader.shuffle if hasattr(dataloader, 'shuffle') else False,
280
+ num_workers=self.config.num_workers,
281
+ pin_memory=False, # Not needed for unified memory
282
+ prefetch_factor=self.config.prefetch_factor if self.config.num_workers > 0 else None
283
+ )
284
+
285
+ return optimized_dataloader
286
+
287
+ def profile_model(
288
+ self,
289
+ model: nn.Module,
290
+ input_shape: Tuple[int, ...],
291
+ num_iterations: int = 100
292
+ ) -> Dict[str, Any]:
293
+ """Profile model performance on MPS."""
294
+ model.eval()
295
+ device = torch.device("mps")
296
+
297
+ dummy_input = torch.randn(input_shape).to(device)
298
+
299
+ # Convert to optimal dtype
300
+ if self.config.use_bfloat16:
301
+ dummy_input = dummy_input.to(dtype=torch.bfloat16)
302
+ elif self.config.use_fp16:
303
+ dummy_input = dummy_input.half()
304
+
305
+ torch.mps.synchronize()
306
+
307
+ import time
308
+ warmup_iterations = 10
309
+ for _ in range(warmup_iterations):
310
+ with torch.no_grad():
311
+ _ = model(dummy_input)
312
+
313
+ torch.mps.synchronize()
314
+
315
+ start_time = time.perf_counter()
316
+ for _ in range(num_iterations):
317
+ with torch.no_grad():
318
+ _ = model(dummy_input)
319
+
320
+ torch.mps.synchronize()
321
+ end_time = time.perf_counter()
322
+
323
+ avg_time = (end_time - start_time) / num_iterations
324
+ throughput = input_shape[0] / avg_time if avg_time > 0 else 0
325
+
326
+ memory_allocated = torch.mps.current_allocated_memory() if hasattr(torch.mps, 'current_allocated_memory') else 0
327
+
328
+ return {
329
+ "avg_inference_time": avg_time,
330
+ "throughput": throughput,
331
+ "memory_allocated": memory_allocated,
332
+ "device": "mps",
333
+ "dtype": str(self.get_optimal_dtype()),
334
+ "fp16": self.config.use_fp16,
335
+ "bfloat16": self.config.use_bfloat16,
336
+ "batch_size": input_shape[0]
337
+ }
338
+
339
+ @staticmethod
340
+ def get_mps_info() -> Dict[str, Any]:
341
+ """Get MPS backend information."""
342
+ info = {
343
+ "available": torch.backends.mps.is_available(),
344
+ "built": torch.backends.mps.is_built()
345
+ }
346
+
347
+ if info["available"]:
348
+ info["current_allocated_memory"] = torch.mps.current_allocated_memory() if hasattr(torch.mps, 'current_allocated_memory') else 0
349
+ info["driver_allocated_memory"] = torch.mps.driver_allocated_memory() if hasattr(torch.mps, 'driver_allocated_memory') else 0
350
+
351
+ # Check bfloat16 support
352
+ try:
353
+ test_tensor = torch.tensor([1.0], dtype=torch.bfloat16, device="mps")
354
+ info["bfloat16_supported"] = True
355
+ except (RuntimeError, TypeError):
356
+ info["bfloat16_supported"] = False
357
+
358
+ return info
359
+
360
+ def get_optimization_summary(self) -> Dict[str, Any]:
361
+ """
362
+ Get summary of optimizations applied.
363
+
364
+ Returns:
365
+ Dictionary with optimization details
366
+ """
367
+ optimal_dtype = self.get_optimal_dtype()
368
+
369
+ return {
370
+ "device": str(self.device),
371
+ "optimal_dtype": str(optimal_dtype),
372
+ "dtype_bits": 16 if optimal_dtype in [torch.float16, torch.bfloat16] else 32,
373
+ "memory_reduction": "50%" if optimal_dtype in [torch.float16, torch.bfloat16] else "0%",
374
+ "hardware_features": {
375
+ "bfloat16": self.config.use_bfloat16,
376
+ "channels_last": self.config.use_channels_last,
377
+ "graph_mode": self.config.use_graph_mode,
378
+ "fused_operations": self.config.fuse_operations,
379
+ "memory_optimized": self.config.optimize_memory
380
+ },
381
+ "gpu_family": self.gpu_validator.gpu_info.gpu_family if self.gpu_validator.gpu_info else "unknown",
382
+ "expected_speedup": "10-15%" if self.config.use_bfloat16 else "baseline"
383
+ }
384
+
385
+ class ConvBNReLU(nn.Module):
386
+ """Fused Conv-BatchNorm-ReLU module."""
387
+
388
+ def __init__(self, conv: nn.Conv2d, bn: nn.BatchNorm2d):
389
+ super().__init__()
390
+ self.conv = conv
391
+ self.bn = bn
392
+ self.relu = nn.ReLU(inplace=True)
393
+
394
+ def forward(self, x):
395
+ return self.relu(self.bn(self.conv(x)))
396
+
397
+ class LinearReLU(nn.Module):
398
+ """Fused Linear-ReLU module."""
399
+
400
+ def __init__(self, linear: nn.Linear):
401
+ super().__init__()
402
+ self.linear = linear
403
+ self.relu = nn.ReLU(inplace=True)
404
+
405
+ def forward(self, x):
406
+ return self.relu(self.linear(x))
407
+
408
+ class LinearGELU(nn.Module):
409
+ """Fused Linear-GELU module."""
410
+
411
+ def __init__(self, linear: nn.Linear):
412
+ super().__init__()
413
+ self.linear = linear
414
+ self.gelu = nn.GELU()
415
+
416
+ def forward(self, x):
417
+ return self.gelu(self.linear(x))