adv-optm 1.0.6__tar.gz → 1.1.0.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (25) hide show
  1. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/PKG-INFO +1 -1
  2. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/__init__.py +1 -1
  3. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/optim/AdamW_adv.py +68 -7
  4. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/optim/Adopt_adv.py +60 -4
  5. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/optim/Prodigy_adv.py +67 -8
  6. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/optim/Simplified_AdEMAMix.py +62 -1
  7. adv_optm-1.1.0.dev1/adv_optm/util/Kourkoutas.py +108 -0
  8. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm.egg-info/PKG-INFO +1 -1
  9. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm.egg-info/SOURCES.txt +1 -0
  10. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/setup.py +1 -1
  11. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/LICENSE +0 -0
  12. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/README.md +0 -0
  13. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  14. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/optim/Lion_adv.py +0 -0
  15. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/optim/__init__.py +0 -0
  16. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  17. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/util/Effective_Shape.py +0 -0
  18. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/util/NNMF.py +0 -0
  19. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/util/One_Bit_Boolean.py +0 -0
  20. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/util/OrthoGrad.py +0 -0
  21. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm/util/__init__.py +0 -0
  22. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm.egg-info/dependency_links.txt +0 -0
  23. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm.egg-info/requires.txt +0 -0
  24. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/adv_optm.egg-info/top_level.txt +0 -0
  25. {adv_optm-1.0.6 → adv_optm-1.1.0.dev1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.0.6
3
+ Version: 1.1.0.dev1
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -16,4 +16,4 @@ __all__ = [
16
16
  "Lion_Prodigy_adv",
17
17
  ]
18
18
 
19
- __version__ = "1.0.6"
19
+ __version__ = "1.1.0.dev1"
@@ -1,11 +1,12 @@
1
1
  import torch
2
- from typing import Optional
2
+ from typing import Optional, Callable
3
3
 
4
4
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
5
  from ..util.Effective_Shape import _get_effective_shape
6
6
  from ..util.NNMF import _nnmf,_unnmf
7
7
  from ..util.OrthoGrad import _orthogonalize_gradient
8
8
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+ from ..util.Kourkoutas import KourkoutasHelper
9
10
 
10
11
  class AdamW_adv(torch.optim.Optimizer):
11
12
  """
@@ -54,6 +55,28 @@ class AdamW_adv(torch.optim.Optimizer):
54
55
  as it gradually introduces the stabilizing slow momentum term. During
55
56
  the warmup, `alpha` ramps from 0 to its target value. If `None`,
56
57
  the scheduler is disabled. (default: None)
58
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
59
+ If `False`, the optimizer behaves as standard AdamW. (default: False)
60
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
61
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
62
+ (default: 0.88)
63
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
64
+ the pooled gradient norms. Corresponds to `α` in the paper.
65
+ (default: 0.93)
66
+ tiny_spike (float): A small constant added to the denominator of the
67
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
68
+ to `ε_spike` in the paper. (default: 1e-9)
69
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
70
+ at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
71
+ dynamic logic activates. (default: 0)
72
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
73
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
74
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
75
+ logging (default: 0).
76
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
77
+ and returns a unique, hashable key representing its "layer" or "bucket".
78
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
79
+ (default: None)
57
80
  nnmf_factor (bool): whether to use the factorization or disable it to use
58
81
  the uncompressed optimizer. (default: False)
59
82
  """
@@ -76,6 +99,13 @@ class AdamW_adv(torch.optim.Optimizer):
76
99
  beta3_ema: float = 0.9999,
77
100
  alpha: float = 5.0,
78
101
  t_alpha: int | None = None,
102
+ kourkoutas_beta: bool = False,
103
+ beta2_min: float = 0.88,
104
+ ema_alpha: float = 0.93,
105
+ tiny_spike: float = 1e-9,
106
+ k_warmup_steps: int = 0,
107
+ k_logging: int = 0,
108
+ layer_key_fn: Optional[Callable] = None,
79
109
  nnmf_factor: bool = False,
80
110
  ):
81
111
  if not (lr >= 0.0):
@@ -86,6 +116,8 @@ class AdamW_adv(torch.optim.Optimizer):
86
116
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
87
117
  if not (weight_decay >= 0.0):
88
118
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
119
+ if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
120
+
89
121
  if cautious_mask and grams_moment:
90
122
  print("Warning: cautious is incompatible with grams, Disabling cautious.")
91
123
  cautious_mask = False
@@ -95,6 +127,8 @@ class AdamW_adv(torch.optim.Optimizer):
95
127
  "vector_reshape": vector_reshape, "use_atan2": use_atan2,
96
128
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
97
129
  "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
130
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
131
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
98
132
  }
99
133
  self.stochastic_rounding = stochastic_rounding
100
134
  self.cautious_mask = cautious_mask
@@ -103,6 +137,12 @@ class AdamW_adv(torch.optim.Optimizer):
103
137
  self.factored = nnmf_factor
104
138
  super().__init__(params, defaults)
105
139
 
140
+ self.kourkoutas_beta = kourkoutas_beta
141
+ self.k_logging= k_logging and kourkoutas_beta
142
+ self.layer_key_fn = layer_key_fn and kourkoutas_beta
143
+ if self.kourkoutas_beta:
144
+ self.kourkoutas_helper = KourkoutasHelper(self)
145
+
106
146
  @property
107
147
  def supports_fused_back_pass(self):
108
148
  return True
@@ -127,8 +167,6 @@ class AdamW_adv(torch.optim.Optimizer):
127
167
  grad = _orthogonalize_gradient(p, grad)
128
168
  state = self.state[p]
129
169
 
130
- beta1, beta2 = group['betas']
131
-
132
170
  # State Initialization
133
171
  if len(state) == 0:
134
172
  state['step'] = 0
@@ -148,7 +186,7 @@ class AdamW_adv(torch.optim.Optimizer):
148
186
  d1, d2 = state['effective_shape']
149
187
 
150
188
  # First moment (m)
151
- if beta1 > 0:
189
+ if group['betas'][0] > 0:
152
190
  state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
153
191
  state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
154
192
  if not self.grams_moment:
@@ -163,16 +201,29 @@ class AdamW_adv(torch.optim.Optimizer):
163
201
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
164
202
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
165
203
  else: # Fallback to standard AdamW for non-factored tensors
166
- if beta1 > 0:
204
+ if group['betas'][0] > 0:
167
205
  state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
168
206
  if self.use_AdEMAMix:
169
207
  state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
170
208
  state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
171
209
 
210
+ current_step = state['step']
211
+ if group['kourkoutas_beta']:
212
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
213
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
214
+
215
+ beta1, beta2 = group['betas']
216
+ if group['kourkoutas_beta']:
217
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
218
+
172
219
  step = state['step'] + 1
173
220
  if group['use_bias_correction']:
174
221
  bias_correction1 = 1.0 - beta1 ** step
175
- bias_correction2 = 1.0 - beta2 ** step
222
+ if group['kourkoutas_beta']:
223
+ bias_correction2 = 1.0 - group['betas'][1] ** step
224
+ # Use beta2_max for bias correction
225
+ else:
226
+ bias_correction2 = 1.0 - beta2 ** step
176
227
  else:
177
228
  bias_correction1 = 1
178
229
  bias_correction2 = 1
@@ -315,4 +366,14 @@ class AdamW_adv(torch.optim.Optimizer):
315
366
  for i, p in enumerate(group['params']):
316
367
  self.step_parameter(p, group, i)
317
368
 
318
- return loss
369
+ if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
370
+ first_param_state = self.state[self.param_groups[0]['params'][0]]
371
+ step_num = first_param_state['step']
372
+
373
+ if step_num > 0 and step_num % self.k_logging == 0:
374
+ if self._beta2_log:
375
+ beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
376
+ print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
377
+ delattr(self, '_beta2_log')
378
+
379
+ return loss
@@ -6,6 +6,7 @@ from ..util.Effective_Shape import _get_effective_shape
6
6
  from ..util.NNMF import _nnmf, _unnmf
7
7
  from ..util.OrthoGrad import _orthogonalize_gradient
8
8
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+ from ..util.Kourkoutas import KourkoutasHelper
9
10
 
10
11
  class Adopt_adv(torch.optim.Optimizer):
11
12
  """
@@ -72,6 +73,28 @@ class Adopt_adv(torch.optim.Optimizer):
72
73
  current gradient. For small batch sizes, use high values (e.g., 10-100) to be
73
74
  more responsive. For large batch sizes, use low values (e.g., 0-1) for
74
75
  stability. (default: 100.0)
76
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
77
+ If `False`, the optimizer behaves as standard Adopt. (default: False)
78
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
79
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
80
+ (default: 0.88)
81
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
82
+ the pooled gradient norms. Corresponds to `α` in the paper.
83
+ (default: 0.93)
84
+ tiny_spike (float): A small constant added to the denominator of the
85
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
86
+ to `ε_spike` in the paper. (default: 1e-9)
87
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
88
+ at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
89
+ dynamic logic activates. (default: 0)
90
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
91
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
92
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
93
+ logging (default: 0).
94
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
95
+ and returns a unique, hashable key representing its "layer" or "bucket".
96
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
97
+ (default: None)
75
98
  nnmf_factor (bool): whether to use the factorization or disable it to use
76
99
  the uncompressed optimizer. (default: False)
77
100
  """
@@ -96,6 +119,13 @@ class Adopt_adv(torch.optim.Optimizer):
96
119
  t_alpha: int | None = None,
97
120
  Simplified_AdEMAMix: bool = False,
98
121
  alpha_grad: float = 100.0,
122
+ kourkoutas_beta: bool = False,
123
+ beta2_min: float = 0.88,
124
+ ema_alpha: float = 0.93,
125
+ tiny_spike: float = 1e-9,
126
+ k_warmup_steps: int = 0,
127
+ k_logging: int = 0,
128
+ layer_key_fn: Optional[Callable] = None,
99
129
  nnmf_factor: bool = False,
100
130
  ):
101
131
  if not (lr >= 0.0):
@@ -111,6 +141,7 @@ class Adopt_adv(torch.optim.Optimizer):
111
141
  cautious_mask = False
112
142
  if betas[0] == 0.0 and Simplified_AdEMAMix:
113
143
  raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
144
+ if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
114
145
  if use_AdEMAMix and Simplified_AdEMAMix:
115
146
  print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
116
147
  if grams_moment and Simplified_AdEMAMix:
@@ -125,6 +156,8 @@ class Adopt_adv(torch.optim.Optimizer):
125
156
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
126
157
  "vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
127
158
  "t_alpha": t_alpha, "alpha_grad": alpha_grad,
159
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
160
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
128
161
  }
129
162
  self.clip_lambda = clip_lambda
130
163
  self.stochastic_rounding = stochastic_rounding
@@ -137,6 +170,12 @@ class Adopt_adv(torch.optim.Optimizer):
137
170
  self.factored = nnmf_factor
138
171
  super().__init__(params, defaults)
139
172
 
173
+ self.kourkoutas_beta = kourkoutas_beta
174
+ self.k_logging= k_logging and kourkoutas_beta
175
+ self.layer_key_fn = layer_key_fn and kourkoutas_beta
176
+ if self.kourkoutas_beta:
177
+ self.kourkoutas_helper = KourkoutasHelper(self)
178
+
140
179
  @property
141
180
  def supports_fused_back_pass(self): return True
142
181
  @property
@@ -156,8 +195,6 @@ class Adopt_adv(torch.optim.Optimizer):
156
195
  grad = _orthogonalize_gradient(p, grad)
157
196
  state = self.state[p]
158
197
 
159
- beta1, beta2 = group['betas']
160
-
161
198
  # State Initialization
162
199
  if len(state) == 0:
163
200
  state['step'] = 0
@@ -176,7 +213,7 @@ class Adopt_adv(torch.optim.Optimizer):
176
213
  d1, d2 = state['effective_shape']
177
214
 
178
215
  # m_0 = 0
179
- if beta1 > 0:
216
+ if group['betas'][0] > 0:
180
217
  state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
181
218
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
182
219
  if not self.grams_moment:
@@ -195,12 +232,21 @@ class Adopt_adv(torch.optim.Optimizer):
195
232
  # Initialize v_0 using NMF
196
233
  _nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
197
234
  else: # Fallback for non-factored tensors
198
- if beta1 > 0:
235
+ if group['betas'][0] > 0:
199
236
  state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
200
237
  if self.use_AdEMAMix:
201
238
  state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
202
239
  state['exp_avg_sq'] = grad.square() # v_0
203
240
 
241
+ current_step = state['step']
242
+ if group['kourkoutas_beta']:
243
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
244
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
245
+
246
+ beta1, beta2 = group['betas']
247
+ if group['kourkoutas_beta']:
248
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
249
+
204
250
  # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
205
251
  if state['step'] == 0 and not self.use_atan2:
206
252
  state['step'] += 1
@@ -386,4 +432,14 @@ class Adopt_adv(torch.optim.Optimizer):
386
432
  for i, p in enumerate(group['params']):
387
433
  self.step_parameter(p, group, i)
388
434
 
435
+ if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
436
+ first_param_state = self.state[self.param_groups[0]['params'][0]]
437
+ step_num = first_param_state['step']
438
+
439
+ if step_num > 0 and step_num % self.k_logging == 0:
440
+ if self._beta2_log:
441
+ beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
442
+ print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
443
+ delattr(self, '_beta2_log')
444
+
389
445
  return loss
@@ -3,11 +3,14 @@ import torch.distributed as dist
3
3
 
4
4
  import math
5
5
 
6
+ from typing import Optional, Callable
7
+
6
8
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
7
9
  from ..util.Effective_Shape import _get_effective_shape
8
10
  from ..util.NNMF import _nnmf,_unnmf
9
11
  from ..util.OrthoGrad import _orthogonalize_gradient
10
12
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
13
+ from ..util.Kourkoutas import KourkoutasHelper
11
14
 
12
15
  class Prodigy_adv(torch.optim.Optimizer):
13
16
  """
@@ -85,6 +88,28 @@ class Prodigy_adv(torch.optim.Optimizer):
85
88
  prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
86
89
  after the specified optimiser step and release all state memory required by Prodigy
87
90
  (default: 0).
91
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
92
+ If `False`, the optimizer behaves as standard AdamW/Prodigy. (default: False)
93
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
94
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
95
+ (default: 0.88)
96
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
97
+ the pooled gradient norms. Corresponds to `α` in the paper.
98
+ (default: 0.93)
99
+ tiny_spike (float): A small constant added to the denominator of the
100
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
101
+ to `ε_spike` in the paper. (default: 1e-9)
102
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
103
+ at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
104
+ dynamic logic activates. (default: 0)
105
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
106
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
107
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
108
+ logging (default: 0).
109
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
110
+ and returns a unique, hashable key representing its "layer" or "bucket".
111
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
112
+ (default: None)
88
113
  """
89
114
 
90
115
  def __init__(
@@ -116,6 +141,13 @@ class Prodigy_adv(torch.optim.Optimizer):
116
141
  fsdp_in_use: bool = False,
117
142
  slice_p: int = 11,
118
143
  prodigy_steps: int = 0,
144
+ kourkoutas_beta: bool = False,
145
+ beta2_min: float = 0.88,
146
+ ema_alpha: float = 0.93,
147
+ tiny_spike: float = 1e-9,
148
+ k_warmup_steps: int = 0,
149
+ k_logging: int = 0,
150
+ layer_key_fn: Optional[Callable] = None,
119
151
  ):
120
152
  if not (lr >= 0.0):
121
153
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -141,6 +173,8 @@ class Prodigy_adv(torch.optim.Optimizer):
141
173
  if use_atan2 and Simplified_AdEMAMix:
142
174
  print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
143
175
  use_atan2 = False
176
+ if kourkoutas_beta and not (betas[1] > beta2_min):
177
+ raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
144
178
  if Simplified_AdEMAMix and alpha_grad > 0:
145
179
  # scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
146
180
  d_coef = d_coef/alpha_grad
@@ -153,7 +187,9 @@ class Prodigy_adv(torch.optim.Optimizer):
153
187
  "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
154
188
  "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
155
189
  "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
156
- "alpha_grad": alpha_grad,
190
+ "alpha_grad": alpha_grad,
191
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
192
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
157
193
  }
158
194
  self.stochastic_rounding = stochastic_rounding
159
195
  self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
@@ -163,6 +199,13 @@ class Prodigy_adv(torch.optim.Optimizer):
163
199
  self.factored = nnmf_factor
164
200
  self.fsdp_in_use = fsdp_in_use
165
201
  super().__init__(params, defaults)
202
+
203
+ self.kourkoutas_beta = kourkoutas_beta
204
+ self.k_logging= k_logging and kourkoutas_beta
205
+ self.layer_key_fn = layer_key_fn and kourkoutas_beta
206
+ if self.kourkoutas_beta:
207
+ self.kourkoutas_helper = KourkoutasHelper(self)
208
+
166
209
  self.init_step()
167
210
 
168
211
  @property
@@ -180,19 +223,17 @@ class Prodigy_adv(torch.optim.Optimizer):
180
223
  def init_step(self):
181
224
  """Resets accumulators and calculates dlr for the upcoming step."""
182
225
  self.d_denom = 0.0
183
-
226
+
184
227
  g_group = self.param_groups[0]
185
- self.beta1, self.beta2 = g_group['betas']
228
+ self.beta1, self.beta2_default = g_group['betas']
186
229
  self.beta3 = g_group['beta3']
187
230
  if self.beta3 is None:
188
- self.beta3 = math.sqrt(self.beta2)
231
+ self.beta3 = math.sqrt(self.beta2_default)
189
232
 
190
- k = g_group['k']
191
233
  self.d = g_group['d']
192
234
  lr = g_group['lr']
193
235
 
194
236
  self.dlr = self.d * lr
195
-
196
237
  self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
197
238
 
198
239
  @torch.no_grad()
@@ -258,6 +299,15 @@ class Prodigy_adv(torch.optim.Optimizer):
258
299
  else:
259
300
  state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
260
301
 
302
+ current_step = state['step']
303
+ if group['kourkoutas_beta']:
304
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
305
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
306
+
307
+ beta2 = self.beta2_default
308
+ if group['kourkoutas_beta']:
309
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
310
+
261
311
  if self.use_AdEMAMix:
262
312
  beta3_ema = group['beta3_ema']
263
313
  alpha = group['alpha']
@@ -295,7 +345,7 @@ class Prodigy_adv(torch.optim.Optimizer):
295
345
  del mask
296
346
 
297
347
  vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
298
- vt.mul_(self.beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - self.beta2))
348
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - beta2))
299
349
 
300
350
  if self.use_AdEMAMix:
301
351
  mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
@@ -368,7 +418,7 @@ class Prodigy_adv(torch.optim.Optimizer):
368
418
  else:
369
419
  update = exp_avg.clone() if self.beta1 > 0 else grad.mul(self.d)
370
420
 
371
- exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - self.beta2))
421
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - beta2))
372
422
 
373
423
  if group['use_atan2']:
374
424
  a = 1.2732395
@@ -431,6 +481,15 @@ class Prodigy_adv(torch.optim.Optimizer):
431
481
  for i, p in enumerate(group['params']):
432
482
  self.step_parameter(p, group, i)
433
483
 
484
+ if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
485
+ first_param_state = self.state[self.param_groups[0]['params'][0]]
486
+ step_num = first_param_state['step']
487
+
488
+ if step_num > 0 and step_num % self.k_logging == 0:
489
+ if self._beta2_log:
490
+ beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
491
+ print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
492
+ delattr(self, '_beta2_log')
434
493
 
435
494
  self.calculate_d()
436
495
  self.init_step()
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ from typing import Optional, Callable
2
3
 
3
4
  import math
4
5
 
@@ -7,6 +8,7 @@ from ..util.Effective_Shape import _get_effective_shape
7
8
  from ..util.NNMF import _nnmf,_unnmf
8
9
  from ..util.OrthoGrad import _orthogonalize_gradient
9
10
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
11
+ from ..util.Kourkoutas import KourkoutasHelper
10
12
 
11
13
  # A little helper from the original simplified_AdEMAMix
12
14
  def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
@@ -47,6 +49,28 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
47
49
  stochastic_rounding (bool): whether to use stochastic
48
50
  rounding for BF16 parameter updates (default: True).
49
51
  orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
52
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
53
+ If `False`, the optimizer behaves as standard Simplified_AdEMAMix. (default: False)
54
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
55
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
56
+ (default: 0.88)
57
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
58
+ the pooled gradient norms. Corresponds to `α` in the paper.
59
+ (default: 0.93)
60
+ tiny_spike (float): A small constant added to the denominator of the
61
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
62
+ to `ε_spike` in the paper. (default: 1e-9)
63
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
64
+ at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
65
+ dynamic logic activates. (default: 0)
66
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
67
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
68
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
69
+ logging (default: 0).
70
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
71
+ and returns a unique, hashable key representing its "layer" or "bucket".
72
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
73
+ (default: None)
50
74
  nnmf_factor (bool): whether to use the factorization or disable it to use
51
75
  the uncompressed optimizer. (default: False)
52
76
  """
@@ -65,6 +89,13 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
65
89
  vector_reshape: bool = True,
66
90
  stochastic_rounding: bool = True,
67
91
  orthogonal_gradient: bool = False,
92
+ kourkoutas_beta: bool = False,
93
+ beta2_min: float = 0.88,
94
+ ema_alpha: float = 0.93,
95
+ tiny_spike: float = 1e-9,
96
+ k_warmup_steps: int = 0,
97
+ k_logging: int = 0,
98
+ layer_key_fn: Optional[Callable] = None,
68
99
  nnmf_factor: bool = False,
69
100
  ):
70
101
  if not (lr >= 0.0):
@@ -77,17 +108,26 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
77
108
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
78
109
  if not 0.0 <= alpha_grad:
79
110
  raise ValueError("Invalid alpha value: {}".format(alpha_grad))
111
+ if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
80
112
 
81
113
  defaults = {
82
114
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
83
115
  "alpha_grad": alpha_grad, "beta1_warmup": beta1_warmup, "min_beta1": min_beta1,
84
116
  "vector_reshape": vector_reshape,
85
117
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
118
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
119
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
86
120
  }
87
121
  self.stochastic_rounding = stochastic_rounding
88
122
  self.factored = nnmf_factor
89
123
  super().__init__(params, defaults)
90
124
 
125
+ self.kourkoutas_beta = kourkoutas_beta
126
+ self.k_logging= k_logging and kourkoutas_beta
127
+ self.layer_key_fn = layer_key_fn and kourkoutas_beta
128
+ if self.kourkoutas_beta:
129
+ self.kourkoutas_helper = KourkoutasHelper(self)
130
+
91
131
  @property
92
132
  def supports_fused_back_pass(self):
93
133
  return True
@@ -149,9 +189,17 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
149
189
  state['num_sum'] = 1.0
150
190
  state['den_sum'] = 1.0
151
191
 
192
+ current_step = state['step']
193
+ if group['kourkoutas_beta']:
194
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
195
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
196
+
152
197
  beta1_final, beta2 = group["betas"]
153
198
  beta1_warmup = group["beta1_warmup"]
154
199
  alpha_grad = group["alpha_grad"]
200
+
201
+ if group['kourkoutas_beta']:
202
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
155
203
 
156
204
  if beta1_warmup is not None:
157
205
  step = state['step'] + 1
@@ -161,7 +209,10 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
161
209
 
162
210
  if group['use_bias_correction']:
163
211
  state['num_sum'] = beta1 * state['num_sum'] + 1.0
164
- state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
212
+ if group['kourkoutas_beta']:
213
+ state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
214
+ else:
215
+ state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
165
216
 
166
217
  if state['factored']:
167
218
  d1, d2 = state['effective_shape']
@@ -243,4 +294,14 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
243
294
  for i, p in enumerate(group['params']):
244
295
  self.step_parameter(p, group, i)
245
296
 
297
+ if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
298
+ first_param_state = self.state[self.param_groups[0]['params'][0]]
299
+ step_num = first_param_state['step']
300
+
301
+ if step_num > 0 and step_num % self.k_logging == 0:
302
+ if self._beta2_log:
303
+ beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
304
+ print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
305
+ delattr(self, '_beta2_log')
306
+
246
307
  return loss
@@ -0,0 +1,108 @@
1
+ import torch
2
+ from torch.optim import Optimizer
3
+ from typing import Callable
4
+
5
+ class KourkoutasHelper:
6
+ """
7
+ A helper class to add layer-wise Kourkoutas-β functionality to a PyTorch optimizer.
8
+ """
9
+ def __init__(self, optimizer: Optimizer):
10
+ # We need a reference to the optimizer to access its param_groups and state
11
+ if not hasattr(optimizer, 'param_groups'):
12
+ raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
13
+ self.optimizer = optimizer
14
+
15
+ # State managed by the helper
16
+ self.layer_state = {}
17
+ self.layer_info = {}
18
+ self._layer_info_built = False
19
+ self._current_step_prepared = -1
20
+
21
+ def _build_layer_info_if_needed(self):
22
+ """Builds a map of layers and the parameters they contain."""
23
+ if self._layer_info_built:
24
+ return
25
+
26
+ if not hasattr(self.optimizer, 'layer_key_fn') or self.optimizer.layer_key_fn is None:
27
+ print("Warning: KourkoutasHelper requires 'layer_key_fn' on the optimizer. Defaulting to tensor-wise (id).")
28
+ self.optimizer.layer_key_fn = lambda p: id(p)
29
+
30
+ for group in self.optimizer.param_groups:
31
+ if not group.get('kourkoutas_beta', False):
32
+ continue
33
+ for p in group['params']:
34
+ if p.grad is None: continue
35
+ layer_key = self.optimizer.layer_key_fn(p)
36
+ if layer_key not in self.layer_info:
37
+ self.layer_info[layer_key] = {'params': [], 'group_ref': group}
38
+ self.layer_info[layer_key]['params'].append(p)
39
+ self._layer_info_built = True
40
+
41
+ def prepare_step(self):
42
+ """
43
+ Calculates dynamic beta2 for all layers using the completed scalar accumulators
44
+ from the PREVIOUS step. Should be called once at the start of an optimizer step.
45
+ """
46
+ self._build_layer_info_if_needed()
47
+
48
+ if hasattr(self.optimizer, 'logging') and self.optimizer.logging:
49
+ if not hasattr(self.optimizer, '_beta2_log'):
50
+ self.optimizer._beta2_log = []
51
+
52
+ for layer_key, info in self.layer_info.items():
53
+ params, group = info['params'], info['group_ref']
54
+
55
+ if layer_key not in self.layer_state:
56
+ self.layer_state[layer_key] = {
57
+ 'r_ema_grad_norm': torch.tensor(0.0, device=params[0].device, dtype=torch.float32),
58
+ 'sum_sq_accumulator': torch.tensor(0.0, device=params[0].device, dtype=torch.float32)
59
+ }
60
+
61
+ layer_state = self.layer_state[layer_key]
62
+
63
+ pooled_grad_norm = torch.sqrt(layer_state['sum_sq_accumulator'])
64
+
65
+ r_ema = layer_state['r_ema_grad_norm']
66
+ r_ema.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
67
+
68
+ raw = pooled_grad_norm / (r_ema + group['tiny_spike'])
69
+ sun = raw / (1.0 + raw)
70
+ beta2_max = group['betas'][1]
71
+ beta2 = beta2_max - (beta2_max - group['beta2_min']) * sun
72
+
73
+ layer_state['dynamic_beta2'] = beta2.item()
74
+ layer_state['sum_sq_accumulator'].zero_()
75
+
76
+ if hasattr(self.optimizer, 'logging') and self.optimizer.logging and hasattr(self.optimizer, '_beta2_log'):
77
+ self.optimizer._beta2_log.append(beta2.item())
78
+
79
+ def maybe_prepare_step(self, current_step: int):
80
+ """
81
+ A universal guard that calls prepare_step() exactly once per training step.
82
+ """
83
+ if self._current_step_prepared < current_step:
84
+ self.prepare_step()
85
+ self._current_step_prepared = current_step
86
+
87
+ def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
88
+ """
89
+ Accumulates the squared L2 norm of a single gradient for the next step's calculation.
90
+ """
91
+ layer_key = self.optimizer.layer_key_fn(p)
92
+ if layer_key not in self.layer_state:
93
+ self.layer_state[layer_key] = {
94
+ 'r_ema_grad_norm': torch.tensor(0.0, device=p.device, dtype=torch.float32),
95
+ 'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
96
+ }
97
+ self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
98
+
99
+ def get_beta2(self, p: torch.Tensor, group: dict, current_step: int) -> float:
100
+ """
101
+ Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
102
+ """
103
+ beta2_default = group['betas'][1]
104
+ if current_step < group['k_warmup_steps']:
105
+ return 0.5 * (group['beta2_min'] + beta2_default)
106
+
107
+ layer_key = self.optimizer.layer_key_fn(p)
108
+ return self.layer_state.get(layer_key, {}).get('dynamic_beta2', beta2_default)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.0.6
3
+ Version: 1.1.0.dev1
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -16,6 +16,7 @@ adv_optm/optim/Simplified_AdEMAMix.py
16
16
  adv_optm/optim/__init__.py
17
17
  adv_optm/util/BF16_Stochastic_Rounding.py
18
18
  adv_optm/util/Effective_Shape.py
19
+ adv_optm/util/Kourkoutas.py
19
20
  adv_optm/util/NNMF.py
20
21
  adv_optm/util/One_Bit_Boolean.py
21
22
  adv_optm/util/OrthoGrad.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="1.0.6",
8
+ version="1.1.0.dev1",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes