cortex-llm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex/__init__.py +73 -0
- cortex/__main__.py +83 -0
- cortex/config.py +329 -0
- cortex/conversation_manager.py +468 -0
- cortex/fine_tuning/__init__.py +8 -0
- cortex/fine_tuning/dataset.py +332 -0
- cortex/fine_tuning/mlx_lora_trainer.py +502 -0
- cortex/fine_tuning/trainer.py +957 -0
- cortex/fine_tuning/wizard.py +707 -0
- cortex/gpu_validator.py +467 -0
- cortex/inference_engine.py +727 -0
- cortex/metal/__init__.py +275 -0
- cortex/metal/gpu_validator.py +177 -0
- cortex/metal/memory_pool.py +886 -0
- cortex/metal/mlx_accelerator.py +678 -0
- cortex/metal/mlx_converter.py +638 -0
- cortex/metal/mps_optimizer.py +417 -0
- cortex/metal/optimizer.py +665 -0
- cortex/metal/performance_profiler.py +364 -0
- cortex/model_downloader.py +130 -0
- cortex/model_manager.py +2187 -0
- cortex/quantization/__init__.py +5 -0
- cortex/quantization/dynamic_quantizer.py +736 -0
- cortex/template_registry/__init__.py +15 -0
- cortex/template_registry/auto_detector.py +144 -0
- cortex/template_registry/config_manager.py +234 -0
- cortex/template_registry/interactive.py +260 -0
- cortex/template_registry/registry.py +347 -0
- cortex/template_registry/template_profiles/__init__.py +5 -0
- cortex/template_registry/template_profiles/base.py +142 -0
- cortex/template_registry/template_profiles/complex/__init__.py +5 -0
- cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
- cortex/template_registry/template_profiles/standard/__init__.py +9 -0
- cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
- cortex/template_registry/template_profiles/standard/chatml.py +82 -0
- cortex/template_registry/template_profiles/standard/gemma.py +103 -0
- cortex/template_registry/template_profiles/standard/llama.py +87 -0
- cortex/template_registry/template_profiles/standard/simple.py +65 -0
- cortex/ui/__init__.py +120 -0
- cortex/ui/cli.py +1685 -0
- cortex/ui/markdown_render.py +185 -0
- cortex/ui/terminal_app.py +534 -0
- cortex_llm-1.0.0.dist-info/METADATA +275 -0
- cortex_llm-1.0.0.dist-info/RECORD +48 -0
- cortex_llm-1.0.0.dist-info/WHEEL +5 -0
- cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
- cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
- cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,886 @@
|
|
|
1
|
+
"""GPU memory pool management for pre-allocation and zero-copy operations."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import weakref
|
|
5
|
+
from typing import Dict, Any, Optional, List, Tuple
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from enum import Enum
|
|
8
|
+
import threading
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
import numpy as np
|
|
11
|
+
import psutil
|
|
12
|
+
import sys
|
|
13
|
+
import os
|
|
14
|
+
import atexit
|
|
15
|
+
|
|
16
|
+
# Configure logging
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
logger.setLevel(logging.INFO)
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import mlx.core as mx
|
|
22
|
+
|
|
23
|
+
# Add parent directory to path for imports
|
|
24
|
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
25
|
+
from gpu_validator import GPUValidator
|
|
26
|
+
|
|
27
|
+
class AllocationStrategy(Enum):
|
|
28
|
+
"""Memory allocation strategies with MLX zero-copy support."""
|
|
29
|
+
BEST_FIT = "best_fit"
|
|
30
|
+
FIRST_FIT = "first_fit"
|
|
31
|
+
UNIFIED = "unified" # Unified memory for CPU/GPU sharing
|
|
32
|
+
DEDICATED = "dedicated"
|
|
33
|
+
ZERO_COPY = "zero_copy" # MLX zero-copy for maximum efficiency
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class MemoryBlock:
|
|
37
|
+
"""Represents a memory block in the pool."""
|
|
38
|
+
block_id: str
|
|
39
|
+
size: int
|
|
40
|
+
offset: int
|
|
41
|
+
allocated: bool
|
|
42
|
+
allocation_time: Optional[datetime]
|
|
43
|
+
last_access: Optional[datetime]
|
|
44
|
+
device_type: str # "mps" or "mlx"
|
|
45
|
+
tensor_ref: Any # Weak reference to tensor
|
|
46
|
+
metadata: Dict[str, Any]
|
|
47
|
+
is_constant: bool = False # Whether this block is for constant memory
|
|
48
|
+
is_read_only: bool = False # Whether this block is read-only (weights)
|
|
49
|
+
|
|
50
|
+
def is_free(self) -> bool:
|
|
51
|
+
"""Check if block is free for allocation."""
|
|
52
|
+
if not self.allocated:
|
|
53
|
+
return True
|
|
54
|
+
if self.tensor_ref is not None:
|
|
55
|
+
ref = self.tensor_ref()
|
|
56
|
+
if ref is None:
|
|
57
|
+
self.allocated = False
|
|
58
|
+
return True
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
def mark_allocated(self, tensor: Any) -> None:
|
|
62
|
+
"""Mark block as allocated."""
|
|
63
|
+
self.allocated = True
|
|
64
|
+
self.allocation_time = datetime.now()
|
|
65
|
+
self.last_access = datetime.now()
|
|
66
|
+
self.tensor_ref = weakref.ref(tensor) if tensor is not None else None
|
|
67
|
+
|
|
68
|
+
def mark_free(self) -> None:
|
|
69
|
+
"""Mark block as free."""
|
|
70
|
+
self.allocated = False
|
|
71
|
+
self.tensor_ref = None
|
|
72
|
+
|
|
73
|
+
class MemoryPool:
|
|
74
|
+
"""Pre-allocated GPU memory pool for zero-copy operations."""
|
|
75
|
+
|
|
76
|
+
DEFAULT_POOL_SIZE = 20 * 1024 * 1024 * 1024 # 20GB default target
|
|
77
|
+
CONSTANT_POOL_SIZE = 64 * 1024 * 1024 # 64MB for constant memory (Metal limit)
|
|
78
|
+
BLOCK_SIZES = [
|
|
79
|
+
1 * 1024 * 1024, # 1MB
|
|
80
|
+
16 * 1024 * 1024, # 16MB
|
|
81
|
+
64 * 1024 * 1024, # 64MB
|
|
82
|
+
256 * 1024 * 1024, # 256MB
|
|
83
|
+
1024 * 1024 * 1024, # 1GB
|
|
84
|
+
4096 * 1024 * 1024, # 4GB
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def get_optimal_pool_size(cls, target_size: Optional[int] = None) -> int:
|
|
89
|
+
"""Get optimal pool size based on available memory."""
|
|
90
|
+
if target_size is not None:
|
|
91
|
+
return target_size
|
|
92
|
+
|
|
93
|
+
# Get available memory
|
|
94
|
+
vm = psutil.virtual_memory()
|
|
95
|
+
available = vm.available
|
|
96
|
+
total = vm.total
|
|
97
|
+
|
|
98
|
+
# More aggressive memory allocation for better performance
|
|
99
|
+
# Use 60% of available memory, but never more than 75% of total memory
|
|
100
|
+
optimal_size = min(
|
|
101
|
+
int(available * 0.60), # 60% of available
|
|
102
|
+
int(total * 0.75) # Never more than 75% of total
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Further limit based on actual available memory
|
|
106
|
+
# Can allocate up to 90% of what's available for better utilization
|
|
107
|
+
max_safe_size = int(available * 0.90)
|
|
108
|
+
optimal_size = min(optimal_size, max_safe_size)
|
|
109
|
+
|
|
110
|
+
# Cap at DEFAULT_POOL_SIZE if we have plenty of memory
|
|
111
|
+
if optimal_size > cls.DEFAULT_POOL_SIZE:
|
|
112
|
+
optimal_size = cls.DEFAULT_POOL_SIZE
|
|
113
|
+
|
|
114
|
+
# If available memory is very low (< 4GB), be extra conservative
|
|
115
|
+
if available < 4 * 1024 * 1024 * 1024:
|
|
116
|
+
optimal_size = min(optimal_size, int(available * 0.25))
|
|
117
|
+
|
|
118
|
+
# Minimum 256MB for basic functionality (reduced from 512MB)
|
|
119
|
+
min_size = 256 * 1024 * 1024
|
|
120
|
+
return max(optimal_size, min_size)
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
pool_size: Optional[int] = None,
|
|
125
|
+
strategy: AllocationStrategy = AllocationStrategy.UNIFIED,
|
|
126
|
+
device: str = "mps",
|
|
127
|
+
auto_size: bool = True,
|
|
128
|
+
silent: bool = False,
|
|
129
|
+
use_bfloat16: Optional[bool] = None,
|
|
130
|
+
enable_zero_copy: bool = True
|
|
131
|
+
):
|
|
132
|
+
"""
|
|
133
|
+
Initialize memory pool.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
pool_size: Total pool size in bytes (None for auto-detection)
|
|
137
|
+
strategy: Allocation strategy
|
|
138
|
+
device: Device type ("mps" or "mlx")
|
|
139
|
+
auto_size: Automatically determine pool size based on available memory
|
|
140
|
+
silent: Suppress initialization messages
|
|
141
|
+
use_bfloat16: Use bfloat16 if supported (None for auto-detect)
|
|
142
|
+
"""
|
|
143
|
+
if auto_size and pool_size is None:
|
|
144
|
+
self.pool_size = self.get_optimal_pool_size()
|
|
145
|
+
logger.info(f"Auto-sizing memory pool to {self.pool_size / (1024**3):.1f}GB")
|
|
146
|
+
if not silent:
|
|
147
|
+
print(f"Auto-sizing memory pool to {self.pool_size / (1024**3):.1f}GB")
|
|
148
|
+
else:
|
|
149
|
+
self.pool_size = pool_size or self.get_optimal_pool_size()
|
|
150
|
+
logger.info(f"Memory pool size set to {self.pool_size / (1024**3):.1f}GB")
|
|
151
|
+
|
|
152
|
+
self.strategy = strategy
|
|
153
|
+
self.device = device
|
|
154
|
+
self.blocks: List[MemoryBlock] = []
|
|
155
|
+
self.constant_blocks: List[MemoryBlock] = [] # Separate pool for constant memory
|
|
156
|
+
self.allocated_memory = 0
|
|
157
|
+
self.allocated_constant_memory = 0
|
|
158
|
+
self.peak_memory = 0
|
|
159
|
+
self._lock = threading.Lock()
|
|
160
|
+
self.silent = silent
|
|
161
|
+
|
|
162
|
+
# Initialize GPU validator for hardware detection
|
|
163
|
+
self.gpu_validator = GPUValidator()
|
|
164
|
+
self.gpu_validator.validate()
|
|
165
|
+
|
|
166
|
+
# Determine optimal dtype
|
|
167
|
+
self.use_bfloat16 = use_bfloat16
|
|
168
|
+
if self.use_bfloat16 is None:
|
|
169
|
+
self.use_bfloat16 = self._should_use_bfloat16()
|
|
170
|
+
|
|
171
|
+
self.optimal_dtype = self._get_optimal_dtype()
|
|
172
|
+
self.enable_zero_copy = enable_zero_copy and device == "mlx"
|
|
173
|
+
|
|
174
|
+
# Regular and constant memory buffers
|
|
175
|
+
self._mps_buffer: Optional[torch.Tensor] = None
|
|
176
|
+
self._mps_constant_buffer: Optional[torch.Tensor] = None
|
|
177
|
+
self._mlx_buffer: Optional[mx.array] = None
|
|
178
|
+
self._mlx_constant_buffer: Optional[mx.array] = None
|
|
179
|
+
|
|
180
|
+
# Zero-copy memory tracking
|
|
181
|
+
self.zero_copy_arrays: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
|
182
|
+
self.unified_memory_regions: Dict[str, Any] = {}
|
|
183
|
+
|
|
184
|
+
# Enable zero-copy strategy for MLX if requested
|
|
185
|
+
if self.enable_zero_copy and strategy == AllocationStrategy.UNIFIED:
|
|
186
|
+
self.strategy = AllocationStrategy.ZERO_COPY
|
|
187
|
+
logger.info("Zero-copy memory strategy enabled for MLX")
|
|
188
|
+
|
|
189
|
+
logger.info(f"Initializing memory pool with strategy: {self.strategy.value}")
|
|
190
|
+
self._initialize_pool()
|
|
191
|
+
|
|
192
|
+
# Register cleanup on exit
|
|
193
|
+
atexit.register(self.cleanup)
|
|
194
|
+
|
|
195
|
+
def cleanup(self) -> None:
|
|
196
|
+
"""Clean up allocated resources to prevent leaks."""
|
|
197
|
+
try:
|
|
198
|
+
# Release MPS buffers
|
|
199
|
+
if self._mps_buffer is not None:
|
|
200
|
+
del self._mps_buffer
|
|
201
|
+
self._mps_buffer = None
|
|
202
|
+
|
|
203
|
+
if self._mps_constant_buffer is not None:
|
|
204
|
+
del self._mps_constant_buffer
|
|
205
|
+
self._mps_constant_buffer = None
|
|
206
|
+
|
|
207
|
+
# Release MLX buffers
|
|
208
|
+
if self._mlx_buffer is not None:
|
|
209
|
+
del self._mlx_buffer
|
|
210
|
+
self._mlx_buffer = None
|
|
211
|
+
|
|
212
|
+
if self._mlx_constant_buffer is not None:
|
|
213
|
+
del self._mlx_constant_buffer
|
|
214
|
+
self._mlx_constant_buffer = None
|
|
215
|
+
|
|
216
|
+
# Force synchronization and cleanup
|
|
217
|
+
if self.device == "mps" and torch.backends.mps.is_available():
|
|
218
|
+
torch.mps.synchronize()
|
|
219
|
+
if hasattr(torch.mps, 'empty_cache'):
|
|
220
|
+
torch.mps.empty_cache()
|
|
221
|
+
except Exception:
|
|
222
|
+
pass # Ignore errors during cleanup
|
|
223
|
+
|
|
224
|
+
def _should_use_bfloat16(self) -> bool:
|
|
225
|
+
"""
|
|
226
|
+
Determine if bfloat16 should be used based on hardware.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
True if bfloat16 is supported and beneficial
|
|
230
|
+
"""
|
|
231
|
+
if self.gpu_validator.gpu_info and self.gpu_validator.gpu_info.supports_bfloat16:
|
|
232
|
+
if self.device == "mps":
|
|
233
|
+
# Check PyTorch bfloat16 support on MPS
|
|
234
|
+
try:
|
|
235
|
+
test_tensor = torch.tensor([1.0], dtype=torch.bfloat16, device="mps")
|
|
236
|
+
return True
|
|
237
|
+
except (RuntimeError, TypeError):
|
|
238
|
+
return False
|
|
239
|
+
elif self.device == "mlx":
|
|
240
|
+
# Check MLX bfloat16 support
|
|
241
|
+
return hasattr(mx, 'bfloat16')
|
|
242
|
+
return False
|
|
243
|
+
|
|
244
|
+
def _get_optimal_dtype(self) -> Any:
|
|
245
|
+
"""
|
|
246
|
+
Get optimal dtype for current device and hardware.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
torch.dtype or mx.dtype
|
|
250
|
+
"""
|
|
251
|
+
if self.device == "mps":
|
|
252
|
+
if self.use_bfloat16:
|
|
253
|
+
return torch.bfloat16
|
|
254
|
+
else:
|
|
255
|
+
return torch.float16
|
|
256
|
+
elif self.device == "mlx":
|
|
257
|
+
if self.use_bfloat16 and hasattr(mx, 'bfloat16'):
|
|
258
|
+
return mx.bfloat16
|
|
259
|
+
else:
|
|
260
|
+
return mx.float16
|
|
261
|
+
else:
|
|
262
|
+
return torch.float32
|
|
263
|
+
|
|
264
|
+
def _initialize_pool(self) -> None:
|
|
265
|
+
"""Initialize the memory pool with pre-allocated buffers."""
|
|
266
|
+
if self.device == "mps" and torch.backends.mps.is_available():
|
|
267
|
+
self._initialize_mps_pool()
|
|
268
|
+
elif self.device == "mlx":
|
|
269
|
+
self._initialize_mlx_pool()
|
|
270
|
+
else:
|
|
271
|
+
raise RuntimeError(f"Unsupported device: {self.device}")
|
|
272
|
+
|
|
273
|
+
def _initialize_mps_pool(self) -> None:
|
|
274
|
+
"""Initialize MPS memory pool with optimal dtype and constant memory."""
|
|
275
|
+
try:
|
|
276
|
+
device = torch.device("mps")
|
|
277
|
+
|
|
278
|
+
# Calculate number of elements based on dtype size
|
|
279
|
+
if self.optimal_dtype in [torch.float16, torch.bfloat16]:
|
|
280
|
+
element_size = 2 # 16-bit types
|
|
281
|
+
else:
|
|
282
|
+
element_size = 4 # 32-bit types
|
|
283
|
+
|
|
284
|
+
# Initialize regular memory pool
|
|
285
|
+
num_elements = self.pool_size // element_size
|
|
286
|
+
self._mps_buffer = torch.empty(
|
|
287
|
+
num_elements,
|
|
288
|
+
dtype=self.optimal_dtype,
|
|
289
|
+
device=device
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Initialize constant memory pool (for weights)
|
|
293
|
+
constant_elements = self.CONSTANT_POOL_SIZE // element_size
|
|
294
|
+
self._mps_constant_buffer = torch.empty(
|
|
295
|
+
constant_elements,
|
|
296
|
+
dtype=self.optimal_dtype,
|
|
297
|
+
device=device
|
|
298
|
+
)
|
|
299
|
+
# Mark as read-only after initialization
|
|
300
|
+
self._mps_constant_buffer.requires_grad_(False)
|
|
301
|
+
|
|
302
|
+
dtype_name = str(self.optimal_dtype).split('.')[-1]
|
|
303
|
+
logger.info(f"MPS memory pool initialized: dtype={dtype_name}, constant_pool={self.CONSTANT_POOL_SIZE / (1024*1024):.1f}MB")
|
|
304
|
+
if not self.silent:
|
|
305
|
+
print(f"Initialized MPS pool with dtype: {dtype_name}")
|
|
306
|
+
print(f"Constant memory pool: {self.CONSTANT_POOL_SIZE / (1024*1024):.1f}MB")
|
|
307
|
+
|
|
308
|
+
if self.strategy == AllocationStrategy.UNIFIED:
|
|
309
|
+
self._create_unified_blocks()
|
|
310
|
+
else:
|
|
311
|
+
self._create_segmented_blocks()
|
|
312
|
+
|
|
313
|
+
# Create constant memory blocks
|
|
314
|
+
self._create_constant_blocks()
|
|
315
|
+
|
|
316
|
+
except Exception as e:
|
|
317
|
+
error_msg = str(e).lower()
|
|
318
|
+
|
|
319
|
+
# Check if this is a dtype support issue vs memory issue
|
|
320
|
+
if self.optimal_dtype == torch.bfloat16 and 'bfloat16' in error_msg:
|
|
321
|
+
# bfloat16 not supported, fallback to float16
|
|
322
|
+
self.optimal_dtype = torch.float16
|
|
323
|
+
self.use_bfloat16 = False
|
|
324
|
+
logger.info("bfloat16 not supported, falling back to float16")
|
|
325
|
+
if not self.silent:
|
|
326
|
+
print("bfloat16 not supported, falling back to float16")
|
|
327
|
+
self._initialize_mps_pool() # Retry with float16
|
|
328
|
+
elif 'invalid buffer size' in error_msg or 'out of memory' in error_msg:
|
|
329
|
+
# Memory allocation failed - provide helpful error
|
|
330
|
+
pool_size_gb = self.pool_size / (1024**3)
|
|
331
|
+
raise RuntimeError(
|
|
332
|
+
f"Failed to allocate {pool_size_gb:.2f}GB memory pool. "
|
|
333
|
+
f"Insufficient memory available. Consider reducing pool size or "
|
|
334
|
+
f"freeing up system memory. Error: {e}"
|
|
335
|
+
)
|
|
336
|
+
else:
|
|
337
|
+
# Other errors - pass through
|
|
338
|
+
raise RuntimeError(f"Failed to initialize MPS pool: {e}")
|
|
339
|
+
|
|
340
|
+
def _initialize_mlx_pool(self) -> None:
|
|
341
|
+
"""Initialize MLX memory pool with optimal dtype and zero-copy support."""
|
|
342
|
+
try:
|
|
343
|
+
# Calculate number of elements based on dtype size
|
|
344
|
+
if self.optimal_dtype in [mx.float16, getattr(mx, 'bfloat16', mx.float16)]:
|
|
345
|
+
element_size = 2 # 16-bit types
|
|
346
|
+
else:
|
|
347
|
+
element_size = 4 # 32-bit types
|
|
348
|
+
|
|
349
|
+
num_elements = self.pool_size // element_size
|
|
350
|
+
|
|
351
|
+
# For zero-copy, create unified memory that can be shared
|
|
352
|
+
if self.strategy == AllocationStrategy.ZERO_COPY:
|
|
353
|
+
logger.info("Creating MLX zero-copy unified memory pool")
|
|
354
|
+
# MLX arrays are already zero-copy between CPU/GPU
|
|
355
|
+
self._mlx_buffer = mx.zeros(
|
|
356
|
+
(num_elements,),
|
|
357
|
+
dtype=self.optimal_dtype
|
|
358
|
+
)
|
|
359
|
+
# Force evaluation to allocate unified memory
|
|
360
|
+
mx.eval(self._mlx_buffer)
|
|
361
|
+
|
|
362
|
+
# Also create constant buffer for weights
|
|
363
|
+
constant_elements = self.CONSTANT_POOL_SIZE // element_size
|
|
364
|
+
self._mlx_constant_buffer = mx.zeros(
|
|
365
|
+
(constant_elements,),
|
|
366
|
+
dtype=self.optimal_dtype
|
|
367
|
+
)
|
|
368
|
+
mx.eval(self._mlx_constant_buffer)
|
|
369
|
+
|
|
370
|
+
dtype_name = str(self.optimal_dtype).split('.')[-1]
|
|
371
|
+
logger.info(f"MLX zero-copy pool initialized: dtype={dtype_name}, size={self.pool_size / (1024**3):.1f}GB")
|
|
372
|
+
if not self.silent:
|
|
373
|
+
print(f"Initialized MLX zero-copy pool with dtype: {dtype_name}")
|
|
374
|
+
print(f"Zero-copy unified memory: {self.pool_size / (1024**3):.1f}GB")
|
|
375
|
+
else:
|
|
376
|
+
# Standard MLX buffer
|
|
377
|
+
self._mlx_buffer = mx.zeros(
|
|
378
|
+
(num_elements,),
|
|
379
|
+
dtype=self.optimal_dtype
|
|
380
|
+
)
|
|
381
|
+
mx.eval(self._mlx_buffer)
|
|
382
|
+
|
|
383
|
+
dtype_name = str(self.optimal_dtype).split('.')[-1]
|
|
384
|
+
logger.info(f"MLX pool initialized: dtype={dtype_name}")
|
|
385
|
+
if not self.silent:
|
|
386
|
+
print(f"Initialized MLX pool with dtype: {dtype_name}")
|
|
387
|
+
|
|
388
|
+
if self.strategy in [AllocationStrategy.UNIFIED, AllocationStrategy.ZERO_COPY]:
|
|
389
|
+
self._create_unified_blocks()
|
|
390
|
+
else:
|
|
391
|
+
self._create_segmented_blocks()
|
|
392
|
+
|
|
393
|
+
# Create constant blocks for MLX if needed
|
|
394
|
+
if self._mlx_constant_buffer is not None:
|
|
395
|
+
self._create_constant_blocks()
|
|
396
|
+
|
|
397
|
+
except Exception as e:
|
|
398
|
+
# Fallback to float16 if bfloat16 fails
|
|
399
|
+
if hasattr(mx, 'bfloat16') and self.optimal_dtype == mx.bfloat16:
|
|
400
|
+
self.optimal_dtype = mx.float16
|
|
401
|
+
self.use_bfloat16 = False
|
|
402
|
+
logger.info("MLX bfloat16 not supported, falling back to float16")
|
|
403
|
+
if not self.silent:
|
|
404
|
+
print("bfloat16 not supported, falling back to float16")
|
|
405
|
+
self._initialize_mlx_pool() # Retry with float16
|
|
406
|
+
else:
|
|
407
|
+
raise RuntimeError(f"Failed to initialize MLX pool: {e}")
|
|
408
|
+
|
|
409
|
+
def _create_unified_blocks(self) -> None:
|
|
410
|
+
"""Create a single unified memory block."""
|
|
411
|
+
block = MemoryBlock(
|
|
412
|
+
block_id="unified_0",
|
|
413
|
+
size=self.pool_size,
|
|
414
|
+
offset=0,
|
|
415
|
+
allocated=False,
|
|
416
|
+
allocation_time=None,
|
|
417
|
+
last_access=None,
|
|
418
|
+
device_type=self.device,
|
|
419
|
+
tensor_ref=None,
|
|
420
|
+
metadata={"type": "unified"},
|
|
421
|
+
is_constant=False,
|
|
422
|
+
is_read_only=False
|
|
423
|
+
)
|
|
424
|
+
self.blocks.append(block)
|
|
425
|
+
|
|
426
|
+
def _create_constant_blocks(self) -> None:
|
|
427
|
+
"""Create constant memory blocks for weights."""
|
|
428
|
+
# Create smaller blocks for better allocation flexibility
|
|
429
|
+
block_sizes = [
|
|
430
|
+
4 * 1024 * 1024, # 4MB blocks
|
|
431
|
+
16 * 1024 * 1024, # 16MB blocks
|
|
432
|
+
]
|
|
433
|
+
|
|
434
|
+
offset = 0
|
|
435
|
+
block_id = 0
|
|
436
|
+
remaining_size = self.CONSTANT_POOL_SIZE
|
|
437
|
+
|
|
438
|
+
for size in block_sizes:
|
|
439
|
+
while remaining_size >= size:
|
|
440
|
+
block = MemoryBlock(
|
|
441
|
+
block_id=f"constant_{block_id}",
|
|
442
|
+
size=size,
|
|
443
|
+
offset=offset,
|
|
444
|
+
allocated=False,
|
|
445
|
+
allocation_time=None,
|
|
446
|
+
last_access=None,
|
|
447
|
+
device_type=self.device,
|
|
448
|
+
tensor_ref=None,
|
|
449
|
+
metadata={"type": "constant", "size_class": size},
|
|
450
|
+
is_constant=True,
|
|
451
|
+
is_read_only=True
|
|
452
|
+
)
|
|
453
|
+
self.constant_blocks.append(block)
|
|
454
|
+
offset += size
|
|
455
|
+
remaining_size -= size
|
|
456
|
+
block_id += 1
|
|
457
|
+
|
|
458
|
+
# Add remainder as final block
|
|
459
|
+
if remaining_size > 0:
|
|
460
|
+
block = MemoryBlock(
|
|
461
|
+
block_id=f"constant_{block_id}",
|
|
462
|
+
size=remaining_size,
|
|
463
|
+
offset=offset,
|
|
464
|
+
allocated=False,
|
|
465
|
+
allocation_time=None,
|
|
466
|
+
last_access=None,
|
|
467
|
+
device_type=self.device,
|
|
468
|
+
tensor_ref=None,
|
|
469
|
+
metadata={"type": "constant", "size_class": "remainder"},
|
|
470
|
+
is_constant=True,
|
|
471
|
+
is_read_only=True
|
|
472
|
+
)
|
|
473
|
+
self.constant_blocks.append(block)
|
|
474
|
+
|
|
475
|
+
def _create_segmented_blocks(self) -> None:
|
|
476
|
+
"""Create segmented memory blocks of various sizes."""
|
|
477
|
+
offset = 0
|
|
478
|
+
block_id = 0
|
|
479
|
+
|
|
480
|
+
remaining_size = self.pool_size
|
|
481
|
+
|
|
482
|
+
for size_class in reversed(self.BLOCK_SIZES):
|
|
483
|
+
while remaining_size >= size_class:
|
|
484
|
+
block = MemoryBlock(
|
|
485
|
+
block_id=f"block_{block_id}",
|
|
486
|
+
size=size_class,
|
|
487
|
+
offset=offset,
|
|
488
|
+
allocated=False,
|
|
489
|
+
allocation_time=None,
|
|
490
|
+
last_access=None,
|
|
491
|
+
device_type=self.device,
|
|
492
|
+
tensor_ref=None,
|
|
493
|
+
metadata={"size_class": size_class}
|
|
494
|
+
)
|
|
495
|
+
self.blocks.append(block)
|
|
496
|
+
offset += size_class
|
|
497
|
+
remaining_size -= size_class
|
|
498
|
+
block_id += 1
|
|
499
|
+
|
|
500
|
+
if remaining_size > 0:
|
|
501
|
+
block = MemoryBlock(
|
|
502
|
+
block_id=f"block_{block_id}",
|
|
503
|
+
size=remaining_size,
|
|
504
|
+
offset=offset,
|
|
505
|
+
allocated=False,
|
|
506
|
+
allocation_time=None,
|
|
507
|
+
last_access=None,
|
|
508
|
+
device_type=self.device,
|
|
509
|
+
tensor_ref=None,
|
|
510
|
+
metadata={"size_class": "remainder"}
|
|
511
|
+
)
|
|
512
|
+
self.blocks.append(block)
|
|
513
|
+
|
|
514
|
+
def allocate(
|
|
515
|
+
self,
|
|
516
|
+
size: int,
|
|
517
|
+
dtype: Optional[Any] = None,
|
|
518
|
+
is_constant: bool = False
|
|
519
|
+
) -> Optional[Any]:
|
|
520
|
+
"""
|
|
521
|
+
Allocate memory from the pool.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
size: Size in bytes
|
|
525
|
+
dtype: Data type for the tensor
|
|
526
|
+
is_constant: Whether to allocate from constant memory pool
|
|
527
|
+
|
|
528
|
+
Returns:
|
|
529
|
+
Allocated tensor or None if allocation fails
|
|
530
|
+
"""
|
|
531
|
+
with self._lock:
|
|
532
|
+
if is_constant:
|
|
533
|
+
# Allocate from constant memory pool
|
|
534
|
+
block = self._find_constant_block(size)
|
|
535
|
+
if block is None:
|
|
536
|
+
return None
|
|
537
|
+
|
|
538
|
+
tensor = self._create_tensor_from_block(block, size, dtype, is_constant=True)
|
|
539
|
+
block.mark_allocated(tensor)
|
|
540
|
+
|
|
541
|
+
self.allocated_constant_memory += size
|
|
542
|
+
else:
|
|
543
|
+
# Allocate from regular pool
|
|
544
|
+
block = self._find_block(size)
|
|
545
|
+
|
|
546
|
+
if block is None:
|
|
547
|
+
self._try_defragment()
|
|
548
|
+
block = self._find_block(size)
|
|
549
|
+
|
|
550
|
+
if block is None:
|
|
551
|
+
return None
|
|
552
|
+
|
|
553
|
+
tensor = self._create_tensor_from_block(block, size, dtype, is_constant=False)
|
|
554
|
+
block.mark_allocated(tensor)
|
|
555
|
+
|
|
556
|
+
self.allocated_memory += size
|
|
557
|
+
self.peak_memory = max(self.peak_memory, self.allocated_memory)
|
|
558
|
+
|
|
559
|
+
return tensor
|
|
560
|
+
|
|
561
|
+
def allocate_weights(
|
|
562
|
+
self,
|
|
563
|
+
size: int,
|
|
564
|
+
dtype: Optional[Any] = None
|
|
565
|
+
) -> Optional[Any]:
|
|
566
|
+
"""
|
|
567
|
+
Allocate memory for weights using constant memory.
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
size: Size in bytes
|
|
571
|
+
dtype: Data type for the tensor
|
|
572
|
+
|
|
573
|
+
Returns:
|
|
574
|
+
Allocated tensor in constant memory or regular memory as fallback
|
|
575
|
+
"""
|
|
576
|
+
# Try constant memory first
|
|
577
|
+
tensor = self.allocate(size, dtype, is_constant=True)
|
|
578
|
+
|
|
579
|
+
# Fallback to regular memory if constant memory is full
|
|
580
|
+
if tensor is None:
|
|
581
|
+
logger.info("Constant memory full, falling back to regular memory for weights")
|
|
582
|
+
if not self.silent:
|
|
583
|
+
print("Constant memory full, falling back to regular memory for weights")
|
|
584
|
+
tensor = self.allocate(size, dtype, is_constant=False)
|
|
585
|
+
|
|
586
|
+
return tensor
|
|
587
|
+
|
|
588
|
+
def _find_block(self, size: int) -> Optional[MemoryBlock]:
|
|
589
|
+
"""Find a suitable block for allocation."""
|
|
590
|
+
if self.strategy == AllocationStrategy.BEST_FIT:
|
|
591
|
+
return self._best_fit(size)
|
|
592
|
+
elif self.strategy == AllocationStrategy.FIRST_FIT:
|
|
593
|
+
return self._first_fit(size)
|
|
594
|
+
else:
|
|
595
|
+
return self._first_fit(size)
|
|
596
|
+
|
|
597
|
+
def _find_constant_block(self, size: int) -> Optional[MemoryBlock]:
|
|
598
|
+
"""Find a suitable constant memory block for allocation."""
|
|
599
|
+
for block in self.constant_blocks:
|
|
600
|
+
if block.is_free() and block.size >= size:
|
|
601
|
+
return block
|
|
602
|
+
return None
|
|
603
|
+
|
|
604
|
+
def _best_fit(self, size: int) -> Optional[MemoryBlock]:
|
|
605
|
+
"""Find the smallest block that fits the requested size."""
|
|
606
|
+
best_block = None
|
|
607
|
+
best_waste = float('inf')
|
|
608
|
+
|
|
609
|
+
for block in self.blocks:
|
|
610
|
+
if block.is_free() and block.size >= size:
|
|
611
|
+
waste = block.size - size
|
|
612
|
+
if waste < best_waste:
|
|
613
|
+
best_waste = waste
|
|
614
|
+
best_block = block
|
|
615
|
+
|
|
616
|
+
return best_block
|
|
617
|
+
|
|
618
|
+
def _first_fit(self, size: int) -> Optional[MemoryBlock]:
|
|
619
|
+
"""Find the first block that fits the requested size."""
|
|
620
|
+
for block in self.blocks:
|
|
621
|
+
if block.is_free() and block.size >= size:
|
|
622
|
+
return block
|
|
623
|
+
return None
|
|
624
|
+
|
|
625
|
+
def _create_tensor_from_block(
|
|
626
|
+
self,
|
|
627
|
+
block: MemoryBlock,
|
|
628
|
+
size: int,
|
|
629
|
+
dtype: Optional[Any],
|
|
630
|
+
is_constant: bool = False
|
|
631
|
+
) -> Any:
|
|
632
|
+
"""Create a tensor view from a memory block with zero-copy support."""
|
|
633
|
+
# Use zero-copy for MLX arrays when enabled
|
|
634
|
+
if self.device == "mlx" and self.strategy == AllocationStrategy.ZERO_COPY:
|
|
635
|
+
return self._create_zero_copy_array(block, size, dtype, is_constant)
|
|
636
|
+
|
|
637
|
+
if self.device == "mps":
|
|
638
|
+
if dtype is None:
|
|
639
|
+
dtype = self.optimal_dtype
|
|
640
|
+
|
|
641
|
+
# Select buffer based on memory type
|
|
642
|
+
buffer = self._mps_constant_buffer if is_constant else self._mps_buffer
|
|
643
|
+
|
|
644
|
+
# Get element size based on buffer dtype
|
|
645
|
+
if buffer.dtype in [torch.float16, torch.bfloat16]:
|
|
646
|
+
buffer_element_size = 2
|
|
647
|
+
else:
|
|
648
|
+
buffer_element_size = 4
|
|
649
|
+
|
|
650
|
+
# Calculate number of elements needed
|
|
651
|
+
if dtype in [torch.float16, torch.bfloat16]:
|
|
652
|
+
target_element_size = 2
|
|
653
|
+
else:
|
|
654
|
+
target_element_size = 4
|
|
655
|
+
|
|
656
|
+
num_elements = size // target_element_size
|
|
657
|
+
start_idx = block.offset // buffer_element_size
|
|
658
|
+
end_idx = start_idx + (size // buffer_element_size)
|
|
659
|
+
|
|
660
|
+
tensor_view = buffer[start_idx:end_idx].view(-1)
|
|
661
|
+
|
|
662
|
+
# Convert dtype if needed
|
|
663
|
+
if dtype != buffer.dtype:
|
|
664
|
+
tensor_view = tensor_view.to(dtype)
|
|
665
|
+
|
|
666
|
+
# Mark as non-gradients for constant memory
|
|
667
|
+
if is_constant:
|
|
668
|
+
tensor_view.requires_grad_(False)
|
|
669
|
+
|
|
670
|
+
return tensor_view[:num_elements]
|
|
671
|
+
|
|
672
|
+
elif self.device == "mlx":
|
|
673
|
+
if dtype is None:
|
|
674
|
+
dtype = self.optimal_dtype
|
|
675
|
+
|
|
676
|
+
# Get element size based on buffer dtype
|
|
677
|
+
if self._mlx_buffer.dtype in [mx.float16, getattr(mx, 'bfloat16', mx.float16)]:
|
|
678
|
+
buffer_element_size = 2
|
|
679
|
+
else:
|
|
680
|
+
buffer_element_size = 4
|
|
681
|
+
|
|
682
|
+
# Calculate number of elements
|
|
683
|
+
if dtype in [mx.float16, getattr(mx, 'bfloat16', mx.float16)]:
|
|
684
|
+
target_element_size = 2
|
|
685
|
+
else:
|
|
686
|
+
target_element_size = 4
|
|
687
|
+
|
|
688
|
+
num_elements = size // target_element_size
|
|
689
|
+
start_idx = block.offset // buffer_element_size
|
|
690
|
+
end_idx = start_idx + (size // buffer_element_size)
|
|
691
|
+
|
|
692
|
+
array_view = self._mlx_buffer[start_idx:end_idx]
|
|
693
|
+
|
|
694
|
+
# Convert dtype if needed (MLX handles this automatically)
|
|
695
|
+
if dtype != self._mlx_buffer.dtype:
|
|
696
|
+
array_view = array_view.astype(dtype)
|
|
697
|
+
|
|
698
|
+
return array_view[:num_elements]
|
|
699
|
+
|
|
700
|
+
def _create_zero_copy_array(
|
|
701
|
+
self,
|
|
702
|
+
block: MemoryBlock,
|
|
703
|
+
size: int,
|
|
704
|
+
dtype: Optional[Any],
|
|
705
|
+
is_constant: bool = False
|
|
706
|
+
) -> mx.array:
|
|
707
|
+
"""Create MLX array with zero-copy from unified memory."""
|
|
708
|
+
logger.debug(f"Creating zero-copy array: size={size}, constant={is_constant}")
|
|
709
|
+
|
|
710
|
+
if dtype is None:
|
|
711
|
+
dtype = self.optimal_dtype
|
|
712
|
+
|
|
713
|
+
# Select buffer based on memory type
|
|
714
|
+
buffer = self._mlx_constant_buffer if is_constant else self._mlx_buffer
|
|
715
|
+
|
|
716
|
+
# Calculate slice for zero-copy view
|
|
717
|
+
if dtype in [mx.float16, getattr(mx, 'bfloat16', mx.float16)]:
|
|
718
|
+
element_size = 2
|
|
719
|
+
else:
|
|
720
|
+
element_size = 4
|
|
721
|
+
|
|
722
|
+
num_elements = size // element_size
|
|
723
|
+
start_idx = block.offset // element_size
|
|
724
|
+
end_idx = start_idx + num_elements
|
|
725
|
+
|
|
726
|
+
# Create zero-copy view - no data movement
|
|
727
|
+
array_view = buffer[start_idx:end_idx]
|
|
728
|
+
|
|
729
|
+
# Convert dtype if needed (MLX does this efficiently)
|
|
730
|
+
if dtype != buffer.dtype:
|
|
731
|
+
array_view = array_view.astype(dtype)
|
|
732
|
+
|
|
733
|
+
# Track zero-copy arrays for monitoring
|
|
734
|
+
array_id = f"zero_copy_{block.block_id}_{id(array_view)}"
|
|
735
|
+
self.zero_copy_arrays[array_id] = array_view
|
|
736
|
+
|
|
737
|
+
# Store metadata for unified memory region
|
|
738
|
+
self.unified_memory_regions[block.block_id] = {
|
|
739
|
+
'size': size,
|
|
740
|
+
'dtype': str(dtype),
|
|
741
|
+
'zero_copy': True,
|
|
742
|
+
'constant': is_constant
|
|
743
|
+
}
|
|
744
|
+
|
|
745
|
+
logger.debug(f"Zero-copy array created: id={array_id}, zero_copy_arrays={len(self.zero_copy_arrays)}")
|
|
746
|
+
return array_view
|
|
747
|
+
|
|
748
|
+
def deallocate(self, tensor: Any) -> bool:
|
|
749
|
+
"""
|
|
750
|
+
Deallocate memory back to the pool.
|
|
751
|
+
|
|
752
|
+
Args:
|
|
753
|
+
tensor: Tensor to deallocate
|
|
754
|
+
|
|
755
|
+
Returns:
|
|
756
|
+
True if deallocation successful
|
|
757
|
+
"""
|
|
758
|
+
with self._lock:
|
|
759
|
+
for block in self.blocks:
|
|
760
|
+
if block.tensor_ref is not None:
|
|
761
|
+
ref = block.tensor_ref()
|
|
762
|
+
if ref is tensor:
|
|
763
|
+
block.mark_free()
|
|
764
|
+
self.allocated_memory -= block.size
|
|
765
|
+
return True
|
|
766
|
+
return False
|
|
767
|
+
|
|
768
|
+
def _try_defragment(self) -> None:
|
|
769
|
+
"""Attempt to defragment the memory pool."""
|
|
770
|
+
free_blocks = [b for b in self.blocks if b.is_free()]
|
|
771
|
+
|
|
772
|
+
if len(free_blocks) < 2:
|
|
773
|
+
return
|
|
774
|
+
|
|
775
|
+
free_blocks.sort(key=lambda b: b.offset)
|
|
776
|
+
|
|
777
|
+
merged = []
|
|
778
|
+
current = free_blocks[0]
|
|
779
|
+
|
|
780
|
+
for block in free_blocks[1:]:
|
|
781
|
+
if current.offset + current.size == block.offset:
|
|
782
|
+
current = MemoryBlock(
|
|
783
|
+
block_id=f"merged_{current.block_id}_{block.block_id}",
|
|
784
|
+
size=current.size + block.size,
|
|
785
|
+
offset=current.offset,
|
|
786
|
+
allocated=False,
|
|
787
|
+
allocation_time=None,
|
|
788
|
+
last_access=None,
|
|
789
|
+
device_type=self.device,
|
|
790
|
+
tensor_ref=None,
|
|
791
|
+
metadata={"merged": True}
|
|
792
|
+
)
|
|
793
|
+
else:
|
|
794
|
+
merged.append(current)
|
|
795
|
+
current = block
|
|
796
|
+
|
|
797
|
+
merged.append(current)
|
|
798
|
+
|
|
799
|
+
allocated_blocks = [b for b in self.blocks if not b.is_free()]
|
|
800
|
+
self.blocks = allocated_blocks + merged
|
|
801
|
+
|
|
802
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
803
|
+
"""Get memory pool statistics."""
|
|
804
|
+
with self._lock:
|
|
805
|
+
free_blocks = sum(1 for b in self.blocks if b.is_free())
|
|
806
|
+
allocated_blocks = len(self.blocks) - free_blocks
|
|
807
|
+
free_memory = sum(b.size for b in self.blocks if b.is_free())
|
|
808
|
+
|
|
809
|
+
# Constant memory stats
|
|
810
|
+
constant_free_blocks = sum(1 for b in self.constant_blocks if b.is_free())
|
|
811
|
+
constant_allocated_blocks = len(self.constant_blocks) - constant_free_blocks
|
|
812
|
+
constant_free_memory = sum(b.size for b in self.constant_blocks if b.is_free())
|
|
813
|
+
|
|
814
|
+
# Get dtype information
|
|
815
|
+
dtype_name = str(self.optimal_dtype).split('.')[-1]
|
|
816
|
+
dtype_bits = 16 if self.optimal_dtype in [torch.float16, torch.bfloat16,
|
|
817
|
+
mx.float16, getattr(mx, 'bfloat16', mx.float16)] else 32
|
|
818
|
+
|
|
819
|
+
# Calculate zero-copy statistics
|
|
820
|
+
zero_copy_count = len(self.zero_copy_arrays) if hasattr(self, 'zero_copy_arrays') else 0
|
|
821
|
+
unified_regions = len(self.unified_memory_regions) if hasattr(self, 'unified_memory_regions') else 0
|
|
822
|
+
|
|
823
|
+
stats = {
|
|
824
|
+
"pool_size": self.pool_size,
|
|
825
|
+
"allocated_memory": self.allocated_memory,
|
|
826
|
+
"free_memory": free_memory,
|
|
827
|
+
"peak_memory": self.peak_memory,
|
|
828
|
+
"total_blocks": len(self.blocks),
|
|
829
|
+
"allocated_blocks": allocated_blocks,
|
|
830
|
+
"free_blocks": free_blocks,
|
|
831
|
+
"fragmentation": 1.0 - (free_memory / (self.pool_size - self.allocated_memory + 0.01)),
|
|
832
|
+
"constant_pool_size": self.CONSTANT_POOL_SIZE,
|
|
833
|
+
"constant_allocated": self.allocated_constant_memory,
|
|
834
|
+
"constant_free": constant_free_memory,
|
|
835
|
+
"constant_blocks": len(self.constant_blocks),
|
|
836
|
+
"constant_allocated_blocks": constant_allocated_blocks,
|
|
837
|
+
"device": self.device,
|
|
838
|
+
"strategy": self.strategy.value,
|
|
839
|
+
"dtype": dtype_name,
|
|
840
|
+
"dtype_bits": dtype_bits,
|
|
841
|
+
"gpu_family": self.gpu_validator.gpu_info.gpu_family if self.gpu_validator.gpu_info else "unknown",
|
|
842
|
+
"memory_efficiency": f"{dtype_bits}-bit precision, {50 if dtype_bits == 16 else 0}% memory savings",
|
|
843
|
+
"constant_memory_benefit": "15-20% bandwidth improvement for weights"
|
|
844
|
+
}
|
|
845
|
+
|
|
846
|
+
# Add zero-copy statistics if MLX with zero-copy enabled
|
|
847
|
+
if self.device == "mlx" and self.strategy == AllocationStrategy.ZERO_COPY:
|
|
848
|
+
stats.update({
|
|
849
|
+
"zero_copy_enabled": True,
|
|
850
|
+
"zero_copy_arrays": zero_copy_count,
|
|
851
|
+
"unified_memory_regions": unified_regions,
|
|
852
|
+
"zero_copy_benefit": "Eliminates CPU-GPU transfer overhead"
|
|
853
|
+
})
|
|
854
|
+
logger.debug(f"Zero-copy stats: arrays={zero_copy_count}, regions={unified_regions}")
|
|
855
|
+
else:
|
|
856
|
+
stats["zero_copy_enabled"] = False
|
|
857
|
+
|
|
858
|
+
return stats
|
|
859
|
+
|
|
860
|
+
def reset(self) -> None:
|
|
861
|
+
"""Reset the memory pool."""
|
|
862
|
+
with self._lock:
|
|
863
|
+
for block in self.blocks:
|
|
864
|
+
block.mark_free()
|
|
865
|
+
|
|
866
|
+
self.allocated_memory = 0
|
|
867
|
+
|
|
868
|
+
if self.strategy == AllocationStrategy.UNIFIED:
|
|
869
|
+
self.blocks = []
|
|
870
|
+
self._create_unified_blocks()
|
|
871
|
+
|
|
872
|
+
def optimize_layout(self) -> None:
|
|
873
|
+
"""Optimize memory layout for better cache locality."""
|
|
874
|
+
with self._lock:
|
|
875
|
+
allocated_blocks = [b for b in self.blocks if not b.is_free()]
|
|
876
|
+
|
|
877
|
+
allocated_blocks.sort(key=lambda b: b.last_access or datetime.min, reverse=True)
|
|
878
|
+
|
|
879
|
+
self._try_defragment()
|
|
880
|
+
|
|
881
|
+
def __del__(self):
|
|
882
|
+
"""Cleanup when pool is destroyed."""
|
|
883
|
+
if self._mps_buffer is not None:
|
|
884
|
+
del self._mps_buffer
|
|
885
|
+
if self._mlx_buffer is not None:
|
|
886
|
+
del self._mlx_buffer
|