quarterbit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,728 @@
1
+ """
2
+ QuarterBit Optimizers
3
+ =====================
4
+
5
+ Precision optimizers for PyTorch.
6
+ Copyright (c) 2026 Clouthier Simulation Labs. All rights reserved.
7
+ """
8
+
9
+ import torch
10
+ from torch.optim import Optimizer
11
+ from typing import List, Optional, Callable, Iterable
12
+ import ctypes
13
+ from .utils import get_lib, is_available
14
+
15
+ def _ptr(tensor):
16
+ """Get ctypes pointer to tensor data."""
17
+ if tensor is None:
18
+ return None
19
+ return ctypes.cast(tensor.data_ptr(), ctypes.POINTER(ctypes.c_float))
20
+
21
+
22
+ class SGD(Optimizer):
23
+ """
24
+ EFT-powered Stochastic Gradient Descent.
25
+
26
+ Drop-in replacement for torch.optim.SGD with precise gradient accumulation.
27
+
28
+ Args:
29
+ params: Iterable of parameters to optimize
30
+ lr: Learning rate
31
+ momentum: Momentum factor (default: 0)
32
+ dampening: Dampening for momentum (default: 0)
33
+ weight_decay: Weight decay (L2 penalty) (default: 0)
34
+ nesterov: Enables Nesterov momentum (default: False)
35
+
36
+ Example:
37
+ >>> from quarterbit.torch import SGD
38
+ >>> optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ params: Iterable[torch.nn.Parameter],
44
+ lr: float = 1e-3,
45
+ momentum: float = 0,
46
+ dampening: float = 0,
47
+ weight_decay: float = 0,
48
+ nesterov: bool = False
49
+ ):
50
+ if lr < 0.0:
51
+ raise ValueError(f"Invalid learning rate: {lr}")
52
+ if momentum < 0.0:
53
+ raise ValueError(f"Invalid momentum value: {momentum}")
54
+ if weight_decay < 0.0:
55
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
56
+ if nesterov and (momentum <= 0 or dampening != 0):
57
+ raise ValueError("Nesterov momentum requires a momentum and zero dampening")
58
+
59
+ defaults = dict(
60
+ lr=lr, momentum=momentum, dampening=dampening,
61
+ weight_decay=weight_decay, nesterov=nesterov
62
+ )
63
+ super().__init__(params, defaults)
64
+
65
+ self._use_eft = is_available()
66
+ if self._use_eft:
67
+ self._lib = get_lib()
68
+
69
+ # Initialize compensation buffers
70
+ for group in self.param_groups:
71
+ for p in group['params']:
72
+ state = self.state[p]
73
+ state['weight_comp'] = torch.zeros_like(p.data)
74
+ if momentum != 0:
75
+ state['momentum_buffer'] = torch.zeros_like(p.data)
76
+ state['momentum_comp'] = torch.zeros_like(p.data)
77
+
78
+ @torch.no_grad()
79
+ def step(self, closure: Optional[Callable] = None):
80
+ """Performs a single optimization step."""
81
+ loss = None
82
+ if closure is not None:
83
+ with torch.enable_grad():
84
+ loss = closure()
85
+
86
+ for group in self.param_groups:
87
+ lr = group['lr']
88
+ momentum = group['momentum']
89
+ dampening = group['dampening']
90
+ weight_decay = group['weight_decay']
91
+ nesterov = group['nesterov']
92
+
93
+ for p in group['params']:
94
+ if p.grad is None:
95
+ continue
96
+
97
+ grad = p.grad.data
98
+ state = self.state[p]
99
+
100
+ if self._use_eft and p.is_cuda and p.dtype == torch.float32:
101
+ # Use EFT kernel
102
+ n = p.numel()
103
+
104
+ if momentum != 0:
105
+ # TODO: Implement momentum with EFT
106
+ # For now, fall back to standard for momentum
107
+ self._standard_step(p, grad, state, group)
108
+ else:
109
+ # Pure SGD with EFT
110
+ self._lib.eft_sgd_step(
111
+ _ptr(p.data),
112
+ _ptr(state['weight_comp']),
113
+ _ptr(grad),
114
+ ctypes.c_float(lr),
115
+ ctypes.c_float(weight_decay),
116
+ ctypes.c_int(n)
117
+ )
118
+ else:
119
+ # Fallback to standard PyTorch
120
+ self._standard_step(p, grad, state, group)
121
+
122
+ return loss
123
+
124
+ def _standard_step(self, p, grad, state, group):
125
+ """Standard PyTorch SGD step (fallback)."""
126
+ weight_decay = group['weight_decay']
127
+ momentum = group['momentum']
128
+ dampening = group['dampening']
129
+ nesterov = group['nesterov']
130
+ lr = group['lr']
131
+
132
+ if weight_decay != 0:
133
+ grad = grad.add(p.data, alpha=weight_decay)
134
+
135
+ if momentum != 0:
136
+ buf = state.get('momentum_buffer')
137
+ if buf is None:
138
+ buf = torch.clone(grad).detach()
139
+ state['momentum_buffer'] = buf
140
+ else:
141
+ buf.mul_(momentum).add_(grad, alpha=1 - dampening)
142
+
143
+ if nesterov:
144
+ grad = grad.add(buf, alpha=momentum)
145
+ else:
146
+ grad = buf
147
+
148
+ p.data.add_(grad, alpha=-lr)
149
+
150
+
151
+ class Adam(Optimizer):
152
+ """
153
+ EFT-powered Adam optimizer.
154
+
155
+ Drop-in replacement for torch.optim.Adam with precise moment accumulation.
156
+
157
+ The key insight: Adam's moment estimates (m and v) accumulate small updates
158
+ over millions of steps. Standard FP32 loses precision. EFT tracks the
159
+ exact error, ensuring stable training.
160
+
161
+ Args:
162
+ params: Iterable of parameters to optimize
163
+ lr: Learning rate (default: 1e-3)
164
+ betas: Coefficients for computing running averages (default: (0.9, 0.999))
165
+ eps: Term added to denominator for numerical stability (default: 1e-8)
166
+ weight_decay: Weight decay (L2 penalty) (default: 0)
167
+ amsgrad: Whether to use AMSGrad variant (default: False)
168
+
169
+ Example:
170
+ >>> from quarterbit.torch import Adam
171
+ >>> optimizer = Adam(model.parameters(), lr=1e-3)
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ params: Iterable[torch.nn.Parameter],
177
+ lr: float = 1e-3,
178
+ betas: tuple = (0.9, 0.999),
179
+ eps: float = 1e-8,
180
+ weight_decay: float = 0,
181
+ amsgrad: bool = False
182
+ ):
183
+ if lr < 0.0:
184
+ raise ValueError(f"Invalid learning rate: {lr}")
185
+ if eps < 0.0:
186
+ raise ValueError(f"Invalid epsilon value: {eps}")
187
+ if not 0.0 <= betas[0] < 1.0:
188
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
189
+ if not 0.0 <= betas[1] < 1.0:
190
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
191
+ if weight_decay < 0.0:
192
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
193
+
194
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
195
+ super().__init__(params, defaults)
196
+
197
+ self._use_eft = is_available()
198
+ if self._use_eft:
199
+ self._lib = get_lib()
200
+
201
+ def _init_state(self, p):
202
+ """Initialize state for a parameter."""
203
+ state = self.state[p]
204
+ state['step'] = 0
205
+ state['exp_avg'] = torch.zeros_like(p.data)
206
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
207
+
208
+ # EFT compensation buffers
209
+ state['weight_comp'] = torch.zeros_like(p.data)
210
+ state['exp_avg_comp'] = torch.zeros_like(p.data)
211
+ state['exp_avg_sq_comp'] = torch.zeros_like(p.data)
212
+
213
+ return state
214
+
215
+ @torch.no_grad()
216
+ def step(self, closure: Optional[Callable] = None):
217
+ """Performs a single optimization step."""
218
+ loss = None
219
+ if closure is not None:
220
+ with torch.enable_grad():
221
+ loss = closure()
222
+
223
+ for group in self.param_groups:
224
+ beta1, beta2 = group['betas']
225
+ lr = group['lr']
226
+ eps = group['eps']
227
+ weight_decay = group['weight_decay']
228
+ amsgrad = group['amsgrad']
229
+
230
+ for p in group['params']:
231
+ if p.grad is None:
232
+ continue
233
+
234
+ grad = p.grad.data
235
+ if grad.is_sparse:
236
+ raise RuntimeError("Adam does not support sparse gradients")
237
+
238
+ # Initialize state if needed
239
+ state = self.state[p]
240
+ if len(state) == 0:
241
+ state = self._init_state(p)
242
+
243
+ state['step'] += 1
244
+ step = state['step']
245
+
246
+ if self._use_eft and p.is_cuda and p.dtype == torch.float32 and not amsgrad:
247
+ # Use EFT kernel
248
+ n = p.numel()
249
+
250
+ self._lib.eft_adam_step(
251
+ _ptr(p.data),
252
+ _ptr(state['weight_comp']),
253
+ _ptr(grad),
254
+ _ptr(state['exp_avg']),
255
+ _ptr(state['exp_avg_sq']),
256
+ _ptr(state['exp_avg_comp']),
257
+ _ptr(state['exp_avg_sq_comp']),
258
+ ctypes.c_float(lr),
259
+ ctypes.c_float(beta1),
260
+ ctypes.c_float(beta2),
261
+ ctypes.c_float(eps),
262
+ ctypes.c_float(weight_decay),
263
+ ctypes.c_int(step),
264
+ ctypes.c_int(n)
265
+ )
266
+ else:
267
+ # Fallback to standard PyTorch Adam
268
+ self._standard_step(p, grad, state, group)
269
+
270
+ return loss
271
+
272
+ def _standard_step(self, p, grad, state, group):
273
+ """Standard PyTorch Adam step (fallback)."""
274
+ beta1, beta2 = group['betas']
275
+ lr = group['lr']
276
+ eps = group['eps']
277
+ weight_decay = group['weight_decay']
278
+ amsgrad = group['amsgrad']
279
+ step = state['step']
280
+
281
+ exp_avg = state['exp_avg']
282
+ exp_avg_sq = state['exp_avg_sq']
283
+
284
+ if weight_decay != 0:
285
+ grad = grad.add(p.data, alpha=weight_decay)
286
+
287
+ # Decay the first and second moment running average coefficient
288
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
289
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
290
+
291
+ if amsgrad:
292
+ max_exp_avg_sq = state.get('max_exp_avg_sq')
293
+ if max_exp_avg_sq is None:
294
+ max_exp_avg_sq = torch.zeros_like(p.data)
295
+ state['max_exp_avg_sq'] = max_exp_avg_sq
296
+ torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
297
+ denom = (max_exp_avg_sq.sqrt() / (1 - beta2 ** step) ** 0.5).add_(eps)
298
+ else:
299
+ bias_correction1 = 1 - beta1 ** step
300
+ bias_correction2 = 1 - beta2 ** step
301
+ denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(eps)
302
+
303
+ step_size = lr / (1 - beta1 ** step) if not amsgrad else lr / (1 - beta1 ** step)
304
+ p.data.addcdiv_(exp_avg, denom, value=-step_size)
305
+
306
+
307
+ class AdamW(Adam):
308
+ """
309
+ EFT-powered AdamW optimizer (Adam with decoupled weight decay).
310
+
311
+ Same as Adam but with proper weight decay (not L2 regularization).
312
+
313
+ Example:
314
+ >>> from quarterbit.torch import AdamW
315
+ >>> optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
316
+ """
317
+
318
+ def __init__(
319
+ self,
320
+ params: Iterable[torch.nn.Parameter],
321
+ lr: float = 1e-3,
322
+ betas: tuple = (0.9, 0.999),
323
+ eps: float = 1e-8,
324
+ weight_decay: float = 1e-2,
325
+ amsgrad: bool = False
326
+ ):
327
+ super().__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
328
+
329
+ def _standard_step(self, p, grad, state, group):
330
+ """AdamW step - weight decay applied directly to weights."""
331
+ beta1, beta2 = group['betas']
332
+ lr = group['lr']
333
+ eps = group['eps']
334
+ weight_decay = group['weight_decay']
335
+ amsgrad = group['amsgrad']
336
+ step = state['step']
337
+
338
+ exp_avg = state['exp_avg']
339
+ exp_avg_sq = state['exp_avg_sq']
340
+
341
+ # Decoupled weight decay (applied to weights, not gradients)
342
+ if weight_decay != 0:
343
+ p.data.mul_(1 - lr * weight_decay)
344
+
345
+ # Standard Adam update
346
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
347
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
348
+
349
+ if amsgrad:
350
+ max_exp_avg_sq = state.get('max_exp_avg_sq')
351
+ if max_exp_avg_sq is None:
352
+ max_exp_avg_sq = torch.zeros_like(p.data)
353
+ state['max_exp_avg_sq'] = max_exp_avg_sq
354
+ torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
355
+ denom = (max_exp_avg_sq.sqrt() / (1 - beta2 ** step) ** 0.5).add_(eps)
356
+ else:
357
+ bias_correction1 = 1 - beta1 ** step
358
+ bias_correction2 = 1 - beta2 ** step
359
+ denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(eps)
360
+
361
+ step_size = lr / (1 - beta1 ** step)
362
+ p.data.addcdiv_(exp_avg, denom, value=-step_size)
363
+
364
+
365
+ class CompactAdam(Optimizer):
366
+ """
367
+ Memory-efficient Adam optimizer with compressed state storage.
368
+
369
+ Uses FP16+FP4 split weights and FP4+INT2 compressed optimizer states
370
+ for ~1.7x memory savings vs standard FP32 Adam, while maintaining
371
+ training quality through error feedback.
372
+
373
+ Memory budget per parameter:
374
+ - Standard FP32 Adam: 16 bytes/param
375
+ - CompactAdam (no EF): 9.25 bytes/param (1.7x savings)
376
+ - CompactAdam (with EF): 10.75 bytes/param (1.5x savings)
377
+
378
+ Args:
379
+ params: Iterable of parameters to optimize
380
+ lr: Learning rate (default: 1e-3)
381
+ betas: Coefficients for computing running averages (default: (0.9, 0.999))
382
+ eps: Term added to denominator for numerical stability (default: 1e-8)
383
+ weight_decay: Weight decay (L2 penalty) (default: 0)
384
+ use_error_feedback: Enable error feedback for better precision (default: True)
385
+
386
+ Example:
387
+ >>> from quarterbit.torch import CompactAdam
388
+ >>> optimizer = CompactAdam(model.parameters(), lr=1e-3)
389
+ >>> # For maximum memory savings (slight precision tradeoff):
390
+ >>> optimizer = CompactAdam(model.parameters(), lr=1e-3, use_error_feedback=False)
391
+ """
392
+
393
+ def __init__(
394
+ self,
395
+ params: Iterable[torch.nn.Parameter],
396
+ lr: float = 1e-3,
397
+ betas: tuple = (0.9, 0.999),
398
+ eps: float = 1e-8,
399
+ weight_decay: float = 0,
400
+ use_error_feedback: bool = True
401
+ ):
402
+ if lr < 0.0:
403
+ raise ValueError(f"Invalid learning rate: {lr}")
404
+ if eps < 0.0:
405
+ raise ValueError(f"Invalid epsilon value: {eps}")
406
+ if not 0.0 <= betas[0] < 1.0:
407
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
408
+ if not 0.0 <= betas[1] < 1.0:
409
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
410
+ if weight_decay < 0.0:
411
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
412
+
413
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
414
+ super().__init__(params, defaults)
415
+
416
+ self._use_ef = 1 if use_error_feedback else 0
417
+ self._handles = {} # Per-parameter handles
418
+ self._step_count = 0
419
+
420
+ # Try to load compact library
421
+ self._lib = None
422
+ self._use_compact = False
423
+ try:
424
+ from .utils import get_lib
425
+ lib = get_lib()
426
+ # Check if compact functions exist
427
+ if hasattr(lib, 'compact_adam_create'):
428
+ self._lib = lib
429
+ self._use_compact = True
430
+ except Exception:
431
+ pass
432
+
433
+ def _get_handle(self, p):
434
+ """Get or create compact adam handle for parameter."""
435
+ if id(p) not in self._handles:
436
+ if self._use_compact and p.is_cuda and p.dtype == torch.float32:
437
+ n = p.numel()
438
+ handle = self._lib.compact_adam_create(ctypes.c_int(n), ctypes.c_int(self._use_ef))
439
+ if handle:
440
+ self._handles[id(p)] = handle
441
+ else:
442
+ self._handles[id(p)] = None
443
+ else:
444
+ self._handles[id(p)] = None
445
+ return self._handles[id(p)]
446
+
447
+ @torch.no_grad()
448
+ def step(self, closure: Optional[Callable] = None):
449
+ """Performs a single optimization step."""
450
+ loss = None
451
+ if closure is not None:
452
+ with torch.enable_grad():
453
+ loss = closure()
454
+
455
+ self._step_count += 1
456
+
457
+ for group in self.param_groups:
458
+ beta1, beta2 = group['betas']
459
+ lr = group['lr']
460
+ eps = group['eps']
461
+ weight_decay = group['weight_decay']
462
+
463
+ for p in group['params']:
464
+ if p.grad is None:
465
+ continue
466
+
467
+ grad = p.grad.data
468
+ if grad.is_sparse:
469
+ raise RuntimeError("CompactAdam does not support sparse gradients")
470
+
471
+ handle = self._get_handle(p)
472
+
473
+ if handle is not None and self._use_compact:
474
+ # Use compact CUDA kernel
475
+ n = p.numel()
476
+
477
+ # Apply weight decay to gradients (L2 regularization)
478
+ if weight_decay != 0:
479
+ grad = grad.add(p.data, alpha=weight_decay)
480
+
481
+ self._lib.compact_adam_step(
482
+ handle,
483
+ _ptr(p.data),
484
+ _ptr(grad),
485
+ ctypes.c_float(lr),
486
+ ctypes.c_float(beta1),
487
+ ctypes.c_float(beta2),
488
+ ctypes.c_float(eps),
489
+ ctypes.c_int(self._step_count)
490
+ )
491
+ else:
492
+ # Fallback to standard Adam
493
+ self._standard_step(p, grad, group)
494
+
495
+ return loss
496
+
497
+ def _standard_step(self, p, grad, group):
498
+ """Standard PyTorch Adam step (fallback for CPU or non-float32)."""
499
+ beta1, beta2 = group['betas']
500
+ lr = group['lr']
501
+ eps = group['eps']
502
+ weight_decay = group['weight_decay']
503
+
504
+ state = self.state[p]
505
+ if len(state) == 0:
506
+ state['step'] = 0
507
+ state['exp_avg'] = torch.zeros_like(p.data)
508
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
509
+
510
+ state['step'] += 1
511
+ step = state['step']
512
+
513
+ exp_avg = state['exp_avg']
514
+ exp_avg_sq = state['exp_avg_sq']
515
+
516
+ if weight_decay != 0:
517
+ grad = grad.add(p.data, alpha=weight_decay)
518
+
519
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
520
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
521
+
522
+ bias_correction1 = 1 - beta1 ** step
523
+ bias_correction2 = 1 - beta2 ** step
524
+ denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(eps)
525
+
526
+ step_size = lr / bias_correction1
527
+ p.data.addcdiv_(exp_avg, denom, value=-step_size)
528
+
529
+ def __del__(self):
530
+ """Clean up handles on deletion."""
531
+ if hasattr(self, '_handles') and hasattr(self, '_lib') and self._lib is not None:
532
+ for handle in self._handles.values():
533
+ if handle is not None:
534
+ try:
535
+ self._lib.compact_adam_destroy(handle)
536
+ except Exception:
537
+ pass
538
+
539
+ def memory_savings(self):
540
+ """Report memory savings vs standard FP32 Adam."""
541
+ total_params = sum(p.numel() for group in self.param_groups for p in group['params'])
542
+ fp32_bytes = total_params * 16 # 4 bytes each for param, grad, m, v
543
+ compact_bytes = total_params * (10.75 if self._use_ef else 9.25)
544
+ savings = 1 - compact_bytes / fp32_bytes
545
+ return {
546
+ 'total_params': total_params,
547
+ 'fp32_bytes': fp32_bytes,
548
+ 'compact_bytes': compact_bytes,
549
+ 'savings_ratio': fp32_bytes / compact_bytes,
550
+ 'savings_percent': savings * 100
551
+ }
552
+
553
+
554
+ class CompactEFTAdam(Optimizer):
555
+ """
556
+ Production Adam optimizer: Memory-efficient + EFT precision.
557
+
558
+ Combines the best of both worlds:
559
+ - Compressed storage: FP16+FP4 weights, FP4+INT2 states (27% memory savings)
560
+ - EFT precision: TwoSum/TwoProduct for exact arithmetic in Adam updates
561
+
562
+ Memory budget per parameter:
563
+ - Standard FP32 Adam: 16 bytes/param
564
+ - CompactEFTAdam: ~13.25 bytes/param (17% savings with full precision)
565
+
566
+ This is the RECOMMENDED optimizer for production training.
567
+
568
+ Args:
569
+ params: Iterable of parameters to optimize
570
+ lr: Learning rate (default: 1e-3)
571
+ betas: Coefficients for computing running averages (default: (0.9, 0.999))
572
+ eps: Term added to denominator for numerical stability (default: 1e-8)
573
+ weight_decay: Weight decay (L2 penalty) (default: 0)
574
+
575
+ Example:
576
+ >>> from quarterbit.torch import CompactEFTAdam
577
+ >>> optimizer = CompactEFTAdam(model.parameters(), lr=1e-3)
578
+ """
579
+
580
+ def __init__(
581
+ self,
582
+ params: Iterable[torch.nn.Parameter],
583
+ lr: float = 1e-3,
584
+ betas: tuple = (0.9, 0.999),
585
+ eps: float = 1e-8,
586
+ weight_decay: float = 0,
587
+ ):
588
+ if lr < 0.0:
589
+ raise ValueError(f"Invalid learning rate: {lr}")
590
+ if eps < 0.0:
591
+ raise ValueError(f"Invalid epsilon value: {eps}")
592
+ if not 0.0 <= betas[0] < 1.0:
593
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
594
+ if not 0.0 <= betas[1] < 1.0:
595
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
596
+ if weight_decay < 0.0:
597
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
598
+
599
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
600
+ super().__init__(params, defaults)
601
+
602
+ self._handles = {}
603
+ self._step_count = 0
604
+
605
+ # Try to load CompactEFT library
606
+ self._lib = None
607
+ self._use_compact_eft = False
608
+ try:
609
+ from .utils import _load_compact_eft_lib
610
+ self._lib = _load_compact_eft_lib()
611
+ if self._lib is not None:
612
+ self._use_compact_eft = True
613
+ except Exception:
614
+ pass
615
+
616
+ def _get_handle(self, p):
617
+ """Get or create CompactEFT handle for parameter."""
618
+ if id(p) not in self._handles:
619
+ if self._use_compact_eft and p.is_cuda and p.dtype == torch.float32:
620
+ handle = self._lib.compact_eft_adam_create(ctypes.c_int(p.numel()))
621
+ self._handles[id(p)] = handle if handle else None
622
+ else:
623
+ self._handles[id(p)] = None
624
+ return self._handles[id(p)]
625
+
626
+ @torch.no_grad()
627
+ def step(self, closure: Optional[Callable] = None):
628
+ """Performs a single optimization step."""
629
+ loss = None
630
+ if closure is not None:
631
+ with torch.enable_grad():
632
+ loss = closure()
633
+
634
+ self._step_count += 1
635
+
636
+ for group in self.param_groups:
637
+ beta1, beta2 = group['betas']
638
+ lr = group['lr']
639
+ eps = group['eps']
640
+ weight_decay = group['weight_decay']
641
+
642
+ for p in group['params']:
643
+ if p.grad is None:
644
+ continue
645
+
646
+ grad = p.grad.data
647
+ if grad.is_sparse:
648
+ raise RuntimeError("CompactEFTAdam does not support sparse gradients")
649
+
650
+ handle = self._get_handle(p)
651
+
652
+ if handle is not None and self._use_compact_eft:
653
+ # Apply weight decay to gradients (L2 regularization)
654
+ if weight_decay != 0:
655
+ grad = grad.add(p.data, alpha=weight_decay)
656
+
657
+ self._lib.compact_eft_adam_step(
658
+ handle,
659
+ _ptr(p.data),
660
+ _ptr(grad),
661
+ ctypes.c_float(lr),
662
+ ctypes.c_float(beta1),
663
+ ctypes.c_float(beta2),
664
+ ctypes.c_float(eps),
665
+ ctypes.c_int(self._step_count)
666
+ )
667
+ else:
668
+ # Fallback to standard Adam
669
+ self._standard_step(p, grad, group)
670
+
671
+ return loss
672
+
673
+ def _standard_step(self, p, grad, group):
674
+ """Standard PyTorch Adam step (fallback for CPU or non-float32)."""
675
+ beta1, beta2 = group['betas']
676
+ lr = group['lr']
677
+ eps = group['eps']
678
+ weight_decay = group['weight_decay']
679
+
680
+ state = self.state[p]
681
+ if len(state) == 0:
682
+ state['step'] = 0
683
+ state['exp_avg'] = torch.zeros_like(p.data)
684
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
685
+
686
+ state['step'] += 1
687
+ step = state['step']
688
+
689
+ exp_avg = state['exp_avg']
690
+ exp_avg_sq = state['exp_avg_sq']
691
+
692
+ if weight_decay != 0:
693
+ grad = grad.add(p.data, alpha=weight_decay)
694
+
695
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
696
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
697
+
698
+ bias_correction1 = 1 - beta1 ** step
699
+ bias_correction2 = 1 - beta2 ** step
700
+ denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(eps)
701
+
702
+ step_size = lr / bias_correction1
703
+ p.data.addcdiv_(exp_avg, denom, value=-step_size)
704
+
705
+ def __del__(self):
706
+ """Clean up handles on deletion."""
707
+ if hasattr(self, '_handles') and hasattr(self, '_lib') and self._lib is not None:
708
+ for handle in self._handles.values():
709
+ if handle is not None:
710
+ try:
711
+ self._lib.compact_eft_adam_destroy(handle)
712
+ except Exception:
713
+ pass
714
+
715
+ def memory_savings(self):
716
+ """Report memory savings vs standard FP32 Adam."""
717
+ total_params = sum(p.numel() for group in self.param_groups for p in group['params'])
718
+ fp32_bytes = total_params * 16
719
+ # Weights: 2.75, states: 2.5, EFT comp: 4, grads: 4 = 13.25 B/param
720
+ compact_eft_bytes = total_params * 13.25
721
+ savings = 1 - compact_eft_bytes / fp32_bytes
722
+ return {
723
+ 'total_params': total_params,
724
+ 'fp32_bytes': fp32_bytes,
725
+ 'compact_eft_bytes': compact_eft_bytes,
726
+ 'savings_ratio': fp32_bytes / compact_eft_bytes,
727
+ 'savings_percent': savings * 100
728
+ }