adv-optm 1.1.3__tar.gz → 1.2.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (27) hide show
  1. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/PKG-INFO +1 -1
  2. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/__init__.py +3 -1
  3. adv_optm-1.2.dev1/adv_optm/optim/Muon_adv.py +247 -0
  4. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/__init__.py +2 -0
  5. adv_optm-1.2.dev1/adv_optm/util/Newton_Schulz.py +48 -0
  6. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/__init__.py +2 -1
  7. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm.egg-info/PKG-INFO +1 -1
  8. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm.egg-info/SOURCES.txt +2 -0
  9. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/setup.py +1 -1
  10. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/LICENSE +0 -0
  11. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/README.md +0 -0
  12. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/AdamW_adv.py +0 -0
  13. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/Adopt_adv.py +0 -0
  14. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
  15. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/Lion_adv.py +0 -0
  16. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/Prodigy_adv.py +0 -0
  17. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  18. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  19. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/Effective_Shape.py +0 -0
  20. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/Kourkoutas.py +0 -0
  21. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/NNMF.py +0 -0
  22. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/One_Bit_Boolean.py +0 -0
  23. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm/util/OrthoGrad.py +0 -0
  24. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm.egg-info/dependency_links.txt +0 -0
  25. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm.egg-info/requires.txt +0 -0
  26. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/adv_optm.egg-info/top_level.txt +0 -0
  27. {adv_optm-1.1.3 → adv_optm-1.2.dev1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.1.3
3
+ Version: 1.2.dev1
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -5,6 +5,7 @@ from .optim import (
5
5
  Simplified_AdEMAMix,
6
6
  Lion_adv,
7
7
  Lion_Prodigy_adv,
8
+ Muon_adv,
8
9
  )
9
10
 
10
11
  __all__ = [
@@ -14,6 +15,7 @@ __all__ = [
14
15
  "Simplified_AdEMAMix",
15
16
  "Lion_adv",
16
17
  "Lion_Prodigy_adv",
18
+ "Muon_adv",
17
19
  ]
18
20
 
19
- __version__ = "1.1.3"
21
+ __version__ = "1.2.dev1"
@@ -0,0 +1,247 @@
1
+ import torch
2
+ from typing import Optional
3
+
4
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
+ from ..util.Newton_Schulz import _newton_schulz_iteration
6
+ from ..util.Effective_Shape import _get_effective_shape
7
+ from ..util.NNMF import _nnmf,_unnmf
8
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+
10
+ class Muon_adv(torch.optim.Optimizer):
11
+ """
12
+ Implements an advanced Muon algorithm.
13
+
14
+ Muon (MomentUm Orthogonalized by Newton-Schulz) is an optimizer designed for
15
+ the hidden layers of neural networks. It applies SGD with momentum and then
16
+ orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
17
+
18
+ This implementation is designed for 2D parameters (e.g., linear layers) and
19
+ can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
20
+ flattening/reshaping them.
21
+
22
+ Args:
23
+ params (iterable): iterable of parameters to optimize or dicts defining
24
+ parameter groups.
25
+ lr (float): learning rate (default: 1e-3).
26
+ beta1 (float): momentum factor (default: 0.9).
27
+ weight_decay (float): weight decay (L2 penalty) (default: 0).
28
+ nesterov (bool): enables Nesterov momentum (default: True).
29
+ ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
30
+ ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
31
+ ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
32
+ quintic polynomial in the Newton-Schulz iteration.
33
+ (default: (3.4445, -4.7750, 2.0315)).
34
+ stochastic_rounding (bool): whether to use stochastic rounding for
35
+ BF16 parameter updates (default: True).
36
+ vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
37
+ matrices for muon NewtonSchulz (default: False).
38
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
39
+ matrices to apply low-rank compression (default: True).
40
+ nnmf_factor (bool): whether to use the factorization or disable it to use
41
+ the uncompressed optimizer. (default: False)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ params,
47
+ lr: float = 1e-3,
48
+ beta1: float = 0.9,
49
+ weight_decay: float = 0.0,
50
+ nesterov: bool = True,
51
+ ns_steps: int = 5,
52
+ ns_eps: float = 1e-7,
53
+ ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
54
+ stochastic_rounding: bool = True,
55
+ vector_reshape_muon: bool = False,
56
+ vector_reshape: bool = True,
57
+ nnmf_factor: bool = False,
58
+ ):
59
+ if not (lr >= 0.0):
60
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
61
+ if not (0.0 <= beta1 < 1.0):
62
+ raise ValueError(f"beta1 should be in [0.0, 1.0). Got {beta1}")
63
+ if not (weight_decay >= 0.0):
64
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
65
+ if not (ns_steps > 0):
66
+ raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
67
+
68
+ defaults = {
69
+ "lr": lr, "beta1": beta1, "weight_decay": weight_decay,
70
+ "nesterov": nesterov, "ns_steps": ns_steps, "ns_eps": ns_eps,
71
+ "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
72
+ "vector_reshape": vector_reshape,
73
+ "vector_reshape_muon": vector_reshape_muon,
74
+ }
75
+ self.stochastic_rounding = stochastic_rounding
76
+ super().__init__(params, defaults)
77
+
78
+ @property
79
+ def supports_fused_back_pass(self):
80
+ return True
81
+
82
+ @property
83
+ def supports_memory_efficient_fp16(self):
84
+ return True
85
+
86
+ @property
87
+ def supports_flat_params(self):
88
+ return False
89
+
90
+ @torch.no_grad()
91
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
92
+ if p.grad is None:
93
+ return
94
+
95
+ grad = p.grad
96
+ state = self.state[p]
97
+
98
+ # State Initialization
99
+ if 'step' not in state:
100
+ state['step'] = 0
101
+
102
+ should_factor = (
103
+ group['nnmf_factor'] and
104
+ not (len(p.shape) == 1 and not group['vector_reshape'])
105
+ )
106
+
107
+ state['factored'] = should_factor
108
+
109
+ state['reshaped_1d_muon'] = len(p.shape) == 1 and group['vector_reshape_muon']
110
+
111
+ dtype = torch.float32 if group['nnmf_factor'] else p.dtype
112
+ device = p.device
113
+ if group['vector_reshape'] or state['reshaped_1d_muon']:
114
+ state['effective_shape'] = _get_effective_shape(p.numel())
115
+ d1, d2 = state['effective_shape']
116
+ if state['factored']:
117
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
118
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
119
+ packed_d2 = (d2 + 7) // 8
120
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
121
+ else:
122
+ if len(p.shape) >= 2:
123
+ state['momentum_buffer'] = torch.zeros_like(p)
124
+ if state['reshaped_1d_muon']:
125
+ state['momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
126
+ elif len(p.shape) == 1:
127
+ state['momentum_buffer'] = torch.zeros_like(p)
128
+
129
+ beta1 = group['beta1']
130
+ nesterov = group['nesterov']
131
+
132
+ if state['factored']: # Factored Muon
133
+
134
+ # Reconstruct momentum from previous step's factors & sign
135
+ d1, d2 = state['effective_shape']
136
+ mt_buf = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
137
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
138
+ torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
139
+ del unpacked_sign
140
+
141
+ # Update momentum in full-size
142
+ grad_reshaped = grad.view(d1, d2)
143
+ mt_buf.mul_(beta1).add_(grad_reshaped)
144
+
145
+ if nesterov:
146
+ # Nesterov momentum
147
+ update = grad_reshaped.add(mt_buf, alpha=beta1)
148
+ else:
149
+ # Standard momentum
150
+ update = mt_buf.clone()
151
+ del grad_reshaped
152
+
153
+ update = _newton_schulz_iteration(
154
+ update,
155
+ steps=group['ns_steps'],
156
+ eps=group['ns_eps'],
157
+ coeffs=group['ns_coeffs'],
158
+ )
159
+
160
+ update = update.view(p.shape).mul_(group['lr'])
161
+
162
+ state['sign'] = _pack_bools(mt_buf > 0)
163
+ _nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
164
+ del mt_buf
165
+
166
+ else: # Standard Muon logic for non-factored tensors
167
+
168
+ if len(p.shape) >= 2 or state['reshaped_1d_muon']:
169
+
170
+ # Momentum update
171
+ mt_buf = state['momentum_buffer']
172
+ if state['reshaped_1d_muon']:
173
+ d1, d2 = state['effective_shape']
174
+ grad_reshaped = grad.view(d1, d2)
175
+ mt_buf.mul_(beta1).add_(grad_reshaped)
176
+ else:
177
+ mt_buf.mul_(beta1).add_(grad)
178
+
179
+ if nesterov:
180
+ # Nesterov momentum
181
+ if state['reshaped_1d_muon']:
182
+ update = grad_reshaped.add(mt_buf, alpha=beta1)
183
+ del grad_reshaped
184
+ else:
185
+ update = grad.add(mt_buf, alpha=beta1)
186
+ else:
187
+ # Standard momentum
188
+ update = mt_buf.clone()
189
+
190
+ # For Conv layers (4D) or other high-dim tensors, flatten to 2D
191
+ if len(p.shape) > 2:
192
+ update = update.view(p.shape[0], -1)
193
+
194
+ # NewtonSchulz
195
+ update = _newton_schulz_iteration(
196
+ update,
197
+ steps=group['ns_steps'],
198
+ eps=group['ns_eps'],
199
+ coeffs=group['ns_coeffs'],
200
+ )
201
+
202
+ # Reshape back to original if we flattened or reshaped
203
+ if len(p.shape) > 2 or state['reshaped_1d_muon']:
204
+ update = update.view(p.shape)
205
+
206
+ update.mul_(group['lr'])
207
+
208
+ else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
209
+ # Momentum update
210
+ mt_buf = state['momentum_buffer']
211
+ mt_buf.mul_(beta1).add_(grad)
212
+ if nesterov:
213
+ # Nesterov momentum
214
+ update = grad.add(mt_buf, alpha=beta1)
215
+ else:
216
+ # Standard momentum
217
+ update = mt_buf.clone()
218
+ update.mul_(group['lr'])
219
+
220
+ # Decoupled weight decay
221
+ if group["weight_decay"] != 0:
222
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
223
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
224
+ else:
225
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
226
+
227
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
228
+ add_stochastic_(p.data, -update)
229
+ else:
230
+ p.data.add_(-update)
231
+ del update
232
+
233
+ state['step'] += 1
234
+
235
+ @torch.no_grad()
236
+ def step(self, closure=None):
237
+ """Performs a single optimization step."""
238
+ loss = None
239
+ if closure is not None:
240
+ with torch.enable_grad():
241
+ loss = closure()
242
+
243
+ for group in self.param_groups:
244
+ for i, p in enumerate(group['params']):
245
+ self.step_parameter(p, group, i)
246
+
247
+ return loss
@@ -4,6 +4,7 @@ from .Adopt_adv import Adopt_adv
4
4
  from .Simplified_AdEMAMix import Simplified_AdEMAMix
5
5
  from .Lion_adv import Lion_adv
6
6
  from .Lion_Prodigy_adv import Lion_Prodigy_adv
7
+ from .Muon_adv import Muon_adv
7
8
 
8
9
  __all__ = [
9
10
  "AdamW_adv",
@@ -12,4 +13,5 @@ __all__ = [
12
13
  "Simplified_AdEMAMix",
13
14
  "Lion_adv",
14
15
  "Lion_Prodigy_adv",
16
+ "Muon_adv",
15
17
  ]
@@ -0,0 +1,48 @@
1
+ import torch
2
+
3
+ @torch.no_grad()
4
+ def _newton_schulz_iteration(
5
+ G: torch.Tensor,
6
+ steps: int = 5,
7
+ eps: float = 1e-7,
8
+ coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315)
9
+ ) -> torch.Tensor:
10
+ """
11
+ Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
12
+ This is the core computation of the Muon optimizer.
13
+
14
+ Args:
15
+ G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
16
+ steps (int): The number of iterations to run.
17
+ eps (float): Small constant for numerical stability during normalization.
18
+ coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
19
+ quintic polynomial update.
20
+
21
+ Returns:
22
+ torch.Tensor: The orthogonalized matrix.
23
+ """
24
+ assert G.ndim == 2, "Newton-Schulz iteration only supports 2D matrices."
25
+
26
+ a, b, c = coeffs
27
+
28
+ X = G.to(torch.bfloat16)
29
+
30
+ # Normalize the matrix
31
+ X.div_(X.norm() + eps)
32
+
33
+ # Handle non-square matrices by transposing the taller one
34
+ transposed = G.size(0) > G.size(1)
35
+ if transposed:
36
+ X = X.T
37
+
38
+ # Perform the iterative updates
39
+ for _ in range(steps):
40
+ A = X @ X.T
41
+ B = b * A + c * (A @ A)
42
+ X = a * X + B @ X
43
+
44
+ # Transpose back if necessary
45
+ if transposed:
46
+ X = X.T
47
+
48
+ return X.to(G.dtype)
@@ -2,10 +2,11 @@ from .BF16_Stochastic_Rounding import add_stochastic_
2
2
  from .Effective_Shape import _get_effective_shape
3
3
  from .One_Bit_Boolean import _pack_bools, _unpack_bools
4
4
  from .OrthoGrad import _orthogonalize_gradient
5
-
5
+ from .Newton_Schulz import _newton_schulz_iteration
6
6
  __all__ = [
7
7
  "_pack_bools", "_unpack_bools",
8
8
  "add_stochastic_",
9
9
  "_get_effective_shape",
10
10
  "_orthogonalize_gradient",
11
+ "_newton_schulz_iteration",
11
12
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.1.3
3
+ Version: 1.2.dev1
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -11,6 +11,7 @@ adv_optm/optim/AdamW_adv.py
11
11
  adv_optm/optim/Adopt_adv.py
12
12
  adv_optm/optim/Lion_Prodigy_adv.py
13
13
  adv_optm/optim/Lion_adv.py
14
+ adv_optm/optim/Muon_adv.py
14
15
  adv_optm/optim/Prodigy_adv.py
15
16
  adv_optm/optim/Simplified_AdEMAMix.py
16
17
  adv_optm/optim/__init__.py
@@ -18,6 +19,7 @@ adv_optm/util/BF16_Stochastic_Rounding.py
18
19
  adv_optm/util/Effective_Shape.py
19
20
  adv_optm/util/Kourkoutas.py
20
21
  adv_optm/util/NNMF.py
22
+ adv_optm/util/Newton_Schulz.py
21
23
  adv_optm/util/One_Bit_Boolean.py
22
24
  adv_optm/util/OrthoGrad.py
23
25
  adv_optm/util/__init__.py
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
5
5
 
6
6
  setup(
7
7
  name="adv_optm",
8
- version="1.1.3",
8
+ version="1.2.dev1",
9
9
  author="Koratahiu",
10
10
  author_email="hiuhonor@gmail.com",
11
11
  license='Apache 2.0',
File without changes
File without changes
File without changes