adv-optm 1.1.2__tar.gz → 1.2.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (27) hide show
  1. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/PKG-INFO +1 -1
  2. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/__init__.py +3 -1
  3. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/optim/AdamW_adv.py +2 -2
  4. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/optim/Adopt_adv.py +1 -1
  5. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/optim/Lion_Prodigy_adv.py +1 -1
  6. adv_optm-1.2.dev1/adv_optm/optim/Muon_adv.py +247 -0
  7. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/optim/Prodigy_adv.py +2 -2
  8. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/optim/Simplified_AdEMAMix.py +2 -2
  9. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/optim/__init__.py +2 -0
  10. adv_optm-1.2.dev1/adv_optm/util/Newton_Schulz.py +48 -0
  11. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/util/__init__.py +2 -1
  12. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm.egg-info/PKG-INFO +1 -1
  13. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm.egg-info/SOURCES.txt +2 -0
  14. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/setup.py +1 -1
  15. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/LICENSE +0 -0
  16. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/README.md +0 -0
  17. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/optim/Lion_adv.py +0 -0
  18. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  19. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/util/Effective_Shape.py +0 -0
  20. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/util/Kourkoutas.py +0 -0
  21. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/util/NNMF.py +0 -0
  22. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/util/One_Bit_Boolean.py +0 -0
  23. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm/util/OrthoGrad.py +0 -0
  24. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm.egg-info/dependency_links.txt +0 -0
  25. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm.egg-info/requires.txt +0 -0
  26. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/adv_optm.egg-info/top_level.txt +0 -0
  27. {adv_optm-1.1.2 → adv_optm-1.2.dev1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.1.2
3
+ Version: 1.2.dev1
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,6 +5,7 @@ from .optim import (
5
5
  Simplified_AdEMAMix,
6
6
  Lion_adv,
7
7
  Lion_Prodigy_adv,
8
+ Muon_adv,
8
9
  )
9
10
 
10
11
  __all__ = [
@@ -14,6 +15,7 @@ __all__ = [
14
15
  "Simplified_AdEMAMix",
15
16
  "Lion_adv",
16
17
  "Lion_Prodigy_adv",
18
+ "Muon_adv",
17
19
  ]
18
20
 
19
- __version__ = "1.1.2"
21
+ __version__ = "1.2.dev1"
@@ -209,7 +209,7 @@ class AdamW_adv(torch.optim.Optimizer):
209
209
  beta1, beta2 = group['betas']
210
210
 
211
211
  current_step = state['step']
212
- if group['kourkoutas_beta']:
212
+ if group.get('kourkoutas_beta', False):
213
213
  # Call prepare_step() once at the beginning of the step for all params
214
214
  self.kourkoutas_helper.maybe_prepare_step(current_step)
215
215
  # Accumulate current grad's norm for the *next* step
@@ -220,7 +220,7 @@ class AdamW_adv(torch.optim.Optimizer):
220
220
  step = state['step'] + 1
221
221
  if group['use_bias_correction']:
222
222
  bias_correction1 = 1.0 - beta1 ** step
223
- if group['kourkoutas_beta']:
223
+ if group.get('kourkoutas_beta', False):
224
224
  bias_correction2 = 1.0 - group['betas'][1] ** step
225
225
  # Use beta2_max for bias correction
226
226
  else:
@@ -240,7 +240,7 @@ class Adopt_adv(torch.optim.Optimizer):
240
240
  beta1, beta2 = group['betas']
241
241
 
242
242
  current_step = state['step']
243
- if group['kourkoutas_beta']:
243
+ if group.get('kourkoutas_beta', False):
244
244
  # Call prepare_step() once at the beginning of the step for all params
245
245
  self.kourkoutas_helper.maybe_prepare_step(current_step)
246
246
  # Accumulate current grad's norm for the *next* step
@@ -325,7 +325,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
325
325
  d_hat = self.d
326
326
  if global_d_denom > 0:
327
327
  d_hat = d_coef * global_d_numerator / global_d_denom
328
- if g_group['d_limiter']:
328
+ if g_group.get('d_limiter', False):
329
329
  d_hat = min(self.d * (2 ** 0.25), d_hat)
330
330
  if self.d == g_group['d0']:
331
331
  self.d = max(self.d, d_hat)
@@ -0,0 +1,247 @@
1
+ import torch
2
+ from typing import Optional
3
+
4
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
+ from ..util.Newton_Schulz import _newton_schulz_iteration
6
+ from ..util.Effective_Shape import _get_effective_shape
7
+ from ..util.NNMF import _nnmf,_unnmf
8
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+
10
+ class Muon_adv(torch.optim.Optimizer):
11
+ """
12
+ Implements an advanced Muon algorithm.
13
+
14
+ Muon (MomentUm Orthogonalized by Newton-Schulz) is an optimizer designed for
15
+ the hidden layers of neural networks. It applies SGD with momentum and then
16
+ orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
17
+
18
+ This implementation is designed for 2D parameters (e.g., linear layers) and
19
+ can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
20
+ flattening/reshaping them.
21
+
22
+ Args:
23
+ params (iterable): iterable of parameters to optimize or dicts defining
24
+ parameter groups.
25
+ lr (float): learning rate (default: 1e-3).
26
+ beta1 (float): momentum factor (default: 0.9).
27
+ weight_decay (float): weight decay (L2 penalty) (default: 0).
28
+ nesterov (bool): enables Nesterov momentum (default: True).
29
+ ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
30
+ ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
31
+ ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
32
+ quintic polynomial in the Newton-Schulz iteration.
33
+ (default: (3.4445, -4.7750, 2.0315)).
34
+ stochastic_rounding (bool): whether to use stochastic rounding for
35
+ BF16 parameter updates (default: True).
36
+ vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
37
+ matrices for muon NewtonSchulz (default: False).
38
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
39
+ matrices to apply low-rank compression (default: True).
40
+ nnmf_factor (bool): whether to use the factorization or disable it to use
41
+ the uncompressed optimizer. (default: False)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ params,
47
+ lr: float = 1e-3,
48
+ beta1: float = 0.9,
49
+ weight_decay: float = 0.0,
50
+ nesterov: bool = True,
51
+ ns_steps: int = 5,
52
+ ns_eps: float = 1e-7,
53
+ ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
54
+ stochastic_rounding: bool = True,
55
+ vector_reshape_muon: bool = False,
56
+ vector_reshape: bool = True,
57
+ nnmf_factor: bool = False,
58
+ ):
59
+ if not (lr >= 0.0):
60
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
61
+ if not (0.0 <= beta1 < 1.0):
62
+ raise ValueError(f"beta1 should be in [0.0, 1.0). Got {beta1}")
63
+ if not (weight_decay >= 0.0):
64
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
65
+ if not (ns_steps > 0):
66
+ raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
67
+
68
+ defaults = {
69
+ "lr": lr, "beta1": beta1, "weight_decay": weight_decay,
70
+ "nesterov": nesterov, "ns_steps": ns_steps, "ns_eps": ns_eps,
71
+ "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
72
+ "vector_reshape": vector_reshape,
73
+ "vector_reshape_muon": vector_reshape_muon,
74
+ }
75
+ self.stochastic_rounding = stochastic_rounding
76
+ super().__init__(params, defaults)
77
+
78
+ @property
79
+ def supports_fused_back_pass(self):
80
+ return True
81
+
82
+ @property
83
+ def supports_memory_efficient_fp16(self):
84
+ return True
85
+
86
+ @property
87
+ def supports_flat_params(self):
88
+ return False
89
+
90
+ @torch.no_grad()
91
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
92
+ if p.grad is None:
93
+ return
94
+
95
+ grad = p.grad
96
+ state = self.state[p]
97
+
98
+ # State Initialization
99
+ if 'step' not in state:
100
+ state['step'] = 0
101
+
102
+ should_factor = (
103
+ group['nnmf_factor'] and
104
+ not (len(p.shape) == 1 and not group['vector_reshape'])
105
+ )
106
+
107
+ state['factored'] = should_factor
108
+
109
+ state['reshaped_1d_muon'] = len(p.shape) == 1 and group['vector_reshape_muon']
110
+
111
+ dtype = torch.float32 if group['nnmf_factor'] else p.dtype
112
+ device = p.device
113
+ if group['vector_reshape'] or state['reshaped_1d_muon']:
114
+ state['effective_shape'] = _get_effective_shape(p.numel())
115
+ d1, d2 = state['effective_shape']
116
+ if state['factored']:
117
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
118
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
119
+ packed_d2 = (d2 + 7) // 8
120
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
121
+ else:
122
+ if len(p.shape) >= 2:
123
+ state['momentum_buffer'] = torch.zeros_like(p)
124
+ if state['reshaped_1d_muon']:
125
+ state['momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
126
+ elif len(p.shape) == 1:
127
+ state['momentum_buffer'] = torch.zeros_like(p)
128
+
129
+ beta1 = group['beta1']
130
+ nesterov = group['nesterov']
131
+
132
+ if state['factored']: # Factored Muon
133
+
134
+ # Reconstruct momentum from previous step's factors & sign
135
+ d1, d2 = state['effective_shape']
136
+ mt_buf = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
137
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
138
+ torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
139
+ del unpacked_sign
140
+
141
+ # Update momentum in full-size
142
+ grad_reshaped = grad.view(d1, d2)
143
+ mt_buf.mul_(beta1).add_(grad_reshaped)
144
+
145
+ if nesterov:
146
+ # Nesterov momentum
147
+ update = grad_reshaped.add(mt_buf, alpha=beta1)
148
+ else:
149
+ # Standard momentum
150
+ update = mt_buf.clone()
151
+ del grad_reshaped
152
+
153
+ update = _newton_schulz_iteration(
154
+ update,
155
+ steps=group['ns_steps'],
156
+ eps=group['ns_eps'],
157
+ coeffs=group['ns_coeffs'],
158
+ )
159
+
160
+ update = update.view(p.shape).mul_(group['lr'])
161
+
162
+ state['sign'] = _pack_bools(mt_buf > 0)
163
+ _nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
164
+ del mt_buf
165
+
166
+ else: # Standard Muon logic for non-factored tensors
167
+
168
+ if len(p.shape) >= 2 or state['reshaped_1d_muon']:
169
+
170
+ # Momentum update
171
+ mt_buf = state['momentum_buffer']
172
+ if state['reshaped_1d_muon']:
173
+ d1, d2 = state['effective_shape']
174
+ grad_reshaped = grad.view(d1, d2)
175
+ mt_buf.mul_(beta1).add_(grad_reshaped)
176
+ else:
177
+ mt_buf.mul_(beta1).add_(grad)
178
+
179
+ if nesterov:
180
+ # Nesterov momentum
181
+ if state['reshaped_1d_muon']:
182
+ update = grad_reshaped.add(mt_buf, alpha=beta1)
183
+ del grad_reshaped
184
+ else:
185
+ update = grad.add(mt_buf, alpha=beta1)
186
+ else:
187
+ # Standard momentum
188
+ update = mt_buf.clone()
189
+
190
+ # For Conv layers (4D) or other high-dim tensors, flatten to 2D
191
+ if len(p.shape) > 2:
192
+ update = update.view(p.shape[0], -1)
193
+
194
+ # NewtonSchulz
195
+ update = _newton_schulz_iteration(
196
+ update,
197
+ steps=group['ns_steps'],
198
+ eps=group['ns_eps'],
199
+ coeffs=group['ns_coeffs'],
200
+ )
201
+
202
+ # Reshape back to original if we flattened or reshaped
203
+ if len(p.shape) > 2 or state['reshaped_1d_muon']:
204
+ update = update.view(p.shape)
205
+
206
+ update.mul_(group['lr'])
207
+
208
+ else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
209
+ # Momentum update
210
+ mt_buf = state['momentum_buffer']
211
+ mt_buf.mul_(beta1).add_(grad)
212
+ if nesterov:
213
+ # Nesterov momentum
214
+ update = grad.add(mt_buf, alpha=beta1)
215
+ else:
216
+ # Standard momentum
217
+ update = mt_buf.clone()
218
+ update.mul_(group['lr'])
219
+
220
+ # Decoupled weight decay
221
+ if group["weight_decay"] != 0:
222
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
223
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
224
+ else:
225
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
226
+
227
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
228
+ add_stochastic_(p.data, -update)
229
+ else:
230
+ p.data.add_(-update)
231
+ del update
232
+
233
+ state['step'] += 1
234
+
235
+ @torch.no_grad()
236
+ def step(self, closure=None):
237
+ """Performs a single optimization step."""
238
+ loss = None
239
+ if closure is not None:
240
+ with torch.enable_grad():
241
+ loss = closure()
242
+
243
+ for group in self.param_groups:
244
+ for i, p in enumerate(group['params']):
245
+ self.step_parameter(p, group, i)
246
+
247
+ return loss
@@ -304,7 +304,7 @@ class Prodigy_adv(torch.optim.Optimizer):
304
304
  state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
305
305
 
306
306
  current_step = state['step']
307
- if group['kourkoutas_beta']:
307
+ if group.get('kourkoutas_beta', False):
308
308
  # Call prepare_step() once at the beginning of the step for all params
309
309
  self.kourkoutas_helper.maybe_prepare_step(current_step)
310
310
  # Accumulate current grad's norm for the *next* step
@@ -515,7 +515,7 @@ class Prodigy_adv(torch.optim.Optimizer):
515
515
  d_hat = self.d
516
516
  if global_d_denom > 0:
517
517
  d_hat = d_coef * global_d_numerator / global_d_denom
518
- if g_group['d_limiter']:
518
+ if g_group.get('d_limiter', False):
519
519
  d_hat = min(self.d * (2 ** 0.25), d_hat)
520
520
  if self.d == g_group['d0']:
521
521
  self.d = max(self.d, d_hat)
@@ -191,7 +191,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
191
191
  beta1_final, beta2 = group["betas"]
192
192
 
193
193
  current_step = state['step']
194
- if group['kourkoutas_beta']:
194
+ if group.get('kourkoutas_beta', False):
195
195
  # Call prepare_step() once at the beginning of the step for all params
196
196
  self.kourkoutas_helper.maybe_prepare_step(current_step)
197
197
  # Accumulate current grad's norm for the *next* step
@@ -210,7 +210,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
210
210
 
211
211
  if group['use_bias_correction']:
212
212
  state['num_sum'] = beta1 * state['num_sum'] + 1.0
213
- if group['kourkoutas_beta']:
213
+ if group.get('kourkoutas_beta', False):
214
214
  state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
215
215
  else:
216
216
  state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
@@ -4,6 +4,7 @@ from .Adopt_adv import Adopt_adv
4
4
  from .Simplified_AdEMAMix import Simplified_AdEMAMix
5
5
  from .Lion_adv import Lion_adv
6
6
  from .Lion_Prodigy_adv import Lion_Prodigy_adv
7
+ from .Muon_adv import Muon_adv
7
8
 
8
9
  __all__ = [
9
10
  "AdamW_adv",
@@ -12,4 +13,5 @@ __all__ = [
12
13
  "Simplified_AdEMAMix",
13
14
  "Lion_adv",
14
15
  "Lion_Prodigy_adv",
16
+ "Muon_adv",
15
17
  ]
@@ -0,0 +1,48 @@
1
+ import torch
2
+
3
+ @torch.no_grad()
4
+ def _newton_schulz_iteration(
5
+ G: torch.Tensor,
6
+ steps: int = 5,
7
+ eps: float = 1e-7,
8
+ coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
9
+ ) -> torch.Tensor:
10
+ """
11
+ Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
12
+ This is the core computation of the Muon optimizer.
13
+
14
+ Args:
15
+ G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
16
+ steps (int): The number of iterations to run.
17
+ eps (float): Small constant for numerical stability during normalization.
18
+ coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
19
+ quintic polynomial update.
20
+
21
+ Returns:
22
+ torch.Tensor: The orthogonalized matrix.
23
+ """
24
+ assert G.ndim == 2, "Newton-Schulz iteration only supports 2D matrices."
25
+
26
+ a, b, c = coeffs
27
+
28
+ X = G.to(torch.bfloat16)
29
+
30
+ # Normalize the matrix
31
+ X.div_(X.norm() + eps)
32
+
33
+ # Handle non-square matrices by transposing the taller one
34
+ transposed = G.size(0) > G.size(1)
35
+ if transposed:
36
+ X = X.T
37
+
38
+ # Perform the iterative updates
39
+ for _ in range(steps):
40
+ A = X @ X.T
41
+ B = b * A + c * (A @ A)
42
+ X = a * X + B @ X
43
+
44
+ # Transpose back if necessary
45
+ if transposed:
46
+ X = X.T
47
+
48
+ return X.to(G.dtype)
@@ -2,10 +2,11 @@ from .BF16_Stochastic_Rounding import add_stochastic_
2
2
  from .Effective_Shape import _get_effective_shape
3
3
  from .One_Bit_Boolean import _pack_bools, _unpack_bools
4
4
  from .OrthoGrad import _orthogonalize_gradient
5
-
5
+ from .Newton_Schulz import _newton_schulz_iteration
6
6
  __all__ = [
7
7
  "_pack_bools", "_unpack_bools",
8
8
  "add_stochastic_",
9
9
  "_get_effective_shape",
10
10
  "_orthogonalize_gradient",
11
+ "_newton_schulz_iteration",
11
12
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.1.2
3
+ Version: 1.2.dev1
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -11,6 +11,7 @@ adv_optm/optim/AdamW_adv.py
11
11
  adv_optm/optim/Adopt_adv.py
12
12
  adv_optm/optim/Lion_Prodigy_adv.py
13
13
  adv_optm/optim/Lion_adv.py
14
+ adv_optm/optim/Muon_adv.py
14
15
  adv_optm/optim/Prodigy_adv.py
15
16
  adv_optm/optim/Simplified_AdEMAMix.py
16
17
  adv_optm/optim/__init__.py
@@ -18,6 +19,7 @@ adv_optm/util/BF16_Stochastic_Rounding.py
18
19
  adv_optm/util/Effective_Shape.py
19
20
  adv_optm/util/Kourkoutas.py
20
21
  adv_optm/util/NNMF.py
22
+ adv_optm/util/Newton_Schulz.py
21
23
  adv_optm/util/One_Bit_Boolean.py
22
24
  adv_optm/util/OrthoGrad.py
23
25
  adv_optm/util/__init__.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="1.1.2",
8
+ version="1.2.dev1",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes