adv-optm 1.2.dev13__py3-none-any.whl → 2.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +85 -64
- adv_optm/optim/Adopt_adv.py +114 -69
- adv_optm/optim/Lion_Prodigy_adv.py +79 -81
- adv_optm/optim/Lion_adv.py +37 -42
- adv_optm/optim/Prodigy_adv.py +105 -85
- adv_optm/optim/Simplified_AdEMAMix.py +92 -51
- adv_optm/optim/__init__.py +1 -1
- adv_optm/util/BF16_Stochastic_Rounding.py +1 -1
- adv_optm/util/Effective_Shape.py +1 -1
- adv_optm/util/Kourkoutas.py +11 -12
- adv_optm/util/NNMF.py +7 -2
- adv_optm/util/Newton_Schulz.py +1 -2
- adv_optm/util/One_Bit_Boolean.py +1 -1
- adv_optm/util/OrthoGrad.py +4 -3
- adv_optm/util/__init__.py +1 -1
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/METADATA +20 -20
- adv_optm-2.dev1.dist-info/RECORD +23 -0
- adv_optm-1.2.dev13.dist-info/RECORD +0 -23
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/top_level.txt +0 -0
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -50,12 +50,6 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
50
50
|
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
51
51
|
A higher value increases the stabilizing influence of the slow
|
|
52
52
|
momentum. (default: 5.0)
|
|
53
|
-
t_alpha (Optional[int]): The number of steps for a linear warmup of the
|
|
54
|
-
`alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
|
|
55
|
-
highly recommended to prevent instability at the beginning of training,
|
|
56
|
-
as it gradually introduces the stabilizing slow momentum term. During
|
|
57
|
-
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
58
|
-
the scheduler is disabled.
|
|
59
53
|
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
60
54
|
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
61
55
|
more responsive, especially for small batch sizes. Enabling this will
|
|
@@ -72,7 +66,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
72
66
|
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
73
67
|
d_coef (float):
|
|
74
68
|
Coefficient in the expression for the estimate of d (default 1.0).
|
|
75
|
-
Values such as 0.5 and 2.0 typically work as well.
|
|
69
|
+
Values such as 0.5 and 2.0 typically work as well.
|
|
76
70
|
Changing this parameter is the preferred way to tune the method.
|
|
77
71
|
growth_rate (float):
|
|
78
72
|
prevent the D estimate from growing faster than this multiplicative rate.
|
|
@@ -82,8 +76,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
82
76
|
If you're using sharded parameters, this should be set to True. The optimizer
|
|
83
77
|
will attempt to auto-detect this, but if you're using an implementation other
|
|
84
78
|
than PyTorch's builtin version, the auto-detection won't work.
|
|
85
|
-
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
86
|
-
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
79
|
+
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
80
|
+
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
87
81
|
Prodigy. Values ~11 are reasonable (default 11).
|
|
88
82
|
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
89
83
|
after the specified optimiser step and release all state memory required by Prodigy
|
|
@@ -108,7 +102,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
108
102
|
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
109
103
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
110
104
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
111
|
-
logging (default: 0).
|
|
105
|
+
logging (default: 0).
|
|
112
106
|
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
113
107
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
114
108
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
@@ -122,7 +116,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
122
116
|
betas: tuple[float, float] = (0.9, 0.999),
|
|
123
117
|
eps: float = 1e-8,
|
|
124
118
|
weight_decay: float = 0.0,
|
|
125
|
-
vector_reshape: bool =
|
|
119
|
+
vector_reshape: bool = False,
|
|
126
120
|
stochastic_rounding: bool = True,
|
|
127
121
|
use_atan2: bool = False,
|
|
128
122
|
cautious_mask: bool = False,
|
|
@@ -131,7 +125,6 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
131
125
|
use_AdEMAMix: bool = False,
|
|
132
126
|
beta3_ema: float = 0.9999,
|
|
133
127
|
alpha: float = 5.0,
|
|
134
|
-
t_alpha: int | None = None,
|
|
135
128
|
Simplified_AdEMAMix: bool = False,
|
|
136
129
|
alpha_grad: float = 100.0,
|
|
137
130
|
nnmf_factor: bool = False,
|
|
@@ -153,6 +146,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
153
146
|
k_warmup_steps: int = 0,
|
|
154
147
|
k_logging: int = 0,
|
|
155
148
|
layer_key_fn: Optional[Callable] = None,
|
|
149
|
+
# Compiled
|
|
150
|
+
compiled_optimizer: bool = False,
|
|
156
151
|
):
|
|
157
152
|
if not (lr >= 0.0):
|
|
158
153
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -188,13 +183,14 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
188
183
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
189
184
|
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
190
185
|
"orthogonal_gradient": orthogonal_gradient,
|
|
191
|
-
"beta3_ema": beta3_ema, "alpha": alpha,
|
|
186
|
+
"beta3_ema": beta3_ema, "alpha": alpha,
|
|
192
187
|
"beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
|
|
193
|
-
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "
|
|
188
|
+
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "slice_p": slice_p,
|
|
194
189
|
"fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps, "d_limiter": d_limiter,
|
|
195
190
|
"alpha_grad": alpha_grad,
|
|
196
191
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
197
192
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
193
|
+
"compiled_optimizer": compiled_optimizer,
|
|
198
194
|
}
|
|
199
195
|
self.stochastic_rounding = stochastic_rounding
|
|
200
196
|
self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
|
|
@@ -203,14 +199,23 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
203
199
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
204
200
|
self.factored = nnmf_factor
|
|
205
201
|
self.fsdp_in_use = fsdp_in_use
|
|
206
|
-
|
|
202
|
+
|
|
207
203
|
self.kourkoutas_beta = kourkoutas_beta
|
|
208
204
|
self.layer_key_fn = layer_key_fn
|
|
209
205
|
|
|
210
206
|
super().__init__(params, defaults)
|
|
207
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
208
|
+
self.device = self.param_groups[0]['params'][0].device
|
|
209
|
+
|
|
211
210
|
if self.kourkoutas_beta:
|
|
212
211
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
212
|
+
|
|
213
213
|
self.init_step()
|
|
214
|
+
self.global_step = 0
|
|
215
|
+
|
|
216
|
+
if compiled_optimizer:
|
|
217
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
218
|
+
self.compile(fullgraph=True)
|
|
214
219
|
|
|
215
220
|
@property
|
|
216
221
|
def supports_fused_back_pass(self):
|
|
@@ -240,32 +245,21 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
240
245
|
self.dlr = self.d * lr
|
|
241
246
|
self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
|
|
242
247
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
return
|
|
247
|
-
|
|
248
|
-
if hasattr(p, "_fsdp_flattened"):
|
|
249
|
-
self.fsdp_in_use = True
|
|
248
|
+
for group in self.param_groups:
|
|
249
|
+
for i, p in enumerate(group['params']):
|
|
250
|
+
self.__init_state(p, group)
|
|
250
251
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
grad = grad.float()
|
|
254
|
-
if group["orthogonal_gradient"]:
|
|
255
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
252
|
+
@torch.no_grad()
|
|
253
|
+
def __init_state(self, p, group):
|
|
256
254
|
state = self.state[p]
|
|
257
255
|
|
|
258
|
-
|
|
259
|
-
if 'step' not in state:
|
|
260
|
-
state['step'] = 0
|
|
256
|
+
if len(state) == 0:
|
|
261
257
|
|
|
262
|
-
|
|
258
|
+
state['factored'] = (
|
|
263
259
|
self.factored and
|
|
264
260
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
265
261
|
)
|
|
266
262
|
|
|
267
|
-
state['factored'] = should_factor
|
|
268
|
-
|
|
269
263
|
slice_p = group['slice_p']
|
|
270
264
|
|
|
271
265
|
dtype = torch.float32 if self.factored else p.dtype
|
|
@@ -277,18 +271,18 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
277
271
|
|
|
278
272
|
# First moment (m)
|
|
279
273
|
if self.beta1 > 0:
|
|
280
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
274
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
281
275
|
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
282
276
|
if not self.grams_moment:
|
|
283
277
|
packed_d2 = (d2 + 7) // 8
|
|
284
278
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
285
279
|
if self.use_AdEMAMix:
|
|
286
|
-
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
280
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
287
281
|
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
288
282
|
packed_d2 = (d2 + 7) // 8
|
|
289
283
|
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
290
284
|
# Second moment (v)
|
|
291
|
-
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
285
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
292
286
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
293
287
|
else: # Fallback to standard AdamW for non-factored tensors
|
|
294
288
|
if self.beta1 > 0:
|
|
@@ -303,25 +297,30 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
303
297
|
else:
|
|
304
298
|
state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
|
|
305
299
|
|
|
306
|
-
|
|
300
|
+
|
|
301
|
+
def __step_parameter(self, p: torch.Tensor, group: dict, d: torch.Tensor | float, dlr: torch.Tensor | float):
|
|
302
|
+
if p.grad is None:
|
|
303
|
+
return
|
|
304
|
+
|
|
305
|
+
grad = p.grad
|
|
306
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
307
|
+
grad = grad.float()
|
|
308
|
+
if group["orthogonal_gradient"]:
|
|
309
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
310
|
+
|
|
311
|
+
state = self.state[p]
|
|
312
|
+
|
|
307
313
|
if group.get('kourkoutas_beta', False):
|
|
308
|
-
# Call prepare_step() once at the beginning of the step for all params
|
|
309
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
310
314
|
# Accumulate current grad's norm for the *next* step
|
|
311
315
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
312
316
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
313
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group
|
|
317
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
314
318
|
else:
|
|
315
319
|
beta2 = self.beta2_default
|
|
316
320
|
|
|
317
321
|
if self.use_AdEMAMix:
|
|
318
322
|
beta3_ema = group['beta3_ema']
|
|
319
323
|
alpha = group['alpha']
|
|
320
|
-
t_alpha = group['t_alpha']
|
|
321
|
-
alpha_step = state['step'] + 1
|
|
322
|
-
alpha_t = alpha
|
|
323
|
-
if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
|
|
324
|
-
alpha_t = min(alpha_step * alpha / t_alpha, alpha)
|
|
325
324
|
if self.Simplified_AdEMAMix:
|
|
326
325
|
alpha_grad = group["alpha_grad"]
|
|
327
326
|
|
|
@@ -339,11 +338,11 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
339
338
|
del unpacked_sign
|
|
340
339
|
# Update momentum in full-size
|
|
341
340
|
if self.Simplified_AdEMAMix:
|
|
342
|
-
mt.mul_(self.beta1).add_(grad_reshaped, alpha=
|
|
341
|
+
mt.mul_(self.beta1).add_(grad_reshaped, alpha=d)
|
|
343
342
|
else:
|
|
344
|
-
mt.mul_(self.beta1).add_(grad_reshaped, alpha=
|
|
343
|
+
mt.mul_(self.beta1).add_(grad_reshaped, alpha=d * (1.0 - self.beta1))
|
|
345
344
|
if self.grams_moment:
|
|
346
|
-
mt
|
|
345
|
+
mt = grad_reshaped.sign().mul_(mt.abs())
|
|
347
346
|
elif self.cautious_mask:
|
|
348
347
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
349
348
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
@@ -351,7 +350,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
351
350
|
del mask
|
|
352
351
|
|
|
353
352
|
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
354
|
-
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=
|
|
353
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=d * d * (1.0 - beta2))
|
|
355
354
|
|
|
356
355
|
if self.use_AdEMAMix:
|
|
357
356
|
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
@@ -360,15 +359,15 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
360
359
|
unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
|
|
361
360
|
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
362
361
|
del unpacked_sign_slow
|
|
363
|
-
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=
|
|
362
|
+
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=d * (1.0 - beta3_ema))
|
|
364
363
|
if self.beta1 > 0:
|
|
365
|
-
update = torch.add(mt, mt_slow, alpha=
|
|
364
|
+
update = torch.add(mt, mt_slow, alpha=alpha)
|
|
366
365
|
else:
|
|
367
|
-
update = torch.add(grad_reshaped.mul(
|
|
366
|
+
update = torch.add(grad_reshaped.mul(d), mt_slow, alpha=alpha)
|
|
368
367
|
elif self.Simplified_AdEMAMix:
|
|
369
|
-
update = torch.add(mt, grad_reshaped, alpha=alpha_grad *
|
|
368
|
+
update = torch.add(mt, grad_reshaped, alpha=alpha_grad * d)
|
|
370
369
|
else:
|
|
371
|
-
update = mt.clone() if self.beta1 > 0 else grad_reshaped.mul(
|
|
370
|
+
update = mt.clone() if self.beta1 > 0 else grad_reshaped.mul(d)
|
|
372
371
|
del grad_reshaped
|
|
373
372
|
|
|
374
373
|
if group['use_atan2']:
|
|
@@ -377,10 +376,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
377
376
|
update.atan2_(denom).mul_(a)
|
|
378
377
|
else:
|
|
379
378
|
denom = vt.sqrt()
|
|
380
|
-
update.div_(denom.add_(
|
|
379
|
+
update.div_(denom.add_(d * group['eps']))
|
|
381
380
|
del denom
|
|
382
381
|
|
|
383
|
-
update = update.view(p.shape).mul_(
|
|
382
|
+
update = update.view(p.shape).mul_(dlr)
|
|
384
383
|
|
|
385
384
|
# Compress updated moments and store new factors
|
|
386
385
|
if self.beta1 > 0:
|
|
@@ -401,11 +400,11 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
401
400
|
if self.beta1 > 0:
|
|
402
401
|
exp_avg = state['exp_avg']
|
|
403
402
|
if self.Simplified_AdEMAMix:
|
|
404
|
-
exp_avg.mul_(self.beta1).add_(grad, alpha=
|
|
403
|
+
exp_avg.mul_(self.beta1).add_(grad, alpha=d)
|
|
405
404
|
else:
|
|
406
|
-
exp_avg.mul_(self.beta1).add_(grad, alpha=
|
|
405
|
+
exp_avg.mul_(self.beta1).add_(grad, alpha=d * (1.0 - self.beta1))
|
|
407
406
|
if self.grams_moment:
|
|
408
|
-
exp_avg = grad.sign()
|
|
407
|
+
exp_avg = grad.sign().mul_(exp_avg.abs())
|
|
409
408
|
elif self.cautious_mask:
|
|
410
409
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
411
410
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
@@ -414,17 +413,17 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
414
413
|
|
|
415
414
|
if self.use_AdEMAMix:
|
|
416
415
|
exp_avg_slow = state['exp_avg_slow']
|
|
417
|
-
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=
|
|
416
|
+
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=d * (1.0 - beta3_ema))
|
|
418
417
|
if self.beta1 > 0:
|
|
419
|
-
update = torch.add(exp_avg, exp_avg_slow, alpha=
|
|
418
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
|
|
420
419
|
else:
|
|
421
|
-
update = torch.add(grad.mul(
|
|
420
|
+
update = torch.add(grad.mul(d), exp_avg_slow, alpha=alpha)
|
|
422
421
|
elif self.Simplified_AdEMAMix:
|
|
423
|
-
update = torch.add(exp_avg, grad, alpha=alpha_grad *
|
|
422
|
+
update = torch.add(exp_avg, grad, alpha=alpha_grad * d)
|
|
424
423
|
else:
|
|
425
|
-
update = exp_avg.clone() if self.beta1 > 0 else grad.mul(
|
|
424
|
+
update = exp_avg.clone() if self.beta1 > 0 else grad.mul(d)
|
|
426
425
|
|
|
427
|
-
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=
|
|
426
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=d * d * (1.0 - beta2))
|
|
428
427
|
|
|
429
428
|
if group['use_atan2']:
|
|
430
429
|
a = 1.2732395
|
|
@@ -432,25 +431,25 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
432
431
|
update.atan2_(denom).mul_(a)
|
|
433
432
|
else:
|
|
434
433
|
denom = exp_avg_sq.sqrt()
|
|
435
|
-
update.div_(denom.add_(
|
|
434
|
+
update.div_(denom.add_(d * group['eps']))
|
|
436
435
|
del denom
|
|
437
436
|
|
|
438
|
-
update.mul_(
|
|
437
|
+
update.mul_(dlr)
|
|
439
438
|
|
|
440
439
|
# --- Accumulate Prodigy stats ---
|
|
441
440
|
prodigy_steps = group['prodigy_steps']
|
|
442
|
-
if prodigy_steps <= 0 or
|
|
441
|
+
if prodigy_steps <= 0 or self.global_step < prodigy_steps:
|
|
443
442
|
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
444
443
|
s, p0 = state['s'], state['p0']
|
|
445
444
|
grad_flat = grad.flatten().float()
|
|
446
445
|
p_flat = p.data.flatten().float()
|
|
447
446
|
p0 = p0.float()
|
|
448
447
|
|
|
449
|
-
self.d_numerator
|
|
448
|
+
self.d_numerator.add_((d / d0) * dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]))
|
|
450
449
|
|
|
451
|
-
alpha = ((
|
|
450
|
+
alpha = ((d / d0) * d) if safeguard_warmup else ((d / d0) * dlr)
|
|
452
451
|
s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
453
|
-
self.d_denom
|
|
452
|
+
self.d_denom.add_(s.abs().sum())
|
|
454
453
|
|
|
455
454
|
del s, p0, grad_flat, p_flat, alpha
|
|
456
455
|
else:
|
|
@@ -463,9 +462,9 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
463
462
|
# Decoupled weight decay
|
|
464
463
|
if group["weight_decay"] != 0:
|
|
465
464
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
466
|
-
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] *
|
|
465
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * dlr)
|
|
467
466
|
else:
|
|
468
|
-
p.data.add_(p.data, alpha=-group["weight_decay"] *
|
|
467
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * dlr)
|
|
469
468
|
|
|
470
469
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
471
470
|
add_stochastic_(p.data, -update)
|
|
@@ -473,7 +472,31 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
473
472
|
p.data.add_(-update)
|
|
474
473
|
del update
|
|
475
474
|
|
|
476
|
-
|
|
475
|
+
@torch.no_grad()
|
|
476
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
477
|
+
if hasattr(p, "_fsdp_flattened"):
|
|
478
|
+
self.fsdp_in_use = True
|
|
479
|
+
|
|
480
|
+
if self.global_step is None and 'step' in self.state[p]:
|
|
481
|
+
# For backward compatibility
|
|
482
|
+
self.global_step = self.state[p]['step']
|
|
483
|
+
|
|
484
|
+
if self.kourkoutas_beta:
|
|
485
|
+
self.kourkoutas_helper.maybe_prepare_step(self.global_step)
|
|
486
|
+
|
|
487
|
+
if isinstance(self.d_numerator, float):
|
|
488
|
+
self.d_numerator = torch.tensor(self.d_numerator, device=p.device)
|
|
489
|
+
self.d_denom = torch.tensor(self.d_denom, device=p.device)
|
|
490
|
+
|
|
491
|
+
if not group.get('compiled_optimizer', False):
|
|
492
|
+
self.__step_parameter(p, group, self.d, self.dlr)
|
|
493
|
+
else:
|
|
494
|
+
d_tensor = torch.tensor(self.d, device=p.device)
|
|
495
|
+
dlr_tensor = torch.tensor(self.dlr, device=p.device)
|
|
496
|
+
self._compiled_step_parameter(p, group, d_tensor, dlr_tensor)
|
|
497
|
+
|
|
498
|
+
def compile(self, *args, **kwargs):
|
|
499
|
+
self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
|
|
477
500
|
|
|
478
501
|
@torch.no_grad()
|
|
479
502
|
def step(self, closure=None):
|
|
@@ -494,23 +517,21 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
494
517
|
def calculate_d(self):
|
|
495
518
|
"""Calculates the new `d` based on the accumulated stats."""
|
|
496
519
|
g_group = self.param_groups[0]
|
|
497
|
-
|
|
520
|
+
|
|
498
521
|
# Only perform d-adaptation if prodigy_steps has not been reached
|
|
499
|
-
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and
|
|
522
|
+
prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and self.global_step >= g_group['prodigy_steps'])
|
|
500
523
|
|
|
501
524
|
if prodigy_active:
|
|
502
525
|
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
503
|
-
|
|
526
|
+
|
|
504
527
|
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
505
|
-
|
|
506
|
-
device = self.param_groups[0]['params'][0].device
|
|
507
|
-
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
528
|
+
dist_tensor = torch.stack([self.d_numerator, self.d_denom])
|
|
508
529
|
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
509
530
|
global_d_numerator = dist_tensor[0].item()
|
|
510
531
|
global_d_denom = dist_tensor[1].item()
|
|
511
532
|
else:
|
|
512
|
-
global_d_numerator = self.d_numerator
|
|
513
|
-
global_d_denom = self.d_denom
|
|
533
|
+
global_d_numerator = self.d_numerator.item()
|
|
534
|
+
global_d_denom = self.d_denom.item()
|
|
514
535
|
|
|
515
536
|
d_hat = self.d
|
|
516
537
|
if global_d_denom > 0:
|
|
@@ -529,7 +550,6 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
529
550
|
group['d_numerator'] = global_d_numerator
|
|
530
551
|
group['d'] = self.d
|
|
531
552
|
group['d_max'] = d_max
|
|
532
|
-
|
|
553
|
+
|
|
533
554
|
# Increment step counter for all groups, regardless of whether d was updated
|
|
534
|
-
|
|
535
|
-
group['k'] += 1
|
|
555
|
+
self.global_step += 1
|