cortex-llm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. cortex/__init__.py +73 -0
  2. cortex/__main__.py +83 -0
  3. cortex/config.py +329 -0
  4. cortex/conversation_manager.py +468 -0
  5. cortex/fine_tuning/__init__.py +8 -0
  6. cortex/fine_tuning/dataset.py +332 -0
  7. cortex/fine_tuning/mlx_lora_trainer.py +502 -0
  8. cortex/fine_tuning/trainer.py +957 -0
  9. cortex/fine_tuning/wizard.py +707 -0
  10. cortex/gpu_validator.py +467 -0
  11. cortex/inference_engine.py +727 -0
  12. cortex/metal/__init__.py +275 -0
  13. cortex/metal/gpu_validator.py +177 -0
  14. cortex/metal/memory_pool.py +886 -0
  15. cortex/metal/mlx_accelerator.py +678 -0
  16. cortex/metal/mlx_converter.py +638 -0
  17. cortex/metal/mps_optimizer.py +417 -0
  18. cortex/metal/optimizer.py +665 -0
  19. cortex/metal/performance_profiler.py +364 -0
  20. cortex/model_downloader.py +130 -0
  21. cortex/model_manager.py +2187 -0
  22. cortex/quantization/__init__.py +5 -0
  23. cortex/quantization/dynamic_quantizer.py +736 -0
  24. cortex/template_registry/__init__.py +15 -0
  25. cortex/template_registry/auto_detector.py +144 -0
  26. cortex/template_registry/config_manager.py +234 -0
  27. cortex/template_registry/interactive.py +260 -0
  28. cortex/template_registry/registry.py +347 -0
  29. cortex/template_registry/template_profiles/__init__.py +5 -0
  30. cortex/template_registry/template_profiles/base.py +142 -0
  31. cortex/template_registry/template_profiles/complex/__init__.py +5 -0
  32. cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
  33. cortex/template_registry/template_profiles/standard/__init__.py +9 -0
  34. cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
  35. cortex/template_registry/template_profiles/standard/chatml.py +82 -0
  36. cortex/template_registry/template_profiles/standard/gemma.py +103 -0
  37. cortex/template_registry/template_profiles/standard/llama.py +87 -0
  38. cortex/template_registry/template_profiles/standard/simple.py +65 -0
  39. cortex/ui/__init__.py +120 -0
  40. cortex/ui/cli.py +1685 -0
  41. cortex/ui/markdown_render.py +185 -0
  42. cortex/ui/terminal_app.py +534 -0
  43. cortex_llm-1.0.0.dist-info/METADATA +275 -0
  44. cortex_llm-1.0.0.dist-info/RECORD +48 -0
  45. cortex_llm-1.0.0.dist-info/WHEEL +5 -0
  46. cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
  47. cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
  48. cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,886 @@
1
+ """GPU memory pool management for pre-allocation and zero-copy operations."""
2
+
3
+ import logging
4
+ import weakref
5
+ from typing import Dict, Any, Optional, List, Tuple
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ import threading
9
+ from datetime import datetime
10
+ import numpy as np
11
+ import psutil
12
+ import sys
13
+ import os
14
+ import atexit
15
+
16
+ # Configure logging
17
+ logger = logging.getLogger(__name__)
18
+ logger.setLevel(logging.INFO)
19
+
20
+ import torch
21
+ import mlx.core as mx
22
+
23
+ # Add parent directory to path for imports
24
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
25
+ from gpu_validator import GPUValidator
26
+
27
+ class AllocationStrategy(Enum):
28
+ """Memory allocation strategies with MLX zero-copy support."""
29
+ BEST_FIT = "best_fit"
30
+ FIRST_FIT = "first_fit"
31
+ UNIFIED = "unified" # Unified memory for CPU/GPU sharing
32
+ DEDICATED = "dedicated"
33
+ ZERO_COPY = "zero_copy" # MLX zero-copy for maximum efficiency
34
+
35
+ @dataclass
36
+ class MemoryBlock:
37
+ """Represents a memory block in the pool."""
38
+ block_id: str
39
+ size: int
40
+ offset: int
41
+ allocated: bool
42
+ allocation_time: Optional[datetime]
43
+ last_access: Optional[datetime]
44
+ device_type: str # "mps" or "mlx"
45
+ tensor_ref: Any # Weak reference to tensor
46
+ metadata: Dict[str, Any]
47
+ is_constant: bool = False # Whether this block is for constant memory
48
+ is_read_only: bool = False # Whether this block is read-only (weights)
49
+
50
+ def is_free(self) -> bool:
51
+ """Check if block is free for allocation."""
52
+ if not self.allocated:
53
+ return True
54
+ if self.tensor_ref is not None:
55
+ ref = self.tensor_ref()
56
+ if ref is None:
57
+ self.allocated = False
58
+ return True
59
+ return False
60
+
61
+ def mark_allocated(self, tensor: Any) -> None:
62
+ """Mark block as allocated."""
63
+ self.allocated = True
64
+ self.allocation_time = datetime.now()
65
+ self.last_access = datetime.now()
66
+ self.tensor_ref = weakref.ref(tensor) if tensor is not None else None
67
+
68
+ def mark_free(self) -> None:
69
+ """Mark block as free."""
70
+ self.allocated = False
71
+ self.tensor_ref = None
72
+
73
+ class MemoryPool:
74
+ """Pre-allocated GPU memory pool for zero-copy operations."""
75
+
76
+ DEFAULT_POOL_SIZE = 20 * 1024 * 1024 * 1024 # 20GB default target
77
+ CONSTANT_POOL_SIZE = 64 * 1024 * 1024 # 64MB for constant memory (Metal limit)
78
+ BLOCK_SIZES = [
79
+ 1 * 1024 * 1024, # 1MB
80
+ 16 * 1024 * 1024, # 16MB
81
+ 64 * 1024 * 1024, # 64MB
82
+ 256 * 1024 * 1024, # 256MB
83
+ 1024 * 1024 * 1024, # 1GB
84
+ 4096 * 1024 * 1024, # 4GB
85
+ ]
86
+
87
+ @classmethod
88
+ def get_optimal_pool_size(cls, target_size: Optional[int] = None) -> int:
89
+ """Get optimal pool size based on available memory."""
90
+ if target_size is not None:
91
+ return target_size
92
+
93
+ # Get available memory
94
+ vm = psutil.virtual_memory()
95
+ available = vm.available
96
+ total = vm.total
97
+
98
+ # More aggressive memory allocation for better performance
99
+ # Use 60% of available memory, but never more than 75% of total memory
100
+ optimal_size = min(
101
+ int(available * 0.60), # 60% of available
102
+ int(total * 0.75) # Never more than 75% of total
103
+ )
104
+
105
+ # Further limit based on actual available memory
106
+ # Can allocate up to 90% of what's available for better utilization
107
+ max_safe_size = int(available * 0.90)
108
+ optimal_size = min(optimal_size, max_safe_size)
109
+
110
+ # Cap at DEFAULT_POOL_SIZE if we have plenty of memory
111
+ if optimal_size > cls.DEFAULT_POOL_SIZE:
112
+ optimal_size = cls.DEFAULT_POOL_SIZE
113
+
114
+ # If available memory is very low (< 4GB), be extra conservative
115
+ if available < 4 * 1024 * 1024 * 1024:
116
+ optimal_size = min(optimal_size, int(available * 0.25))
117
+
118
+ # Minimum 256MB for basic functionality (reduced from 512MB)
119
+ min_size = 256 * 1024 * 1024
120
+ return max(optimal_size, min_size)
121
+
122
+ def __init__(
123
+ self,
124
+ pool_size: Optional[int] = None,
125
+ strategy: AllocationStrategy = AllocationStrategy.UNIFIED,
126
+ device: str = "mps",
127
+ auto_size: bool = True,
128
+ silent: bool = False,
129
+ use_bfloat16: Optional[bool] = None,
130
+ enable_zero_copy: bool = True
131
+ ):
132
+ """
133
+ Initialize memory pool.
134
+
135
+ Args:
136
+ pool_size: Total pool size in bytes (None for auto-detection)
137
+ strategy: Allocation strategy
138
+ device: Device type ("mps" or "mlx")
139
+ auto_size: Automatically determine pool size based on available memory
140
+ silent: Suppress initialization messages
141
+ use_bfloat16: Use bfloat16 if supported (None for auto-detect)
142
+ """
143
+ if auto_size and pool_size is None:
144
+ self.pool_size = self.get_optimal_pool_size()
145
+ logger.info(f"Auto-sizing memory pool to {self.pool_size / (1024**3):.1f}GB")
146
+ if not silent:
147
+ print(f"Auto-sizing memory pool to {self.pool_size / (1024**3):.1f}GB")
148
+ else:
149
+ self.pool_size = pool_size or self.get_optimal_pool_size()
150
+ logger.info(f"Memory pool size set to {self.pool_size / (1024**3):.1f}GB")
151
+
152
+ self.strategy = strategy
153
+ self.device = device
154
+ self.blocks: List[MemoryBlock] = []
155
+ self.constant_blocks: List[MemoryBlock] = [] # Separate pool for constant memory
156
+ self.allocated_memory = 0
157
+ self.allocated_constant_memory = 0
158
+ self.peak_memory = 0
159
+ self._lock = threading.Lock()
160
+ self.silent = silent
161
+
162
+ # Initialize GPU validator for hardware detection
163
+ self.gpu_validator = GPUValidator()
164
+ self.gpu_validator.validate()
165
+
166
+ # Determine optimal dtype
167
+ self.use_bfloat16 = use_bfloat16
168
+ if self.use_bfloat16 is None:
169
+ self.use_bfloat16 = self._should_use_bfloat16()
170
+
171
+ self.optimal_dtype = self._get_optimal_dtype()
172
+ self.enable_zero_copy = enable_zero_copy and device == "mlx"
173
+
174
+ # Regular and constant memory buffers
175
+ self._mps_buffer: Optional[torch.Tensor] = None
176
+ self._mps_constant_buffer: Optional[torch.Tensor] = None
177
+ self._mlx_buffer: Optional[mx.array] = None
178
+ self._mlx_constant_buffer: Optional[mx.array] = None
179
+
180
+ # Zero-copy memory tracking
181
+ self.zero_copy_arrays: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
182
+ self.unified_memory_regions: Dict[str, Any] = {}
183
+
184
+ # Enable zero-copy strategy for MLX if requested
185
+ if self.enable_zero_copy and strategy == AllocationStrategy.UNIFIED:
186
+ self.strategy = AllocationStrategy.ZERO_COPY
187
+ logger.info("Zero-copy memory strategy enabled for MLX")
188
+
189
+ logger.info(f"Initializing memory pool with strategy: {self.strategy.value}")
190
+ self._initialize_pool()
191
+
192
+ # Register cleanup on exit
193
+ atexit.register(self.cleanup)
194
+
195
+ def cleanup(self) -> None:
196
+ """Clean up allocated resources to prevent leaks."""
197
+ try:
198
+ # Release MPS buffers
199
+ if self._mps_buffer is not None:
200
+ del self._mps_buffer
201
+ self._mps_buffer = None
202
+
203
+ if self._mps_constant_buffer is not None:
204
+ del self._mps_constant_buffer
205
+ self._mps_constant_buffer = None
206
+
207
+ # Release MLX buffers
208
+ if self._mlx_buffer is not None:
209
+ del self._mlx_buffer
210
+ self._mlx_buffer = None
211
+
212
+ if self._mlx_constant_buffer is not None:
213
+ del self._mlx_constant_buffer
214
+ self._mlx_constant_buffer = None
215
+
216
+ # Force synchronization and cleanup
217
+ if self.device == "mps" and torch.backends.mps.is_available():
218
+ torch.mps.synchronize()
219
+ if hasattr(torch.mps, 'empty_cache'):
220
+ torch.mps.empty_cache()
221
+ except Exception:
222
+ pass # Ignore errors during cleanup
223
+
224
+ def _should_use_bfloat16(self) -> bool:
225
+ """
226
+ Determine if bfloat16 should be used based on hardware.
227
+
228
+ Returns:
229
+ True if bfloat16 is supported and beneficial
230
+ """
231
+ if self.gpu_validator.gpu_info and self.gpu_validator.gpu_info.supports_bfloat16:
232
+ if self.device == "mps":
233
+ # Check PyTorch bfloat16 support on MPS
234
+ try:
235
+ test_tensor = torch.tensor([1.0], dtype=torch.bfloat16, device="mps")
236
+ return True
237
+ except (RuntimeError, TypeError):
238
+ return False
239
+ elif self.device == "mlx":
240
+ # Check MLX bfloat16 support
241
+ return hasattr(mx, 'bfloat16')
242
+ return False
243
+
244
+ def _get_optimal_dtype(self) -> Any:
245
+ """
246
+ Get optimal dtype for current device and hardware.
247
+
248
+ Returns:
249
+ torch.dtype or mx.dtype
250
+ """
251
+ if self.device == "mps":
252
+ if self.use_bfloat16:
253
+ return torch.bfloat16
254
+ else:
255
+ return torch.float16
256
+ elif self.device == "mlx":
257
+ if self.use_bfloat16 and hasattr(mx, 'bfloat16'):
258
+ return mx.bfloat16
259
+ else:
260
+ return mx.float16
261
+ else:
262
+ return torch.float32
263
+
264
+ def _initialize_pool(self) -> None:
265
+ """Initialize the memory pool with pre-allocated buffers."""
266
+ if self.device == "mps" and torch.backends.mps.is_available():
267
+ self._initialize_mps_pool()
268
+ elif self.device == "mlx":
269
+ self._initialize_mlx_pool()
270
+ else:
271
+ raise RuntimeError(f"Unsupported device: {self.device}")
272
+
273
+ def _initialize_mps_pool(self) -> None:
274
+ """Initialize MPS memory pool with optimal dtype and constant memory."""
275
+ try:
276
+ device = torch.device("mps")
277
+
278
+ # Calculate number of elements based on dtype size
279
+ if self.optimal_dtype in [torch.float16, torch.bfloat16]:
280
+ element_size = 2 # 16-bit types
281
+ else:
282
+ element_size = 4 # 32-bit types
283
+
284
+ # Initialize regular memory pool
285
+ num_elements = self.pool_size // element_size
286
+ self._mps_buffer = torch.empty(
287
+ num_elements,
288
+ dtype=self.optimal_dtype,
289
+ device=device
290
+ )
291
+
292
+ # Initialize constant memory pool (for weights)
293
+ constant_elements = self.CONSTANT_POOL_SIZE // element_size
294
+ self._mps_constant_buffer = torch.empty(
295
+ constant_elements,
296
+ dtype=self.optimal_dtype,
297
+ device=device
298
+ )
299
+ # Mark as read-only after initialization
300
+ self._mps_constant_buffer.requires_grad_(False)
301
+
302
+ dtype_name = str(self.optimal_dtype).split('.')[-1]
303
+ logger.info(f"MPS memory pool initialized: dtype={dtype_name}, constant_pool={self.CONSTANT_POOL_SIZE / (1024*1024):.1f}MB")
304
+ if not self.silent:
305
+ print(f"Initialized MPS pool with dtype: {dtype_name}")
306
+ print(f"Constant memory pool: {self.CONSTANT_POOL_SIZE / (1024*1024):.1f}MB")
307
+
308
+ if self.strategy == AllocationStrategy.UNIFIED:
309
+ self._create_unified_blocks()
310
+ else:
311
+ self._create_segmented_blocks()
312
+
313
+ # Create constant memory blocks
314
+ self._create_constant_blocks()
315
+
316
+ except Exception as e:
317
+ error_msg = str(e).lower()
318
+
319
+ # Check if this is a dtype support issue vs memory issue
320
+ if self.optimal_dtype == torch.bfloat16 and 'bfloat16' in error_msg:
321
+ # bfloat16 not supported, fallback to float16
322
+ self.optimal_dtype = torch.float16
323
+ self.use_bfloat16 = False
324
+ logger.info("bfloat16 not supported, falling back to float16")
325
+ if not self.silent:
326
+ print("bfloat16 not supported, falling back to float16")
327
+ self._initialize_mps_pool() # Retry with float16
328
+ elif 'invalid buffer size' in error_msg or 'out of memory' in error_msg:
329
+ # Memory allocation failed - provide helpful error
330
+ pool_size_gb = self.pool_size / (1024**3)
331
+ raise RuntimeError(
332
+ f"Failed to allocate {pool_size_gb:.2f}GB memory pool. "
333
+ f"Insufficient memory available. Consider reducing pool size or "
334
+ f"freeing up system memory. Error: {e}"
335
+ )
336
+ else:
337
+ # Other errors - pass through
338
+ raise RuntimeError(f"Failed to initialize MPS pool: {e}")
339
+
340
+ def _initialize_mlx_pool(self) -> None:
341
+ """Initialize MLX memory pool with optimal dtype and zero-copy support."""
342
+ try:
343
+ # Calculate number of elements based on dtype size
344
+ if self.optimal_dtype in [mx.float16, getattr(mx, 'bfloat16', mx.float16)]:
345
+ element_size = 2 # 16-bit types
346
+ else:
347
+ element_size = 4 # 32-bit types
348
+
349
+ num_elements = self.pool_size // element_size
350
+
351
+ # For zero-copy, create unified memory that can be shared
352
+ if self.strategy == AllocationStrategy.ZERO_COPY:
353
+ logger.info("Creating MLX zero-copy unified memory pool")
354
+ # MLX arrays are already zero-copy between CPU/GPU
355
+ self._mlx_buffer = mx.zeros(
356
+ (num_elements,),
357
+ dtype=self.optimal_dtype
358
+ )
359
+ # Force evaluation to allocate unified memory
360
+ mx.eval(self._mlx_buffer)
361
+
362
+ # Also create constant buffer for weights
363
+ constant_elements = self.CONSTANT_POOL_SIZE // element_size
364
+ self._mlx_constant_buffer = mx.zeros(
365
+ (constant_elements,),
366
+ dtype=self.optimal_dtype
367
+ )
368
+ mx.eval(self._mlx_constant_buffer)
369
+
370
+ dtype_name = str(self.optimal_dtype).split('.')[-1]
371
+ logger.info(f"MLX zero-copy pool initialized: dtype={dtype_name}, size={self.pool_size / (1024**3):.1f}GB")
372
+ if not self.silent:
373
+ print(f"Initialized MLX zero-copy pool with dtype: {dtype_name}")
374
+ print(f"Zero-copy unified memory: {self.pool_size / (1024**3):.1f}GB")
375
+ else:
376
+ # Standard MLX buffer
377
+ self._mlx_buffer = mx.zeros(
378
+ (num_elements,),
379
+ dtype=self.optimal_dtype
380
+ )
381
+ mx.eval(self._mlx_buffer)
382
+
383
+ dtype_name = str(self.optimal_dtype).split('.')[-1]
384
+ logger.info(f"MLX pool initialized: dtype={dtype_name}")
385
+ if not self.silent:
386
+ print(f"Initialized MLX pool with dtype: {dtype_name}")
387
+
388
+ if self.strategy in [AllocationStrategy.UNIFIED, AllocationStrategy.ZERO_COPY]:
389
+ self._create_unified_blocks()
390
+ else:
391
+ self._create_segmented_blocks()
392
+
393
+ # Create constant blocks for MLX if needed
394
+ if self._mlx_constant_buffer is not None:
395
+ self._create_constant_blocks()
396
+
397
+ except Exception as e:
398
+ # Fallback to float16 if bfloat16 fails
399
+ if hasattr(mx, 'bfloat16') and self.optimal_dtype == mx.bfloat16:
400
+ self.optimal_dtype = mx.float16
401
+ self.use_bfloat16 = False
402
+ logger.info("MLX bfloat16 not supported, falling back to float16")
403
+ if not self.silent:
404
+ print("bfloat16 not supported, falling back to float16")
405
+ self._initialize_mlx_pool() # Retry with float16
406
+ else:
407
+ raise RuntimeError(f"Failed to initialize MLX pool: {e}")
408
+
409
+ def _create_unified_blocks(self) -> None:
410
+ """Create a single unified memory block."""
411
+ block = MemoryBlock(
412
+ block_id="unified_0",
413
+ size=self.pool_size,
414
+ offset=0,
415
+ allocated=False,
416
+ allocation_time=None,
417
+ last_access=None,
418
+ device_type=self.device,
419
+ tensor_ref=None,
420
+ metadata={"type": "unified"},
421
+ is_constant=False,
422
+ is_read_only=False
423
+ )
424
+ self.blocks.append(block)
425
+
426
+ def _create_constant_blocks(self) -> None:
427
+ """Create constant memory blocks for weights."""
428
+ # Create smaller blocks for better allocation flexibility
429
+ block_sizes = [
430
+ 4 * 1024 * 1024, # 4MB blocks
431
+ 16 * 1024 * 1024, # 16MB blocks
432
+ ]
433
+
434
+ offset = 0
435
+ block_id = 0
436
+ remaining_size = self.CONSTANT_POOL_SIZE
437
+
438
+ for size in block_sizes:
439
+ while remaining_size >= size:
440
+ block = MemoryBlock(
441
+ block_id=f"constant_{block_id}",
442
+ size=size,
443
+ offset=offset,
444
+ allocated=False,
445
+ allocation_time=None,
446
+ last_access=None,
447
+ device_type=self.device,
448
+ tensor_ref=None,
449
+ metadata={"type": "constant", "size_class": size},
450
+ is_constant=True,
451
+ is_read_only=True
452
+ )
453
+ self.constant_blocks.append(block)
454
+ offset += size
455
+ remaining_size -= size
456
+ block_id += 1
457
+
458
+ # Add remainder as final block
459
+ if remaining_size > 0:
460
+ block = MemoryBlock(
461
+ block_id=f"constant_{block_id}",
462
+ size=remaining_size,
463
+ offset=offset,
464
+ allocated=False,
465
+ allocation_time=None,
466
+ last_access=None,
467
+ device_type=self.device,
468
+ tensor_ref=None,
469
+ metadata={"type": "constant", "size_class": "remainder"},
470
+ is_constant=True,
471
+ is_read_only=True
472
+ )
473
+ self.constant_blocks.append(block)
474
+
475
+ def _create_segmented_blocks(self) -> None:
476
+ """Create segmented memory blocks of various sizes."""
477
+ offset = 0
478
+ block_id = 0
479
+
480
+ remaining_size = self.pool_size
481
+
482
+ for size_class in reversed(self.BLOCK_SIZES):
483
+ while remaining_size >= size_class:
484
+ block = MemoryBlock(
485
+ block_id=f"block_{block_id}",
486
+ size=size_class,
487
+ offset=offset,
488
+ allocated=False,
489
+ allocation_time=None,
490
+ last_access=None,
491
+ device_type=self.device,
492
+ tensor_ref=None,
493
+ metadata={"size_class": size_class}
494
+ )
495
+ self.blocks.append(block)
496
+ offset += size_class
497
+ remaining_size -= size_class
498
+ block_id += 1
499
+
500
+ if remaining_size > 0:
501
+ block = MemoryBlock(
502
+ block_id=f"block_{block_id}",
503
+ size=remaining_size,
504
+ offset=offset,
505
+ allocated=False,
506
+ allocation_time=None,
507
+ last_access=None,
508
+ device_type=self.device,
509
+ tensor_ref=None,
510
+ metadata={"size_class": "remainder"}
511
+ )
512
+ self.blocks.append(block)
513
+
514
+ def allocate(
515
+ self,
516
+ size: int,
517
+ dtype: Optional[Any] = None,
518
+ is_constant: bool = False
519
+ ) -> Optional[Any]:
520
+ """
521
+ Allocate memory from the pool.
522
+
523
+ Args:
524
+ size: Size in bytes
525
+ dtype: Data type for the tensor
526
+ is_constant: Whether to allocate from constant memory pool
527
+
528
+ Returns:
529
+ Allocated tensor or None if allocation fails
530
+ """
531
+ with self._lock:
532
+ if is_constant:
533
+ # Allocate from constant memory pool
534
+ block = self._find_constant_block(size)
535
+ if block is None:
536
+ return None
537
+
538
+ tensor = self._create_tensor_from_block(block, size, dtype, is_constant=True)
539
+ block.mark_allocated(tensor)
540
+
541
+ self.allocated_constant_memory += size
542
+ else:
543
+ # Allocate from regular pool
544
+ block = self._find_block(size)
545
+
546
+ if block is None:
547
+ self._try_defragment()
548
+ block = self._find_block(size)
549
+
550
+ if block is None:
551
+ return None
552
+
553
+ tensor = self._create_tensor_from_block(block, size, dtype, is_constant=False)
554
+ block.mark_allocated(tensor)
555
+
556
+ self.allocated_memory += size
557
+ self.peak_memory = max(self.peak_memory, self.allocated_memory)
558
+
559
+ return tensor
560
+
561
+ def allocate_weights(
562
+ self,
563
+ size: int,
564
+ dtype: Optional[Any] = None
565
+ ) -> Optional[Any]:
566
+ """
567
+ Allocate memory for weights using constant memory.
568
+
569
+ Args:
570
+ size: Size in bytes
571
+ dtype: Data type for the tensor
572
+
573
+ Returns:
574
+ Allocated tensor in constant memory or regular memory as fallback
575
+ """
576
+ # Try constant memory first
577
+ tensor = self.allocate(size, dtype, is_constant=True)
578
+
579
+ # Fallback to regular memory if constant memory is full
580
+ if tensor is None:
581
+ logger.info("Constant memory full, falling back to regular memory for weights")
582
+ if not self.silent:
583
+ print("Constant memory full, falling back to regular memory for weights")
584
+ tensor = self.allocate(size, dtype, is_constant=False)
585
+
586
+ return tensor
587
+
588
+ def _find_block(self, size: int) -> Optional[MemoryBlock]:
589
+ """Find a suitable block for allocation."""
590
+ if self.strategy == AllocationStrategy.BEST_FIT:
591
+ return self._best_fit(size)
592
+ elif self.strategy == AllocationStrategy.FIRST_FIT:
593
+ return self._first_fit(size)
594
+ else:
595
+ return self._first_fit(size)
596
+
597
+ def _find_constant_block(self, size: int) -> Optional[MemoryBlock]:
598
+ """Find a suitable constant memory block for allocation."""
599
+ for block in self.constant_blocks:
600
+ if block.is_free() and block.size >= size:
601
+ return block
602
+ return None
603
+
604
+ def _best_fit(self, size: int) -> Optional[MemoryBlock]:
605
+ """Find the smallest block that fits the requested size."""
606
+ best_block = None
607
+ best_waste = float('inf')
608
+
609
+ for block in self.blocks:
610
+ if block.is_free() and block.size >= size:
611
+ waste = block.size - size
612
+ if waste < best_waste:
613
+ best_waste = waste
614
+ best_block = block
615
+
616
+ return best_block
617
+
618
+ def _first_fit(self, size: int) -> Optional[MemoryBlock]:
619
+ """Find the first block that fits the requested size."""
620
+ for block in self.blocks:
621
+ if block.is_free() and block.size >= size:
622
+ return block
623
+ return None
624
+
625
+ def _create_tensor_from_block(
626
+ self,
627
+ block: MemoryBlock,
628
+ size: int,
629
+ dtype: Optional[Any],
630
+ is_constant: bool = False
631
+ ) -> Any:
632
+ """Create a tensor view from a memory block with zero-copy support."""
633
+ # Use zero-copy for MLX arrays when enabled
634
+ if self.device == "mlx" and self.strategy == AllocationStrategy.ZERO_COPY:
635
+ return self._create_zero_copy_array(block, size, dtype, is_constant)
636
+
637
+ if self.device == "mps":
638
+ if dtype is None:
639
+ dtype = self.optimal_dtype
640
+
641
+ # Select buffer based on memory type
642
+ buffer = self._mps_constant_buffer if is_constant else self._mps_buffer
643
+
644
+ # Get element size based on buffer dtype
645
+ if buffer.dtype in [torch.float16, torch.bfloat16]:
646
+ buffer_element_size = 2
647
+ else:
648
+ buffer_element_size = 4
649
+
650
+ # Calculate number of elements needed
651
+ if dtype in [torch.float16, torch.bfloat16]:
652
+ target_element_size = 2
653
+ else:
654
+ target_element_size = 4
655
+
656
+ num_elements = size // target_element_size
657
+ start_idx = block.offset // buffer_element_size
658
+ end_idx = start_idx + (size // buffer_element_size)
659
+
660
+ tensor_view = buffer[start_idx:end_idx].view(-1)
661
+
662
+ # Convert dtype if needed
663
+ if dtype != buffer.dtype:
664
+ tensor_view = tensor_view.to(dtype)
665
+
666
+ # Mark as non-gradients for constant memory
667
+ if is_constant:
668
+ tensor_view.requires_grad_(False)
669
+
670
+ return tensor_view[:num_elements]
671
+
672
+ elif self.device == "mlx":
673
+ if dtype is None:
674
+ dtype = self.optimal_dtype
675
+
676
+ # Get element size based on buffer dtype
677
+ if self._mlx_buffer.dtype in [mx.float16, getattr(mx, 'bfloat16', mx.float16)]:
678
+ buffer_element_size = 2
679
+ else:
680
+ buffer_element_size = 4
681
+
682
+ # Calculate number of elements
683
+ if dtype in [mx.float16, getattr(mx, 'bfloat16', mx.float16)]:
684
+ target_element_size = 2
685
+ else:
686
+ target_element_size = 4
687
+
688
+ num_elements = size // target_element_size
689
+ start_idx = block.offset // buffer_element_size
690
+ end_idx = start_idx + (size // buffer_element_size)
691
+
692
+ array_view = self._mlx_buffer[start_idx:end_idx]
693
+
694
+ # Convert dtype if needed (MLX handles this automatically)
695
+ if dtype != self._mlx_buffer.dtype:
696
+ array_view = array_view.astype(dtype)
697
+
698
+ return array_view[:num_elements]
699
+
700
+ def _create_zero_copy_array(
701
+ self,
702
+ block: MemoryBlock,
703
+ size: int,
704
+ dtype: Optional[Any],
705
+ is_constant: bool = False
706
+ ) -> mx.array:
707
+ """Create MLX array with zero-copy from unified memory."""
708
+ logger.debug(f"Creating zero-copy array: size={size}, constant={is_constant}")
709
+
710
+ if dtype is None:
711
+ dtype = self.optimal_dtype
712
+
713
+ # Select buffer based on memory type
714
+ buffer = self._mlx_constant_buffer if is_constant else self._mlx_buffer
715
+
716
+ # Calculate slice for zero-copy view
717
+ if dtype in [mx.float16, getattr(mx, 'bfloat16', mx.float16)]:
718
+ element_size = 2
719
+ else:
720
+ element_size = 4
721
+
722
+ num_elements = size // element_size
723
+ start_idx = block.offset // element_size
724
+ end_idx = start_idx + num_elements
725
+
726
+ # Create zero-copy view - no data movement
727
+ array_view = buffer[start_idx:end_idx]
728
+
729
+ # Convert dtype if needed (MLX does this efficiently)
730
+ if dtype != buffer.dtype:
731
+ array_view = array_view.astype(dtype)
732
+
733
+ # Track zero-copy arrays for monitoring
734
+ array_id = f"zero_copy_{block.block_id}_{id(array_view)}"
735
+ self.zero_copy_arrays[array_id] = array_view
736
+
737
+ # Store metadata for unified memory region
738
+ self.unified_memory_regions[block.block_id] = {
739
+ 'size': size,
740
+ 'dtype': str(dtype),
741
+ 'zero_copy': True,
742
+ 'constant': is_constant
743
+ }
744
+
745
+ logger.debug(f"Zero-copy array created: id={array_id}, zero_copy_arrays={len(self.zero_copy_arrays)}")
746
+ return array_view
747
+
748
+ def deallocate(self, tensor: Any) -> bool:
749
+ """
750
+ Deallocate memory back to the pool.
751
+
752
+ Args:
753
+ tensor: Tensor to deallocate
754
+
755
+ Returns:
756
+ True if deallocation successful
757
+ """
758
+ with self._lock:
759
+ for block in self.blocks:
760
+ if block.tensor_ref is not None:
761
+ ref = block.tensor_ref()
762
+ if ref is tensor:
763
+ block.mark_free()
764
+ self.allocated_memory -= block.size
765
+ return True
766
+ return False
767
+
768
+ def _try_defragment(self) -> None:
769
+ """Attempt to defragment the memory pool."""
770
+ free_blocks = [b for b in self.blocks if b.is_free()]
771
+
772
+ if len(free_blocks) < 2:
773
+ return
774
+
775
+ free_blocks.sort(key=lambda b: b.offset)
776
+
777
+ merged = []
778
+ current = free_blocks[0]
779
+
780
+ for block in free_blocks[1:]:
781
+ if current.offset + current.size == block.offset:
782
+ current = MemoryBlock(
783
+ block_id=f"merged_{current.block_id}_{block.block_id}",
784
+ size=current.size + block.size,
785
+ offset=current.offset,
786
+ allocated=False,
787
+ allocation_time=None,
788
+ last_access=None,
789
+ device_type=self.device,
790
+ tensor_ref=None,
791
+ metadata={"merged": True}
792
+ )
793
+ else:
794
+ merged.append(current)
795
+ current = block
796
+
797
+ merged.append(current)
798
+
799
+ allocated_blocks = [b for b in self.blocks if not b.is_free()]
800
+ self.blocks = allocated_blocks + merged
801
+
802
+ def get_stats(self) -> Dict[str, Any]:
803
+ """Get memory pool statistics."""
804
+ with self._lock:
805
+ free_blocks = sum(1 for b in self.blocks if b.is_free())
806
+ allocated_blocks = len(self.blocks) - free_blocks
807
+ free_memory = sum(b.size for b in self.blocks if b.is_free())
808
+
809
+ # Constant memory stats
810
+ constant_free_blocks = sum(1 for b in self.constant_blocks if b.is_free())
811
+ constant_allocated_blocks = len(self.constant_blocks) - constant_free_blocks
812
+ constant_free_memory = sum(b.size for b in self.constant_blocks if b.is_free())
813
+
814
+ # Get dtype information
815
+ dtype_name = str(self.optimal_dtype).split('.')[-1]
816
+ dtype_bits = 16 if self.optimal_dtype in [torch.float16, torch.bfloat16,
817
+ mx.float16, getattr(mx, 'bfloat16', mx.float16)] else 32
818
+
819
+ # Calculate zero-copy statistics
820
+ zero_copy_count = len(self.zero_copy_arrays) if hasattr(self, 'zero_copy_arrays') else 0
821
+ unified_regions = len(self.unified_memory_regions) if hasattr(self, 'unified_memory_regions') else 0
822
+
823
+ stats = {
824
+ "pool_size": self.pool_size,
825
+ "allocated_memory": self.allocated_memory,
826
+ "free_memory": free_memory,
827
+ "peak_memory": self.peak_memory,
828
+ "total_blocks": len(self.blocks),
829
+ "allocated_blocks": allocated_blocks,
830
+ "free_blocks": free_blocks,
831
+ "fragmentation": 1.0 - (free_memory / (self.pool_size - self.allocated_memory + 0.01)),
832
+ "constant_pool_size": self.CONSTANT_POOL_SIZE,
833
+ "constant_allocated": self.allocated_constant_memory,
834
+ "constant_free": constant_free_memory,
835
+ "constant_blocks": len(self.constant_blocks),
836
+ "constant_allocated_blocks": constant_allocated_blocks,
837
+ "device": self.device,
838
+ "strategy": self.strategy.value,
839
+ "dtype": dtype_name,
840
+ "dtype_bits": dtype_bits,
841
+ "gpu_family": self.gpu_validator.gpu_info.gpu_family if self.gpu_validator.gpu_info else "unknown",
842
+ "memory_efficiency": f"{dtype_bits}-bit precision, {50 if dtype_bits == 16 else 0}% memory savings",
843
+ "constant_memory_benefit": "15-20% bandwidth improvement for weights"
844
+ }
845
+
846
+ # Add zero-copy statistics if MLX with zero-copy enabled
847
+ if self.device == "mlx" and self.strategy == AllocationStrategy.ZERO_COPY:
848
+ stats.update({
849
+ "zero_copy_enabled": True,
850
+ "zero_copy_arrays": zero_copy_count,
851
+ "unified_memory_regions": unified_regions,
852
+ "zero_copy_benefit": "Eliminates CPU-GPU transfer overhead"
853
+ })
854
+ logger.debug(f"Zero-copy stats: arrays={zero_copy_count}, regions={unified_regions}")
855
+ else:
856
+ stats["zero_copy_enabled"] = False
857
+
858
+ return stats
859
+
860
+ def reset(self) -> None:
861
+ """Reset the memory pool."""
862
+ with self._lock:
863
+ for block in self.blocks:
864
+ block.mark_free()
865
+
866
+ self.allocated_memory = 0
867
+
868
+ if self.strategy == AllocationStrategy.UNIFIED:
869
+ self.blocks = []
870
+ self._create_unified_blocks()
871
+
872
+ def optimize_layout(self) -> None:
873
+ """Optimize memory layout for better cache locality."""
874
+ with self._lock:
875
+ allocated_blocks = [b for b in self.blocks if not b.is_free()]
876
+
877
+ allocated_blocks.sort(key=lambda b: b.last_access or datetime.min, reverse=True)
878
+
879
+ self._try_defragment()
880
+
881
+ def __del__(self):
882
+ """Cleanup when pool is destroyed."""
883
+ if self._mps_buffer is not None:
884
+ del self._mps_buffer
885
+ if self._mlx_buffer is not None:
886
+ del self._mlx_buffer