adv-optm 1.0.6__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +61 -9
- adv_optm/optim/Adopt_adv.py +435 -388
- adv_optm/optim/Lion_Prodigy_adv.py +315 -315
- adv_optm/optim/Lion_adv.py +1 -1
- adv_optm/optim/Prodigy_adv.py +78 -19
- adv_optm/optim/Simplified_AdEMAMix.py +54 -2
- adv_optm/util/Kourkoutas.py +171 -0
- {adv_optm-1.0.6.dist-info → adv_optm-1.1.0.dist-info}/METADATA +1 -1
- adv_optm-1.1.0.dist-info/RECORD +20 -0
- adv_optm-1.0.6.dist-info/RECORD +0 -19
- {adv_optm-1.0.6.dist-info → adv_optm-1.1.0.dist-info}/WHEEL +0 -0
- {adv_optm-1.0.6.dist-info → adv_optm-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.0.6.dist-info → adv_optm-1.1.0.dist-info}/top_level.txt +0 -0
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -3,15 +3,18 @@ import torch.distributed as dist
|
|
|
3
3
|
|
|
4
4
|
import math
|
|
5
5
|
|
|
6
|
+
from typing import Optional, Callable
|
|
7
|
+
|
|
6
8
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
7
9
|
from ..util.Effective_Shape import _get_effective_shape
|
|
8
10
|
from ..util.NNMF import _nnmf,_unnmf
|
|
9
11
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
10
12
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
13
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
11
14
|
|
|
12
15
|
class Prodigy_adv(torch.optim.Optimizer):
|
|
13
16
|
"""
|
|
14
|
-
Implements
|
|
17
|
+
Implements an advanced Prodigy algorithm.
|
|
15
18
|
This is an advanced version of Prodigy with optional features like
|
|
16
19
|
low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
|
|
17
20
|
|
|
@@ -85,6 +88,31 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
85
88
|
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
86
89
|
after the specified optimiser step and release all state memory required by Prodigy
|
|
87
90
|
(default: 0).
|
|
91
|
+
d_limiter (bool): whether to clamp the new step size estimate (`d_hat`)
|
|
92
|
+
to prevent sudden, volatile increases in the adaptive step size (`d`).
|
|
93
|
+
(default: False)
|
|
94
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
95
|
+
If `False`, the optimizer behaves as standard AdamW/Prodigy. (default: False)
|
|
96
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
97
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
98
|
+
(default: 0.88)
|
|
99
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
100
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
101
|
+
(default: 0.93)
|
|
102
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
103
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
104
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
105
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
106
|
+
at a fixed beta2 value before the
|
|
107
|
+
dynamic logic activates. (default: 0)
|
|
108
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
109
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
110
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
111
|
+
logging (default: 0).
|
|
112
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
113
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
114
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
115
|
+
(default: None)
|
|
88
116
|
"""
|
|
89
117
|
|
|
90
118
|
def __init__(
|
|
@@ -116,6 +144,15 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
116
144
|
fsdp_in_use: bool = False,
|
|
117
145
|
slice_p: int = 11,
|
|
118
146
|
prodigy_steps: int = 0,
|
|
147
|
+
d_limiter: bool = False,
|
|
148
|
+
# K-b parameters
|
|
149
|
+
kourkoutas_beta: bool = False,
|
|
150
|
+
beta2_min: float = 0.9,
|
|
151
|
+
ema_alpha: float = 0.95,
|
|
152
|
+
tiny_spike: float = 1e-9,
|
|
153
|
+
k_warmup_steps: int = 0,
|
|
154
|
+
k_logging: int = 0,
|
|
155
|
+
layer_key_fn: Optional[Callable] = None,
|
|
119
156
|
):
|
|
120
157
|
if not (lr >= 0.0):
|
|
121
158
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -141,8 +178,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
141
178
|
if use_atan2 and Simplified_AdEMAMix:
|
|
142
179
|
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
143
180
|
use_atan2 = False
|
|
144
|
-
if
|
|
145
|
-
|
|
181
|
+
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
182
|
+
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
183
|
+
if Simplified_AdEMAMix and alpha_grad > 0 and not d_limiter:
|
|
184
|
+
# scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix.
|
|
146
185
|
d_coef = d_coef/alpha_grad
|
|
147
186
|
|
|
148
187
|
defaults = {
|
|
@@ -152,8 +191,10 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
152
191
|
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
153
192
|
"beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
|
|
154
193
|
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
|
|
155
|
-
"fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
|
|
156
|
-
"alpha_grad": alpha_grad,
|
|
194
|
+
"fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps, "d_limiter": d_limiter,
|
|
195
|
+
"alpha_grad": alpha_grad,
|
|
196
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
197
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
157
198
|
}
|
|
158
199
|
self.stochastic_rounding = stochastic_rounding
|
|
159
200
|
self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
|
|
@@ -162,7 +203,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
162
203
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
163
204
|
self.factored = nnmf_factor
|
|
164
205
|
self.fsdp_in_use = fsdp_in_use
|
|
206
|
+
|
|
207
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
208
|
+
self.layer_key_fn = layer_key_fn
|
|
209
|
+
|
|
165
210
|
super().__init__(params, defaults)
|
|
211
|
+
if self.kourkoutas_beta:
|
|
212
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
166
213
|
self.init_step()
|
|
167
214
|
|
|
168
215
|
@property
|
|
@@ -180,19 +227,17 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
180
227
|
def init_step(self):
|
|
181
228
|
"""Resets accumulators and calculates dlr for the upcoming step."""
|
|
182
229
|
self.d_denom = 0.0
|
|
183
|
-
|
|
230
|
+
|
|
184
231
|
g_group = self.param_groups[0]
|
|
185
|
-
self.beta1, self.
|
|
232
|
+
self.beta1, self.beta2_default = g_group['betas']
|
|
186
233
|
self.beta3 = g_group['beta3']
|
|
187
234
|
if self.beta3 is None:
|
|
188
|
-
self.beta3 = math.sqrt(self.
|
|
189
|
-
|
|
190
|
-
k = g_group['k']
|
|
235
|
+
self.beta3 = math.sqrt(self.beta2_default)
|
|
236
|
+
|
|
191
237
|
self.d = g_group['d']
|
|
192
238
|
lr = g_group['lr']
|
|
193
239
|
|
|
194
240
|
self.dlr = self.d * lr
|
|
195
|
-
|
|
196
241
|
self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
|
|
197
242
|
|
|
198
243
|
@torch.no_grad()
|
|
@@ -211,7 +256,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
211
256
|
state = self.state[p]
|
|
212
257
|
|
|
213
258
|
# State Initialization
|
|
214
|
-
if
|
|
259
|
+
if 'step' not in state:
|
|
215
260
|
state['step'] = 0
|
|
216
261
|
|
|
217
262
|
should_factor = (
|
|
@@ -258,14 +303,27 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
258
303
|
else:
|
|
259
304
|
state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
|
|
260
305
|
|
|
306
|
+
current_step = state['step']
|
|
307
|
+
if group['kourkoutas_beta']:
|
|
308
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
309
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
310
|
+
# Accumulate current grad's norm for the *next* step
|
|
311
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
312
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
313
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
314
|
+
beta3 = math.sqrt(beta2)
|
|
315
|
+
else:
|
|
316
|
+
beta2 = self.beta2_default
|
|
317
|
+
beta3 = self.beta3
|
|
318
|
+
|
|
261
319
|
if self.use_AdEMAMix:
|
|
262
320
|
beta3_ema = group['beta3_ema']
|
|
263
321
|
alpha = group['alpha']
|
|
264
322
|
t_alpha = group['t_alpha']
|
|
265
|
-
|
|
323
|
+
alpha_step = state['step'] + 1
|
|
266
324
|
alpha_t = alpha
|
|
267
|
-
if t_alpha is not None and t_alpha > 0 and
|
|
268
|
-
alpha_t = min(
|
|
325
|
+
if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
|
|
326
|
+
alpha_t = min(alpha_step * alpha / t_alpha, alpha)
|
|
269
327
|
if self.Simplified_AdEMAMix:
|
|
270
328
|
alpha_grad = group["alpha_grad"]
|
|
271
329
|
|
|
@@ -295,7 +353,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
295
353
|
del mask
|
|
296
354
|
|
|
297
355
|
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
298
|
-
vt.mul_(
|
|
356
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - beta2))
|
|
299
357
|
|
|
300
358
|
if self.use_AdEMAMix:
|
|
301
359
|
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
@@ -368,7 +426,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
368
426
|
else:
|
|
369
427
|
update = exp_avg.clone() if self.beta1 > 0 else grad.mul(self.d)
|
|
370
428
|
|
|
371
|
-
exp_avg_sq.mul_(
|
|
429
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - beta2))
|
|
372
430
|
|
|
373
431
|
if group['use_atan2']:
|
|
374
432
|
a = 1.2732395
|
|
@@ -393,7 +451,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
393
451
|
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
394
452
|
|
|
395
453
|
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
396
|
-
s.mul_(
|
|
454
|
+
s.mul_(beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
397
455
|
self.d_denom += s.abs().sum().item()
|
|
398
456
|
|
|
399
457
|
del s, p0, grad_flat, p_flat, alpha
|
|
@@ -431,7 +489,6 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
431
489
|
for i, p in enumerate(group['params']):
|
|
432
490
|
self.step_parameter(p, group, i)
|
|
433
491
|
|
|
434
|
-
|
|
435
492
|
self.calculate_d()
|
|
436
493
|
self.init_step()
|
|
437
494
|
return loss
|
|
@@ -460,6 +517,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
460
517
|
d_hat = self.d
|
|
461
518
|
if global_d_denom > 0:
|
|
462
519
|
d_hat = d_coef * global_d_numerator / global_d_denom
|
|
520
|
+
if g_group['d_limiter']:
|
|
521
|
+
d_hat = min(self.d * (2 ** 0.25), d_hat)
|
|
463
522
|
if self.d == g_group['d0']:
|
|
464
523
|
self.d = max(self.d, d_hat)
|
|
465
524
|
d_max = max(d_max, d_hat)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from typing import Optional, Callable
|
|
2
3
|
|
|
3
4
|
import math
|
|
4
5
|
|
|
@@ -7,6 +8,7 @@ from ..util.Effective_Shape import _get_effective_shape
|
|
|
7
8
|
from ..util.NNMF import _nnmf,_unnmf
|
|
8
9
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
10
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
11
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
12
|
|
|
11
13
|
# A little helper from the original simplified_AdEMAMix
|
|
12
14
|
def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
|
|
@@ -47,6 +49,28 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
47
49
|
stochastic_rounding (bool): whether to use stochastic
|
|
48
50
|
rounding for BF16 parameter updates (default: True).
|
|
49
51
|
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
52
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
53
|
+
If `False`, the optimizer behaves as standard Simplified_AdEMAMix. (default: False)
|
|
54
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
55
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
56
|
+
(default: 0.88)
|
|
57
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
58
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
59
|
+
(default: 0.93)
|
|
60
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
61
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
62
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
63
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
64
|
+
at a fixed beta2 value before the
|
|
65
|
+
dynamic logic activates. (default: 0)
|
|
66
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
67
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
68
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
69
|
+
logging (default: 0).
|
|
70
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
71
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
72
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
73
|
+
(default: None)
|
|
50
74
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
51
75
|
the uncompressed optimizer. (default: False)
|
|
52
76
|
"""
|
|
@@ -65,6 +89,13 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
65
89
|
vector_reshape: bool = True,
|
|
66
90
|
stochastic_rounding: bool = True,
|
|
67
91
|
orthogonal_gradient: bool = False,
|
|
92
|
+
kourkoutas_beta: bool = False,
|
|
93
|
+
beta2_min: float = 0.9,
|
|
94
|
+
ema_alpha: float = 0.95,
|
|
95
|
+
tiny_spike: float = 1e-9,
|
|
96
|
+
k_warmup_steps: int = 0,
|
|
97
|
+
k_logging: int = 0,
|
|
98
|
+
layer_key_fn: Optional[Callable] = None,
|
|
68
99
|
nnmf_factor: bool = False,
|
|
69
100
|
):
|
|
70
101
|
if not (lr >= 0.0):
|
|
@@ -77,17 +108,25 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
77
108
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
78
109
|
if not 0.0 <= alpha_grad:
|
|
79
110
|
raise ValueError("Invalid alpha value: {}".format(alpha_grad))
|
|
111
|
+
if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
80
112
|
|
|
81
113
|
defaults = {
|
|
82
114
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
83
115
|
"alpha_grad": alpha_grad, "beta1_warmup": beta1_warmup, "min_beta1": min_beta1,
|
|
84
116
|
"vector_reshape": vector_reshape,
|
|
85
117
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
118
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
119
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
86
120
|
}
|
|
87
121
|
self.stochastic_rounding = stochastic_rounding
|
|
88
122
|
self.factored = nnmf_factor
|
|
123
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
124
|
+
self.layer_key_fn = layer_key_fn
|
|
89
125
|
super().__init__(params, defaults)
|
|
90
126
|
|
|
127
|
+
if self.kourkoutas_beta:
|
|
128
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
129
|
+
|
|
91
130
|
@property
|
|
92
131
|
def supports_fused_back_pass(self):
|
|
93
132
|
return True
|
|
@@ -113,7 +152,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
113
152
|
state = self.state[p]
|
|
114
153
|
|
|
115
154
|
# State Initialization
|
|
116
|
-
if
|
|
155
|
+
if 'step' not in state:
|
|
117
156
|
state['step'] = 0
|
|
118
157
|
|
|
119
158
|
should_factor = (
|
|
@@ -150,6 +189,16 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
150
189
|
state['den_sum'] = 1.0
|
|
151
190
|
|
|
152
191
|
beta1_final, beta2 = group["betas"]
|
|
192
|
+
|
|
193
|
+
current_step = state['step']
|
|
194
|
+
if group['kourkoutas_beta']:
|
|
195
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
196
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
197
|
+
# Accumulate current grad's norm for the *next* step
|
|
198
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
199
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
200
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
201
|
+
|
|
153
202
|
beta1_warmup = group["beta1_warmup"]
|
|
154
203
|
alpha_grad = group["alpha_grad"]
|
|
155
204
|
|
|
@@ -161,7 +210,10 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
161
210
|
|
|
162
211
|
if group['use_bias_correction']:
|
|
163
212
|
state['num_sum'] = beta1 * state['num_sum'] + 1.0
|
|
164
|
-
|
|
213
|
+
if group['kourkoutas_beta']:
|
|
214
|
+
state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
|
|
215
|
+
else:
|
|
216
|
+
state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
|
|
165
217
|
|
|
166
218
|
if state['factored']:
|
|
167
219
|
d1, d2 = state['effective_shape']
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.optim import Optimizer
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
class KourkoutasHelper:
|
|
6
|
+
"""
|
|
7
|
+
A helper class to add layer-wise Kourkoutas-β functionality to a PyTorch optimizer.
|
|
8
|
+
"""
|
|
9
|
+
def __init__(self, optimizer: Optimizer):
|
|
10
|
+
# We need a reference to the optimizer to access its param_groups and state
|
|
11
|
+
if not hasattr(optimizer, 'param_groups'):
|
|
12
|
+
raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
|
|
13
|
+
self.optimizer = optimizer
|
|
14
|
+
self.layer_state = {}
|
|
15
|
+
|
|
16
|
+
self.layer_info = {}
|
|
17
|
+
self._layer_info_built = False
|
|
18
|
+
self._current_step_prepared = -1
|
|
19
|
+
|
|
20
|
+
# Store stats for external logging (e.g., TensorBoard)
|
|
21
|
+
self.last_beta2_stats = {}
|
|
22
|
+
|
|
23
|
+
# This ensures the map is complete before the first backward pass,
|
|
24
|
+
# making it compatible with fused back pass mechanisms.
|
|
25
|
+
self._build_layer_info_if_needed()
|
|
26
|
+
|
|
27
|
+
if self.optimizer.param_groups[0].get('k_logging', 0) > 0:
|
|
28
|
+
self.print_layer_info()
|
|
29
|
+
|
|
30
|
+
def _build_layer_info_if_needed(self):
|
|
31
|
+
"""Builds a map of layers and the parameters they contain."""
|
|
32
|
+
if self._layer_info_built:
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
if not hasattr(self.optimizer, 'layer_key_fn') or self.optimizer.layer_key_fn is None:
|
|
36
|
+
print("Warning: KourkoutasHelper requires 'layer_key_fn' on the optimizer. Defaulting to tensor-wise (id).")
|
|
37
|
+
self.optimizer.layer_key_fn = lambda p: id(p)
|
|
38
|
+
|
|
39
|
+
for group in self.optimizer.param_groups:
|
|
40
|
+
for p in group['params']:
|
|
41
|
+
# The mapping is static and should not depend on the presence of a gradient.
|
|
42
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
43
|
+
if layer_key not in self.layer_info:
|
|
44
|
+
self.layer_info[layer_key] = {'params': [], 'group_ref': group}
|
|
45
|
+
self.layer_info[layer_key]['params'].append(p)
|
|
46
|
+
|
|
47
|
+
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
48
|
+
if k_logging_interval > 0:
|
|
49
|
+
print(f"[Kourkoutas-β Debug] Layer info built. Found {len(self.layer_info)} unique layers/buckets.")
|
|
50
|
+
|
|
51
|
+
self._layer_info_built = True
|
|
52
|
+
|
|
53
|
+
def print_layer_info(self):
|
|
54
|
+
"""Prints the contents of self.layer_info for debugging."""
|
|
55
|
+
print("\n--- BEGIN self.layer_info DUMP ---")
|
|
56
|
+
if not self.layer_info:
|
|
57
|
+
print("Layer info is empty. Make sure the optimizer has parameters.")
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
for layer_key, info in self.layer_info.items():
|
|
61
|
+
param_count = len(info['params'])
|
|
62
|
+
first_param_details = ""
|
|
63
|
+
if param_count > 0:
|
|
64
|
+
p = info['params'][0]
|
|
65
|
+
first_param_details = f" (Example param shape: {list(p.shape)}, dtype: {p.dtype})"
|
|
66
|
+
|
|
67
|
+
print(f"Key: {layer_key}, Params: {param_count}{first_param_details}")
|
|
68
|
+
|
|
69
|
+
print("--- END self.layer_info DUMP ---\n")
|
|
70
|
+
|
|
71
|
+
def prepare_step(self, current_step: int):
|
|
72
|
+
"""
|
|
73
|
+
Calculates dynamic beta2 for all layers using the completed scalar accumulators
|
|
74
|
+
from the PREVIOUS step. Should be called once at the start of an optimizer step.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
beta2_log = []
|
|
78
|
+
first_layer_key = next(iter(self.layer_info), None)
|
|
79
|
+
# These are just for the sample log, initialize them
|
|
80
|
+
sun, pooled_grad_norm, prev_r_ema_val, r_ema_tensor = (torch.tensor(0.0),)*4
|
|
81
|
+
|
|
82
|
+
for layer_key, info in self.layer_info.items():
|
|
83
|
+
params, group = info['params'], info['group_ref']
|
|
84
|
+
|
|
85
|
+
first_param_in_layer = info['params'][0]
|
|
86
|
+
param_state = self.optimizer.state[first_param_in_layer]
|
|
87
|
+
|
|
88
|
+
if layer_key not in self.layer_state:
|
|
89
|
+
self.layer_state[layer_key] = {
|
|
90
|
+
'sum_sq_accumulator': torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
if 'kourkoutas_r_ema' not in param_state:
|
|
94
|
+
param_state['kourkoutas_r_ema'] = torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
|
|
95
|
+
|
|
96
|
+
r_ema_tensor = param_state['kourkoutas_r_ema']
|
|
97
|
+
accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
|
|
98
|
+
|
|
99
|
+
pooled_grad_norm = torch.sqrt(accumulator)
|
|
100
|
+
prev_r_ema_val = r_ema_tensor.item() # for logging
|
|
101
|
+
|
|
102
|
+
# Update the persistent EMA tensor in-place.
|
|
103
|
+
r_ema_tensor.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
|
|
104
|
+
|
|
105
|
+
beta2_max = group['betas'][1]
|
|
106
|
+
sun = torch.tensor(0.0, device=r_ema_tensor.device) # Default sun to 0 for warmup
|
|
107
|
+
|
|
108
|
+
if current_step < group['k_warmup_steps']:
|
|
109
|
+
beta2 = beta2_max
|
|
110
|
+
else:
|
|
111
|
+
raw = pooled_grad_norm / (r_ema_tensor + group['tiny_spike'])
|
|
112
|
+
sun = raw / (1.0 + raw)
|
|
113
|
+
beta2 = beta2_max - (beta2_max - group['beta2_min']) * sun
|
|
114
|
+
|
|
115
|
+
# Store the final calculated beta2 in the helper's transient state for this step.
|
|
116
|
+
self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
|
|
117
|
+
|
|
118
|
+
# Reset the accumulator for the next optimizer step.
|
|
119
|
+
accumulator.zero_()
|
|
120
|
+
|
|
121
|
+
beta2_log.append(self.layer_state[layer_key]['dynamic_beta2'])
|
|
122
|
+
|
|
123
|
+
# Always compute stats for TensorBoard
|
|
124
|
+
if beta2_log:
|
|
125
|
+
beta2_tensor = torch.tensor(beta2_log, device='cpu')
|
|
126
|
+
self.last_beta2_stats = {
|
|
127
|
+
'min': beta2_tensor.min().item(),
|
|
128
|
+
'max': beta2_tensor.max().item(),
|
|
129
|
+
'mean': beta2_tensor.mean().item(),
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
# Handle periodic console logging
|
|
133
|
+
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
134
|
+
is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
|
|
135
|
+
if is_logging_step and self.last_beta2_stats:
|
|
136
|
+
if first_layer_key:
|
|
137
|
+
print(f"\n[Kourkoutas-β Debug] Step {current_step + 1} - Sample Layer '{first_layer_key}':")
|
|
138
|
+
print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema_tensor.item():.4e}")
|
|
139
|
+
print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {self.layer_state[first_layer_key]['dynamic_beta2']:.4f}")
|
|
140
|
+
print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={self.last_beta2_stats['min']:.4f}, Max={self.last_beta2_stats['max']:.4f}, Mean={self.last_beta2_stats['mean']:.4f}")
|
|
141
|
+
|
|
142
|
+
def maybe_prepare_step(self, current_step: int):
|
|
143
|
+
"""
|
|
144
|
+
A universal guard that calls prepare_step() exactly once per training step.
|
|
145
|
+
"""
|
|
146
|
+
if self._current_step_prepared < current_step:
|
|
147
|
+
self.prepare_step(current_step)
|
|
148
|
+
self._current_step_prepared = current_step
|
|
149
|
+
|
|
150
|
+
def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
|
|
151
|
+
"""
|
|
152
|
+
Accumulates the squared L2 norm of a single gradient for the next step's calculation.
|
|
153
|
+
"""
|
|
154
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
155
|
+
|
|
156
|
+
if layer_key in self.layer_info:
|
|
157
|
+
# Initialize the transient state for this layer if it's the first time in the step.
|
|
158
|
+
if layer_key not in self.layer_state:
|
|
159
|
+
self.layer_state[layer_key] = {
|
|
160
|
+
'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
|
|
161
|
+
}
|
|
162
|
+
# Accumulate for the *next* step's prepare_step call
|
|
163
|
+
self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
|
|
164
|
+
|
|
165
|
+
def get_beta2(self, p: torch.Tensor, group: dict, current_step: int) -> float:
|
|
166
|
+
"""
|
|
167
|
+
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
168
|
+
"""
|
|
169
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
170
|
+
# The default is the max value, which is correct for unmapped params or edge cases
|
|
171
|
+
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
adv_optm/__init__.py,sha256=lNp6_DdCvw-0zok2UdMkaEyVLZIlMRSKgBp-hJ15Hao,306
|
|
2
|
+
adv_optm/optim/AdamW_adv.py,sha256=ddEUVOif1gfZPgEJNrEGZ2wnha4MPMWw5ppPd8acQ3o,17457
|
|
3
|
+
adv_optm/optim/Adopt_adv.py,sha256=fhH3hS9K6z5Blxc7NFfzpCrUGbl9EQnwLPmKDxBC1zg,21415
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=aJ9orEEw0QYbrDzn1be0SHvOBlIkLwWG9RpWFuNMskM,13163
|
|
5
|
+
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
6
|
+
adv_optm/optim/Prodigy_adv.py,sha256=4O7BLGhqLW46Ff3UN9JfrktHonCYDy3ojHUfW8jtaDs,25940
|
|
7
|
+
adv_optm/optim/Simplified_AdEMAMix.py,sha256=gPjMhKulzmAeO42foe-d7xW0AcB50vKFYsvHgxbD3uc,12949
|
|
8
|
+
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
|
+
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
|
+
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
11
|
+
adv_optm/util/Kourkoutas.py,sha256=DCsIcZ1sEeSwthN8KZH7OTKoIZJ3ah4t5DNiqxsSuCk,8344
|
|
12
|
+
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
13
|
+
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
14
|
+
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
15
|
+
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
16
|
+
adv_optm-1.1.0.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
17
|
+
adv_optm-1.1.0.dist-info/METADATA,sha256=dwRwKQykba-7TP6a94qpOg6xz450QESAi5E8AnEV-iM,8422
|
|
18
|
+
adv_optm-1.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
19
|
+
adv_optm-1.1.0.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
20
|
+
adv_optm-1.1.0.dist-info/RECORD,,
|
adv_optm-1.0.6.dist-info/RECORD
DELETED
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=dAbueuVEIGoYrYXx8UE4ATfFBH5wEKrpkXGPTjFH0r0,306
|
|
2
|
-
adv_optm/optim/AdamW_adv.py,sha256=aTuYcJgd_EcZOrs6TDgBrBKw3wtU5LPzE5WvTBDDeEo,14317
|
|
3
|
-
adv_optm/optim/Adopt_adv.py,sha256=FTpDDSlYruZDt1VVLgEI_bADiO8f26j-utQs7Gn2fFA,18108
|
|
4
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=sGzhts9a6gHfCkuHTB5L9IrClo4c6UThzYYErBwqOaA,12844
|
|
5
|
-
adv_optm/optim/Lion_adv.py,sha256=6G1CukJB_pC7l9HwFEuY1ydsNHZFabVmOvcHDsHHVuQ,8295
|
|
6
|
-
adv_optm/optim/Prodigy_adv.py,sha256=G8xXLO9YBeLb9574uS0HpdY9w3ojblaV-PJFghUnToQ,22493
|
|
7
|
-
adv_optm/optim/Simplified_AdEMAMix.py,sha256=tb3d6Cw_nGwcTzYUhDnKqyP7GzjD1hn8k4WqGG5lhmw,9813
|
|
8
|
-
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
|
-
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
|
-
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
11
|
-
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
12
|
-
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
13
|
-
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
14
|
-
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
15
|
-
adv_optm-1.0.6.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
16
|
-
adv_optm-1.0.6.dist-info/METADATA,sha256=3PslWXH0ysoiXU83vN3F9kWRw48fwUM4H1z1tMyEGvI,8422
|
|
17
|
-
adv_optm-1.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
-
adv_optm-1.0.6.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
19
|
-
adv_optm-1.0.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|