adv-optm 1.2.dev14__py3-none-any.whl → 2.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -66,7 +66,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
66
66
  k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
67
67
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
68
68
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
69
- logging (default: 0).
69
+ logging (default: 0).
70
70
  layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
71
71
  and returns a unique, hashable key representing its "layer" or "bucket".
72
72
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
@@ -86,7 +86,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
86
86
  beta1_warmup: int | None = None,
87
87
  min_beta1: float | None = 0.9,
88
88
  use_bias_correction: bool = True,
89
- vector_reshape: bool = True,
89
+ vector_reshape: bool = False,
90
90
  stochastic_rounding: bool = True,
91
91
  orthogonal_gradient: bool = False,
92
92
  kourkoutas_beta: bool = False,
@@ -97,6 +97,8 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
97
97
  k_logging: int = 0,
98
98
  layer_key_fn: Optional[Callable] = None,
99
99
  nnmf_factor: bool = False,
100
+ # Compiled
101
+ compiled_optimizer: bool = False,
100
102
  ):
101
103
  if not (lr >= 0.0):
102
104
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -108,7 +110,8 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
108
110
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
109
111
  if not 0.0 <= alpha_grad:
110
112
  raise ValueError("Invalid alpha value: {}".format(alpha_grad))
111
- if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
113
+ if kourkoutas_beta and not (betas[1] > beta2_min):
114
+ raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
112
115
 
113
116
  defaults = {
114
117
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
@@ -117,16 +120,33 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
117
120
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
118
121
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
119
122
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
123
+ "compiled_optimizer": compiled_optimizer,
120
124
  }
121
125
  self.stochastic_rounding = stochastic_rounding
122
126
  self.factored = nnmf_factor
123
127
  self.kourkoutas_beta = kourkoutas_beta
124
128
  self.layer_key_fn = layer_key_fn
129
+ self.use_bias_correction = use_bias_correction
130
+ if use_bias_correction:
131
+ self.num_sum = betas[0] * 1.0
132
+ self.den_sum = betas[1] * (1.0 - betas[1])
133
+ else:
134
+ self.num_sum = 1.0
135
+ self.den_sum = 1.0
136
+
125
137
  super().__init__(params, defaults)
126
138
 
139
+ self.init_step()
140
+
127
141
  if self.kourkoutas_beta:
128
142
  self.kourkoutas_helper = KourkoutasHelper(self)
129
143
 
144
+ self.global_step = 0
145
+
146
+ if compiled_optimizer:
147
+ torch._dynamo.config.cache_size_limit = 8192
148
+ self.compile(fullgraph=True)
149
+
130
150
  @property
131
151
  def supports_fused_back_pass(self):
132
152
  return True
@@ -139,29 +159,22 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
139
159
  def supports_flat_params(self):
140
160
  return False
141
161
 
142
- @torch.no_grad()
143
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
144
- if p.grad is None:
145
- return
162
+ def init_step(self):
163
+ for group in self.param_groups:
164
+ for p in group['params']:
165
+ self.__init_state(p, group)
146
166
 
147
- grad = p.grad
148
- if grad.dtype != torch.float32 and self.factored:
149
- grad = grad.float()
150
- if group["orthogonal_gradient"]:
151
- grad = _orthogonalize_gradient(p, grad)
167
+ @torch.no_grad()
168
+ def __init_state(self, p, group):
152
169
  state = self.state[p]
153
170
 
154
- # State Initialization
155
- if 'step' not in state:
156
- state['step'] = 0
171
+ if len(state) == 0:
157
172
 
158
- should_factor = (
173
+ state['factored'] = (
159
174
  self.factored and
160
175
  not (len(p.shape) == 1 and not group['vector_reshape'])
161
176
  )
162
177
 
163
- state['factored'] = should_factor
164
-
165
178
  dtype = torch.float32 if self.factored else p.dtype
166
179
  device = p.device
167
180
 
@@ -170,50 +183,42 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
170
183
  d1, d2 = state['effective_shape']
171
184
 
172
185
  # First moment (m)
173
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
186
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
174
187
  state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
175
188
  packed_d2 = (d2 + 7) // 8
176
189
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
177
190
  # Second moment (v)
178
- state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
191
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
179
192
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
180
193
  else: # Fallback to standard optimizer for non-factored tensors
181
194
  state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
182
195
  state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
183
-
184
- if group['use_bias_correction']:
185
- state['num_sum'] = 0.0
186
- state['den_sum'] = 0.0
187
- else:
188
- state['num_sum'] = 1.0
189
- state['den_sum'] = 1.0
190
196
 
191
- beta1_final, beta2 = group["betas"]
192
197
 
193
- current_step = state['step']
198
+
199
+ @torch.no_grad()
200
+ def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float, beta1, num_sum, den_sum):
201
+ if p.grad is None:
202
+ return
203
+
204
+ grad = p.grad
205
+ if grad.dtype != torch.float32 and self.factored:
206
+ grad = grad.float()
207
+ if group["orthogonal_gradient"]:
208
+ grad = _orthogonalize_gradient(p, grad)
209
+ state = self.state[p]
210
+
211
+
212
+ ___, beta2 = group["betas"]
213
+
194
214
  if group.get('kourkoutas_beta', False):
195
- # Call prepare_step() once at the beginning of the step for all params
196
- self.kourkoutas_helper.maybe_prepare_step(current_step)
197
215
  # Accumulate current grad's norm for the *next* step
198
216
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
199
217
  # Get the dynamic beta2 calculated in prepare_step()
200
- beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
218
+ beta2 = self.kourkoutas_helper.get_beta2(p, group)
201
219
 
202
- beta1_warmup = group["beta1_warmup"]
203
220
  alpha_grad = group["alpha_grad"]
204
221
 
205
- if beta1_warmup is not None:
206
- step = state['step'] + 1
207
- beta1 = linear_hl_warmup_scheduler(step, beta_end=beta1_final, beta_start=group['min_beta1'], warmup=beta1_warmup)
208
- else:
209
- beta1 = beta1_final
210
-
211
- if group['use_bias_correction']:
212
- state['num_sum'] = beta1 * state['num_sum'] + 1.0
213
- if group.get('kourkoutas_beta', False):
214
- state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
215
- else:
216
- state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
217
222
 
218
223
  if state['factored']:
219
224
  d1, d2 = state['effective_shape']
@@ -233,12 +238,12 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
233
238
  update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
234
239
  del grad_reshaped
235
240
 
236
- denom = vt.sqrt().add_(group['eps'] * math.sqrt(state['den_sum']))
241
+ denom = vt.sqrt().add_(group['eps'] * math.sqrt(den_sum))
237
242
  update.div_(denom)
238
243
  del denom
239
244
 
240
245
  if group['use_bias_correction']:
241
- update = (update / state['num_sum']) * math.sqrt(state['den_sum'])
246
+ update = (update / num_sum) * math.sqrt(den_sum)
242
247
 
243
248
  update = update.view(p.shape).mul_(group['lr'])
244
249
 
@@ -259,12 +264,12 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
259
264
 
260
265
  exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
261
266
 
262
- denom = exp_avg_sq.sqrt().add_(group['eps'] * math.sqrt(state['den_sum']))
267
+ denom = exp_avg_sq.sqrt().add_(group['eps'] * math.sqrt(den_sum))
263
268
  update.div_(denom)
264
269
  del denom
265
270
 
266
271
  if group['use_bias_correction']:
267
- update = (update / state['num_sum']) * math.sqrt(state['den_sum'])
272
+ update = (update / num_sum) * math.sqrt(den_sum)
268
273
 
269
274
  update.mul_(group['lr'])
270
275
 
@@ -281,7 +286,36 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
281
286
  p.data.add_(-update)
282
287
  del update
283
288
 
284
- state['step'] += 1
289
+
290
+ @torch.no_grad()
291
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
292
+ if self.global_step is None and 'step' in self.state[p]:
293
+ # For backward compatibility
294
+ g_state = self.state[p]
295
+ self.global_step = g_state['step']
296
+ self.num_sum = group["betas"][0] * g_state['num_sum'] + 1.0
297
+ self.den_sum = group['betas'][1] * g_state['den_sum'] + (1.0 - group['betas'][1])
298
+
299
+ if group["beta1_warmup"] is not None:
300
+ step = self.global_step + 1
301
+ beta1 = linear_hl_warmup_scheduler(step, beta_end=group["betas"][0], beta_start=group['min_beta1'], warmup=group["beta1_warmup"])
302
+ else:
303
+ beta1 = group["betas"][0]
304
+
305
+ if group.get('kourkoutas_beta', False):
306
+ # Prepare Kourkoutas-β once per step using the global step counter.
307
+ self.kourkoutas_helper.maybe_prepare_step(self.global_step)
308
+
309
+ if not group.get('compiled_optimizer', False):
310
+ self.__step_parameter(p, group, group['lr'], beta1, self.num_sum, self.den_sum)
311
+ else:
312
+ lr_tensor = torch.tensor(group['lr'], device=p.device)
313
+ num_sum_tesnor = torch.tensor(self.num_sum, device=p.device)
314
+ den_sum_tesnor = torch.tensor(self.den_sum, device=p.device)
315
+ self._compiled_step_parameter(p, group, lr_tensor, beta1, self.num_sum, self.den_sum)
316
+
317
+ def compile(self, *args, **kwargs):
318
+ self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
285
319
 
286
320
  @torch.no_grad()
287
321
  def step(self, closure=None):
@@ -294,5 +328,12 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
294
328
  for group in self.param_groups:
295
329
  for i, p in enumerate(group['params']):
296
330
  self.step_parameter(p, group, i)
331
+
332
+ g_group = self.param_groups[0]
333
+ if g_group['use_bias_correction']:
334
+ self.num_sum = g_group["betas"][0] * self.num_sum + 1.0
335
+ self.den_sum = g_group['betas'][1] * self.den_sum + (1.0 - g_group['betas'][1])
336
+
337
+ self.global_step += 1
297
338
 
298
- return loss
339
+ return loss
@@ -16,4 +16,4 @@ __all__ = [
16
16
  "Lion_Prodigy_adv",
17
17
  "Muon_adv",
18
18
  "AdaMuon_adv",
19
- ]
19
+ ]
@@ -44,4 +44,4 @@ def add_stochastic_(input: Tensor, other: Tensor, alpha: float = 1.0):
44
44
  result = other.clone() if other.dtype == torch.float32 else other.to(dtype=torch.float32)
45
45
 
46
46
  result.add_(input, alpha=alpha)
47
- copy_stochastic_(input, result)
47
+ copy_stochastic_(input, result)
@@ -5,4 +5,4 @@ def _get_effective_shape(numel: int) -> tuple[int, int]:
5
5
  for i in reversed(range(1, int(numel ** 0.5) + 1)):
6
6
  if numel % i == 0:
7
7
  return (numel // i, i)
8
- return (numel, 1)
8
+ return (numel, 1)
@@ -1,6 +1,5 @@
1
1
  import torch
2
2
  from torch.optim import Optimizer
3
- from typing import Callable
4
3
 
5
4
  class KourkoutasHelper:
6
5
  """
@@ -58,7 +57,7 @@ class KourkoutasHelper:
58
57
  Calculates dynamic beta2 for all layers using the completed scalar accumulators
59
58
  from the PREVIOUS step. Should be called once at the start of an optimizer step.
60
59
  """
61
-
60
+
62
61
  beta2_log = []
63
62
  # These are just for the sample log, initialize them
64
63
  sun, pooled_grad_norm, r_ema_tensor = (torch.tensor(0.0),)*3
@@ -69,7 +68,7 @@ class KourkoutasHelper:
69
68
  master_defaults = self.optimizer.defaults
70
69
 
71
70
  for layer_key, info in self.layer_info.items():
72
- params, group = info['params'], info['group_ref']
71
+ group = info['group_ref']
73
72
 
74
73
  if not group.get('kourkoutas_beta', False) and not group.get('adam_kourkoutas_beta', False):
75
74
  continue
@@ -81,7 +80,7 @@ class KourkoutasHelper:
81
80
  self.layer_state[layer_key] = {
82
81
  'sum_sq_accumulator': torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
83
82
  }
84
-
83
+
85
84
  if 'kourkoutas_r_ema' not in param_state:
86
85
  param_state['kourkoutas_r_ema'] = torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
87
86
 
@@ -96,14 +95,14 @@ class KourkoutasHelper:
96
95
 
97
96
  r_ema_tensor = param_state['kourkoutas_r_ema']
98
97
  accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
99
-
98
+
100
99
  pooled_grad_norm = torch.sqrt(accumulator)
101
-
100
+
102
101
  # Update the persistent EMA tensor in-place.
103
102
  r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
104
-
103
+
105
104
  sun = torch.tensor(0.0, device=r_ema_tensor.device) # Default sun to 0 for warmup
106
-
105
+
107
106
  if current_step < k_warmup_steps:
108
107
  beta2 = beta2_max
109
108
  else:
@@ -113,7 +112,7 @@ class KourkoutasHelper:
113
112
 
114
113
  # Store the final calculated beta2 in the helper's transient state for this step.
115
114
  self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
116
-
115
+
117
116
  # Reset the accumulator for the next optimizer step.
118
117
  accumulator.zero_()
119
118
 
@@ -149,10 +148,10 @@ class KourkoutasHelper:
149
148
  # Accumulate for the *next* step's prepare_step call
150
149
  self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
151
150
 
152
- def get_beta2(self, p: torch.Tensor, group: dict, current_step: int) -> float:
151
+ def get_beta2(self, p: torch.Tensor, group: dict) -> float:
153
152
  """
154
153
  Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
155
154
  """
156
155
  layer_key = self.optimizer.layer_key_fn(p)
157
156
  # The default is the max value, which is correct for unmapped params or edge cases
158
- return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
157
+ return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
adv_optm/util/NNMF.py CHANGED
@@ -9,10 +9,15 @@ def _nnmf(matrix: torch.Tensor, out: tuple):
9
9
  shape = matrix.shape
10
10
  torch.sum(matrix, dim=1, out=out[0])
11
11
  torch.sum(matrix, dim=0, out=out[1])
12
+
13
+ # Add a small epsilon for numerical stability and to remove
14
+ # data-dependent branching, making it compatible with torch.dynamo.
15
+ epsilon = 1e-12
16
+
12
17
  # Normalize one of the factors for stability
13
18
  if shape[0] < shape[1]:
14
19
  scale = out[0].sum()
15
- if scale != 0: out[0].div_(scale)
20
+ out[0].div_(scale + epsilon)
16
21
  else:
17
22
  scale = out[1].sum()
18
- if scale != 0: out[1].div_(scale)
23
+ out[1].div_(scale + epsilon)
@@ -21,7 +21,6 @@ def _newton_schulz_iteration(
21
21
  Returns:
22
22
  torch.Tensor: The orthogonalized matrix.
23
23
  """
24
- assert G.ndim == 2, "Newton-Schulz iteration only supports 2D matrices."
25
24
 
26
25
  a, b, c = coeffs
27
26
 
@@ -45,4 +44,4 @@ def _newton_schulz_iteration(
45
44
  if transposed:
46
45
  X = X.T
47
46
 
48
- return X.to(G.dtype)
47
+ return X.to(G.dtype)
@@ -19,4 +19,4 @@ def _unpack_bools(packed_tensor: torch.Tensor, original_m: int) -> torch.Tensor:
19
19
  shifter = (2**torch.arange(8, device=packed_tensor.device, dtype=torch.uint8)).view(1, 1, 8)
20
20
  unpacked_padded = (packed_tensor.unsqueeze(2) & shifter) != 0
21
21
  unpacked = unpacked_padded.view(packed_tensor.shape[0], -1)[:, :original_m]
22
- return unpacked
22
+ return unpacked
@@ -2,15 +2,16 @@ import torch
2
2
 
3
3
  def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
4
4
  """Projects the gradient `grad` to be orthogonal to the parameter `p`."""
5
- if grad.is_sparse: raise RuntimeError("OrthoGrad logic does not support sparse gradients.")
5
+ if grad.is_sparse:
6
+ raise RuntimeError("OrthoGrad logic does not support sparse gradients.")
6
7
  original_shape = grad.shape
7
8
  original_dtype = grad.dtype
8
9
  w = p.view(-1).float()
9
10
  g = grad.view(-1).float()
10
11
  w_norm_sq = torch.dot(w, w).add_(1e-30)
11
12
  proj = torch.dot(w, g) / w_norm_sq
12
- g_orth = g.sub(w, alpha=proj)
13
+ g_orth = g.sub(w * proj)
13
14
  g_norm = g.norm(2)
14
15
  g_orth_norm = g_orth.norm(2).add_(1e-30)
15
16
  g_orth_scaled = g_orth * (g_norm / g_orth_norm)
16
- return g_orth_scaled.view(original_shape).to(original_dtype)
17
+ return g_orth_scaled.view(original_shape).to(original_dtype)
adv_optm/util/__init__.py CHANGED
@@ -10,4 +10,4 @@ __all__ = [
10
10
  "_get_effective_shape",
11
11
  "_orthogonalize_gradient",
12
12
  "_newton_schulz_iteration",
13
- ]
13
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev14
3
+ Version: 2.dev1
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -52,7 +52,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
52
52
  ### **Memory-Efficient Optimization (SMMF-inspired)**
53
53
  - **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
54
54
  - **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
55
- - **Innovation**:
55
+ - **Innovation**:
56
56
  - First moment split into **1-bit sign + absolute value**
57
57
  - Final storage: **four factored vectors + one 1-bit sign state**
58
58
  - Preserves Adam-like update quality with drastically reduced memory
@@ -110,7 +110,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
110
110
 
111
111
  ## 🛠️ Comprehensive Feature Guide
112
112
 
113
- ### A. Universal Safe Features
113
+ ### A. Universal Safe Features
114
114
  *These features work with all optimizers and are generally safe to enable.*
115
115
 
116
116
  | Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
@@ -165,7 +165,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
165
165
  | `beta1` | 0.99 | Controls accumulator memory length:<br>• Small BS: **0.99–0.9999**<br>• Large BS: **0.9** |
166
166
  | `Grad α` | 100 | Most critical parameter:<br>• Inversely scales with batch size<br>• **100–10** for small BS (≤32)<br>• **1–0.1** for large BS (≥512) |
167
167
 
168
- > ⚠️ **Critical**: Requires **~100x smaller learning rate** than AdamW (e.g., 1e-6 vs 1e-4).
168
+ > ⚠️ **Critical**: Requires **~100x smaller learning rate** than AdamW (e.g., 1e-6 vs 1e-4).
169
169
  > For `Prodigy_Adv`, set `initial_d` to:
170
170
  > - **LoRA**: `1e-8`
171
171
  > - **Full FT**: `1e-10`
@@ -175,10 +175,10 @@ This library integrates multiple state-of-the-art optimization techniques valida
175
175
 
176
176
  #### Performance Validation
177
177
 
178
- **Small Batch Training (SDXL, BS=2, 1.8K steps)**
178
+ **Small Batch Training (SDXL, BS=2, 1.8K steps)**
179
179
  ![Training Comparison](https://github.com/user-attachments/assets/7eff0671-cc59-47fc-8b63-d5205456d649)
180
180
 
181
- - **🟢 Prodigy_Adv** (beta1=0.9, d0=1e-5): Final LR = 2.9e-4
181
+ - **🟢 Prodigy_Adv** (beta1=0.9, d0=1e-5): Final LR = 2.9e-4
182
182
  - **🔵 Prodigy_Adv + Simplified_AdEMAMix** (beta1=0.99, α=100, d0=1e-7): Final LR = 5.8e-6
183
183
 
184
184
  **Results**:
@@ -202,8 +202,8 @@ This library integrates multiple state-of-the-art optimization techniques valida
202
202
 
203
203
  Instead of using a fixed β₂ (e.g., 0.999 or 0.95), it **dynamically modulates β₂ per layer** based on a bounded *sunspike ratio*:
204
204
 
205
- - **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
206
- - **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
205
+ - **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
206
+ - **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
207
207
 
208
208
  This is especially effective for **noisy training, small batch sizes, and high learning rates**, where gradient norms shift abruptly due to noise or aggressive LR schedules.
209
209
 
@@ -220,17 +220,17 @@ This is especially effective for **noisy training, small batch sizes, and high l
220
220
 
221
221
  #### 📊 Performance Validation
222
222
 
223
- **ADAMW_ADV - full SDXL finetuning (aggressive LR: 3e-5) (BS=4, 2.5K steps)**
223
+ **ADAMW_ADV - full SDXL finetuning (aggressive LR: 3e-5) (BS=4, 2.5K steps)**
224
224
  <img width="1460" height="382" alt="image" src="https://github.com/user-attachments/assets/007f278a-fbac-4f3d-9cc7-274c3b959cdd" />
225
225
 
226
- - 🟣 Fixed `beta2=0.999`
227
- - 🟠 Auto K-beta
226
+ - 🟣 Fixed `beta2=0.999`
227
+ - 🟠 Auto K-beta
228
228
 
229
- **Observations:**
229
+ **Observations:**
230
230
  - K-beta is clearly better and more robust/stable for high LRs.
231
231
 
232
- > 📚 **Reference**:
233
- > - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
232
+ > 📚 **Reference**:
233
+ > - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
234
234
  > - Code: [kbeta](https://github.com/sck-at-ucy/kbeta)
235
235
 
236
236
  ---
@@ -258,7 +258,7 @@ settings:
258
258
  - factored: False # Can be true or false, quality should not degrade due to Simplified_AdEMAMix’s high tolerance to 1-bit factorization.
259
259
  ```
260
260
 
261
- > ✅ **Why it works**:
261
+ > ✅ **Why it works**:
262
262
  > - `Kourkoutas-β` handles beta2 values
263
263
  > - `Simplified_AdEMAMix` ensures responsiveness in small-batch noise
264
264
  > - `OrthoGrad` prevents overfitting without weight decay
@@ -267,9 +267,9 @@ settings:
267
267
 
268
268
  ## 📚 References
269
269
 
270
- 1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
271
- 2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
272
- 3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
273
- 4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
274
- 5. [AdaMeM: Memory Efficient Momentum for Adafactor](https://openreview.net/forum?id=fZqMVTz7K5)
270
+ 1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
271
+ 2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
272
+ 3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
273
+ 4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
274
+ 5. [AdaMeM: Memory Efficient Momentum for Adafactor](https://openreview.net/forum?id=fZqMVTz7K5)
275
275
  6. [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
@@ -0,0 +1,23 @@
1
+ adv_optm/__init__.py,sha256=UYBoPKsOboNHt-w9RtcMQ1UsE-yQTSm71JxNNTcTrCU,379
2
+ adv_optm/optim/AdaMuon_adv.py,sha256=-UBw_mJj8JzDAi3zQ0nLnSOgzsTzl7b7kVksDRUziEE,30582
3
+ adv_optm/optim/AdamW_adv.py,sha256=e9KWjAFjUUGy773PbaTANU2WqhvhJ1biuO7pIjVSxSM,17820
4
+ adv_optm/optim/Adopt_adv.py,sha256=ctPimRRky3vRxAw0CYfUD-0IxNO8olj-Srjngw9IoFo,21916
5
+ adv_optm/optim/Lion_Prodigy_adv.py,sha256=k76U1UzAauibTat1Mu7whjVn2gX-8F5GuOUhEWn1aYc,14126
6
+ adv_optm/optim/Lion_adv.py,sha256=qfyclaaOLtSVc2q-ZAazZ0Iu_rv5ob_Zy720iK0_Czg,7691
7
+ adv_optm/optim/Muon_adv.py,sha256=8d99NcXzLyxTbxVVXC8mHyeW7wM8jjK59QoXVTLScQA,32112
8
+ adv_optm/optim/Prodigy_adv.py,sha256=5Jlxo6hXOMpiZQuNxEQXkuiQr9IdPO9V1dNEvs_Sb5I,26296
9
+ adv_optm/optim/Simplified_AdEMAMix.py,sha256=Iwp434fsjo6SEMY6RqAdXVxeDKtg77r_pEm-BNqP5UU,14417
10
+ adv_optm/optim/__init__.py,sha256=F4f-D8QGIByXHAZAu0keJf4foA22NpK-L9QgywVxAm8,491
11
+ adv_optm/util/BF16_Stochastic_Rounding.py,sha256=b8bE7xGtJxZnQYCqdPKtYb8xYGrDftO6jCLLKLa9Ut8,1550
12
+ adv_optm/util/Effective_Shape.py,sha256=h9pF4HaCkjDyo2dxlUpM66oD6FtclQnb7yPPfvReHyI,320
13
+ adv_optm/util/Kourkoutas.py,sha256=8Lik30MACDwM77aNWmMecmPS9g31fT4jE6fuIG4QMTk,7366
14
+ adv_optm/util/NNMF.py,sha256=hrvNGERj8evhPIWnWzsKdm5DwIZblTB4pkhc9xWytSY,794
15
+ adv_optm/util/Newton_Schulz.py,sha256=5Em0PnTXic7bqU5VLWJJeJXnVt1_zqVRPc2CEZi7yLA,1301
16
+ adv_optm/util/One_Bit_Boolean.py,sha256=tE8lSnbKR3oO-EtM0Kzvf0E4hmuBvhmtFR_75su-DNI,1070
17
+ adv_optm/util/OrthoGrad.py,sha256=doP667YpdiEdP3-cpyWiRNkAdkT-nzs45VSafOCRDHw,713
18
+ adv_optm/util/__init__.py,sha256=cA5zt5dvznkOw2lqbaGvFjslznB1UEFYYZMMFsXrWBg,437
19
+ adv_optm-2.dev1.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
20
+ adv_optm-2.dev1.dist-info/METADATA,sha256=u-2m5LUufoFimmEb87P6JGvJoJEbf0SeyL2fQw_y1rY,13983
21
+ adv_optm-2.dev1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ adv_optm-2.dev1.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
23
+ adv_optm-2.dev1.dist-info/RECORD,,
@@ -1,23 +0,0 @@
1
- adv_optm/__init__.py,sha256=D5arg90L2AukHVLCuo7eEbYCh1KtUMOnCwrxsBQgA18,380
2
- adv_optm/optim/AdaMuon_adv.py,sha256=-UBw_mJj8JzDAi3zQ0nLnSOgzsTzl7b7kVksDRUziEE,30582
3
- adv_optm/optim/AdamW_adv.py,sha256=KL9SCJWZ_ckAQEApB6ofbndVYjancN-v7Us7hJLFf54,17475
4
- adv_optm/optim/Adopt_adv.py,sha256=S8XI2YA7683jsW8p7igc2YcU30lsN0H18qL02Kpvj8E,21244
5
- adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
6
- adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
7
- adv_optm/optim/Muon_adv.py,sha256=8d99NcXzLyxTbxVVXC8mHyeW7wM8jjK59QoXVTLScQA,32112
8
- adv_optm/optim/Prodigy_adv.py,sha256=lEjbtuQbomsCX39DnTPeI8Z5YG0f2aZPXN_E7-nGgWw,26060
9
- adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
10
- adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
11
- adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
12
- adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
13
- adv_optm/util/Kourkoutas.py,sha256=SSzhe0B6Zb2AXGwCKpVTLr0aaFfspcFBNZCZG3azI9k,7516
14
- adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
15
- adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
16
- adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
17
- adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
18
- adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
19
- adv_optm-1.2.dev14.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
20
- adv_optm-1.2.dev14.dist-info/METADATA,sha256=g017hnuxrm1a34pjnXUDZlnUify9xQtX_ZkrbMEXLLY,14023
21
- adv_optm-1.2.dev14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
- adv_optm-1.2.dev14.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
23
- adv_optm-1.2.dev14.dist-info/RECORD,,