adv-optm 1.0.5__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -3,15 +3,18 @@ import torch.distributed as dist
3
3
 
4
4
  import math
5
5
 
6
+ from typing import Optional, Callable
7
+
6
8
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
7
9
  from ..util.Effective_Shape import _get_effective_shape
8
10
  from ..util.NNMF import _nnmf,_unnmf
9
11
  from ..util.OrthoGrad import _orthogonalize_gradient
10
12
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
13
+ from ..util.Kourkoutas import KourkoutasHelper
11
14
 
12
15
  class Prodigy_adv(torch.optim.Optimizer):
13
16
  """
14
- Implements a factored Prodigy/AdamW algorithm.
17
+ Implements an advanced Prodigy algorithm.
15
18
  This is an advanced version of Prodigy with optional features like
16
19
  low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
17
20
 
@@ -85,6 +88,31 @@ class Prodigy_adv(torch.optim.Optimizer):
85
88
  prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
86
89
  after the specified optimiser step and release all state memory required by Prodigy
87
90
  (default: 0).
91
+ d_limiter (bool): whether to clamp the new step size estimate (`d_hat`)
92
+ to prevent sudden, volatile increases in the adaptive step size (`d`).
93
+ (default: False)
94
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
95
+ If `False`, the optimizer behaves as standard AdamW/Prodigy. (default: False)
96
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
97
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
98
+ (default: 0.88)
99
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
100
+ the pooled gradient norms. Corresponds to `α` in the paper.
101
+ (default: 0.93)
102
+ tiny_spike (float): A small constant added to the denominator of the
103
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
104
+ to `ε_spike` in the paper. (default: 1e-9)
105
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
106
+ at a fixed beta2 value before the
107
+ dynamic logic activates. (default: 0)
108
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
109
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
110
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
111
+ logging (default: 0).
112
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
113
+ and returns a unique, hashable key representing its "layer" or "bucket".
114
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
115
+ (default: None)
88
116
  """
89
117
 
90
118
  def __init__(
@@ -116,6 +144,15 @@ class Prodigy_adv(torch.optim.Optimizer):
116
144
  fsdp_in_use: bool = False,
117
145
  slice_p: int = 11,
118
146
  prodigy_steps: int = 0,
147
+ d_limiter: bool = False,
148
+ # K-b parameters
149
+ kourkoutas_beta: bool = False,
150
+ beta2_min: float = 0.9,
151
+ ema_alpha: float = 0.95,
152
+ tiny_spike: float = 1e-9,
153
+ k_warmup_steps: int = 0,
154
+ k_logging: int = 0,
155
+ layer_key_fn: Optional[Callable] = None,
119
156
  ):
120
157
  if not (lr >= 0.0):
121
158
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -141,8 +178,10 @@ class Prodigy_adv(torch.optim.Optimizer):
141
178
  if use_atan2 and Simplified_AdEMAMix:
142
179
  print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
143
180
  use_atan2 = False
144
- if Simplified_AdEMAMix and alpha_grad > 0:
145
- # scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
181
+ if kourkoutas_beta and not (betas[1] > beta2_min):
182
+ raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
183
+ if Simplified_AdEMAMix and alpha_grad > 0 and not d_limiter:
184
+ # scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix.
146
185
  d_coef = d_coef/alpha_grad
147
186
 
148
187
  defaults = {
@@ -152,8 +191,10 @@ class Prodigy_adv(torch.optim.Optimizer):
152
191
  "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
153
192
  "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
154
193
  "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
155
- "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
156
- "alpha_grad": alpha_grad,
194
+ "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps, "d_limiter": d_limiter,
195
+ "alpha_grad": alpha_grad,
196
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
197
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
157
198
  }
158
199
  self.stochastic_rounding = stochastic_rounding
159
200
  self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
@@ -162,7 +203,13 @@ class Prodigy_adv(torch.optim.Optimizer):
162
203
  self.Simplified_AdEMAMix = Simplified_AdEMAMix
163
204
  self.factored = nnmf_factor
164
205
  self.fsdp_in_use = fsdp_in_use
206
+
207
+ self.kourkoutas_beta = kourkoutas_beta
208
+ self.layer_key_fn = layer_key_fn
209
+
165
210
  super().__init__(params, defaults)
211
+ if self.kourkoutas_beta:
212
+ self.kourkoutas_helper = KourkoutasHelper(self)
166
213
  self.init_step()
167
214
 
168
215
  @property
@@ -180,19 +227,17 @@ class Prodigy_adv(torch.optim.Optimizer):
180
227
  def init_step(self):
181
228
  """Resets accumulators and calculates dlr for the upcoming step."""
182
229
  self.d_denom = 0.0
183
-
230
+
184
231
  g_group = self.param_groups[0]
185
- self.beta1, self.beta2 = g_group['betas']
232
+ self.beta1, self.beta2_default = g_group['betas']
186
233
  self.beta3 = g_group['beta3']
187
234
  if self.beta3 is None:
188
- self.beta3 = math.sqrt(self.beta2)
189
-
190
- k = g_group['k']
235
+ self.beta3 = math.sqrt(self.beta2_default)
236
+
191
237
  self.d = g_group['d']
192
238
  lr = g_group['lr']
193
239
 
194
240
  self.dlr = self.d * lr
195
-
196
241
  self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
197
242
 
198
243
  @torch.no_grad()
@@ -211,7 +256,7 @@ class Prodigy_adv(torch.optim.Optimizer):
211
256
  state = self.state[p]
212
257
 
213
258
  # State Initialization
214
- if len(state) == 0:
259
+ if 'step' not in state:
215
260
  state['step'] = 0
216
261
 
217
262
  should_factor = (
@@ -258,14 +303,27 @@ class Prodigy_adv(torch.optim.Optimizer):
258
303
  else:
259
304
  state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
260
305
 
306
+ current_step = state['step']
307
+ if group['kourkoutas_beta']:
308
+ # Call prepare_step() once at the beginning of the step for all params
309
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
310
+ # Accumulate current grad's norm for the *next* step
311
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
312
+ # Get the dynamic beta2 calculated in prepare_step()
313
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
314
+ beta3 = math.sqrt(beta2)
315
+ else:
316
+ beta2 = self.beta2_default
317
+ beta3 = self.beta3
318
+
261
319
  if self.use_AdEMAMix:
262
320
  beta3_ema = group['beta3_ema']
263
321
  alpha = group['alpha']
264
322
  t_alpha = group['t_alpha']
265
- current_step = state['step'] + 1
323
+ alpha_step = state['step'] + 1
266
324
  alpha_t = alpha
267
- if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
268
- alpha_t = min(current_step * alpha / t_alpha, alpha)
325
+ if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
326
+ alpha_t = min(alpha_step * alpha / t_alpha, alpha)
269
327
  if self.Simplified_AdEMAMix:
270
328
  alpha_grad = group["alpha_grad"]
271
329
 
@@ -295,7 +353,7 @@ class Prodigy_adv(torch.optim.Optimizer):
295
353
  del mask
296
354
 
297
355
  vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
298
- vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
356
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - beta2))
299
357
 
300
358
  if self.use_AdEMAMix:
301
359
  mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
@@ -308,11 +366,11 @@ class Prodigy_adv(torch.optim.Optimizer):
308
366
  if self.beta1 > 0:
309
367
  update = torch.add(mt, mt_slow, alpha=alpha_t)
310
368
  else:
311
- update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
369
+ update = torch.add(grad_reshaped.mul(self.d), mt_slow, alpha=alpha_t)
312
370
  elif self.Simplified_AdEMAMix:
313
371
  update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
314
372
  else:
315
- update = mt.clone() if self.beta1 > 0 else grad_reshaped.clone()
373
+ update = mt.clone() if self.beta1 > 0 else grad_reshaped.mul(self.d)
316
374
  del grad_reshaped
317
375
 
318
376
  if group['use_atan2']:
@@ -362,13 +420,13 @@ class Prodigy_adv(torch.optim.Optimizer):
362
420
  if self.beta1 > 0:
363
421
  update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
364
422
  else:
365
- update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
423
+ update = torch.add(grad.mul(self.d), exp_avg_slow, alpha=alpha_t)
366
424
  elif self.Simplified_AdEMAMix:
367
425
  update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
368
426
  else:
369
- update = exp_avg.clone() if self.beta1 > 0 else grad.clone()
427
+ update = exp_avg.clone() if self.beta1 > 0 else grad.mul(self.d)
370
428
 
371
- exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
429
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - beta2))
372
430
 
373
431
  if group['use_atan2']:
374
432
  a = 1.2732395
@@ -393,7 +451,7 @@ class Prodigy_adv(torch.optim.Optimizer):
393
451
  self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
394
452
 
395
453
  alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
396
- s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
454
+ s.mul_(beta3).add_(grad_flat[::slice_p], alpha=alpha)
397
455
  self.d_denom += s.abs().sum().item()
398
456
 
399
457
  del s, p0, grad_flat, p_flat, alpha
@@ -431,7 +489,6 @@ class Prodigy_adv(torch.optim.Optimizer):
431
489
  for i, p in enumerate(group['params']):
432
490
  self.step_parameter(p, group, i)
433
491
 
434
-
435
492
  self.calculate_d()
436
493
  self.init_step()
437
494
  return loss
@@ -460,6 +517,8 @@ class Prodigy_adv(torch.optim.Optimizer):
460
517
  d_hat = self.d
461
518
  if global_d_denom > 0:
462
519
  d_hat = d_coef * global_d_numerator / global_d_denom
520
+ if g_group['d_limiter']:
521
+ d_hat = min(self.d * (2 ** 0.25), d_hat)
463
522
  if self.d == g_group['d0']:
464
523
  self.d = max(self.d, d_hat)
465
524
  d_max = max(d_max, d_hat)
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ from typing import Optional, Callable
2
3
 
3
4
  import math
4
5
 
@@ -7,6 +8,7 @@ from ..util.Effective_Shape import _get_effective_shape
7
8
  from ..util.NNMF import _nnmf,_unnmf
8
9
  from ..util.OrthoGrad import _orthogonalize_gradient
9
10
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
11
+ from ..util.Kourkoutas import KourkoutasHelper
10
12
 
11
13
  # A little helper from the original simplified_AdEMAMix
12
14
  def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
@@ -47,6 +49,28 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
47
49
  stochastic_rounding (bool): whether to use stochastic
48
50
  rounding for BF16 parameter updates (default: True).
49
51
  orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
52
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
53
+ If `False`, the optimizer behaves as standard Simplified_AdEMAMix. (default: False)
54
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
55
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
56
+ (default: 0.88)
57
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
58
+ the pooled gradient norms. Corresponds to `α` in the paper.
59
+ (default: 0.93)
60
+ tiny_spike (float): A small constant added to the denominator of the
61
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
62
+ to `ε_spike` in the paper. (default: 1e-9)
63
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
64
+ at a fixed beta2 value before the
65
+ dynamic logic activates. (default: 0)
66
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
67
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
68
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
69
+ logging (default: 0).
70
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
71
+ and returns a unique, hashable key representing its "layer" or "bucket".
72
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
73
+ (default: None)
50
74
  nnmf_factor (bool): whether to use the factorization or disable it to use
51
75
  the uncompressed optimizer. (default: False)
52
76
  """
@@ -65,6 +89,13 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
65
89
  vector_reshape: bool = True,
66
90
  stochastic_rounding: bool = True,
67
91
  orthogonal_gradient: bool = False,
92
+ kourkoutas_beta: bool = False,
93
+ beta2_min: float = 0.9,
94
+ ema_alpha: float = 0.95,
95
+ tiny_spike: float = 1e-9,
96
+ k_warmup_steps: int = 0,
97
+ k_logging: int = 0,
98
+ layer_key_fn: Optional[Callable] = None,
68
99
  nnmf_factor: bool = False,
69
100
  ):
70
101
  if not (lr >= 0.0):
@@ -77,17 +108,25 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
77
108
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
78
109
  if not 0.0 <= alpha_grad:
79
110
  raise ValueError("Invalid alpha value: {}".format(alpha_grad))
111
+ if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
80
112
 
81
113
  defaults = {
82
114
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
83
115
  "alpha_grad": alpha_grad, "beta1_warmup": beta1_warmup, "min_beta1": min_beta1,
84
116
  "vector_reshape": vector_reshape,
85
117
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
118
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
119
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
86
120
  }
87
121
  self.stochastic_rounding = stochastic_rounding
88
122
  self.factored = nnmf_factor
123
+ self.kourkoutas_beta = kourkoutas_beta
124
+ self.layer_key_fn = layer_key_fn
89
125
  super().__init__(params, defaults)
90
126
 
127
+ if self.kourkoutas_beta:
128
+ self.kourkoutas_helper = KourkoutasHelper(self)
129
+
91
130
  @property
92
131
  def supports_fused_back_pass(self):
93
132
  return True
@@ -113,7 +152,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
113
152
  state = self.state[p]
114
153
 
115
154
  # State Initialization
116
- if len(state) == 0:
155
+ if 'step' not in state:
117
156
  state['step'] = 0
118
157
 
119
158
  should_factor = (
@@ -150,6 +189,16 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
150
189
  state['den_sum'] = 1.0
151
190
 
152
191
  beta1_final, beta2 = group["betas"]
192
+
193
+ current_step = state['step']
194
+ if group['kourkoutas_beta']:
195
+ # Call prepare_step() once at the beginning of the step for all params
196
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
197
+ # Accumulate current grad's norm for the *next* step
198
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
199
+ # Get the dynamic beta2 calculated in prepare_step()
200
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
201
+
153
202
  beta1_warmup = group["beta1_warmup"]
154
203
  alpha_grad = group["alpha_grad"]
155
204
 
@@ -161,7 +210,10 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
161
210
 
162
211
  if group['use_bias_correction']:
163
212
  state['num_sum'] = beta1 * state['num_sum'] + 1.0
164
- state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
213
+ if group['kourkoutas_beta']:
214
+ state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
215
+ else:
216
+ state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
165
217
 
166
218
  if state['factored']:
167
219
  d1, d2 = state['effective_shape']
@@ -0,0 +1,171 @@
1
+ import torch
2
+ from torch.optim import Optimizer
3
+ from typing import Callable
4
+
5
+ class KourkoutasHelper:
6
+ """
7
+ A helper class to add layer-wise Kourkoutas-β functionality to a PyTorch optimizer.
8
+ """
9
+ def __init__(self, optimizer: Optimizer):
10
+ # We need a reference to the optimizer to access its param_groups and state
11
+ if not hasattr(optimizer, 'param_groups'):
12
+ raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
13
+ self.optimizer = optimizer
14
+ self.layer_state = {}
15
+
16
+ self.layer_info = {}
17
+ self._layer_info_built = False
18
+ self._current_step_prepared = -1
19
+
20
+ # Store stats for external logging (e.g., TensorBoard)
21
+ self.last_beta2_stats = {}
22
+
23
+ # This ensures the map is complete before the first backward pass,
24
+ # making it compatible with fused back pass mechanisms.
25
+ self._build_layer_info_if_needed()
26
+
27
+ if self.optimizer.param_groups[0].get('k_logging', 0) > 0:
28
+ self.print_layer_info()
29
+
30
+ def _build_layer_info_if_needed(self):
31
+ """Builds a map of layers and the parameters they contain."""
32
+ if self._layer_info_built:
33
+ return
34
+
35
+ if not hasattr(self.optimizer, 'layer_key_fn') or self.optimizer.layer_key_fn is None:
36
+ print("Warning: KourkoutasHelper requires 'layer_key_fn' on the optimizer. Defaulting to tensor-wise (id).")
37
+ self.optimizer.layer_key_fn = lambda p: id(p)
38
+
39
+ for group in self.optimizer.param_groups:
40
+ for p in group['params']:
41
+ # The mapping is static and should not depend on the presence of a gradient.
42
+ layer_key = self.optimizer.layer_key_fn(p)
43
+ if layer_key not in self.layer_info:
44
+ self.layer_info[layer_key] = {'params': [], 'group_ref': group}
45
+ self.layer_info[layer_key]['params'].append(p)
46
+
47
+ k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
48
+ if k_logging_interval > 0:
49
+ print(f"[Kourkoutas-β Debug] Layer info built. Found {len(self.layer_info)} unique layers/buckets.")
50
+
51
+ self._layer_info_built = True
52
+
53
+ def print_layer_info(self):
54
+ """Prints the contents of self.layer_info for debugging."""
55
+ print("\n--- BEGIN self.layer_info DUMP ---")
56
+ if not self.layer_info:
57
+ print("Layer info is empty. Make sure the optimizer has parameters.")
58
+ return
59
+
60
+ for layer_key, info in self.layer_info.items():
61
+ param_count = len(info['params'])
62
+ first_param_details = ""
63
+ if param_count > 0:
64
+ p = info['params'][0]
65
+ first_param_details = f" (Example param shape: {list(p.shape)}, dtype: {p.dtype})"
66
+
67
+ print(f"Key: {layer_key}, Params: {param_count}{first_param_details}")
68
+
69
+ print("--- END self.layer_info DUMP ---\n")
70
+
71
+ def prepare_step(self, current_step: int):
72
+ """
73
+ Calculates dynamic beta2 for all layers using the completed scalar accumulators
74
+ from the PREVIOUS step. Should be called once at the start of an optimizer step.
75
+ """
76
+
77
+ beta2_log = []
78
+ first_layer_key = next(iter(self.layer_info), None)
79
+ # These are just for the sample log, initialize them
80
+ sun, pooled_grad_norm, prev_r_ema_val, r_ema_tensor = (torch.tensor(0.0),)*4
81
+
82
+ for layer_key, info in self.layer_info.items():
83
+ params, group = info['params'], info['group_ref']
84
+
85
+ first_param_in_layer = info['params'][0]
86
+ param_state = self.optimizer.state[first_param_in_layer]
87
+
88
+ if layer_key not in self.layer_state:
89
+ self.layer_state[layer_key] = {
90
+ 'sum_sq_accumulator': torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
91
+ }
92
+
93
+ if 'kourkoutas_r_ema' not in param_state:
94
+ param_state['kourkoutas_r_ema'] = torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
95
+
96
+ r_ema_tensor = param_state['kourkoutas_r_ema']
97
+ accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
98
+
99
+ pooled_grad_norm = torch.sqrt(accumulator)
100
+ prev_r_ema_val = r_ema_tensor.item() # for logging
101
+
102
+ # Update the persistent EMA tensor in-place.
103
+ r_ema_tensor.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
104
+
105
+ beta2_max = group['betas'][1]
106
+ sun = torch.tensor(0.0, device=r_ema_tensor.device) # Default sun to 0 for warmup
107
+
108
+ if current_step < group['k_warmup_steps']:
109
+ beta2 = beta2_max
110
+ else:
111
+ raw = pooled_grad_norm / (r_ema_tensor + group['tiny_spike'])
112
+ sun = raw / (1.0 + raw)
113
+ beta2 = beta2_max - (beta2_max - group['beta2_min']) * sun
114
+
115
+ # Store the final calculated beta2 in the helper's transient state for this step.
116
+ self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
117
+
118
+ # Reset the accumulator for the next optimizer step.
119
+ accumulator.zero_()
120
+
121
+ beta2_log.append(self.layer_state[layer_key]['dynamic_beta2'])
122
+
123
+ # Always compute stats for TensorBoard
124
+ if beta2_log:
125
+ beta2_tensor = torch.tensor(beta2_log, device='cpu')
126
+ self.last_beta2_stats = {
127
+ 'min': beta2_tensor.min().item(),
128
+ 'max': beta2_tensor.max().item(),
129
+ 'mean': beta2_tensor.mean().item(),
130
+ }
131
+
132
+ # Handle periodic console logging
133
+ k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
134
+ is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
135
+ if is_logging_step and self.last_beta2_stats:
136
+ if first_layer_key:
137
+ print(f"\n[Kourkoutas-β Debug] Step {current_step + 1} - Sample Layer '{first_layer_key}':")
138
+ print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema_tensor.item():.4e}")
139
+ print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {self.layer_state[first_layer_key]['dynamic_beta2']:.4f}")
140
+ print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={self.last_beta2_stats['min']:.4f}, Max={self.last_beta2_stats['max']:.4f}, Mean={self.last_beta2_stats['mean']:.4f}")
141
+
142
+ def maybe_prepare_step(self, current_step: int):
143
+ """
144
+ A universal guard that calls prepare_step() exactly once per training step.
145
+ """
146
+ if self._current_step_prepared < current_step:
147
+ self.prepare_step(current_step)
148
+ self._current_step_prepared = current_step
149
+
150
+ def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
151
+ """
152
+ Accumulates the squared L2 norm of a single gradient for the next step's calculation.
153
+ """
154
+ layer_key = self.optimizer.layer_key_fn(p)
155
+
156
+ if layer_key in self.layer_info:
157
+ # Initialize the transient state for this layer if it's the first time in the step.
158
+ if layer_key not in self.layer_state:
159
+ self.layer_state[layer_key] = {
160
+ 'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
161
+ }
162
+ # Accumulate for the *next* step's prepare_step call
163
+ self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
164
+
165
+ def get_beta2(self, p: torch.Tensor, group: dict, current_step: int) -> float:
166
+ """
167
+ Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
168
+ """
169
+ layer_key = self.optimizer.layer_key_fn(p)
170
+ # The default is the max value, which is correct for unmapped params or edge cases
171
+ return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.0.5
3
+ Version: 1.1.0
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -0,0 +1,20 @@
1
+ adv_optm/__init__.py,sha256=lNp6_DdCvw-0zok2UdMkaEyVLZIlMRSKgBp-hJ15Hao,306
2
+ adv_optm/optim/AdamW_adv.py,sha256=ddEUVOif1gfZPgEJNrEGZ2wnha4MPMWw5ppPd8acQ3o,17457
3
+ adv_optm/optim/Adopt_adv.py,sha256=fhH3hS9K6z5Blxc7NFfzpCrUGbl9EQnwLPmKDxBC1zg,21415
4
+ adv_optm/optim/Lion_Prodigy_adv.py,sha256=aJ9orEEw0QYbrDzn1be0SHvOBlIkLwWG9RpWFuNMskM,13163
5
+ adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
6
+ adv_optm/optim/Prodigy_adv.py,sha256=4O7BLGhqLW46Ff3UN9JfrktHonCYDy3ojHUfW8jtaDs,25940
7
+ adv_optm/optim/Simplified_AdEMAMix.py,sha256=gPjMhKulzmAeO42foe-d7xW0AcB50vKFYsvHgxbD3uc,12949
8
+ adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
9
+ adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
10
+ adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
11
+ adv_optm/util/Kourkoutas.py,sha256=DCsIcZ1sEeSwthN8KZH7OTKoIZJ3ah4t5DNiqxsSuCk,8344
12
+ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
13
+ adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
14
+ adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
15
+ adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
16
+ adv_optm-1.1.0.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
17
+ adv_optm-1.1.0.dist-info/METADATA,sha256=dwRwKQykba-7TP6a94qpOg6xz450QESAi5E8AnEV-iM,8422
18
+ adv_optm-1.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
+ adv_optm-1.1.0.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
20
+ adv_optm-1.1.0.dist-info/RECORD,,
@@ -1,19 +0,0 @@
1
- adv_optm/__init__.py,sha256=9sM4fP1pj9divFhLVUzHbBWe50H82H3FYIGVIHTHpkg,306
2
- adv_optm/optim/AdamW_adv.py,sha256=aTuYcJgd_EcZOrs6TDgBrBKw3wtU5LPzE5WvTBDDeEo,14317
3
- adv_optm/optim/Adopt_adv.py,sha256=KdEVSl2w1gRXFtz2fwCVT4i9inTspp-PQq3mobpa-9A,17476
4
- adv_optm/optim/Lion_Prodigy_adv.py,sha256=sGzhts9a6gHfCkuHTB5L9IrClo4c6UThzYYErBwqOaA,12844
5
- adv_optm/optim/Lion_adv.py,sha256=6G1CukJB_pC7l9HwFEuY1ydsNHZFabVmOvcHDsHHVuQ,8295
6
- adv_optm/optim/Prodigy_adv.py,sha256=8XUpu19BaBmHb-R9K3jgwySDbtVaLU1_Drtttc_zITs,22461
7
- adv_optm/optim/Simplified_AdEMAMix.py,sha256=tb3d6Cw_nGwcTzYUhDnKqyP7GzjD1hn8k4WqGG5lhmw,9813
8
- adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
9
- adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
10
- adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
11
- adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
12
- adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
13
- adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
14
- adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
15
- adv_optm-1.0.5.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
16
- adv_optm-1.0.5.dist-info/METADATA,sha256=ouxI4cwBQ2IPuOjrkA478XwSetGP6ku51vW1QxHIGcY,8422
17
- adv_optm-1.0.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
- adv_optm-1.0.5.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
19
- adv_optm-1.0.5.dist-info/RECORD,,