adv-optm 1.2.dev2__tar.gz → 1.2.dev4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/PKG-INFO +1 -1
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/__init__.py +3 -1
- adv_optm-1.2.dev4/adv_optm/optim/AdaMuon_adv.py +465 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/optim/AdamW_adv.py +3 -3
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/optim/Adopt_adv.py +3 -6
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/optim/Muon_adv.py +74 -20
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/optim/Prodigy_adv.py +3 -3
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/optim/Simplified_AdEMAMix.py +3 -3
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/optim/__init__.py +2 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/util/Kourkoutas.py +5 -5
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/util/MuonAdam_helper.py +1 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm.egg-info/SOURCES.txt +1 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/setup.py +1 -1
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/LICENSE +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/README.md +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/util/Newton_Schulz.py +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm/util/__init__.py +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.2.dev2 → adv_optm-1.2.dev4}/setup.cfg +0 -0
|
@@ -6,6 +6,7 @@ from .optim import (
|
|
|
6
6
|
Lion_adv,
|
|
7
7
|
Lion_Prodigy_adv,
|
|
8
8
|
Muon_adv,
|
|
9
|
+
AdaMuon_adv,
|
|
9
10
|
)
|
|
10
11
|
|
|
11
12
|
__all__ = [
|
|
@@ -16,6 +17,7 @@ __all__ = [
|
|
|
16
17
|
"Lion_adv",
|
|
17
18
|
"Lion_Prodigy_adv",
|
|
18
19
|
"Muon_adv",
|
|
20
|
+
"AdaMuon_adv",
|
|
19
21
|
]
|
|
20
22
|
|
|
21
|
-
__version__ = "1.2.
|
|
23
|
+
__version__ = "1.2.dev4"
|
|
@@ -0,0 +1,465 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional, Callable
|
|
3
|
+
|
|
4
|
+
from .AdamW_adv import AdamW_adv
|
|
5
|
+
from ..util.MuonAdam_helper import MuonAdamHelper
|
|
6
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
7
|
+
|
|
8
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
9
|
+
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
10
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
11
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
12
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
13
|
+
|
|
14
|
+
class AdaMuon_adv(torch.optim.Optimizer):
|
|
15
|
+
"""
|
|
16
|
+
Implements the AdaMuon optimizer algorithm.
|
|
17
|
+
|
|
18
|
+
AdaMuon combines the geometry-aware updates of Muon with the element-wise
|
|
19
|
+
adaptivity of Adam. It is designed for 2D parameters (e.g., linear layers)
|
|
20
|
+
and can handle higher-dimensional parameters by flattening.
|
|
21
|
+
|
|
22
|
+
The algorithm incorporates three key mechanisms:
|
|
23
|
+
1. A sign-stabilized orthogonal update, where the sign of the momentum is
|
|
24
|
+
orthogonalized instead of the momentum itself.
|
|
25
|
+
2. An element-wise second momentum estimator applied to the orthogonalized
|
|
26
|
+
update directions.
|
|
27
|
+
3. An RMS-aligned rescaling strategy to match the update magnitude of Adam,
|
|
28
|
+
allowing for reuse of learning rate schedules.
|
|
29
|
+
|
|
30
|
+
Can also operate in a hybrid mode, using an auxiliary AdamW
|
|
31
|
+
optimizer for specific parameters (e.g., biases, norms, embeddings) as
|
|
32
|
+
defined by a `layer_key_fn`.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
36
|
+
parameter groups.
|
|
37
|
+
lr (float): learning rate (default: 1e-3).
|
|
38
|
+
betas (tuple[float, float]): coefficients used for both first and second moment
|
|
39
|
+
estimation (default: (0.95, 0.95))
|
|
40
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0.1).
|
|
41
|
+
eps (float): term added to the denominator for adaptive scaling to improve
|
|
42
|
+
numerical stability (default: 1e-8).
|
|
43
|
+
rms_target (float): The target Root-Mean-Square value for the final update
|
|
44
|
+
vector, used for RMS-aligned rescaling. Allows for the reuse of existing Adam
|
|
45
|
+
learning rate schedules. (default: 0.2).
|
|
46
|
+
ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
|
|
47
|
+
ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
|
|
48
|
+
ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
49
|
+
quintic polynomial in the Newton-Schulz iteration.
|
|
50
|
+
(default: (3.4445, -4.7750, 2.0315)).
|
|
51
|
+
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
52
|
+
BF16 parameter updates (default: True).
|
|
53
|
+
nesterov (bool): enables Nesterov momentum (default: False).
|
|
54
|
+
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
55
|
+
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
56
|
+
This changes the update to `alpha_grad * grad + mt`, which can be
|
|
57
|
+
more responsive, especially for small batch sizes. (default: False)
|
|
58
|
+
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
59
|
+
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
60
|
+
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
61
|
+
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
62
|
+
stability. (default: 100.0)
|
|
63
|
+
vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
|
|
64
|
+
matrices for muon NewtonSchulz (default: False).
|
|
65
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
66
|
+
matrices to apply low-rank compression (default: True).
|
|
67
|
+
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
68
|
+
the uncompressed optimizer. (default: False)
|
|
69
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
70
|
+
If `False`, the optimizer behaves as standard AdamW. (default: False)
|
|
71
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
72
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
73
|
+
(default: 0.88)
|
|
74
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
75
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
76
|
+
(default: 0.93)
|
|
77
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
78
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
79
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
80
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
81
|
+
at a fixed beta2 value before the
|
|
82
|
+
dynamic logic activates. (default: 0)
|
|
83
|
+
MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
|
|
84
|
+
Parameters designated by `layer_key_fn` will be optimized with
|
|
85
|
+
AdamW_adv instead of Muon. (default: False)
|
|
86
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
87
|
+
and returns a key. If the key is 'adam', the parameter is handled by
|
|
88
|
+
the auxiliary AdamW optimizer. All other keys are handled by Muon.
|
|
89
|
+
Only used when `MuonWithAuxAdam` is True. (default: None)
|
|
90
|
+
adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
|
|
91
|
+
to the auxiliary AdamW_adv optimizer. Only used when
|
|
92
|
+
`MuonWithAuxAdam` is True. (default: None)
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
params,
|
|
98
|
+
lr: float = 1e-3,
|
|
99
|
+
betas: tuple[float, float] = (0.95, 0.95),
|
|
100
|
+
weight_decay: float = 0.1,
|
|
101
|
+
eps: float = 1e-8,
|
|
102
|
+
rms_target: float = 0.2,
|
|
103
|
+
ns_steps: int = 5,
|
|
104
|
+
ns_eps: float = 1e-7,
|
|
105
|
+
ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
106
|
+
stochastic_rounding: bool = True,
|
|
107
|
+
use_atan2: bool = False,
|
|
108
|
+
nesterov: bool = False,
|
|
109
|
+
Simplified_AdEMAMix: bool = False,
|
|
110
|
+
alpha_grad: float = 100.0,
|
|
111
|
+
vector_reshape_muon: bool = False,
|
|
112
|
+
vector_reshape: bool = False,
|
|
113
|
+
nnmf_factor: bool = False,
|
|
114
|
+
# K-b parameters
|
|
115
|
+
kourkoutas_beta: bool = False,
|
|
116
|
+
beta2_min: float = 0.9,
|
|
117
|
+
ema_alpha: float = 0.95,
|
|
118
|
+
tiny_spike: float = 1e-9,
|
|
119
|
+
k_warmup_steps: int = 0,
|
|
120
|
+
k_logging: int = 0,
|
|
121
|
+
layer_key_kb_fn: Optional[Callable] = None,
|
|
122
|
+
# hybrid optimizer mode
|
|
123
|
+
MuonWithAuxAdam: bool = False,
|
|
124
|
+
layer_key_fn: Optional[Callable] = None,
|
|
125
|
+
muon_adam_lr: float = 1e-4,
|
|
126
|
+
adam_kwargs: Optional[dict] = None,
|
|
127
|
+
):
|
|
128
|
+
if not (lr >= 0.0):
|
|
129
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
130
|
+
if not (weight_decay >= 0.0):
|
|
131
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
132
|
+
if not (ns_steps > 0):
|
|
133
|
+
raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
|
|
134
|
+
if Simplified_AdEMAMix and nesterov:
|
|
135
|
+
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
136
|
+
nesterov = False
|
|
137
|
+
|
|
138
|
+
muon_defaults = {
|
|
139
|
+
"lr": lr, "betas": betas, "weight_decay": weight_decay,
|
|
140
|
+
"eps": eps, "rms_target": rms_target, "ns_steps": ns_steps,
|
|
141
|
+
"ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
142
|
+
"vector_reshape": vector_reshape,
|
|
143
|
+
"vector_reshape_muon": vector_reshape_muon,
|
|
144
|
+
"nesterov":nesterov, "use_atan2":use_atan2,
|
|
145
|
+
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
146
|
+
"_kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
147
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
148
|
+
}
|
|
149
|
+
self.stochastic_rounding = stochastic_rounding
|
|
150
|
+
self._kourkoutas_beta = kourkoutas_beta
|
|
151
|
+
self._kourkoutas_helper = None
|
|
152
|
+
self.layer_key_kb_fn = layer_key_kb_fn
|
|
153
|
+
self.MuonWithAuxAdam = MuonWithAuxAdam
|
|
154
|
+
self.helper = None
|
|
155
|
+
self.aux_adam = None
|
|
156
|
+
|
|
157
|
+
if not self.MuonWithAuxAdam:
|
|
158
|
+
super().__init__(params, muon_defaults)
|
|
159
|
+
return
|
|
160
|
+
|
|
161
|
+
# HYBRID OPTIMIZER LOGIC
|
|
162
|
+
adam_kwargs = adam_kwargs or {}
|
|
163
|
+
self.aux_adam = AdamW_adv(
|
|
164
|
+
[],
|
|
165
|
+
lr=muon_adam_lr,
|
|
166
|
+
**adam_kwargs,
|
|
167
|
+
_is_delegate=True
|
|
168
|
+
)
|
|
169
|
+
adam_defaults = self.aux_adam.defaults
|
|
170
|
+
|
|
171
|
+
final_param_groups = []
|
|
172
|
+
_layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
|
|
173
|
+
|
|
174
|
+
for group in params:
|
|
175
|
+
# All params in a group are of the same type
|
|
176
|
+
first_param = group['params'][0]
|
|
177
|
+
key = _layer_key_fn(first_param)
|
|
178
|
+
optim_type = 'adam' if key == 'adam' else 'muon'
|
|
179
|
+
|
|
180
|
+
new_group = group.copy()
|
|
181
|
+
defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
|
|
182
|
+
|
|
183
|
+
for key, value in defaults_to_use.items():
|
|
184
|
+
new_group.setdefault(key, value)
|
|
185
|
+
|
|
186
|
+
final_param_groups.append(new_group)
|
|
187
|
+
|
|
188
|
+
super().__init__(final_param_groups, {})
|
|
189
|
+
|
|
190
|
+
# Now that self is initialized, create the helper
|
|
191
|
+
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@property
|
|
195
|
+
def supports_fused_back_pass(self):
|
|
196
|
+
return True
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
def supports_memory_efficient_fp16(self):
|
|
200
|
+
return True
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def supports_flat_params(self):
|
|
204
|
+
return False
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def kourkoutas_helper(self):
|
|
208
|
+
"""
|
|
209
|
+
Exposes the kourkoutas_helper from the auxiliary AdamW optimizer,
|
|
210
|
+
if it exists. This allows external access for logging K-b.
|
|
211
|
+
"""
|
|
212
|
+
if self.aux_adam and hasattr(self.aux_adam, 'kourkoutas_helper'):
|
|
213
|
+
return self.aux_adam.kourkoutas_helper
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
@torch.no_grad()
|
|
217
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
218
|
+
if group['_kourkoutas_beta'] and self._kourkoutas_helper is None:
|
|
219
|
+
self._kourkoutas_helper = KourkoutasHelper(self)
|
|
220
|
+
|
|
221
|
+
if self.MuonWithAuxAdam:
|
|
222
|
+
optim_type = self.helper.get_optimizer_type(p)
|
|
223
|
+
if optim_type == 'adam':
|
|
224
|
+
# Delegate to the AdamW_adv optimizer's logic.
|
|
225
|
+
# We need to temporarily "lend" our state and param_groups
|
|
226
|
+
self.aux_adam.state = self.state
|
|
227
|
+
self.aux_adam.param_groups = self.param_groups
|
|
228
|
+
|
|
229
|
+
# Ensure the aux optimizer uses the same Kourkoutas helper instance.
|
|
230
|
+
if self._kourkoutas_helper is not None:
|
|
231
|
+
self.aux_adam.kourkoutas_helper = self._kourkoutas_helper
|
|
232
|
+
|
|
233
|
+
self.aux_adam.step_parameter(p, group, i)
|
|
234
|
+
return
|
|
235
|
+
|
|
236
|
+
if p.grad is None:
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
grad = p.grad
|
|
240
|
+
state = self.state[p]
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# State Initialization
|
|
244
|
+
if 'step' not in state:
|
|
245
|
+
state['step'] = 0
|
|
246
|
+
|
|
247
|
+
should_factor = (
|
|
248
|
+
group['nnmf_factor'] and
|
|
249
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
state['factored'] = should_factor
|
|
253
|
+
|
|
254
|
+
state['reshaped_1d_muon'] = len(p.shape) == 1 and group['vector_reshape_muon']
|
|
255
|
+
|
|
256
|
+
dtype = torch.float32 if group['nnmf_factor'] else p.dtype
|
|
257
|
+
device = p.device
|
|
258
|
+
if state['factored'] or state['reshaped_1d_muon']:
|
|
259
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
260
|
+
d1, d2 = state['effective_shape']
|
|
261
|
+
if state['factored']:
|
|
262
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
263
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
264
|
+
packed_d2 = (d2 + 7) // 8
|
|
265
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
266
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
267
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
268
|
+
else:
|
|
269
|
+
if len(p.shape) >= 2:
|
|
270
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
271
|
+
state['second_momentum_buffer'] = torch.zeros_like(p)
|
|
272
|
+
if state['reshaped_1d_muon']:
|
|
273
|
+
state['momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
|
|
274
|
+
state['second_momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
|
|
275
|
+
elif len(p.shape) == 1:
|
|
276
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
277
|
+
|
|
278
|
+
# Retrieve hyperparameters
|
|
279
|
+
beta1, beta2 = group['betas']
|
|
280
|
+
current_step = state['step']
|
|
281
|
+
nesterov = group['nesterov']
|
|
282
|
+
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
283
|
+
alpha_grad = group['alpha_grad']
|
|
284
|
+
|
|
285
|
+
if state['factored']: # Factored AdaMuon
|
|
286
|
+
|
|
287
|
+
# Reconstruct momentum from previous step's factors & sign
|
|
288
|
+
d1, d2 = state['effective_shape']
|
|
289
|
+
mt_buf = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
290
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
291
|
+
torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
|
|
292
|
+
del unpacked_sign
|
|
293
|
+
|
|
294
|
+
# Update momentum in full-size
|
|
295
|
+
grad_reshaped = grad.view(d1, d2)
|
|
296
|
+
mt_buf.mul_(beta1).add_(grad_reshaped)
|
|
297
|
+
|
|
298
|
+
if nesterov:
|
|
299
|
+
signed_m_buf = torch.sign(grad_reshaped.add(mt_buf, alpha=beta1))
|
|
300
|
+
elif Simplified_AdEMAMix:
|
|
301
|
+
signed_m_buf = torch.sign(mt_buf.add(grad_reshaped, alpha=alpha_grad))
|
|
302
|
+
else:
|
|
303
|
+
signed_m_buf = torch.sign(mt_buf)
|
|
304
|
+
del grad_reshaped
|
|
305
|
+
|
|
306
|
+
update = _newton_schulz_iteration(
|
|
307
|
+
signed_m_buf,
|
|
308
|
+
steps=group['ns_steps'],
|
|
309
|
+
eps=group['ns_eps'],
|
|
310
|
+
coeffs=group['ns_coeffs'],
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
if group['_kourkoutas_beta']:
|
|
314
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
315
|
+
self._kourkoutas_helper.maybe_prepare_step(current_step)
|
|
316
|
+
# Accumulate current sign-stabilized orthogonal update's norm for the *next* step
|
|
317
|
+
self._kourkoutas_helper.accumulate_gradient_sq_norm(p, update.view(p.shape))
|
|
318
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
319
|
+
beta2 = self._kourkoutas_helper.get_beta2(p, group, current_step)
|
|
320
|
+
|
|
321
|
+
# Reconstruct second momentum from previous step's factors
|
|
322
|
+
vt_buf = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
323
|
+
|
|
324
|
+
# Update second momentum in full-size
|
|
325
|
+
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
326
|
+
|
|
327
|
+
# Apply second momentum update (adaptive scaling)
|
|
328
|
+
if group['use_atan2']:
|
|
329
|
+
a = 1.2732395
|
|
330
|
+
denom = vt_buf.sqrt()
|
|
331
|
+
update.atan2_(denom).mul_(a)
|
|
332
|
+
else:
|
|
333
|
+
denom = vt_buf.sqrt().add_(group['eps'])
|
|
334
|
+
update.div_(denom)
|
|
335
|
+
del denom
|
|
336
|
+
|
|
337
|
+
# RMS-aligned rescaling
|
|
338
|
+
rms_target = group['rms_target']
|
|
339
|
+
num_elements = update.numel()
|
|
340
|
+
scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm())
|
|
341
|
+
|
|
342
|
+
update.mul_(scaling_factor)
|
|
343
|
+
update = update.view(p.shape).mul_(group['lr'])
|
|
344
|
+
del num_elements, scaling_factor
|
|
345
|
+
|
|
346
|
+
# Compress updated moments and store new factors
|
|
347
|
+
state['sign'] = _pack_bools(mt_buf > 0)
|
|
348
|
+
_nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
349
|
+
del mt_buf
|
|
350
|
+
|
|
351
|
+
_nnmf(vt_buf.abs(), out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
352
|
+
del vt_buf
|
|
353
|
+
|
|
354
|
+
else: # Standard AdaMuon logic for non-factored tensors
|
|
355
|
+
|
|
356
|
+
if len(p.shape) >= 2 or state['reshaped_1d_muon']:
|
|
357
|
+
|
|
358
|
+
# Momentum update
|
|
359
|
+
mt_buf = state['momentum_buffer']
|
|
360
|
+
if state['reshaped_1d_muon']:
|
|
361
|
+
d1, d2 = state['effective_shape']
|
|
362
|
+
grad_reshaped = grad.view(d1, d2)
|
|
363
|
+
mt_buf.mul_(beta1).add_(grad_reshaped)
|
|
364
|
+
if nesterov:
|
|
365
|
+
signed_m_buf = torch.sign(grad_reshaped.add(mt_buf, alpha=beta1))
|
|
366
|
+
elif Simplified_AdEMAMix:
|
|
367
|
+
signed_m_buf = torch.sign(mt_buf.add(grad_reshaped, alpha=alpha_grad))
|
|
368
|
+
else:
|
|
369
|
+
signed_m_buf = torch.sign(mt_buf)
|
|
370
|
+
del grad_reshaped
|
|
371
|
+
else:
|
|
372
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
373
|
+
if nesterov:
|
|
374
|
+
signed_m_buf = torch.sign(grad.add(mt_buf, alpha=beta1))
|
|
375
|
+
elif Simplified_AdEMAMix:
|
|
376
|
+
signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
|
|
377
|
+
else:
|
|
378
|
+
signed_m_buf = torch.sign(mt_buf)
|
|
379
|
+
|
|
380
|
+
# Flatten if necessary (e.g., for Conv layers)
|
|
381
|
+
if len(p.shape) > 2:
|
|
382
|
+
signed_m_buf = signed_m_buf.view(p.shape[0], -1)
|
|
383
|
+
|
|
384
|
+
# NewtonSchulz
|
|
385
|
+
update = _newton_schulz_iteration(
|
|
386
|
+
signed_m_buf,
|
|
387
|
+
steps=group['ns_steps'],
|
|
388
|
+
eps=group['ns_eps'],
|
|
389
|
+
coeffs=group['ns_coeffs'],
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
if len(p.shape) > 2 or state['reshaped_1d_muon']:
|
|
393
|
+
update = update.view(p.shape)
|
|
394
|
+
|
|
395
|
+
if group['_kourkoutas_beta']:
|
|
396
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
397
|
+
self._kourkoutas_helper.maybe_prepare_step(current_step)
|
|
398
|
+
# Accumulate current sign-stabilized orthogonal update's norm for the *next* step
|
|
399
|
+
self._kourkoutas_helper.accumulate_gradient_sq_norm(p, update)
|
|
400
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
401
|
+
beta2 = self._kourkoutas_helper.get_beta2(p, group, current_step)
|
|
402
|
+
|
|
403
|
+
vt_buf = state['second_momentum_buffer']
|
|
404
|
+
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
405
|
+
|
|
406
|
+
# Apply second momentum update (adaptive scaling)
|
|
407
|
+
if group['use_atan2']:
|
|
408
|
+
a = 1.2732395
|
|
409
|
+
denom = vt_buf.sqrt()
|
|
410
|
+
update.atan2_(denom).mul_(a)
|
|
411
|
+
else:
|
|
412
|
+
denom = vt_buf.sqrt().add_(group['eps'])
|
|
413
|
+
update.div_(denom)
|
|
414
|
+
del denom
|
|
415
|
+
|
|
416
|
+
# RMS-aligned rescaling
|
|
417
|
+
rms_target = group['rms_target']
|
|
418
|
+
num_elements = update.numel()
|
|
419
|
+
scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm())
|
|
420
|
+
|
|
421
|
+
update.mul_(scaling_factor)
|
|
422
|
+
del num_elements, scaling_factor
|
|
423
|
+
|
|
424
|
+
update.mul_(group['lr'])
|
|
425
|
+
|
|
426
|
+
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
|
|
427
|
+
# Momentum update
|
|
428
|
+
mt_buf = state['momentum_buffer']
|
|
429
|
+
mt_buf.mul_(beta1).add_(grad)
|
|
430
|
+
if nesterov:
|
|
431
|
+
update = grad.add(mt_buf, alpha=beta1)
|
|
432
|
+
elif Simplified_AdEMAMix:
|
|
433
|
+
signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
|
|
434
|
+
else:
|
|
435
|
+
update = mt_buf.clone()
|
|
436
|
+
update.mul_(group['lr'])
|
|
437
|
+
|
|
438
|
+
# Decoupled weight decay
|
|
439
|
+
if group["weight_decay"] != 0:
|
|
440
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
441
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
442
|
+
else:
|
|
443
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
444
|
+
|
|
445
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
446
|
+
add_stochastic_(p.data, -update)
|
|
447
|
+
else:
|
|
448
|
+
p.data.add_(-update)
|
|
449
|
+
del update
|
|
450
|
+
|
|
451
|
+
state['step'] += 1
|
|
452
|
+
|
|
453
|
+
@torch.no_grad()
|
|
454
|
+
def step(self, closure=None):
|
|
455
|
+
"""Performs a single optimization step."""
|
|
456
|
+
loss = None
|
|
457
|
+
if closure is not None:
|
|
458
|
+
with torch.enable_grad():
|
|
459
|
+
loss = closure()
|
|
460
|
+
|
|
461
|
+
for group in self.param_groups:
|
|
462
|
+
for i, p in enumerate(group['params']):
|
|
463
|
+
self.step_parameter(p, group, i)
|
|
464
|
+
|
|
465
|
+
return loss
|
|
@@ -73,7 +73,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
73
73
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
74
74
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
75
75
|
logging (default: 0).
|
|
76
|
-
|
|
76
|
+
layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
77
77
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
78
78
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
79
79
|
(default: None)
|
|
@@ -105,7 +105,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
105
105
|
tiny_spike: float = 1e-9,
|
|
106
106
|
k_warmup_steps: int = 0,
|
|
107
107
|
k_logging: int = 0,
|
|
108
|
-
|
|
108
|
+
layer_key_kb_fn: Optional[Callable] = None,
|
|
109
109
|
nnmf_factor: bool = False,
|
|
110
110
|
_is_delegate: bool = False,
|
|
111
111
|
):
|
|
@@ -137,7 +137,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
137
137
|
self.use_AdEMAMix = use_AdEMAMix
|
|
138
138
|
self.factored = nnmf_factor
|
|
139
139
|
self.kourkoutas_beta = kourkoutas_beta
|
|
140
|
-
self.
|
|
140
|
+
self.layer_key_kb_fn = layer_key_kb_fn
|
|
141
141
|
if not _is_delegate:
|
|
142
142
|
super().__init__(params, defaults)
|
|
143
143
|
else:
|
|
@@ -91,7 +91,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
91
91
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
92
92
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
93
93
|
logging (default: 0).
|
|
94
|
-
|
|
94
|
+
layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
95
95
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
96
96
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
97
97
|
(default: None)
|
|
@@ -125,7 +125,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
125
125
|
tiny_spike: float = 1e-9,
|
|
126
126
|
k_warmup_steps: int = 0,
|
|
127
127
|
k_logging: int = 0,
|
|
128
|
-
|
|
128
|
+
layer_key_kb_fn: Optional[Callable] = None,
|
|
129
129
|
nnmf_factor: bool = False,
|
|
130
130
|
):
|
|
131
131
|
if not (lr >= 0.0):
|
|
@@ -148,9 +148,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
148
148
|
print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
|
|
149
149
|
if cautious_mask and Simplified_AdEMAMix:
|
|
150
150
|
print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
151
|
-
if use_atan2 and Simplified_AdEMAMix:
|
|
152
|
-
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
153
|
-
use_atan2 = False
|
|
154
151
|
|
|
155
152
|
defaults = {
|
|
156
153
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
@@ -169,7 +166,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
169
166
|
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
170
167
|
self.factored = nnmf_factor
|
|
171
168
|
self.kourkoutas_beta = kourkoutas_beta
|
|
172
|
-
self.
|
|
169
|
+
self.layer_key_kb_fn = layer_key_kb_fn
|
|
173
170
|
super().__init__(params, defaults)
|
|
174
171
|
|
|
175
172
|
if self.kourkoutas_beta:
|
|
@@ -22,7 +22,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
22
22
|
can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
|
|
23
23
|
flattening/reshaping them.
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
Can also operate in a hybrid mode, using an auxiliary AdamW
|
|
26
26
|
optimizer for specific parameters (e.g., biases, norms, embeddings) as
|
|
27
27
|
defined by a `layer_key_fn`.
|
|
28
28
|
|
|
@@ -38,6 +38,14 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
38
38
|
ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
39
39
|
quintic polynomial in the Newton-Schulz iteration.
|
|
40
40
|
(default: (3.4445, -4.7750, 2.0315)).
|
|
41
|
+
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
42
|
+
This changes the update to `alpha_grad * grad + mt`, which can be
|
|
43
|
+
more responsive, especially for small batch sizes. (default: False)
|
|
44
|
+
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
45
|
+
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
46
|
+
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
47
|
+
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
48
|
+
stability. (default: 100.0)
|
|
41
49
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
42
50
|
BF16 parameter updates (default: True).
|
|
43
51
|
vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
|
|
@@ -68,9 +76,11 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
68
76
|
ns_steps: int = 5,
|
|
69
77
|
ns_eps: float = 1e-7,
|
|
70
78
|
ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
79
|
+
Simplified_AdEMAMix: bool = False,
|
|
80
|
+
alpha_grad: float = 100.0,
|
|
71
81
|
stochastic_rounding: bool = True,
|
|
72
82
|
vector_reshape_muon: bool = False,
|
|
73
|
-
vector_reshape: bool =
|
|
83
|
+
vector_reshape: bool = False,
|
|
74
84
|
nnmf_factor: bool = False,
|
|
75
85
|
# hybrid optimizer mode
|
|
76
86
|
MuonWithAuxAdam: bool = False,
|
|
@@ -86,13 +96,17 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
86
96
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
87
97
|
if not (ns_steps > 0):
|
|
88
98
|
raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
|
|
99
|
+
if Simplified_AdEMAMix and nesterov:
|
|
100
|
+
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
101
|
+
nesterov = False
|
|
89
102
|
|
|
90
|
-
|
|
103
|
+
muon_defaults = {
|
|
91
104
|
"lr": lr, "beta1": beta1, "weight_decay": weight_decay,
|
|
92
105
|
"nesterov": nesterov, "ns_steps": ns_steps, "ns_eps": ns_eps,
|
|
93
106
|
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
94
107
|
"vector_reshape": vector_reshape,
|
|
95
108
|
"vector_reshape_muon": vector_reshape_muon,
|
|
109
|
+
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
96
110
|
}
|
|
97
111
|
self.stochastic_rounding = stochastic_rounding
|
|
98
112
|
|
|
@@ -100,23 +114,41 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
100
114
|
self.helper = None
|
|
101
115
|
self.aux_adam = None
|
|
102
116
|
|
|
103
|
-
if self.MuonWithAuxAdam:
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
self.aux_adam = AdamW_adv(
|
|
107
|
-
[],
|
|
108
|
-
lr=muon_adam_lr,
|
|
109
|
-
**adam_kwargs,
|
|
110
|
-
_is_delegate=True
|
|
111
|
-
)
|
|
112
|
-
# Update the defaults dictionary
|
|
113
|
-
defaults.update(self.aux_adam.defaults)
|
|
114
|
-
|
|
115
|
-
super().__init__(params, defaults)
|
|
117
|
+
if not self.MuonWithAuxAdam:
|
|
118
|
+
super().__init__(params, muon_defaults)
|
|
119
|
+
return
|
|
116
120
|
|
|
117
|
-
|
|
118
|
-
|
|
121
|
+
# HYBRID OPTIMIZER LOGIC
|
|
122
|
+
adam_kwargs = adam_kwargs or {}
|
|
123
|
+
self.aux_adam = AdamW_adv(
|
|
124
|
+
[],
|
|
125
|
+
lr=muon_adam_lr,
|
|
126
|
+
**adam_kwargs,
|
|
127
|
+
_is_delegate=True
|
|
128
|
+
)
|
|
129
|
+
adam_defaults = self.aux_adam.defaults
|
|
130
|
+
|
|
131
|
+
final_param_groups = []
|
|
132
|
+
_layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
|
|
133
|
+
|
|
134
|
+
for group in params:
|
|
135
|
+
first_param = group['params'][0]
|
|
136
|
+
key = _layer_key_fn(first_param)
|
|
137
|
+
optim_type = 'adam' if key == 'adam' else 'muon'
|
|
138
|
+
|
|
139
|
+
new_group = group.copy()
|
|
140
|
+
defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
|
|
141
|
+
|
|
142
|
+
for key, value in defaults_to_use.items():
|
|
143
|
+
new_group.setdefault(key, value)
|
|
144
|
+
|
|
145
|
+
final_param_groups.append(new_group)
|
|
146
|
+
|
|
147
|
+
super().__init__(final_param_groups, {})
|
|
119
148
|
|
|
149
|
+
# Now that self is initialized, create the helper
|
|
150
|
+
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
151
|
+
|
|
120
152
|
|
|
121
153
|
@property
|
|
122
154
|
def supports_fused_back_pass(self):
|
|
@@ -130,6 +162,16 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
130
162
|
def supports_flat_params(self):
|
|
131
163
|
return False
|
|
132
164
|
|
|
165
|
+
@property
|
|
166
|
+
def kourkoutas_helper(self):
|
|
167
|
+
"""
|
|
168
|
+
Exposes the kourkoutas_helper from the auxiliary AdamW optimizer,
|
|
169
|
+
if it exists. This allows external access for logging K-b.
|
|
170
|
+
"""
|
|
171
|
+
if self.aux_adam and hasattr(self.aux_adam, 'kourkoutas_helper'):
|
|
172
|
+
return self.aux_adam.kourkoutas_helper
|
|
173
|
+
return None
|
|
174
|
+
|
|
133
175
|
@torch.no_grad()
|
|
134
176
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
135
177
|
if self.MuonWithAuxAdam:
|
|
@@ -165,7 +207,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
165
207
|
|
|
166
208
|
dtype = torch.float32 if group['nnmf_factor'] else p.dtype
|
|
167
209
|
device = p.device
|
|
168
|
-
if
|
|
210
|
+
if state['factored'] or state['reshaped_1d_muon']:
|
|
169
211
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
170
212
|
d1, d2 = state['effective_shape']
|
|
171
213
|
if state['factored']:
|
|
@@ -183,6 +225,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
183
225
|
|
|
184
226
|
beta1 = group['beta1']
|
|
185
227
|
nesterov = group['nesterov']
|
|
228
|
+
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
229
|
+
alpha_grad = group['alpha_grad']
|
|
186
230
|
|
|
187
231
|
if state['factored']: # Factored Muon
|
|
188
232
|
|
|
@@ -200,6 +244,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
200
244
|
if nesterov:
|
|
201
245
|
# Nesterov momentum
|
|
202
246
|
update = grad_reshaped.add(mt_buf, alpha=beta1)
|
|
247
|
+
elif Simplified_AdEMAMix:
|
|
248
|
+
update = torch.add(mt_buf, grad_reshaped, alpha=alpha_grad)
|
|
203
249
|
else:
|
|
204
250
|
# Standard momentum
|
|
205
251
|
update = mt_buf.clone()
|
|
@@ -238,6 +284,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
238
284
|
del grad_reshaped
|
|
239
285
|
else:
|
|
240
286
|
update = grad.add(mt_buf, alpha=beta1)
|
|
287
|
+
elif Simplified_AdEMAMix:
|
|
288
|
+
if state['reshaped_1d_muon']:
|
|
289
|
+
update = torch.add(mt_buf, grad_reshaped, alpha=alpha_grad)
|
|
290
|
+
del grad_reshaped
|
|
291
|
+
else:
|
|
292
|
+
update = torch.add(mt_buf, grad, alpha=alpha_grad)
|
|
241
293
|
else:
|
|
242
294
|
# Standard momentum
|
|
243
295
|
update = mt_buf.clone()
|
|
@@ -267,6 +319,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
267
319
|
if nesterov:
|
|
268
320
|
# Nesterov momentum
|
|
269
321
|
update = grad.add(mt_buf, alpha=beta1)
|
|
322
|
+
elif Simplified_AdEMAMix:
|
|
323
|
+
update = torch.add(mt_buf, grad, alpha=alpha_grad)
|
|
270
324
|
else:
|
|
271
325
|
# Standard momentum
|
|
272
326
|
update = mt_buf.clone()
|
|
@@ -299,4 +353,4 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
299
353
|
for i, p in enumerate(group['params']):
|
|
300
354
|
self.step_parameter(p, group, i)
|
|
301
355
|
|
|
302
|
-
return loss
|
|
356
|
+
return loss
|
|
@@ -109,7 +109,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
109
109
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
110
110
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
111
111
|
logging (default: 0).
|
|
112
|
-
|
|
112
|
+
layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
113
113
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
114
114
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
115
115
|
(default: None)
|
|
@@ -152,7 +152,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
152
152
|
tiny_spike: float = 1e-9,
|
|
153
153
|
k_warmup_steps: int = 0,
|
|
154
154
|
k_logging: int = 0,
|
|
155
|
-
|
|
155
|
+
layer_key_kb_fn: Optional[Callable] = None,
|
|
156
156
|
):
|
|
157
157
|
if not (lr >= 0.0):
|
|
158
158
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -205,7 +205,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
205
205
|
self.fsdp_in_use = fsdp_in_use
|
|
206
206
|
|
|
207
207
|
self.kourkoutas_beta = kourkoutas_beta
|
|
208
|
-
self.
|
|
208
|
+
self.layer_key_kb_fn = layer_key_kb_fn
|
|
209
209
|
|
|
210
210
|
super().__init__(params, defaults)
|
|
211
211
|
if self.kourkoutas_beta:
|
|
@@ -67,7 +67,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
67
67
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
68
68
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
69
69
|
logging (default: 0).
|
|
70
|
-
|
|
70
|
+
layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
71
71
|
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
72
72
|
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
73
73
|
(default: None)
|
|
@@ -95,7 +95,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
95
95
|
tiny_spike: float = 1e-9,
|
|
96
96
|
k_warmup_steps: int = 0,
|
|
97
97
|
k_logging: int = 0,
|
|
98
|
-
|
|
98
|
+
layer_key_kb_fn: Optional[Callable] = None,
|
|
99
99
|
nnmf_factor: bool = False,
|
|
100
100
|
):
|
|
101
101
|
if not (lr >= 0.0):
|
|
@@ -121,7 +121,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
121
121
|
self.stochastic_rounding = stochastic_rounding
|
|
122
122
|
self.factored = nnmf_factor
|
|
123
123
|
self.kourkoutas_beta = kourkoutas_beta
|
|
124
|
-
self.
|
|
124
|
+
self.layer_key_kb_fn = layer_key_kb_fn
|
|
125
125
|
super().__init__(params, defaults)
|
|
126
126
|
|
|
127
127
|
if self.kourkoutas_beta:
|
|
@@ -5,6 +5,7 @@ from .Simplified_AdEMAMix import Simplified_AdEMAMix
|
|
|
5
5
|
from .Lion_adv import Lion_adv
|
|
6
6
|
from .Lion_Prodigy_adv import Lion_Prodigy_adv
|
|
7
7
|
from .Muon_adv import Muon_adv
|
|
8
|
+
from .AdaMuon_adv import AdaMuon_adv
|
|
8
9
|
|
|
9
10
|
__all__ = [
|
|
10
11
|
"AdamW_adv",
|
|
@@ -14,4 +15,5 @@ __all__ = [
|
|
|
14
15
|
"Lion_adv",
|
|
15
16
|
"Lion_Prodigy_adv",
|
|
16
17
|
"Muon_adv",
|
|
18
|
+
"AdaMuon_adv",
|
|
17
19
|
]
|
|
@@ -32,12 +32,12 @@ class KourkoutasHelper:
|
|
|
32
32
|
if self._layer_info_built:
|
|
33
33
|
return
|
|
34
34
|
|
|
35
|
-
if hasattr(self.optimizer, '
|
|
35
|
+
if hasattr(self.optimizer, 'layer_key_kb_fn') and self.optimizer.layer_key_kb_fn is not None:
|
|
36
36
|
# A custom key function was provided by the user. We will use it.
|
|
37
37
|
pass
|
|
38
38
|
else:
|
|
39
39
|
# No key function was provided. Default to coarse, shape-based bucketing.
|
|
40
|
-
self.optimizer.
|
|
40
|
+
self.optimizer.layer_key_kb_fn = lambda p: \
|
|
41
41
|
(id(p),) if p.dim() == 2 and 1 <= p.shape[0] <= 10 and p.shape[1] in {768, 1280, 4096} \
|
|
42
42
|
else tuple(p.shape)
|
|
43
43
|
# This ensures that we won't mix embeddings with tokens (1 to 10)
|
|
@@ -46,7 +46,7 @@ class KourkoutasHelper:
|
|
|
46
46
|
for group in self.optimizer.param_groups:
|
|
47
47
|
for p in group['params']:
|
|
48
48
|
# The mapping is static and should not depend on the presence of a gradient.
|
|
49
|
-
layer_key = self.optimizer.
|
|
49
|
+
layer_key = self.optimizer.layer_key_kb_fn(p)
|
|
50
50
|
if layer_key not in self.layer_info:
|
|
51
51
|
self.layer_info[layer_key] = {'params': [], 'group_ref': group}
|
|
52
52
|
self.layer_info[layer_key]['params'].append(p)
|
|
@@ -158,7 +158,7 @@ class KourkoutasHelper:
|
|
|
158
158
|
"""
|
|
159
159
|
Accumulates the squared L2 norm of a single gradient for the next step's calculation.
|
|
160
160
|
"""
|
|
161
|
-
layer_key = self.optimizer.
|
|
161
|
+
layer_key = self.optimizer.layer_key_kb_fn(p)
|
|
162
162
|
|
|
163
163
|
if layer_key in self.layer_info:
|
|
164
164
|
# Initialize the transient state for this layer if it's the first time in the step.
|
|
@@ -173,6 +173,6 @@ class KourkoutasHelper:
|
|
|
173
173
|
"""
|
|
174
174
|
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
175
175
|
"""
|
|
176
|
-
layer_key = self.optimizer.
|
|
176
|
+
layer_key = self.optimizer.layer_key_kb_fn(p)
|
|
177
177
|
# The default is the max value, which is correct for unmapped params or edge cases
|
|
178
178
|
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
|
|
@@ -7,6 +7,7 @@ adv_optm.egg-info/SOURCES.txt
|
|
|
7
7
|
adv_optm.egg-info/dependency_links.txt
|
|
8
8
|
adv_optm.egg-info/requires.txt
|
|
9
9
|
adv_optm.egg-info/top_level.txt
|
|
10
|
+
adv_optm/optim/AdaMuon_adv.py
|
|
10
11
|
adv_optm/optim/AdamW_adv.py
|
|
11
12
|
adv_optm/optim/Adopt_adv.py
|
|
12
13
|
adv_optm/optim/Lion_Prodigy_adv.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|