adv-optm 1.1.0.dev1__tar.gz → 1.1.0.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/PKG-INFO +1 -1
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/__init__.py +1 -1
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/AdamW_adv.py +8 -17
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/Adopt_adv.py +11 -16
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/Prodigy_adv.py +13 -22
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/Simplified_AdEMAMix.py +9 -18
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/Kourkoutas.py +51 -25
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/setup.py +1 -1
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/LICENSE +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/README.md +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/__init__.py +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/setup.cfg +0 -0
|
@@ -128,18 +128,17 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
128
128
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
129
129
|
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
130
130
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
131
|
-
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
|
|
131
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
132
132
|
}
|
|
133
133
|
self.stochastic_rounding = stochastic_rounding
|
|
134
134
|
self.cautious_mask = cautious_mask
|
|
135
135
|
self.grams_moment = grams_moment
|
|
136
136
|
self.use_AdEMAMix = use_AdEMAMix
|
|
137
137
|
self.factored = nnmf_factor
|
|
138
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
139
|
+
self.layer_key_fn = layer_key_fn
|
|
138
140
|
super().__init__(params, defaults)
|
|
139
141
|
|
|
140
|
-
self.kourkoutas_beta = kourkoutas_beta
|
|
141
|
-
self.k_logging= k_logging and kourkoutas_beta
|
|
142
|
-
self.layer_key_fn = layer_key_fn and kourkoutas_beta
|
|
143
142
|
if self.kourkoutas_beta:
|
|
144
143
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
145
144
|
|
|
@@ -207,13 +206,15 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
207
206
|
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
208
207
|
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
209
208
|
|
|
209
|
+
beta1, beta2 = group['betas']
|
|
210
|
+
|
|
210
211
|
current_step = state['step']
|
|
211
212
|
if group['kourkoutas_beta']:
|
|
213
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
212
214
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
215
|
+
# Accumulate current grad's norm for the *next* step
|
|
213
216
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
214
|
-
|
|
215
|
-
beta1, beta2 = group['betas']
|
|
216
|
-
if group['kourkoutas_beta']:
|
|
217
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
217
218
|
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
218
219
|
|
|
219
220
|
step = state['step'] + 1
|
|
@@ -366,14 +367,4 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
366
367
|
for i, p in enumerate(group['params']):
|
|
367
368
|
self.step_parameter(p, group, i)
|
|
368
369
|
|
|
369
|
-
if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
|
|
370
|
-
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
371
|
-
step_num = first_param_state['step']
|
|
372
|
-
|
|
373
|
-
if step_num > 0 and step_num % self.k_logging == 0:
|
|
374
|
-
if self._beta2_log:
|
|
375
|
-
beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
|
|
376
|
-
print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
377
|
-
delattr(self, '_beta2_log')
|
|
378
|
-
|
|
379
370
|
return loss
|
|
@@ -157,7 +157,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
157
157
|
"vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
|
|
158
158
|
"t_alpha": t_alpha, "alpha_grad": alpha_grad,
|
|
159
159
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
160
|
-
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
|
|
160
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
161
161
|
}
|
|
162
162
|
self.clip_lambda = clip_lambda
|
|
163
163
|
self.stochastic_rounding = stochastic_rounding
|
|
@@ -168,11 +168,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
168
168
|
self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
|
|
169
169
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
170
170
|
self.factored = nnmf_factor
|
|
171
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
172
|
+
self.layer_key_fn = layer_key_fn
|
|
171
173
|
super().__init__(params, defaults)
|
|
172
174
|
|
|
173
|
-
self.kourkoutas_beta = kourkoutas_beta
|
|
174
|
-
self.k_logging= k_logging and kourkoutas_beta
|
|
175
|
-
self.layer_key_fn = layer_key_fn and kourkoutas_beta
|
|
176
175
|
if self.kourkoutas_beta:
|
|
177
176
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
178
177
|
|
|
@@ -238,13 +237,15 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
238
237
|
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
239
238
|
state['exp_avg_sq'] = grad.square() # v_0
|
|
240
239
|
|
|
240
|
+
beta1, beta2 = group['betas']
|
|
241
|
+
|
|
241
242
|
current_step = state['step']
|
|
242
243
|
if group['kourkoutas_beta']:
|
|
244
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
243
245
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
246
|
+
# Accumulate current grad's norm for the *next* step
|
|
244
247
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
245
|
-
|
|
246
|
-
beta1, beta2 = group['betas']
|
|
247
|
-
if group['kourkoutas_beta']:
|
|
248
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
248
249
|
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
249
250
|
|
|
250
251
|
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
@@ -257,10 +258,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
257
258
|
alpha = group['alpha']
|
|
258
259
|
t_alpha = group['t_alpha']
|
|
259
260
|
# Use step+1 for 1-based step count in scheduler
|
|
260
|
-
|
|
261
|
+
alpha_step = state['step'] + 1
|
|
261
262
|
alpha_t = alpha
|
|
262
|
-
if t_alpha is not None and t_alpha > 0 and
|
|
263
|
-
alpha_t = min(
|
|
263
|
+
if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
|
|
264
|
+
alpha_t = min(alpha_step * alpha / t_alpha, alpha)
|
|
264
265
|
if self.Simplified_AdEMAMix:
|
|
265
266
|
alpha_grad = group["alpha_grad"]
|
|
266
267
|
|
|
@@ -436,10 +437,4 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
436
437
|
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
437
438
|
step_num = first_param_state['step']
|
|
438
439
|
|
|
439
|
-
if step_num > 0 and step_num % self.k_logging == 0:
|
|
440
|
-
if self._beta2_log:
|
|
441
|
-
beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
|
|
442
|
-
print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
443
|
-
delattr(self, '_beta2_log')
|
|
444
|
-
|
|
445
440
|
return loss
|
|
@@ -189,7 +189,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
189
189
|
"fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
|
|
190
190
|
"alpha_grad": alpha_grad,
|
|
191
191
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
192
|
-
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
|
|
192
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
193
193
|
}
|
|
194
194
|
self.stochastic_rounding = stochastic_rounding
|
|
195
195
|
self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
|
|
@@ -198,14 +198,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
198
198
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
199
199
|
self.factored = nnmf_factor
|
|
200
200
|
self.fsdp_in_use = fsdp_in_use
|
|
201
|
-
|
|
202
|
-
|
|
201
|
+
|
|
203
202
|
self.kourkoutas_beta = kourkoutas_beta
|
|
204
|
-
self.
|
|
205
|
-
|
|
203
|
+
self.layer_key_fn = layer_key_fn
|
|
204
|
+
|
|
205
|
+
super().__init__(params, defaults)
|
|
206
206
|
if self.kourkoutas_beta:
|
|
207
207
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
208
|
-
|
|
209
208
|
self.init_step()
|
|
210
209
|
|
|
211
210
|
@property
|
|
@@ -301,21 +300,23 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
301
300
|
|
|
302
301
|
current_step = state['step']
|
|
303
302
|
if group['kourkoutas_beta']:
|
|
303
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
304
304
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
305
|
+
# Accumulate current grad's norm for the *next* step
|
|
305
306
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
306
|
-
|
|
307
|
-
beta2 = self.beta2_default
|
|
308
|
-
if group['kourkoutas_beta']:
|
|
307
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
309
308
|
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
309
|
+
else:
|
|
310
|
+
beta2 = self.beta2_default
|
|
310
311
|
|
|
311
312
|
if self.use_AdEMAMix:
|
|
312
313
|
beta3_ema = group['beta3_ema']
|
|
313
314
|
alpha = group['alpha']
|
|
314
315
|
t_alpha = group['t_alpha']
|
|
315
|
-
|
|
316
|
+
alpha_step = state['step'] + 1
|
|
316
317
|
alpha_t = alpha
|
|
317
|
-
if t_alpha is not None and t_alpha > 0 and
|
|
318
|
-
alpha_t = min(
|
|
318
|
+
if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
|
|
319
|
+
alpha_t = min(alpha_step * alpha / t_alpha, alpha)
|
|
319
320
|
if self.Simplified_AdEMAMix:
|
|
320
321
|
alpha_grad = group["alpha_grad"]
|
|
321
322
|
|
|
@@ -481,16 +482,6 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
481
482
|
for i, p in enumerate(group['params']):
|
|
482
483
|
self.step_parameter(p, group, i)
|
|
483
484
|
|
|
484
|
-
if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
|
|
485
|
-
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
486
|
-
step_num = first_param_state['step']
|
|
487
|
-
|
|
488
|
-
if step_num > 0 and step_num % self.k_logging == 0:
|
|
489
|
-
if self._beta2_log:
|
|
490
|
-
beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
|
|
491
|
-
print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
492
|
-
delattr(self, '_beta2_log')
|
|
493
|
-
|
|
494
485
|
self.calculate_d()
|
|
495
486
|
self.init_step()
|
|
496
487
|
return loss
|
|
@@ -116,15 +116,14 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
116
116
|
"vector_reshape": vector_reshape,
|
|
117
117
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
118
118
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
119
|
-
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
|
|
119
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
120
120
|
}
|
|
121
121
|
self.stochastic_rounding = stochastic_rounding
|
|
122
122
|
self.factored = nnmf_factor
|
|
123
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
124
|
+
self.layer_key_fn = layer_key_fn
|
|
123
125
|
super().__init__(params, defaults)
|
|
124
126
|
|
|
125
|
-
self.kourkoutas_beta = kourkoutas_beta
|
|
126
|
-
self.k_logging= k_logging and kourkoutas_beta
|
|
127
|
-
self.layer_key_fn = layer_key_fn and kourkoutas_beta
|
|
128
127
|
if self.kourkoutas_beta:
|
|
129
128
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
130
129
|
|
|
@@ -189,17 +188,19 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
189
188
|
state['num_sum'] = 1.0
|
|
190
189
|
state['den_sum'] = 1.0
|
|
191
190
|
|
|
191
|
+
beta1_final, beta2 = group["betas"]
|
|
192
|
+
|
|
192
193
|
current_step = state['step']
|
|
193
194
|
if group['kourkoutas_beta']:
|
|
195
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
194
196
|
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
197
|
+
# Accumulate current grad's norm for the *next* step
|
|
195
198
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
199
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
200
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
196
201
|
|
|
197
|
-
beta1_final, beta2 = group["betas"]
|
|
198
202
|
beta1_warmup = group["beta1_warmup"]
|
|
199
203
|
alpha_grad = group["alpha_grad"]
|
|
200
|
-
|
|
201
|
-
if group['kourkoutas_beta']:
|
|
202
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
203
204
|
|
|
204
205
|
if beta1_warmup is not None:
|
|
205
206
|
step = state['step'] + 1
|
|
@@ -294,14 +295,4 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
294
295
|
for i, p in enumerate(group['params']):
|
|
295
296
|
self.step_parameter(p, group, i)
|
|
296
297
|
|
|
297
|
-
if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
|
|
298
|
-
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
299
|
-
step_num = first_param_state['step']
|
|
300
|
-
|
|
301
|
-
if step_num > 0 and step_num % self.k_logging == 0:
|
|
302
|
-
if self._beta2_log:
|
|
303
|
-
beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
|
|
304
|
-
print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
305
|
-
delattr(self, '_beta2_log')
|
|
306
|
-
|
|
307
298
|
return loss
|
|
@@ -28,26 +28,32 @@ class KourkoutasHelper:
|
|
|
28
28
|
self.optimizer.layer_key_fn = lambda p: id(p)
|
|
29
29
|
|
|
30
30
|
for group in self.optimizer.param_groups:
|
|
31
|
-
if not group.get('kourkoutas_beta', False):
|
|
32
|
-
continue
|
|
33
31
|
for p in group['params']:
|
|
34
32
|
if p.grad is None: continue
|
|
35
33
|
layer_key = self.optimizer.layer_key_fn(p)
|
|
36
34
|
if layer_key not in self.layer_info:
|
|
37
35
|
self.layer_info[layer_key] = {'params': [], 'group_ref': group}
|
|
38
36
|
self.layer_info[layer_key]['params'].append(p)
|
|
37
|
+
|
|
38
|
+
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
39
|
+
if k_logging_interval > 0:
|
|
40
|
+
print(f"[Kourkoutas-β Debug] Layer info built. Found {len(self.layer_info)} unique layers/buckets.")
|
|
41
|
+
|
|
39
42
|
self._layer_info_built = True
|
|
40
43
|
|
|
41
|
-
def prepare_step(self):
|
|
44
|
+
def prepare_step(self, current_step: int):
|
|
42
45
|
"""
|
|
43
46
|
Calculates dynamic beta2 for all layers using the completed scalar accumulators
|
|
44
47
|
from the PREVIOUS step. Should be called once at the start of an optimizer step.
|
|
45
48
|
"""
|
|
46
49
|
self._build_layer_info_if_needed()
|
|
47
50
|
|
|
48
|
-
if
|
|
49
|
-
|
|
50
|
-
|
|
51
|
+
# Check if logging is enabled for this step based on the interval
|
|
52
|
+
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
53
|
+
is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
|
|
54
|
+
|
|
55
|
+
beta2_log = [] if is_logging_step else None
|
|
56
|
+
first_layer_key = next(iter(self.layer_info), None)
|
|
51
57
|
|
|
52
58
|
for layer_key, info in self.layer_info.items():
|
|
53
59
|
params, group = info['params'], info['group_ref']
|
|
@@ -60,49 +66,69 @@ class KourkoutasHelper:
|
|
|
60
66
|
|
|
61
67
|
layer_state = self.layer_state[layer_key]
|
|
62
68
|
|
|
69
|
+
# Use the completed accumulator from the previous step
|
|
63
70
|
pooled_grad_norm = torch.sqrt(layer_state['sum_sq_accumulator'])
|
|
64
71
|
|
|
65
72
|
r_ema = layer_state['r_ema_grad_norm']
|
|
73
|
+
prev_r_ema_val = r_ema.item() # for logging
|
|
74
|
+
|
|
75
|
+
# EMA is always updated, even during warmup
|
|
66
76
|
r_ema.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
|
|
67
77
|
|
|
68
|
-
|
|
69
|
-
sun = raw / (1.0 + raw)
|
|
78
|
+
sun = torch.tensor(0.0, device=r_ema.device) # Default sun to 0 for warmup
|
|
70
79
|
beta2_max = group['betas'][1]
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
80
|
+
|
|
81
|
+
# --- CONSOLIDATED WARMUP LOGIC ---
|
|
82
|
+
if current_step < group['k_warmup_steps']:
|
|
83
|
+
beta2 = beta2_max
|
|
84
|
+
else:
|
|
85
|
+
raw = pooled_grad_norm / (r_ema + group['tiny_spike'])
|
|
86
|
+
sun = raw / (1.0 + raw)
|
|
87
|
+
beta2 = beta2_max - (beta2_max - group['beta2_min']) * sun
|
|
88
|
+
|
|
89
|
+
layer_state['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
|
|
74
90
|
layer_state['sum_sq_accumulator'].zero_()
|
|
75
91
|
|
|
76
|
-
if
|
|
77
|
-
|
|
92
|
+
if is_logging_step:
|
|
93
|
+
beta2_log.append(layer_state['dynamic_beta2'])
|
|
94
|
+
if layer_key == first_layer_key:
|
|
95
|
+
print(f"\n[Kourkoutas-β Debug] Step {current_step + 1} - Sample Layer '{layer_key}':")
|
|
96
|
+
print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema.item():.4e}")
|
|
97
|
+
print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {layer_state['dynamic_beta2']:.4f}")
|
|
98
|
+
|
|
99
|
+
if is_logging_step and beta2_log:
|
|
100
|
+
beta2_tensor = torch.tensor(beta2_log, device='cpu')
|
|
101
|
+
print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
102
|
+
|
|
78
103
|
|
|
79
104
|
def maybe_prepare_step(self, current_step: int):
|
|
80
105
|
"""
|
|
81
106
|
A universal guard that calls prepare_step() exactly once per training step.
|
|
82
107
|
"""
|
|
83
108
|
if self._current_step_prepared < current_step:
|
|
84
|
-
self.prepare_step()
|
|
109
|
+
self.prepare_step(current_step)
|
|
85
110
|
self._current_step_prepared = current_step
|
|
86
111
|
|
|
87
112
|
def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
|
|
88
113
|
"""
|
|
89
114
|
Accumulates the squared L2 norm of a single gradient for the next step's calculation.
|
|
90
115
|
"""
|
|
116
|
+
self._build_layer_info_if_needed()
|
|
91
117
|
layer_key = self.optimizer.layer_key_fn(p)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
118
|
+
|
|
119
|
+
if layer_key in self.layer_info:
|
|
120
|
+
if layer_key not in self.layer_state:
|
|
121
|
+
self.layer_state[layer_key] = {
|
|
122
|
+
'r_ema_grad_norm': torch.tensor(0.0, device=p.device, dtype=torch.float32),
|
|
123
|
+
'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
|
|
124
|
+
}
|
|
125
|
+
# Accumulate for the *next* step's prepare_step call
|
|
126
|
+
self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
|
|
98
127
|
|
|
99
128
|
def get_beta2(self, p: torch.Tensor, group: dict, current_step: int) -> float:
|
|
100
129
|
"""
|
|
101
130
|
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
102
131
|
"""
|
|
103
|
-
beta2_default = group['betas'][1]
|
|
104
|
-
if current_step < group['k_warmup_steps']:
|
|
105
|
-
return 0.5 * (group['beta2_min'] + beta2_default)
|
|
106
|
-
|
|
107
132
|
layer_key = self.optimizer.layer_key_fn(p)
|
|
108
|
-
|
|
133
|
+
# The default is the max value, which is correct for unmapped params or edge cases
|
|
134
|
+
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|