adv-optm 1.0.6__py3-none-any.whl → 1.1.0.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +59 -7
- adv_optm/optim/Adopt_adv.py +58 -7
- adv_optm/optim/Prodigy_adv.py +62 -12
- adv_optm/optim/Simplified_AdEMAMix.py +53 -1
- adv_optm/util/Kourkoutas.py +134 -0
- {adv_optm-1.0.6.dist-info → adv_optm-1.1.0.dev2.dist-info}/METADATA +1 -1
- adv_optm-1.1.0.dev2.dist-info/RECORD +20 -0
- adv_optm-1.0.6.dist-info/RECORD +0 -19
- {adv_optm-1.0.6.dist-info → adv_optm-1.1.0.dev2.dist-info}/WHEEL +0 -0
- {adv_optm-1.0.6.dist-info → adv_optm-1.1.0.dev2.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.0.6.dist-info → adv_optm-1.1.0.dev2.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional, Callable
|
|
3
3
|
|
|
4
4
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
5
5
|
from ..util.Effective_Shape import _get_effective_shape
|
|
6
6
|
from ..util.NNMF import _nnmf,_unnmf
|
|
7
7
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
8
8
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
9
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
9
10
|
|
|
10
11
|
class AdamW_adv(torch.optim.Optimizer):
|
|
11
12
|
"""
|
|
@@ -54,6 +55,28 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
54
55
|
as it gradually introduces the stabilizing slow momentum term. During
|
|
55
56
|
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
56
57
|
the scheduler is disabled. (default: None)
|
|
58
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
59
|
+
If `False`, the optimizer behaves as standard AdamW. (default: False)
|
|
60
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
61
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
62
|
+
(default: 0.88)
|
|
63
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
64
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
65
|
+
(default: 0.93)
|
|
66
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
67
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
68
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
69
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
70
|
+
at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
|
|
71
|
+
dynamic logic activates. (default: 0)
|
|
72
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
73
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
74
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
75
|
+
logging (default: 0).
|
|
76
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
77
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
78
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
79
|
+
(default: None)
|
|
57
80
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
58
81
|
the uncompressed optimizer. (default: False)
|
|
59
82
|
"""
|
|
@@ -76,6 +99,13 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
76
99
|
beta3_ema: float = 0.9999,
|
|
77
100
|
alpha: float = 5.0,
|
|
78
101
|
t_alpha: int | None = None,
|
|
102
|
+
kourkoutas_beta: bool = False,
|
|
103
|
+
beta2_min: float = 0.88,
|
|
104
|
+
ema_alpha: float = 0.93,
|
|
105
|
+
tiny_spike: float = 1e-9,
|
|
106
|
+
k_warmup_steps: int = 0,
|
|
107
|
+
k_logging: int = 0,
|
|
108
|
+
layer_key_fn: Optional[Callable] = None,
|
|
79
109
|
nnmf_factor: bool = False,
|
|
80
110
|
):
|
|
81
111
|
if not (lr >= 0.0):
|
|
@@ -86,6 +116,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
86
116
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
87
117
|
if not (weight_decay >= 0.0):
|
|
88
118
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
119
|
+
if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
120
|
+
|
|
89
121
|
if cautious_mask and grams_moment:
|
|
90
122
|
print("Warning: cautious is incompatible with grams, Disabling cautious.")
|
|
91
123
|
cautious_mask = False
|
|
@@ -95,14 +127,21 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
95
127
|
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
96
128
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
97
129
|
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
130
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
131
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
98
132
|
}
|
|
99
133
|
self.stochastic_rounding = stochastic_rounding
|
|
100
134
|
self.cautious_mask = cautious_mask
|
|
101
135
|
self.grams_moment = grams_moment
|
|
102
136
|
self.use_AdEMAMix = use_AdEMAMix
|
|
103
137
|
self.factored = nnmf_factor
|
|
138
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
139
|
+
self.layer_key_fn = layer_key_fn
|
|
104
140
|
super().__init__(params, defaults)
|
|
105
141
|
|
|
142
|
+
if self.kourkoutas_beta:
|
|
143
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
144
|
+
|
|
106
145
|
@property
|
|
107
146
|
def supports_fused_back_pass(self):
|
|
108
147
|
return True
|
|
@@ -127,8 +166,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
127
166
|
grad = _orthogonalize_gradient(p, grad)
|
|
128
167
|
state = self.state[p]
|
|
129
168
|
|
|
130
|
-
beta1, beta2 = group['betas']
|
|
131
|
-
|
|
132
169
|
# State Initialization
|
|
133
170
|
if len(state) == 0:
|
|
134
171
|
state['step'] = 0
|
|
@@ -148,7 +185,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
148
185
|
d1, d2 = state['effective_shape']
|
|
149
186
|
|
|
150
187
|
# First moment (m)
|
|
151
|
-
if
|
|
188
|
+
if group['betas'][0] > 0:
|
|
152
189
|
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
153
190
|
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
154
191
|
if not self.grams_moment:
|
|
@@ -163,16 +200,31 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
163
200
|
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
164
201
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
165
202
|
else: # Fallback to standard AdamW for non-factored tensors
|
|
166
|
-
if
|
|
203
|
+
if group['betas'][0] > 0:
|
|
167
204
|
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
168
205
|
if self.use_AdEMAMix:
|
|
169
206
|
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
170
207
|
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
171
208
|
|
|
209
|
+
beta1, beta2 = group['betas']
|
|
210
|
+
|
|
211
|
+
current_step = state['step']
|
|
212
|
+
if group['kourkoutas_beta']:
|
|
213
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
214
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
215
|
+
# Accumulate current grad's norm for the *next* step
|
|
216
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
217
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
218
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
219
|
+
|
|
172
220
|
step = state['step'] + 1
|
|
173
221
|
if group['use_bias_correction']:
|
|
174
222
|
bias_correction1 = 1.0 - beta1 ** step
|
|
175
|
-
|
|
223
|
+
if group['kourkoutas_beta']:
|
|
224
|
+
bias_correction2 = 1.0 - group['betas'][1] ** step
|
|
225
|
+
# Use beta2_max for bias correction
|
|
226
|
+
else:
|
|
227
|
+
bias_correction2 = 1.0 - beta2 ** step
|
|
176
228
|
else:
|
|
177
229
|
bias_correction1 = 1
|
|
178
230
|
bias_correction2 = 1
|
|
@@ -315,4 +367,4 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
315
367
|
for i, p in enumerate(group['params']):
|
|
316
368
|
self.step_parameter(p, group, i)
|
|
317
369
|
|
|
318
|
-
return loss
|
|
370
|
+
return loss
|
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -6,6 +6,7 @@ from ..util.Effective_Shape import _get_effective_shape
|
|
|
6
6
|
from ..util.NNMF import _nnmf, _unnmf
|
|
7
7
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
8
8
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
9
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
9
10
|
|
|
10
11
|
class Adopt_adv(torch.optim.Optimizer):
|
|
11
12
|
"""
|
|
@@ -72,6 +73,28 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
72
73
|
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
73
74
|
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
74
75
|
stability. (default: 100.0)
|
|
76
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
77
|
+
If `False`, the optimizer behaves as standard Adopt. (default: False)
|
|
78
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
79
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
80
|
+
(default: 0.88)
|
|
81
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
82
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
83
|
+
(default: 0.93)
|
|
84
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
85
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
86
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
87
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
88
|
+
at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
|
|
89
|
+
dynamic logic activates. (default: 0)
|
|
90
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
91
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
92
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
93
|
+
logging (default: 0).
|
|
94
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
95
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
96
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
97
|
+
(default: None)
|
|
75
98
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
76
99
|
the uncompressed optimizer. (default: False)
|
|
77
100
|
"""
|
|
@@ -96,6 +119,13 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
96
119
|
t_alpha: int | None = None,
|
|
97
120
|
Simplified_AdEMAMix: bool = False,
|
|
98
121
|
alpha_grad: float = 100.0,
|
|
122
|
+
kourkoutas_beta: bool = False,
|
|
123
|
+
beta2_min: float = 0.88,
|
|
124
|
+
ema_alpha: float = 0.93,
|
|
125
|
+
tiny_spike: float = 1e-9,
|
|
126
|
+
k_warmup_steps: int = 0,
|
|
127
|
+
k_logging: int = 0,
|
|
128
|
+
layer_key_fn: Optional[Callable] = None,
|
|
99
129
|
nnmf_factor: bool = False,
|
|
100
130
|
):
|
|
101
131
|
if not (lr >= 0.0):
|
|
@@ -111,6 +141,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
111
141
|
cautious_mask = False
|
|
112
142
|
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
113
143
|
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
144
|
+
if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
114
145
|
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
115
146
|
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
116
147
|
if grams_moment and Simplified_AdEMAMix:
|
|
@@ -125,6 +156,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
125
156
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
126
157
|
"vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
|
|
127
158
|
"t_alpha": t_alpha, "alpha_grad": alpha_grad,
|
|
159
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
160
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
128
161
|
}
|
|
129
162
|
self.clip_lambda = clip_lambda
|
|
130
163
|
self.stochastic_rounding = stochastic_rounding
|
|
@@ -135,8 +168,13 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
135
168
|
self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
|
|
136
169
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
137
170
|
self.factored = nnmf_factor
|
|
171
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
172
|
+
self.layer_key_fn = layer_key_fn
|
|
138
173
|
super().__init__(params, defaults)
|
|
139
174
|
|
|
175
|
+
if self.kourkoutas_beta:
|
|
176
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
177
|
+
|
|
140
178
|
@property
|
|
141
179
|
def supports_fused_back_pass(self): return True
|
|
142
180
|
@property
|
|
@@ -156,8 +194,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
156
194
|
grad = _orthogonalize_gradient(p, grad)
|
|
157
195
|
state = self.state[p]
|
|
158
196
|
|
|
159
|
-
beta1, beta2 = group['betas']
|
|
160
|
-
|
|
161
197
|
# State Initialization
|
|
162
198
|
if len(state) == 0:
|
|
163
199
|
state['step'] = 0
|
|
@@ -176,7 +212,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
176
212
|
d1, d2 = state['effective_shape']
|
|
177
213
|
|
|
178
214
|
# m_0 = 0
|
|
179
|
-
if
|
|
215
|
+
if group['betas'][0] > 0:
|
|
180
216
|
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
181
217
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
182
218
|
if not self.grams_moment:
|
|
@@ -195,12 +231,23 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
195
231
|
# Initialize v_0 using NMF
|
|
196
232
|
_nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
197
233
|
else: # Fallback for non-factored tensors
|
|
198
|
-
if
|
|
234
|
+
if group['betas'][0] > 0:
|
|
199
235
|
state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
|
|
200
236
|
if self.use_AdEMAMix:
|
|
201
237
|
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
202
238
|
state['exp_avg_sq'] = grad.square() # v_0
|
|
203
239
|
|
|
240
|
+
beta1, beta2 = group['betas']
|
|
241
|
+
|
|
242
|
+
current_step = state['step']
|
|
243
|
+
if group['kourkoutas_beta']:
|
|
244
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
245
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
246
|
+
# Accumulate current grad's norm for the *next* step
|
|
247
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
248
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
249
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
250
|
+
|
|
204
251
|
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
205
252
|
if state['step'] == 0 and not self.use_atan2:
|
|
206
253
|
state['step'] += 1
|
|
@@ -211,10 +258,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
211
258
|
alpha = group['alpha']
|
|
212
259
|
t_alpha = group['t_alpha']
|
|
213
260
|
# Use step+1 for 1-based step count in scheduler
|
|
214
|
-
|
|
261
|
+
alpha_step = state['step'] + 1
|
|
215
262
|
alpha_t = alpha
|
|
216
|
-
if t_alpha is not None and t_alpha > 0 and
|
|
217
|
-
alpha_t = min(
|
|
263
|
+
if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
|
|
264
|
+
alpha_t = min(alpha_step * alpha / t_alpha, alpha)
|
|
218
265
|
if self.Simplified_AdEMAMix:
|
|
219
266
|
alpha_grad = group["alpha_grad"]
|
|
220
267
|
|
|
@@ -386,4 +433,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
386
433
|
for i, p in enumerate(group['params']):
|
|
387
434
|
self.step_parameter(p, group, i)
|
|
388
435
|
|
|
436
|
+
if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
|
|
437
|
+
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
438
|
+
step_num = first_param_state['step']
|
|
439
|
+
|
|
389
440
|
return loss
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -3,11 +3,14 @@ import torch.distributed as dist
|
|
|
3
3
|
|
|
4
4
|
import math
|
|
5
5
|
|
|
6
|
+
from typing import Optional, Callable
|
|
7
|
+
|
|
6
8
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
7
9
|
from ..util.Effective_Shape import _get_effective_shape
|
|
8
10
|
from ..util.NNMF import _nnmf,_unnmf
|
|
9
11
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
10
12
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
13
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
11
14
|
|
|
12
15
|
class Prodigy_adv(torch.optim.Optimizer):
|
|
13
16
|
"""
|
|
@@ -85,6 +88,28 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
85
88
|
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
86
89
|
after the specified optimiser step and release all state memory required by Prodigy
|
|
87
90
|
(default: 0).
|
|
91
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
92
|
+
If `False`, the optimizer behaves as standard AdamW/Prodigy. (default: False)
|
|
93
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
94
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
95
|
+
(default: 0.88)
|
|
96
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
97
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
98
|
+
(default: 0.93)
|
|
99
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
100
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
101
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
102
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
103
|
+
at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
|
|
104
|
+
dynamic logic activates. (default: 0)
|
|
105
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
106
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
107
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
108
|
+
logging (default: 0).
|
|
109
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
110
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
111
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
112
|
+
(default: None)
|
|
88
113
|
"""
|
|
89
114
|
|
|
90
115
|
def __init__(
|
|
@@ -116,6 +141,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
116
141
|
fsdp_in_use: bool = False,
|
|
117
142
|
slice_p: int = 11,
|
|
118
143
|
prodigy_steps: int = 0,
|
|
144
|
+
kourkoutas_beta: bool = False,
|
|
145
|
+
beta2_min: float = 0.88,
|
|
146
|
+
ema_alpha: float = 0.93,
|
|
147
|
+
tiny_spike: float = 1e-9,
|
|
148
|
+
k_warmup_steps: int = 0,
|
|
149
|
+
k_logging: int = 0,
|
|
150
|
+
layer_key_fn: Optional[Callable] = None,
|
|
119
151
|
):
|
|
120
152
|
if not (lr >= 0.0):
|
|
121
153
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -141,6 +173,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
141
173
|
if use_atan2 and Simplified_AdEMAMix:
|
|
142
174
|
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
143
175
|
use_atan2 = False
|
|
176
|
+
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
177
|
+
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
144
178
|
if Simplified_AdEMAMix and alpha_grad > 0:
|
|
145
179
|
# scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
|
|
146
180
|
d_coef = d_coef/alpha_grad
|
|
@@ -153,7 +187,9 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
153
187
|
"beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
|
|
154
188
|
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
|
|
155
189
|
"fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
|
|
156
|
-
"alpha_grad": alpha_grad,
|
|
190
|
+
"alpha_grad": alpha_grad,
|
|
191
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
192
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
157
193
|
}
|
|
158
194
|
self.stochastic_rounding = stochastic_rounding
|
|
159
195
|
self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
|
|
@@ -162,7 +198,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
162
198
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
163
199
|
self.factored = nnmf_factor
|
|
164
200
|
self.fsdp_in_use = fsdp_in_use
|
|
201
|
+
|
|
202
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
203
|
+
self.layer_key_fn = layer_key_fn
|
|
204
|
+
|
|
165
205
|
super().__init__(params, defaults)
|
|
206
|
+
if self.kourkoutas_beta:
|
|
207
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
166
208
|
self.init_step()
|
|
167
209
|
|
|
168
210
|
@property
|
|
@@ -180,19 +222,17 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
180
222
|
def init_step(self):
|
|
181
223
|
"""Resets accumulators and calculates dlr for the upcoming step."""
|
|
182
224
|
self.d_denom = 0.0
|
|
183
|
-
|
|
225
|
+
|
|
184
226
|
g_group = self.param_groups[0]
|
|
185
|
-
self.beta1, self.
|
|
227
|
+
self.beta1, self.beta2_default = g_group['betas']
|
|
186
228
|
self.beta3 = g_group['beta3']
|
|
187
229
|
if self.beta3 is None:
|
|
188
|
-
self.beta3 = math.sqrt(self.
|
|
230
|
+
self.beta3 = math.sqrt(self.beta2_default)
|
|
189
231
|
|
|
190
|
-
k = g_group['k']
|
|
191
232
|
self.d = g_group['d']
|
|
192
233
|
lr = g_group['lr']
|
|
193
234
|
|
|
194
235
|
self.dlr = self.d * lr
|
|
195
|
-
|
|
196
236
|
self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
|
|
197
237
|
|
|
198
238
|
@torch.no_grad()
|
|
@@ -258,14 +298,25 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
258
298
|
else:
|
|
259
299
|
state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
|
|
260
300
|
|
|
301
|
+
current_step = state['step']
|
|
302
|
+
if group['kourkoutas_beta']:
|
|
303
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
304
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
305
|
+
# Accumulate current grad's norm for the *next* step
|
|
306
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
307
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
308
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
309
|
+
else:
|
|
310
|
+
beta2 = self.beta2_default
|
|
311
|
+
|
|
261
312
|
if self.use_AdEMAMix:
|
|
262
313
|
beta3_ema = group['beta3_ema']
|
|
263
314
|
alpha = group['alpha']
|
|
264
315
|
t_alpha = group['t_alpha']
|
|
265
|
-
|
|
316
|
+
alpha_step = state['step'] + 1
|
|
266
317
|
alpha_t = alpha
|
|
267
|
-
if t_alpha is not None and t_alpha > 0 and
|
|
268
|
-
alpha_t = min(
|
|
318
|
+
if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
|
|
319
|
+
alpha_t = min(alpha_step * alpha / t_alpha, alpha)
|
|
269
320
|
if self.Simplified_AdEMAMix:
|
|
270
321
|
alpha_grad = group["alpha_grad"]
|
|
271
322
|
|
|
@@ -295,7 +346,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
295
346
|
del mask
|
|
296
347
|
|
|
297
348
|
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
298
|
-
vt.mul_(
|
|
349
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - beta2))
|
|
299
350
|
|
|
300
351
|
if self.use_AdEMAMix:
|
|
301
352
|
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
@@ -368,7 +419,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
368
419
|
else:
|
|
369
420
|
update = exp_avg.clone() if self.beta1 > 0 else grad.mul(self.d)
|
|
370
421
|
|
|
371
|
-
exp_avg_sq.mul_(
|
|
422
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - beta2))
|
|
372
423
|
|
|
373
424
|
if group['use_atan2']:
|
|
374
425
|
a = 1.2732395
|
|
@@ -431,7 +482,6 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
431
482
|
for i, p in enumerate(group['params']):
|
|
432
483
|
self.step_parameter(p, group, i)
|
|
433
484
|
|
|
434
|
-
|
|
435
485
|
self.calculate_d()
|
|
436
486
|
self.init_step()
|
|
437
487
|
return loss
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from typing import Optional, Callable
|
|
2
3
|
|
|
3
4
|
import math
|
|
4
5
|
|
|
@@ -7,6 +8,7 @@ from ..util.Effective_Shape import _get_effective_shape
|
|
|
7
8
|
from ..util.NNMF import _nnmf,_unnmf
|
|
8
9
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
10
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
11
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
12
|
|
|
11
13
|
# A little helper from the original simplified_AdEMAMix
|
|
12
14
|
def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
|
|
@@ -47,6 +49,28 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
47
49
|
stochastic_rounding (bool): whether to use stochastic
|
|
48
50
|
rounding for BF16 parameter updates (default: True).
|
|
49
51
|
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
52
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
53
|
+
If `False`, the optimizer behaves as standard Simplified_AdEMAMix. (default: False)
|
|
54
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
55
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
56
|
+
(default: 0.88)
|
|
57
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
58
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
59
|
+
(default: 0.93)
|
|
60
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
61
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
62
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
63
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
64
|
+
at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
|
|
65
|
+
dynamic logic activates. (default: 0)
|
|
66
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
67
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
68
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
69
|
+
logging (default: 0).
|
|
70
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
71
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
72
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
73
|
+
(default: None)
|
|
50
74
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
51
75
|
the uncompressed optimizer. (default: False)
|
|
52
76
|
"""
|
|
@@ -65,6 +89,13 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
65
89
|
vector_reshape: bool = True,
|
|
66
90
|
stochastic_rounding: bool = True,
|
|
67
91
|
orthogonal_gradient: bool = False,
|
|
92
|
+
kourkoutas_beta: bool = False,
|
|
93
|
+
beta2_min: float = 0.88,
|
|
94
|
+
ema_alpha: float = 0.93,
|
|
95
|
+
tiny_spike: float = 1e-9,
|
|
96
|
+
k_warmup_steps: int = 0,
|
|
97
|
+
k_logging: int = 0,
|
|
98
|
+
layer_key_fn: Optional[Callable] = None,
|
|
68
99
|
nnmf_factor: bool = False,
|
|
69
100
|
):
|
|
70
101
|
if not (lr >= 0.0):
|
|
@@ -77,17 +108,25 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
77
108
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
78
109
|
if not 0.0 <= alpha_grad:
|
|
79
110
|
raise ValueError("Invalid alpha value: {}".format(alpha_grad))
|
|
111
|
+
if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
80
112
|
|
|
81
113
|
defaults = {
|
|
82
114
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
83
115
|
"alpha_grad": alpha_grad, "beta1_warmup": beta1_warmup, "min_beta1": min_beta1,
|
|
84
116
|
"vector_reshape": vector_reshape,
|
|
85
117
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
118
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
119
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
86
120
|
}
|
|
87
121
|
self.stochastic_rounding = stochastic_rounding
|
|
88
122
|
self.factored = nnmf_factor
|
|
123
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
124
|
+
self.layer_key_fn = layer_key_fn
|
|
89
125
|
super().__init__(params, defaults)
|
|
90
126
|
|
|
127
|
+
if self.kourkoutas_beta:
|
|
128
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
129
|
+
|
|
91
130
|
@property
|
|
92
131
|
def supports_fused_back_pass(self):
|
|
93
132
|
return True
|
|
@@ -150,6 +189,16 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
150
189
|
state['den_sum'] = 1.0
|
|
151
190
|
|
|
152
191
|
beta1_final, beta2 = group["betas"]
|
|
192
|
+
|
|
193
|
+
current_step = state['step']
|
|
194
|
+
if group['kourkoutas_beta']:
|
|
195
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
196
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
197
|
+
# Accumulate current grad's norm for the *next* step
|
|
198
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
199
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
200
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
201
|
+
|
|
153
202
|
beta1_warmup = group["beta1_warmup"]
|
|
154
203
|
alpha_grad = group["alpha_grad"]
|
|
155
204
|
|
|
@@ -161,7 +210,10 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
161
210
|
|
|
162
211
|
if group['use_bias_correction']:
|
|
163
212
|
state['num_sum'] = beta1 * state['num_sum'] + 1.0
|
|
164
|
-
|
|
213
|
+
if group['kourkoutas_beta']:
|
|
214
|
+
state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
|
|
215
|
+
else:
|
|
216
|
+
state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
|
|
165
217
|
|
|
166
218
|
if state['factored']:
|
|
167
219
|
d1, d2 = state['effective_shape']
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.optim import Optimizer
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
class KourkoutasHelper:
|
|
6
|
+
"""
|
|
7
|
+
A helper class to add layer-wise Kourkoutas-β functionality to a PyTorch optimizer.
|
|
8
|
+
"""
|
|
9
|
+
def __init__(self, optimizer: Optimizer):
|
|
10
|
+
# We need a reference to the optimizer to access its param_groups and state
|
|
11
|
+
if not hasattr(optimizer, 'param_groups'):
|
|
12
|
+
raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
|
|
13
|
+
self.optimizer = optimizer
|
|
14
|
+
|
|
15
|
+
# State managed by the helper
|
|
16
|
+
self.layer_state = {}
|
|
17
|
+
self.layer_info = {}
|
|
18
|
+
self._layer_info_built = False
|
|
19
|
+
self._current_step_prepared = -1
|
|
20
|
+
|
|
21
|
+
def _build_layer_info_if_needed(self):
|
|
22
|
+
"""Builds a map of layers and the parameters they contain."""
|
|
23
|
+
if self._layer_info_built:
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
if not hasattr(self.optimizer, 'layer_key_fn') or self.optimizer.layer_key_fn is None:
|
|
27
|
+
print("Warning: KourkoutasHelper requires 'layer_key_fn' on the optimizer. Defaulting to tensor-wise (id).")
|
|
28
|
+
self.optimizer.layer_key_fn = lambda p: id(p)
|
|
29
|
+
|
|
30
|
+
for group in self.optimizer.param_groups:
|
|
31
|
+
for p in group['params']:
|
|
32
|
+
if p.grad is None: continue
|
|
33
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
34
|
+
if layer_key not in self.layer_info:
|
|
35
|
+
self.layer_info[layer_key] = {'params': [], 'group_ref': group}
|
|
36
|
+
self.layer_info[layer_key]['params'].append(p)
|
|
37
|
+
|
|
38
|
+
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
39
|
+
if k_logging_interval > 0:
|
|
40
|
+
print(f"[Kourkoutas-β Debug] Layer info built. Found {len(self.layer_info)} unique layers/buckets.")
|
|
41
|
+
|
|
42
|
+
self._layer_info_built = True
|
|
43
|
+
|
|
44
|
+
def prepare_step(self, current_step: int):
|
|
45
|
+
"""
|
|
46
|
+
Calculates dynamic beta2 for all layers using the completed scalar accumulators
|
|
47
|
+
from the PREVIOUS step. Should be called once at the start of an optimizer step.
|
|
48
|
+
"""
|
|
49
|
+
self._build_layer_info_if_needed()
|
|
50
|
+
|
|
51
|
+
# Check if logging is enabled for this step based on the interval
|
|
52
|
+
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
53
|
+
is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
|
|
54
|
+
|
|
55
|
+
beta2_log = [] if is_logging_step else None
|
|
56
|
+
first_layer_key = next(iter(self.layer_info), None)
|
|
57
|
+
|
|
58
|
+
for layer_key, info in self.layer_info.items():
|
|
59
|
+
params, group = info['params'], info['group_ref']
|
|
60
|
+
|
|
61
|
+
if layer_key not in self.layer_state:
|
|
62
|
+
self.layer_state[layer_key] = {
|
|
63
|
+
'r_ema_grad_norm': torch.tensor(0.0, device=params[0].device, dtype=torch.float32),
|
|
64
|
+
'sum_sq_accumulator': torch.tensor(0.0, device=params[0].device, dtype=torch.float32)
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
layer_state = self.layer_state[layer_key]
|
|
68
|
+
|
|
69
|
+
# Use the completed accumulator from the previous step
|
|
70
|
+
pooled_grad_norm = torch.sqrt(layer_state['sum_sq_accumulator'])
|
|
71
|
+
|
|
72
|
+
r_ema = layer_state['r_ema_grad_norm']
|
|
73
|
+
prev_r_ema_val = r_ema.item() # for logging
|
|
74
|
+
|
|
75
|
+
# EMA is always updated, even during warmup
|
|
76
|
+
r_ema.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
|
|
77
|
+
|
|
78
|
+
sun = torch.tensor(0.0, device=r_ema.device) # Default sun to 0 for warmup
|
|
79
|
+
beta2_max = group['betas'][1]
|
|
80
|
+
|
|
81
|
+
# --- CONSOLIDATED WARMUP LOGIC ---
|
|
82
|
+
if current_step < group['k_warmup_steps']:
|
|
83
|
+
beta2 = beta2_max
|
|
84
|
+
else:
|
|
85
|
+
raw = pooled_grad_norm / (r_ema + group['tiny_spike'])
|
|
86
|
+
sun = raw / (1.0 + raw)
|
|
87
|
+
beta2 = beta2_max - (beta2_max - group['beta2_min']) * sun
|
|
88
|
+
|
|
89
|
+
layer_state['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
|
|
90
|
+
layer_state['sum_sq_accumulator'].zero_()
|
|
91
|
+
|
|
92
|
+
if is_logging_step:
|
|
93
|
+
beta2_log.append(layer_state['dynamic_beta2'])
|
|
94
|
+
if layer_key == first_layer_key:
|
|
95
|
+
print(f"\n[Kourkoutas-β Debug] Step {current_step + 1} - Sample Layer '{layer_key}':")
|
|
96
|
+
print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema.item():.4e}")
|
|
97
|
+
print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {layer_state['dynamic_beta2']:.4f}")
|
|
98
|
+
|
|
99
|
+
if is_logging_step and beta2_log:
|
|
100
|
+
beta2_tensor = torch.tensor(beta2_log, device='cpu')
|
|
101
|
+
print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def maybe_prepare_step(self, current_step: int):
|
|
105
|
+
"""
|
|
106
|
+
A universal guard that calls prepare_step() exactly once per training step.
|
|
107
|
+
"""
|
|
108
|
+
if self._current_step_prepared < current_step:
|
|
109
|
+
self.prepare_step(current_step)
|
|
110
|
+
self._current_step_prepared = current_step
|
|
111
|
+
|
|
112
|
+
def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
|
|
113
|
+
"""
|
|
114
|
+
Accumulates the squared L2 norm of a single gradient for the next step's calculation.
|
|
115
|
+
"""
|
|
116
|
+
self._build_layer_info_if_needed()
|
|
117
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
118
|
+
|
|
119
|
+
if layer_key in self.layer_info:
|
|
120
|
+
if layer_key not in self.layer_state:
|
|
121
|
+
self.layer_state[layer_key] = {
|
|
122
|
+
'r_ema_grad_norm': torch.tensor(0.0, device=p.device, dtype=torch.float32),
|
|
123
|
+
'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
|
|
124
|
+
}
|
|
125
|
+
# Accumulate for the *next* step's prepare_step call
|
|
126
|
+
self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
|
|
127
|
+
|
|
128
|
+
def get_beta2(self, p: torch.Tensor, group: dict, current_step: int) -> float:
|
|
129
|
+
"""
|
|
130
|
+
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
131
|
+
"""
|
|
132
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
133
|
+
# The default is the max value, which is correct for unmapped params or edge cases
|
|
134
|
+
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
adv_optm/__init__.py,sha256=hkmbLr1AVDoC6VbnyTkNy-G4g5bmcLFH2Kv4dYWp9uY,311
|
|
2
|
+
adv_optm/optim/AdamW_adv.py,sha256=H4XlYZELwiFvXt0A9wMlRNiw9c8rmPMspHDCvR_SZIQ,17487
|
|
3
|
+
adv_optm/optim/Adopt_adv.py,sha256=PJ3ZaLgzYbvxXDS56FGjzMrVMyHDXSWdUPHnX5NpNAA,21241
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=sGzhts9a6gHfCkuHTB5L9IrClo4c6UThzYYErBwqOaA,12844
|
|
5
|
+
adv_optm/optim/Lion_adv.py,sha256=6G1CukJB_pC7l9HwFEuY1ydsNHZFabVmOvcHDsHHVuQ,8295
|
|
6
|
+
adv_optm/optim/Prodigy_adv.py,sha256=-eMTutexbGrUQtSehKaOo6BO_p3QySpSIMgJKWvbxog,25517
|
|
7
|
+
adv_optm/optim/Simplified_AdEMAMix.py,sha256=b4GaSI-TX6wFBqGxZeoJPbf2nVRCEtB3WVb1olDgY14,12980
|
|
8
|
+
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
|
+
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
|
+
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
11
|
+
adv_optm/util/Kourkoutas.py,sha256=6OzK96KJ7Dd9Py8hiGWszF9C_n4uVoDjFCA_EYbhL4c,6600
|
|
12
|
+
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
13
|
+
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
14
|
+
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
15
|
+
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
16
|
+
adv_optm-1.1.0.dev2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
17
|
+
adv_optm-1.1.0.dev2.dist-info/METADATA,sha256=Y2F2wkpPmdbRtHft1KdCm1D6feTmiP5kFJ6iYpSLwCo,8427
|
|
18
|
+
adv_optm-1.1.0.dev2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
19
|
+
adv_optm-1.1.0.dev2.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
20
|
+
adv_optm-1.1.0.dev2.dist-info/RECORD,,
|
adv_optm-1.0.6.dist-info/RECORD
DELETED
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=dAbueuVEIGoYrYXx8UE4ATfFBH5wEKrpkXGPTjFH0r0,306
|
|
2
|
-
adv_optm/optim/AdamW_adv.py,sha256=aTuYcJgd_EcZOrs6TDgBrBKw3wtU5LPzE5WvTBDDeEo,14317
|
|
3
|
-
adv_optm/optim/Adopt_adv.py,sha256=FTpDDSlYruZDt1VVLgEI_bADiO8f26j-utQs7Gn2fFA,18108
|
|
4
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=sGzhts9a6gHfCkuHTB5L9IrClo4c6UThzYYErBwqOaA,12844
|
|
5
|
-
adv_optm/optim/Lion_adv.py,sha256=6G1CukJB_pC7l9HwFEuY1ydsNHZFabVmOvcHDsHHVuQ,8295
|
|
6
|
-
adv_optm/optim/Prodigy_adv.py,sha256=G8xXLO9YBeLb9574uS0HpdY9w3ojblaV-PJFghUnToQ,22493
|
|
7
|
-
adv_optm/optim/Simplified_AdEMAMix.py,sha256=tb3d6Cw_nGwcTzYUhDnKqyP7GzjD1hn8k4WqGG5lhmw,9813
|
|
8
|
-
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
|
-
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
|
-
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
11
|
-
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
12
|
-
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
13
|
-
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
14
|
-
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
15
|
-
adv_optm-1.0.6.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
16
|
-
adv_optm-1.0.6.dist-info/METADATA,sha256=3PslWXH0ysoiXU83vN3F9kWRw48fwUM4H1z1tMyEGvI,8422
|
|
17
|
-
adv_optm-1.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
-
adv_optm-1.0.6.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
19
|
-
adv_optm-1.0.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|