adv-optm 1.1.0.dev1__tar.gz → 1.1.0.dev3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/PKG-INFO +1 -1
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/__init__.py +1 -1
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/optim/AdamW_adv.py +8 -17
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/optim/Adopt_adv.py +11 -16
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/optim/Prodigy_adv.py +17 -24
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/optim/Simplified_AdEMAMix.py +9 -18
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/util/Kourkoutas.py +55 -27
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/setup.py +1 -1
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/LICENSE +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/README.md +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm/util/__init__.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev3}/setup.cfg +0 -0
|
@@ -128,18 +128,17 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
128
128
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
129
129
|
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
130
130
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
131
|
-
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
|
|
131
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
132
132
|
}
|
|
133
133
|
self.stochastic_rounding = stochastic_rounding
|
|
134
134
|
self.cautious_mask = cautious_mask
|
|
135
135
|
self.grams_moment = grams_moment
|
|
136
136
|
self.use_AdEMAMix = use_AdEMAMix
|
|
137
137
|
self.factored = nnmf_factor
|
|
138
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
139
|
+
self.layer_key_fn = layer_key_fn
|
|
138
140
|
super().__init__(params, defaults)
|
|
139
141
|
|
|
140
|
-
self.kourkoutas_beta = kourkoutas_beta
|
|
141
|
-
self.k_logging= k_logging and kourkoutas_beta
|
|
142
|
-
self.layer_key_fn = layer_key_fn and kourkoutas_beta
|
|
143
142
|
if self.kourkoutas_beta:
|
|
144
143
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
145
144
|
|
|
@@ -207,13 +206,15 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
207
206
|
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
208
207
|
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
209
208
|
|
|
209
|
+
beta1, beta2 = group['betas']
|
|
210
|
+
|
|
210
211
|
current_step = state['step']
|
|
211
212
|
if group['kourkoutas_beta']:
|
|
213
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
212
214
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
215
|
+
# Accumulate current grad's norm for the *next* step
|
|
213
216
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
214
|
-
|
|
215
|
-
beta1, beta2 = group['betas']
|
|
216
|
-
if group['kourkoutas_beta']:
|
|
217
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
217
218
|
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
218
219
|
|
|
219
220
|
step = state['step'] + 1
|
|
@@ -366,14 +367,4 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
366
367
|
for i, p in enumerate(group['params']):
|
|
367
368
|
self.step_parameter(p, group, i)
|
|
368
369
|
|
|
369
|
-
if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
|
|
370
|
-
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
371
|
-
step_num = first_param_state['step']
|
|
372
|
-
|
|
373
|
-
if step_num > 0 and step_num % self.k_logging == 0:
|
|
374
|
-
if self._beta2_log:
|
|
375
|
-
beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
|
|
376
|
-
print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
377
|
-
delattr(self, '_beta2_log')
|
|
378
|
-
|
|
379
370
|
return loss
|
|
@@ -157,7 +157,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
157
157
|
"vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
|
|
158
158
|
"t_alpha": t_alpha, "alpha_grad": alpha_grad,
|
|
159
159
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
160
|
-
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
|
|
160
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
161
161
|
}
|
|
162
162
|
self.clip_lambda = clip_lambda
|
|
163
163
|
self.stochastic_rounding = stochastic_rounding
|
|
@@ -168,11 +168,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
168
168
|
self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
|
|
169
169
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
170
170
|
self.factored = nnmf_factor
|
|
171
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
172
|
+
self.layer_key_fn = layer_key_fn
|
|
171
173
|
super().__init__(params, defaults)
|
|
172
174
|
|
|
173
|
-
self.kourkoutas_beta = kourkoutas_beta
|
|
174
|
-
self.k_logging= k_logging and kourkoutas_beta
|
|
175
|
-
self.layer_key_fn = layer_key_fn and kourkoutas_beta
|
|
176
175
|
if self.kourkoutas_beta:
|
|
177
176
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
178
177
|
|
|
@@ -238,13 +237,15 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
238
237
|
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
239
238
|
state['exp_avg_sq'] = grad.square() # v_0
|
|
240
239
|
|
|
240
|
+
beta1, beta2 = group['betas']
|
|
241
|
+
|
|
241
242
|
current_step = state['step']
|
|
242
243
|
if group['kourkoutas_beta']:
|
|
244
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
243
245
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
246
|
+
# Accumulate current grad's norm for the *next* step
|
|
244
247
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
245
|
-
|
|
246
|
-
beta1, beta2 = group['betas']
|
|
247
|
-
if group['kourkoutas_beta']:
|
|
248
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
248
249
|
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
249
250
|
|
|
250
251
|
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
@@ -257,10 +258,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
257
258
|
alpha = group['alpha']
|
|
258
259
|
t_alpha = group['t_alpha']
|
|
259
260
|
# Use step+1 for 1-based step count in scheduler
|
|
260
|
-
|
|
261
|
+
alpha_step = state['step'] + 1
|
|
261
262
|
alpha_t = alpha
|
|
262
|
-
if t_alpha is not None and t_alpha > 0 and
|
|
263
|
-
alpha_t = min(
|
|
263
|
+
if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
|
|
264
|
+
alpha_t = min(alpha_step * alpha / t_alpha, alpha)
|
|
264
265
|
if self.Simplified_AdEMAMix:
|
|
265
266
|
alpha_grad = group["alpha_grad"]
|
|
266
267
|
|
|
@@ -436,10 +437,4 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
436
437
|
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
437
438
|
step_num = first_param_state['step']
|
|
438
439
|
|
|
439
|
-
if step_num > 0 and step_num % self.k_logging == 0:
|
|
440
|
-
if self._beta2_log:
|
|
441
|
-
beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
|
|
442
|
-
print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
443
|
-
delattr(self, '_beta2_log')
|
|
444
|
-
|
|
445
440
|
return loss
|
|
@@ -189,7 +189,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
189
189
|
"fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
|
|
190
190
|
"alpha_grad": alpha_grad,
|
|
191
191
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
192
|
-
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
|
|
192
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
193
193
|
}
|
|
194
194
|
self.stochastic_rounding = stochastic_rounding
|
|
195
195
|
self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
|
|
@@ -198,14 +198,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
198
198
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
199
199
|
self.factored = nnmf_factor
|
|
200
200
|
self.fsdp_in_use = fsdp_in_use
|
|
201
|
-
|
|
202
|
-
|
|
201
|
+
|
|
203
202
|
self.kourkoutas_beta = kourkoutas_beta
|
|
204
|
-
self.
|
|
205
|
-
|
|
203
|
+
self.layer_key_fn = layer_key_fn
|
|
204
|
+
|
|
205
|
+
super().__init__(params, defaults)
|
|
206
206
|
if self.kourkoutas_beta:
|
|
207
207
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
208
|
-
|
|
209
208
|
self.init_step()
|
|
210
209
|
|
|
211
210
|
@property
|
|
@@ -229,7 +228,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
229
228
|
self.beta3 = g_group['beta3']
|
|
230
229
|
if self.beta3 is None:
|
|
231
230
|
self.beta3 = math.sqrt(self.beta2_default)
|
|
232
|
-
|
|
231
|
+
|
|
233
232
|
self.d = g_group['d']
|
|
234
233
|
lr = g_group['lr']
|
|
235
234
|
|
|
@@ -301,21 +300,25 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
301
300
|
|
|
302
301
|
current_step = state['step']
|
|
303
302
|
if group['kourkoutas_beta']:
|
|
303
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
304
304
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
305
|
+
# Accumulate current grad's norm for the *next* step
|
|
305
306
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
306
|
-
|
|
307
|
-
beta2 = self.beta2_default
|
|
308
|
-
if group['kourkoutas_beta']:
|
|
307
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
309
308
|
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
309
|
+
beta3 = math.sqrt(beta2)
|
|
310
|
+
else:
|
|
311
|
+
beta2 = self.beta2_default
|
|
312
|
+
beta3 = self.beta3
|
|
310
313
|
|
|
311
314
|
if self.use_AdEMAMix:
|
|
312
315
|
beta3_ema = group['beta3_ema']
|
|
313
316
|
alpha = group['alpha']
|
|
314
317
|
t_alpha = group['t_alpha']
|
|
315
|
-
|
|
318
|
+
alpha_step = state['step'] + 1
|
|
316
319
|
alpha_t = alpha
|
|
317
|
-
if t_alpha is not None and t_alpha > 0 and
|
|
318
|
-
alpha_t = min(
|
|
320
|
+
if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
|
|
321
|
+
alpha_t = min(alpha_step * alpha / t_alpha, alpha)
|
|
319
322
|
if self.Simplified_AdEMAMix:
|
|
320
323
|
alpha_grad = group["alpha_grad"]
|
|
321
324
|
|
|
@@ -443,7 +446,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
443
446
|
self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
|
|
444
447
|
|
|
445
448
|
alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
|
|
446
|
-
s.mul_(
|
|
449
|
+
s.mul_(beta3).add_(grad_flat[::slice_p], alpha=alpha)
|
|
447
450
|
self.d_denom += s.abs().sum().item()
|
|
448
451
|
|
|
449
452
|
del s, p0, grad_flat, p_flat, alpha
|
|
@@ -481,16 +484,6 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
481
484
|
for i, p in enumerate(group['params']):
|
|
482
485
|
self.step_parameter(p, group, i)
|
|
483
486
|
|
|
484
|
-
if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
|
|
485
|
-
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
486
|
-
step_num = first_param_state['step']
|
|
487
|
-
|
|
488
|
-
if step_num > 0 and step_num % self.k_logging == 0:
|
|
489
|
-
if self._beta2_log:
|
|
490
|
-
beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
|
|
491
|
-
print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
492
|
-
delattr(self, '_beta2_log')
|
|
493
|
-
|
|
494
487
|
self.calculate_d()
|
|
495
488
|
self.init_step()
|
|
496
489
|
return loss
|
|
@@ -116,15 +116,14 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
116
116
|
"vector_reshape": vector_reshape,
|
|
117
117
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
118
118
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
119
|
-
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
|
|
119
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
120
120
|
}
|
|
121
121
|
self.stochastic_rounding = stochastic_rounding
|
|
122
122
|
self.factored = nnmf_factor
|
|
123
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
124
|
+
self.layer_key_fn = layer_key_fn
|
|
123
125
|
super().__init__(params, defaults)
|
|
124
126
|
|
|
125
|
-
self.kourkoutas_beta = kourkoutas_beta
|
|
126
|
-
self.k_logging= k_logging and kourkoutas_beta
|
|
127
|
-
self.layer_key_fn = layer_key_fn and kourkoutas_beta
|
|
128
127
|
if self.kourkoutas_beta:
|
|
129
128
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
130
129
|
|
|
@@ -189,17 +188,19 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
189
188
|
state['num_sum'] = 1.0
|
|
190
189
|
state['den_sum'] = 1.0
|
|
191
190
|
|
|
191
|
+
beta1_final, beta2 = group["betas"]
|
|
192
|
+
|
|
192
193
|
current_step = state['step']
|
|
193
194
|
if group['kourkoutas_beta']:
|
|
195
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
194
196
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
197
|
+
# Accumulate current grad's norm for the *next* step
|
|
195
198
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
199
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
200
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
196
201
|
|
|
197
|
-
beta1_final, beta2 = group["betas"]
|
|
198
202
|
beta1_warmup = group["beta1_warmup"]
|
|
199
203
|
alpha_grad = group["alpha_grad"]
|
|
200
|
-
|
|
201
|
-
if group['kourkoutas_beta']:
|
|
202
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
203
204
|
|
|
204
205
|
if beta1_warmup is not None:
|
|
205
206
|
step = state['step'] + 1
|
|
@@ -294,14 +295,4 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
294
295
|
for i, p in enumerate(group['params']):
|
|
295
296
|
self.step_parameter(p, group, i)
|
|
296
297
|
|
|
297
|
-
if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
|
|
298
|
-
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
299
|
-
step_num = first_param_state['step']
|
|
300
|
-
|
|
301
|
-
if step_num > 0 and step_num % self.k_logging == 0:
|
|
302
|
-
if self._beta2_log:
|
|
303
|
-
beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
|
|
304
|
-
print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
305
|
-
delattr(self, '_beta2_log')
|
|
306
|
-
|
|
307
298
|
return loss
|
|
@@ -18,6 +18,10 @@ class KourkoutasHelper:
|
|
|
18
18
|
self._layer_info_built = False
|
|
19
19
|
self._current_step_prepared = -1
|
|
20
20
|
|
|
21
|
+
# This ensures the map is complete before the first backward pass,
|
|
22
|
+
# making it compatible with fused back pass mechanisms.
|
|
23
|
+
self._build_layer_info_if_needed()
|
|
24
|
+
|
|
21
25
|
def _build_layer_info_if_needed(self):
|
|
22
26
|
"""Builds a map of layers and the parameters they contain."""
|
|
23
27
|
if self._layer_info_built:
|
|
@@ -28,26 +32,31 @@ class KourkoutasHelper:
|
|
|
28
32
|
self.optimizer.layer_key_fn = lambda p: id(p)
|
|
29
33
|
|
|
30
34
|
for group in self.optimizer.param_groups:
|
|
31
|
-
if not group.get('kourkoutas_beta', False):
|
|
32
|
-
continue
|
|
33
35
|
for p in group['params']:
|
|
34
|
-
|
|
36
|
+
# The mapping is static and should not depend on the presence of a gradient.
|
|
35
37
|
layer_key = self.optimizer.layer_key_fn(p)
|
|
36
38
|
if layer_key not in self.layer_info:
|
|
37
39
|
self.layer_info[layer_key] = {'params': [], 'group_ref': group}
|
|
38
40
|
self.layer_info[layer_key]['params'].append(p)
|
|
41
|
+
|
|
42
|
+
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
43
|
+
if k_logging_interval > 0:
|
|
44
|
+
print(f"[Kourkoutas-β Debug] Layer info built. Found {len(self.layer_info)} unique layers/buckets.")
|
|
45
|
+
|
|
39
46
|
self._layer_info_built = True
|
|
40
47
|
|
|
41
|
-
def prepare_step(self):
|
|
48
|
+
def prepare_step(self, current_step: int):
|
|
42
49
|
"""
|
|
43
50
|
Calculates dynamic beta2 for all layers using the completed scalar accumulators
|
|
44
51
|
from the PREVIOUS step. Should be called once at the start of an optimizer step.
|
|
45
52
|
"""
|
|
46
|
-
self._build_layer_info_if_needed()
|
|
47
53
|
|
|
48
|
-
if
|
|
49
|
-
|
|
50
|
-
|
|
54
|
+
# Check if logging is enabled for this step based on the interval
|
|
55
|
+
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
56
|
+
is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
|
|
57
|
+
|
|
58
|
+
beta2_log = [] if is_logging_step else None
|
|
59
|
+
first_layer_key = next(iter(self.layer_info), None)
|
|
51
60
|
|
|
52
61
|
for layer_key, info in self.layer_info.items():
|
|
53
62
|
params, group = info['params'], info['group_ref']
|
|
@@ -60,28 +69,47 @@ class KourkoutasHelper:
|
|
|
60
69
|
|
|
61
70
|
layer_state = self.layer_state[layer_key]
|
|
62
71
|
|
|
72
|
+
# Use the completed accumulator from the previous step
|
|
63
73
|
pooled_grad_norm = torch.sqrt(layer_state['sum_sq_accumulator'])
|
|
64
74
|
|
|
65
75
|
r_ema = layer_state['r_ema_grad_norm']
|
|
76
|
+
prev_r_ema_val = r_ema.item() # for logging
|
|
77
|
+
|
|
78
|
+
# EMA is always updated, even during warmup
|
|
66
79
|
r_ema.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
|
|
67
80
|
|
|
68
|
-
|
|
69
|
-
sun = raw / (1.0 + raw)
|
|
81
|
+
sun = torch.tensor(0.0, device=r_ema.device) # Default sun to 0 for warmup
|
|
70
82
|
beta2_max = group['betas'][1]
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
83
|
+
|
|
84
|
+
# --- CONSOLIDATED WARMUP LOGIC ---
|
|
85
|
+
if current_step < group['k_warmup_steps']:
|
|
86
|
+
beta2 = beta2_max
|
|
87
|
+
else:
|
|
88
|
+
raw = pooled_grad_norm / (r_ema + group['tiny_spike'])
|
|
89
|
+
sun = raw / (1.0 + raw)
|
|
90
|
+
beta2 = beta2_max - (beta2_max - group['beta2_min']) * sun
|
|
91
|
+
|
|
92
|
+
layer_state['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
|
|
74
93
|
layer_state['sum_sq_accumulator'].zero_()
|
|
75
94
|
|
|
76
|
-
if
|
|
77
|
-
|
|
95
|
+
if is_logging_step:
|
|
96
|
+
beta2_log.append(layer_state['dynamic_beta2'])
|
|
97
|
+
if layer_key == first_layer_key:
|
|
98
|
+
print(f"\n[Kourkoutas-β Debug] Step {current_step + 1} - Sample Layer '{layer_key}':")
|
|
99
|
+
print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema.item():.4e}")
|
|
100
|
+
print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {layer_state['dynamic_beta2']:.4f}")
|
|
101
|
+
|
|
102
|
+
if is_logging_step and beta2_log:
|
|
103
|
+
beta2_tensor = torch.tensor(beta2_log, device='cpu')
|
|
104
|
+
print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
105
|
+
|
|
78
106
|
|
|
79
107
|
def maybe_prepare_step(self, current_step: int):
|
|
80
108
|
"""
|
|
81
109
|
A universal guard that calls prepare_step() exactly once per training step.
|
|
82
110
|
"""
|
|
83
111
|
if self._current_step_prepared < current_step:
|
|
84
|
-
self.prepare_step()
|
|
112
|
+
self.prepare_step(current_step)
|
|
85
113
|
self._current_step_prepared = current_step
|
|
86
114
|
|
|
87
115
|
def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
|
|
@@ -89,20 +117,20 @@ class KourkoutasHelper:
|
|
|
89
117
|
Accumulates the squared L2 norm of a single gradient for the next step's calculation.
|
|
90
118
|
"""
|
|
91
119
|
layer_key = self.optimizer.layer_key_fn(p)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
120
|
+
|
|
121
|
+
if layer_key in self.layer_info:
|
|
122
|
+
if layer_key not in self.layer_state:
|
|
123
|
+
self.layer_state[layer_key] = {
|
|
124
|
+
'r_ema_grad_norm': torch.tensor(0.0, device=p.device, dtype=torch.float32),
|
|
125
|
+
'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
|
|
126
|
+
}
|
|
127
|
+
# Accumulate for the *next* step's prepare_step call
|
|
128
|
+
self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
|
|
98
129
|
|
|
99
130
|
def get_beta2(self, p: torch.Tensor, group: dict, current_step: int) -> float:
|
|
100
131
|
"""
|
|
101
132
|
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
102
133
|
"""
|
|
103
|
-
beta2_default = group['betas'][1]
|
|
104
|
-
if current_step < group['k_warmup_steps']:
|
|
105
|
-
return 0.5 * (group['beta2_min'] + beta2_default)
|
|
106
|
-
|
|
107
134
|
layer_key = self.optimizer.layer_key_fn(p)
|
|
108
|
-
|
|
135
|
+
# The default is the max value, which is correct for unmapped params or edge cases
|
|
136
|
+
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|