adv-optm 1.2.dev2__tar.gz → 1.2.dev3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (29) hide show
  1. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/PKG-INFO +1 -1
  2. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/__init__.py +3 -1
  3. adv_optm-1.2.dev3/adv_optm/optim/AdaMuon_adv.py +443 -0
  4. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/optim/AdamW_adv.py +3 -3
  5. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/optim/Adopt_adv.py +3 -6
  6. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/optim/Muon_adv.py +39 -3
  7. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/optim/Prodigy_adv.py +3 -3
  8. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/optim/Simplified_AdEMAMix.py +3 -3
  9. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/optim/__init__.py +2 -0
  10. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/util/Kourkoutas.py +5 -5
  11. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/util/MuonAdam_helper.py +1 -0
  12. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm.egg-info/PKG-INFO +1 -1
  13. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm.egg-info/SOURCES.txt +1 -0
  14. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/setup.py +1 -1
  15. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/LICENSE +0 -0
  16. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/README.md +0 -0
  17. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  18. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/optim/Lion_adv.py +0 -0
  19. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  20. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/util/Effective_Shape.py +0 -0
  21. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/util/NNMF.py +0 -0
  22. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/util/Newton_Schulz.py +0 -0
  23. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/util/One_Bit_Boolean.py +0 -0
  24. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/util/OrthoGrad.py +0 -0
  25. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm/util/__init__.py +0 -0
  26. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm.egg-info/dependency_links.txt +0 -0
  27. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm.egg-info/requires.txt +0 -0
  28. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/adv_optm.egg-info/top_level.txt +0 -0
  29. {adv_optm-1.2.dev2 → adv_optm-1.2.dev3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev2
3
+ Version: 1.2.dev3
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -6,6 +6,7 @@ from .optim import (
6
6
  Lion_adv,
7
7
  Lion_Prodigy_adv,
8
8
  Muon_adv,
9
+ AdaMuon_adv,
9
10
  )
10
11
 
11
12
  __all__ = [
@@ -16,6 +17,7 @@ __all__ = [
16
17
  "Lion_adv",
17
18
  "Lion_Prodigy_adv",
18
19
  "Muon_adv",
20
+ "AdaMuon_adv",
19
21
  ]
20
22
 
21
- __version__ = "1.2.dev2"
23
+ __version__ = "1.2.dev3"
@@ -0,0 +1,443 @@
1
+ import torch
2
+ from typing import Optional, Callable
3
+
4
+ from .AdamW_adv import AdamW_adv
5
+ from ..util.MuonAdam_helper import MuonAdamHelper
6
+ from ..util.Kourkoutas import KourkoutasHelper
7
+
8
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
9
+ from ..util.Newton_Schulz import _newton_schulz_iteration
10
+ from ..util.Effective_Shape import _get_effective_shape
11
+ from ..util.NNMF import _nnmf,_unnmf
12
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
13
+
14
+ class AdaMuon_adv(torch.optim.Optimizer):
15
+ """
16
+ Implements the AdaMuon optimizer algorithm.
17
+
18
+ AdaMuon combines the geometry-aware updates of Muon with the element-wise
19
+ adaptivity of Adam. It is designed for 2D parameters (e.g., linear layers)
20
+ and can handle higher-dimensional parameters by flattening.
21
+
22
+ The algorithm incorporates three key mechanisms:
23
+ 1. A sign-stabilized orthogonal update, where the sign of the momentum is
24
+ orthogonalized instead of the momentum itself.
25
+ 2. An element-wise second momentum estimator applied to the orthogonalized
26
+ update directions.
27
+ 3. An RMS-aligned rescaling strategy to match the update magnitude of Adam,
28
+ allowing for reuse of learning rate schedules.
29
+
30
+ Can also operate in a hybrid mode, using an auxiliary AdamW
31
+ optimizer for specific parameters (e.g., biases, norms, embeddings) as
32
+ defined by a `layer_key_fn`.
33
+
34
+ Args:
35
+ params (iterable): iterable of parameters to optimize or dicts defining
36
+ parameter groups.
37
+ lr (float): learning rate (default: 1e-3).
38
+ betas (tuple[float, float]): coefficients used for both first and second moment
39
+ estimation (default: (0.95, 0.95))
40
+ weight_decay (float): weight decay (L2 penalty) (default: 0.1).
41
+ eps (float): term added to the denominator for adaptive scaling to improve
42
+ numerical stability (default: 1e-8).
43
+ rms_target (float): The target Root-Mean-Square value for the final update
44
+ vector, used for RMS-aligned rescaling. Allows for the reuse of existing Adam
45
+ learning rate schedules. (default: 0.2).
46
+ ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
47
+ ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
48
+ ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
49
+ quintic polynomial in the Newton-Schulz iteration.
50
+ (default: (3.4445, -4.7750, 2.0315)).
51
+ stochastic_rounding (bool): whether to use stochastic rounding for
52
+ BF16 parameter updates (default: True).
53
+ nesterov (bool): enables Nesterov momentum (default: False).
54
+ use_atan2 (bool): whether to use the atan2 update rule. (default: False)
55
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
56
+ This changes the update to `alpha_grad * grad + mt`, which can be
57
+ more responsive, especially for small batch sizes. (default: False)
58
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
59
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
60
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
61
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
62
+ stability. (default: 100.0)
63
+ vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
64
+ matrices for muon NewtonSchulz (default: False).
65
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
66
+ matrices to apply low-rank compression (default: True).
67
+ nnmf_factor (bool): whether to use the factorization or disable it to use
68
+ the uncompressed optimizer. (default: False)
69
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
70
+ If `False`, the optimizer behaves as standard AdamW. (default: False)
71
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
72
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
73
+ (default: 0.88)
74
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
75
+ the pooled gradient norms. Corresponds to `α` in the paper.
76
+ (default: 0.93)
77
+ tiny_spike (float): A small constant added to the denominator of the
78
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
79
+ to `ε_spike` in the paper. (default: 1e-9)
80
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
81
+ at a fixed beta2 value before the
82
+ dynamic logic activates. (default: 0)
83
+ MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
84
+ Parameters designated by `layer_key_fn` will be optimized with
85
+ AdamW_adv instead of Muon. (default: False)
86
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
87
+ and returns a key. If the key is 'adam', the parameter is handled by
88
+ the auxiliary AdamW optimizer. All other keys are handled by Muon.
89
+ Only used when `MuonWithAuxAdam` is True. (default: None)
90
+ adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
91
+ to the auxiliary AdamW_adv optimizer. Only used when
92
+ `MuonWithAuxAdam` is True. (default: None)
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ params,
98
+ lr: float = 1e-3,
99
+ betas: tuple[float, float] = (0.95, 0.95),
100
+ weight_decay: float = 0.1,
101
+ eps: float = 1e-8,
102
+ rms_target: float = 0.2,
103
+ ns_steps: int = 5,
104
+ ns_eps: float = 1e-7,
105
+ ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
106
+ stochastic_rounding: bool = True,
107
+ use_atan2: bool = False,
108
+ nesterov: bool = False,
109
+ Simplified_AdEMAMix: bool = False,
110
+ alpha_grad: float = 100.0,
111
+ vector_reshape_muon: bool = False,
112
+ vector_reshape: bool = False,
113
+ nnmf_factor: bool = False,
114
+ # K-b parameters
115
+ kourkoutas_beta: bool = False,
116
+ beta2_min: float = 0.9,
117
+ ema_alpha: float = 0.95,
118
+ tiny_spike: float = 1e-9,
119
+ k_warmup_steps: int = 0,
120
+ k_logging: int = 0,
121
+ layer_key_kb_fn: Optional[Callable] = None,
122
+ # hybrid optimizer mode
123
+ MuonWithAuxAdam: bool = False,
124
+ layer_key_fn: Optional[Callable] = None,
125
+ muon_adam_lr: float = 1e-4,
126
+ adam_kwargs: Optional[dict] = None,
127
+ ):
128
+ if not (lr >= 0.0):
129
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
130
+ if not (weight_decay >= 0.0):
131
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
132
+ if not (ns_steps > 0):
133
+ raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
134
+ if Simplified_AdEMAMix and nesterov:
135
+ print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
136
+ nesterov = False
137
+
138
+ defaults = {
139
+ "lr": lr, "betas": betas, "weight_decay": weight_decay,
140
+ "eps": eps, "rms_target": rms_target, "ns_steps": ns_steps,
141
+ "ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
142
+ "vector_reshape": vector_reshape,
143
+ "vector_reshape_muon": vector_reshape_muon,
144
+ "nesterov":nesterov, "use_atan2":use_atan2,
145
+ "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
146
+ "_kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
147
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
148
+ }
149
+ self.stochastic_rounding = stochastic_rounding
150
+ self._kourkoutas_beta = kourkoutas_beta
151
+ self._kourkoutas_helper = None
152
+ self.layer_key_kb_fn = layer_key_kb_fn
153
+ self.MuonWithAuxAdam = MuonWithAuxAdam
154
+ self.helper = None
155
+ self.aux_adam = None
156
+
157
+ if self.MuonWithAuxAdam:
158
+ adam_kwargs = adam_kwargs or {}
159
+ # Create a delegate AdamW optimizer to get its default hyperparameters.
160
+ self.aux_adam = AdamW_adv(
161
+ [],
162
+ lr=muon_adam_lr,
163
+ **adam_kwargs,
164
+ _is_delegate=True
165
+ )
166
+ # Update the defaults dictionary
167
+ defaults.update(self.aux_adam.defaults)
168
+
169
+ super().__init__(params, defaults)
170
+
171
+ if self.MuonWithAuxAdam:
172
+ self.helper = MuonAdamHelper(self, layer_key_fn)
173
+
174
+
175
+ @property
176
+ def supports_fused_back_pass(self):
177
+ return True
178
+
179
+ @property
180
+ def supports_memory_efficient_fp16(self):
181
+ return True
182
+
183
+ @property
184
+ def supports_flat_params(self):
185
+ return False
186
+
187
+ @property
188
+ def kourkoutas_helper(self):
189
+ """
190
+ Exposes the kourkoutas_helper from the auxiliary AdamW optimizer,
191
+ if it exists. This allows external access for logging K-b.
192
+ """
193
+ if self.aux_adam and hasattr(self.aux_adam, 'kourkoutas_helper'):
194
+ return self.aux_adam.kourkoutas_helper
195
+ return None
196
+
197
+ @torch.no_grad()
198
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
199
+ if self.MuonWithAuxAdam:
200
+ optim_type = self.helper.get_optimizer_type(p)
201
+ if optim_type == 'adam':
202
+ # Delegate to the AdamW_adv optimizer's logic.
203
+ # We need to temporarily "lend" our state and param_groups
204
+ # to the delegate so it has the full context to work with,
205
+ # especially for features like Kourkoutas-beta.
206
+ self.aux_adam.state = self.state
207
+ self.aux_adam.param_groups = self.param_groups
208
+ self.aux_adam.step_parameter(p, group, i)
209
+ return
210
+
211
+ if group['_kourkoutas_beta'] and self._kourkoutas_helper is None:
212
+ self._kourkoutas_helper = KourkoutasHelper(self)
213
+
214
+ if p.grad is None:
215
+ return
216
+
217
+ grad = p.grad
218
+ state = self.state[p]
219
+
220
+
221
+ # State Initialization
222
+ if 'step' not in state:
223
+ state['step'] = 0
224
+
225
+ should_factor = (
226
+ group['nnmf_factor'] and
227
+ not (len(p.shape) == 1 and not group['vector_reshape'])
228
+ )
229
+
230
+ state['factored'] = should_factor
231
+
232
+ state['reshaped_1d_muon'] = len(p.shape) == 1 and group['vector_reshape_muon']
233
+
234
+ dtype = torch.float32 if group['nnmf_factor'] else p.dtype
235
+ device = p.device
236
+ if state['factored'] or state['reshaped_1d_muon']:
237
+ state['effective_shape'] = _get_effective_shape(p.numel())
238
+ d1, d2 = state['effective_shape']
239
+ if state['factored']:
240
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
241
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
242
+ packed_d2 = (d2 + 7) // 8
243
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
244
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
245
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
246
+ else:
247
+ if len(p.shape) >= 2:
248
+ state['momentum_buffer'] = torch.zeros_like(p)
249
+ state['second_momentum_buffer'] = torch.zeros_like(p)
250
+ if state['reshaped_1d_muon']:
251
+ state['momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
252
+ state['second_momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
253
+ elif len(p.shape) == 1:
254
+ state['momentum_buffer'] = torch.zeros_like(p)
255
+
256
+ # Retrieve hyperparameters
257
+ beta1, beta2 = group['betas']
258
+ current_step = state['step']
259
+ nesterov = group['nesterov']
260
+ Simplified_AdEMAMix = group['Simplified_AdEMAMix']
261
+ alpha_grad = group['alpha_grad']
262
+
263
+ if state['factored']: # Factored AdaMuon
264
+
265
+ # Reconstruct momentum from previous step's factors & sign
266
+ d1, d2 = state['effective_shape']
267
+ mt_buf = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
268
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
269
+ torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
270
+ del unpacked_sign
271
+
272
+ # Update momentum in full-size
273
+ grad_reshaped = grad.view(d1, d2)
274
+ mt_buf.mul_(beta1).add_(grad_reshaped)
275
+
276
+ if nesterov:
277
+ signed_m_buf = torch.sign(grad_reshaped.add(mt_buf, alpha=beta1))
278
+ elif Simplified_AdEMAMix:
279
+ signed_m_buf = torch.sign(mt_buf.add(grad_reshaped, alpha=alpha_grad))
280
+ else:
281
+ signed_m_buf = torch.sign(mt_buf)
282
+ del grad_reshaped
283
+
284
+ update = _newton_schulz_iteration(
285
+ signed_m_buf,
286
+ steps=group['ns_steps'],
287
+ eps=group['ns_eps'],
288
+ coeffs=group['ns_coeffs'],
289
+ )
290
+
291
+ if group['_kourkoutas_beta']:
292
+ # Call prepare_step() once at the beginning of the step for all params
293
+ self._kourkoutas_helper.maybe_prepare_step(current_step)
294
+ # Accumulate current sign-stabilized orthogonal update's norm for the *next* step
295
+ self._kourkoutas_helper.accumulate_gradient_sq_norm(p, update.view(p.shape))
296
+ # Get the dynamic beta2 calculated in prepare_step()
297
+ beta2 = self._kourkoutas_helper.get_beta2(p, group, current_step)
298
+
299
+ # Reconstruct second momentum from previous step's factors
300
+ vt_buf = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
301
+
302
+ # Update second momentum in full-size
303
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
304
+
305
+ # Apply second momentum update (adaptive scaling)
306
+ if group['use_atan2']:
307
+ a = 1.2732395
308
+ denom = vt_buf.sqrt()
309
+ update.atan2_(denom).mul_(a)
310
+ else:
311
+ denom = vt_buf.sqrt().add_(group['eps'])
312
+ update.div_(denom)
313
+ del denom
314
+
315
+ # RMS-aligned rescaling
316
+ rms_target = group['rms_target']
317
+ num_elements = update.numel()
318
+ scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm())
319
+
320
+ update.mul_(scaling_factor)
321
+ update = update.view(p.shape).mul_(group['lr'])
322
+ del num_elements, scaling_factor
323
+
324
+ # Compress updated moments and store new factors
325
+ state['sign'] = _pack_bools(mt_buf > 0)
326
+ _nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
327
+ del mt_buf
328
+
329
+ _nnmf(vt_buf.abs(), out=(state['mu_v_nmf'], state['mv_v_nmf']))
330
+ del vt_buf
331
+
332
+ else: # Standard AdaMuon logic for non-factored tensors
333
+
334
+ if len(p.shape) >= 2 or state['reshaped_1d_muon']:
335
+
336
+ # Momentum update
337
+ mt_buf = state['momentum_buffer']
338
+ if state['reshaped_1d_muon']:
339
+ d1, d2 = state['effective_shape']
340
+ grad_reshaped = grad.view(d1, d2)
341
+ mt_buf.mul_(beta1).add_(grad_reshaped)
342
+ if nesterov:
343
+ signed_m_buf = torch.sign(grad_reshaped.add(mt_buf, alpha=beta1))
344
+ elif Simplified_AdEMAMix:
345
+ signed_m_buf = torch.sign(mt_buf.add(grad_reshaped, alpha=alpha_grad))
346
+ else:
347
+ signed_m_buf = torch.sign(mt_buf)
348
+ del grad_reshaped
349
+ else:
350
+ mt_buf.mul_(beta1).add_(grad)
351
+ if nesterov:
352
+ signed_m_buf = torch.sign(grad.add(mt_buf, alpha=beta1))
353
+ elif Simplified_AdEMAMix:
354
+ signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
355
+ else:
356
+ signed_m_buf = torch.sign(mt_buf)
357
+
358
+ # Flatten if necessary (e.g., for Conv layers)
359
+ if len(p.shape) > 2:
360
+ signed_m_buf = signed_m_buf.view(p.shape[0], -1)
361
+
362
+ # NewtonSchulz
363
+ update = _newton_schulz_iteration(
364
+ signed_m_buf,
365
+ steps=group['ns_steps'],
366
+ eps=group['ns_eps'],
367
+ coeffs=group['ns_coeffs'],
368
+ )
369
+
370
+ if len(p.shape) > 2 or state['reshaped_1d_muon']:
371
+ update = update.view(p.shape)
372
+
373
+ if group['_kourkoutas_beta']:
374
+ # Call prepare_step() once at the beginning of the step for all params
375
+ self._kourkoutas_helper.maybe_prepare_step(current_step)
376
+ # Accumulate current sign-stabilized orthogonal update's norm for the *next* step
377
+ self._kourkoutas_helper.accumulate_gradient_sq_norm(p, update)
378
+ # Get the dynamic beta2 calculated in prepare_step()
379
+ beta2 = self._kourkoutas_helper.get_beta2(p, group, current_step)
380
+
381
+ vt_buf = state['second_momentum_buffer']
382
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
383
+
384
+ # Apply second momentum update (adaptive scaling)
385
+ if group['use_atan2']:
386
+ a = 1.2732395
387
+ denom = vt_buf.sqrt()
388
+ update.atan2_(denom).mul_(a)
389
+ else:
390
+ denom = vt_buf.sqrt().add_(group['eps'])
391
+ update.div_(denom)
392
+ del denom
393
+
394
+ # RMS-aligned rescaling
395
+ rms_target = group['rms_target']
396
+ num_elements = update.numel()
397
+ scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm())
398
+
399
+ update.mul_(scaling_factor)
400
+ del num_elements, scaling_factor
401
+
402
+ update.mul_(group['lr'])
403
+
404
+ else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
405
+ # Momentum update
406
+ mt_buf = state['momentum_buffer']
407
+ mt_buf.mul_(beta1).add_(grad)
408
+ if nesterov:
409
+ update = grad.add(mt_buf, alpha=beta1)
410
+ elif Simplified_AdEMAMix:
411
+ signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
412
+ else:
413
+ update = mt_buf.clone()
414
+ update.mul_(group['lr'])
415
+
416
+ # Decoupled weight decay
417
+ if group["weight_decay"] != 0:
418
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
419
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
420
+ else:
421
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
422
+
423
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
424
+ add_stochastic_(p.data, -update)
425
+ else:
426
+ p.data.add_(-update)
427
+ del update
428
+
429
+ state['step'] += 1
430
+
431
+ @torch.no_grad()
432
+ def step(self, closure=None):
433
+ """Performs a single optimization step."""
434
+ loss = None
435
+ if closure is not None:
436
+ with torch.enable_grad():
437
+ loss = closure()
438
+
439
+ for group in self.param_groups:
440
+ for i, p in enumerate(group['params']):
441
+ self.step_parameter(p, group, i)
442
+
443
+ return loss
@@ -73,7 +73,7 @@ class AdamW_adv(torch.optim.Optimizer):
73
73
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
74
74
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
75
75
  logging (default: 0).
76
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
76
+ layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
77
77
  and returns a unique, hashable key representing its "layer" or "bucket".
78
78
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
79
79
  (default: None)
@@ -105,7 +105,7 @@ class AdamW_adv(torch.optim.Optimizer):
105
105
  tiny_spike: float = 1e-9,
106
106
  k_warmup_steps: int = 0,
107
107
  k_logging: int = 0,
108
- layer_key_fn: Optional[Callable] = None,
108
+ layer_key_kb_fn: Optional[Callable] = None,
109
109
  nnmf_factor: bool = False,
110
110
  _is_delegate: bool = False,
111
111
  ):
@@ -137,7 +137,7 @@ class AdamW_adv(torch.optim.Optimizer):
137
137
  self.use_AdEMAMix = use_AdEMAMix
138
138
  self.factored = nnmf_factor
139
139
  self.kourkoutas_beta = kourkoutas_beta
140
- self.layer_key_fn = layer_key_fn
140
+ self.layer_key_kb_fn = layer_key_kb_fn
141
141
  if not _is_delegate:
142
142
  super().__init__(params, defaults)
143
143
  else:
@@ -91,7 +91,7 @@ class Adopt_adv(torch.optim.Optimizer):
91
91
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
92
92
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
93
93
  logging (default: 0).
94
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
94
+ layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
95
95
  and returns a unique, hashable key representing its "layer" or "bucket".
96
96
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
97
97
  (default: None)
@@ -125,7 +125,7 @@ class Adopt_adv(torch.optim.Optimizer):
125
125
  tiny_spike: float = 1e-9,
126
126
  k_warmup_steps: int = 0,
127
127
  k_logging: int = 0,
128
- layer_key_fn: Optional[Callable] = None,
128
+ layer_key_kb_fn: Optional[Callable] = None,
129
129
  nnmf_factor: bool = False,
130
130
  ):
131
131
  if not (lr >= 0.0):
@@ -148,9 +148,6 @@ class Adopt_adv(torch.optim.Optimizer):
148
148
  print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
149
149
  if cautious_mask and Simplified_AdEMAMix:
150
150
  print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
151
- if use_atan2 and Simplified_AdEMAMix:
152
- print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
153
- use_atan2 = False
154
151
 
155
152
  defaults = {
156
153
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
@@ -169,7 +166,7 @@ class Adopt_adv(torch.optim.Optimizer):
169
166
  self.Simplified_AdEMAMix = Simplified_AdEMAMix
170
167
  self.factored = nnmf_factor
171
168
  self.kourkoutas_beta = kourkoutas_beta
172
- self.layer_key_fn = layer_key_fn
169
+ self.layer_key_kb_fn = layer_key_kb_fn
173
170
  super().__init__(params, defaults)
174
171
 
175
172
  if self.kourkoutas_beta:
@@ -22,7 +22,7 @@ class Muon_adv(torch.optim.Optimizer):
22
22
  can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
23
23
  flattening/reshaping them.
24
24
 
25
- This version can also operate in a hybrid mode, using an auxiliary AdamW
25
+ Can also operate in a hybrid mode, using an auxiliary AdamW
26
26
  optimizer for specific parameters (e.g., biases, norms, embeddings) as
27
27
  defined by a `layer_key_fn`.
28
28
 
@@ -38,6 +38,14 @@ class Muon_adv(torch.optim.Optimizer):
38
38
  ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
39
39
  quintic polynomial in the Newton-Schulz iteration.
40
40
  (default: (3.4445, -4.7750, 2.0315)).
41
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
42
+ This changes the update to `alpha_grad * grad + mt`, which can be
43
+ more responsive, especially for small batch sizes. (default: False)
44
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
45
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
46
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
47
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
48
+ stability. (default: 100.0)
41
49
  stochastic_rounding (bool): whether to use stochastic rounding for
42
50
  BF16 parameter updates (default: True).
43
51
  vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
@@ -68,9 +76,11 @@ class Muon_adv(torch.optim.Optimizer):
68
76
  ns_steps: int = 5,
69
77
  ns_eps: float = 1e-7,
70
78
  ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
79
+ Simplified_AdEMAMix: bool = False,
80
+ alpha_grad: float = 100.0,
71
81
  stochastic_rounding: bool = True,
72
82
  vector_reshape_muon: bool = False,
73
- vector_reshape: bool = True,
83
+ vector_reshape: bool = False,
74
84
  nnmf_factor: bool = False,
75
85
  # hybrid optimizer mode
76
86
  MuonWithAuxAdam: bool = False,
@@ -86,6 +96,9 @@ class Muon_adv(torch.optim.Optimizer):
86
96
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
87
97
  if not (ns_steps > 0):
88
98
  raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
99
+ if Simplified_AdEMAMix and nesterov:
100
+ print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
101
+ nesterov = False
89
102
 
90
103
  defaults = {
91
104
  "lr": lr, "beta1": beta1, "weight_decay": weight_decay,
@@ -93,6 +106,7 @@ class Muon_adv(torch.optim.Optimizer):
93
106
  "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
94
107
  "vector_reshape": vector_reshape,
95
108
  "vector_reshape_muon": vector_reshape_muon,
109
+ "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
96
110
  }
97
111
  self.stochastic_rounding = stochastic_rounding
98
112
 
@@ -130,6 +144,16 @@ class Muon_adv(torch.optim.Optimizer):
130
144
  def supports_flat_params(self):
131
145
  return False
132
146
 
147
+ @property
148
+ def kourkoutas_helper(self):
149
+ """
150
+ Exposes the kourkoutas_helper from the auxiliary AdamW optimizer,
151
+ if it exists. This allows external access for logging K-b.
152
+ """
153
+ if self.aux_adam and hasattr(self.aux_adam, 'kourkoutas_helper'):
154
+ return self.aux_adam.kourkoutas_helper
155
+ return None
156
+
133
157
  @torch.no_grad()
134
158
  def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
135
159
  if self.MuonWithAuxAdam:
@@ -165,7 +189,7 @@ class Muon_adv(torch.optim.Optimizer):
165
189
 
166
190
  dtype = torch.float32 if group['nnmf_factor'] else p.dtype
167
191
  device = p.device
168
- if group['vector_reshape'] or state['reshaped_1d_muon']:
192
+ if state['factored'] or state['reshaped_1d_muon']:
169
193
  state['effective_shape'] = _get_effective_shape(p.numel())
170
194
  d1, d2 = state['effective_shape']
171
195
  if state['factored']:
@@ -183,6 +207,8 @@ class Muon_adv(torch.optim.Optimizer):
183
207
 
184
208
  beta1 = group['beta1']
185
209
  nesterov = group['nesterov']
210
+ Simplified_AdEMAMix = group['Simplified_AdEMAMix']
211
+ alpha_grad = group['alpha_grad']
186
212
 
187
213
  if state['factored']: # Factored Muon
188
214
 
@@ -200,6 +226,8 @@ class Muon_adv(torch.optim.Optimizer):
200
226
  if nesterov:
201
227
  # Nesterov momentum
202
228
  update = grad_reshaped.add(mt_buf, alpha=beta1)
229
+ elif Simplified_AdEMAMix:
230
+ update = torch.add(mt_buf, grad_reshaped, alpha=alpha_grad)
203
231
  else:
204
232
  # Standard momentum
205
233
  update = mt_buf.clone()
@@ -238,6 +266,12 @@ class Muon_adv(torch.optim.Optimizer):
238
266
  del grad_reshaped
239
267
  else:
240
268
  update = grad.add(mt_buf, alpha=beta1)
269
+ elif Simplified_AdEMAMix:
270
+ if state['reshaped_1d_muon']:
271
+ update = torch.add(mt_buf, grad_reshaped, alpha=alpha_grad)
272
+ del grad_reshaped
273
+ else:
274
+ update = torch.add(mt_buf, grad, alpha=alpha_grad)
241
275
  else:
242
276
  # Standard momentum
243
277
  update = mt_buf.clone()
@@ -267,6 +301,8 @@ class Muon_adv(torch.optim.Optimizer):
267
301
  if nesterov:
268
302
  # Nesterov momentum
269
303
  update = grad.add(mt_buf, alpha=beta1)
304
+ elif Simplified_AdEMAMix:
305
+ update = torch.add(mt_buf, grad, alpha=alpha_grad)
270
306
  else:
271
307
  # Standard momentum
272
308
  update = mt_buf.clone()
@@ -109,7 +109,7 @@ class Prodigy_adv(torch.optim.Optimizer):
109
109
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
110
110
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
111
111
  logging (default: 0).
112
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
112
+ layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
113
113
  and returns a unique, hashable key representing its "layer" or "bucket".
114
114
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
115
115
  (default: None)
@@ -152,7 +152,7 @@ class Prodigy_adv(torch.optim.Optimizer):
152
152
  tiny_spike: float = 1e-9,
153
153
  k_warmup_steps: int = 0,
154
154
  k_logging: int = 0,
155
- layer_key_fn: Optional[Callable] = None,
155
+ layer_key_kb_fn: Optional[Callable] = None,
156
156
  ):
157
157
  if not (lr >= 0.0):
158
158
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -205,7 +205,7 @@ class Prodigy_adv(torch.optim.Optimizer):
205
205
  self.fsdp_in_use = fsdp_in_use
206
206
 
207
207
  self.kourkoutas_beta = kourkoutas_beta
208
- self.layer_key_fn = layer_key_fn
208
+ self.layer_key_kb_fn = layer_key_kb_fn
209
209
 
210
210
  super().__init__(params, defaults)
211
211
  if self.kourkoutas_beta:
@@ -67,7 +67,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
67
67
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
68
68
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
69
69
  logging (default: 0).
70
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
70
+ layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
71
71
  and returns a unique, hashable key representing its "layer" or "bucket".
72
72
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
73
73
  (default: None)
@@ -95,7 +95,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
95
95
  tiny_spike: float = 1e-9,
96
96
  k_warmup_steps: int = 0,
97
97
  k_logging: int = 0,
98
- layer_key_fn: Optional[Callable] = None,
98
+ layer_key_kb_fn: Optional[Callable] = None,
99
99
  nnmf_factor: bool = False,
100
100
  ):
101
101
  if not (lr >= 0.0):
@@ -121,7 +121,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
121
121
  self.stochastic_rounding = stochastic_rounding
122
122
  self.factored = nnmf_factor
123
123
  self.kourkoutas_beta = kourkoutas_beta
124
- self.layer_key_fn = layer_key_fn
124
+ self.layer_key_kb_fn = layer_key_kb_fn
125
125
  super().__init__(params, defaults)
126
126
 
127
127
  if self.kourkoutas_beta:
@@ -5,6 +5,7 @@ from .Simplified_AdEMAMix import Simplified_AdEMAMix
5
5
  from .Lion_adv import Lion_adv
6
6
  from .Lion_Prodigy_adv import Lion_Prodigy_adv
7
7
  from .Muon_adv import Muon_adv
8
+ from .AdaMuon_adv import AdaMuon_adv
8
9
 
9
10
  __all__ = [
10
11
  "AdamW_adv",
@@ -14,4 +15,5 @@ __all__ = [
14
15
  "Lion_adv",
15
16
  "Lion_Prodigy_adv",
16
17
  "Muon_adv",
18
+ "AdaMuon_adv",
17
19
  ]
@@ -32,12 +32,12 @@ class KourkoutasHelper:
32
32
  if self._layer_info_built:
33
33
  return
34
34
 
35
- if hasattr(self.optimizer, 'layer_key_fn') and self.optimizer.layer_key_fn is not None:
35
+ if hasattr(self.optimizer, 'layer_key_kb_fn') and self.optimizer.layer_key_kb_fn is not None:
36
36
  # A custom key function was provided by the user. We will use it.
37
37
  pass
38
38
  else:
39
39
  # No key function was provided. Default to coarse, shape-based bucketing.
40
- self.optimizer.layer_key_fn = lambda p: \
40
+ self.optimizer.layer_key_kb_fn = lambda p: \
41
41
  (id(p),) if p.dim() == 2 and 1 <= p.shape[0] <= 10 and p.shape[1] in {768, 1280, 4096} \
42
42
  else tuple(p.shape)
43
43
  # This ensures that we won't mix embeddings with tokens (1 to 10)
@@ -46,7 +46,7 @@ class KourkoutasHelper:
46
46
  for group in self.optimizer.param_groups:
47
47
  for p in group['params']:
48
48
  # The mapping is static and should not depend on the presence of a gradient.
49
- layer_key = self.optimizer.layer_key_fn(p)
49
+ layer_key = self.optimizer.layer_key_kb_fn(p)
50
50
  if layer_key not in self.layer_info:
51
51
  self.layer_info[layer_key] = {'params': [], 'group_ref': group}
52
52
  self.layer_info[layer_key]['params'].append(p)
@@ -158,7 +158,7 @@ class KourkoutasHelper:
158
158
  """
159
159
  Accumulates the squared L2 norm of a single gradient for the next step's calculation.
160
160
  """
161
- layer_key = self.optimizer.layer_key_fn(p)
161
+ layer_key = self.optimizer.layer_key_kb_fn(p)
162
162
 
163
163
  if layer_key in self.layer_info:
164
164
  # Initialize the transient state for this layer if it's the first time in the step.
@@ -173,6 +173,6 @@ class KourkoutasHelper:
173
173
  """
174
174
  Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
175
175
  """
176
- layer_key = self.optimizer.layer_key_fn(p)
176
+ layer_key = self.optimizer.layer_key_kb_fn(p)
177
177
  # The default is the max value, which is correct for unmapped params or edge cases
178
178
  return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
@@ -1,3 +1,4 @@
1
+ import torch
1
2
  from torch.optim import Optimizer
2
3
  from typing import Callable, Optional
3
4
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev2
3
+ Version: 1.2.dev3
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -7,6 +7,7 @@ adv_optm.egg-info/SOURCES.txt
7
7
  adv_optm.egg-info/dependency_links.txt
8
8
  adv_optm.egg-info/requires.txt
9
9
  adv_optm.egg-info/top_level.txt
10
+ adv_optm/optim/AdaMuon_adv.py
10
11
  adv_optm/optim/AdamW_adv.py
11
12
  adv_optm/optim/Adopt_adv.py
12
13
  adv_optm/optim/Lion_Prodigy_adv.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="1.2.dev2",
8
+ version="1.2.dev3",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes