adv-optm 1.2.dev5__tar.gz → 1.2.dev7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/PKG-INFO +1 -1
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/__init__.py +1 -1
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/Muon_adv.py +165 -18
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/Kourkoutas.py +21 -5
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/setup.py +1 -1
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/LICENSE +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/README.md +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/AdaMuon_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/MuonAdam_helper.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/Newton_Schulz.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm/util/__init__.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev7}/setup.cfg +0 -0
|
@@ -18,6 +18,10 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
18
18
|
the hidden layers of neural networks. It applies SGD with momentum and then
|
|
19
19
|
orthogonalizes the resulting update matrix using a Newton-Schulz iteration.
|
|
20
20
|
|
|
21
|
+
NorMuon (Neuron-wise Normalized Muon) extends this by adding neuron-level
|
|
22
|
+
adaptive learning rates, combining the benefits of orthogonalization with
|
|
23
|
+
second-order momentum statistics.
|
|
24
|
+
|
|
21
25
|
This implementation is designed for 2D parameters (e.g., linear layers) and
|
|
22
26
|
can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
|
|
23
27
|
flattening/reshaping them.
|
|
@@ -54,6 +58,19 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
54
58
|
matrices to apply low-rank compression (default: True).
|
|
55
59
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
56
60
|
the uncompressed optimizer. (default: False)
|
|
61
|
+
low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
|
|
62
|
+
projects the update to a lower rank before orthogonalization.
|
|
63
|
+
(default: False)
|
|
64
|
+
ortho_rank (int): The rank for low-rank orthogonalization.
|
|
65
|
+
(default: 128)
|
|
66
|
+
normuon_variant (bool): If True, enables the NorMuon update rule, which adds
|
|
67
|
+
neuron-wise normalization. (default: False)
|
|
68
|
+
beta2_normuon (float): The exponential decay rate for the second moment estimates
|
|
69
|
+
used in NorMuon. (default: 0.95)
|
|
70
|
+
normuon_eps (float): Epsilon for NorMuon normalization stability. (default: 1e-8)
|
|
71
|
+
normuon_lr_scale (float): Scaling factor for the NorMuon learning rate.
|
|
72
|
+
(default: 0.2)
|
|
73
|
+
normuon_atan2 (bool): whether to use the atan2 for NorMuon. (default: False)
|
|
57
74
|
MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
|
|
58
75
|
Parameters designated by `layer_key_fn` will be optimized with
|
|
59
76
|
AdamW_adv instead of Muon. (default: False)
|
|
@@ -82,6 +99,15 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
82
99
|
vector_reshape_muon: bool = False,
|
|
83
100
|
vector_reshape: bool = False,
|
|
84
101
|
nnmf_factor: bool = False,
|
|
102
|
+
# Low-rank Muon
|
|
103
|
+
low_rank_ortho: bool = False,
|
|
104
|
+
ortho_rank: int = 128,
|
|
105
|
+
# NorMuon additions
|
|
106
|
+
normuon_variant: bool = False,
|
|
107
|
+
beta2_normuon: float = 0.95,
|
|
108
|
+
normuon_eps: float = 1e-8,
|
|
109
|
+
normuon_lr_scale: float = 0.2,
|
|
110
|
+
normuon_atan2: bool = False,
|
|
85
111
|
# hybrid optimizer mode
|
|
86
112
|
MuonWithAuxAdam: bool = False,
|
|
87
113
|
layer_key_fn: Optional[Callable] = None,
|
|
@@ -92,6 +118,8 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
92
118
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
93
119
|
if not (0.0 <= beta1 < 1.0):
|
|
94
120
|
raise ValueError(f"beta1 should be in [0.0, 1.0). Got {beta1}")
|
|
121
|
+
if normuon_variant and not (0.0 <= beta2_normuon < 1.0):
|
|
122
|
+
raise ValueError(f"beta2_normuon should be in [0.0, 1.0) for NorMuon. Got {beta2_normuon}")
|
|
95
123
|
if not (weight_decay >= 0.0):
|
|
96
124
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
97
125
|
if not (ns_steps > 0):
|
|
@@ -106,10 +134,16 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
106
134
|
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
107
135
|
"vector_reshape": vector_reshape,
|
|
108
136
|
"vector_reshape_muon": vector_reshape_muon,
|
|
109
|
-
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
137
|
+
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
138
|
+
# Low-rank Ortho
|
|
139
|
+
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
140
|
+
# NorMuon
|
|
141
|
+
"normuon_variant": normuon_variant, "beta2_normuon": beta2_normuon,
|
|
142
|
+
"normuon_eps": normuon_eps, "normuon_lr_scale": normuon_lr_scale,
|
|
143
|
+
"normuon_atan2": normuon_atan2,
|
|
110
144
|
}
|
|
111
145
|
self.stochastic_rounding = stochastic_rounding
|
|
112
|
-
|
|
146
|
+
|
|
113
147
|
self.MuonWithAuxAdam = MuonWithAuxAdam
|
|
114
148
|
self.helper = None
|
|
115
149
|
self.aux_adam = None
|
|
@@ -223,6 +257,12 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
223
257
|
elif len(p.shape) == 1:
|
|
224
258
|
state['momentum_buffer'] = torch.zeros_like(p)
|
|
225
259
|
|
|
260
|
+
# NorMuon state initialization
|
|
261
|
+
if group['normuon_variant']:
|
|
262
|
+
if len(p.shape) >= 2 or state['reshaped_1d_muon']:
|
|
263
|
+
num_rows = p.shape[0] if len(p.shape) >= 2 else state['effective_shape'][0]
|
|
264
|
+
state['normuon_v'] = torch.zeros(num_rows, device=p.device, dtype=torch.float32)
|
|
265
|
+
|
|
226
266
|
beta1 = group['beta1']
|
|
227
267
|
nesterov = group['nesterov']
|
|
228
268
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
@@ -251,14 +291,60 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
251
291
|
update = mt_buf.clone()
|
|
252
292
|
del grad_reshaped
|
|
253
293
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
294
|
+
# Orthogonalization step
|
|
295
|
+
if group['low_rank_muon']:
|
|
296
|
+
# Low-Rank Orthogonalization on the reconstructed matrix
|
|
297
|
+
M = update
|
|
298
|
+
r = min(group['low_rank_rank'], M.shape[0], M.shape[1])
|
|
299
|
+
if r > 0:
|
|
300
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
301
|
+
MG = M @ G_sketch
|
|
302
|
+
if MG.dtype != torch.float32:
|
|
303
|
+
MG_dtype = M.dtype
|
|
304
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
305
|
+
Q = Q.to(MG_dtype)
|
|
306
|
+
else:
|
|
307
|
+
Q, _ = torch.linalg.qr(MG)
|
|
308
|
+
projected_M = Q.T @ M
|
|
309
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
310
|
+
projected_M, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
311
|
+
)
|
|
312
|
+
update = Q @ ortho_projected_M
|
|
313
|
+
else: # Fallback for invalid rank
|
|
314
|
+
update = _newton_schulz_iteration(
|
|
315
|
+
update, steps=group['ns_steps'], eps=group['ns_eps'], coeffs=group['ns_coeffs']
|
|
316
|
+
)
|
|
317
|
+
else:
|
|
318
|
+
# Original full Newton-Schulz
|
|
319
|
+
update = _newton_schulz_iteration(
|
|
320
|
+
update,
|
|
321
|
+
steps=group['ns_steps'],
|
|
322
|
+
eps=group['ns_eps'],
|
|
323
|
+
coeffs=group['ns_coeffs'],
|
|
324
|
+
)
|
|
260
325
|
|
|
261
|
-
|
|
326
|
+
|
|
327
|
+
if group['normuon_variant'] and 'normuon_v' in state:
|
|
328
|
+
v_t = state['normuon_v']
|
|
329
|
+
beta2_normuon = group['beta2_normuon']
|
|
330
|
+
# Update 2nd moment estimate
|
|
331
|
+
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
332
|
+
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
333
|
+
# Normalize update
|
|
334
|
+
if group['normuon_atan2']:
|
|
335
|
+
a = 1.2732395
|
|
336
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
337
|
+
else:
|
|
338
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
339
|
+
# Scale learning rate
|
|
340
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
341
|
+
if update_norm > 1e-12:
|
|
342
|
+
scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
|
|
343
|
+
else:
|
|
344
|
+
scaled_lr = 0.0
|
|
345
|
+
update = update.view(p.shape).mul_(scaled_lr)
|
|
346
|
+
else: # Original Muon learning rate application
|
|
347
|
+
update = update.view(p.shape).mul_(group['lr'])
|
|
262
348
|
|
|
263
349
|
state['sign'] = _pack_bools(mt_buf > 0)
|
|
264
350
|
_nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
@@ -298,19 +384,80 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
298
384
|
if len(p.shape) > 2:
|
|
299
385
|
update = update.view(p.shape[0], -1)
|
|
300
386
|
|
|
301
|
-
#
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
387
|
+
# Orthogonalization step
|
|
388
|
+
if group['low_rank_ortho']:
|
|
389
|
+
# Low-Rank Orthogonalization based on Gaussian Sketching
|
|
390
|
+
M = update
|
|
391
|
+
r = min(group['ortho_rank'], M.shape[0], M.shape[1])
|
|
392
|
+
|
|
393
|
+
if r > 0:
|
|
394
|
+
# 1. Sketch the matrix
|
|
395
|
+
G_sketch = torch.randn(M.shape[1], r, device=M.device, dtype=M.dtype)
|
|
396
|
+
MG = M @ G_sketch
|
|
397
|
+
|
|
398
|
+
# 2. QR decomposition to get orthogonal basis Q
|
|
399
|
+
if MG.dtype != torch.float32:
|
|
400
|
+
MG_dtype = M.dtype
|
|
401
|
+
Q, _ = torch.linalg.qr(MG.float())
|
|
402
|
+
Q = Q.to(MG_dtype)
|
|
403
|
+
else:
|
|
404
|
+
Q, _ = torch.linalg.qr(MG)
|
|
405
|
+
|
|
406
|
+
# 3. Project M onto the basis
|
|
407
|
+
projected_M = Q.T @ M
|
|
408
|
+
|
|
409
|
+
# 4. Orthogonalize the smaller projected matrix
|
|
410
|
+
ortho_projected_M = _newton_schulz_iteration(
|
|
411
|
+
projected_M,
|
|
412
|
+
steps=group['ns_steps'],
|
|
413
|
+
eps=group['ns_eps'],
|
|
414
|
+
coeffs=group['ns_coeffs'],
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
# 5. Project back to the original space
|
|
418
|
+
update = Q @ ortho_projected_M
|
|
419
|
+
else: # Fallback for invalid rank
|
|
420
|
+
update = _newton_schulz_iteration(
|
|
421
|
+
update,
|
|
422
|
+
steps=group['ns_steps'],
|
|
423
|
+
eps=group['ns_eps'],
|
|
424
|
+
coeffs=group['ns_coeffs'],
|
|
425
|
+
)
|
|
426
|
+
else:
|
|
427
|
+
# Original NewtonSchulz
|
|
428
|
+
update = _newton_schulz_iteration(
|
|
429
|
+
update,
|
|
430
|
+
steps=group['ns_steps'],
|
|
431
|
+
eps=group['ns_eps'],
|
|
432
|
+
coeffs=group['ns_coeffs'],
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# NorMuon Logic
|
|
436
|
+
if group['normuon_variant'] and 'normuon_v' in state:
|
|
437
|
+
v_t = state['normuon_v']
|
|
438
|
+
beta2_normuon = group['beta2_normuon']
|
|
439
|
+
# Update 2nd moment estimate
|
|
440
|
+
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
441
|
+
v_t.mul_(beta2_normuon).add_(mean_squared_update, alpha=1 - beta2_normuon)
|
|
442
|
+
# Normalize update
|
|
443
|
+
if group['normuon_atan2']:
|
|
444
|
+
a = 1.2732395
|
|
445
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
446
|
+
else:
|
|
447
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
448
|
+
# Scale learning rate
|
|
449
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
450
|
+
if update_norm > 1e-12:
|
|
451
|
+
scaled_lr = group['normuon_lr_scale'] * group['lr'] * (p.numel()**0.5) / update_norm
|
|
452
|
+
else:
|
|
453
|
+
scaled_lr = 0.0
|
|
454
|
+
update.mul_(scaled_lr)
|
|
455
|
+
else: # Original Muon learning rate application
|
|
456
|
+
update.mul_(group['lr'])
|
|
308
457
|
|
|
309
458
|
# Reshape back to original if we flattened or reshaped
|
|
310
459
|
if len(p.shape) > 2 or state['reshaped_1d_muon']:
|
|
311
460
|
update = update.view(p.shape)
|
|
312
|
-
|
|
313
|
-
update.mul_(group['lr'])
|
|
314
461
|
|
|
315
462
|
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
|
|
316
463
|
# Momentum update
|
|
@@ -86,9 +86,17 @@ class KourkoutasHelper:
|
|
|
86
86
|
# These are just for the sample log, initialize them
|
|
87
87
|
sun, pooled_grad_norm, prev_r_ema_val, r_ema_tensor = (torch.tensor(0.0),)*4
|
|
88
88
|
|
|
89
|
+
# The optimizer that owns this helper holds the master defaults for K-b.
|
|
90
|
+
# This is crucial in hybrid optimizers where some param_groups might not
|
|
91
|
+
# have all K-b keys populated, preventing KeyErrors.
|
|
92
|
+
master_defaults = self.optimizer.defaults
|
|
93
|
+
|
|
89
94
|
for layer_key, info in self.layer_info.items():
|
|
90
95
|
params, group = info['params'], info['group_ref']
|
|
91
96
|
|
|
97
|
+
if not group.get('kourkoutas_beta', False):
|
|
98
|
+
continue
|
|
99
|
+
|
|
92
100
|
first_param_in_layer = info['params'][0]
|
|
93
101
|
param_state = self.optimizer.state[first_param_in_layer]
|
|
94
102
|
|
|
@@ -100,6 +108,15 @@ class KourkoutasHelper:
|
|
|
100
108
|
if 'kourkoutas_r_ema' not in param_state:
|
|
101
109
|
param_state['kourkoutas_r_ema'] = torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
|
|
102
110
|
|
|
111
|
+
# Use group-specific K-b settings, falling back to the optimizer's master defaults.
|
|
112
|
+
# This makes the helper robust against param groups that enable kourkoutas_beta
|
|
113
|
+
# but are missing the other required hyperparameters.
|
|
114
|
+
ema_alpha = group.get('ema_alpha', master_defaults['ema_alpha'])
|
|
115
|
+
beta2_max = group.get('betas', master_defaults['betas'])[1]
|
|
116
|
+
beta2_min = group.get('beta2_min', master_defaults['beta2_min'])
|
|
117
|
+
tiny_spike = group.get('tiny_spike', master_defaults['tiny_spike'])
|
|
118
|
+
k_warmup_steps = group.get('k_warmup_steps', master_defaults['k_warmup_steps'])
|
|
119
|
+
|
|
103
120
|
r_ema_tensor = param_state['kourkoutas_r_ema']
|
|
104
121
|
accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
|
|
105
122
|
|
|
@@ -107,17 +124,16 @@ class KourkoutasHelper:
|
|
|
107
124
|
prev_r_ema_val = r_ema_tensor.item() # for logging
|
|
108
125
|
|
|
109
126
|
# Update the persistent EMA tensor in-place.
|
|
110
|
-
r_ema_tensor.mul_(
|
|
127
|
+
r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
|
|
111
128
|
|
|
112
|
-
beta2_max = group['betas'][1]
|
|
113
129
|
sun = torch.tensor(0.0, device=r_ema_tensor.device) # Default sun to 0 for warmup
|
|
114
130
|
|
|
115
|
-
if current_step <
|
|
131
|
+
if current_step < k_warmup_steps:
|
|
116
132
|
beta2 = beta2_max
|
|
117
133
|
else:
|
|
118
|
-
raw = pooled_grad_norm / (r_ema_tensor +
|
|
134
|
+
raw = pooled_grad_norm / (r_ema_tensor + tiny_spike)
|
|
119
135
|
sun = raw / (1.0 + raw)
|
|
120
|
-
beta2 = beta2_max - (beta2_max -
|
|
136
|
+
beta2 = beta2_max - (beta2_max - beta2_min) * sun
|
|
121
137
|
|
|
122
138
|
# Store the final calculated beta2 in the helper's transient state for this step.
|
|
123
139
|
self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|