adv-optm 1.1.0.dev1__tar.gz → 1.1.0.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (25) hide show
  1. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/PKG-INFO +1 -1
  2. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/__init__.py +1 -1
  3. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/AdamW_adv.py +8 -17
  4. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/Adopt_adv.py +11 -16
  5. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/Prodigy_adv.py +13 -22
  6. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/Simplified_AdEMAMix.py +9 -18
  7. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/Kourkoutas.py +51 -25
  8. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/PKG-INFO +1 -1
  9. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/setup.py +1 -1
  10. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/LICENSE +0 -0
  11. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/README.md +0 -0
  12. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  13. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/Lion_adv.py +0 -0
  14. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/optim/__init__.py +0 -0
  15. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  16. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/Effective_Shape.py +0 -0
  17. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/NNMF.py +0 -0
  18. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/One_Bit_Boolean.py +0 -0
  19. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/OrthoGrad.py +0 -0
  20. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm/util/__init__.py +0 -0
  21. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/SOURCES.txt +0 -0
  22. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/dependency_links.txt +0 -0
  23. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/requires.txt +0 -0
  24. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/top_level.txt +0 -0
  25. {adv_optm-1.1.0.dev1 → adv_optm-1.1.0.dev2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.1.0.dev1
3
+ Version: 1.1.0.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -16,4 +16,4 @@ __all__ = [
16
16
  "Lion_Prodigy_adv",
17
17
  ]
18
18
 
19
- __version__ = "1.1.0.dev1"
19
+ __version__ = "1.1.0.dev2"
@@ -128,18 +128,17 @@ class AdamW_adv(torch.optim.Optimizer):
128
128
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
129
129
  "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
130
130
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
131
- "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
131
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
132
132
  }
133
133
  self.stochastic_rounding = stochastic_rounding
134
134
  self.cautious_mask = cautious_mask
135
135
  self.grams_moment = grams_moment
136
136
  self.use_AdEMAMix = use_AdEMAMix
137
137
  self.factored = nnmf_factor
138
+ self.kourkoutas_beta = kourkoutas_beta
139
+ self.layer_key_fn = layer_key_fn
138
140
  super().__init__(params, defaults)
139
141
 
140
- self.kourkoutas_beta = kourkoutas_beta
141
- self.k_logging= k_logging and kourkoutas_beta
142
- self.layer_key_fn = layer_key_fn and kourkoutas_beta
143
142
  if self.kourkoutas_beta:
144
143
  self.kourkoutas_helper = KourkoutasHelper(self)
145
144
 
@@ -207,13 +206,15 @@ class AdamW_adv(torch.optim.Optimizer):
207
206
  state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
208
207
  state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
209
208
 
209
+ beta1, beta2 = group['betas']
210
+
210
211
  current_step = state['step']
211
212
  if group['kourkoutas_beta']:
213
+ # Call prepare_step() once at the beginning of the step for all params
212
214
  self.kourkoutas_helper.maybe_prepare_step(current_step)
215
+ # Accumulate current grad's norm for the *next* step
213
216
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
214
-
215
- beta1, beta2 = group['betas']
216
- if group['kourkoutas_beta']:
217
+ # Get the dynamic beta2 calculated in prepare_step()
217
218
  beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
218
219
 
219
220
  step = state['step'] + 1
@@ -366,14 +367,4 @@ class AdamW_adv(torch.optim.Optimizer):
366
367
  for i, p in enumerate(group['params']):
367
368
  self.step_parameter(p, group, i)
368
369
 
369
- if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
370
- first_param_state = self.state[self.param_groups[0]['params'][0]]
371
- step_num = first_param_state['step']
372
-
373
- if step_num > 0 and step_num % self.k_logging == 0:
374
- if self._beta2_log:
375
- beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
376
- print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
377
- delattr(self, '_beta2_log')
378
-
379
370
  return loss
@@ -157,7 +157,7 @@ class Adopt_adv(torch.optim.Optimizer):
157
157
  "vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
158
158
  "t_alpha": t_alpha, "alpha_grad": alpha_grad,
159
159
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
160
- "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
160
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
161
161
  }
162
162
  self.clip_lambda = clip_lambda
163
163
  self.stochastic_rounding = stochastic_rounding
@@ -168,11 +168,10 @@ class Adopt_adv(torch.optim.Optimizer):
168
168
  self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
169
169
  self.Simplified_AdEMAMix = Simplified_AdEMAMix
170
170
  self.factored = nnmf_factor
171
+ self.kourkoutas_beta = kourkoutas_beta
172
+ self.layer_key_fn = layer_key_fn
171
173
  super().__init__(params, defaults)
172
174
 
173
- self.kourkoutas_beta = kourkoutas_beta
174
- self.k_logging= k_logging and kourkoutas_beta
175
- self.layer_key_fn = layer_key_fn and kourkoutas_beta
176
175
  if self.kourkoutas_beta:
177
176
  self.kourkoutas_helper = KourkoutasHelper(self)
178
177
 
@@ -238,13 +237,15 @@ class Adopt_adv(torch.optim.Optimizer):
238
237
  state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
239
238
  state['exp_avg_sq'] = grad.square() # v_0
240
239
 
240
+ beta1, beta2 = group['betas']
241
+
241
242
  current_step = state['step']
242
243
  if group['kourkoutas_beta']:
244
+ # Call prepare_step() once at the beginning of the step for all params
243
245
  self.kourkoutas_helper.maybe_prepare_step(current_step)
246
+ # Accumulate current grad's norm for the *next* step
244
247
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
245
-
246
- beta1, beta2 = group['betas']
247
- if group['kourkoutas_beta']:
248
+ # Get the dynamic beta2 calculated in prepare_step()
248
249
  beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
249
250
 
250
251
  # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
@@ -257,10 +258,10 @@ class Adopt_adv(torch.optim.Optimizer):
257
258
  alpha = group['alpha']
258
259
  t_alpha = group['t_alpha']
259
260
  # Use step+1 for 1-based step count in scheduler
260
- current_step = state['step'] + 1
261
+ alpha_step = state['step'] + 1
261
262
  alpha_t = alpha
262
- if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
263
- alpha_t = min(current_step * alpha / t_alpha, alpha)
263
+ if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
264
+ alpha_t = min(alpha_step * alpha / t_alpha, alpha)
264
265
  if self.Simplified_AdEMAMix:
265
266
  alpha_grad = group["alpha_grad"]
266
267
 
@@ -436,10 +437,4 @@ class Adopt_adv(torch.optim.Optimizer):
436
437
  first_param_state = self.state[self.param_groups[0]['params'][0]]
437
438
  step_num = first_param_state['step']
438
439
 
439
- if step_num > 0 and step_num % self.k_logging == 0:
440
- if self._beta2_log:
441
- beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
442
- print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
443
- delattr(self, '_beta2_log')
444
-
445
440
  return loss
@@ -189,7 +189,7 @@ class Prodigy_adv(torch.optim.Optimizer):
189
189
  "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
190
190
  "alpha_grad": alpha_grad,
191
191
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
192
- "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
192
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
193
193
  }
194
194
  self.stochastic_rounding = stochastic_rounding
195
195
  self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
@@ -198,14 +198,13 @@ class Prodigy_adv(torch.optim.Optimizer):
198
198
  self.Simplified_AdEMAMix = Simplified_AdEMAMix
199
199
  self.factored = nnmf_factor
200
200
  self.fsdp_in_use = fsdp_in_use
201
- super().__init__(params, defaults)
202
-
201
+
203
202
  self.kourkoutas_beta = kourkoutas_beta
204
- self.k_logging= k_logging and kourkoutas_beta
205
- self.layer_key_fn = layer_key_fn and kourkoutas_beta
203
+ self.layer_key_fn = layer_key_fn
204
+
205
+ super().__init__(params, defaults)
206
206
  if self.kourkoutas_beta:
207
207
  self.kourkoutas_helper = KourkoutasHelper(self)
208
-
209
208
  self.init_step()
210
209
 
211
210
  @property
@@ -301,21 +300,23 @@ class Prodigy_adv(torch.optim.Optimizer):
301
300
 
302
301
  current_step = state['step']
303
302
  if group['kourkoutas_beta']:
303
+ # Call prepare_step() once at the beginning of the step for all params
304
304
  self.kourkoutas_helper.maybe_prepare_step(current_step)
305
+ # Accumulate current grad's norm for the *next* step
305
306
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
306
-
307
- beta2 = self.beta2_default
308
- if group['kourkoutas_beta']:
307
+ # Get the dynamic beta2 calculated in prepare_step()
309
308
  beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
309
+ else:
310
+ beta2 = self.beta2_default
310
311
 
311
312
  if self.use_AdEMAMix:
312
313
  beta3_ema = group['beta3_ema']
313
314
  alpha = group['alpha']
314
315
  t_alpha = group['t_alpha']
315
- current_step = state['step'] + 1
316
+ alpha_step = state['step'] + 1
316
317
  alpha_t = alpha
317
- if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
318
- alpha_t = min(current_step * alpha / t_alpha, alpha)
318
+ if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
319
+ alpha_t = min(alpha_step * alpha / t_alpha, alpha)
319
320
  if self.Simplified_AdEMAMix:
320
321
  alpha_grad = group["alpha_grad"]
321
322
 
@@ -481,16 +482,6 @@ class Prodigy_adv(torch.optim.Optimizer):
481
482
  for i, p in enumerate(group['params']):
482
483
  self.step_parameter(p, group, i)
483
484
 
484
- if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
485
- first_param_state = self.state[self.param_groups[0]['params'][0]]
486
- step_num = first_param_state['step']
487
-
488
- if step_num > 0 and step_num % self.k_logging == 0:
489
- if self._beta2_log:
490
- beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
491
- print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
492
- delattr(self, '_beta2_log')
493
-
494
485
  self.calculate_d()
495
486
  self.init_step()
496
487
  return loss
@@ -116,15 +116,14 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
116
116
  "vector_reshape": vector_reshape,
117
117
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
118
118
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
119
- "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
119
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
120
120
  }
121
121
  self.stochastic_rounding = stochastic_rounding
122
122
  self.factored = nnmf_factor
123
+ self.kourkoutas_beta = kourkoutas_beta
124
+ self.layer_key_fn = layer_key_fn
123
125
  super().__init__(params, defaults)
124
126
 
125
- self.kourkoutas_beta = kourkoutas_beta
126
- self.k_logging= k_logging and kourkoutas_beta
127
- self.layer_key_fn = layer_key_fn and kourkoutas_beta
128
127
  if self.kourkoutas_beta:
129
128
  self.kourkoutas_helper = KourkoutasHelper(self)
130
129
 
@@ -189,17 +188,19 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
189
188
  state['num_sum'] = 1.0
190
189
  state['den_sum'] = 1.0
191
190
 
191
+ beta1_final, beta2 = group["betas"]
192
+
192
193
  current_step = state['step']
193
194
  if group['kourkoutas_beta']:
195
+ # Call prepare_step() once at the beginning of the step for all params
194
196
  self.kourkoutas_helper.maybe_prepare_step(current_step)
197
+ # Accumulate current grad's norm for the *next* step
195
198
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
199
+ # Get the dynamic beta2 calculated in prepare_step()
200
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
196
201
 
197
- beta1_final, beta2 = group["betas"]
198
202
  beta1_warmup = group["beta1_warmup"]
199
203
  alpha_grad = group["alpha_grad"]
200
-
201
- if group['kourkoutas_beta']:
202
- beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
203
204
 
204
205
  if beta1_warmup is not None:
205
206
  step = state['step'] + 1
@@ -294,14 +295,4 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
294
295
  for i, p in enumerate(group['params']):
295
296
  self.step_parameter(p, group, i)
296
297
 
297
- if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
298
- first_param_state = self.state[self.param_groups[0]['params'][0]]
299
- step_num = first_param_state['step']
300
-
301
- if step_num > 0 and step_num % self.k_logging == 0:
302
- if self._beta2_log:
303
- beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
304
- print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
305
- delattr(self, '_beta2_log')
306
-
307
298
  return loss
@@ -28,26 +28,32 @@ class KourkoutasHelper:
28
28
  self.optimizer.layer_key_fn = lambda p: id(p)
29
29
 
30
30
  for group in self.optimizer.param_groups:
31
- if not group.get('kourkoutas_beta', False):
32
- continue
33
31
  for p in group['params']:
34
32
  if p.grad is None: continue
35
33
  layer_key = self.optimizer.layer_key_fn(p)
36
34
  if layer_key not in self.layer_info:
37
35
  self.layer_info[layer_key] = {'params': [], 'group_ref': group}
38
36
  self.layer_info[layer_key]['params'].append(p)
37
+
38
+ k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
39
+ if k_logging_interval > 0:
40
+ print(f"[Kourkoutas-β Debug] Layer info built. Found {len(self.layer_info)} unique layers/buckets.")
41
+
39
42
  self._layer_info_built = True
40
43
 
41
- def prepare_step(self):
44
+ def prepare_step(self, current_step: int):
42
45
  """
43
46
  Calculates dynamic beta2 for all layers using the completed scalar accumulators
44
47
  from the PREVIOUS step. Should be called once at the start of an optimizer step.
45
48
  """
46
49
  self._build_layer_info_if_needed()
47
50
 
48
- if hasattr(self.optimizer, 'logging') and self.optimizer.logging:
49
- if not hasattr(self.optimizer, '_beta2_log'):
50
- self.optimizer._beta2_log = []
51
+ # Check if logging is enabled for this step based on the interval
52
+ k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
53
+ is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
54
+
55
+ beta2_log = [] if is_logging_step else None
56
+ first_layer_key = next(iter(self.layer_info), None)
51
57
 
52
58
  for layer_key, info in self.layer_info.items():
53
59
  params, group = info['params'], info['group_ref']
@@ -60,49 +66,69 @@ class KourkoutasHelper:
60
66
 
61
67
  layer_state = self.layer_state[layer_key]
62
68
 
69
+ # Use the completed accumulator from the previous step
63
70
  pooled_grad_norm = torch.sqrt(layer_state['sum_sq_accumulator'])
64
71
 
65
72
  r_ema = layer_state['r_ema_grad_norm']
73
+ prev_r_ema_val = r_ema.item() # for logging
74
+
75
+ # EMA is always updated, even during warmup
66
76
  r_ema.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
67
77
 
68
- raw = pooled_grad_norm / (r_ema + group['tiny_spike'])
69
- sun = raw / (1.0 + raw)
78
+ sun = torch.tensor(0.0, device=r_ema.device) # Default sun to 0 for warmup
70
79
  beta2_max = group['betas'][1]
71
- beta2 = beta2_max - (beta2_max - group['beta2_min']) * sun
72
-
73
- layer_state['dynamic_beta2'] = beta2.item()
80
+
81
+ # --- CONSOLIDATED WARMUP LOGIC ---
82
+ if current_step < group['k_warmup_steps']:
83
+ beta2 = beta2_max
84
+ else:
85
+ raw = pooled_grad_norm / (r_ema + group['tiny_spike'])
86
+ sun = raw / (1.0 + raw)
87
+ beta2 = beta2_max - (beta2_max - group['beta2_min']) * sun
88
+
89
+ layer_state['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
74
90
  layer_state['sum_sq_accumulator'].zero_()
75
91
 
76
- if hasattr(self.optimizer, 'logging') and self.optimizer.logging and hasattr(self.optimizer, '_beta2_log'):
77
- self.optimizer._beta2_log.append(beta2.item())
92
+ if is_logging_step:
93
+ beta2_log.append(layer_state['dynamic_beta2'])
94
+ if layer_key == first_layer_key:
95
+ print(f"\n[Kourkoutas-β Debug] Step {current_step + 1} - Sample Layer '{layer_key}':")
96
+ print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema.item():.4e}")
97
+ print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {layer_state['dynamic_beta2']:.4f}")
98
+
99
+ if is_logging_step and beta2_log:
100
+ beta2_tensor = torch.tensor(beta2_log, device='cpu')
101
+ print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
102
+
78
103
 
79
104
  def maybe_prepare_step(self, current_step: int):
80
105
  """
81
106
  A universal guard that calls prepare_step() exactly once per training step.
82
107
  """
83
108
  if self._current_step_prepared < current_step:
84
- self.prepare_step()
109
+ self.prepare_step(current_step)
85
110
  self._current_step_prepared = current_step
86
111
 
87
112
  def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
88
113
  """
89
114
  Accumulates the squared L2 norm of a single gradient for the next step's calculation.
90
115
  """
116
+ self._build_layer_info_if_needed()
91
117
  layer_key = self.optimizer.layer_key_fn(p)
92
- if layer_key not in self.layer_state:
93
- self.layer_state[layer_key] = {
94
- 'r_ema_grad_norm': torch.tensor(0.0, device=p.device, dtype=torch.float32),
95
- 'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
96
- }
97
- self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
118
+
119
+ if layer_key in self.layer_info:
120
+ if layer_key not in self.layer_state:
121
+ self.layer_state[layer_key] = {
122
+ 'r_ema_grad_norm': torch.tensor(0.0, device=p.device, dtype=torch.float32),
123
+ 'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
124
+ }
125
+ # Accumulate for the *next* step's prepare_step call
126
+ self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
98
127
 
99
128
  def get_beta2(self, p: torch.Tensor, group: dict, current_step: int) -> float:
100
129
  """
101
130
  Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
102
131
  """
103
- beta2_default = group['betas'][1]
104
- if current_step < group['k_warmup_steps']:
105
- return 0.5 * (group['beta2_min'] + beta2_default)
106
-
107
132
  layer_key = self.optimizer.layer_key_fn(p)
108
- return self.layer_state.get(layer_key, {}).get('dynamic_beta2', beta2_default)
133
+ # The default is the max value, which is correct for unmapped params or edge cases
134
+ return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.1.0.dev1
3
+ Version: 1.1.0.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="1.1.0.dev1",
8
+ version="1.1.0.dev2",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes