adv-optm 1.0.6__tar.gz → 1.1.0.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (25) hide show
  1. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/PKG-INFO +1 -1
  2. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/__init__.py +1 -1
  3. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/optim/AdamW_adv.py +59 -7
  4. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/optim/Adopt_adv.py +58 -7
  5. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/optim/Prodigy_adv.py +62 -12
  6. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/optim/Simplified_AdEMAMix.py +53 -1
  7. adv_optm-1.1.0.dev2/adv_optm/util/Kourkoutas.py +134 -0
  8. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/PKG-INFO +1 -1
  9. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/SOURCES.txt +1 -0
  10. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/setup.py +1 -1
  11. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/LICENSE +0 -0
  12. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/README.md +0 -0
  13. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  14. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/optim/Lion_adv.py +0 -0
  15. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/optim/__init__.py +0 -0
  16. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  17. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/util/Effective_Shape.py +0 -0
  18. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/util/NNMF.py +0 -0
  19. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/util/One_Bit_Boolean.py +0 -0
  20. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/util/OrthoGrad.py +0 -0
  21. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm/util/__init__.py +0 -0
  22. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/dependency_links.txt +0 -0
  23. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/requires.txt +0 -0
  24. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/adv_optm.egg-info/top_level.txt +0 -0
  25. {adv_optm-1.0.6 → adv_optm-1.1.0.dev2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.0.6
3
+ Version: 1.1.0.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -16,4 +16,4 @@ __all__ = [
16
16
  "Lion_Prodigy_adv",
17
17
  ]
18
18
 
19
- __version__ = "1.0.6"
19
+ __version__ = "1.1.0.dev2"
@@ -1,11 +1,12 @@
1
1
  import torch
2
- from typing import Optional
2
+ from typing import Optional, Callable
3
3
 
4
4
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
5
  from ..util.Effective_Shape import _get_effective_shape
6
6
  from ..util.NNMF import _nnmf,_unnmf
7
7
  from ..util.OrthoGrad import _orthogonalize_gradient
8
8
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+ from ..util.Kourkoutas import KourkoutasHelper
9
10
 
10
11
  class AdamW_adv(torch.optim.Optimizer):
11
12
  """
@@ -54,6 +55,28 @@ class AdamW_adv(torch.optim.Optimizer):
54
55
  as it gradually introduces the stabilizing slow momentum term. During
55
56
  the warmup, `alpha` ramps from 0 to its target value. If `None`,
56
57
  the scheduler is disabled. (default: None)
58
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
59
+ If `False`, the optimizer behaves as standard AdamW. (default: False)
60
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
61
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
62
+ (default: 0.88)
63
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
64
+ the pooled gradient norms. Corresponds to `α` in the paper.
65
+ (default: 0.93)
66
+ tiny_spike (float): A small constant added to the denominator of the
67
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
68
+ to `ε_spike` in the paper. (default: 1e-9)
69
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
70
+ at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
71
+ dynamic logic activates. (default: 0)
72
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
73
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
74
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
75
+ logging (default: 0).
76
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
77
+ and returns a unique, hashable key representing its "layer" or "bucket".
78
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
79
+ (default: None)
57
80
  nnmf_factor (bool): whether to use the factorization or disable it to use
58
81
  the uncompressed optimizer. (default: False)
59
82
  """
@@ -76,6 +99,13 @@ class AdamW_adv(torch.optim.Optimizer):
76
99
  beta3_ema: float = 0.9999,
77
100
  alpha: float = 5.0,
78
101
  t_alpha: int | None = None,
102
+ kourkoutas_beta: bool = False,
103
+ beta2_min: float = 0.88,
104
+ ema_alpha: float = 0.93,
105
+ tiny_spike: float = 1e-9,
106
+ k_warmup_steps: int = 0,
107
+ k_logging: int = 0,
108
+ layer_key_fn: Optional[Callable] = None,
79
109
  nnmf_factor: bool = False,
80
110
  ):
81
111
  if not (lr >= 0.0):
@@ -86,6 +116,8 @@ class AdamW_adv(torch.optim.Optimizer):
86
116
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
87
117
  if not (weight_decay >= 0.0):
88
118
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
119
+ if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
120
+
89
121
  if cautious_mask and grams_moment:
90
122
  print("Warning: cautious is incompatible with grams, Disabling cautious.")
91
123
  cautious_mask = False
@@ -95,14 +127,21 @@ class AdamW_adv(torch.optim.Optimizer):
95
127
  "vector_reshape": vector_reshape, "use_atan2": use_atan2,
96
128
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
97
129
  "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
130
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
131
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
98
132
  }
99
133
  self.stochastic_rounding = stochastic_rounding
100
134
  self.cautious_mask = cautious_mask
101
135
  self.grams_moment = grams_moment
102
136
  self.use_AdEMAMix = use_AdEMAMix
103
137
  self.factored = nnmf_factor
138
+ self.kourkoutas_beta = kourkoutas_beta
139
+ self.layer_key_fn = layer_key_fn
104
140
  super().__init__(params, defaults)
105
141
 
142
+ if self.kourkoutas_beta:
143
+ self.kourkoutas_helper = KourkoutasHelper(self)
144
+
106
145
  @property
107
146
  def supports_fused_back_pass(self):
108
147
  return True
@@ -127,8 +166,6 @@ class AdamW_adv(torch.optim.Optimizer):
127
166
  grad = _orthogonalize_gradient(p, grad)
128
167
  state = self.state[p]
129
168
 
130
- beta1, beta2 = group['betas']
131
-
132
169
  # State Initialization
133
170
  if len(state) == 0:
134
171
  state['step'] = 0
@@ -148,7 +185,7 @@ class AdamW_adv(torch.optim.Optimizer):
148
185
  d1, d2 = state['effective_shape']
149
186
 
150
187
  # First moment (m)
151
- if beta1 > 0:
188
+ if group['betas'][0] > 0:
152
189
  state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
153
190
  state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
154
191
  if not self.grams_moment:
@@ -163,16 +200,31 @@ class AdamW_adv(torch.optim.Optimizer):
163
200
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
164
201
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
165
202
  else: # Fallback to standard AdamW for non-factored tensors
166
- if beta1 > 0:
203
+ if group['betas'][0] > 0:
167
204
  state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
168
205
  if self.use_AdEMAMix:
169
206
  state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
170
207
  state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
171
208
 
209
+ beta1, beta2 = group['betas']
210
+
211
+ current_step = state['step']
212
+ if group['kourkoutas_beta']:
213
+ # Call prepare_step() once at the beginning of the step for all params
214
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
215
+ # Accumulate current grad's norm for the *next* step
216
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
217
+ # Get the dynamic beta2 calculated in prepare_step()
218
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
219
+
172
220
  step = state['step'] + 1
173
221
  if group['use_bias_correction']:
174
222
  bias_correction1 = 1.0 - beta1 ** step
175
- bias_correction2 = 1.0 - beta2 ** step
223
+ if group['kourkoutas_beta']:
224
+ bias_correction2 = 1.0 - group['betas'][1] ** step
225
+ # Use beta2_max for bias correction
226
+ else:
227
+ bias_correction2 = 1.0 - beta2 ** step
176
228
  else:
177
229
  bias_correction1 = 1
178
230
  bias_correction2 = 1
@@ -315,4 +367,4 @@ class AdamW_adv(torch.optim.Optimizer):
315
367
  for i, p in enumerate(group['params']):
316
368
  self.step_parameter(p, group, i)
317
369
 
318
- return loss
370
+ return loss
@@ -6,6 +6,7 @@ from ..util.Effective_Shape import _get_effective_shape
6
6
  from ..util.NNMF import _nnmf, _unnmf
7
7
  from ..util.OrthoGrad import _orthogonalize_gradient
8
8
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+ from ..util.Kourkoutas import KourkoutasHelper
9
10
 
10
11
  class Adopt_adv(torch.optim.Optimizer):
11
12
  """
@@ -72,6 +73,28 @@ class Adopt_adv(torch.optim.Optimizer):
72
73
  current gradient. For small batch sizes, use high values (e.g., 10-100) to be
73
74
  more responsive. For large batch sizes, use low values (e.g., 0-1) for
74
75
  stability. (default: 100.0)
76
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
77
+ If `False`, the optimizer behaves as standard Adopt. (default: False)
78
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
79
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
80
+ (default: 0.88)
81
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
82
+ the pooled gradient norms. Corresponds to `α` in the paper.
83
+ (default: 0.93)
84
+ tiny_spike (float): A small constant added to the denominator of the
85
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
86
+ to `ε_spike` in the paper. (default: 1e-9)
87
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
88
+ at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
89
+ dynamic logic activates. (default: 0)
90
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
91
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
92
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
93
+ logging (default: 0).
94
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
95
+ and returns a unique, hashable key representing its "layer" or "bucket".
96
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
97
+ (default: None)
75
98
  nnmf_factor (bool): whether to use the factorization or disable it to use
76
99
  the uncompressed optimizer. (default: False)
77
100
  """
@@ -96,6 +119,13 @@ class Adopt_adv(torch.optim.Optimizer):
96
119
  t_alpha: int | None = None,
97
120
  Simplified_AdEMAMix: bool = False,
98
121
  alpha_grad: float = 100.0,
122
+ kourkoutas_beta: bool = False,
123
+ beta2_min: float = 0.88,
124
+ ema_alpha: float = 0.93,
125
+ tiny_spike: float = 1e-9,
126
+ k_warmup_steps: int = 0,
127
+ k_logging: int = 0,
128
+ layer_key_fn: Optional[Callable] = None,
99
129
  nnmf_factor: bool = False,
100
130
  ):
101
131
  if not (lr >= 0.0):
@@ -111,6 +141,7 @@ class Adopt_adv(torch.optim.Optimizer):
111
141
  cautious_mask = False
112
142
  if betas[0] == 0.0 and Simplified_AdEMAMix:
113
143
  raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
144
+ if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
114
145
  if use_AdEMAMix and Simplified_AdEMAMix:
115
146
  print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
116
147
  if grams_moment and Simplified_AdEMAMix:
@@ -125,6 +156,8 @@ class Adopt_adv(torch.optim.Optimizer):
125
156
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
126
157
  "vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
127
158
  "t_alpha": t_alpha, "alpha_grad": alpha_grad,
159
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
160
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
128
161
  }
129
162
  self.clip_lambda = clip_lambda
130
163
  self.stochastic_rounding = stochastic_rounding
@@ -135,8 +168,13 @@ class Adopt_adv(torch.optim.Optimizer):
135
168
  self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
136
169
  self.Simplified_AdEMAMix = Simplified_AdEMAMix
137
170
  self.factored = nnmf_factor
171
+ self.kourkoutas_beta = kourkoutas_beta
172
+ self.layer_key_fn = layer_key_fn
138
173
  super().__init__(params, defaults)
139
174
 
175
+ if self.kourkoutas_beta:
176
+ self.kourkoutas_helper = KourkoutasHelper(self)
177
+
140
178
  @property
141
179
  def supports_fused_back_pass(self): return True
142
180
  @property
@@ -156,8 +194,6 @@ class Adopt_adv(torch.optim.Optimizer):
156
194
  grad = _orthogonalize_gradient(p, grad)
157
195
  state = self.state[p]
158
196
 
159
- beta1, beta2 = group['betas']
160
-
161
197
  # State Initialization
162
198
  if len(state) == 0:
163
199
  state['step'] = 0
@@ -176,7 +212,7 @@ class Adopt_adv(torch.optim.Optimizer):
176
212
  d1, d2 = state['effective_shape']
177
213
 
178
214
  # m_0 = 0
179
- if beta1 > 0:
215
+ if group['betas'][0] > 0:
180
216
  state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
181
217
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
182
218
  if not self.grams_moment:
@@ -195,12 +231,23 @@ class Adopt_adv(torch.optim.Optimizer):
195
231
  # Initialize v_0 using NMF
196
232
  _nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
197
233
  else: # Fallback for non-factored tensors
198
- if beta1 > 0:
234
+ if group['betas'][0] > 0:
199
235
  state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
200
236
  if self.use_AdEMAMix:
201
237
  state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
202
238
  state['exp_avg_sq'] = grad.square() # v_0
203
239
 
240
+ beta1, beta2 = group['betas']
241
+
242
+ current_step = state['step']
243
+ if group['kourkoutas_beta']:
244
+ # Call prepare_step() once at the beginning of the step for all params
245
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
246
+ # Accumulate current grad's norm for the *next* step
247
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
248
+ # Get the dynamic beta2 calculated in prepare_step()
249
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
250
+
204
251
  # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
205
252
  if state['step'] == 0 and not self.use_atan2:
206
253
  state['step'] += 1
@@ -211,10 +258,10 @@ class Adopt_adv(torch.optim.Optimizer):
211
258
  alpha = group['alpha']
212
259
  t_alpha = group['t_alpha']
213
260
  # Use step+1 for 1-based step count in scheduler
214
- current_step = state['step'] + 1
261
+ alpha_step = state['step'] + 1
215
262
  alpha_t = alpha
216
- if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
217
- alpha_t = min(current_step * alpha / t_alpha, alpha)
263
+ if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
264
+ alpha_t = min(alpha_step * alpha / t_alpha, alpha)
218
265
  if self.Simplified_AdEMAMix:
219
266
  alpha_grad = group["alpha_grad"]
220
267
 
@@ -386,4 +433,8 @@ class Adopt_adv(torch.optim.Optimizer):
386
433
  for i, p in enumerate(group['params']):
387
434
  self.step_parameter(p, group, i)
388
435
 
436
+ if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
437
+ first_param_state = self.state[self.param_groups[0]['params'][0]]
438
+ step_num = first_param_state['step']
439
+
389
440
  return loss
@@ -3,11 +3,14 @@ import torch.distributed as dist
3
3
 
4
4
  import math
5
5
 
6
+ from typing import Optional, Callable
7
+
6
8
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
7
9
  from ..util.Effective_Shape import _get_effective_shape
8
10
  from ..util.NNMF import _nnmf,_unnmf
9
11
  from ..util.OrthoGrad import _orthogonalize_gradient
10
12
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
13
+ from ..util.Kourkoutas import KourkoutasHelper
11
14
 
12
15
  class Prodigy_adv(torch.optim.Optimizer):
13
16
  """
@@ -85,6 +88,28 @@ class Prodigy_adv(torch.optim.Optimizer):
85
88
  prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
86
89
  after the specified optimiser step and release all state memory required by Prodigy
87
90
  (default: 0).
91
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
92
+ If `False`, the optimizer behaves as standard AdamW/Prodigy. (default: False)
93
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
94
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
95
+ (default: 0.88)
96
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
97
+ the pooled gradient norms. Corresponds to `α` in the paper.
98
+ (default: 0.93)
99
+ tiny_spike (float): A small constant added to the denominator of the
100
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
101
+ to `ε_spike` in the paper. (default: 1e-9)
102
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
103
+ at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
104
+ dynamic logic activates. (default: 0)
105
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
106
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
107
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
108
+ logging (default: 0).
109
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
110
+ and returns a unique, hashable key representing its "layer" or "bucket".
111
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
112
+ (default: None)
88
113
  """
89
114
 
90
115
  def __init__(
@@ -116,6 +141,13 @@ class Prodigy_adv(torch.optim.Optimizer):
116
141
  fsdp_in_use: bool = False,
117
142
  slice_p: int = 11,
118
143
  prodigy_steps: int = 0,
144
+ kourkoutas_beta: bool = False,
145
+ beta2_min: float = 0.88,
146
+ ema_alpha: float = 0.93,
147
+ tiny_spike: float = 1e-9,
148
+ k_warmup_steps: int = 0,
149
+ k_logging: int = 0,
150
+ layer_key_fn: Optional[Callable] = None,
119
151
  ):
120
152
  if not (lr >= 0.0):
121
153
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -141,6 +173,8 @@ class Prodigy_adv(torch.optim.Optimizer):
141
173
  if use_atan2 and Simplified_AdEMAMix:
142
174
  print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
143
175
  use_atan2 = False
176
+ if kourkoutas_beta and not (betas[1] > beta2_min):
177
+ raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
144
178
  if Simplified_AdEMAMix and alpha_grad > 0:
145
179
  # scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
146
180
  d_coef = d_coef/alpha_grad
@@ -153,7 +187,9 @@ class Prodigy_adv(torch.optim.Optimizer):
153
187
  "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
154
188
  "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
155
189
  "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
156
- "alpha_grad": alpha_grad,
190
+ "alpha_grad": alpha_grad,
191
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
192
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
157
193
  }
158
194
  self.stochastic_rounding = stochastic_rounding
159
195
  self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
@@ -162,7 +198,13 @@ class Prodigy_adv(torch.optim.Optimizer):
162
198
  self.Simplified_AdEMAMix = Simplified_AdEMAMix
163
199
  self.factored = nnmf_factor
164
200
  self.fsdp_in_use = fsdp_in_use
201
+
202
+ self.kourkoutas_beta = kourkoutas_beta
203
+ self.layer_key_fn = layer_key_fn
204
+
165
205
  super().__init__(params, defaults)
206
+ if self.kourkoutas_beta:
207
+ self.kourkoutas_helper = KourkoutasHelper(self)
166
208
  self.init_step()
167
209
 
168
210
  @property
@@ -180,19 +222,17 @@ class Prodigy_adv(torch.optim.Optimizer):
180
222
  def init_step(self):
181
223
  """Resets accumulators and calculates dlr for the upcoming step."""
182
224
  self.d_denom = 0.0
183
-
225
+
184
226
  g_group = self.param_groups[0]
185
- self.beta1, self.beta2 = g_group['betas']
227
+ self.beta1, self.beta2_default = g_group['betas']
186
228
  self.beta3 = g_group['beta3']
187
229
  if self.beta3 is None:
188
- self.beta3 = math.sqrt(self.beta2)
230
+ self.beta3 = math.sqrt(self.beta2_default)
189
231
 
190
- k = g_group['k']
191
232
  self.d = g_group['d']
192
233
  lr = g_group['lr']
193
234
 
194
235
  self.dlr = self.d * lr
195
-
196
236
  self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
197
237
 
198
238
  @torch.no_grad()
@@ -258,14 +298,25 @@ class Prodigy_adv(torch.optim.Optimizer):
258
298
  else:
259
299
  state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
260
300
 
301
+ current_step = state['step']
302
+ if group['kourkoutas_beta']:
303
+ # Call prepare_step() once at the beginning of the step for all params
304
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
305
+ # Accumulate current grad's norm for the *next* step
306
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
307
+ # Get the dynamic beta2 calculated in prepare_step()
308
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
309
+ else:
310
+ beta2 = self.beta2_default
311
+
261
312
  if self.use_AdEMAMix:
262
313
  beta3_ema = group['beta3_ema']
263
314
  alpha = group['alpha']
264
315
  t_alpha = group['t_alpha']
265
- current_step = state['step'] + 1
316
+ alpha_step = state['step'] + 1
266
317
  alpha_t = alpha
267
- if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
268
- alpha_t = min(current_step * alpha / t_alpha, alpha)
318
+ if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
319
+ alpha_t = min(alpha_step * alpha / t_alpha, alpha)
269
320
  if self.Simplified_AdEMAMix:
270
321
  alpha_grad = group["alpha_grad"]
271
322
 
@@ -295,7 +346,7 @@ class Prodigy_adv(torch.optim.Optimizer):
295
346
  del mask
296
347
 
297
348
  vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
298
- vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
349
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - beta2))
299
350
 
300
351
  if self.use_AdEMAMix:
301
352
  mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
@@ -368,7 +419,7 @@ class Prodigy_adv(torch.optim.Optimizer):
368
419
  else:
369
420
  update = exp_avg.clone() if self.beta1 > 0 else grad.mul(self.d)
370
421
 
371
- exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
422
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - beta2))
372
423
 
373
424
  if group['use_atan2']:
374
425
  a = 1.2732395
@@ -431,7 +482,6 @@ class Prodigy_adv(torch.optim.Optimizer):
431
482
  for i, p in enumerate(group['params']):
432
483
  self.step_parameter(p, group, i)
433
484
 
434
-
435
485
  self.calculate_d()
436
486
  self.init_step()
437
487
  return loss
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ from typing import Optional, Callable
2
3
 
3
4
  import math
4
5
 
@@ -7,6 +8,7 @@ from ..util.Effective_Shape import _get_effective_shape
7
8
  from ..util.NNMF import _nnmf,_unnmf
8
9
  from ..util.OrthoGrad import _orthogonalize_gradient
9
10
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
11
+ from ..util.Kourkoutas import KourkoutasHelper
10
12
 
11
13
  # A little helper from the original simplified_AdEMAMix
12
14
  def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
@@ -47,6 +49,28 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
47
49
  stochastic_rounding (bool): whether to use stochastic
48
50
  rounding for BF16 parameter updates (default: True).
49
51
  orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
52
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
53
+ If `False`, the optimizer behaves as standard Simplified_AdEMAMix. (default: False)
54
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
55
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
56
+ (default: 0.88)
57
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
58
+ the pooled gradient norms. Corresponds to `α` in the paper.
59
+ (default: 0.93)
60
+ tiny_spike (float): A small constant added to the denominator of the
61
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
62
+ to `ε_spike` in the paper. (default: 1e-9)
63
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
64
+ at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
65
+ dynamic logic activates. (default: 0)
66
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
67
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
68
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
69
+ logging (default: 0).
70
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
71
+ and returns a unique, hashable key representing its "layer" or "bucket".
72
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
73
+ (default: None)
50
74
  nnmf_factor (bool): whether to use the factorization or disable it to use
51
75
  the uncompressed optimizer. (default: False)
52
76
  """
@@ -65,6 +89,13 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
65
89
  vector_reshape: bool = True,
66
90
  stochastic_rounding: bool = True,
67
91
  orthogonal_gradient: bool = False,
92
+ kourkoutas_beta: bool = False,
93
+ beta2_min: float = 0.88,
94
+ ema_alpha: float = 0.93,
95
+ tiny_spike: float = 1e-9,
96
+ k_warmup_steps: int = 0,
97
+ k_logging: int = 0,
98
+ layer_key_fn: Optional[Callable] = None,
68
99
  nnmf_factor: bool = False,
69
100
  ):
70
101
  if not (lr >= 0.0):
@@ -77,17 +108,25 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
77
108
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
78
109
  if not 0.0 <= alpha_grad:
79
110
  raise ValueError("Invalid alpha value: {}".format(alpha_grad))
111
+ if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
80
112
 
81
113
  defaults = {
82
114
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
83
115
  "alpha_grad": alpha_grad, "beta1_warmup": beta1_warmup, "min_beta1": min_beta1,
84
116
  "vector_reshape": vector_reshape,
85
117
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
118
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
119
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
86
120
  }
87
121
  self.stochastic_rounding = stochastic_rounding
88
122
  self.factored = nnmf_factor
123
+ self.kourkoutas_beta = kourkoutas_beta
124
+ self.layer_key_fn = layer_key_fn
89
125
  super().__init__(params, defaults)
90
126
 
127
+ if self.kourkoutas_beta:
128
+ self.kourkoutas_helper = KourkoutasHelper(self)
129
+
91
130
  @property
92
131
  def supports_fused_back_pass(self):
93
132
  return True
@@ -150,6 +189,16 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
150
189
  state['den_sum'] = 1.0
151
190
 
152
191
  beta1_final, beta2 = group["betas"]
192
+
193
+ current_step = state['step']
194
+ if group['kourkoutas_beta']:
195
+ # Call prepare_step() once at the beginning of the step for all params
196
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
197
+ # Accumulate current grad's norm for the *next* step
198
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
199
+ # Get the dynamic beta2 calculated in prepare_step()
200
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
201
+
153
202
  beta1_warmup = group["beta1_warmup"]
154
203
  alpha_grad = group["alpha_grad"]
155
204
 
@@ -161,7 +210,10 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
161
210
 
162
211
  if group['use_bias_correction']:
163
212
  state['num_sum'] = beta1 * state['num_sum'] + 1.0
164
- state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
213
+ if group['kourkoutas_beta']:
214
+ state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
215
+ else:
216
+ state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
165
217
 
166
218
  if state['factored']:
167
219
  d1, d2 = state['effective_shape']
@@ -0,0 +1,134 @@
1
+ import torch
2
+ from torch.optim import Optimizer
3
+ from typing import Callable
4
+
5
+ class KourkoutasHelper:
6
+ """
7
+ A helper class to add layer-wise Kourkoutas-β functionality to a PyTorch optimizer.
8
+ """
9
+ def __init__(self, optimizer: Optimizer):
10
+ # We need a reference to the optimizer to access its param_groups and state
11
+ if not hasattr(optimizer, 'param_groups'):
12
+ raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
13
+ self.optimizer = optimizer
14
+
15
+ # State managed by the helper
16
+ self.layer_state = {}
17
+ self.layer_info = {}
18
+ self._layer_info_built = False
19
+ self._current_step_prepared = -1
20
+
21
+ def _build_layer_info_if_needed(self):
22
+ """Builds a map of layers and the parameters they contain."""
23
+ if self._layer_info_built:
24
+ return
25
+
26
+ if not hasattr(self.optimizer, 'layer_key_fn') or self.optimizer.layer_key_fn is None:
27
+ print("Warning: KourkoutasHelper requires 'layer_key_fn' on the optimizer. Defaulting to tensor-wise (id).")
28
+ self.optimizer.layer_key_fn = lambda p: id(p)
29
+
30
+ for group in self.optimizer.param_groups:
31
+ for p in group['params']:
32
+ if p.grad is None: continue
33
+ layer_key = self.optimizer.layer_key_fn(p)
34
+ if layer_key not in self.layer_info:
35
+ self.layer_info[layer_key] = {'params': [], 'group_ref': group}
36
+ self.layer_info[layer_key]['params'].append(p)
37
+
38
+ k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
39
+ if k_logging_interval > 0:
40
+ print(f"[Kourkoutas-β Debug] Layer info built. Found {len(self.layer_info)} unique layers/buckets.")
41
+
42
+ self._layer_info_built = True
43
+
44
+ def prepare_step(self, current_step: int):
45
+ """
46
+ Calculates dynamic beta2 for all layers using the completed scalar accumulators
47
+ from the PREVIOUS step. Should be called once at the start of an optimizer step.
48
+ """
49
+ self._build_layer_info_if_needed()
50
+
51
+ # Check if logging is enabled for this step based on the interval
52
+ k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
53
+ is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
54
+
55
+ beta2_log = [] if is_logging_step else None
56
+ first_layer_key = next(iter(self.layer_info), None)
57
+
58
+ for layer_key, info in self.layer_info.items():
59
+ params, group = info['params'], info['group_ref']
60
+
61
+ if layer_key not in self.layer_state:
62
+ self.layer_state[layer_key] = {
63
+ 'r_ema_grad_norm': torch.tensor(0.0, device=params[0].device, dtype=torch.float32),
64
+ 'sum_sq_accumulator': torch.tensor(0.0, device=params[0].device, dtype=torch.float32)
65
+ }
66
+
67
+ layer_state = self.layer_state[layer_key]
68
+
69
+ # Use the completed accumulator from the previous step
70
+ pooled_grad_norm = torch.sqrt(layer_state['sum_sq_accumulator'])
71
+
72
+ r_ema = layer_state['r_ema_grad_norm']
73
+ prev_r_ema_val = r_ema.item() # for logging
74
+
75
+ # EMA is always updated, even during warmup
76
+ r_ema.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
77
+
78
+ sun = torch.tensor(0.0, device=r_ema.device) # Default sun to 0 for warmup
79
+ beta2_max = group['betas'][1]
80
+
81
+ # --- CONSOLIDATED WARMUP LOGIC ---
82
+ if current_step < group['k_warmup_steps']:
83
+ beta2 = beta2_max
84
+ else:
85
+ raw = pooled_grad_norm / (r_ema + group['tiny_spike'])
86
+ sun = raw / (1.0 + raw)
87
+ beta2 = beta2_max - (beta2_max - group['beta2_min']) * sun
88
+
89
+ layer_state['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
90
+ layer_state['sum_sq_accumulator'].zero_()
91
+
92
+ if is_logging_step:
93
+ beta2_log.append(layer_state['dynamic_beta2'])
94
+ if layer_key == first_layer_key:
95
+ print(f"\n[Kourkoutas-β Debug] Step {current_step + 1} - Sample Layer '{layer_key}':")
96
+ print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema.item():.4e}")
97
+ print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {layer_state['dynamic_beta2']:.4f}")
98
+
99
+ if is_logging_step and beta2_log:
100
+ beta2_tensor = torch.tensor(beta2_log, device='cpu')
101
+ print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
102
+
103
+
104
+ def maybe_prepare_step(self, current_step: int):
105
+ """
106
+ A universal guard that calls prepare_step() exactly once per training step.
107
+ """
108
+ if self._current_step_prepared < current_step:
109
+ self.prepare_step(current_step)
110
+ self._current_step_prepared = current_step
111
+
112
+ def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
113
+ """
114
+ Accumulates the squared L2 norm of a single gradient for the next step's calculation.
115
+ """
116
+ self._build_layer_info_if_needed()
117
+ layer_key = self.optimizer.layer_key_fn(p)
118
+
119
+ if layer_key in self.layer_info:
120
+ if layer_key not in self.layer_state:
121
+ self.layer_state[layer_key] = {
122
+ 'r_ema_grad_norm': torch.tensor(0.0, device=p.device, dtype=torch.float32),
123
+ 'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
124
+ }
125
+ # Accumulate for the *next* step's prepare_step call
126
+ self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
127
+
128
+ def get_beta2(self, p: torch.Tensor, group: dict, current_step: int) -> float:
129
+ """
130
+ Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
131
+ """
132
+ layer_key = self.optimizer.layer_key_fn(p)
133
+ # The default is the max value, which is correct for unmapped params or edge cases
134
+ return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.0.6
3
+ Version: 1.1.0.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -16,6 +16,7 @@ adv_optm/optim/Simplified_AdEMAMix.py
16
16
  adv_optm/optim/__init__.py
17
17
  adv_optm/util/BF16_Stochastic_Rounding.py
18
18
  adv_optm/util/Effective_Shape.py
19
+ adv_optm/util/Kourkoutas.py
19
20
  adv_optm/util/NNMF.py
20
21
  adv_optm/util/One_Bit_Boolean.py
21
22
  adv_optm/util/OrthoGrad.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="1.0.6",
8
+ version="1.1.0.dev2",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes