quarterbit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quarterbit/__init__.py +94 -0
- quarterbit/torch/__init__.py +22 -0
- quarterbit/torch/functional.py +229 -0
- quarterbit/torch/optim.py +728 -0
- quarterbit/torch/utils.py +239 -0
- quarterbit-0.1.0.dist-info/METADATA +122 -0
- quarterbit-0.1.0.dist-info/RECORD +10 -0
- quarterbit-0.1.0.dist-info/WHEEL +5 -0
- quarterbit-0.1.0.dist-info/licenses/LICENSE +53 -0
- quarterbit-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,728 @@
|
|
|
1
|
+
"""
|
|
2
|
+
QuarterBit Optimizers
|
|
3
|
+
=====================
|
|
4
|
+
|
|
5
|
+
Precision optimizers for PyTorch.
|
|
6
|
+
Copyright (c) 2026 Clouthier Simulation Labs. All rights reserved.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch.optim import Optimizer
|
|
11
|
+
from typing import List, Optional, Callable, Iterable
|
|
12
|
+
import ctypes
|
|
13
|
+
from .utils import get_lib, is_available
|
|
14
|
+
|
|
15
|
+
def _ptr(tensor):
|
|
16
|
+
"""Get ctypes pointer to tensor data."""
|
|
17
|
+
if tensor is None:
|
|
18
|
+
return None
|
|
19
|
+
return ctypes.cast(tensor.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SGD(Optimizer):
|
|
23
|
+
"""
|
|
24
|
+
EFT-powered Stochastic Gradient Descent.
|
|
25
|
+
|
|
26
|
+
Drop-in replacement for torch.optim.SGD with precise gradient accumulation.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
params: Iterable of parameters to optimize
|
|
30
|
+
lr: Learning rate
|
|
31
|
+
momentum: Momentum factor (default: 0)
|
|
32
|
+
dampening: Dampening for momentum (default: 0)
|
|
33
|
+
weight_decay: Weight decay (L2 penalty) (default: 0)
|
|
34
|
+
nesterov: Enables Nesterov momentum (default: False)
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> from quarterbit.torch import SGD
|
|
38
|
+
>>> optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
params: Iterable[torch.nn.Parameter],
|
|
44
|
+
lr: float = 1e-3,
|
|
45
|
+
momentum: float = 0,
|
|
46
|
+
dampening: float = 0,
|
|
47
|
+
weight_decay: float = 0,
|
|
48
|
+
nesterov: bool = False
|
|
49
|
+
):
|
|
50
|
+
if lr < 0.0:
|
|
51
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
|
52
|
+
if momentum < 0.0:
|
|
53
|
+
raise ValueError(f"Invalid momentum value: {momentum}")
|
|
54
|
+
if weight_decay < 0.0:
|
|
55
|
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
56
|
+
if nesterov and (momentum <= 0 or dampening != 0):
|
|
57
|
+
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
|
58
|
+
|
|
59
|
+
defaults = dict(
|
|
60
|
+
lr=lr, momentum=momentum, dampening=dampening,
|
|
61
|
+
weight_decay=weight_decay, nesterov=nesterov
|
|
62
|
+
)
|
|
63
|
+
super().__init__(params, defaults)
|
|
64
|
+
|
|
65
|
+
self._use_eft = is_available()
|
|
66
|
+
if self._use_eft:
|
|
67
|
+
self._lib = get_lib()
|
|
68
|
+
|
|
69
|
+
# Initialize compensation buffers
|
|
70
|
+
for group in self.param_groups:
|
|
71
|
+
for p in group['params']:
|
|
72
|
+
state = self.state[p]
|
|
73
|
+
state['weight_comp'] = torch.zeros_like(p.data)
|
|
74
|
+
if momentum != 0:
|
|
75
|
+
state['momentum_buffer'] = torch.zeros_like(p.data)
|
|
76
|
+
state['momentum_comp'] = torch.zeros_like(p.data)
|
|
77
|
+
|
|
78
|
+
@torch.no_grad()
|
|
79
|
+
def step(self, closure: Optional[Callable] = None):
|
|
80
|
+
"""Performs a single optimization step."""
|
|
81
|
+
loss = None
|
|
82
|
+
if closure is not None:
|
|
83
|
+
with torch.enable_grad():
|
|
84
|
+
loss = closure()
|
|
85
|
+
|
|
86
|
+
for group in self.param_groups:
|
|
87
|
+
lr = group['lr']
|
|
88
|
+
momentum = group['momentum']
|
|
89
|
+
dampening = group['dampening']
|
|
90
|
+
weight_decay = group['weight_decay']
|
|
91
|
+
nesterov = group['nesterov']
|
|
92
|
+
|
|
93
|
+
for p in group['params']:
|
|
94
|
+
if p.grad is None:
|
|
95
|
+
continue
|
|
96
|
+
|
|
97
|
+
grad = p.grad.data
|
|
98
|
+
state = self.state[p]
|
|
99
|
+
|
|
100
|
+
if self._use_eft and p.is_cuda and p.dtype == torch.float32:
|
|
101
|
+
# Use EFT kernel
|
|
102
|
+
n = p.numel()
|
|
103
|
+
|
|
104
|
+
if momentum != 0:
|
|
105
|
+
# TODO: Implement momentum with EFT
|
|
106
|
+
# For now, fall back to standard for momentum
|
|
107
|
+
self._standard_step(p, grad, state, group)
|
|
108
|
+
else:
|
|
109
|
+
# Pure SGD with EFT
|
|
110
|
+
self._lib.eft_sgd_step(
|
|
111
|
+
_ptr(p.data),
|
|
112
|
+
_ptr(state['weight_comp']),
|
|
113
|
+
_ptr(grad),
|
|
114
|
+
ctypes.c_float(lr),
|
|
115
|
+
ctypes.c_float(weight_decay),
|
|
116
|
+
ctypes.c_int(n)
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
# Fallback to standard PyTorch
|
|
120
|
+
self._standard_step(p, grad, state, group)
|
|
121
|
+
|
|
122
|
+
return loss
|
|
123
|
+
|
|
124
|
+
def _standard_step(self, p, grad, state, group):
|
|
125
|
+
"""Standard PyTorch SGD step (fallback)."""
|
|
126
|
+
weight_decay = group['weight_decay']
|
|
127
|
+
momentum = group['momentum']
|
|
128
|
+
dampening = group['dampening']
|
|
129
|
+
nesterov = group['nesterov']
|
|
130
|
+
lr = group['lr']
|
|
131
|
+
|
|
132
|
+
if weight_decay != 0:
|
|
133
|
+
grad = grad.add(p.data, alpha=weight_decay)
|
|
134
|
+
|
|
135
|
+
if momentum != 0:
|
|
136
|
+
buf = state.get('momentum_buffer')
|
|
137
|
+
if buf is None:
|
|
138
|
+
buf = torch.clone(grad).detach()
|
|
139
|
+
state['momentum_buffer'] = buf
|
|
140
|
+
else:
|
|
141
|
+
buf.mul_(momentum).add_(grad, alpha=1 - dampening)
|
|
142
|
+
|
|
143
|
+
if nesterov:
|
|
144
|
+
grad = grad.add(buf, alpha=momentum)
|
|
145
|
+
else:
|
|
146
|
+
grad = buf
|
|
147
|
+
|
|
148
|
+
p.data.add_(grad, alpha=-lr)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class Adam(Optimizer):
|
|
152
|
+
"""
|
|
153
|
+
EFT-powered Adam optimizer.
|
|
154
|
+
|
|
155
|
+
Drop-in replacement for torch.optim.Adam with precise moment accumulation.
|
|
156
|
+
|
|
157
|
+
The key insight: Adam's moment estimates (m and v) accumulate small updates
|
|
158
|
+
over millions of steps. Standard FP32 loses precision. EFT tracks the
|
|
159
|
+
exact error, ensuring stable training.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
params: Iterable of parameters to optimize
|
|
163
|
+
lr: Learning rate (default: 1e-3)
|
|
164
|
+
betas: Coefficients for computing running averages (default: (0.9, 0.999))
|
|
165
|
+
eps: Term added to denominator for numerical stability (default: 1e-8)
|
|
166
|
+
weight_decay: Weight decay (L2 penalty) (default: 0)
|
|
167
|
+
amsgrad: Whether to use AMSGrad variant (default: False)
|
|
168
|
+
|
|
169
|
+
Example:
|
|
170
|
+
>>> from quarterbit.torch import Adam
|
|
171
|
+
>>> optimizer = Adam(model.parameters(), lr=1e-3)
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
params: Iterable[torch.nn.Parameter],
|
|
177
|
+
lr: float = 1e-3,
|
|
178
|
+
betas: tuple = (0.9, 0.999),
|
|
179
|
+
eps: float = 1e-8,
|
|
180
|
+
weight_decay: float = 0,
|
|
181
|
+
amsgrad: bool = False
|
|
182
|
+
):
|
|
183
|
+
if lr < 0.0:
|
|
184
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
|
185
|
+
if eps < 0.0:
|
|
186
|
+
raise ValueError(f"Invalid epsilon value: {eps}")
|
|
187
|
+
if not 0.0 <= betas[0] < 1.0:
|
|
188
|
+
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
|
189
|
+
if not 0.0 <= betas[1] < 1.0:
|
|
190
|
+
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
|
191
|
+
if weight_decay < 0.0:
|
|
192
|
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
193
|
+
|
|
194
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
|
|
195
|
+
super().__init__(params, defaults)
|
|
196
|
+
|
|
197
|
+
self._use_eft = is_available()
|
|
198
|
+
if self._use_eft:
|
|
199
|
+
self._lib = get_lib()
|
|
200
|
+
|
|
201
|
+
def _init_state(self, p):
|
|
202
|
+
"""Initialize state for a parameter."""
|
|
203
|
+
state = self.state[p]
|
|
204
|
+
state['step'] = 0
|
|
205
|
+
state['exp_avg'] = torch.zeros_like(p.data)
|
|
206
|
+
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
|
207
|
+
|
|
208
|
+
# EFT compensation buffers
|
|
209
|
+
state['weight_comp'] = torch.zeros_like(p.data)
|
|
210
|
+
state['exp_avg_comp'] = torch.zeros_like(p.data)
|
|
211
|
+
state['exp_avg_sq_comp'] = torch.zeros_like(p.data)
|
|
212
|
+
|
|
213
|
+
return state
|
|
214
|
+
|
|
215
|
+
@torch.no_grad()
|
|
216
|
+
def step(self, closure: Optional[Callable] = None):
|
|
217
|
+
"""Performs a single optimization step."""
|
|
218
|
+
loss = None
|
|
219
|
+
if closure is not None:
|
|
220
|
+
with torch.enable_grad():
|
|
221
|
+
loss = closure()
|
|
222
|
+
|
|
223
|
+
for group in self.param_groups:
|
|
224
|
+
beta1, beta2 = group['betas']
|
|
225
|
+
lr = group['lr']
|
|
226
|
+
eps = group['eps']
|
|
227
|
+
weight_decay = group['weight_decay']
|
|
228
|
+
amsgrad = group['amsgrad']
|
|
229
|
+
|
|
230
|
+
for p in group['params']:
|
|
231
|
+
if p.grad is None:
|
|
232
|
+
continue
|
|
233
|
+
|
|
234
|
+
grad = p.grad.data
|
|
235
|
+
if grad.is_sparse:
|
|
236
|
+
raise RuntimeError("Adam does not support sparse gradients")
|
|
237
|
+
|
|
238
|
+
# Initialize state if needed
|
|
239
|
+
state = self.state[p]
|
|
240
|
+
if len(state) == 0:
|
|
241
|
+
state = self._init_state(p)
|
|
242
|
+
|
|
243
|
+
state['step'] += 1
|
|
244
|
+
step = state['step']
|
|
245
|
+
|
|
246
|
+
if self._use_eft and p.is_cuda and p.dtype == torch.float32 and not amsgrad:
|
|
247
|
+
# Use EFT kernel
|
|
248
|
+
n = p.numel()
|
|
249
|
+
|
|
250
|
+
self._lib.eft_adam_step(
|
|
251
|
+
_ptr(p.data),
|
|
252
|
+
_ptr(state['weight_comp']),
|
|
253
|
+
_ptr(grad),
|
|
254
|
+
_ptr(state['exp_avg']),
|
|
255
|
+
_ptr(state['exp_avg_sq']),
|
|
256
|
+
_ptr(state['exp_avg_comp']),
|
|
257
|
+
_ptr(state['exp_avg_sq_comp']),
|
|
258
|
+
ctypes.c_float(lr),
|
|
259
|
+
ctypes.c_float(beta1),
|
|
260
|
+
ctypes.c_float(beta2),
|
|
261
|
+
ctypes.c_float(eps),
|
|
262
|
+
ctypes.c_float(weight_decay),
|
|
263
|
+
ctypes.c_int(step),
|
|
264
|
+
ctypes.c_int(n)
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
# Fallback to standard PyTorch Adam
|
|
268
|
+
self._standard_step(p, grad, state, group)
|
|
269
|
+
|
|
270
|
+
return loss
|
|
271
|
+
|
|
272
|
+
def _standard_step(self, p, grad, state, group):
|
|
273
|
+
"""Standard PyTorch Adam step (fallback)."""
|
|
274
|
+
beta1, beta2 = group['betas']
|
|
275
|
+
lr = group['lr']
|
|
276
|
+
eps = group['eps']
|
|
277
|
+
weight_decay = group['weight_decay']
|
|
278
|
+
amsgrad = group['amsgrad']
|
|
279
|
+
step = state['step']
|
|
280
|
+
|
|
281
|
+
exp_avg = state['exp_avg']
|
|
282
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
283
|
+
|
|
284
|
+
if weight_decay != 0:
|
|
285
|
+
grad = grad.add(p.data, alpha=weight_decay)
|
|
286
|
+
|
|
287
|
+
# Decay the first and second moment running average coefficient
|
|
288
|
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
289
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
290
|
+
|
|
291
|
+
if amsgrad:
|
|
292
|
+
max_exp_avg_sq = state.get('max_exp_avg_sq')
|
|
293
|
+
if max_exp_avg_sq is None:
|
|
294
|
+
max_exp_avg_sq = torch.zeros_like(p.data)
|
|
295
|
+
state['max_exp_avg_sq'] = max_exp_avg_sq
|
|
296
|
+
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
|
297
|
+
denom = (max_exp_avg_sq.sqrt() / (1 - beta2 ** step) ** 0.5).add_(eps)
|
|
298
|
+
else:
|
|
299
|
+
bias_correction1 = 1 - beta1 ** step
|
|
300
|
+
bias_correction2 = 1 - beta2 ** step
|
|
301
|
+
denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(eps)
|
|
302
|
+
|
|
303
|
+
step_size = lr / (1 - beta1 ** step) if not amsgrad else lr / (1 - beta1 ** step)
|
|
304
|
+
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class AdamW(Adam):
|
|
308
|
+
"""
|
|
309
|
+
EFT-powered AdamW optimizer (Adam with decoupled weight decay).
|
|
310
|
+
|
|
311
|
+
Same as Adam but with proper weight decay (not L2 regularization).
|
|
312
|
+
|
|
313
|
+
Example:
|
|
314
|
+
>>> from quarterbit.torch import AdamW
|
|
315
|
+
>>> optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
|
|
316
|
+
"""
|
|
317
|
+
|
|
318
|
+
def __init__(
|
|
319
|
+
self,
|
|
320
|
+
params: Iterable[torch.nn.Parameter],
|
|
321
|
+
lr: float = 1e-3,
|
|
322
|
+
betas: tuple = (0.9, 0.999),
|
|
323
|
+
eps: float = 1e-8,
|
|
324
|
+
weight_decay: float = 1e-2,
|
|
325
|
+
amsgrad: bool = False
|
|
326
|
+
):
|
|
327
|
+
super().__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
|
|
328
|
+
|
|
329
|
+
def _standard_step(self, p, grad, state, group):
|
|
330
|
+
"""AdamW step - weight decay applied directly to weights."""
|
|
331
|
+
beta1, beta2 = group['betas']
|
|
332
|
+
lr = group['lr']
|
|
333
|
+
eps = group['eps']
|
|
334
|
+
weight_decay = group['weight_decay']
|
|
335
|
+
amsgrad = group['amsgrad']
|
|
336
|
+
step = state['step']
|
|
337
|
+
|
|
338
|
+
exp_avg = state['exp_avg']
|
|
339
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
340
|
+
|
|
341
|
+
# Decoupled weight decay (applied to weights, not gradients)
|
|
342
|
+
if weight_decay != 0:
|
|
343
|
+
p.data.mul_(1 - lr * weight_decay)
|
|
344
|
+
|
|
345
|
+
# Standard Adam update
|
|
346
|
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
347
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
348
|
+
|
|
349
|
+
if amsgrad:
|
|
350
|
+
max_exp_avg_sq = state.get('max_exp_avg_sq')
|
|
351
|
+
if max_exp_avg_sq is None:
|
|
352
|
+
max_exp_avg_sq = torch.zeros_like(p.data)
|
|
353
|
+
state['max_exp_avg_sq'] = max_exp_avg_sq
|
|
354
|
+
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
|
355
|
+
denom = (max_exp_avg_sq.sqrt() / (1 - beta2 ** step) ** 0.5).add_(eps)
|
|
356
|
+
else:
|
|
357
|
+
bias_correction1 = 1 - beta1 ** step
|
|
358
|
+
bias_correction2 = 1 - beta2 ** step
|
|
359
|
+
denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(eps)
|
|
360
|
+
|
|
361
|
+
step_size = lr / (1 - beta1 ** step)
|
|
362
|
+
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
class CompactAdam(Optimizer):
|
|
366
|
+
"""
|
|
367
|
+
Memory-efficient Adam optimizer with compressed state storage.
|
|
368
|
+
|
|
369
|
+
Uses FP16+FP4 split weights and FP4+INT2 compressed optimizer states
|
|
370
|
+
for ~1.7x memory savings vs standard FP32 Adam, while maintaining
|
|
371
|
+
training quality through error feedback.
|
|
372
|
+
|
|
373
|
+
Memory budget per parameter:
|
|
374
|
+
- Standard FP32 Adam: 16 bytes/param
|
|
375
|
+
- CompactAdam (no EF): 9.25 bytes/param (1.7x savings)
|
|
376
|
+
- CompactAdam (with EF): 10.75 bytes/param (1.5x savings)
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
params: Iterable of parameters to optimize
|
|
380
|
+
lr: Learning rate (default: 1e-3)
|
|
381
|
+
betas: Coefficients for computing running averages (default: (0.9, 0.999))
|
|
382
|
+
eps: Term added to denominator for numerical stability (default: 1e-8)
|
|
383
|
+
weight_decay: Weight decay (L2 penalty) (default: 0)
|
|
384
|
+
use_error_feedback: Enable error feedback for better precision (default: True)
|
|
385
|
+
|
|
386
|
+
Example:
|
|
387
|
+
>>> from quarterbit.torch import CompactAdam
|
|
388
|
+
>>> optimizer = CompactAdam(model.parameters(), lr=1e-3)
|
|
389
|
+
>>> # For maximum memory savings (slight precision tradeoff):
|
|
390
|
+
>>> optimizer = CompactAdam(model.parameters(), lr=1e-3, use_error_feedback=False)
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
def __init__(
|
|
394
|
+
self,
|
|
395
|
+
params: Iterable[torch.nn.Parameter],
|
|
396
|
+
lr: float = 1e-3,
|
|
397
|
+
betas: tuple = (0.9, 0.999),
|
|
398
|
+
eps: float = 1e-8,
|
|
399
|
+
weight_decay: float = 0,
|
|
400
|
+
use_error_feedback: bool = True
|
|
401
|
+
):
|
|
402
|
+
if lr < 0.0:
|
|
403
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
|
404
|
+
if eps < 0.0:
|
|
405
|
+
raise ValueError(f"Invalid epsilon value: {eps}")
|
|
406
|
+
if not 0.0 <= betas[0] < 1.0:
|
|
407
|
+
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
|
408
|
+
if not 0.0 <= betas[1] < 1.0:
|
|
409
|
+
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
|
410
|
+
if weight_decay < 0.0:
|
|
411
|
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
412
|
+
|
|
413
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
|
414
|
+
super().__init__(params, defaults)
|
|
415
|
+
|
|
416
|
+
self._use_ef = 1 if use_error_feedback else 0
|
|
417
|
+
self._handles = {} # Per-parameter handles
|
|
418
|
+
self._step_count = 0
|
|
419
|
+
|
|
420
|
+
# Try to load compact library
|
|
421
|
+
self._lib = None
|
|
422
|
+
self._use_compact = False
|
|
423
|
+
try:
|
|
424
|
+
from .utils import get_lib
|
|
425
|
+
lib = get_lib()
|
|
426
|
+
# Check if compact functions exist
|
|
427
|
+
if hasattr(lib, 'compact_adam_create'):
|
|
428
|
+
self._lib = lib
|
|
429
|
+
self._use_compact = True
|
|
430
|
+
except Exception:
|
|
431
|
+
pass
|
|
432
|
+
|
|
433
|
+
def _get_handle(self, p):
|
|
434
|
+
"""Get or create compact adam handle for parameter."""
|
|
435
|
+
if id(p) not in self._handles:
|
|
436
|
+
if self._use_compact and p.is_cuda and p.dtype == torch.float32:
|
|
437
|
+
n = p.numel()
|
|
438
|
+
handle = self._lib.compact_adam_create(ctypes.c_int(n), ctypes.c_int(self._use_ef))
|
|
439
|
+
if handle:
|
|
440
|
+
self._handles[id(p)] = handle
|
|
441
|
+
else:
|
|
442
|
+
self._handles[id(p)] = None
|
|
443
|
+
else:
|
|
444
|
+
self._handles[id(p)] = None
|
|
445
|
+
return self._handles[id(p)]
|
|
446
|
+
|
|
447
|
+
@torch.no_grad()
|
|
448
|
+
def step(self, closure: Optional[Callable] = None):
|
|
449
|
+
"""Performs a single optimization step."""
|
|
450
|
+
loss = None
|
|
451
|
+
if closure is not None:
|
|
452
|
+
with torch.enable_grad():
|
|
453
|
+
loss = closure()
|
|
454
|
+
|
|
455
|
+
self._step_count += 1
|
|
456
|
+
|
|
457
|
+
for group in self.param_groups:
|
|
458
|
+
beta1, beta2 = group['betas']
|
|
459
|
+
lr = group['lr']
|
|
460
|
+
eps = group['eps']
|
|
461
|
+
weight_decay = group['weight_decay']
|
|
462
|
+
|
|
463
|
+
for p in group['params']:
|
|
464
|
+
if p.grad is None:
|
|
465
|
+
continue
|
|
466
|
+
|
|
467
|
+
grad = p.grad.data
|
|
468
|
+
if grad.is_sparse:
|
|
469
|
+
raise RuntimeError("CompactAdam does not support sparse gradients")
|
|
470
|
+
|
|
471
|
+
handle = self._get_handle(p)
|
|
472
|
+
|
|
473
|
+
if handle is not None and self._use_compact:
|
|
474
|
+
# Use compact CUDA kernel
|
|
475
|
+
n = p.numel()
|
|
476
|
+
|
|
477
|
+
# Apply weight decay to gradients (L2 regularization)
|
|
478
|
+
if weight_decay != 0:
|
|
479
|
+
grad = grad.add(p.data, alpha=weight_decay)
|
|
480
|
+
|
|
481
|
+
self._lib.compact_adam_step(
|
|
482
|
+
handle,
|
|
483
|
+
_ptr(p.data),
|
|
484
|
+
_ptr(grad),
|
|
485
|
+
ctypes.c_float(lr),
|
|
486
|
+
ctypes.c_float(beta1),
|
|
487
|
+
ctypes.c_float(beta2),
|
|
488
|
+
ctypes.c_float(eps),
|
|
489
|
+
ctypes.c_int(self._step_count)
|
|
490
|
+
)
|
|
491
|
+
else:
|
|
492
|
+
# Fallback to standard Adam
|
|
493
|
+
self._standard_step(p, grad, group)
|
|
494
|
+
|
|
495
|
+
return loss
|
|
496
|
+
|
|
497
|
+
def _standard_step(self, p, grad, group):
|
|
498
|
+
"""Standard PyTorch Adam step (fallback for CPU or non-float32)."""
|
|
499
|
+
beta1, beta2 = group['betas']
|
|
500
|
+
lr = group['lr']
|
|
501
|
+
eps = group['eps']
|
|
502
|
+
weight_decay = group['weight_decay']
|
|
503
|
+
|
|
504
|
+
state = self.state[p]
|
|
505
|
+
if len(state) == 0:
|
|
506
|
+
state['step'] = 0
|
|
507
|
+
state['exp_avg'] = torch.zeros_like(p.data)
|
|
508
|
+
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
|
509
|
+
|
|
510
|
+
state['step'] += 1
|
|
511
|
+
step = state['step']
|
|
512
|
+
|
|
513
|
+
exp_avg = state['exp_avg']
|
|
514
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
515
|
+
|
|
516
|
+
if weight_decay != 0:
|
|
517
|
+
grad = grad.add(p.data, alpha=weight_decay)
|
|
518
|
+
|
|
519
|
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
520
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
521
|
+
|
|
522
|
+
bias_correction1 = 1 - beta1 ** step
|
|
523
|
+
bias_correction2 = 1 - beta2 ** step
|
|
524
|
+
denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(eps)
|
|
525
|
+
|
|
526
|
+
step_size = lr / bias_correction1
|
|
527
|
+
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
|
528
|
+
|
|
529
|
+
def __del__(self):
|
|
530
|
+
"""Clean up handles on deletion."""
|
|
531
|
+
if hasattr(self, '_handles') and hasattr(self, '_lib') and self._lib is not None:
|
|
532
|
+
for handle in self._handles.values():
|
|
533
|
+
if handle is not None:
|
|
534
|
+
try:
|
|
535
|
+
self._lib.compact_adam_destroy(handle)
|
|
536
|
+
except Exception:
|
|
537
|
+
pass
|
|
538
|
+
|
|
539
|
+
def memory_savings(self):
|
|
540
|
+
"""Report memory savings vs standard FP32 Adam."""
|
|
541
|
+
total_params = sum(p.numel() for group in self.param_groups for p in group['params'])
|
|
542
|
+
fp32_bytes = total_params * 16 # 4 bytes each for param, grad, m, v
|
|
543
|
+
compact_bytes = total_params * (10.75 if self._use_ef else 9.25)
|
|
544
|
+
savings = 1 - compact_bytes / fp32_bytes
|
|
545
|
+
return {
|
|
546
|
+
'total_params': total_params,
|
|
547
|
+
'fp32_bytes': fp32_bytes,
|
|
548
|
+
'compact_bytes': compact_bytes,
|
|
549
|
+
'savings_ratio': fp32_bytes / compact_bytes,
|
|
550
|
+
'savings_percent': savings * 100
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
class CompactEFTAdam(Optimizer):
|
|
555
|
+
"""
|
|
556
|
+
Production Adam optimizer: Memory-efficient + EFT precision.
|
|
557
|
+
|
|
558
|
+
Combines the best of both worlds:
|
|
559
|
+
- Compressed storage: FP16+FP4 weights, FP4+INT2 states (27% memory savings)
|
|
560
|
+
- EFT precision: TwoSum/TwoProduct for exact arithmetic in Adam updates
|
|
561
|
+
|
|
562
|
+
Memory budget per parameter:
|
|
563
|
+
- Standard FP32 Adam: 16 bytes/param
|
|
564
|
+
- CompactEFTAdam: ~13.25 bytes/param (17% savings with full precision)
|
|
565
|
+
|
|
566
|
+
This is the RECOMMENDED optimizer for production training.
|
|
567
|
+
|
|
568
|
+
Args:
|
|
569
|
+
params: Iterable of parameters to optimize
|
|
570
|
+
lr: Learning rate (default: 1e-3)
|
|
571
|
+
betas: Coefficients for computing running averages (default: (0.9, 0.999))
|
|
572
|
+
eps: Term added to denominator for numerical stability (default: 1e-8)
|
|
573
|
+
weight_decay: Weight decay (L2 penalty) (default: 0)
|
|
574
|
+
|
|
575
|
+
Example:
|
|
576
|
+
>>> from quarterbit.torch import CompactEFTAdam
|
|
577
|
+
>>> optimizer = CompactEFTAdam(model.parameters(), lr=1e-3)
|
|
578
|
+
"""
|
|
579
|
+
|
|
580
|
+
def __init__(
|
|
581
|
+
self,
|
|
582
|
+
params: Iterable[torch.nn.Parameter],
|
|
583
|
+
lr: float = 1e-3,
|
|
584
|
+
betas: tuple = (0.9, 0.999),
|
|
585
|
+
eps: float = 1e-8,
|
|
586
|
+
weight_decay: float = 0,
|
|
587
|
+
):
|
|
588
|
+
if lr < 0.0:
|
|
589
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
|
590
|
+
if eps < 0.0:
|
|
591
|
+
raise ValueError(f"Invalid epsilon value: {eps}")
|
|
592
|
+
if not 0.0 <= betas[0] < 1.0:
|
|
593
|
+
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
|
594
|
+
if not 0.0 <= betas[1] < 1.0:
|
|
595
|
+
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
|
596
|
+
if weight_decay < 0.0:
|
|
597
|
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
598
|
+
|
|
599
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
|
600
|
+
super().__init__(params, defaults)
|
|
601
|
+
|
|
602
|
+
self._handles = {}
|
|
603
|
+
self._step_count = 0
|
|
604
|
+
|
|
605
|
+
# Try to load CompactEFT library
|
|
606
|
+
self._lib = None
|
|
607
|
+
self._use_compact_eft = False
|
|
608
|
+
try:
|
|
609
|
+
from .utils import _load_compact_eft_lib
|
|
610
|
+
self._lib = _load_compact_eft_lib()
|
|
611
|
+
if self._lib is not None:
|
|
612
|
+
self._use_compact_eft = True
|
|
613
|
+
except Exception:
|
|
614
|
+
pass
|
|
615
|
+
|
|
616
|
+
def _get_handle(self, p):
|
|
617
|
+
"""Get or create CompactEFT handle for parameter."""
|
|
618
|
+
if id(p) not in self._handles:
|
|
619
|
+
if self._use_compact_eft and p.is_cuda and p.dtype == torch.float32:
|
|
620
|
+
handle = self._lib.compact_eft_adam_create(ctypes.c_int(p.numel()))
|
|
621
|
+
self._handles[id(p)] = handle if handle else None
|
|
622
|
+
else:
|
|
623
|
+
self._handles[id(p)] = None
|
|
624
|
+
return self._handles[id(p)]
|
|
625
|
+
|
|
626
|
+
@torch.no_grad()
|
|
627
|
+
def step(self, closure: Optional[Callable] = None):
|
|
628
|
+
"""Performs a single optimization step."""
|
|
629
|
+
loss = None
|
|
630
|
+
if closure is not None:
|
|
631
|
+
with torch.enable_grad():
|
|
632
|
+
loss = closure()
|
|
633
|
+
|
|
634
|
+
self._step_count += 1
|
|
635
|
+
|
|
636
|
+
for group in self.param_groups:
|
|
637
|
+
beta1, beta2 = group['betas']
|
|
638
|
+
lr = group['lr']
|
|
639
|
+
eps = group['eps']
|
|
640
|
+
weight_decay = group['weight_decay']
|
|
641
|
+
|
|
642
|
+
for p in group['params']:
|
|
643
|
+
if p.grad is None:
|
|
644
|
+
continue
|
|
645
|
+
|
|
646
|
+
grad = p.grad.data
|
|
647
|
+
if grad.is_sparse:
|
|
648
|
+
raise RuntimeError("CompactEFTAdam does not support sparse gradients")
|
|
649
|
+
|
|
650
|
+
handle = self._get_handle(p)
|
|
651
|
+
|
|
652
|
+
if handle is not None and self._use_compact_eft:
|
|
653
|
+
# Apply weight decay to gradients (L2 regularization)
|
|
654
|
+
if weight_decay != 0:
|
|
655
|
+
grad = grad.add(p.data, alpha=weight_decay)
|
|
656
|
+
|
|
657
|
+
self._lib.compact_eft_adam_step(
|
|
658
|
+
handle,
|
|
659
|
+
_ptr(p.data),
|
|
660
|
+
_ptr(grad),
|
|
661
|
+
ctypes.c_float(lr),
|
|
662
|
+
ctypes.c_float(beta1),
|
|
663
|
+
ctypes.c_float(beta2),
|
|
664
|
+
ctypes.c_float(eps),
|
|
665
|
+
ctypes.c_int(self._step_count)
|
|
666
|
+
)
|
|
667
|
+
else:
|
|
668
|
+
# Fallback to standard Adam
|
|
669
|
+
self._standard_step(p, grad, group)
|
|
670
|
+
|
|
671
|
+
return loss
|
|
672
|
+
|
|
673
|
+
def _standard_step(self, p, grad, group):
|
|
674
|
+
"""Standard PyTorch Adam step (fallback for CPU or non-float32)."""
|
|
675
|
+
beta1, beta2 = group['betas']
|
|
676
|
+
lr = group['lr']
|
|
677
|
+
eps = group['eps']
|
|
678
|
+
weight_decay = group['weight_decay']
|
|
679
|
+
|
|
680
|
+
state = self.state[p]
|
|
681
|
+
if len(state) == 0:
|
|
682
|
+
state['step'] = 0
|
|
683
|
+
state['exp_avg'] = torch.zeros_like(p.data)
|
|
684
|
+
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
|
685
|
+
|
|
686
|
+
state['step'] += 1
|
|
687
|
+
step = state['step']
|
|
688
|
+
|
|
689
|
+
exp_avg = state['exp_avg']
|
|
690
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
691
|
+
|
|
692
|
+
if weight_decay != 0:
|
|
693
|
+
grad = grad.add(p.data, alpha=weight_decay)
|
|
694
|
+
|
|
695
|
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
696
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
697
|
+
|
|
698
|
+
bias_correction1 = 1 - beta1 ** step
|
|
699
|
+
bias_correction2 = 1 - beta2 ** step
|
|
700
|
+
denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(eps)
|
|
701
|
+
|
|
702
|
+
step_size = lr / bias_correction1
|
|
703
|
+
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
|
704
|
+
|
|
705
|
+
def __del__(self):
|
|
706
|
+
"""Clean up handles on deletion."""
|
|
707
|
+
if hasattr(self, '_handles') and hasattr(self, '_lib') and self._lib is not None:
|
|
708
|
+
for handle in self._handles.values():
|
|
709
|
+
if handle is not None:
|
|
710
|
+
try:
|
|
711
|
+
self._lib.compact_eft_adam_destroy(handle)
|
|
712
|
+
except Exception:
|
|
713
|
+
pass
|
|
714
|
+
|
|
715
|
+
def memory_savings(self):
|
|
716
|
+
"""Report memory savings vs standard FP32 Adam."""
|
|
717
|
+
total_params = sum(p.numel() for group in self.param_groups for p in group['params'])
|
|
718
|
+
fp32_bytes = total_params * 16
|
|
719
|
+
# Weights: 2.75, states: 2.5, EFT comp: 4, grads: 4 = 13.25 B/param
|
|
720
|
+
compact_eft_bytes = total_params * 13.25
|
|
721
|
+
savings = 1 - compact_eft_bytes / fp32_bytes
|
|
722
|
+
return {
|
|
723
|
+
'total_params': total_params,
|
|
724
|
+
'fp32_bytes': fp32_bytes,
|
|
725
|
+
'compact_eft_bytes': compact_eft_bytes,
|
|
726
|
+
'savings_ratio': fp32_bytes / compact_eft_bytes,
|
|
727
|
+
'savings_percent': savings * 100
|
|
728
|
+
}
|